From 925709de6ee0344ef9487a2747211f64c928e654 Mon Sep 17 00:00:00 2001 From: Pavan Kumar <66913595+ascender1729@users.noreply.github.com> Date: Sun, 20 Apr 2025 16:00:45 +0000 Subject: [PATCH] fix(mcp): resolve client initialization and update server configuration in mcp integration --- examples/mcp_example/mock_math_server.py | 5 ++-- swarms/tools/mcp_client.py | 1 + swarms/tools/mcp_integration.py | 38 ++++++++++++------------ 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/examples/mcp_example/mock_math_server.py b/examples/mcp_example/mock_math_server.py index 5386d70f..298c3a89 100644 --- a/examples/mcp_example/mock_math_server.py +++ b/examples/mcp_example/mock_math_server.py @@ -5,10 +5,11 @@ from loguru import logger # Create FastMCP instance with SSE transport mcp = FastMCP( - host="0.0.0.0", + host="0.0.0.0", port=8000, + transport="sse", require_session_id=False, - transport="sse" # Explicitly specify SSE transport + timeout=30.0 ) @mcp.tool() diff --git a/swarms/tools/mcp_client.py b/swarms/tools/mcp_client.py index 9a9d2b37..5d25b33d 100644 --- a/swarms/tools/mcp_client.py +++ b/swarms/tools/mcp_client.py @@ -43,6 +43,7 @@ async def _execute_mcp_tool( method: Literal["stdio", "sse"] = "sse", parameters: Dict[Any, Any] = None, output_type: Literal["str", "dict"] = "str", + timeout: float = 30.0, *args, **kwargs, ) -> Dict[Any, Any]: diff --git a/swarms/tools/mcp_integration.py b/swarms/tools/mcp_integration.py index 320385ff..cbc7e005 100644 --- a/swarms/tools/mcp_integration.py +++ b/swarms/tools/mcp_integration.py @@ -297,8 +297,10 @@ class MCPServerSse: """Connect to the MCP server with proper locking.""" async with self._connection_lock: if not self.client: - self.client = ClientSession() - await self.client.connect(self.create_streams()) + transport = await self.create_streams() + read_stream, write_stream = transport + self.client = ClientSession(read_stream=read_stream, write_stream=write_stream) + await self.client.initialize() def create_streams(self, **kwargs) -> AbstractAsyncContextManager[Any]: return sse_client( @@ -312,7 +314,7 @@ class MCPServerSse: """Parse input while preserving original format.""" if isinstance(payload, dict): return payload - + if isinstance(payload, str): try: # Try to parse as JSON @@ -321,37 +323,37 @@ class MCPServerSse: except json.JSONDecodeError: # Check if it's a math operation import re - + # Pattern matching for basic math operations add_pattern = r"(?i)(?:what\s+is\s+)?(\d+)\s*(?:plus|\+)\s*(\d+)" mult_pattern = r"(?i)(?:multiply|times|\*)\s*(\d+)\s*(?:and|by)?\s*(\d+)" div_pattern = r"(?i)(?:divide)\s*(\d+)\s*(?:by)\s*(\d+)" - + # Check for addition if match := re.search(add_pattern, payload): a, b = map(int, match.groups()) return {"tool_name": "add", "a": a, "b": b} - + # Check for multiplication if match := re.search(mult_pattern, payload): a, b = map(int, match.groups()) return {"tool_name": "multiply", "a": a, "b": b} - + # Check for division if match := re.search(div_pattern, payload): a, b = map(int, match.groups()) return {"tool_name": "divide", "a": a, "b": b} - + # Default to text input if no pattern matches return {"text": payload} - + return {"text": str(payload)} def _format_output(self, result: Any, original_input: Any) -> str: """Format output based on input type and result.""" if not self.preserve_format: return str(result) - + try: if isinstance(result, (int, float)): # For numeric results, format based on operation @@ -376,30 +378,30 @@ class MCPServerSse: """Call a tool on the MCP server with support for various input formats.""" if not self.client: raise RuntimeError("Not connected to MCP server") - + # Store original input for formatting original_input = payload - + # Parse input parsed_payload = self._parse_input(payload) - + # Add message to history self.messages.append({ "role": "user", "content": str(payload), "parsed": parsed_payload }) - + try: result = await self.client.call_tool(parsed_payload) formatted_result = self._format_output(result, original_input) - + self.messages.append({ "role": "assistant", "content": formatted_result, "raw_result": result }) - + return formatted_result except Exception as e: error_msg = f"Error calling tool: {str(e)}" @@ -502,6 +504,4 @@ async def _batch(params: List[MCPServerSseParams], payload: dict[str, Any] | str return [any_to_str(r) for r in results if not isinstance(r, Exception)] except Exception as e: logger.error(f"Error in batch processing: {e}") - return [] - - + return [] \ No newline at end of file