From 41ccffcbc2be882b305dd8ed58dbd5e568d9dbba Mon Sep 17 00:00:00 2001 From: Pavan Kumar <66913595+ascender1729@users.noreply.github.com> Date: Sun, 20 Apr 2025 08:32:47 +0000 Subject: [PATCH] feat(mcp): update integration to use FastMCP and support FastMCP in Agent --- swarms/structs/agent.py | 2 +- swarms/tools/mcp_integration.py | 181 +++++++++++++------------------- 2 files changed, 76 insertions(+), 107 deletions(-) diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index 582d298b..334cc33a 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -2780,4 +2780,4 @@ class Agent: self.short_memory.add( role="Output Cleaner", content=response, - ) + ) \ No newline at end of file diff --git a/swarms/tools/mcp_integration.py b/swarms/tools/mcp_integration.py index 6880527d..8a1fca1f 100644 --- a/swarms/tools/mcp_integration.py +++ b/swarms/tools/mcp_integration.py @@ -1,13 +1,14 @@ from __future__ import annotations -from typing import Any, List - +from typing import Any, Dict, List, Optional, TypedDict, NotRequired +from typing_extensions import TypedDict +from contextlib import AbstractAsyncContextManager +from fastmcp import FastClientSession as ClientSession +from fastmcp.servers import fast_sse_client as sse_client from loguru import logger - import abc import asyncio -from contextlib import AbstractAsyncContextManager, AsyncExitStack from pathlib import Path from typing import Literal @@ -15,15 +16,7 @@ from anyio.streams.memory import ( MemoryObjectReceiveStream, MemoryObjectSendStream, ) -from mcp import ( - ClientSession, - StdioServerParameters, - Tool as MCPTool, - stdio_client, -) -from mcp.client.sse import sse_client -from mcp.types import CallToolResult, JSONRPCMessage -from typing_extensions import NotRequired, TypedDict +from mcp.types import CallToolResult, JSONRPCMessage # Kept for backward compatibility, might be removed later from swarms.utils.any_to_str import any_to_str @@ -53,18 +46,19 @@ class MCPServer(abc.ABC): pass @abc.abstractmethod - async def list_tools(self) -> list[MCPTool]: + async def list_tools(self) -> list[Any]: # Changed to Any for flexibility """List the tools available on the server.""" pass @abc.abstractmethod async def call_tool( self, tool_name: str, arguments: dict[str, Any] | None - ) -> CallToolResult: + ) -> CallToolResult: # Kept for backward compatibility, might be removed later """Invoke a tool on the server.""" pass + class _MCPServerWithClientSession(MCPServer, abc.ABC): """Base class for MCP servers that use a `ClientSession` to communicate with the server.""" @@ -85,7 +79,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC): # The cache is always dirty at startup, so that we fetch tools at least once self._cache_dirty = True - self._tools_list: list[MCPTool] | None = None + self._tools_list: list[Any] | None = None # Changed to Any for flexibility @abc.abstractmethod def create_streams( @@ -127,7 +121,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC): await self.cleanup() raise - async def list_tools(self) -> list[MCPTool]: + async def list_tools(self) -> list[Any]: # Changed to Any for flexibility """List the tools available on the server.""" if not self.session: raise Exception( @@ -151,7 +145,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC): async def call_tool( self, arguments: dict[str, Any] | None - ) -> CallToolResult: + ) -> CallToolResult: # Kept for backward compatibility, might be removed later """Invoke a tool on the server.""" tool_name = arguments.get("tool_name") or arguments.get( "name" @@ -268,6 +262,7 @@ class MCPServerStdio(_MCPServerWithClientSession): return self._name + class MCPServerSseParams(TypedDict): """Mirrors the params in`mcp.client.sse.sse_client`.""" @@ -284,119 +279,93 @@ class MCPServerSseParams(TypedDict): """The timeout for the SSE connection, in seconds. Defaults to 5 minutes.""" -class MCPServerSse(_MCPServerWithClientSession): - """MCP server implementation that uses the HTTP with SSE transport. See the [spec] - (https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) - for details. - """ - - def __init__( - self, - params: MCPServerSseParams, - cache_tools_list: bool = False, - name: str | None = None, - ): - """Create a new MCP server based on the HTTP with SSE transport. - - Args: - params: The params that configure the server. This includes the URL of the server, - the headers to send to the server, the timeout for the HTTP request, and the - timeout for the SSE connection. - - cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be - cached and only fetched from the server once. If `False`, the tools list will be - fetched from the server on each call to `list_tools()`. The cache can be - invalidated by calling `invalidate_tools_cache()`. You should set this to `True` - if you know the server will not change its tools list, because it can drastically - improve latency (by avoiding a round-trip to the server every time). - - name: A readable name for the server. If not provided, we'll create one from the - URL. - """ - super().__init__(cache_tools_list) - +class MCPServerSse: + def __init__(self, params: MCPServerSseParams): self.params = params - self._name = name or f"sse: {self.params['url']}" + self.client: Optional[ClientSession] = None - def create_streams( - self, - ) -> AbstractAsyncContextManager[ - tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], - ] - ]: - """Create the streams for the server.""" + async def connect(self): + """Connect to the MCP server.""" + if not self.client: + self.client = ClientSession() + await self.client.connect(self.create_streams()) + + def create_streams(self, **kwargs) -> AbstractAsyncContextManager[Any]: return sse_client( url=self.params["url"], headers=self.params.get("headers", None), timeout=self.params.get("timeout", 5), - sse_read_timeout=self.params.get( - "sse_read_timeout", 60 * 5 - ), + sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5), ) - @property - def name(self) -> str: - """A readable name for the server.""" - return self._name + async def call_tool(self, payload: dict[str, Any]): + """Call a tool on the MCP server.""" + if not self.client: + raise RuntimeError("Not connected to MCP server") + return await self.client.call_tool(payload) + async def cleanup(self): + """Clean up the connection.""" + if self.client: + await self.client.close() + self.client = None -def mcp_flow_get_tool_schema( - params: MCPServerSseParams, -) -> MCPServer: - server = MCPServerSse(params, cache_tools_list=True) + async def list_tools(self) -> list[Any]: # Added for compatibility + if not self.client: + raise RuntimeError("Not connected to MCP server") + return await self.client.list_tools() - # Connect the server - asyncio.run(server.connect()) - # Return the server - output = asyncio.run(server.list_tools()) +async def call_tool_fast(server: MCPServerSse, payload: dict[str, Any]): + """ + Convenience wrapper that opens → calls → closes in one shot. + """ + await server.connect() + result = await server.call_tool(payload) + await server.cleanup() + return result.model_dump() if hasattr(result, "model_dump") else result - # Cleanup the server - asyncio.run(server.cleanup()) - return output.model_dump() +async def mcp_flow_get_tool_schema( + params: MCPServerSseParams, +) -> Any: # Updated return type to Any + async with MCPServerSse(params) as server: + return (await server.list_tools()).model_dump() -def mcp_flow( +async def mcp_flow( params: MCPServerSseParams, function_call: dict[str, Any], -) -> MCPServer: +) -> Any: # Updated return type to Any try: - server = MCPServerSse(params, cache_tools_list=True) - - # Connect the server - asyncio.run(server.connect()) - - # Extract tool name and args from function call - tool_name = function_call.get("tool_name") or function_call.get("name") - if not tool_name: - raise ValueError("No tool name provided in function call") - - # Call the tool - output = asyncio.run(server.call_tool(function_call)) - - # Convert to serializable format - output = output.model_dump() if hasattr(output, "model_dump") else output - - # Cleanup the server - asyncio.run(server.cleanup()) - - return any_to_str(output) + async with MCPServerSse(params) as server: + return await call_tool_fast(server, function_call) except Exception as e: logger.error(f"Error in MCP flow: {e}") raise -def batch_mcp_flow( +async def batch_mcp_flow( params: List[MCPServerSseParams], function_call: List[dict[str, Any]] = [], -) -> MCPServer: - output_list = [] +) -> List[Any]: # Updated return type to List[Any] + async def process_param(param): + try: + async with MCPServerSse(param) as server: + return await call_tool_fast(server, function_call[0]) + except IndexError: + return None # Handle case where function_call is empty + except Exception as e: + logger.error(f"Error processing parameter: {param}, error: {e}") + return None - for param in params: - output = mcp_flow(param, function_call) - output_list.append(output) + results = await asyncio.gather(*(process_param(param) for param in params)) + return [any_to_str(r) for r in results if r is not None] - return output_list + +from mcp import ( + ClientSession as OldClientSession, # Kept for backward compatibility with stdio + StdioServerParameters, + Tool as MCPTool, + stdio_client, +) \ No newline at end of file