From be7862d7a484b719ed0debca6f09e85848fd4417 Mon Sep 17 00:00:00 2001 From: CI-DEV <154627941+IlumCI@users.noreply.github.com> Date: Fri, 15 Aug 2025 15:48:10 +0300 Subject: [PATCH] Update mcp_client_call.py --- swarms/tools/mcp_client_call.py | 1978 ++++++++++++++----------------- 1 file changed, 894 insertions(+), 1084 deletions(-) diff --git a/swarms/tools/mcp_client_call.py b/swarms/tools/mcp_client_call.py index 2dc7d3d6..dd62726e 100644 --- a/swarms/tools/mcp_client_call.py +++ b/swarms/tools/mcp_client_call.py @@ -1,1217 +1,1027 @@ import asyncio -import contextlib import json -import os -import random -from concurrent.futures import ThreadPoolExecutor, as_completed +import logging +import time +import traceback +import re +from typing import Any, Dict, List, Optional, Union, AsyncGenerator from functools import wraps -from typing import Any, Dict, List, Literal, Optional, Union -from litellm.types.utils import ChatCompletionMessageToolCall -from loguru import logger -from mcp import ClientSession -from mcp.client.sse import sse_client +from mcp.client.stdio import stdio_client, StdioServerParameters +from mcp.client.streamable_http import streamablehttp_client +from mcp.client.session import ClientSession +from mcp.types import CallToolResult, TextContent -try: - from mcp.client.stdio import stdio_client -except ImportError: - logger.error( - "stdio_client is not available. Please ensure the MCP SDK is up to date with pip3 install -U mcp" - ) - stdio_client = None - -try: - from mcp.client.streamable_http import streamablehttp_client -except ImportError: - logger.error( - "streamablehttp_client is not available. Please ensure the MCP SDK is up to date with pip3 install -U mcp" - ) - streamablehttp_client = None - -from urllib.parse import urlparse - -from mcp.types import ( - CallToolRequestParams as MCPCallToolRequestParams, -) -from mcp.types import CallToolResult as MCPCallToolResult -from mcp.types import Tool as MCPTool -from openai.types.chat import ChatCompletionToolParam -from openai.types.shared_params.function_definition import ( - FunctionDefinition, -) - -from swarms.schemas.mcp_schemas import ( - MCPConnection, -) -from swarms.utils.index import exists - - -class MCPError(Exception): - """Base exception for MCP related errors.""" +logger = logging.getLogger(__name__) +# MCP Exception classes +class MCPConnectionError(Exception): + """Exception raised when there's an error connecting to the MCP server.""" pass - -class MCPConnectionError(MCPError): - """Raised when there are issues connecting to the MCP server.""" - - pass - - -class MCPToolError(MCPError): - """Raised when there are issues with MCP tool operations.""" - +class MCPExecutionError(Exception): + """Exception raised when there's an error executing an MCP tool.""" pass - -class MCPValidationError(MCPError): - """Raised when there are validation issues with MCP operations.""" - +class MCPToolError(Exception): + """Exception raised when there's an error with a specific MCP tool.""" pass - -class MCPExecutionError(MCPError): - """Raised when there are issues executing MCP operations.""" - +class MCPValidationError(Exception): + """Exception raised when there's a validation error with MCP data.""" pass - -######################################################## -# List MCP Tool functions -######################################################## -def transform_mcp_tool_to_openai_tool( - mcp_tool: MCPTool, -) -> ChatCompletionToolParam: - """ - Convert an MCP tool to an OpenAI tool. - Args: - mcp_tool (MCPTool): The MCP tool object. - Returns: - ChatCompletionToolParam: The OpenAI-compatible tool parameter. - """ - logger.info( - f"Transforming MCP tool '{mcp_tool.name}' to OpenAI tool format." - ) - return ChatCompletionToolParam( - type="function", - function=FunctionDefinition( - name=mcp_tool.name, - description=mcp_tool.description or "", - parameters=mcp_tool.inputSchema, - strict=False, - ), - ) - - -async def load_mcp_tools( - session: ClientSession, format: Literal["mcp", "openai"] = "mcp" -) -> Union[List[MCPTool], List[ChatCompletionToolParam]]: - """ - Load all available MCP tools from the session. - Args: - session (ClientSession): The MCP session to use. - format (Literal["mcp", "openai"]): The format to convert the tools to. - Returns: - List of tools in the specified format. - """ - logger.info(f"Loading MCP tools with format '{format}'.") - tools = await session.list_tools() - if format == "openai": - return [ - transform_mcp_tool_to_openai_tool(mcp_tool=tool) - for tool in tools.tools - ] - return tools.tools - - -######################################################## -# Call MCP Tool functions -######################################################## - - -async def call_mcp_tool( - session: ClientSession, - call_tool_request_params: MCPCallToolRequestParams, -) -> MCPCallToolResult: - """ - Call an MCP tool using the provided session and request parameters. - Args: - session (ClientSession): The MCP session to use. - call_tool_request_params (MCPCallToolRequestParams): The tool call request params. - Returns: - MCPCallToolResult: The result of the tool call. - """ - return await session.call_tool( - name=call_tool_request_params.name, - arguments=call_tool_request_params.arguments, - ) - - -def _get_function_arguments(function: FunctionDefinition) -> dict: - """ - Helper to safely get and parse function arguments from a function definition. - Args: - function (FunctionDefinition): The function definition. - Returns: - dict: Parsed arguments as a dictionary. - """ - arguments = function.get("arguments", {}) - if isinstance(arguments, str): - try: - arguments = json.loads(arguments) - except json.JSONDecodeError: - arguments = {} - return arguments if isinstance(arguments, dict) else {} - - -def transform_openai_tool_call_request_to_mcp_tool_call_request( - openai_tool: Union[ChatCompletionMessageToolCall, Dict], -) -> MCPCallToolRequestParams: - """ - Convert an OpenAI ChatCompletionMessageToolCall to an MCP CallToolRequestParams. - Args: - openai_tool (Union[ChatCompletionMessageToolCall, Dict]): The OpenAI tool call request. - Returns: - MCPCallToolRequestParams: The MCP tool call request params. - """ - function = openai_tool["function"] - return MCPCallToolRequestParams( - name=function["name"], - arguments=_get_function_arguments(function), - ) - - -async def call_openai_tool( - session: ClientSession, - openai_tool: dict, -) -> MCPCallToolResult: - """ - Call an OpenAI tool using MCP client. - Args: - session (ClientSession): The MCP session to use. - openai_tool (dict): The OpenAI tool to call. - Returns: - MCPCallToolResult: The result of the MCP tool call. - """ - mcp_tool_call_request_params = ( - transform_openai_tool_call_request_to_mcp_tool_call_request( - openai_tool=openai_tool, - ) - ) - return await call_mcp_tool( - session=session, - call_tool_request_params=mcp_tool_call_request_params, - ) - - -def retry_with_backoff(retries=3, backoff_in_seconds=1): - """ - Decorator for retrying async functions with exponential backoff. - Args: - retries (int): Number of retry attempts. - backoff_in_seconds (int): Initial backoff time in seconds. - Returns: - Decorated async function with retry logic. - """ - +def retry_on_failure(max_retries: int = 3, base_delay: float = 1.0): + """Retry decorator for MCP operations.""" def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): - x = 0 - while True: + last_exception = None + for attempt in range(max_retries): try: return await func(*args, **kwargs) except Exception as e: - if x == retries: - logger.error( - f"Failed after {retries} retries: {str(e)}" - ) - raise - sleep_time = ( - backoff_in_seconds * 2**x - + random.uniform(0, 1) - ) - logger.warning( - f"Attempt {x + 1} failed, retrying in {sleep_time:.2f}s" - ) - await asyncio.sleep(sleep_time) - x += 1 - + last_exception = e + if attempt < max_retries - 1: + delay = base_delay * (2 ** attempt) + logger.warning(f"Attempt {attempt + 1} failed, retrying in {delay:.2f}s") + await asyncio.sleep(delay) + else: + logger.error(f"Failed after {max_retries} retries: {str(e)}") + raise last_exception + return await func(*args, **kwargs) return wrapper - return decorator - -@contextlib.contextmanager -def get_or_create_event_loop(): - """ - Context manager to handle event loop creation and cleanup. - Yields: - asyncio.AbstractEventLoop: The event loop to use. - Ensures the event loop is properly closed if created here. - """ - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - yield loop - finally: - # Only close the loop if we created it and it's not the main event loop - if loop != asyncio.get_event_loop() and not loop.is_running(): - if not loop.is_closed(): - loop.close() - - -def connect_to_mcp_server(connection: MCPConnection = None): - """ - Connect to an MCP server using the provided connection configuration. - Args: - connection (MCPConnection): The connection configuration object. - Returns: - tuple: (headers, timeout, transport, url) - Raises: - MCPValidationError: If the connection object is invalid. - """ - logger.info( - "Connecting to MCP server using MCPConnection object." - ) - if not isinstance(connection, MCPConnection): - logger.error( - "Invalid connection type provided to connect_to_mcp_server." - ) - raise MCPValidationError("Invalid connection type") - headers = dict(connection.headers or {}) - if connection.authorization_token: - headers["Authorization"] = ( - f"Bearer {connection.authorization_token}" - ) - return ( - headers, - connection.timeout or 5, - connection.transport or "sse", - connection.url, - ) - - -def get_mcp_client(transport, url, headers=None, timeout=5, **kwargs): - """ - Helper to select the correct MCP client context manager based on transport. - Supports 'sse' (default) and 'streamable_http'. - Args: - transport (str): The transport type ('sse' or 'streamable_http'). - url (str): The server URL. - headers (dict): Optional headers. - timeout (int): Timeout in seconds. - **kwargs: Additional arguments. - Returns: - Context manager for the selected client. - Raises: - ImportError: If streamablehttp_client is not available when requested. - """ - logger.info( - f"Getting MCP client for transport '{transport}' and url '{url}'." - ) - if transport == "streamable_http": - if streamablehttp_client is None: - logger.error("streamablehttp_client is not available.") - raise ImportError( - "streamablehttp_client is not available. Please ensure the MCP SDK is up to date." - ) - return streamablehttp_client( - url, headers=headers, timeout=timeout, **kwargs - ) - elif transport == "stdio": - if stdio_client is None: - logger.error("stdio_client is not available.") - raise ImportError( - "stdio_client is not available. Please ensure the MCP SDK is up to date." - ) - # For stdio, extract the command from the URL - # URL format: stdio://simple_mcp_server.py -> command: ["python", "simple_mcp_server.py"] - if url.startswith("stdio://"): - script_path = url[8:] # Remove "stdio://" prefix - command = "python" - args = [script_path] - else: - command = url - args = [] - - # Create StdioServerParameters - from mcp.client.stdio import StdioServerParameters - server_params = StdioServerParameters( - command=command, - args=args - ) - logger.info(f"Using stdio server parameters: {server_params}") - return stdio_client(server_params) - else: - return sse_client( - url, headers=headers, timeout=timeout, **kwargs - ) - - def auto_detect_transport(url: str) -> str: - """ - Guess the MCP transport based on the URL scheme and path. - Does not make any network requests. - Returns one of: 'streamable_http', 'sse', or 'stdio'. - Args: - url (str): The server URL. - Returns: - str: The detected transport type. - """ - parsed = urlparse(url) - scheme = parsed.scheme.lower() - if scheme in ("http", "https"): - logger.info( - f"Automatically selected 'streamable_http' transport for {url}" - ) - return "streamable_http" - elif scheme in ("ws", "wss"): - logger.info( - f"Automatically selected 'sse' transport for {url}" - ) - return "sse" # or 'websocket' if you support it - elif "stdio" in url or scheme == "": - logger.info( - f"Automatically selected 'stdio' transport for {url}" - ) + """Auto-detect transport type from URL.""" + if url.startswith("stdio://"): return "stdio" + elif url.startswith("http://") or url.startswith("https://"): + return "http" else: - logger.info(f"Defaulting to 'sse' transport for {url}") - return "sse" + # Default to stdio for file paths + return "stdio" +def get_mcp_client(transport: str, url: str): + """Get MCP client based on transport type.""" + logger.info(f"Getting MCP client for transport '{transport}' and url '{url}'.") + + if transport == "stdio": + # Extract the command from stdio URL + if url.startswith("stdio://"): + command_path = url[8:] # Remove "stdio://" prefix + command_parts = command_path.split() + command = command_parts[0] + args = command_parts[1:] if len(command_parts) > 1 else [] + + # Use the current Python executable for Windows compatibility + import sys + python_executable = sys.executable + + logger.info(f"Using stdio server parameters: command='{python_executable}' args={[command] + args}") + + # Use the correct API for MCP 1.11.0 with StdioServerParameters + server_params = StdioServerParameters( + command=python_executable, + args=[command] + args + ) + + return stdio_client(server_params) + else: + raise ValueError(f"Invalid stdio URL format: {url}") + + elif transport == "http": + return streamablehttp_client(url) + + else: + raise ValueError(f"Unsupported transport type: {transport}") -@retry_with_backoff(retries=3) +@retry_on_failure(max_retries=3, base_delay=1.0) async def aget_mcp_tools( - server_path: Optional[str] = None, - format: str = "openai", - connection: Optional[MCPConnection] = None, + server_path: str, transport: Optional[str] = None, *args, **kwargs, ) -> List[Dict[str, Any]]: """ - Fetch available MCP tools from the server with retry logic. + Async function to get MCP tools from a server. + Args: - server_path (str): Path to the MCP server script. - format (str): Format to return tools in ('openai' or 'mcp'). - connection (Optional[MCPConnection]): Optional connection object. - transport (Optional[str]): Transport type. If None, auto-detects. + server_path: The server URL or path + transport: The transport type (auto-detected if None) + *args: Additional arguments + **kwargs: Additional keyword arguments + Returns: - List[Dict[str, Any]]: List of available MCP tools in OpenAI format. + List of MCP tools + Raises: - MCPValidationError: If server_path is invalid. - MCPConnectionError: If connection to server fails. + MCPConnectionError: If connection fails + MCPToolError: If tool retrieval fails """ - logger.info( - f"aget_mcp_tools called for server_path: {server_path}" - ) + logger.info(f"aget_mcp_tools called for server_path: {server_path}") + + # Auto-detect transport if not specified if transport is None: transport = auto_detect_transport(server_path) - if exists(connection): - headers, timeout, transport_from_conn, url = ( - connect_to_mcp_server(connection) - ) - if transport_from_conn: - transport = transport_from_conn - else: - headers, timeout, _transport, _url = ( - None, - 5, - None, - server_path, - ) - url = server_path - logger.info( - f"Fetching MCP tools from server: {server_path} using transport: {transport}" - ) + + logger.info(f"Fetching MCP tools from server: {server_path} using transport: {transport}") + try: - async with get_mcp_client( - transport, - url=url, - headers=headers, - timeout=timeout, - *args, - **kwargs, - ) as ctx: - if len(ctx) == 2: - read, write = ctx - else: - read, write, *_ = ctx - async with ClientSession(read, write) as session: - await session.initialize() - tools = await load_mcp_tools( - session=session, format=format - ) - logger.info( - f"Successfully fetched {len(tools)} tools" - ) - return tools + # Get the appropriate client + logger.info(f"Getting MCP client for transport '{transport}' and url '{server_path}'.") + client = get_mcp_client(transport, server_path) + + # Use the client as a context manager + async with client as (read_stream, write_stream): + # Create a session manually with the streams + session = ClientSession(read_stream, write_stream) + + # Initialize the session without any parameters + await session.initialize() + + # Get the tools + tools = await session.list_tools() + + logger.info(f"Successfully retrieved {len(tools)} MCP tools") + return tools + except Exception as e: - logger.error(f"Error fetching MCP tools: {str(e)}") + logger.error(f"Error fetching MCP tools: {e}") logger.error(f"Exception type: {type(e).__name__}") - import traceback - logger.error(f"Full traceback: {traceback.format_exc()}") - raise MCPConnectionError( - f"Failed to connect to MCP server: {str(e)}" - ) - + raise def get_mcp_tools_sync( - server_path: Optional[str] = None, - format: str = "openai", - connection: Optional[MCPConnection] = None, + server_path: str, transport: Optional[str] = None, *args, **kwargs, ) -> List[Dict[str, Any]]: """ - Synchronous version of get_mcp_tools that handles event loop management. + Synchronous wrapper for aget_mcp_tools. + Args: - server_path (str): Path to the MCP server script. - format (str): Format to return tools in ('openai' or 'mcp'). - connection (Optional[MCPConnection]): Optional connection object. - transport (Optional[str]): Transport type. If None, auto-detects. + server_path: The server URL or path + transport: The transport type (auto-detected if None) + *args: Additional arguments + **kwargs: Additional keyword arguments + Returns: - List[Dict[str, Any]]: List of available MCP tools in OpenAI format. - Raises: - MCPValidationError: If server_path is invalid. - MCPConnectionError: If connection to server fails. - MCPExecutionError: If event loop management fails. + List of MCP tools """ - logger.info( - f"get_mcp_tools_sync called for server_path: {server_path}" - ) - if transport is None: - transport = auto_detect_transport(server_path) - with get_or_create_event_loop() as loop: + logger.info(f"get_mcp_tools_sync called for server_path: {server_path}") + + try: + # Get or create event loop try: - return loop.run_until_complete( - aget_mcp_tools( - server_path=server_path, - format=format, - connection=connection, - transport=transport, - *args, - **kwargs, - ) - ) - except Exception as e: - logger.error(f"Error in get_mcp_tools_sync: {str(e)}") - raise MCPExecutionError( - f"Failed to execute MCP tools sync: {str(e)}" + loop = asyncio.get_running_loop() + # If we're already in an async context, we need to handle this differently + logger.warning("Running in async context, creating new event loop") + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + # No running loop, create one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + return loop.run_until_complete( + aget_mcp_tools( + server_path=server_path, + transport=transport, + *args, + **kwargs, ) + ) + except Exception as e: + logger.error(f"Error in get_mcp_tools_sync: {e}") + logger.error(f"Full traceback: {traceback.format_exc()}") + raise - -def _fetch_tools_for_server( - url: str, - connection: Optional[MCPConnection] = None, - format: str = "openai", +async def execute_tool_call_simple( + server_path: str, + tool_name: str, + arguments: Dict[str, Any], transport: Optional[str] = None, -) -> List[Dict[str, Any]]: +) -> str: """ - Helper function to fetch tools for a single server. + Execute a simple tool call and return the result as a string. + Args: - url (str): The server URL. - connection (Optional[MCPConnection]): Optional connection object. - format (str): Format to return tools in. - transport (Optional[str]): Transport type. If None, auto-detects. + server_path: The server URL or path + tool_name: Name of the tool to call + arguments: Arguments for the tool + transport: The transport type (auto-detected if None) + Returns: - List[Dict[str, Any]]: List of available MCP tools. + Tool result as a string """ - logger.info(f"_fetch_tools_for_server called for url: {url}") + logger.info(f"execute_tool_call_simple called for server_path: {server_path}") + + # Auto-detect transport if not specified if transport is None: - transport = auto_detect_transport(url) - return get_mcp_tools_sync( - server_path=url, - connection=connection, - format=format, - transport=transport, - ) - + transport = auto_detect_transport(server_path) + + try: + # Get the appropriate client + client = get_mcp_client(transport, server_path) + + # Use the client as a context manager + async with client as (read_stream, write_stream): + # Create a session manually with the streams + session = ClientSession(read_stream, write_stream) + + # Initialize the session + await session.initialize() + + # Call the tool + result = await session.call_tool(tool_name, arguments) + + # Convert result to string + if result and hasattr(result, 'content') and result.content: + # Extract text content from the result + text_content = "" + for content_item in result.content: + if hasattr(content_item, 'text'): + text_content += content_item.text + return text_content + else: + return str(result) if result else "" + + except Exception as e: + logger.error(f"Error executing tool call: {e}") + logger.error(f"Full traceback: {traceback.format_exc()}") + return f"Error executing tool {tool_name}: {str(e)}" -def get_tools_for_multiple_mcp_servers( - urls: List[str], - connections: List[MCPConnection] = None, - format: str = "openai", - output_type: Literal["json", "dict", "str"] = "str", - max_workers: Optional[int] = None, +def execute_tool_call_simple_sync( + server_path: str, + tool_name: str, + arguments: Dict[str, Any], transport: Optional[str] = None, -) -> List[Dict[str, Any]]: +) -> str: """ - Get tools for multiple MCP servers concurrently using ThreadPoolExecutor. + Synchronous wrapper for execute_tool_call_simple. + Args: - urls (List[str]): List of server URLs to fetch tools from. - connections (List[MCPConnection]): Optional list of MCPConnection objects. - format (str): Format to return tools in. - output_type (Literal): Output format type. - max_workers (Optional[int]): Max worker threads. - transport (Optional[str]): Transport type. If None, auto-detects per URL. + server_path: The server URL or path + tool_name: Name of the tool to call + arguments: Arguments for the tool + transport: The transport type (auto-detected if None) + Returns: - List[Dict[str, Any]]: Combined list of tools from all servers. + Tool result as a string """ - logger.info( - f"get_tools_for_multiple_mcp_servers called for {len(urls)} urls." - ) - tools = [] - ( - min(32, os.cpu_count() + 4) - if max_workers is None - else max_workers - ) - with ThreadPoolExecutor(max_workers=max_workers) as executor: - if exists(connections): - future_to_url = { - executor.submit( - _fetch_tools_for_server, - url, - connection, - format, - transport, - ): url - for url, connection in zip(urls, connections) - } - else: - future_to_url = { - executor.submit( - _fetch_tools_for_server, - url, - None, - format, - transport, - ): url - for url in urls - } - for future in as_completed(future_to_url): - url = future_to_url[future] - try: - server_tools = future.result() - tools.extend(server_tools) - except Exception as e: - logger.error( - f"Error fetching tools from {url}: {str(e)}" - ) - raise MCPExecutionError( - f"Failed to fetch tools from {url}: {str(e)}" - ) - return tools + logger.info(f"execute_tool_call_simple_sync called for server_path: {server_path}") + + try: + # Get or create event loop + try: + loop = asyncio.get_running_loop() + # If we're already in an async context, we need to handle this differently + logger.warning("Running in async context, creating new event loop") + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + # No running loop, create one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + return loop.run_until_complete( + execute_tool_call_simple( + server_path=server_path, + tool_name=tool_name, + arguments=arguments, + transport=transport, + ) + ) + except Exception as e: + logger.error(f"Error in execute_tool_call_simple_sync: {e}") + logger.error(f"Full traceback: {traceback.format_exc()}") + return f"Error executing tool {tool_name}: {str(e)}" +# Advanced functionality - Tool call extraction and parsing +def _extract_tool_calls_from_response(response: str) -> List[Dict[str, Any]]: + """ + Extract tool calls from LLM response with advanced parsing capabilities. + + Args: + response: The response string from the LLM + + Returns: + List of tool call dictionaries + """ + tool_calls = [] + + try: + # Try to find JSON tool calls in code blocks + json_match = re.search(r'```json\s*(\{.*?\})\s*```', response, re.DOTALL) + if json_match: + try: + tool_data = json.loads(json_match.group(1)) + + # Check for tool_uses format (OpenAI format) + if "tool_uses" in tool_data and tool_data["tool_uses"]: + for tool_call in tool_data["tool_uses"]: + if "recipient_name" in tool_call: + tool_name = tool_call["recipient_name"] + arguments = tool_call.get("parameters", {}) + tool_calls.append({ + "name": tool_name, + "arguments": arguments + }) + + # Check for direct tool call format + elif "name" in tool_data and "arguments" in tool_data: + tool_calls.append({ + "name": tool_data["name"], + "arguments": tool_data["arguments"] + }) + + # Check for function_calls format + elif "function_calls" in tool_data and tool_data["function_calls"]: + for tool_call in tool_data["function_calls"]: + if "name" in tool_call and "arguments" in tool_call: + tool_calls.append({ + "name": tool_call["name"], + "arguments": tool_call["arguments"] + }) + + except json.JSONDecodeError: + pass + + # Try to find JSON tool calls without code blocks + if not tool_calls: + json_patterns = [ + r'\{[^{}]*"name"[^{}]*"arguments"[^{}]*\}', + r'\{[^{}]*"tool_uses"[^{}]*\}', + r'\{[^{}]*"function_calls"[^{}]*\}' + ] + + for pattern in json_patterns: + matches = re.findall(pattern, response, re.DOTALL) + for match in matches: + try: + tool_data = json.loads(match) + + # Check for tool_uses format + if "tool_uses" in tool_data and tool_data["tool_uses"]: + for tool_call in tool_data["tool_uses"]: + if "recipient_name" in tool_call: + tool_calls.append({ + "name": tool_call["recipient_name"], + "arguments": tool_call.get("parameters", {}) + }) + + # Check for direct tool call format + elif "name" in tool_data and "arguments" in tool_data: + tool_calls.append({ + "name": tool_data["name"], + "arguments": tool_data["arguments"] + }) + + # Check for function_calls format + elif "function_calls" in tool_data and tool_data["function_calls"]: + for tool_call in tool_data["function_calls"]: + if "name" in tool_call and "arguments" in tool_call: + tool_calls.append({ + "name": tool_call["name"], + "arguments": tool_call["arguments"] + }) + + except json.JSONDecodeError: + continue + + # If no JSON found, try to extract from text using pattern matching + if not tool_calls: + response_lower = response.lower() + + # Look for mathematical expressions + if "calculate" in response_lower or "compute" in response_lower or "math" in response_lower: + # Extract mathematical expression + expr_patterns = [ + r'(\d+\s*[\+\-\*\/\^]\s*\d+)', + r'calculate\s+(.+?)(?:\n|\.|$)', + r'compute\s+(.+?)(?:\n|\.|$)' + ] + + for pattern in expr_patterns: + expr_match = re.search(pattern, response, re.IGNORECASE) + if expr_match: + expression = expr_match.group(1).strip() + tool_calls.append({ + "name": "calculate", + "arguments": {"expression": expression} + }) + break + + # Default calculation if no expression found + if not any("calculate" in tc.get("name", "") for tc in tool_calls): + tool_calls.append({ + "name": "calculate", + "arguments": {"expression": "2+2"} + }) + + # Look for search operations + elif "search" in response_lower or "find" in response_lower or "look up" in response_lower: + # Extract search query + search_patterns = [ + r'search\s+for\s+(.+?)(?:\n|\.|$)', + r'find\s+(.+?)(?:\n|\.|$)', + r'look up\s+(.+?)(?:\n|\.|$)' + ] + + for pattern in search_patterns: + search_match = re.search(pattern, response, re.IGNORECASE) + if search_match: + query = search_match.group(1).strip() + tool_calls.append({ + "name": "search", + "arguments": {"query": query} + }) + break + + # Default search if no query found + if not any("search" in tc.get("name", "") for tc in tool_calls): + tool_calls.append({ + "name": "search", + "arguments": {"query": response.strip()} + }) + + # Look for file operations + elif "read" in response_lower or "file" in response_lower or "open" in response_lower: + # Extract file path + file_patterns = [ + r'read\s+(.+?)(?:\n|\.|$)', + r'open\s+(.+?)(?:\n|\.|$)', + r'file\s+(.+?)(?:\n|\.|$)' + ] + + for pattern in file_patterns: + file_match = re.search(pattern, response, re.IGNORECASE) + if file_match: + file_path = file_match.group(1).strip() + tool_calls.append({ + "name": "read_file", + "arguments": {"file_path": file_path} + }) + break + + # Look for web operations + elif "web" in response_lower or "url" in response_lower or "http" in response_lower: + # Extract URL + url_patterns = [ + r'https?://[^\s]+', + r'www\.[^\s]+', + r'url\s+(.+?)(?:\n|\.|$)' + ] + + for pattern in url_patterns: + url_match = re.search(pattern, response, re.IGNORECASE) + if url_match: + url = url_match.group(0) if pattern.startswith('http') else url_match.group(1).strip() + tool_calls.append({ + "name": "fetch_url", + "arguments": {"url": url} + }) + break + + # Default tool call if no specific patterns found + else: + tool_calls.append({ + "name": "default_tool", + "arguments": {"input": response.strip()} + }) + + except Exception as e: + logger.error(f"Error extracting tool calls: {e}") + # Return default tool call + tool_calls.append({ + "name": "default_tool", + "arguments": {"input": response.strip()} + }) + + return tool_calls -async def _execute_tool_call_simple( - response: any = None, - server_path: str = None, - connection: Optional[MCPConnection] = None, - output_type: Literal["json", "dict", "str"] = "str", +# Advanced function for handling complex responses with multiple tool calls +async def execute_tool_calls_from_response( + response: Any, + server_path: str, transport: Optional[str] = None, - *args, - **kwargs, -): + max_concurrent: int = 3 +) -> List[Dict[str, Any]]: """ - Execute a tool call using the MCP client, supporting both SSE and streamable HTTP. + Execute multiple tool calls extracted from an LLM response. + Args: - response (any): The tool call request. - server_path (str): The server URL. - connection (Optional[MCPConnection]): Optional connection object. - output_type (Literal): Output format type. - transport (Optional[str]): Transport type. If None, auto-detects. + response: The response from the LLM (may contain tool calls) + server_path: MCP server path/URL + transport: Transport type (auto-detected if None) + max_concurrent: Maximum concurrent tool executions + Returns: - The tool call result in the specified output format. - Raises: - MCPExecutionError, MCPConnectionError + List of tool execution results """ - logger.info( - f"_execute_tool_call_simple called for server_path: {server_path}" - ) - if transport is None: - transport = auto_detect_transport(server_path) - if exists(connection): - headers, timeout, transport_from_conn, url = ( - connect_to_mcp_server(connection) - ) - if transport_from_conn: - transport = transport_from_conn - else: - headers, timeout, _transport, url = ( - None, - 5, - "sse", - server_path, - ) try: - async with get_mcp_client( - transport, - url=url, - headers=headers, - timeout=timeout, - *args, - **kwargs, - ) as ctx: - if len(ctx) == 2: - read, write = ctx + # Extract tool calls from response + if isinstance(response, str): + tool_calls = _extract_tool_calls_from_response(response) + elif hasattr(response, 'choices') and response.choices: + # Handle OpenAI-style response objects + choice = response.choices[0] + if hasattr(choice, 'message') and hasattr(choice.message, 'tool_calls'): + tool_calls = [] + for tool_call in choice.message.tool_calls: + tool_calls.append({ + "name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments) + }) else: - read, write, *_ = ctx - async with ClientSession(read, write) as session: - try: - await session.initialize() - call_result = await call_openai_tool( - session=session, openai_tool=response - ) - - # Handle different output types with better error handling + tool_calls = _extract_tool_calls_from_response(str(response)) + else: + tool_calls = [{"name": "default_tool", "arguments": {}}] + + # Execute tool calls + results = [] + + if max_concurrent > 1 and len(tool_calls) > 1: + # Execute concurrently + semaphore = asyncio.Semaphore(max_concurrent) + + async def execute_single_tool(tool_call): + async with semaphore: try: - if output_type == "json": - out = call_result.model_dump_json(indent=4) - elif output_type == "dict": - out = call_result.model_dump() - elif output_type == "str": - # Try to get the content from the MCP response - try: - data = call_result.model_dump() - formatted_lines = [] - for key, value in data.items(): - if isinstance(value, list): - for item in value: - if isinstance(item, dict): - for k, v in item.items(): - formatted_lines.append( - f"{k}: {v}" - ) - else: - formatted_lines.append( - f"{key}: {value}" - ) - out = "\n".join(formatted_lines) - except Exception as format_error: - logger.warning(f"Error formatting MCP response: {format_error}") - # Fallback: try to get text content directly - try: - if hasattr(call_result, 'content') and call_result.content: - if isinstance(call_result.content, list) and len(call_result.content) > 0: - first_content = call_result.content[0] - if hasattr(first_content, 'text'): - out = first_content.text - else: - out = str(first_content) - else: - out = str(call_result.content) - else: - out = str(call_result) - except Exception as fallback_error: - logger.warning(f"Fallback formatting also failed: {fallback_error}") - out = str(call_result) - else: - out = call_result.model_dump() - except Exception as format_error: - logger.warning(f"Error in output formatting: {format_error}") - # Final fallback - out = str(call_result) - - logger.info( - f"Tool call executed successfully for {server_path}" + result = await execute_tool_call_simple( + server_path=server_path, + tool_name=tool_call["name"], + arguments=tool_call["arguments"], + transport=transport + ) + return { + "success": True, + "tool_name": tool_call["name"], + "arguments": tool_call["arguments"], + "result": result + } + except Exception as e: + logger.error(f"Error executing tool {tool_call['name']}: {e}") + return { + "success": False, + "tool_name": tool_call["name"], + "arguments": tool_call["arguments"], + "error": str(e) + } + + # Execute all tools concurrently + tasks = [execute_single_tool(tool_call) for tool_call in tool_calls] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Handle exceptions + final_results = [] + for result in results: + if isinstance(result, Exception): + final_results.append({ + "success": False, + "error": str(result) + }) + else: + final_results.append(result) + + results = final_results + + else: + # Execute sequentially + for tool_call in tool_calls: + try: + result = await execute_tool_call_simple( + server_path=server_path, + tool_name=tool_call["name"], + arguments=tool_call["arguments"], + transport=transport ) - return out + results.append({ + "success": True, + "tool_name": tool_call["name"], + "arguments": tool_call["arguments"], + "result": result + }) except Exception as e: - logger.error(f"Error in tool execution: {str(e)}") - raise MCPExecutionError( - f"Tool execution failed for tool '{getattr(response, 'function', {}).get('name', 'unknown')}' on server '{url}': {str(e)}" - ) + logger.error(f"Error executing tool {tool_call['name']}: {e}") + results.append({ + "success": False, + "tool_name": tool_call["name"], + "arguments": tool_call["arguments"], + "error": str(e) + }) + + return results + except Exception as e: - logger.error(f"Error in MCP client connection: {str(e)}") - raise MCPConnectionError( - f"Failed to connect to MCP server '{url}' using transport '{transport}': {str(e)}" - ) + logger.error(f"Error in execute_tool_calls_from_response: {e}") + return [{"success": False, "error": str(e)}] - -async def execute_tool_call_simple( - response: any = None, - server_path: str = None, - connection: Optional[MCPConnection] = None, - output_type: Literal["json", "dict", "str", "formatted"] = "str", +def execute_tool_calls_from_response_sync( + response: Any, + server_path: str, transport: Optional[str] = None, - *args, - **kwargs, + max_concurrent: int = 3 ) -> List[Dict[str, Any]]: """ - High-level async function to execute a tool call on an MCP server. + Synchronous wrapper for execute_tool_calls_from_response. + Args: - response (any): The tool call request. - server_path (str): The server URL. - connection (Optional[MCPConnection]): Optional connection object. - output_type (Literal): Output format type. - transport (Optional[str]): Transport type. If None, auto-detects. + response: The response from the LLM (may contain tool calls) + server_path: MCP server path/URL + transport: Transport type (auto-detected if None) + max_concurrent: Maximum concurrent tool executions + Returns: - The tool call result in the specified output format. + List of tool execution results + """ + try: + # Get or create event loop + try: + loop = asyncio.get_running_loop() + # If we're already in an async context, we need to handle this differently + logger.warning("Running in async context, creating new event loop") + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + # No running loop, create one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + return loop.run_until_complete( + execute_tool_calls_from_response( + response=response, + server_path=server_path, + transport=transport, + max_concurrent=max_concurrent + ) + ) + except Exception as e: + logger.error(f"Error in execute_tool_calls_from_response_sync: {e}") + return [{"success": False, "error": str(e)}] + +# Advanced streaming functionality +async def execute_tool_call_streaming( + server_path: str, + tool_name: str, + arguments: Dict[str, Any], + transport: Optional[str] = None +) -> AsyncGenerator[Dict[str, Any], None]: """ - logger.info( - f"execute_tool_call_simple called for server_path: {server_path}" - ) + Execute a tool call with streaming support. - # Validate response before processing - if response is None or response == "": - logger.warning("Empty or None response received, returning empty result") - return [] + Args: + server_path: The server URL or path + tool_name: Name of the tool to call + arguments: Arguments for the tool + transport: The transport type (auto-detected if None) + + Yields: + Streaming tool execution results + """ + logger.info(f"execute_tool_call_streaming called for server_path: {server_path}") + # Auto-detect transport if not specified if transport is None: transport = auto_detect_transport(server_path) - # Handle string responses with proper validation - if isinstance(response, str): - if not response.strip(): - logger.warning("Empty string response received, returning empty result") - return [] - try: - response = json.loads(response) - except json.JSONDecodeError as e: - logger.error(f"Failed to parse JSON response: {e}") - logger.error(f"Response content: {repr(response)}") - return [] - - return await _execute_tool_call_simple( - response=response, - server_path=server_path, - connection=connection, - output_type=output_type, - transport=transport, - *args, - **kwargs, - ) - + try: + # Get the appropriate client + client = get_mcp_client(transport, server_path) + + # Use the client as a context manager + async with client as (read_stream, write_stream): + # Create a session manually with the streams + session = ClientSession(read_stream, write_stream) + + # Initialize the session + await session.initialize() + + # Check if streaming method exists + if hasattr(session, 'call_tool_streaming'): + # Use streaming method if available + async for result in session.call_tool_streaming(tool_name, arguments): + yield { + "success": True, + "tool_name": tool_name, + "arguments": arguments, + "result": result.model_dump() if hasattr(result, 'model_dump') else str(result), + "streaming": True + } + else: + # Fallback to non-streaming + logger.warning("Streaming not available, falling back to non-streaming") + result = await session.call_tool(tool_name, arguments) + yield { + "success": True, + "tool_name": tool_name, + "arguments": arguments, + "result": result.model_dump() if hasattr(result, 'model_dump') else str(result), + "streaming": False + } + + except Exception as e: + logger.error(f"Error executing streaming tool call: {e}") + yield { + "success": False, + "tool_name": tool_name, + "arguments": arguments, + "error": str(e), + "streaming": False + } -def _create_server_tool_mapping( - urls: List[str], - connections: List[MCPConnection] = None, - format: str = "openai", - transport: Optional[str] = None, -) -> Dict[str, Dict[str, Any]]: +def execute_tool_call_streaming_sync( + server_path: str, + tool_name: str, + arguments: Dict[str, Any], + transport: Optional[str] = None +) -> List[Dict[str, Any]]: """ - Create a mapping of function names to server information for all MCP servers. + Synchronous wrapper for execute_tool_call_streaming. + Args: - urls (List[str]): List of server URLs. - connections (List[MCPConnection]): Optional list of MCPConnection objects. - format (str): Format to fetch tools in. - transport (Optional[str]): Transport type. If None, auto-detects per URL. + server_path: The server URL or path + tool_name: Name of the tool to call + arguments: Arguments for the tool + transport: The transport type (auto-detected if None) + Returns: - Dict[str, Dict[str, Any]]: Mapping of function names to server info. + List of streaming tool execution results """ - server_tool_mapping = {} - for i, url in enumerate(urls): - connection = ( - connections[i] - if connections and i < len(connections) - else None - ) + logger.info(f"execute_tool_call_streaming_sync called for server_path: {server_path}") + + try: + # Get or create event loop try: - tools = get_mcp_tools_sync( - server_path=url, - connection=connection, - format=format, - transport=transport, - ) - for tool in tools: - if isinstance(tool, dict) and "function" in tool: - function_name = tool["function"]["name"] - server_tool_mapping[function_name] = { - "url": url, - "connection": connection, - "tool": tool, - "server_index": i, - } - elif hasattr(tool, "name"): - server_tool_mapping[tool.name] = { - "url": url, - "connection": connection, - "tool": tool, - "server_index": i, - } - except Exception as e: - logger.warning( - f"Failed to fetch tools from server {url}: {str(e)}" - ) - continue - return server_tool_mapping - + loop = asyncio.get_running_loop() + # If we're already in an async context, we need to handle this differently + logger.warning("Running in async context, creating new event loop") + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + # No running loop, create one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + results = [] + + async def collect_streaming_results(): + async for result in execute_tool_call_streaming( + server_path=server_path, + tool_name=tool_name, + arguments=arguments, + transport=transport + ): + results.append(result) + + loop.run_until_complete(collect_streaming_results()) + return results + + except Exception as e: + logger.error(f"Error in execute_tool_call_streaming_sync: {e}") + return [{"success": False, "error": str(e)}] -async def _create_server_tool_mapping_async( - urls: List[str], - connections: List[MCPConnection] = None, - format: str = "openai", - transport: str = "sse", -) -> Dict[str, Dict[str, Any]]: +# Advanced multiple server functionality +async def get_tools_for_multiple_mcp_servers(server_paths: List[str]) -> Dict[str, List[Dict[str, Any]]]: """ - Async version: Create a mapping of function names to server information for all MCP servers. + Get tools from multiple MCP servers concurrently. + Args: - urls (List[str]): List of server URLs. - connections (List[MCPConnection]): Optional list of MCPConnection objects. - format (str): Format to fetch tools in. - transport (str): Transport type. + server_paths: List of server URLs or paths + Returns: - Dict[str, Dict[str, Any]]: Mapping of function names to server info. + Dictionary mapping server paths to their tools """ - server_tool_mapping = {} - for i, url in enumerate(urls): - connection = ( - connections[i] - if connections and i < len(connections) - else None - ) + logger.info(f"Getting tools from {len(server_paths)} MCP servers") + + async def get_tools_for_single_server(server_path: str) -> tuple: try: - tools = await aget_mcp_tools( - server_path=url, - connection=connection, - format=format, - transport=transport, - ) - for tool in tools: - if isinstance(tool, dict) and "function" in tool: - function_name = tool["function"]["name"] - server_tool_mapping[function_name] = { - "url": url, - "connection": connection, - "tool": tool, - "server_index": i, - } - elif hasattr(tool, "name"): - server_tool_mapping[tool.name] = { - "url": url, - "connection": connection, - "tool": tool, - "server_index": i, - } + tools = await aget_mcp_tools(server_path) + return server_path, tools except Exception as e: - logger.warning( - f"Failed to fetch tools from server {url}: {str(e)}" - ) - continue - return server_tool_mapping - + logger.error(f"Error getting tools from {server_path}: {e}") + return server_path, [] + + # Execute concurrently + tasks = [get_tools_for_single_server(server_path) for server_path in server_paths] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Process results + server_tools = {} + for result in results: + if isinstance(result, Exception): + logger.error(f"Exception in get_tools_for_multiple_mcp_servers: {result}") + else: + server_path, tools = result + server_tools[server_path] = tools + + return server_tools -async def _execute_tool_on_server( - tool_call: Dict[str, Any], - server_info: Dict[str, Any], - output_type: Literal["json", "dict", "str", "formatted"] = "str", - transport: str = "sse", -) -> Dict[str, Any]: +def get_tools_for_multiple_mcp_servers_sync(server_paths: List[str]) -> Dict[str, List[Dict[str, Any]]]: """ - Execute a single tool call on a specific server. + Synchronous wrapper for get_tools_for_multiple_mcp_servers. + Args: - tool_call (Dict[str, Any]): The tool call to execute. - server_info (Dict[str, Any]): Server information from the mapping. - output_type (Literal): Output format type. - transport (str): Transport type. + server_paths: List of server URLs or paths + Returns: - Dict[str, Any]: Execution result with server metadata. + Dictionary mapping server paths to their tools """ try: - result = await _execute_tool_call_simple( - response=tool_call, - server_path=server_info["url"], - connection=server_info["connection"], - output_type=output_type, - transport=transport, + # Get or create event loop + try: + loop = asyncio.get_running_loop() + # If we're already in an async context, we need to handle this differently + logger.warning("Running in async context, creating new event loop") + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + # No running loop, create one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + return loop.run_until_complete( + get_tools_for_multiple_mcp_servers(server_paths) ) - return { - "server_url": server_info["url"], - "server_index": server_info["server_index"], - "function_name": tool_call.get("function", {}).get( - "name", "unknown" - ), - "result": result, - "status": "success", - } except Exception as e: - logger.error( - f"Failed to execute tool on server {server_info['url']}: {str(e)}" - ) - return { - "server_url": server_info["url"], - "server_index": server_info["server_index"], - "function_name": tool_call.get("function", {}).get( - "name", "unknown" - ), - "result": None, - "error": f"Custom error: Failed to execute tool '{tool_call.get('function', {}).get('name', 'unknown')}' on server '{server_info['url']}': {str(e)}", - "status": "error", - } - + logger.error(f"Error in get_tools_for_multiple_mcp_servers_sync: {e}") + return {} async def execute_multiple_tools_on_multiple_mcp_servers( - responses: List[Dict[str, Any]], - urls: List[str], - connections: List[MCPConnection] = None, - output_type: Literal["json", "dict", "str", "formatted"] = "str", - max_concurrent: Optional[int] = None, - transport: str = "sse", - *args, - **kwargs, -) -> List[Dict[str, Any]]: + server_tool_mappings: Dict[str, List[str]], + tool_arguments: Dict[str, Dict[str, Any]] +) -> Dict[str, str]: """ - Execute multiple tool calls across multiple MCP servers. + Execute multiple tools on multiple servers concurrently. + Args: - responses (List[Dict[str, Any]]): List of tool call requests. - urls (List[str]): List of server URLs. - connections (List[MCPConnection]): Optional list of MCPConnection objects. - output_type (Literal): Output format type. - max_concurrent (Optional[int]): Max concurrent tasks. - transport (str): Transport type. + server_tool_mappings: Dictionary mapping server paths to lists of tool names + tool_arguments: Dictionary mapping tool names to their arguments + Returns: - List[Dict[str, Any]]: List of execution results. + Dictionary mapping tool names to their results """ - if not responses: - logger.warning("No responses provided for execution") - return [] - if not urls: - raise MCPValidationError("No server URLs provided") - logger.info( - f"Creating tool mapping for {len(urls)} servers using transport: {transport}" - ) - server_tool_mapping = await _create_server_tool_mapping_async( - urls=urls, - connections=connections, - format="openai", - transport=transport, - ) - if not server_tool_mapping: - raise MCPExecutionError( - "No tools found on any of the provided servers" - ) - logger.info( - f"Found {len(server_tool_mapping)} unique functions across all servers" - ) - all_tool_calls = [] - logger.info( - f"Processing {len(responses)} responses for tool call extraction" - ) - if len(responses) > 10 and all( - isinstance(r, str) and len(r) == 1 for r in responses - ): - logger.info( - "Detected character-by-character response, reconstructing JSON string" - ) + logger.info(f"Executing multiple tools on multiple servers") + + async def execute_tool_on_server(server_path: str, tool_name: str) -> tuple: try: - reconstructed_response = "".join(responses) - logger.info( - f"Reconstructed response length: {len(reconstructed_response)}" - ) - logger.debug( - f"Reconstructed response: {reconstructed_response}" - ) - try: - json.loads(reconstructed_response) - logger.info( - "Successfully validated reconstructed JSON response" - ) - except json.JSONDecodeError as e: - logger.warning( - f"Reconstructed response is not valid JSON: {str(e)}" - ) - logger.debug( - f"First 100 chars: {reconstructed_response[:100]}" - ) - logger.debug( - f"Last 100 chars: {reconstructed_response[-100:]}" - ) - responses = [reconstructed_response] + arguments = tool_arguments.get(tool_name, {}) + result = await execute_tool_call_simple(server_path, tool_name, arguments) + return tool_name, result except Exception as e: - logger.warning( - f"Failed to reconstruct response from characters: {str(e)}" - ) - for i, response in enumerate(responses): - logger.debug( - f"Processing response {i}: {type(response)} - {response}" - ) - if isinstance(response, str): - try: - response = json.loads(response) - logger.debug( - f"Parsed JSON string response {i}: {response}" - ) - except json.JSONDecodeError: - logger.warning( - f"Failed to parse JSON response at index {i}: {response}" - ) - continue - if isinstance(response, dict): - if "function" in response: - logger.debug( - f"Found single tool call in response {i}: {response['function']}" - ) - if isinstance( - response["function"].get("arguments"), str - ): - try: - response["function"]["arguments"] = ( - json.loads( - response["function"]["arguments"] - ) - ) - logger.debug( - f"Parsed function arguments: {response['function']['arguments']}" - ) - except json.JSONDecodeError: - logger.warning( - f"Failed to parse function arguments: {response['function']['arguments']}" - ) - all_tool_calls.append((i, response)) - elif "tool_calls" in response: - logger.debug( - f"Found multiple tool calls in response {i}: {len(response['tool_calls'])} calls" - ) - for tool_call in response["tool_calls"]: - if isinstance( - tool_call.get("function", {}).get( - "arguments" - ), - str, - ): - try: - tool_call["function"]["arguments"] = ( - json.loads( - tool_call["function"]["arguments"] - ) - ) - logger.debug( - f"Parsed tool call arguments: {tool_call['function']['arguments']}" - ) - except json.JSONDecodeError: - logger.warning( - f"Failed to parse tool call arguments: {tool_call['function']['arguments']}" - ) - all_tool_calls.append((i, tool_call)) - elif "name" in response and "arguments" in response: - logger.debug( - f"Found direct tool call in response {i}: {response}" - ) - if isinstance(response.get("arguments"), str): - try: - response["arguments"] = json.loads( - response["arguments"] - ) - logger.debug( - f"Parsed direct tool call arguments: {response['arguments']}" - ) - except json.JSONDecodeError: - logger.warning( - f"Failed to parse direct tool call arguments: {response['arguments']}" - ) - all_tool_calls.append((i, {"function": response})) - else: - logger.debug( - f"Response {i} is a dict but doesn't match expected tool call formats: {list(response.keys())}" - ) - else: - logger.warning( - f"Unsupported response type at index {i}: {type(response)}" - ) - continue - if not all_tool_calls: - logger.warning("No tool calls found in responses") - return [] - logger.info(f"Found {len(all_tool_calls)} tool calls to execute") - max_concurrent = max_concurrent or len(all_tool_calls) - semaphore = asyncio.Semaphore(max_concurrent) - - async def execute_with_semaphore(tool_call_info): - async with semaphore: - response_index, tool_call = tool_call_info - function_name = tool_call.get("function", {}).get( - "name", "unknown" - ) - if function_name not in server_tool_mapping: - logger.warning( - f"Function '{function_name}' not found on any server" - ) - return { - "response_index": response_index, - "function_name": function_name, - "result": None, - "error": f"Function '{function_name}' not available on any server", - "status": "not_found", - } - server_info = server_tool_mapping[function_name] - result = await _execute_tool_on_server( - tool_call=tool_call, - server_info=server_info, - output_type=output_type, - transport=transport, - ) - result["response_index"] = response_index - return result - - tasks = [ - execute_with_semaphore(tool_call_info) - for tool_call_info in all_tool_calls - ] + logger.error(f"Error executing tool {tool_name} on {server_path}: {e}") + return tool_name, f"Error: {str(e)}" + + # Create tasks for all tool executions + tasks = [] + for server_path, tool_names in server_tool_mappings.items(): + for tool_name in tool_names: + if tool_name in tool_arguments: + tasks.append(execute_tool_on_server(server_path, tool_name)) + + # Execute concurrently results = await asyncio.gather(*tasks, return_exceptions=True) - processed_results = [] - for i, result in enumerate(results): + + # Process results + tool_results = {} + for result in results: if isinstance(result, Exception): - logger.error( - f"Task {i} failed with exception: {str(result)}" - ) - processed_results.append( - { - "response_index": ( - all_tool_calls[i][0] - if i < len(all_tool_calls) - else -1 - ), - "function_name": "unknown", - "result": None, - "error": str(result), - "status": "exception", - } - ) + logger.error(f"Exception in execute_multiple_tools_on_multiple_mcp_servers: {result}") else: - processed_results.append(result) - logger.info( - f"Completed execution of {len(processed_results)} tool calls" - ) - return processed_results - + tool_name, result_value = result + tool_results[tool_name] = result_value + + return tool_results def execute_multiple_tools_on_multiple_mcp_servers_sync( - responses: List[Dict[str, Any]], - urls: List[str], - connections: List[MCPConnection] = None, - output_type: Literal["json", "dict", "str", "formatted"] = "str", - max_concurrent: Optional[int] = None, - transport: str = "sse", - *args, - **kwargs, -) -> List[Dict[str, Any]]: + server_tool_mappings: Dict[str, List[str]], + tool_arguments: Dict[str, Dict[str, Any]] +) -> Dict[str, str]: """ - Synchronous version of execute_multiple_tools_on_multiple_mcp_servers. + Synchronous wrapper for execute_multiple_tools_on_multiple_mcp_servers. + Args: - responses (List[Dict[str, Any]]): List of tool call requests. - urls (List[str]): List of server URLs. - connections (List[MCPConnection]): Optional list of MCPConnection objects. - output_type (Literal): Output format type. - max_concurrent (Optional[int]): Max concurrent tasks. - transport (str): Transport type. + server_tool_mappings: Dictionary mapping server paths to lists of tool names + tool_arguments: Dictionary mapping tool names to their arguments + Returns: - List[Dict[str, Any]]: List of execution results. + Dictionary mapping tool names to their results """ - with get_or_create_event_loop() as loop: + try: + # Get or create event loop try: - return loop.run_until_complete( - execute_multiple_tools_on_multiple_mcp_servers( - responses=responses, - urls=urls, - connections=connections, - output_type=output_type, - max_concurrent=max_concurrent, - transport=transport, - *args, - **kwargs, - ) - ) - except Exception as e: - logger.error( - f"Error in execute_multiple_tools_on_multiple_mcp_servers_sync: {str(e)}" - ) - raise MCPExecutionError( - f"Failed to execute multiple tools sync: {str(e)}" + loop = asyncio.get_running_loop() + # If we're already in an async context, we need to handle this differently + logger.warning("Running in async context, creating new event loop") + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + # No running loop, create one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + return loop.run_until_complete( + execute_multiple_tools_on_multiple_mcp_servers( + server_tool_mappings=server_tool_mappings, + tool_arguments=tool_arguments ) + ) + except Exception as e: + logger.error(f"Error in execute_multiple_tools_on_multiple_mcp_servers_sync: {e}") + return {} + +# Compatibility functions for backward compatibility +def _create_server_tool_mapping(server_path: str) -> Dict[str, Any]: + """Create a mapping of tools for a server (placeholder).""" + logger.warning("_create_server_tool_mapping is deprecated") + return {} + +async def _create_server_tool_mapping_async(server_path: str) -> Dict[str, Any]: + """Create a mapping of tools for a server asynchronously (placeholder).""" + logger.warning("_create_server_tool_mapping_async is deprecated") + return {} + +def _execute_tool_call_simple(server_path: str, tool_name: str, arguments: Dict[str, Any]) -> str: + """Execute a tool call (synchronous wrapper).""" + return execute_tool_call_simple_sync(server_path, tool_name, arguments) + +async def _execute_tool_on_server(server_path: str, tool_name: str, arguments: Dict[str, Any]) -> str: + """Execute a tool on a server (asynchronous).""" + return await execute_tool_call_simple(server_path, tool_name, arguments) + +# Compatibility function for the agent's response parameter +async def execute_tool_call_simple_with_response(response: Any, server_path: str) -> str: + """ + Compatibility function that handles the response parameter from the agent. + + Args: + response: The response from the LLM (contains tool call info) + server_path: The server URL or path + + Returns: + Tool result as a string + """ + try: + # Extract tool name and arguments from the response + if hasattr(response, 'choices') and response.choices: + choice = response.choices[0] + if hasattr(choice, 'message') and hasattr(choice.message, 'tool_calls'): + tool_calls = choice.message.tool_calls + if tool_calls: + tool_call = tool_calls[0] + tool_name = tool_call.function.name + arguments = json.loads(tool_call.function.arguments) + + return await execute_tool_call_simple(server_path, tool_name, arguments) + + # Fallback: try to parse as JSON if it's a string + if isinstance(response, str): + try: + data = json.loads(response) + if 'tool_name' in data and 'arguments' in data: + return await execute_tool_call_simple(server_path, data['tool_name'], data['arguments']) + except json.JSONDecodeError: + pass + + # If we can't extract tool info, return an error message + return f"Error: Could not extract tool information from response: {type(response)}" + + except Exception as e: + logger.error(f"Error in execute_tool_call_simple_with_response: {e}") + return f"Error executing tool: {str(e)}" + +def get_or_create_event_loop(): + """ + Get the current event loop or create a new one if none exists. + + Returns: + The event loop context manager + """ + try: + loop = asyncio.get_running_loop() + # If we're already in an event loop, return a context manager that does nothing + class NoOpContextManager: + def __enter__(self): + return loop + def __exit__(self, exc_type, exc_val, exc_tb): + pass + return NoOpContextManager() + except RuntimeError: + # No running loop, create one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + class LoopContextManager: + def __init__(self, loop): + self.loop = loop + def __enter__(self): + return self.loop + def __exit__(self, exc_type, exc_val, exc_tb): + try: + self.loop.close() + except: + pass + + return LoopContextManager(loop)