|
|
|
@ -1,13 +1,14 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
from typing import Any, List
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List, Optional, TypedDict, NotRequired
|
|
|
|
|
from typing_extensions import TypedDict
|
|
|
|
|
from contextlib import AbstractAsyncContextManager
|
|
|
|
|
from fastmcp import FastClientSession as ClientSession
|
|
|
|
|
from fastmcp.servers import fast_sse_client as sse_client
|
|
|
|
|
|
|
|
|
|
from loguru import logger
|
|
|
|
|
|
|
|
|
|
import abc
|
|
|
|
|
import asyncio
|
|
|
|
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Literal
|
|
|
|
|
|
|
|
|
@ -15,15 +16,7 @@ from anyio.streams.memory import (
|
|
|
|
|
MemoryObjectReceiveStream,
|
|
|
|
|
MemoryObjectSendStream,
|
|
|
|
|
)
|
|
|
|
|
from mcp import (
|
|
|
|
|
ClientSession,
|
|
|
|
|
StdioServerParameters,
|
|
|
|
|
Tool as MCPTool,
|
|
|
|
|
stdio_client,
|
|
|
|
|
)
|
|
|
|
|
from mcp.client.sse import sse_client
|
|
|
|
|
from mcp.types import CallToolResult, JSONRPCMessage
|
|
|
|
|
from typing_extensions import NotRequired, TypedDict
|
|
|
|
|
from mcp.types import CallToolResult, JSONRPCMessage # Kept for backward compatibility, might be removed later
|
|
|
|
|
|
|
|
|
|
from swarms.utils.any_to_str import any_to_str
|
|
|
|
|
|
|
|
|
@ -53,18 +46,19 @@ class MCPServer(abc.ABC):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@abc.abstractmethod
|
|
|
|
|
async def list_tools(self) -> list[MCPTool]:
|
|
|
|
|
async def list_tools(self) -> list[Any]: # Changed to Any for flexibility
|
|
|
|
|
"""List the tools available on the server."""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@abc.abstractmethod
|
|
|
|
|
async def call_tool(
|
|
|
|
|
self, tool_name: str, arguments: dict[str, Any] | None
|
|
|
|
|
) -> CallToolResult:
|
|
|
|
|
) -> CallToolResult: # Kept for backward compatibility, might be removed later
|
|
|
|
|
"""Invoke a tool on the server."""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
|
|
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
|
|
|
|
|
|
|
|
|
@ -85,7 +79,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
|
|
|
|
|
|
|
# The cache is always dirty at startup, so that we fetch tools at least once
|
|
|
|
|
self._cache_dirty = True
|
|
|
|
|
self._tools_list: list[MCPTool] | None = None
|
|
|
|
|
self._tools_list: list[Any] | None = None # Changed to Any for flexibility
|
|
|
|
|
|
|
|
|
|
@abc.abstractmethod
|
|
|
|
|
def create_streams(
|
|
|
|
@ -127,7 +121,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
|
|
await self.cleanup()
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
async def list_tools(self) -> list[MCPTool]:
|
|
|
|
|
async def list_tools(self) -> list[Any]: # Changed to Any for flexibility
|
|
|
|
|
"""List the tools available on the server."""
|
|
|
|
|
if not self.session:
|
|
|
|
|
raise Exception(
|
|
|
|
@ -151,7 +145,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
|
|
|
|
|
|
|
async def call_tool(
|
|
|
|
|
self, arguments: dict[str, Any] | None
|
|
|
|
|
) -> CallToolResult:
|
|
|
|
|
) -> CallToolResult: # Kept for backward compatibility, might be removed later
|
|
|
|
|
"""Invoke a tool on the server."""
|
|
|
|
|
tool_name = arguments.get("tool_name") or arguments.get(
|
|
|
|
|
"name"
|
|
|
|
@ -268,6 +262,7 @@ class MCPServerStdio(_MCPServerWithClientSession):
|
|
|
|
|
return self._name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MCPServerSseParams(TypedDict):
|
|
|
|
|
"""Mirrors the params in`mcp.client.sse.sse_client`."""
|
|
|
|
|
|
|
|
|
@ -284,119 +279,93 @@ class MCPServerSseParams(TypedDict):
|
|
|
|
|
"""The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MCPServerSse(_MCPServerWithClientSession):
|
|
|
|
|
"""MCP server implementation that uses the HTTP with SSE transport. See the [spec]
|
|
|
|
|
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse)
|
|
|
|
|
for details.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
params: MCPServerSseParams,
|
|
|
|
|
cache_tools_list: bool = False,
|
|
|
|
|
name: str | None = None,
|
|
|
|
|
):
|
|
|
|
|
"""Create a new MCP server based on the HTTP with SSE transport.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
params: The params that configure the server. This includes the URL of the server,
|
|
|
|
|
the headers to send to the server, the timeout for the HTTP request, and the
|
|
|
|
|
timeout for the SSE connection.
|
|
|
|
|
|
|
|
|
|
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
|
|
|
|
|
cached and only fetched from the server once. If `False`, the tools list will be
|
|
|
|
|
fetched from the server on each call to `list_tools()`. The cache can be
|
|
|
|
|
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
|
|
|
|
|
if you know the server will not change its tools list, because it can drastically
|
|
|
|
|
improve latency (by avoiding a round-trip to the server every time).
|
|
|
|
|
|
|
|
|
|
name: A readable name for the server. If not provided, we'll create one from the
|
|
|
|
|
URL.
|
|
|
|
|
"""
|
|
|
|
|
super().__init__(cache_tools_list)
|
|
|
|
|
|
|
|
|
|
class MCPServerSse:
|
|
|
|
|
def __init__(self, params: MCPServerSseParams):
|
|
|
|
|
self.params = params
|
|
|
|
|
self._name = name or f"sse: {self.params['url']}"
|
|
|
|
|
self.client: Optional[ClientSession] = None
|
|
|
|
|
|
|
|
|
|
def create_streams(
|
|
|
|
|
self,
|
|
|
|
|
) -> AbstractAsyncContextManager[
|
|
|
|
|
tuple[
|
|
|
|
|
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
|
|
|
|
MemoryObjectSendStream[JSONRPCMessage],
|
|
|
|
|
]
|
|
|
|
|
]:
|
|
|
|
|
"""Create the streams for the server."""
|
|
|
|
|
async def connect(self):
|
|
|
|
|
"""Connect to the MCP server."""
|
|
|
|
|
if not self.client:
|
|
|
|
|
self.client = ClientSession()
|
|
|
|
|
await self.client.connect(self.create_streams())
|
|
|
|
|
|
|
|
|
|
def create_streams(self, **kwargs) -> AbstractAsyncContextManager[Any]:
|
|
|
|
|
return sse_client(
|
|
|
|
|
url=self.params["url"],
|
|
|
|
|
headers=self.params.get("headers", None),
|
|
|
|
|
timeout=self.params.get("timeout", 5),
|
|
|
|
|
sse_read_timeout=self.params.get(
|
|
|
|
|
"sse_read_timeout", 60 * 5
|
|
|
|
|
),
|
|
|
|
|
sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def name(self) -> str:
|
|
|
|
|
"""A readable name for the server."""
|
|
|
|
|
return self._name
|
|
|
|
|
async def call_tool(self, payload: dict[str, Any]):
|
|
|
|
|
"""Call a tool on the MCP server."""
|
|
|
|
|
if not self.client:
|
|
|
|
|
raise RuntimeError("Not connected to MCP server")
|
|
|
|
|
return await self.client.call_tool(payload)
|
|
|
|
|
|
|
|
|
|
async def cleanup(self):
|
|
|
|
|
"""Clean up the connection."""
|
|
|
|
|
if self.client:
|
|
|
|
|
await self.client.close()
|
|
|
|
|
self.client = None
|
|
|
|
|
|
|
|
|
|
def mcp_flow_get_tool_schema(
|
|
|
|
|
params: MCPServerSseParams,
|
|
|
|
|
) -> MCPServer:
|
|
|
|
|
server = MCPServerSse(params, cache_tools_list=True)
|
|
|
|
|
async def list_tools(self) -> list[Any]: # Added for compatibility
|
|
|
|
|
if not self.client:
|
|
|
|
|
raise RuntimeError("Not connected to MCP server")
|
|
|
|
|
return await self.client.list_tools()
|
|
|
|
|
|
|
|
|
|
# Connect the server
|
|
|
|
|
asyncio.run(server.connect())
|
|
|
|
|
|
|
|
|
|
# Return the server
|
|
|
|
|
output = asyncio.run(server.list_tools())
|
|
|
|
|
async def call_tool_fast(server: MCPServerSse, payload: dict[str, Any]):
|
|
|
|
|
"""
|
|
|
|
|
Convenience wrapper that opens → calls → closes in one shot.
|
|
|
|
|
"""
|
|
|
|
|
await server.connect()
|
|
|
|
|
result = await server.call_tool(payload)
|
|
|
|
|
await server.cleanup()
|
|
|
|
|
return result.model_dump() if hasattr(result, "model_dump") else result
|
|
|
|
|
|
|
|
|
|
# Cleanup the server
|
|
|
|
|
asyncio.run(server.cleanup())
|
|
|
|
|
|
|
|
|
|
return output.model_dump()
|
|
|
|
|
async def mcp_flow_get_tool_schema(
|
|
|
|
|
params: MCPServerSseParams,
|
|
|
|
|
) -> Any: # Updated return type to Any
|
|
|
|
|
async with MCPServerSse(params) as server:
|
|
|
|
|
return (await server.list_tools()).model_dump()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mcp_flow(
|
|
|
|
|
async def mcp_flow(
|
|
|
|
|
params: MCPServerSseParams,
|
|
|
|
|
function_call: dict[str, Any],
|
|
|
|
|
) -> MCPServer:
|
|
|
|
|
) -> Any: # Updated return type to Any
|
|
|
|
|
try:
|
|
|
|
|
server = MCPServerSse(params, cache_tools_list=True)
|
|
|
|
|
|
|
|
|
|
# Connect the server
|
|
|
|
|
asyncio.run(server.connect())
|
|
|
|
|
|
|
|
|
|
# Extract tool name and args from function call
|
|
|
|
|
tool_name = function_call.get("tool_name") or function_call.get("name")
|
|
|
|
|
if not tool_name:
|
|
|
|
|
raise ValueError("No tool name provided in function call")
|
|
|
|
|
|
|
|
|
|
# Call the tool
|
|
|
|
|
output = asyncio.run(server.call_tool(function_call))
|
|
|
|
|
|
|
|
|
|
# Convert to serializable format
|
|
|
|
|
output = output.model_dump() if hasattr(output, "model_dump") else output
|
|
|
|
|
|
|
|
|
|
# Cleanup the server
|
|
|
|
|
asyncio.run(server.cleanup())
|
|
|
|
|
|
|
|
|
|
return any_to_str(output)
|
|
|
|
|
async with MCPServerSse(params) as server:
|
|
|
|
|
return await call_tool_fast(server, function_call)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error in MCP flow: {e}")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def batch_mcp_flow(
|
|
|
|
|
async def batch_mcp_flow(
|
|
|
|
|
params: List[MCPServerSseParams],
|
|
|
|
|
function_call: List[dict[str, Any]] = [],
|
|
|
|
|
) -> MCPServer:
|
|
|
|
|
output_list = []
|
|
|
|
|
) -> List[Any]: # Updated return type to List[Any]
|
|
|
|
|
async def process_param(param):
|
|
|
|
|
try:
|
|
|
|
|
async with MCPServerSse(param) as server:
|
|
|
|
|
return await call_tool_fast(server, function_call[0])
|
|
|
|
|
except IndexError:
|
|
|
|
|
return None # Handle case where function_call is empty
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error processing parameter: {param}, error: {e}")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
for param in params:
|
|
|
|
|
output = mcp_flow(param, function_call)
|
|
|
|
|
output_list.append(output)
|
|
|
|
|
results = await asyncio.gather(*(process_param(param) for param in params))
|
|
|
|
|
return [any_to_str(r) for r in results if r is not None]
|
|
|
|
|
|
|
|
|
|
return output_list
|
|
|
|
|
|
|
|
|
|
from mcp import (
|
|
|
|
|
ClientSession as OldClientSession, # Kept for backward compatibility with stdio
|
|
|
|
|
StdioServerParameters,
|
|
|
|
|
Tool as MCPTool,
|
|
|
|
|
stdio_client,
|
|
|
|
|
)
|