You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/swarms/tools/mcp_client_call.py

1016 lines
32 KiB

import os
import asyncio
import contextlib
import json
import random
from functools import wraps
from typing import Any, Dict, List, Literal, Optional, Union
from concurrent.futures import ThreadPoolExecutor, as_completed
from litellm.types.utils import ChatCompletionMessageToolCall
from loguru import logger
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.types import (
CallToolRequestParams as MCPCallToolRequestParams,
)
from mcp.types import CallToolResult as MCPCallToolResult
from mcp.types import Tool as MCPTool
from openai.types.chat import ChatCompletionToolParam
from openai.types.shared_params.function_definition import (
FunctionDefinition,
)
from swarms.schemas.mcp_schemas import (
MCPConnection,
)
from swarms.utils.index import exists
class MCPError(Exception):
"""Base exception for MCP related errors."""
pass
class MCPConnectionError(MCPError):
"""Raised when there are issues connecting to the MCP server."""
pass
class MCPToolError(MCPError):
"""Raised when there are issues with MCP tool operations."""
pass
class MCPValidationError(MCPError):
"""Raised when there are validation issues with MCP operations."""
pass
class MCPExecutionError(MCPError):
"""Raised when there are issues executing MCP operations."""
pass
########################################################
# List MCP Tool functions
########################################################
def transform_mcp_tool_to_openai_tool(
mcp_tool: MCPTool,
) -> ChatCompletionToolParam:
"""Convert an MCP tool to an OpenAI tool."""
return ChatCompletionToolParam(
type="function",
function=FunctionDefinition(
name=mcp_tool.name,
description=mcp_tool.description or "",
parameters=mcp_tool.inputSchema,
strict=False,
),
)
async def load_mcp_tools(
session: ClientSession, format: Literal["mcp", "openai"] = "mcp"
) -> Union[List[MCPTool], List[ChatCompletionToolParam]]:
"""
Load all available MCP tools
Args:
session: The MCP session to use
format: The format to convert the tools to
By default, the tools are returned in MCP format.
If format is set to "openai", the tools are converted to OpenAI API compatible tools.
"""
tools = await session.list_tools()
if format == "openai":
return [
transform_mcp_tool_to_openai_tool(mcp_tool=tool)
for tool in tools.tools
]
return tools.tools
########################################################
# Call MCP Tool functions
########################################################
async def call_mcp_tool(
session: ClientSession,
call_tool_request_params: MCPCallToolRequestParams,
) -> MCPCallToolResult:
"""Call an MCP tool."""
tool_result = await session.call_tool(
name=call_tool_request_params.name,
arguments=call_tool_request_params.arguments,
)
return tool_result
def _get_function_arguments(function: FunctionDefinition) -> dict:
"""Helper to safely get and parse function arguments."""
arguments = function.get("arguments", {})
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
return arguments if isinstance(arguments, dict) else {}
def transform_openai_tool_call_request_to_mcp_tool_call_request(
openai_tool: Union[ChatCompletionMessageToolCall, Dict],
) -> MCPCallToolRequestParams:
"""Convert an OpenAI ChatCompletionMessageToolCall to an MCP CallToolRequestParams."""
function = openai_tool["function"]
return MCPCallToolRequestParams(
name=function["name"],
arguments=_get_function_arguments(function),
)
async def call_openai_tool(
session: ClientSession,
openai_tool: dict,
) -> MCPCallToolResult:
"""
Call an OpenAI tool using MCP client.
Args:
session: The MCP session to use
openai_tool: The OpenAI tool to call. You can get this from the `choices[0].message.tool_calls[0]` of the response from the OpenAI API.
Returns:
The result of the MCP tool call.
"""
mcp_tool_call_request_params = (
transform_openai_tool_call_request_to_mcp_tool_call_request(
openai_tool=openai_tool,
)
)
return await call_mcp_tool(
session=session,
call_tool_request_params=mcp_tool_call_request_params,
)
def retry_with_backoff(retries=3, backoff_in_seconds=1):
"""Decorator for retrying functions with exponential backoff."""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
x = 0
while True:
try:
return await func(*args, **kwargs)
except Exception as e:
if x == retries:
logger.error(
f"Failed after {retries} retries: {str(e)}"
)
raise
sleep_time = (
backoff_in_seconds * 2**x
+ random.uniform(0, 1)
)
logger.warning(
f"Attempt {x + 1} failed, retrying in {sleep_time:.2f}s"
)
await asyncio.sleep(sleep_time)
x += 1
return wrapper
return decorator
@contextlib.contextmanager
def get_or_create_event_loop():
"""Context manager to handle event loop creation and cleanup."""
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
yield loop
finally:
# Only close the loop if we created it and it's not the main event loop
if loop != asyncio.get_event_loop() and not loop.is_running():
if not loop.is_closed():
loop.close()
def connect_to_mcp_server(connection: MCPConnection = None):
"""Connect to an MCP server.
Args:
connection (MCPConnection): The connection configuration object
Returns:
tuple: A tuple containing (headers, timeout, transport, url)
Raises:
MCPValidationError: If the connection object is invalid
"""
if not isinstance(connection, MCPConnection):
raise MCPValidationError("Invalid connection type")
# Direct attribute access is faster than property access
headers = dict(connection.headers or {})
if connection.authorization_token:
headers["Authorization"] = (
f"Bearer {connection.authorization_token}"
)
return (
headers,
connection.timeout or 5,
connection.transport or "sse",
connection.url,
)
@retry_with_backoff(retries=3)
async def aget_mcp_tools(
server_path: Optional[str] = None,
format: str = "openai",
connection: Optional[MCPConnection] = None,
*args,
**kwargs,
) -> List[Dict[str, Any]]:
"""
Fetch available MCP tools from the server with retry logic.
Args:
server_path (str): Path to the MCP server script
Returns:
List[Dict[str, Any]]: List of available MCP tools in OpenAI format
Raises:
MCPValidationError: If server_path is invalid
MCPConnectionError: If connection to server fails
"""
if exists(connection):
headers, timeout, transport, url = connect_to_mcp_server(
connection
)
else:
headers, timeout, _transport, _url = (
None,
5,
None,
server_path,
)
logger.info(f"Fetching MCP tools from server: {server_path}")
try:
async with sse_client(
url=server_path,
headers=headers,
timeout=timeout,
*args,
**kwargs,
) as (
read,
write,
):
async with ClientSession(read, write) as session:
await session.initialize()
tools = await load_mcp_tools(
session=session, format=format
)
logger.info(
f"Successfully fetched {len(tools)} tools"
)
return tools
except Exception as e:
logger.error(f"Error fetching MCP tools: {str(e)}")
raise MCPConnectionError(
f"Failed to connect to MCP server: {str(e)}"
)
def get_mcp_tools_sync(
server_path: Optional[str] = None,
format: str = "openai",
connection: Optional[MCPConnection] = None,
*args,
**kwargs,
) -> List[Dict[str, Any]]:
"""
Synchronous version of get_mcp_tools that handles event loop management.
Args:
server_path (str): Path to the MCP server script
Returns:
List[Dict[str, Any]]: List of available MCP tools in OpenAI format
Raises:
MCPValidationError: If server_path is invalid
MCPConnectionError: If connection to server fails
MCPExecutionError: If event loop management fails
"""
with get_or_create_event_loop() as loop:
try:
return loop.run_until_complete(
aget_mcp_tools(
server_path=server_path,
format=format,
connection=connection,
*args,
**kwargs,
)
)
except Exception as e:
logger.error(f"Error in get_mcp_tools_sync: {str(e)}")
raise MCPExecutionError(
f"Failed to execute MCP tools sync: {str(e)}"
)
def _fetch_tools_for_server(
url: str,
connection: Optional[MCPConnection] = None,
format: str = "openai",
) -> List[Dict[str, Any]]:
"""Helper function to fetch tools for a single server."""
return get_mcp_tools_sync(
server_path=url,
connection=connection,
format=format,
)
def get_tools_for_multiple_mcp_servers(
urls: List[str],
connections: List[MCPConnection] = None,
format: str = "openai",
output_type: Literal["json", "dict", "str"] = "str",
max_workers: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""Get tools for multiple MCP servers concurrently using ThreadPoolExecutor.
Args:
urls: List of server URLs to fetch tools from
connections: Optional list of MCPConnection objects corresponding to each URL
format: Format to return tools in (default: "openai")
output_type: Type of output format (default: "str")
max_workers: Maximum number of worker threads (default: None, uses min(32, os.cpu_count() + 4))
Returns:
List[Dict[str, Any]]: Combined list of tools from all servers
"""
tools = []
(
min(32, os.cpu_count() + 4)
if max_workers is None
else max_workers
)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
if exists(connections):
# Create future tasks for each URL-connection pair
future_to_url = {
executor.submit(
_fetch_tools_for_server, url, connection, format
): url
for url, connection in zip(urls, connections)
}
else:
# Create future tasks for each URL without connections
future_to_url = {
executor.submit(
_fetch_tools_for_server, url, None, format
): url
for url in urls
}
# Process completed futures as they come in
for future in as_completed(future_to_url):
url = future_to_url[future]
try:
server_tools = future.result()
tools.extend(server_tools)
except Exception as e:
logger.error(
f"Error fetching tools from {url}: {str(e)}"
)
raise MCPExecutionError(
f"Failed to fetch tools from {url}: {str(e)}"
)
return tools
async def _execute_tool_call_simple(
response: any = None,
server_path: str = None,
connection: Optional[MCPConnection] = None,
output_type: Literal["json", "dict", "str"] = "str",
*args,
**kwargs,
):
"""Execute a tool call using the MCP client."""
if exists(connection):
headers, timeout, transport, url = connect_to_mcp_server(
connection
)
else:
headers, timeout, _transport, url = (
None,
5,
"sse",
server_path,
)
try:
async with sse_client(
url=url, headers=headers, timeout=timeout, *args, **kwargs
) as (
read,
write,
):
async with ClientSession(read, write) as session:
try:
await session.initialize()
call_result = await call_openai_tool(
session=session,
openai_tool=response,
)
if output_type == "json":
out = call_result.model_dump_json(indent=4)
elif output_type == "dict":
out = call_result.model_dump()
elif output_type == "str":
data = call_result.model_dump()
formatted_lines = []
for key, value in data.items():
if isinstance(value, list):
for item in value:
if isinstance(item, dict):
for k, v in item.items():
formatted_lines.append(
f"{k}: {v}"
)
else:
formatted_lines.append(
f"{key}: {value}"
)
out = "\n".join(formatted_lines)
return out
except Exception as e:
logger.error(f"Error in tool execution: {str(e)}")
raise MCPExecutionError(
f"Tool execution failed: {str(e)}"
)
except Exception as e:
logger.error(f"Error in SSE client connection: {str(e)}")
raise MCPConnectionError(
f"Failed to connect to MCP server: {str(e)}"
)
async def execute_tool_call_simple(
response: any = None,
server_path: str = None,
connection: Optional[MCPConnection] = None,
output_type: Literal["json", "dict", "str", "formatted"] = "str",
*args,
**kwargs,
) -> List[Dict[str, Any]]:
if isinstance(response, str):
response = json.loads(response)
return await _execute_tool_call_simple(
response=response,
server_path=server_path,
connection=connection,
output_type=output_type,
*args,
**kwargs,
)
def _create_server_tool_mapping(
urls: List[str],
connections: List[MCPConnection] = None,
format: str = "openai",
) -> Dict[str, Dict[str, Any]]:
"""
Create a mapping of function names to server information for all MCP servers.
Args:
urls: List of server URLs
connections: Optional list of MCPConnection objects
format: Format to fetch tools in
Returns:
Dict mapping function names to server info (url, connection, tool)
"""
server_tool_mapping = {}
for i, url in enumerate(urls):
connection = (
connections[i]
if connections and i < len(connections)
else None
)
try:
# Get tools for this server
tools = get_mcp_tools_sync(
server_path=url,
connection=connection,
format=format,
)
# Create mapping for each tool
for tool in tools:
if isinstance(tool, dict) and "function" in tool:
function_name = tool["function"]["name"]
server_tool_mapping[function_name] = {
"url": url,
"connection": connection,
"tool": tool,
"server_index": i,
}
elif hasattr(tool, "name"):
# Handle MCPTool objects
server_tool_mapping[tool.name] = {
"url": url,
"connection": connection,
"tool": tool,
"server_index": i,
}
except Exception as e:
logger.warning(
f"Failed to fetch tools from server {url}: {str(e)}"
)
continue
return server_tool_mapping
async def _create_server_tool_mapping_async(
urls: List[str],
connections: List[MCPConnection] = None,
format: str = "openai",
) -> Dict[str, Dict[str, Any]]:
"""
Async version: Create a mapping of function names to server information for all MCP servers.
Args:
urls: List of server URLs
connections: Optional list of MCPConnection objects
format: Format to fetch tools in
Returns:
Dict mapping function names to server info (url, connection, tool)
"""
server_tool_mapping = {}
for i, url in enumerate(urls):
connection = (
connections[i]
if connections and i < len(connections)
else None
)
try:
# Get tools for this server using async function
tools = await aget_mcp_tools(
server_path=url,
connection=connection,
format=format,
)
# Create mapping for each tool
for tool in tools:
if isinstance(tool, dict) and "function" in tool:
function_name = tool["function"]["name"]
server_tool_mapping[function_name] = {
"url": url,
"connection": connection,
"tool": tool,
"server_index": i,
}
elif hasattr(tool, "name"):
# Handle MCPTool objects
server_tool_mapping[tool.name] = {
"url": url,
"connection": connection,
"tool": tool,
"server_index": i,
}
except Exception as e:
logger.warning(
f"Failed to fetch tools from server {url}: {str(e)}"
)
continue
return server_tool_mapping
async def _execute_tool_on_server(
tool_call: Dict[str, Any],
server_info: Dict[str, Any],
output_type: Literal["json", "dict", "str", "formatted"] = "str",
) -> Dict[str, Any]:
"""
Execute a single tool call on a specific server.
Args:
tool_call: The tool call to execute
server_info: Server information from the mapping
output_type: Output format type
Returns:
Execution result with server metadata
"""
try:
result = await _execute_tool_call_simple(
response=tool_call,
server_path=server_info["url"],
connection=server_info["connection"],
output_type=output_type,
)
return {
"server_url": server_info["url"],
"server_index": server_info["server_index"],
"function_name": tool_call.get("function", {}).get(
"name", "unknown"
),
"result": result,
"status": "success",
}
except Exception as e:
logger.error(
f"Failed to execute tool on server {server_info['url']}: {str(e)}"
)
return {
"server_url": server_info["url"],
"server_index": server_info["server_index"],
"function_name": tool_call.get("function", {}).get(
"name", "unknown"
),
"result": None,
"error": str(e),
"status": "error",
}
async def execute_multiple_tools_on_multiple_mcp_servers(
responses: List[Dict[str, Any]],
urls: List[str],
connections: List[MCPConnection] = None,
output_type: Literal["json", "dict", "str", "formatted"] = "str",
max_concurrent: Optional[int] = None,
*args,
**kwargs,
) -> List[Dict[str, Any]]:
"""
Execute multiple tool calls across multiple MCP servers.
This function creates a mapping of function names to servers, then for each response
that contains tool calls, it finds the appropriate server for each function and
executes the calls concurrently.
Args:
responses: List of responses containing tool calls (OpenAI format)
urls: List of MCP server URLs
connections: Optional list of MCPConnection objects corresponding to each URL
output_type: Output format type for results
max_concurrent: Maximum number of concurrent executions (default: len(responses))
Returns:
List of execution results with server metadata
Example:
# Example responses format:
responses = [
{
"function": {
"name": "search_web",
"arguments": {"query": "python programming"}
}
},
{
"function": {
"name": "search_database",
"arguments": {"table": "users", "id": 123}
}
}
]
urls = ["http://server1:8000", "http://server2:8000"]
results = await execute_multiple_tools_on_multiple_mcp_servers(
responses=responses,
urls=urls
)
"""
if not responses:
logger.warning("No responses provided for execution")
return []
if not urls:
raise MCPValidationError("No server URLs provided")
# Create mapping of function names to servers using async version
logger.info(f"Creating tool mapping for {len(urls)} servers")
server_tool_mapping = await _create_server_tool_mapping_async(
urls=urls, connections=connections, format="openai"
)
if not server_tool_mapping:
raise MCPExecutionError(
"No tools found on any of the provided servers"
)
logger.info(
f"Found {len(server_tool_mapping)} unique functions across all servers"
)
# Extract all tool calls from responses
all_tool_calls = []
logger.info(
f"Processing {len(responses)} responses for tool call extraction"
)
# Check if responses are individual characters that need to be reconstructed
if len(responses) > 10 and all(
isinstance(r, str) and len(r) == 1 for r in responses
):
logger.info(
"Detected character-by-character response, reconstructing JSON string"
)
try:
reconstructed_response = "".join(responses)
logger.info(
f"Reconstructed response length: {len(reconstructed_response)}"
)
logger.debug(
f"Reconstructed response: {reconstructed_response}"
)
# Try to parse the reconstructed response to validate it
try:
json.loads(reconstructed_response)
logger.info(
"Successfully validated reconstructed JSON response"
)
except json.JSONDecodeError as e:
logger.warning(
f"Reconstructed response is not valid JSON: {str(e)}"
)
logger.debug(
f"First 100 chars: {reconstructed_response[:100]}"
)
logger.debug(
f"Last 100 chars: {reconstructed_response[-100:]}"
)
responses = [reconstructed_response]
except Exception as e:
logger.warning(
f"Failed to reconstruct response from characters: {str(e)}"
)
for i, response in enumerate(responses):
logger.debug(
f"Processing response {i}: {type(response)} - {response}"
)
# Handle JSON string responses
if isinstance(response, str):
try:
response = json.loads(response)
logger.debug(
f"Parsed JSON string response {i}: {response}"
)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse JSON response at index {i}: {response}"
)
continue
if isinstance(response, dict):
# Single tool call
if "function" in response:
logger.debug(
f"Found single tool call in response {i}: {response['function']}"
)
# Parse arguments if they're a JSON string
if isinstance(
response["function"].get("arguments"), str
):
try:
response["function"]["arguments"] = (
json.loads(
response["function"]["arguments"]
)
)
logger.debug(
f"Parsed function arguments: {response['function']['arguments']}"
)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse function arguments: {response['function']['arguments']}"
)
all_tool_calls.append((i, response))
# Multiple tool calls
elif "tool_calls" in response:
logger.debug(
f"Found multiple tool calls in response {i}: {len(response['tool_calls'])} calls"
)
for tool_call in response["tool_calls"]:
# Parse arguments if they're a JSON string
if isinstance(
tool_call.get("function", {}).get(
"arguments"
),
str,
):
try:
tool_call["function"]["arguments"] = (
json.loads(
tool_call["function"]["arguments"]
)
)
logger.debug(
f"Parsed tool call arguments: {tool_call['function']['arguments']}"
)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse tool call arguments: {tool_call['function']['arguments']}"
)
all_tool_calls.append((i, tool_call))
# Direct tool call
elif "name" in response and "arguments" in response:
logger.debug(
f"Found direct tool call in response {i}: {response}"
)
# Parse arguments if they're a JSON string
if isinstance(response.get("arguments"), str):
try:
response["arguments"] = json.loads(
response["arguments"]
)
logger.debug(
f"Parsed direct tool call arguments: {response['arguments']}"
)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse direct tool call arguments: {response['arguments']}"
)
all_tool_calls.append((i, {"function": response}))
else:
logger.debug(
f"Response {i} is a dict but doesn't match expected tool call formats: {list(response.keys())}"
)
else:
logger.warning(
f"Unsupported response type at index {i}: {type(response)}"
)
continue
if not all_tool_calls:
logger.warning("No tool calls found in responses")
return []
logger.info(f"Found {len(all_tool_calls)} tool calls to execute")
# Execute tool calls concurrently
max_concurrent = max_concurrent or len(all_tool_calls)
semaphore = asyncio.Semaphore(max_concurrent)
async def execute_with_semaphore(tool_call_info):
async with semaphore:
response_index, tool_call = tool_call_info
function_name = tool_call.get("function", {}).get(
"name", "unknown"
)
if function_name not in server_tool_mapping:
logger.warning(
f"Function '{function_name}' not found on any server"
)
return {
"response_index": response_index,
"function_name": function_name,
"result": None,
"error": f"Function '{function_name}' not available on any server",
"status": "not_found",
}
server_info = server_tool_mapping[function_name]
result = await _execute_tool_on_server(
tool_call=tool_call,
server_info=server_info,
output_type=output_type,
)
result["response_index"] = response_index
return result
# Execute all tool calls concurrently
tasks = [
execute_with_semaphore(tool_call_info)
for tool_call_info in all_tool_calls
]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Process results and handle exceptions
processed_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
logger.error(
f"Task {i} failed with exception: {str(result)}"
)
processed_results.append(
{
"response_index": (
all_tool_calls[i][0]
if i < len(all_tool_calls)
else -1
),
"function_name": "unknown",
"result": None,
"error": str(result),
"status": "exception",
}
)
else:
processed_results.append(result)
logger.info(
f"Completed execution of {len(processed_results)} tool calls"
)
return processed_results
def execute_multiple_tools_on_multiple_mcp_servers_sync(
responses: List[Dict[str, Any]],
urls: List[str],
connections: List[MCPConnection] = None,
output_type: Literal["json", "dict", "str", "formatted"] = "str",
max_concurrent: Optional[int] = None,
*args,
**kwargs,
) -> List[Dict[str, Any]]:
"""
Synchronous version of execute_multiple_tools_on_multiple_mcp_servers.
Args:
responses: List of responses containing tool calls (OpenAI format)
urls: List of MCP server URLs
connections: Optional list of MCPConnection objects corresponding to each URL
output_type: Output format type for results
max_concurrent: Maximum number of concurrent executions
Returns:
List of execution results with server metadata
"""
with get_or_create_event_loop() as loop:
try:
return loop.run_until_complete(
execute_multiple_tools_on_multiple_mcp_servers(
responses=responses,
urls=urls,
connections=connections,
output_type=output_type,
max_concurrent=max_concurrent,
*args,
**kwargs,
)
)
except Exception as e:
logger.error(
f"Error in execute_multiple_tools_on_multiple_mcp_servers_sync: {str(e)}"
)
raise MCPExecutionError(
f"Failed to execute multiple tools sync: {str(e)}"
)