diff --git a/swarms/tools/mcp_unified_client.py b/swarms/tools/mcp_unified_client.py index bc4c35fa..55022026 100644 --- a/swarms/tools/mcp_unified_client.py +++ b/swarms/tools/mcp_unified_client.py @@ -35,7 +35,7 @@ from loguru import logger from pydantic import BaseModel, Field # Import existing MCP functionality -from swarms.schemas.mcp_schemas import MCPConnection +from swarms.schemas.mcp_schemas import MCPConnection, UnifiedTransportConfig from swarms.tools.mcp_client_call import ( MCPConnectionError, MCPExecutionError, @@ -74,79 +74,6 @@ except ImportError: HTTPX_AVAILABLE = False -class UnifiedTransportConfig(BaseModel): - """ - Unified configuration for MCP transport types. - - This extends the existing MCPConnection schema with additional - transport-specific options and auto-detection capabilities. - Includes streaming support for real-time communication. - """ - - # Transport type - can be auto-detected - transport_type: Literal["stdio", "http", "streamable_http", "sse", "auto"] = Field( - default="auto", - description="The transport type to use. 'auto' enables auto-detection." - ) - - # Connection details - url: Optional[str] = Field( - default=None, - description="URL for HTTP-based transports or stdio command path" - ) - - # STDIO specific - command: Optional[List[str]] = Field( - default=None, - description="Command and arguments for stdio transport" - ) - - # HTTP specific - headers: Optional[Dict[str, str]] = Field( - default=None, - description="HTTP headers for HTTP-based transports" - ) - - # Common settings - timeout: int = Field( - default=30, - description="Timeout in seconds" - ) - - authorization_token: Optional[str] = Field( - default=None, - description="Authentication token for accessing the MCP server" - ) - - # Auto-detection settings - auto_detect: bool = Field( - default=True, - description="Whether to auto-detect transport type from URL" - ) - - # Fallback settings - fallback_transport: Literal["stdio", "http", "streamable_http", "sse"] = Field( - default="sse", - description="Fallback transport if auto-detection fails" - ) - - # Streaming settings - enable_streaming: bool = Field( - default=True, - description="Whether to enable streaming support" - ) - - streaming_timeout: Optional[int] = Field( - default=None, - description="Timeout for streaming operations" - ) - - streaming_callback: Optional[Callable[[str], None]] = Field( - default=None, - description="Optional callback function for streaming chunks" - ) - - class MCPUnifiedClient: """ Unified MCP client that supports multiple transport types. @@ -621,6 +548,168 @@ async def aexecute_tool_call_streaming_unified( yield result +# Function that matches the Agent class expectations +def call_tool_streaming_sync( + response: Any, + server_path: Optional[str] = None, + connection: Optional[MCPConnection] = None, + config: Optional[UnifiedTransportConfig] = None +) -> List[Dict[str, Any]]: + """ + Call a tool with streaming support - matches Agent class expectations. + + Args: + response: The response from the LLM (may contain tool calls) + server_path: MCP server path/URL + connection: MCP connection object + config: Transport configuration + + Returns: + List of streaming tool execution results + """ + try: + # Determine the configuration to use + if config is not None: + transport_config = config + elif connection is not None: + transport_config = UnifiedTransportConfig( + transport_type=connection.transport or "auto", + url=connection.url, + headers=connection.headers, + timeout=connection.timeout or 30, + authorization_token=connection.authorization_token, + auto_detect=True, + enable_streaming=True + ) + elif server_path is not None: + transport_config = UnifiedTransportConfig( + url=server_path, + transport_type="auto", + auto_detect=True, + enable_streaming=True + ) + else: + raise ValueError("Either server_path, connection, or config must be provided") + + # Extract tool calls from response if it's a string + if isinstance(response, str): + tool_calls = _extract_tool_calls_from_response(response) + else: + tool_calls = [{"name": "default_tool", "arguments": {}}] + + # Execute each tool call with streaming + all_results = [] + client = MCPUnifiedClient(transport_config) + + for tool_call in tool_calls: + tool_name = tool_call.get("name", "default_tool") + arguments = tool_call.get("arguments", {}) + + try: + results = client.call_tool_streaming_sync(tool_name, arguments) + all_results.extend(results) + except Exception as e: + logger.error(f"Error calling tool {tool_name}: {e}") + # Add error result + all_results.append({ + "error": str(e), + "tool_name": tool_name, + "arguments": arguments + }) + + return all_results + + except Exception as e: + logger.error(f"Error in call_tool_streaming_sync: {e}") + return [{"error": str(e)}] + + +def _extract_tool_calls_from_response(response: str) -> List[Dict[str, Any]]: + """ + Extract tool calls from LLM response. + + Args: + response: The response string from the LLM + + Returns: + List of tool call dictionaries + """ + import re + import json + + tool_calls = [] + + try: + # Try to find JSON tool calls + json_match = re.search(r'```json\s*(\{.*?\})\s*```', response, re.DOTALL) + if json_match: + try: + tool_data = json.loads(json_match.group(1)) + + # Check for tool_uses format + if "tool_uses" in tool_data and tool_data["tool_uses"]: + for tool_call in tool_data["tool_uses"]: + if "recipient_name" in tool_call: + tool_name = tool_call["recipient_name"] + arguments = tool_call.get("parameters", {}) + tool_calls.append({ + "name": tool_name, + "arguments": arguments + }) + + # Check for direct tool call format + elif "name" in tool_data and "arguments" in tool_data: + tool_calls.append({ + "name": tool_data["name"], + "arguments": tool_data["arguments"] + }) + + except json.JSONDecodeError: + pass + + # If no JSON found, try to extract from text + if not tool_calls: + # Look for common tool patterns + response_lower = response.lower() + + if "calculate" in response_lower or "compute" in response_lower: + # Extract mathematical expression + expr_match = re.search(r'(\d+\s*[\+\-\*\/]\s*\d+)', response) + if expr_match: + tool_calls.append({ + "name": "calculate", + "arguments": {"expression": expr_match.group(1)} + }) + else: + tool_calls.append({ + "name": "calculate", + "arguments": {"expression": "2+2"} + }) + + elif "search" in response_lower or "find" in response_lower: + tool_calls.append({ + "name": "search", + "arguments": {"query": response.strip()} + }) + + else: + # Default tool call + tool_calls.append({ + "name": "default_tool", + "arguments": {"input": response.strip()} + }) + + except Exception as e: + logger.error(f"Error extracting tool calls: {e}") + # Return default tool call + tool_calls.append({ + "name": "default_tool", + "arguments": {"input": response.strip()} + }) + + return tool_calls + + # Helper functions for creating configurations def create_stdio_config(command: List[str], **kwargs) -> UnifiedTransportConfig: """