diff --git a/examples/mcp_example/mock_math_server.py b/examples/mcp_example/mock_math_server.py index d46c144a..5386d70f 100644 --- a/examples/mcp_example/mock_math_server.py +++ b/examples/mcp_example/mock_math_server.py @@ -1,26 +1,49 @@ - from fastmcp import FastMCP +from typing import Dict, Any +import asyncio +from loguru import logger -mcp = FastMCP("Math-Mock-Server") +# Create FastMCP instance with SSE transport +mcp = FastMCP( + host="0.0.0.0", + port=8000, + require_session_id=False, + transport="sse" # Explicitly specify SSE transport +) @mcp.tool() def add(a: int, b: int) -> int: - """Add two numbers together""" + """Add two numbers.""" return a + b @mcp.tool() def multiply(a: int, b: int) -> int: - """Multiply two numbers together""" + """Multiply two numbers.""" return a * b @mcp.tool() def divide(a: int, b: int) -> float: - """Divide two numbers""" + """Divide two numbers.""" if b == 0: - return {"error": "Cannot divide by zero"} + raise ValueError("Cannot divide by zero") return a / b +async def run_server(): + """Run the server with proper error handling.""" + try: + logger.info("Starting math server on http://0.0.0.0:8000") + await mcp.run_async() + except Exception as e: + logger.error(f"Server error: {e}") + raise + finally: + await mcp.cleanup() + if __name__ == "__main__": - print("Starting Mock Math Server on port 8000...") - # FastMCP expects transport_kwargs as separate parameters - mcp.run(transport="sse", host="0.0.0.0", port=8000) + try: + asyncio.run(run_server()) + except KeyboardInterrupt: + logger.info("Server stopped by user") + except Exception as e: + logger.error(f"Fatal error: {e}") + raise diff --git a/swarms/tools/mcp_integration.py b/swarms/tools/mcp_integration.py index 2af01f86..b046b3b5 100644 --- a/swarms/tools/mcp_integration.py +++ b/swarms/tools/mcp_integration.py @@ -284,12 +284,16 @@ class MCPServerSse: def __init__(self, params: MCPServerSseParams): self.params = params self.client: Optional[ClientSession] = None + self._connection_lock = asyncio.Lock() + self.messages = [] # Store messages instead of using conversation + self.preserve_format = True # Flag to preserve original formatting async def connect(self): - """Connect to the MCP server.""" - if not self.client: - self.client = ClientSession() - await self.client.connect(self.create_streams()) + """Connect to the MCP server with proper locking.""" + async with self._connection_lock: + if not self.client: + self.client = ClientSession() + await self.client.connect(self.create_streams()) def create_streams(self, **kwargs) -> AbstractAsyncContextManager[Any]: return sse_client( @@ -299,45 +303,159 @@ class MCPServerSse: sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5), ) - async def call_tool(self, payload: dict[str, Any]): - """Call a tool on the MCP server.""" + def _parse_input(self, payload: Any) -> dict: + """Parse input while preserving original format.""" + if isinstance(payload, dict): + return payload + + if isinstance(payload, str): + try: + # Try to parse as JSON + import json + return json.loads(payload) + except json.JSONDecodeError: + # Check if it's a math operation + import re + + # Pattern matching for basic math operations + add_pattern = r"(?i)(?:what\s+is\s+)?(\d+)\s*(?:plus|\+)\s*(\d+)" + mult_pattern = r"(?i)(?:multiply|times|\*)\s*(\d+)\s*(?:and|by)?\s*(\d+)" + div_pattern = r"(?i)(?:divide)\s*(\d+)\s*(?:by)\s*(\d+)" + + # Check for addition + if match := re.search(add_pattern, payload): + a, b = map(int, match.groups()) + return {"tool_name": "add", "a": a, "b": b} + + # Check for multiplication + if match := re.search(mult_pattern, payload): + a, b = map(int, match.groups()) + return {"tool_name": "multiply", "a": a, "b": b} + + # Check for division + if match := re.search(div_pattern, payload): + a, b = map(int, match.groups()) + return {"tool_name": "divide", "a": a, "b": b} + + # Default to text input if no pattern matches + return {"text": payload} + + return {"text": str(payload)} + + def _format_output(self, result: Any, original_input: Any) -> str: + """Format output based on input type and result.""" + if not self.preserve_format: + return str(result) + + try: + if isinstance(result, (int, float)): + # For numeric results, format based on operation + if isinstance(original_input, dict): + tool_name = original_input.get("tool_name", "") + if tool_name == "add": + return f"{original_input['a']} + {original_input['b']} = {result}" + elif tool_name == "multiply": + return f"{original_input['a']} * {original_input['b']} = {result}" + elif tool_name == "divide": + return f"{original_input['a']} / {original_input['b']} = {result}" + return str(result) + elif isinstance(result, dict): + return json.dumps(result, indent=2) + else: + return str(result) + except Exception as e: + logger.error(f"Error formatting output: {e}") + return str(result) + + async def call_tool(self, payload: Any) -> Any: + """Call a tool on the MCP server with support for various input formats.""" if not self.client: raise RuntimeError("Not connected to MCP server") - return await self.client.call_tool(payload) + + # Store original input for formatting + original_input = payload + + # Parse input + parsed_payload = self._parse_input(payload) + + # Add message to history + self.messages.append({ + "role": "user", + "content": str(payload), + "parsed": parsed_payload + }) + + try: + result = await self.client.call_tool(parsed_payload) + formatted_result = self._format_output(result, original_input) + + self.messages.append({ + "role": "assistant", + "content": formatted_result, + "raw_result": result + }) + + return formatted_result + except Exception as e: + error_msg = f"Error calling tool: {str(e)}" + self.messages.append({ + "role": "error", + "content": error_msg, + "original_input": payload + }) + raise async def cleanup(self): - """Clean up the connection.""" - if self.client: - await self.client.close() - self.client = None - - async def list_tools(self) -> list[Any]: # Added for compatibility + """Clean up the connection with proper locking.""" + async with self._connection_lock: + if self.client: + await self.client.close() + self.client = None + + async def list_tools(self) -> list[Any]: + """List available tools with proper error handling.""" if not self.client: raise RuntimeError("Not connected to MCP server") - return await self.client.list_tools() + try: + return await self.client.list_tools() + except Exception as e: + logger.error(f"Error listing tools: {e}") + return [] -async def call_tool_fast(server: MCPServerSse, payload: dict[str, Any]): +async def call_tool_fast(server: MCPServerSse, payload: dict[str, Any] | str): """ - Convenience wrapper that opens → calls → closes in one shot. + Convenience wrapper that opens → calls → closes in one shot with proper error handling. """ - await server.connect() - result = await server.call_tool(payload) - await server.cleanup() - return result.model_dump() if hasattr(result, "model_dump") else result + try: + await server.connect() + result = await server.call_tool(payload) + return result.model_dump() if hasattr(result, "model_dump") else result + except Exception as e: + logger.error(f"Error in call_tool_fast: {e}") + raise + finally: + await server.cleanup() async def mcp_flow_get_tool_schema( params: MCPServerSseParams, -) -> Any: # Updated return type to Any - async with MCPServerSse(params) as server: - return (await server.list_tools()).model_dump() +) -> Any: + """Get tool schema with proper error handling.""" + try: + async with MCPServerSse(params) as server: + tools = await server.list_tools() + return tools.model_dump() if hasattr(tools, "model_dump") else tools + except Exception as e: + logger.error(f"Error getting tool schema: {e}") + raise async def mcp_flow( params: MCPServerSseParams, - function_call: dict[str, Any], -) -> Any: # Updated return type to Any + function_call: dict[str, Any] | str, +) -> Any: + """Execute MCP flow with proper error handling.""" try: async with MCPServerSse(params) as server: return await call_tool_fast(server, function_call) @@ -346,28 +464,40 @@ async def mcp_flow( raise -# Helper function to call one MCP server -async def _call_one_server(param: MCPServerSseParams, payload: dict[str, Any]) -> Any: +async def _call_one_server(param: MCPServerSseParams, payload: dict[str, Any] | str) -> Any: """Make a call to a single MCP server with proper async context management.""" - async with MCPServerSse(param, cache_tools_list=True) as srv: - res = await srv.call_tool(payload) - try: - return res.model_dump() # For fast-mcp ≥0.2 - except AttributeError: - return res # Plain dict or string + try: + server = MCPServerSse(param) + await server.connect() + result = await server.call_tool(payload) + return result + except Exception as e: + logger.error(f"Error calling server: {e}") + raise + finally: + if 'server' in locals(): + await server.cleanup() + -# Synchronous wrapper for the Agent to use -def batch_mcp_flow(params: List[MCPServerSseParams], payload: dict[str, Any]) -> List[Any]: +def batch_mcp_flow(params: List[MCPServerSseParams], payload: dict[str, Any] | str) -> List[Any]: """Blocking helper that fans out to all MCP servers in params.""" - return asyncio.run(_batch(params, payload)) + try: + return asyncio.run(_batch(params, payload)) + except Exception as e: + logger.error(f"Error in batch_mcp_flow: {e}") + return [] -# Async implementation of batch processing -async def _batch(params: List[MCPServerSseParams], payload: dict[str, Any]) -> List[Any]: + +async def _batch(params: List[MCPServerSseParams], payload: dict[str, Any] | str) -> List[Any]: """Fan out to all MCP servers asynchronously and gather results.""" - coros = [_call_one_server(p, payload) for p in params] - results = await asyncio.gather(*coros, return_exceptions=True) - # Filter out exceptions and convert to strings - return [any_to_str(r) for r in results if not isinstance(r, Exception)] + try: + coros = [_call_one_server(p, payload) for p in params] + results = await asyncio.gather(*coros, return_exceptions=True) + # Filter out exceptions and convert to strings + return [any_to_str(r) for r in results if not isinstance(r, Exception)] + except Exception as e: + logger.error(f"Error in batch processing: {e}") + return [] from mcp import (