test(mcp): streamline testing workflows and add new tests for mcp_integration.py

pull/819/head
DP37 3 months ago committed by ascender1729
parent 925709de6e
commit d75bbed8ee

@ -7,66 +7,6 @@ packages = ["libxcrypt"]
[workflows]
runButton = "Run MCP Demo"
[[workflows.workflow]]
name = "Run MCP Tests"
author = 13983571
mode = "sequential"
[[workflows.workflow.tasks]]
task = "shell.exec"
args = "python -m pytest tests/tools/test_mcp_integration.py -v"
[[workflows.workflow]]
name = "Run Interactive Agents"
author = 13983571
mode = "sequential"
[[workflows.workflow.tasks]]
task = "shell.exec"
args = "python -m pytest tests/tools/test_mcp_integration.py::test_interactive_multi_agent_mcp -s"
[[workflows.workflow]]
name = "Run MCP Test"
author = 13983571
mode = "sequential"
[[workflows.workflow.tasks]]
task = "shell.exec"
args = "python examples/mcp_example/math_server.py & "
[[workflows.workflow.tasks]]
task = "shell.exec"
args = "sleep 2"
[[workflows.workflow.tasks]]
task = "shell.exec"
args = "python examples/mcp_example/test_integration.py"
[[workflows.workflow]]
name = "Run Mock MCP System"
author = 13983571
mode = "parallel"
[[workflows.workflow.tasks]]
task = "shell.exec"
args = "python examples/mcp_example/mock_stock_server.py &"
[[workflows.workflow.tasks]]
task = "shell.exec"
args = "sleep 2"
[[workflows.workflow.tasks]]
task = "shell.exec"
args = "python examples/mcp_example/mock_math_server.py &"
[[workflows.workflow.tasks]]
task = "shell.exec"
args = "sleep 2"
[[workflows.workflow.tasks]]
task = "shell.exec"
args = "python examples/mcp_example/mock_multi_agent.py"
[[workflows.workflow]]
name = "Run Tests"
author = 13983571
@ -75,20 +15,3 @@ mode = "sequential"
[[workflows.workflow.tasks]]
task = "shell.exec"
args = "python -m unittest tests/test_basic_example.py -v"
[[workflows.workflow]]
name = "Run MCP Demo"
author = 13983571
mode = "parallel"
[[workflows.workflow.tasks]]
task = "shell.exec"
args = "python examples/mcp_example/mock_math_server.py &"
[[workflows.workflow.tasks]]
task = "shell.exec"
args = "sleep 2"
[[workflows.workflow.tasks]]
task = "shell.exec"
args = "python examples/mcp_example/mcp_client.py"

@ -0,0 +1,392 @@
this si the oreginal firle siese for integration base dont his update it not for sepcief case the mcop_inteation is used it inteage in aget firle"from __future__ import annotations
from typing import Any, List
from loguru import logger
import abc
import asyncio
from contextlib import AbstractAsyncContextManager, AsyncExitStack
from pathlib import Path
from typing import Literal
from anyio.streams.memory import (
MemoryObjectReceiveStream,
MemoryObjectSendStream,
)
from mcp import (
ClientSession,
StdioServerParameters,
Tool as MCPTool,
stdio_client,
)
from mcp.client.sse import sse_client
from mcp.types import CallToolResult, JSONRPCMessage
from typing_extensions import NotRequired, TypedDict
from swarms.utils.any_to_str import any_to_str
class MCPServer(abc.ABC):
"""Base class for Model Context Protocol servers."""
@abc.abstractmethod
async def connect(self):
"""Connect to the server. For example, this might mean spawning a subprocess or
opening a network connection. The server is expected to remain connected until
`cleanup()` is called.
"""
pass
@property
@abc.abstractmethod
def name(self) -> str:
"""A readable name for the server."""
pass
@abc.abstractmethod
async def cleanup(self):
"""Cleanup the server. For example, this might mean closing a subprocess or
closing a network connection.
"""
pass
@abc.abstractmethod
async def list_tools(self) -> list[MCPTool]:
"""List the tools available on the server."""
pass
@abc.abstractmethod
async def call_tool(
self, tool_name: str, arguments: dict[str, Any] | None
) -> CallToolResult:
"""Invoke a tool on the server."""
pass
class _MCPServerWithClientSession(MCPServer, abc.ABC):
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
def __init__(self, cache_tools_list: bool):
"""
Args:
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
cached and only fetched from the server once. If `False`, the tools list will be
fetched from the server on each call to `list_tools()`. The cache can be invalidated
by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
server will not change its tools list, because it can drastically improve latency
(by avoiding a round-trip to the server every time).
"""
self.session: ClientSession | None = None
self.exit_stack: AsyncExitStack = AsyncExitStack()
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
self.cache_tools_list = cache_tools_list
# The cache is always dirty at startup, so that we fetch tools at least once
self._cache_dirty = True
self._tools_list: list[MCPTool] | None = None
@abc.abstractmethod
def create_streams(
self,
) -> AbstractAsyncContextManager[
tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
]:
"""Create the streams for the server."""
pass
async def __aenter__(self):
await self.connect()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.cleanup()
def invalidate_tools_cache(self):
"""Invalidate the tools cache."""
self._cache_dirty = True
async def connect(self):
"""Connect to the server."""
try:
transport = await self.exit_stack.enter_async_context(
self.create_streams()
)
read, write = transport
session = await self.exit_stack.enter_async_context(
ClientSession(read, write)
)
await session.initialize()
self.session = session
except Exception as e:
logger.error(f"Error initializing MCP server: {e}")
await self.cleanup()
raise
async def list_tools(self) -> list[MCPTool]:
"""List the tools available on the server."""
if not self.session:
raise Exception(
"Server not initialized. Make sure you call `connect()` first."
)
# Return from cache if caching is enabled, we have tools, and the cache is not dirty
if (
self.cache_tools_list
and not self._cache_dirty
and self._tools_list
):
return self._tools_list
# Reset the cache dirty to False
self._cache_dirty = False
# Fetch the tools from the server
self._tools_list = (await self.session.list_tools()).tools
return self._tools_list
async def call_tool(
self, arguments: dict[str, Any] | None
) -> CallToolResult:
"""Invoke a tool on the server."""
tool_name = arguments.get("tool_name") or arguments.get(
"name"
)
if not tool_name:
raise Exception("No tool name found in arguments")
if not self.session:
raise Exception(
"Server not initialized. Make sure you call `connect()` first."
)
return await self.session.call_tool(tool_name, arguments)
async def cleanup(self):
"""Cleanup the server."""
async with self._cleanup_lock:
try:
await self.exit_stack.aclose()
self.session = None
except Exception as e:
logger.error(f"Error cleaning up server: {e}")
class MCPServerStdioParams(TypedDict):
"""Mirrors `mcp.client.stdio.StdioServerParameters`, but lets you pass params without another
import.
"""
command: str
"""The executable to run to start the server. For example, `python` or `node`."""
args: NotRequired[list[str]]
"""Command line args to pass to the `command` executable. For example, `['foo.py']` or
`['server.js', '--port', '8080']`."""
env: NotRequired[dict[str, str]]
"""The environment variables to set for the server. ."""
cwd: NotRequired[str | Path]
"""The working directory to use when spawning the process."""
encoding: NotRequired[str]
"""The text encoding used when sending/receiving messages to the server. Defaults to `utf-8`."""
encoding_error_handler: NotRequired[
Literal["strict", "ignore", "replace"]
]
"""The text encoding error handler. Defaults to `strict`.
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
explanations of possible values.
"""
class MCPServerStdio(_MCPServerWithClientSession):
"""MCP server implementation that uses the stdio transport. See the [spec]
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) for
details.
"""
def __init__(
self,
params: MCPServerStdioParams,
cache_tools_list: bool = False,
name: str | None = None,
):
"""Create a new MCP server based on the stdio transport.
Args:
params: The params that configure the server. This includes the command to run to
start the server, the args to pass to the command, the environment variables to
set for the server, the working directory to use when spawning the process, and
the text encoding used when sending/receiving messages to the server.
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
cached and only fetched from the server once. If `False`, the tools list will be
fetched from the server on each call to `list_tools()`. The cache can be
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
if you know the server will not change its tools list, because it can drastically
improve latency (by avoiding a round-trip to the server every time).
name: A readable name for the server. If not provided, we'll create one from the
command.
"""
super().__init__(cache_tools_list)
self.params = StdioServerParameters(
command=params["command"],
args=params.get("args", []),
env=params.get("env"),
cwd=params.get("cwd"),
encoding=params.get("encoding", "utf-8"),
encoding_error_handler=params.get(
"encoding_error_handler", "strict"
),
)
self._name = name or f"stdio: {self.params.command}"
def create_streams(
self,
) -> AbstractAsyncContextManager[
tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
]:
"""Create the streams for the server."""
return stdio_client(self.params)
@property
def name(self) -> str:
"""A readable name for the server."""
return self._name
class MCPServerSseParams(TypedDict):
"""Mirrors the params in`mcp.client.sse.sse_client`."""
url: str
"""The URL of the server."""
headers: NotRequired[dict[str, str]]
"""The headers to send to the server."""
timeout: NotRequired[float]
"""The timeout for the HTTP request. Defaults to 5 seconds."""
sse_read_timeout: NotRequired[float]
"""The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""
class MCPServerSse(_MCPServerWithClientSession):
"""MCP server implementation that uses the HTTP with SSE transport. See the [spec]
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse)
for details.
"""
def __init__(
self,
params: MCPServerSseParams,
cache_tools_list: bool = False,
name: str | None = None,
):
"""Create a new MCP server based on the HTTP with SSE transport.
Args:
params: The params that configure the server. This includes the URL of the server,
the headers to send to the server, the timeout for the HTTP request, and the
timeout for the SSE connection.
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
cached and only fetched from the server once. If `False`, the tools list will be
fetched from the server on each call to `list_tools()`. The cache can be
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
if you know the server will not change its tools list, because it can drastically
improve latency (by avoiding a round-trip to the server every time).
name: A readable name for the server. If not provided, we'll create one from the
URL.
"""
super().__init__(cache_tools_list)
self.params = params
self._name = name or f"sse: {self.params['url']}"
def create_streams(
self,
) -> AbstractAsyncContextManager[
tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
]:
"""Create the streams for the server."""
return sse_client(
url=self.params["url"],
headers=self.params.get("headers", None),
timeout=self.params.get("timeout", 5),
sse_read_timeout=self.params.get(
"sse_read_timeout", 60 * 5
),
)
@property
def name(self) -> str:
"""A readable name for the server."""
return self._name
def mcp_flow_get_tool_schema(
params: MCPServerSseParams,
) -> MCPServer:
server = MCPServerSse(params, cache_tools_list=True)
# Connect the server
asyncio.run(server.connect())
# Return the server
output = asyncio.run(server.list_tools())
# Cleanup the server
asyncio.run(server.cleanup())
return output.model_dump()
def mcp_flow(
params: MCPServerSseParams,
function_call: dict[str, Any],
) -> MCPServer:
server = MCPServerSse(params, cache_tools_list=True)
# Connect the server
asyncio.run(server.connect())
# Return the server
output = asyncio.run(server.call_tool(function_call))
output = output.model_dump()
# Cleanup the server
asyncio.run(server.cleanup())
return any_to_str(output)
def batch_mcp_flow(
params: List[MCPServerSseParams],
function_call: List[dict[str, Any]] = [],
) -> MCPServer:
output_list = []
for param in params:
output = mcp_flow(param, function_call)
output_list.append(output)
return output_list"

@ -0,0 +1,392 @@
this si the oreginal firle siese for integration base dont his update it not for sepcief case the mcop_inteation is used it inteage in aget firle"from __future__ import annotations
from typing import Any, List
from loguru import logger
import abc
import asyncio
from contextlib import AbstractAsyncContextManager, AsyncExitStack
from pathlib import Path
from typing import Literal
from anyio.streams.memory import (
MemoryObjectReceiveStream,
MemoryObjectSendStream,
)
from mcp import (
ClientSession,
StdioServerParameters,
Tool as MCPTool,
stdio_client,
)
from mcp.client.sse import sse_client
from mcp.types import CallToolResult, JSONRPCMessage
from typing_extensions import NotRequired, TypedDict
from swarms.utils.any_to_str import any_to_str
class MCPServer(abc.ABC):
"""Base class for Model Context Protocol servers."""
@abc.abstractmethod
async def connect(self):
"""Connect to the server. For example, this might mean spawning a subprocess or
opening a network connection. The server is expected to remain connected until
`cleanup()` is called.
"""
pass
@property
@abc.abstractmethod
def name(self) -> str:
"""A readable name for the server."""
pass
@abc.abstractmethod
async def cleanup(self):
"""Cleanup the server. For example, this might mean closing a subprocess or
closing a network connection.
"""
pass
@abc.abstractmethod
async def list_tools(self) -> list[MCPTool]:
"""List the tools available on the server."""
pass
@abc.abstractmethod
async def call_tool(
self, tool_name: str, arguments: dict[str, Any] | None
) -> CallToolResult:
"""Invoke a tool on the server."""
pass
class _MCPServerWithClientSession(MCPServer, abc.ABC):
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
def __init__(self, cache_tools_list: bool):
"""
Args:
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
cached and only fetched from the server once. If `False`, the tools list will be
fetched from the server on each call to `list_tools()`. The cache can be invalidated
by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
server will not change its tools list, because it can drastically improve latency
(by avoiding a round-trip to the server every time).
"""
self.session: ClientSession | None = None
self.exit_stack: AsyncExitStack = AsyncExitStack()
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
self.cache_tools_list = cache_tools_list
# The cache is always dirty at startup, so that we fetch tools at least once
self._cache_dirty = True
self._tools_list: list[MCPTool] | None = None
@abc.abstractmethod
def create_streams(
self,
) -> AbstractAsyncContextManager[
tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
]:
"""Create the streams for the server."""
pass
async def __aenter__(self):
await self.connect()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.cleanup()
def invalidate_tools_cache(self):
"""Invalidate the tools cache."""
self._cache_dirty = True
async def connect(self):
"""Connect to the server."""
try:
transport = await self.exit_stack.enter_async_context(
self.create_streams()
)
read, write = transport
session = await self.exit_stack.enter_async_context(
ClientSession(read, write)
)
await session.initialize()
self.session = session
except Exception as e:
logger.error(f"Error initializing MCP server: {e}")
await self.cleanup()
raise
async def list_tools(self) -> list[MCPTool]:
"""List the tools available on the server."""
if not self.session:
raise Exception(
"Server not initialized. Make sure you call `connect()` first."
)
# Return from cache if caching is enabled, we have tools, and the cache is not dirty
if (
self.cache_tools_list
and not self._cache_dirty
and self._tools_list
):
return self._tools_list
# Reset the cache dirty to False
self._cache_dirty = False
# Fetch the tools from the server
self._tools_list = (await self.session.list_tools()).tools
return self._tools_list
async def call_tool(
self, arguments: dict[str, Any] | None
) -> CallToolResult:
"""Invoke a tool on the server."""
tool_name = arguments.get("tool_name") or arguments.get(
"name"
)
if not tool_name:
raise Exception("No tool name found in arguments")
if not self.session:
raise Exception(
"Server not initialized. Make sure you call `connect()` first."
)
return await self.session.call_tool(tool_name, arguments)
async def cleanup(self):
"""Cleanup the server."""
async with self._cleanup_lock:
try:
await self.exit_stack.aclose()
self.session = None
except Exception as e:
logger.error(f"Error cleaning up server: {e}")
class MCPServerStdioParams(TypedDict):
"""Mirrors `mcp.client.stdio.StdioServerParameters`, but lets you pass params without another
import.
"""
command: str
"""The executable to run to start the server. For example, `python` or `node`."""
args: NotRequired[list[str]]
"""Command line args to pass to the `command` executable. For example, `['foo.py']` or
`['server.js', '--port', '8080']`."""
env: NotRequired[dict[str, str]]
"""The environment variables to set for the server. ."""
cwd: NotRequired[str | Path]
"""The working directory to use when spawning the process."""
encoding: NotRequired[str]
"""The text encoding used when sending/receiving messages to the server. Defaults to `utf-8`."""
encoding_error_handler: NotRequired[
Literal["strict", "ignore", "replace"]
]
"""The text encoding error handler. Defaults to `strict`.
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
explanations of possible values.
"""
class MCPServerStdio(_MCPServerWithClientSession):
"""MCP server implementation that uses the stdio transport. See the [spec]
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) for
details.
"""
def __init__(
self,
params: MCPServerStdioParams,
cache_tools_list: bool = False,
name: str | None = None,
):
"""Create a new MCP server based on the stdio transport.
Args:
params: The params that configure the server. This includes the command to run to
start the server, the args to pass to the command, the environment variables to
set for the server, the working directory to use when spawning the process, and
the text encoding used when sending/receiving messages to the server.
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
cached and only fetched from the server once. If `False`, the tools list will be
fetched from the server on each call to `list_tools()`. The cache can be
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
if you know the server will not change its tools list, because it can drastically
improve latency (by avoiding a round-trip to the server every time).
name: A readable name for the server. If not provided, we'll create one from the
command.
"""
super().__init__(cache_tools_list)
self.params = StdioServerParameters(
command=params["command"],
args=params.get("args", []),
env=params.get("env"),
cwd=params.get("cwd"),
encoding=params.get("encoding", "utf-8"),
encoding_error_handler=params.get(
"encoding_error_handler", "strict"
),
)
self._name = name or f"stdio: {self.params.command}"
def create_streams(
self,
) -> AbstractAsyncContextManager[
tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
]:
"""Create the streams for the server."""
return stdio_client(self.params)
@property
def name(self) -> str:
"""A readable name for the server."""
return self._name
class MCPServerSseParams(TypedDict):
"""Mirrors the params in`mcp.client.sse.sse_client`."""
url: str
"""The URL of the server."""
headers: NotRequired[dict[str, str]]
"""The headers to send to the server."""
timeout: NotRequired[float]
"""The timeout for the HTTP request. Defaults to 5 seconds."""
sse_read_timeout: NotRequired[float]
"""The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""
class MCPServerSse(_MCPServerWithClientSession):
"""MCP server implementation that uses the HTTP with SSE transport. See the [spec]
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse)
for details.
"""
def __init__(
self,
params: MCPServerSseParams,
cache_tools_list: bool = False,
name: str | None = None,
):
"""Create a new MCP server based on the HTTP with SSE transport.
Args:
params: The params that configure the server. This includes the URL of the server,
the headers to send to the server, the timeout for the HTTP request, and the
timeout for the SSE connection.
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
cached and only fetched from the server once. If `False`, the tools list will be
fetched from the server on each call to `list_tools()`. The cache can be
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
if you know the server will not change its tools list, because it can drastically
improve latency (by avoiding a round-trip to the server every time).
name: A readable name for the server. If not provided, we'll create one from the
URL.
"""
super().__init__(cache_tools_list)
self.params = params
self._name = name or f"sse: {self.params['url']}"
def create_streams(
self,
) -> AbstractAsyncContextManager[
tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
]:
"""Create the streams for the server."""
return sse_client(
url=self.params["url"],
headers=self.params.get("headers", None),
timeout=self.params.get("timeout", 5),
sse_read_timeout=self.params.get(
"sse_read_timeout", 60 * 5
),
)
@property
def name(self) -> str:
"""A readable name for the server."""
return self._name
def mcp_flow_get_tool_schema(
params: MCPServerSseParams,
) -> MCPServer:
server = MCPServerSse(params, cache_tools_list=True)
# Connect the server
asyncio.run(server.connect())
# Return the server
output = asyncio.run(server.list_tools())
# Cleanup the server
asyncio.run(server.cleanup())
return output.model_dump()
def mcp_flow(
params: MCPServerSseParams,
function_call: dict[str, Any],
) -> MCPServer:
server = MCPServerSse(params, cache_tools_list=True)
# Connect the server
asyncio.run(server.connect())
# Return the server
output = asyncio.run(server.call_tool(function_call))
output = output.model_dump()
# Cleanup the server
asyncio.run(server.cleanup())
return any_to_str(output)
def batch_mcp_flow(
params: List[MCPServerSseParams],
function_call: List[dict[str, Any]] = [],
) -> MCPServer:
output_list = []
for param in params:
output = mcp_flow(param, function_call)
output_list.append(output)
return output_list"

@ -1,507 +1,255 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional
from typing_extensions import NotRequired, TypedDict
from contextlib import AbstractAsyncContextManager
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client, StdioServerParameters
from loguru import logger
import abc
import asyncio
from pathlib import Path
from typing import Literal
from anyio.streams.memory import (
MemoryObjectReceiveStream,
MemoryObjectSendStream,
)
from mcp.types import CallToolResult, JSONRPCMessage # Kept for backward compatibility, might be removed later
from swarms.utils.any_to_str import any_to_str
from mcp import (
ClientSession as OldClientSession, # Kept for backward compatibility with stdio
StdioServerParameters,
Tool as MCPTool,
stdio_client,
)
class MCPServer(abc.ABC):
"""Base class for Model Context Protocol servers."""
@abc.abstractmethod
async def connect(self):
"""Connect to the server. For example, this might mean spawning a subprocess or
opening a network connection. The server is expected to remain connected until
`cleanup()` is called.
"""
pass
@property
@abc.abstractmethod
def name(self) -> str:
"""A readable name for the server."""
pass
@abc.abstractmethod
async def cleanup(self):
"""Cleanup the server. For example, this might mean closing a subprocess or
closing a network connection.
"""
pass
@abc.abstractmethod
async def list_tools(self) -> list[Any]: # Changed to Any for flexibility
"""List the tools available on the server."""
pass
@abc.abstractmethod
async def call_tool(
self, tool_name: str, arguments: dict[str, Any] | None
) -> CallToolResult: # Kept for backward compatibility, might be removed later
"""Invoke a tool on the server."""
pass
class _MCPServerWithClientSession(MCPServer, abc.ABC):
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
def __init__(self, cache_tools_list: bool):
"""
Args:
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
cached and only fetched from the server once. If `False`, the tools list will be
fetched from the server on each call to `list_tools()`. The cache can be invalidated
by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
server will not change its tools list, because it can drastically improve latency
(by avoiding a round-trip to the server every time).
"""
self.session: ClientSession | None = None
self.exit_stack: AsyncExitStack = AsyncExitStack()
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
self.cache_tools_list = cache_tools_list
# The cache is always dirty at startup, so that we fetch tools at least once
self._cache_dirty = True
self._tools_list: list[Any] | None = None # Changed to Any for flexibility
@abc.abstractmethod
def create_streams(
self,
) -> AbstractAsyncContextManager[
tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
]:
"""Create the streams for the server."""
pass
async def __aenter__(self):
await self.connect()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.cleanup()
def invalidate_tools_cache(self):
"""Invalidate the tools cache."""
self._cache_dirty = True
async def connect(self):
"""Connect to the server."""
try:
transport = await self.exit_stack.enter_async_context(
self.create_streams()
)
read, write = transport
session = await self.exit_stack.enter_async_context(
ClientSession(read, write)
)
await session.initialize()
self.session = session
except Exception as e:
logger.error(f"Error initializing MCP server: {e}")
await self.cleanup()
raise
async def list_tools(self) -> list[Any]: # Changed to Any for flexibility
"""List the tools available on the server."""
if not self.session:
raise Exception(
"Server not initialized. Make sure you call `connect()` first."
)
from __future__ import annotations
import abc
import asyncio
from contextlib import AbstractAsyncContextManager, AsyncExitStack
from pathlib import Path
from typing import Any, Dict, List, Optional, Literal
from typing_extensions import NotRequired, TypedDict
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from loguru import logger
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
from mcp.client.sse import sse_client
from mcp.types import CallToolResult, JSONRPCMessage
from swarms.utils.any_to_str import any_to_str
class MCPServer(abc.ABC):
"""Base class for Model Context Protocol servers."""
@abc.abstractmethod
async def connect(self) -> None:
"""Establish connection to the MCP server."""
pass
@property
@abc.abstractmethod
def name(self) -> str:
"""Human-readable server name."""
pass
@abc.abstractmethod
async def cleanup(self) -> None:
"""Clean up resources and close connection."""
pass
@abc.abstractmethod
async def list_tools(self) -> List[MCPTool]:
"""List available MCP tools on the server."""
pass
@abc.abstractmethod
async def call_tool(
self, tool_name: str, arguments: Dict[str, Any] | None
) -> CallToolResult:
"""Invoke a tool by name with provided arguments."""
pass
class _MCPServerWithClientSession(MCPServer, abc.ABC):
"""Mixin providing ClientSession-based MCP communication."""
def __init__(self, cache_tools_list: bool = False):
self.session: Optional[ClientSession] = None
self.exit_stack: AsyncExitStack = AsyncExitStack()
self._cleanup_lock = asyncio.Lock()
self.cache_tools_list = cache_tools_list
self._cache_dirty = True
self._tools_list: Optional[List[MCPTool]] = None
@abc.abstractmethod
def create_streams(
self
) -> AbstractAsyncContextManager[
tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
]:
"""Supply the read/write streams for the MCP transport."""
pass
async def __aenter__(self) -> MCPServer:
await self.connect()
return self # type: ignore
async def __aexit__(self, exc_type, exc_value, tb) -> None:
await self.cleanup()
async def connect(self) -> None:
"""Initialize transport and ClientSession."""
try:
transport = await self.exit_stack.enter_async_context(
self.create_streams()
)
read, write = transport
session = await self.exit_stack.enter_async_context(
ClientSession(read, write)
)
await session.initialize()
self.session = session
except Exception as e:
logger.error(f"Error initializing MCP server: {e}")
await self.cleanup()
raise
async def cleanup(self) -> None:
"""Close session and transport."""
async with self._cleanup_lock:
try:
await self.exit_stack.aclose()
except Exception as e:
logger.error(f"Error during cleanup: {e}")
finally:
self.session = None
async def list_tools(self) -> List[MCPTool]:
if not self.session:
raise RuntimeError("Server not connected. Call connect() first.")
if self.cache_tools_list and not self._cache_dirty and self._tools_list:
return self._tools_list
self._cache_dirty = False
self._tools_list = (await self.session.list_tools()).tools
return self._tools_list # type: ignore
async def call_tool(
self, tool_name: str | None = None, arguments: Dict[str, Any] | None = None
) -> CallToolResult:
if not arguments:
raise ValueError("Arguments dict is required to call a tool")
name = tool_name or arguments.get("tool_name") or arguments.get("name")
if not name:
raise ValueError("Tool name missing in arguments")
if not self.session:
raise RuntimeError("Server not connected. Call connect() first.")
return await self.session.call_tool(name, arguments)
class MCPServerStdioParams(TypedDict):
"""Configuration for stdio transport."""
command: str
args: NotRequired[List[str]]
env: NotRequired[Dict[str, str]]
cwd: NotRequired[str | Path]
encoding: NotRequired[str]
encoding_error_handler: NotRequired[Literal["strict", "ignore", "replace"]]
class MCPServerStdio(_MCPServerWithClientSession):
"""MCP server over stdio transport."""
def __init__(
self,
params: MCPServerStdioParams,
cache_tools_list: bool = False,
name: Optional[str] = None,
):
super().__init__(cache_tools_list)
self.params = StdioServerParameters(
command=params["command"],
args=params.get("args", []),
env=params.get("env"),
cwd=params.get("cwd"),
encoding=params.get("encoding", "utf-8"),
encoding_error_handler=params.get("encoding_error_handler", "strict"),
)
self._name = name or f"stdio:{self.params.command}"
def create_streams(self) -> AbstractAsyncContextManager[
tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
]:
return stdio_client(self.params)
@property
def name(self) -> str:
return self._name
class MCPServerSseParams(TypedDict):
"""Configuration for HTTP+SSE transport."""
url: str
headers: NotRequired[Dict[str, str]]
timeout: NotRequired[float]
sse_read_timeout: NotRequired[float]
class MCPServerSse(_MCPServerWithClientSession):
"""MCP server over HTTP with SSE transport."""
def __init__(
self,
params: MCPServerSseParams,
cache_tools_list: bool = False,
name: Optional[str] = None,
):
super().__init__(cache_tools_list)
self.params = params
self._name = name or f"sse:{params['url']}"
def create_streams(self) -> AbstractAsyncContextManager[
tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
]:
return sse_client(
url=self.params["url"],
headers=self.params.get("headers"),
timeout=self.params.get("timeout", 5),
sse_read_timeout=self.params.get("sse_read_timeout", 300),
)
@property
def name(self) -> str:
return self._name
async def call_tool_fast(
server: MCPServerSse, payload: Dict[str, Any] | str
) -> Any:
try:
await server.connect()
result = await server.call_tool(arguments=payload if isinstance(payload, dict) else None)
return result
finally:
await server.cleanup()
# Return from cache if caching is enabled, we have tools, and the cache is not dirty
if (
self.cache_tools_list
and not self._cache_dirty
and self._tools_list
):
return self._tools_list
# Reset the cache dirty to False
self._cache_dirty = False
# Fetch the tools from the server
self._tools_list = (await self.session.list_tools()).tools
return self._tools_list
async def call_tool(
self, arguments: dict[str, Any] | None
) -> CallToolResult: # Kept for backward compatibility, might be removed later
"""Invoke a tool on the server."""
tool_name = arguments.get("tool_name") or arguments.get(
"name"
)
if not tool_name:
raise Exception("No tool name found in arguments")
if not self.session:
raise Exception(
"Server not initialized. Make sure you call `connect()` first."
)
return await self.session.call_tool(tool_name, arguments)
async def mcp_flow_get_tool_schema(
params: MCPServerSseParams,
) -> Any:
async with MCPServerSse(params) as server:
tools = await server.list_tools()
return tools
async def cleanup(self):
"""Cleanup the server."""
async with self._cleanup_lock:
try:
await self.exit_stack.aclose()
self.session = None
except Exception as e:
logger.error(f"Error cleaning up server: {e}")
class MCPServerStdioParams(TypedDict):
"""Mirrors `mcp.client.stdio.StdioServerParameters`, but lets you pass params without another
import.
"""
command: str
"""The executable to run to start the server. For example, `python` or `node`."""
args: NotRequired[list[str]]
"""Command line args to pass to the `command` executable. For example, `['foo.py']` or
`['server.js', '--port', '8080']`."""
env: NotRequired[dict[str, str]]
"""The environment variables to set for the server. ."""
cwd: NotRequired[str | Path]
"""The working directory to use when spawning the process."""
encoding: NotRequired[str]
"""The text encoding used when sending/receiving messages to the server. Defaults to `utf-8`."""
encoding_error_handler: NotRequired[
Literal["strict", "ignore", "replace"]
]
"""The text encoding error handler. Defaults to `strict`.
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
explanations of possible values.
"""
class MCPServerStdio(_MCPServerWithClientSession):
"""MCP server implementation that uses the stdio transport. See the [spec]
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) for
details.
"""
def __init__(
self,
params: MCPServerStdioParams,
cache_tools_list: bool = False,
name: str | None = None,
):
"""Create a new MCP server based on the stdio transport.
Args:
params: The params that configure the server. This includes the command to run to
start the server, the args to pass to the command, the environment variables to
set for the server, the working directory to use when spawning the process, and
the text encoding used when sending/receiving messages to the server.
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
cached and only fetched from the server once. If `False`, the tools list will be
fetched from the server on each call to `list_tools()`. The cache can be
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
if you know the server will not change its tools list, because it can drastically
improve latency (by avoiding a round-trip to the server every time).
name: A readable name for the server. If not provided, we'll create one from the
command.
"""
super().__init__(cache_tools_list)
self.params = StdioServerParameters(
command=params["command"],
args=params.get("args", []),
env=params.get("env"),
cwd=params.get("cwd"),
encoding=params.get("encoding", "utf-8"),
encoding_error_handler=params.get(
"encoding_error_handler", "strict"
),
)
self._name = name or f"stdio: {self.params.command}"
def create_streams(
self,
) -> AbstractAsyncContextManager[
tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage],
]
]:
"""Create the streams for the server."""
return stdio_client(self.params)
@property
def name(self) -> str:
"""A readable name for the server."""
return self._name
class MCPServerSseParams(TypedDict):
"""Mirrors the params in`mcp.client.sse.sse_client`."""
url: str
"""The URL of the server."""
headers: NotRequired[dict[str, str]]
"""The headers to send to the server."""
timeout: NotRequired[float]
"""The timeout for the HTTP request. Defaults to 5 seconds."""
sse_read_timeout: NotRequired[float]
"""The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""
class MCPServerSse:
def __init__(self, params: MCPServerSseParams):
self.params = params
self.client: Optional[ClientSession] = None
self._connection_lock = asyncio.Lock()
self.messages = [] # Store messages instead of using conversation
self.preserve_format = True # Flag to preserve original formatting
async def connect(self):
"""Connect to the MCP server with proper locking."""
async with self._connection_lock:
if not self.client:
transport = await self.create_streams()
read_stream, write_stream = transport
self.client = ClientSession(read_stream=read_stream, write_stream=write_stream)
await self.client.initialize()
def create_streams(self, **kwargs) -> AbstractAsyncContextManager[Any]:
return sse_client(
url=self.params["url"],
headers=self.params.get("headers", None),
timeout=self.params.get("timeout", 5),
sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5),
)
def _parse_input(self, payload: Any) -> dict:
"""Parse input while preserving original format."""
if isinstance(payload, dict):
return payload
if isinstance(payload, str):
async def mcp_flow(
params: MCPServerSseParams,
function_call: Dict[str, Any] | str,
) -> Any:
async with MCPServerSse(params) as server:
return await call_tool_fast(server, function_call)
async def _call_one_server(
params: MCPServerSseParams, payload: Dict[str, Any] | str
) -> Any:
server = MCPServerSse(params)
try:
# Try to parse as JSON
import json
return json.loads(payload)
except json.JSONDecodeError:
# Check if it's a math operation
import re
# Pattern matching for basic math operations
add_pattern = r"(?i)(?:what\s+is\s+)?(\d+)\s*(?:plus|\+)\s*(\d+)"
mult_pattern = r"(?i)(?:multiply|times|\*)\s*(\d+)\s*(?:and|by)?\s*(\d+)"
div_pattern = r"(?i)(?:divide)\s*(\d+)\s*(?:by)\s*(\d+)"
# Check for addition
if match := re.search(add_pattern, payload):
a, b = map(int, match.groups())
return {"tool_name": "add", "a": a, "b": b}
# Check for multiplication
if match := re.search(mult_pattern, payload):
a, b = map(int, match.groups())
return {"tool_name": "multiply", "a": a, "b": b}
# Check for division
if match := re.search(div_pattern, payload):
a, b = map(int, match.groups())
return {"tool_name": "divide", "a": a, "b": b}
# Default to text input if no pattern matches
return {"text": payload}
return {"text": str(payload)}
def _format_output(self, result: Any, original_input: Any) -> str:
"""Format output based on input type and result."""
if not self.preserve_format:
return str(result)
try:
if isinstance(result, (int, float)):
# For numeric results, format based on operation
if isinstance(original_input, dict):
tool_name = original_input.get("tool_name", "")
if tool_name == "add":
return f"{original_input['a']} + {original_input['b']} = {result}"
elif tool_name == "multiply":
return f"{original_input['a']} * {original_input['b']} = {result}"
elif tool_name == "divide":
return f"{original_input['a']} / {original_input['b']} = {result}"
return str(result)
elif isinstance(result, dict):
return json.dumps(result, indent=2)
else:
return str(result)
except Exception as e:
logger.error(f"Error formatting output: {e}")
return str(result)
async def call_tool(self, payload: Any) -> Any:
"""Call a tool on the MCP server with support for various input formats."""
if not self.client:
raise RuntimeError("Not connected to MCP server")
# Store original input for formatting
original_input = payload
# Parse input
parsed_payload = self._parse_input(payload)
# Add message to history
self.messages.append({
"role": "user",
"content": str(payload),
"parsed": parsed_payload
})
try:
result = await self.client.call_tool(parsed_payload)
formatted_result = self._format_output(result, original_input)
self.messages.append({
"role": "assistant",
"content": formatted_result,
"raw_result": result
})
return formatted_result
except Exception as e:
error_msg = f"Error calling tool: {str(e)}"
self.messages.append({
"role": "error",
"content": error_msg,
"original_input": payload
})
raise
async def cleanup(self):
"""Clean up the connection with proper locking."""
async with self._connection_lock:
if self.client:
await self.client.close()
self.client = None
async def list_tools(self) -> list[Any]:
"""List available tools with proper error handling."""
if not self.client:
raise RuntimeError("Not connected to MCP server")
try:
return await self.client.list_tools()
except Exception as e:
logger.error(f"Error listing tools: {e}")
return []
async def call_tool_fast(server: MCPServerSse, payload: dict[str, Any] | str):
"""
Convenience wrapper that opens calls closes in one shot with proper error handling.
"""
try:
await server.connect()
result = await server.call_tool(payload)
return result.model_dump() if hasattr(result, "model_dump") else result
except Exception as e:
logger.error(f"Error in call_tool_fast: {e}")
raise
finally:
await server.cleanup()
async def mcp_flow_get_tool_schema(
params: MCPServerSseParams,
) -> Any:
"""Get tool schema with proper error handling."""
try:
async with MCPServerSse(params) as server:
tools = await server.list_tools()
return tools.model_dump() if hasattr(tools, "model_dump") else tools
except Exception as e:
logger.error(f"Error getting tool schema: {e}")
raise
async def mcp_flow(
params: MCPServerSseParams,
function_call: dict[str, Any] | str,
) -> Any:
"""Execute MCP flow with proper error handling."""
try:
async with MCPServerSse(params) as server:
return await call_tool_fast(server, function_call)
except Exception as e:
logger.error(f"Error in MCP flow: {e}")
raise
async def _call_one_server(param: MCPServerSseParams, payload: dict[str, Any] | str) -> Any:
"""Make a call to a single MCP server with proper async context management."""
try:
server = MCPServerSse(param)
await server.connect()
result = await server.call_tool(payload)
return result
except Exception as e:
logger.error(f"Error calling server: {e}")
raise
finally:
if 'server' in locals():
await server.cleanup()
def batch_mcp_flow(params: List[MCPServerSseParams], payload: dict[str, Any] | str) -> List[Any]:
"""Blocking helper that fans out to all MCP servers in params."""
try:
return asyncio.run(_batch(params, payload))
except Exception as e:
logger.error(f"Error in batch_mcp_flow: {e}")
return []
async def _batch(params: List[MCPServerSseParams], payload: dict[str, Any] | str) -> List[Any]:
"""Fan out to all MCP servers asynchronously and gather results."""
try:
coros = [_call_one_server(p, payload) for p in params]
results = await asyncio.gather(*coros, return_exceptions=True)
# Filter out exceptions and convert to strings
return [any_to_str(r) for r in results if not isinstance(r, Exception)]
except Exception as e:
logger.error(f"Error in batch processing: {e}")
return []
await server.connect()
return await server.call_tool(arguments=payload if isinstance(payload, dict) else None)
finally:
await server.cleanup()
def batch_mcp_flow(
params: List[MCPServerSseParams], payload: Dict[str, Any] | str
) -> List[Any]:
return asyncio.run(
asyncio.gather(*[_call_one_server(p, payload) for p in params])
)

Loading…
Cancel
Save