feat(mcp): update integration to use FastMCP and support FastMCP in Agent

pull/819/head
Pavan Kumar 3 months ago committed by ascender1729
parent 824dea060e
commit 41ccffcbc2

@ -2780,4 +2780,4 @@ class Agent:
self.short_memory.add( self.short_memory.add(
role="Output Cleaner", role="Output Cleaner",
content=response, content=response,
) )

@ -1,13 +1,14 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, List from typing import Any, Dict, List, Optional, TypedDict, NotRequired
from typing_extensions import TypedDict
from contextlib import AbstractAsyncContextManager
from fastmcp import FastClientSession as ClientSession
from fastmcp.servers import fast_sse_client as sse_client
from loguru import logger from loguru import logger
import abc import abc
import asyncio import asyncio
from contextlib import AbstractAsyncContextManager, AsyncExitStack
from pathlib import Path from pathlib import Path
from typing import Literal from typing import Literal
@ -15,15 +16,7 @@ from anyio.streams.memory import (
MemoryObjectReceiveStream, MemoryObjectReceiveStream,
MemoryObjectSendStream, MemoryObjectSendStream,
) )
from mcp import ( from mcp.types import CallToolResult, JSONRPCMessage # Kept for backward compatibility, might be removed later
ClientSession,
StdioServerParameters,
Tool as MCPTool,
stdio_client,
)
from mcp.client.sse import sse_client
from mcp.types import CallToolResult, JSONRPCMessage
from typing_extensions import NotRequired, TypedDict
from swarms.utils.any_to_str import any_to_str from swarms.utils.any_to_str import any_to_str
@ -53,18 +46,19 @@ class MCPServer(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def list_tools(self) -> list[MCPTool]: async def list_tools(self) -> list[Any]: # Changed to Any for flexibility
"""List the tools available on the server.""" """List the tools available on the server."""
pass pass
@abc.abstractmethod @abc.abstractmethod
async def call_tool( async def call_tool(
self, tool_name: str, arguments: dict[str, Any] | None self, tool_name: str, arguments: dict[str, Any] | None
) -> CallToolResult: ) -> CallToolResult: # Kept for backward compatibility, might be removed later
"""Invoke a tool on the server.""" """Invoke a tool on the server."""
pass pass
class _MCPServerWithClientSession(MCPServer, abc.ABC): class _MCPServerWithClientSession(MCPServer, abc.ABC):
"""Base class for MCP servers that use a `ClientSession` to communicate with the server.""" """Base class for MCP servers that use a `ClientSession` to communicate with the server."""
@ -85,7 +79,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
# The cache is always dirty at startup, so that we fetch tools at least once # The cache is always dirty at startup, so that we fetch tools at least once
self._cache_dirty = True self._cache_dirty = True
self._tools_list: list[MCPTool] | None = None self._tools_list: list[Any] | None = None # Changed to Any for flexibility
@abc.abstractmethod @abc.abstractmethod
def create_streams( def create_streams(
@ -127,7 +121,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
await self.cleanup() await self.cleanup()
raise raise
async def list_tools(self) -> list[MCPTool]: async def list_tools(self) -> list[Any]: # Changed to Any for flexibility
"""List the tools available on the server.""" """List the tools available on the server."""
if not self.session: if not self.session:
raise Exception( raise Exception(
@ -151,7 +145,7 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
async def call_tool( async def call_tool(
self, arguments: dict[str, Any] | None self, arguments: dict[str, Any] | None
) -> CallToolResult: ) -> CallToolResult: # Kept for backward compatibility, might be removed later
"""Invoke a tool on the server.""" """Invoke a tool on the server."""
tool_name = arguments.get("tool_name") or arguments.get( tool_name = arguments.get("tool_name") or arguments.get(
"name" "name"
@ -268,6 +262,7 @@ class MCPServerStdio(_MCPServerWithClientSession):
return self._name return self._name
class MCPServerSseParams(TypedDict): class MCPServerSseParams(TypedDict):
"""Mirrors the params in`mcp.client.sse.sse_client`.""" """Mirrors the params in`mcp.client.sse.sse_client`."""
@ -284,119 +279,93 @@ class MCPServerSseParams(TypedDict):
"""The timeout for the SSE connection, in seconds. Defaults to 5 minutes.""" """The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""
class MCPServerSse(_MCPServerWithClientSession): class MCPServerSse:
"""MCP server implementation that uses the HTTP with SSE transport. See the [spec] def __init__(self, params: MCPServerSseParams):
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse)
for details.
"""
def __init__(
self,
params: MCPServerSseParams,
cache_tools_list: bool = False,
name: str | None = None,
):
"""Create a new MCP server based on the HTTP with SSE transport.
Args:
params: The params that configure the server. This includes the URL of the server,
the headers to send to the server, the timeout for the HTTP request, and the
timeout for the SSE connection.
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
cached and only fetched from the server once. If `False`, the tools list will be
fetched from the server on each call to `list_tools()`. The cache can be
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
if you know the server will not change its tools list, because it can drastically
improve latency (by avoiding a round-trip to the server every time).
name: A readable name for the server. If not provided, we'll create one from the
URL.
"""
super().__init__(cache_tools_list)
self.params = params self.params = params
self._name = name or f"sse: {self.params['url']}" self.client: Optional[ClientSession] = None
def create_streams( async def connect(self):
self, """Connect to the MCP server."""
) -> AbstractAsyncContextManager[ if not self.client:
tuple[ self.client = ClientSession()
MemoryObjectReceiveStream[JSONRPCMessage | Exception], await self.client.connect(self.create_streams())
MemoryObjectSendStream[JSONRPCMessage],
] def create_streams(self, **kwargs) -> AbstractAsyncContextManager[Any]:
]:
"""Create the streams for the server."""
return sse_client( return sse_client(
url=self.params["url"], url=self.params["url"],
headers=self.params.get("headers", None), headers=self.params.get("headers", None),
timeout=self.params.get("timeout", 5), timeout=self.params.get("timeout", 5),
sse_read_timeout=self.params.get( sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5),
"sse_read_timeout", 60 * 5
),
) )
@property async def call_tool(self, payload: dict[str, Any]):
def name(self) -> str: """Call a tool on the MCP server."""
"""A readable name for the server.""" if not self.client:
return self._name raise RuntimeError("Not connected to MCP server")
return await self.client.call_tool(payload)
async def cleanup(self):
"""Clean up the connection."""
if self.client:
await self.client.close()
self.client = None
def mcp_flow_get_tool_schema( async def list_tools(self) -> list[Any]: # Added for compatibility
params: MCPServerSseParams, if not self.client:
) -> MCPServer: raise RuntimeError("Not connected to MCP server")
server = MCPServerSse(params, cache_tools_list=True) return await self.client.list_tools()
# Connect the server
asyncio.run(server.connect())
# Return the server async def call_tool_fast(server: MCPServerSse, payload: dict[str, Any]):
output = asyncio.run(server.list_tools()) """
Convenience wrapper that opens calls closes in one shot.
"""
await server.connect()
result = await server.call_tool(payload)
await server.cleanup()
return result.model_dump() if hasattr(result, "model_dump") else result
# Cleanup the server
asyncio.run(server.cleanup())
return output.model_dump() async def mcp_flow_get_tool_schema(
params: MCPServerSseParams,
) -> Any: # Updated return type to Any
async with MCPServerSse(params) as server:
return (await server.list_tools()).model_dump()
def mcp_flow( async def mcp_flow(
params: MCPServerSseParams, params: MCPServerSseParams,
function_call: dict[str, Any], function_call: dict[str, Any],
) -> MCPServer: ) -> Any: # Updated return type to Any
try: try:
server = MCPServerSse(params, cache_tools_list=True) async with MCPServerSse(params) as server:
return await call_tool_fast(server, function_call)
# Connect the server
asyncio.run(server.connect())
# Extract tool name and args from function call
tool_name = function_call.get("tool_name") or function_call.get("name")
if not tool_name:
raise ValueError("No tool name provided in function call")
# Call the tool
output = asyncio.run(server.call_tool(function_call))
# Convert to serializable format
output = output.model_dump() if hasattr(output, "model_dump") else output
# Cleanup the server
asyncio.run(server.cleanup())
return any_to_str(output)
except Exception as e: except Exception as e:
logger.error(f"Error in MCP flow: {e}") logger.error(f"Error in MCP flow: {e}")
raise raise
def batch_mcp_flow( async def batch_mcp_flow(
params: List[MCPServerSseParams], params: List[MCPServerSseParams],
function_call: List[dict[str, Any]] = [], function_call: List[dict[str, Any]] = [],
) -> MCPServer: ) -> List[Any]: # Updated return type to List[Any]
output_list = [] async def process_param(param):
try:
async with MCPServerSse(param) as server:
return await call_tool_fast(server, function_call[0])
except IndexError:
return None # Handle case where function_call is empty
except Exception as e:
logger.error(f"Error processing parameter: {param}, error: {e}")
return None
for param in params: results = await asyncio.gather(*(process_param(param) for param in params))
output = mcp_flow(param, function_call) return [any_to_str(r) for r in results if r is not None]
output_list.append(output)
return output_list
from mcp import (
ClientSession as OldClientSession, # Kept for backward compatibility with stdio
StdioServerParameters,
Tool as MCPTool,
stdio_client,
)
Loading…
Cancel
Save