parent
1a84b24394
commit
16a2321525
@ -1,255 +1,255 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Literal
|
from typing import Any, Dict, List, Optional, Literal
|
||||||
from typing_extensions import NotRequired, TypedDict
|
from typing_extensions import NotRequired, TypedDict
|
||||||
|
|
||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
|
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
|
||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
from mcp.types import CallToolResult, JSONRPCMessage
|
from mcp.types import CallToolResult, JSONRPCMessage
|
||||||
|
|
||||||
from swarms.utils.any_to_str import any_to_str
|
from swarms.utils.any_to_str import any_to_str
|
||||||
|
|
||||||
|
|
||||||
class MCPServer(abc.ABC):
|
class MCPServer(abc.ABC):
|
||||||
"""Base class for Model Context Protocol servers."""
|
"""Base class for Model Context Protocol servers."""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def connect(self) -> None:
|
async def connect(self) -> None:
|
||||||
"""Establish connection to the MCP server."""
|
"""Establish connection to the MCP server."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
"""Human-readable server name."""
|
"""Human-readable server name."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def cleanup(self) -> None:
|
async def cleanup(self) -> None:
|
||||||
"""Clean up resources and close connection."""
|
"""Clean up resources and close connection."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def list_tools(self) -> List[MCPTool]:
|
async def list_tools(self) -> List[MCPTool]:
|
||||||
"""List available MCP tools on the server."""
|
"""List available MCP tools on the server."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def call_tool(
|
async def call_tool(
|
||||||
self, tool_name: str, arguments: Dict[str, Any] | None
|
self, tool_name: str, arguments: Dict[str, Any] | None
|
||||||
) -> CallToolResult:
|
) -> CallToolResult:
|
||||||
"""Invoke a tool by name with provided arguments."""
|
"""Invoke a tool by name with provided arguments."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
||||||
"""Mixin providing ClientSession-based MCP communication."""
|
"""Mixin providing ClientSession-based MCP communication."""
|
||||||
|
|
||||||
def __init__(self, cache_tools_list: bool = False):
|
def __init__(self, cache_tools_list: bool = False):
|
||||||
self.session: Optional[ClientSession] = None
|
self.session: Optional[ClientSession] = None
|
||||||
self.exit_stack: AsyncExitStack = AsyncExitStack()
|
self.exit_stack: AsyncExitStack = AsyncExitStack()
|
||||||
self._cleanup_lock = asyncio.Lock()
|
self._cleanup_lock = asyncio.Lock()
|
||||||
self.cache_tools_list = cache_tools_list
|
self.cache_tools_list = cache_tools_list
|
||||||
self._cache_dirty = True
|
self._cache_dirty = True
|
||||||
self._tools_list: Optional[List[MCPTool]] = None
|
self._tools_list: Optional[List[MCPTool]] = None
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def create_streams(
|
def create_streams(
|
||||||
self
|
self
|
||||||
) -> AbstractAsyncContextManager[
|
) -> AbstractAsyncContextManager[
|
||||||
tuple[
|
tuple[
|
||||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||||
MemoryObjectSendStream[JSONRPCMessage],
|
MemoryObjectSendStream[JSONRPCMessage],
|
||||||
]
|
]
|
||||||
]:
|
]:
|
||||||
"""Supply the read/write streams for the MCP transport."""
|
"""Supply the read/write streams for the MCP transport."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def __aenter__(self) -> MCPServer:
|
async def __aenter__(self) -> MCPServer:
|
||||||
await self.connect()
|
await self.connect()
|
||||||
return self # type: ignore
|
return self # type: ignore
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_value, tb) -> None:
|
async def __aexit__(self, exc_type, exc_value, tb) -> None:
|
||||||
await self.cleanup()
|
await self.cleanup()
|
||||||
|
|
||||||
async def connect(self) -> None:
|
async def connect(self) -> None:
|
||||||
"""Initialize transport and ClientSession."""
|
"""Initialize transport and ClientSession."""
|
||||||
try:
|
try:
|
||||||
transport = await self.exit_stack.enter_async_context(
|
transport = await self.exit_stack.enter_async_context(
|
||||||
self.create_streams()
|
self.create_streams()
|
||||||
)
|
)
|
||||||
read, write = transport
|
read, write = transport
|
||||||
session = await self.exit_stack.enter_async_context(
|
session = await self.exit_stack.enter_async_context(
|
||||||
ClientSession(read, write)
|
ClientSession(read, write)
|
||||||
)
|
)
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
self.session = session
|
self.session = session
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error initializing MCP server: {e}")
|
logger.error(f"Error initializing MCP server: {e}")
|
||||||
await self.cleanup()
|
await self.cleanup()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def cleanup(self) -> None:
|
async def cleanup(self) -> None:
|
||||||
"""Close session and transport."""
|
"""Close session and transport."""
|
||||||
async with self._cleanup_lock:
|
async with self._cleanup_lock:
|
||||||
try:
|
|
||||||
await self.exit_stack.aclose()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during cleanup: {e}")
|
|
||||||
finally:
|
|
||||||
self.session = None
|
|
||||||
|
|
||||||
async def list_tools(self) -> List[MCPTool]:
|
|
||||||
if not self.session:
|
|
||||||
raise RuntimeError("Server not connected. Call connect() first.")
|
|
||||||
if self.cache_tools_list and not self._cache_dirty and self._tools_list:
|
|
||||||
return self._tools_list
|
|
||||||
self._cache_dirty = False
|
|
||||||
self._tools_list = (await self.session.list_tools()).tools
|
|
||||||
return self._tools_list # type: ignore
|
|
||||||
|
|
||||||
async def call_tool(
|
|
||||||
self, tool_name: str | None = None, arguments: Dict[str, Any] | None = None
|
|
||||||
) -> CallToolResult:
|
|
||||||
if not arguments:
|
|
||||||
raise ValueError("Arguments dict is required to call a tool")
|
|
||||||
name = tool_name or arguments.get("tool_name") or arguments.get("name")
|
|
||||||
if not name:
|
|
||||||
raise ValueError("Tool name missing in arguments")
|
|
||||||
if not self.session:
|
|
||||||
raise RuntimeError("Server not connected. Call connect() first.")
|
|
||||||
return await self.session.call_tool(name, arguments)
|
|
||||||
|
|
||||||
|
|
||||||
class MCPServerStdioParams(TypedDict):
|
|
||||||
"""Configuration for stdio transport."""
|
|
||||||
command: str
|
|
||||||
args: NotRequired[List[str]]
|
|
||||||
env: NotRequired[Dict[str, str]]
|
|
||||||
cwd: NotRequired[str | Path]
|
|
||||||
encoding: NotRequired[str]
|
|
||||||
encoding_error_handler: NotRequired[Literal["strict", "ignore", "replace"]]
|
|
||||||
|
|
||||||
|
|
||||||
class MCPServerStdio(_MCPServerWithClientSession):
|
|
||||||
"""MCP server over stdio transport."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
params: MCPServerStdioParams,
|
|
||||||
cache_tools_list: bool = False,
|
|
||||||
name: Optional[str] = None,
|
|
||||||
):
|
|
||||||
super().__init__(cache_tools_list)
|
|
||||||
self.params = StdioServerParameters(
|
|
||||||
command=params["command"],
|
|
||||||
args=params.get("args", []),
|
|
||||||
env=params.get("env"),
|
|
||||||
cwd=params.get("cwd"),
|
|
||||||
encoding=params.get("encoding", "utf-8"),
|
|
||||||
encoding_error_handler=params.get("encoding_error_handler", "strict"),
|
|
||||||
)
|
|
||||||
self._name = name or f"stdio:{self.params.command}"
|
|
||||||
|
|
||||||
def create_streams(self) -> AbstractAsyncContextManager[
|
|
||||||
tuple[
|
|
||||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
|
||||||
MemoryObjectSendStream[JSONRPCMessage],
|
|
||||||
]
|
|
||||||
]:
|
|
||||||
return stdio_client(self.params)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return self._name
|
|
||||||
|
|
||||||
|
|
||||||
class MCPServerSseParams(TypedDict):
|
|
||||||
"""Configuration for HTTP+SSE transport."""
|
|
||||||
url: str
|
|
||||||
headers: NotRequired[Dict[str, str]]
|
|
||||||
timeout: NotRequired[float]
|
|
||||||
sse_read_timeout: NotRequired[float]
|
|
||||||
|
|
||||||
|
|
||||||
class MCPServerSse(_MCPServerWithClientSession):
|
|
||||||
"""MCP server over HTTP with SSE transport."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
params: MCPServerSseParams,
|
|
||||||
cache_tools_list: bool = False,
|
|
||||||
name: Optional[str] = None,
|
|
||||||
):
|
|
||||||
super().__init__(cache_tools_list)
|
|
||||||
self.params = params
|
|
||||||
self._name = name or f"sse:{params['url']}"
|
|
||||||
|
|
||||||
def create_streams(self) -> AbstractAsyncContextManager[
|
|
||||||
tuple[
|
|
||||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
|
||||||
MemoryObjectSendStream[JSONRPCMessage],
|
|
||||||
]
|
|
||||||
]:
|
|
||||||
return sse_client(
|
|
||||||
url=self.params["url"],
|
|
||||||
headers=self.params.get("headers"),
|
|
||||||
timeout=self.params.get("timeout", 5),
|
|
||||||
sse_read_timeout=self.params.get("sse_read_timeout", 300),
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return self._name
|
|
||||||
|
|
||||||
|
|
||||||
async def call_tool_fast(
|
|
||||||
server: MCPServerSse, payload: Dict[str, Any] | str
|
|
||||||
) -> Any:
|
|
||||||
try:
|
|
||||||
await server.connect()
|
|
||||||
result = await server.call_tool(arguments=payload if isinstance(payload, dict) else None)
|
|
||||||
return result
|
|
||||||
finally:
|
|
||||||
await server.cleanup()
|
|
||||||
|
|
||||||
|
|
||||||
async def mcp_flow_get_tool_schema(
|
|
||||||
params: MCPServerSseParams,
|
|
||||||
) -> Any:
|
|
||||||
async with MCPServerSse(params) as server:
|
|
||||||
tools = await server.list_tools()
|
|
||||||
return tools
|
|
||||||
|
|
||||||
|
|
||||||
async def mcp_flow(
|
|
||||||
params: MCPServerSseParams,
|
|
||||||
function_call: Dict[str, Any] | str,
|
|
||||||
) -> Any:
|
|
||||||
async with MCPServerSse(params) as server:
|
|
||||||
return await call_tool_fast(server, function_call)
|
|
||||||
|
|
||||||
|
|
||||||
async def _call_one_server(
|
|
||||||
params: MCPServerSseParams, payload: Dict[str, Any] | str
|
|
||||||
) -> Any:
|
|
||||||
server = MCPServerSse(params)
|
|
||||||
try:
|
try:
|
||||||
await server.connect()
|
await self.exit_stack.aclose()
|
||||||
return await server.call_tool(arguments=payload if isinstance(payload, dict) else None)
|
except Exception as e:
|
||||||
|
logger.error(f"Error during cleanup: {e}")
|
||||||
finally:
|
finally:
|
||||||
await server.cleanup()
|
self.session = None
|
||||||
|
|
||||||
|
async def list_tools(self) -> List[MCPTool]:
|
||||||
def batch_mcp_flow(
|
if not self.session:
|
||||||
params: List[MCPServerSseParams], payload: Dict[str, Any] | str
|
raise RuntimeError("Server not connected. Call connect() first.")
|
||||||
) -> List[Any]:
|
if self.cache_tools_list and not self._cache_dirty and self._tools_list:
|
||||||
return asyncio.run(
|
return self._tools_list
|
||||||
asyncio.gather(*[_call_one_server(p, payload) for p in params])
|
self._cache_dirty = False
|
||||||
)
|
self._tools_list = (await self.session.list_tools()).tools
|
||||||
|
return self._tools_list # type: ignore
|
||||||
|
|
||||||
|
async def call_tool(
|
||||||
|
self, tool_name: str | None = None, arguments: Dict[str, Any] | None = None
|
||||||
|
) -> CallToolResult:
|
||||||
|
if not arguments:
|
||||||
|
raise ValueError("Arguments dict is required to call a tool")
|
||||||
|
name = tool_name or arguments.get("tool_name") or arguments.get("name")
|
||||||
|
if not name:
|
||||||
|
raise ValueError("Tool name missing in arguments")
|
||||||
|
if not self.session:
|
||||||
|
raise RuntimeError("Server not connected. Call connect() first.")
|
||||||
|
return await self.session.call_tool(name, arguments)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerStdioParams(TypedDict):
|
||||||
|
"""Configuration for stdio transport."""
|
||||||
|
command: str
|
||||||
|
args: NotRequired[List[str]]
|
||||||
|
env: NotRequired[Dict[str, str]]
|
||||||
|
cwd: NotRequired[str | Path]
|
||||||
|
encoding: NotRequired[str]
|
||||||
|
encoding_error_handler: NotRequired[Literal["strict", "ignore", "replace"]]
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerStdio(_MCPServerWithClientSession):
|
||||||
|
"""MCP server over stdio transport."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: MCPServerStdioParams,
|
||||||
|
cache_tools_list: bool = False,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(cache_tools_list)
|
||||||
|
self.params = StdioServerParameters(
|
||||||
|
command=params["command"],
|
||||||
|
args=params.get("args", []),
|
||||||
|
env=params.get("env"),
|
||||||
|
cwd=params.get("cwd"),
|
||||||
|
encoding=params.get("encoding", "utf-8"),
|
||||||
|
encoding_error_handler=params.get("encoding_error_handler", "strict"),
|
||||||
|
)
|
||||||
|
self._name = name or f"stdio:{self.params.command}"
|
||||||
|
|
||||||
|
def create_streams(self) -> AbstractAsyncContextManager[
|
||||||
|
tuple[
|
||||||
|
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||||
|
MemoryObjectSendStream[JSONRPCMessage],
|
||||||
|
]
|
||||||
|
]:
|
||||||
|
return stdio_client(self.params)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerSseParams(TypedDict):
|
||||||
|
"""Configuration for HTTP+SSE transport."""
|
||||||
|
url: str
|
||||||
|
headers: NotRequired[Dict[str, str]]
|
||||||
|
timeout: NotRequired[float]
|
||||||
|
sse_read_timeout: NotRequired[float]
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerSse(_MCPServerWithClientSession):
|
||||||
|
"""MCP server over HTTP with SSE transport."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: MCPServerSseParams,
|
||||||
|
cache_tools_list: bool = False,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(cache_tools_list)
|
||||||
|
self.params = params
|
||||||
|
self._name = name or f"sse:{params['url']}"
|
||||||
|
|
||||||
|
def create_streams(self) -> AbstractAsyncContextManager[
|
||||||
|
tuple[
|
||||||
|
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||||
|
MemoryObjectSendStream[JSONRPCMessage],
|
||||||
|
]
|
||||||
|
]:
|
||||||
|
return sse_client(
|
||||||
|
url=self.params["url"],
|
||||||
|
headers=self.params.get("headers"),
|
||||||
|
timeout=self.params.get("timeout", 5),
|
||||||
|
sse_read_timeout=self.params.get("sse_read_timeout", 300),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
|
||||||
|
async def call_tool_fast(
|
||||||
|
server: MCPServerSse, payload: Dict[str, Any] | str
|
||||||
|
) -> Any:
|
||||||
|
try:
|
||||||
|
await server.connect()
|
||||||
|
result = await server.call_tool(arguments=payload if isinstance(payload, dict) else None)
|
||||||
|
return result
|
||||||
|
finally:
|
||||||
|
await server.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
async def mcp_flow_get_tool_schema(
|
||||||
|
params: MCPServerSseParams,
|
||||||
|
) -> Any:
|
||||||
|
async with MCPServerSse(params) as server:
|
||||||
|
tools = await server.list_tools()
|
||||||
|
return tools
|
||||||
|
|
||||||
|
|
||||||
|
async def mcp_flow(
|
||||||
|
params: MCPServerSseParams,
|
||||||
|
function_call: Dict[str, Any] | str,
|
||||||
|
) -> Any:
|
||||||
|
async with MCPServerSse(params) as server:
|
||||||
|
return await call_tool_fast(server, function_call)
|
||||||
|
|
||||||
|
|
||||||
|
async def _call_one_server(
|
||||||
|
params: MCPServerSseParams, payload: Dict[str, Any] | str
|
||||||
|
) -> Any:
|
||||||
|
server = MCPServerSse(params)
|
||||||
|
try:
|
||||||
|
await server.connect()
|
||||||
|
return await server.call_tool(arguments=payload if isinstance(payload, dict) else None)
|
||||||
|
finally:
|
||||||
|
await server.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
def batch_mcp_flow(
|
||||||
|
params: List[MCPServerSseParams], payload: Dict[str, Any] | str
|
||||||
|
) -> List[Any]:
|
||||||
|
return asyncio.run(
|
||||||
|
asyncio.gather(*[_call_one_server(p, payload) for p in params])
|
||||||
|
)
|
Loading…
Reference in new issue