Update mcp_client_call.py

pull/1005/head
CI-DEV 2 months ago committed by GitHub
parent 6497e5d827
commit 6c7c8b9405
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -12,12 +12,21 @@ from loguru import logger
from mcp import ClientSession from mcp import ClientSession
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
try:
from mcp.client.stdio import stdio_client
except ImportError:
logger.error(
"stdio_client is not available. Please ensure the MCP SDK is up to date with pip3 install -U mcp"
)
stdio_client = None
try: try:
from mcp.client.streamable_http import streamablehttp_client from mcp.client.streamable_http import streamablehttp_client
except ImportError: except ImportError:
logger.error( logger.error(
"streamablehttp_client is not available. Please ensure the MCP SDK is up to date with pip3 install -U mcp" "streamablehttp_client is not available. Please ensure the MCP SDK is up to date with pip3 install -U mcp"
) )
streamablehttp_client = None
from urllib.parse import urlparse from urllib.parse import urlparse
@ -313,6 +322,30 @@ def get_mcp_client(transport, url, headers=None, timeout=5, **kwargs):
return streamablehttp_client( return streamablehttp_client(
url, headers=headers, timeout=timeout, **kwargs url, headers=headers, timeout=timeout, **kwargs
) )
elif transport == "stdio":
if stdio_client is None:
logger.error("stdio_client is not available.")
raise ImportError(
"stdio_client is not available. Please ensure the MCP SDK is up to date."
)
# For stdio, extract the command from the URL
# URL format: stdio://simple_mcp_server.py -> command: ["python", "simple_mcp_server.py"]
if url.startswith("stdio://"):
script_path = url[8:] # Remove "stdio://" prefix
command = "python"
args = [script_path]
else:
command = url
args = []
# Create StdioServerParameters
from mcp.client.stdio import StdioServerParameters
server_params = StdioServerParameters(
command=command,
args=args
)
logger.info(f"Using stdio server parameters: {server_params}")
return stdio_client(server_params)
else: else:
return sse_client( return sse_client(
url, headers=headers, timeout=timeout, **kwargs url, headers=headers, timeout=timeout, **kwargs
@ -419,6 +452,9 @@ async def aget_mcp_tools(
return tools return tools
except Exception as e: except Exception as e:
logger.error(f"Error fetching MCP tools: {str(e)}") logger.error(f"Error fetching MCP tools: {str(e)}")
logger.error(f"Exception type: {type(e).__name__}")
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
raise MCPConnectionError( raise MCPConnectionError(
f"Failed to connect to MCP server: {str(e)}" f"Failed to connect to MCP server: {str(e)}"
) )
@ -623,28 +659,56 @@ async def _execute_tool_call_simple(
call_result = await call_openai_tool( call_result = await call_openai_tool(
session=session, openai_tool=response session=session, openai_tool=response
) )
if output_type == "json":
out = call_result.model_dump_json(indent=4) # Handle different output types with better error handling
elif output_type == "dict": try:
out = call_result.model_dump() if output_type == "json":
elif output_type == "str": out = call_result.model_dump_json(indent=4)
data = call_result.model_dump() elif output_type == "dict":
formatted_lines = [] out = call_result.model_dump()
for key, value in data.items(): elif output_type == "str":
if isinstance(value, list): # Try to get the content from the MCP response
for item in value: try:
if isinstance(item, dict): data = call_result.model_dump()
for k, v in item.items(): formatted_lines = []
formatted_lines.append( for key, value in data.items():
f"{k}: {v}" if isinstance(value, list):
) for item in value:
else: if isinstance(item, dict):
formatted_lines.append( for k, v in item.items():
f"{key}: {value}" formatted_lines.append(
) f"{k}: {v}"
out = "\n".join(formatted_lines) )
else: else:
out = call_result.model_dump() formatted_lines.append(
f"{key}: {value}"
)
out = "\n".join(formatted_lines)
except Exception as format_error:
logger.warning(f"Error formatting MCP response: {format_error}")
# Fallback: try to get text content directly
try:
if hasattr(call_result, 'content') and call_result.content:
if isinstance(call_result.content, list) and len(call_result.content) > 0:
first_content = call_result.content[0]
if hasattr(first_content, 'text'):
out = first_content.text
else:
out = str(first_content)
else:
out = str(call_result.content)
else:
out = str(call_result)
except Exception as fallback_error:
logger.warning(f"Fallback formatting also failed: {fallback_error}")
out = str(call_result)
else:
out = call_result.model_dump()
except Exception as format_error:
logger.warning(f"Error in output formatting: {format_error}")
# Final fallback
out = str(call_result)
logger.info( logger.info(
f"Tool call executed successfully for {server_path}" f"Tool call executed successfully for {server_path}"
) )
@ -684,10 +748,27 @@ async def execute_tool_call_simple(
logger.info( logger.info(
f"execute_tool_call_simple called for server_path: {server_path}" f"execute_tool_call_simple called for server_path: {server_path}"
) )
# Validate response before processing
if response is None or response == "":
logger.warning("Empty or None response received, returning empty result")
return []
if transport is None: if transport is None:
transport = auto_detect_transport(server_path) transport = auto_detect_transport(server_path)
# Handle string responses with proper validation
if isinstance(response, str): if isinstance(response, str):
response = json.loads(response) if not response.strip():
logger.warning("Empty string response received, returning empty result")
return []
try:
response = json.loads(response)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON response: {e}")
logger.error(f"Response content: {repr(response)}")
return []
return await _execute_tool_call_simple( return await _execute_tool_call_simple(
response=response, response=response,
server_path=server_path, server_path=server_path,

Loading…
Cancel
Save