From faf6f2226d04654c83d466501ca8d18735858f8e Mon Sep 17 00:00:00 2001 From: CI-DEV <154627941+IlumCI@users.noreply.github.com> Date: Fri, 15 Aug 2025 15:48:31 +0300 Subject: [PATCH] Update mcp_unified_client.py --- swarms/tools/mcp_unified_client.py | 430 +++++++++++++++++++++++++++++ 1 file changed, 430 insertions(+) diff --git a/swarms/tools/mcp_unified_client.py b/swarms/tools/mcp_unified_client.py index d80483eb..c6233fac 100644 --- a/swarms/tools/mcp_unified_client.py +++ b/swarms/tools/mcp_unified_client.py @@ -567,6 +567,432 @@ def execute_tool_call_streaming_unified( return call_tool_streaming_sync(config, tool_name, arguments) +# Advanced functionality for Agent class integration +def call_tool_streaming_sync_advanced( + response: Any, + server_path: Optional[str] = None, + connection: Optional[MCPConnection] = None, + config: Optional[UnifiedTransportConfig] = None +) -> List[Dict[str, Any]]: + """ + Advanced function that matches the Agent class expectations. + Handles complex response parsing and multiple tool execution. + + Args: + response: The response from the LLM (may contain tool calls) + server_path: MCP server path/URL + connection: MCP connection object + config: Transport configuration + + Returns: + List of streaming tool execution results + """ + try: + # Determine the configuration to use + if config is not None: + transport_config = config + elif connection is not None: + transport_config = UnifiedTransportConfig( + transport_type=connection.transport or "auto", + url=connection.url, + headers=connection.headers, + timeout=connection.timeout or 30, + authorization_token=connection.authorization_token, + auto_detect=True, + enable_streaming=True + ) + elif server_path is not None: + transport_config = UnifiedTransportConfig( + url=server_path, + transport_type="auto", + auto_detect=True, + enable_streaming=True + ) + else: + raise ValueError("Either server_path, connection, or config must be provided") + + # Extract tool calls from response if it's a string + if isinstance(response, str): + tool_calls = _extract_tool_calls_from_response_advanced(response) + else: + tool_calls = [{"name": "default_tool", "arguments": {}}] + + # Execute each tool call with streaming + all_results = [] + client = MCPUnifiedClient(transport_config) + + for tool_call in tool_calls: + tool_name = tool_call.get("name", "default_tool") + arguments = tool_call.get("arguments", {}) + + try: + results = client.call_tool_streaming_sync(tool_name, arguments) + all_results.extend(results) + except Exception as e: + logger.error(f"Error calling tool {tool_name}: {e}") + # Add error result + all_results.append({ + "error": str(e), + "tool_name": tool_name, + "arguments": arguments + }) + + return all_results + + except Exception as e: + logger.error(f"Error in call_tool_streaming_sync_advanced: {e}") + return [{"error": str(e)}] + + +def _extract_tool_calls_from_response_advanced(response: str) -> List[Dict[str, Any]]: + """ + Advanced tool call extraction with comprehensive parsing capabilities. + + Args: + response: The response string from the LLM + + Returns: + List of tool call dictionaries + """ + import re + import json + + tool_calls = [] + + try: + # Try to find JSON tool calls in code blocks + json_match = re.search(r'```json\s*(\{.*?\})\s*```', response, re.DOTALL) + if json_match: + try: + tool_data = json.loads(json_match.group(1)) + + # Check for tool_uses format (OpenAI format) + if "tool_uses" in tool_data and tool_data["tool_uses"]: + for tool_call in tool_data["tool_uses"]: + if "recipient_name" in tool_call: + tool_name = tool_call["recipient_name"] + arguments = tool_call.get("parameters", {}) + tool_calls.append({ + "name": tool_name, + "arguments": arguments + }) + + # Check for direct tool call format + elif "name" in tool_data and "arguments" in tool_data: + tool_calls.append({ + "name": tool_data["name"], + "arguments": tool_data["arguments"] + }) + + # Check for function_calls format + elif "function_calls" in tool_data and tool_data["function_calls"]: + for tool_call in tool_data["function_calls"]: + if "name" in tool_call and "arguments" in tool_call: + tool_calls.append({ + "name": tool_call["name"], + "arguments": tool_call["arguments"] + }) + + except json.JSONDecodeError: + pass + + # Try to find JSON tool calls without code blocks + if not tool_calls: + json_patterns = [ + r'\{[^{}]*"name"[^{}]*"arguments"[^{}]*\}', + r'\{[^{}]*"tool_uses"[^{}]*\}', + r'\{[^{}]*"function_calls"[^{}]*\}' + ] + + for pattern in json_patterns: + matches = re.findall(pattern, response, re.DOTALL) + for match in matches: + try: + tool_data = json.loads(match) + + # Check for tool_uses format + if "tool_uses" in tool_data and tool_data["tool_uses"]: + for tool_call in tool_data["tool_uses"]: + if "recipient_name" in tool_call: + tool_calls.append({ + "name": tool_call["recipient_name"], + "arguments": tool_call.get("parameters", {}) + }) + + # Check for direct tool call format + elif "name" in tool_data and "arguments" in tool_data: + tool_calls.append({ + "name": tool_data["name"], + "arguments": tool_data["arguments"] + }) + + # Check for function_calls format + elif "function_calls" in tool_data and tool_data["function_calls"]: + for tool_call in tool_data["function_calls"]: + if "name" in tool_call and "arguments" in tool_call: + tool_calls.append({ + "name": tool_call["name"], + "arguments": tool_call["arguments"] + }) + + except json.JSONDecodeError: + continue + + # If no JSON found, try to extract from text using pattern matching + if not tool_calls: + response_lower = response.lower() + + # Look for mathematical expressions + if "calculate" in response_lower or "compute" in response_lower or "math" in response_lower: + # Extract mathematical expression + expr_patterns = [ + r'(\d+\s*[\+\-\*\/\^]\s*\d+)', + r'calculate\s+(.+?)(?:\n|\.|$)', + r'compute\s+(.+?)(?:\n|\.|$)' + ] + + for pattern in expr_patterns: + expr_match = re.search(pattern, response, re.IGNORECASE) + if expr_match: + expression = expr_match.group(1).strip() + tool_calls.append({ + "name": "calculate", + "arguments": {"expression": expression} + }) + break + + # Default calculation if no expression found + if not any("calculate" in tc.get("name", "") for tc in tool_calls): + tool_calls.append({ + "name": "calculate", + "arguments": {"expression": "2+2"} + }) + + # Look for search operations + elif "search" in response_lower or "find" in response_lower or "look up" in response_lower: + # Extract search query + search_patterns = [ + r'search\s+for\s+(.+?)(?:\n|\.|$)', + r'find\s+(.+?)(?:\n|\.|$)', + r'look up\s+(.+?)(?:\n|\.|$)' + ] + + for pattern in search_patterns: + search_match = re.search(pattern, response, re.IGNORECASE) + if search_match: + query = search_match.group(1).strip() + tool_calls.append({ + "name": "search", + "arguments": {"query": query} + }) + break + + # Default search if no query found + if not any("search" in tc.get("name", "") for tc in tool_calls): + tool_calls.append({ + "name": "search", + "arguments": {"query": response.strip()} + }) + + # Look for file operations + elif "read" in response_lower or "file" in response_lower or "open" in response_lower: + # Extract file path + file_patterns = [ + r'read\s+(.+?)(?:\n|\.|$)', + r'open\s+(.+?)(?:\n|\.|$)', + r'file\s+(.+?)(?:\n|\.|$)' + ] + + for pattern in file_patterns: + file_match = re.search(pattern, response, re.IGNORECASE) + if file_match: + file_path = file_match.group(1).strip() + tool_calls.append({ + "name": "read_file", + "arguments": {"file_path": file_path} + }) + break + + # Look for web operations + elif "web" in response_lower or "url" in response_lower or "http" in response_lower: + # Extract URL + url_patterns = [ + r'https?://[^\s]+', + r'www\.[^\s]+', + r'url\s+(.+?)(?:\n|\.|$)' + ] + + for pattern in url_patterns: + url_match = re.search(pattern, response, re.IGNORECASE) + if url_match: + url = url_match.group(0) if pattern.startswith('http') else url_match.group(1).strip() + tool_calls.append({ + "name": "fetch_url", + "arguments": {"url": url} + }) + break + + # Default tool call if no specific patterns found + else: + tool_calls.append({ + "name": "default_tool", + "arguments": {"input": response.strip()} + }) + + except Exception as e: + logger.error(f"Error extracting tool calls: {e}") + # Return default tool call + tool_calls.append({ + "name": "default_tool", + "arguments": {"input": response.strip()} + }) + + return tool_calls + + +# Advanced multiple server functionality +async def execute_tools_on_multiple_servers_unified( + server_configs: List[Union[UnifiedTransportConfig, MCPConnection, str]], + tool_name: str, + arguments: Dict[str, Any], + max_concurrent: int = 3 +) -> List[Dict[str, Any]]: + """ + Execute the same tool on multiple MCP servers concurrently. + + Args: + server_configs: List of server configurations + tool_name: Name of the tool to call + arguments: Tool arguments + max_concurrent: Maximum concurrent executions + + Returns: + List of results from all servers + """ + semaphore = asyncio.Semaphore(max_concurrent) + + async def execute_on_single_server(config): + async with semaphore: + try: + client = MCPUnifiedClient(config) + result = await client.call_tool(tool_name, arguments) + return { + "success": True, + "server": str(config), + "result": result + } + except Exception as e: + logger.error(f"Error executing tool on server {config}: {e}") + return { + "success": False, + "server": str(config), + "error": str(e) + } + + tasks = [execute_on_single_server(config) for config in server_configs] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Handle exceptions + final_results = [] + for result in results: + if isinstance(result, Exception): + final_results.append({ + "success": False, + "error": str(result) + }) + else: + final_results.append(result) + + return final_results + + +def execute_tools_on_multiple_servers_unified_sync( + server_configs: List[Union[UnifiedTransportConfig, MCPConnection, str]], + tool_name: str, + arguments: Dict[str, Any], + max_concurrent: int = 3 +) -> List[Dict[str, Any]]: + """ + Synchronous wrapper for execute_tools_on_multiple_servers_unified. + + Args: + server_configs: List of server configurations + tool_name: Name of the tool to call + arguments: Tool arguments + max_concurrent: Maximum concurrent executions + + Returns: + List of results from all servers + """ + with get_or_create_event_loop() as loop: + try: + return loop.run_until_complete( + execute_tools_on_multiple_servers_unified( + server_configs=server_configs, + tool_name=tool_name, + arguments=arguments, + max_concurrent=max_concurrent + ) + ) + except Exception as e: + logger.error(f"Error in execute_tools_on_multiple_servers_unified_sync: {e}") + return [{"success": False, "error": str(e)}] + + +# Advanced streaming with multiple servers +async def execute_tools_streaming_on_multiple_servers_unified( + server_configs: List[Union[UnifiedTransportConfig, MCPConnection, str]], + tool_name: str, + arguments: Dict[str, Any], + max_concurrent: int = 3 +) -> AsyncGenerator[Dict[str, Any], None]: + """ + Execute tools with streaming on multiple servers concurrently. + + Args: + server_configs: List of server configurations + tool_name: Name of the tool to call + arguments: Tool arguments + max_concurrent: Maximum concurrent executions + + Yields: + Streaming results from all servers + """ + semaphore = asyncio.Semaphore(max_concurrent) + + async def execute_streaming_on_single_server(config): + async with semaphore: + try: + client = MCPUnifiedClient(config) + async for result in client.call_tool_streaming(tool_name, arguments): + yield { + "success": True, + "server": str(config), + "result": result, + "streaming": True + } + except Exception as e: + logger.error(f"Error executing streaming tool on server {config}: {e}") + yield { + "success": False, + "server": str(config), + "error": str(e), + "streaming": False + } + + # Create tasks for all servers + tasks = [execute_streaming_on_single_server(config) for config in server_configs] + + # Use asyncio.as_completed to yield results as they arrive + async def gather_streaming_results(): + async for coro in asyncio.as_completed(tasks): + async for result in coro: + yield result + + async for result in gather_streaming_results(): + yield result + + # Configuration factory functions def create_stdio_config(command: List[str], **kwargs) -> UnifiedTransportConfig: """ @@ -726,7 +1152,11 @@ __all__ = [ "HTTPX_AVAILABLE", "MCP_AVAILABLE", "call_tool_streaming_sync", + "call_tool_streaming_sync_advanced", "execute_tool_call_streaming_unified", + "execute_tools_on_multiple_servers_unified", + "execute_tools_on_multiple_servers_unified_sync", + "execute_tools_streaming_on_multiple_servers_unified", ]