parent
925709de6e
commit
d75bbed8ee
@ -0,0 +1,392 @@
|
||||
this si the oreginal firle siese for integration base dont his update it not for sepcief case the mcop_inteation is used it inteage in aget firle"from __future__ import annotations
|
||||
|
||||
from typing import Any, List
|
||||
|
||||
|
||||
from loguru import logger
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from anyio.streams.memory import (
|
||||
MemoryObjectReceiveStream,
|
||||
MemoryObjectSendStream,
|
||||
)
|
||||
from mcp import (
|
||||
ClientSession,
|
||||
StdioServerParameters,
|
||||
Tool as MCPTool,
|
||||
stdio_client,
|
||||
)
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.types import CallToolResult, JSONRPCMessage
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from swarms.utils.any_to_str import any_to_str
|
||||
|
||||
|
||||
class MCPServer(abc.ABC):
|
||||
"""Base class for Model Context Protocol servers."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def connect(self):
|
||||
"""Connect to the server. For example, this might mean spawning a subprocess or
|
||||
opening a network connection. The server is expected to remain connected until
|
||||
`cleanup()` is called.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def name(self) -> str:
|
||||
"""A readable name for the server."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def cleanup(self):
|
||||
"""Cleanup the server. For example, this might mean closing a subprocess or
|
||||
closing a network connection.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_tools(self) -> list[MCPTool]:
|
||||
"""List the tools available on the server."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def call_tool(
|
||||
self, tool_name: str, arguments: dict[str, Any] | None
|
||||
) -> CallToolResult:
|
||||
"""Invoke a tool on the server."""
|
||||
pass
|
||||
|
||||
|
||||
class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
||||
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
|
||||
|
||||
def __init__(self, cache_tools_list: bool):
|
||||
"""
|
||||
Args:
|
||||
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
|
||||
cached and only fetched from the server once. If `False`, the tools list will be
|
||||
fetched from the server on each call to `list_tools()`. The cache can be invalidated
|
||||
by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
|
||||
server will not change its tools list, because it can drastically improve latency
|
||||
(by avoiding a round-trip to the server every time).
|
||||
"""
|
||||
self.session: ClientSession | None = None
|
||||
self.exit_stack: AsyncExitStack = AsyncExitStack()
|
||||
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
|
||||
self.cache_tools_list = cache_tools_list
|
||||
|
||||
# The cache is always dirty at startup, so that we fetch tools at least once
|
||||
self._cache_dirty = True
|
||||
self._tools_list: list[MCPTool] | None = None
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_streams(
|
||||
self,
|
||||
) -> AbstractAsyncContextManager[
|
||||
tuple[
|
||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
MemoryObjectSendStream[JSONRPCMessage],
|
||||
]
|
||||
]:
|
||||
"""Create the streams for the server."""
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
await self.cleanup()
|
||||
|
||||
def invalidate_tools_cache(self):
|
||||
"""Invalidate the tools cache."""
|
||||
self._cache_dirty = True
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to the server."""
|
||||
try:
|
||||
transport = await self.exit_stack.enter_async_context(
|
||||
self.create_streams()
|
||||
)
|
||||
read, write = transport
|
||||
session = await self.exit_stack.enter_async_context(
|
||||
ClientSession(read, write)
|
||||
)
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing MCP server: {e}")
|
||||
await self.cleanup()
|
||||
raise
|
||||
|
||||
async def list_tools(self) -> list[MCPTool]:
|
||||
"""List the tools available on the server."""
|
||||
if not self.session:
|
||||
raise Exception(
|
||||
"Server not initialized. Make sure you call `connect()` first."
|
||||
)
|
||||
|
||||
# Return from cache if caching is enabled, we have tools, and the cache is not dirty
|
||||
if (
|
||||
self.cache_tools_list
|
||||
and not self._cache_dirty
|
||||
and self._tools_list
|
||||
):
|
||||
return self._tools_list
|
||||
|
||||
# Reset the cache dirty to False
|
||||
self._cache_dirty = False
|
||||
|
||||
# Fetch the tools from the server
|
||||
self._tools_list = (await self.session.list_tools()).tools
|
||||
return self._tools_list
|
||||
|
||||
async def call_tool(
|
||||
self, arguments: dict[str, Any] | None
|
||||
) -> CallToolResult:
|
||||
"""Invoke a tool on the server."""
|
||||
tool_name = arguments.get("tool_name") or arguments.get(
|
||||
"name"
|
||||
)
|
||||
|
||||
if not tool_name:
|
||||
raise Exception("No tool name found in arguments")
|
||||
|
||||
if not self.session:
|
||||
raise Exception(
|
||||
"Server not initialized. Make sure you call `connect()` first."
|
||||
)
|
||||
|
||||
return await self.session.call_tool(tool_name, arguments)
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup the server."""
|
||||
async with self._cleanup_lock:
|
||||
try:
|
||||
await self.exit_stack.aclose()
|
||||
self.session = None
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up server: {e}")
|
||||
|
||||
|
||||
class MCPServerStdioParams(TypedDict):
|
||||
"""Mirrors `mcp.client.stdio.StdioServerParameters`, but lets you pass params without another
|
||||
import.
|
||||
"""
|
||||
|
||||
command: str
|
||||
"""The executable to run to start the server. For example, `python` or `node`."""
|
||||
|
||||
args: NotRequired[list[str]]
|
||||
"""Command line args to pass to the `command` executable. For example, `['foo.py']` or
|
||||
`['server.js', '--port', '8080']`."""
|
||||
|
||||
env: NotRequired[dict[str, str]]
|
||||
"""The environment variables to set for the server. ."""
|
||||
|
||||
cwd: NotRequired[str | Path]
|
||||
"""The working directory to use when spawning the process."""
|
||||
|
||||
encoding: NotRequired[str]
|
||||
"""The text encoding used when sending/receiving messages to the server. Defaults to `utf-8`."""
|
||||
|
||||
encoding_error_handler: NotRequired[
|
||||
Literal["strict", "ignore", "replace"]
|
||||
]
|
||||
"""The text encoding error handler. Defaults to `strict`.
|
||||
|
||||
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
|
||||
explanations of possible values.
|
||||
"""
|
||||
|
||||
|
||||
class MCPServerStdio(_MCPServerWithClientSession):
|
||||
"""MCP server implementation that uses the stdio transport. See the [spec]
|
||||
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) for
|
||||
details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: MCPServerStdioParams,
|
||||
cache_tools_list: bool = False,
|
||||
name: str | None = None,
|
||||
):
|
||||
"""Create a new MCP server based on the stdio transport.
|
||||
|
||||
Args:
|
||||
params: The params that configure the server. This includes the command to run to
|
||||
start the server, the args to pass to the command, the environment variables to
|
||||
set for the server, the working directory to use when spawning the process, and
|
||||
the text encoding used when sending/receiving messages to the server.
|
||||
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
|
||||
cached and only fetched from the server once. If `False`, the tools list will be
|
||||
fetched from the server on each call to `list_tools()`. The cache can be
|
||||
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
|
||||
if you know the server will not change its tools list, because it can drastically
|
||||
improve latency (by avoiding a round-trip to the server every time).
|
||||
name: A readable name for the server. If not provided, we'll create one from the
|
||||
command.
|
||||
"""
|
||||
super().__init__(cache_tools_list)
|
||||
|
||||
self.params = StdioServerParameters(
|
||||
command=params["command"],
|
||||
args=params.get("args", []),
|
||||
env=params.get("env"),
|
||||
cwd=params.get("cwd"),
|
||||
encoding=params.get("encoding", "utf-8"),
|
||||
encoding_error_handler=params.get(
|
||||
"encoding_error_handler", "strict"
|
||||
),
|
||||
)
|
||||
|
||||
self._name = name or f"stdio: {self.params.command}"
|
||||
|
||||
def create_streams(
|
||||
self,
|
||||
) -> AbstractAsyncContextManager[
|
||||
tuple[
|
||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
MemoryObjectSendStream[JSONRPCMessage],
|
||||
]
|
||||
]:
|
||||
"""Create the streams for the server."""
|
||||
return stdio_client(self.params)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""A readable name for the server."""
|
||||
return self._name
|
||||
|
||||
|
||||
class MCPServerSseParams(TypedDict):
|
||||
"""Mirrors the params in`mcp.client.sse.sse_client`."""
|
||||
|
||||
url: str
|
||||
"""The URL of the server."""
|
||||
|
||||
headers: NotRequired[dict[str, str]]
|
||||
"""The headers to send to the server."""
|
||||
|
||||
timeout: NotRequired[float]
|
||||
"""The timeout for the HTTP request. Defaults to 5 seconds."""
|
||||
|
||||
sse_read_timeout: NotRequired[float]
|
||||
"""The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""
|
||||
|
||||
|
||||
class MCPServerSse(_MCPServerWithClientSession):
|
||||
"""MCP server implementation that uses the HTTP with SSE transport. See the [spec]
|
||||
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse)
|
||||
for details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: MCPServerSseParams,
|
||||
cache_tools_list: bool = False,
|
||||
name: str | None = None,
|
||||
):
|
||||
"""Create a new MCP server based on the HTTP with SSE transport.
|
||||
|
||||
Args:
|
||||
params: The params that configure the server. This includes the URL of the server,
|
||||
the headers to send to the server, the timeout for the HTTP request, and the
|
||||
timeout for the SSE connection.
|
||||
|
||||
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
|
||||
cached and only fetched from the server once. If `False`, the tools list will be
|
||||
fetched from the server on each call to `list_tools()`. The cache can be
|
||||
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
|
||||
if you know the server will not change its tools list, because it can drastically
|
||||
improve latency (by avoiding a round-trip to the server every time).
|
||||
|
||||
name: A readable name for the server. If not provided, we'll create one from the
|
||||
URL.
|
||||
"""
|
||||
super().__init__(cache_tools_list)
|
||||
|
||||
self.params = params
|
||||
self._name = name or f"sse: {self.params['url']}"
|
||||
|
||||
def create_streams(
|
||||
self,
|
||||
) -> AbstractAsyncContextManager[
|
||||
tuple[
|
||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
MemoryObjectSendStream[JSONRPCMessage],
|
||||
]
|
||||
]:
|
||||
"""Create the streams for the server."""
|
||||
return sse_client(
|
||||
url=self.params["url"],
|
||||
headers=self.params.get("headers", None),
|
||||
timeout=self.params.get("timeout", 5),
|
||||
sse_read_timeout=self.params.get(
|
||||
"sse_read_timeout", 60 * 5
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""A readable name for the server."""
|
||||
return self._name
|
||||
|
||||
|
||||
def mcp_flow_get_tool_schema(
|
||||
params: MCPServerSseParams,
|
||||
) -> MCPServer:
|
||||
server = MCPServerSse(params, cache_tools_list=True)
|
||||
|
||||
# Connect the server
|
||||
asyncio.run(server.connect())
|
||||
|
||||
# Return the server
|
||||
output = asyncio.run(server.list_tools())
|
||||
|
||||
# Cleanup the server
|
||||
asyncio.run(server.cleanup())
|
||||
|
||||
return output.model_dump()
|
||||
|
||||
|
||||
def mcp_flow(
|
||||
params: MCPServerSseParams,
|
||||
function_call: dict[str, Any],
|
||||
) -> MCPServer:
|
||||
server = MCPServerSse(params, cache_tools_list=True)
|
||||
|
||||
# Connect the server
|
||||
asyncio.run(server.connect())
|
||||
|
||||
# Return the server
|
||||
output = asyncio.run(server.call_tool(function_call))
|
||||
|
||||
output = output.model_dump()
|
||||
|
||||
# Cleanup the server
|
||||
asyncio.run(server.cleanup())
|
||||
|
||||
return any_to_str(output)
|
||||
|
||||
|
||||
def batch_mcp_flow(
|
||||
params: List[MCPServerSseParams],
|
||||
function_call: List[dict[str, Any]] = [],
|
||||
) -> MCPServer:
|
||||
output_list = []
|
||||
|
||||
for param in params:
|
||||
output = mcp_flow(param, function_call)
|
||||
output_list.append(output)
|
||||
|
||||
return output_list"
|
@ -0,0 +1,392 @@
|
||||
this si the oreginal firle siese for integration base dont his update it not for sepcief case the mcop_inteation is used it inteage in aget firle"from __future__ import annotations
|
||||
|
||||
from typing import Any, List
|
||||
|
||||
|
||||
from loguru import logger
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from anyio.streams.memory import (
|
||||
MemoryObjectReceiveStream,
|
||||
MemoryObjectSendStream,
|
||||
)
|
||||
from mcp import (
|
||||
ClientSession,
|
||||
StdioServerParameters,
|
||||
Tool as MCPTool,
|
||||
stdio_client,
|
||||
)
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.types import CallToolResult, JSONRPCMessage
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from swarms.utils.any_to_str import any_to_str
|
||||
|
||||
|
||||
class MCPServer(abc.ABC):
|
||||
"""Base class for Model Context Protocol servers."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def connect(self):
|
||||
"""Connect to the server. For example, this might mean spawning a subprocess or
|
||||
opening a network connection. The server is expected to remain connected until
|
||||
`cleanup()` is called.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def name(self) -> str:
|
||||
"""A readable name for the server."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def cleanup(self):
|
||||
"""Cleanup the server. For example, this might mean closing a subprocess or
|
||||
closing a network connection.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_tools(self) -> list[MCPTool]:
|
||||
"""List the tools available on the server."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def call_tool(
|
||||
self, tool_name: str, arguments: dict[str, Any] | None
|
||||
) -> CallToolResult:
|
||||
"""Invoke a tool on the server."""
|
||||
pass
|
||||
|
||||
|
||||
class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
||||
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
|
||||
|
||||
def __init__(self, cache_tools_list: bool):
|
||||
"""
|
||||
Args:
|
||||
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
|
||||
cached and only fetched from the server once. If `False`, the tools list will be
|
||||
fetched from the server on each call to `list_tools()`. The cache can be invalidated
|
||||
by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
|
||||
server will not change its tools list, because it can drastically improve latency
|
||||
(by avoiding a round-trip to the server every time).
|
||||
"""
|
||||
self.session: ClientSession | None = None
|
||||
self.exit_stack: AsyncExitStack = AsyncExitStack()
|
||||
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
|
||||
self.cache_tools_list = cache_tools_list
|
||||
|
||||
# The cache is always dirty at startup, so that we fetch tools at least once
|
||||
self._cache_dirty = True
|
||||
self._tools_list: list[MCPTool] | None = None
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_streams(
|
||||
self,
|
||||
) -> AbstractAsyncContextManager[
|
||||
tuple[
|
||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
MemoryObjectSendStream[JSONRPCMessage],
|
||||
]
|
||||
]:
|
||||
"""Create the streams for the server."""
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
await self.cleanup()
|
||||
|
||||
def invalidate_tools_cache(self):
|
||||
"""Invalidate the tools cache."""
|
||||
self._cache_dirty = True
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to the server."""
|
||||
try:
|
||||
transport = await self.exit_stack.enter_async_context(
|
||||
self.create_streams()
|
||||
)
|
||||
read, write = transport
|
||||
session = await self.exit_stack.enter_async_context(
|
||||
ClientSession(read, write)
|
||||
)
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing MCP server: {e}")
|
||||
await self.cleanup()
|
||||
raise
|
||||
|
||||
async def list_tools(self) -> list[MCPTool]:
|
||||
"""List the tools available on the server."""
|
||||
if not self.session:
|
||||
raise Exception(
|
||||
"Server not initialized. Make sure you call `connect()` first."
|
||||
)
|
||||
|
||||
# Return from cache if caching is enabled, we have tools, and the cache is not dirty
|
||||
if (
|
||||
self.cache_tools_list
|
||||
and not self._cache_dirty
|
||||
and self._tools_list
|
||||
):
|
||||
return self._tools_list
|
||||
|
||||
# Reset the cache dirty to False
|
||||
self._cache_dirty = False
|
||||
|
||||
# Fetch the tools from the server
|
||||
self._tools_list = (await self.session.list_tools()).tools
|
||||
return self._tools_list
|
||||
|
||||
async def call_tool(
|
||||
self, arguments: dict[str, Any] | None
|
||||
) -> CallToolResult:
|
||||
"""Invoke a tool on the server."""
|
||||
tool_name = arguments.get("tool_name") or arguments.get(
|
||||
"name"
|
||||
)
|
||||
|
||||
if not tool_name:
|
||||
raise Exception("No tool name found in arguments")
|
||||
|
||||
if not self.session:
|
||||
raise Exception(
|
||||
"Server not initialized. Make sure you call `connect()` first."
|
||||
)
|
||||
|
||||
return await self.session.call_tool(tool_name, arguments)
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup the server."""
|
||||
async with self._cleanup_lock:
|
||||
try:
|
||||
await self.exit_stack.aclose()
|
||||
self.session = None
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up server: {e}")
|
||||
|
||||
|
||||
class MCPServerStdioParams(TypedDict):
|
||||
"""Mirrors `mcp.client.stdio.StdioServerParameters`, but lets you pass params without another
|
||||
import.
|
||||
"""
|
||||
|
||||
command: str
|
||||
"""The executable to run to start the server. For example, `python` or `node`."""
|
||||
|
||||
args: NotRequired[list[str]]
|
||||
"""Command line args to pass to the `command` executable. For example, `['foo.py']` or
|
||||
`['server.js', '--port', '8080']`."""
|
||||
|
||||
env: NotRequired[dict[str, str]]
|
||||
"""The environment variables to set for the server. ."""
|
||||
|
||||
cwd: NotRequired[str | Path]
|
||||
"""The working directory to use when spawning the process."""
|
||||
|
||||
encoding: NotRequired[str]
|
||||
"""The text encoding used when sending/receiving messages to the server. Defaults to `utf-8`."""
|
||||
|
||||
encoding_error_handler: NotRequired[
|
||||
Literal["strict", "ignore", "replace"]
|
||||
]
|
||||
"""The text encoding error handler. Defaults to `strict`.
|
||||
|
||||
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
|
||||
explanations of possible values.
|
||||
"""
|
||||
|
||||
|
||||
class MCPServerStdio(_MCPServerWithClientSession):
|
||||
"""MCP server implementation that uses the stdio transport. See the [spec]
|
||||
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) for
|
||||
details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: MCPServerStdioParams,
|
||||
cache_tools_list: bool = False,
|
||||
name: str | None = None,
|
||||
):
|
||||
"""Create a new MCP server based on the stdio transport.
|
||||
|
||||
Args:
|
||||
params: The params that configure the server. This includes the command to run to
|
||||
start the server, the args to pass to the command, the environment variables to
|
||||
set for the server, the working directory to use when spawning the process, and
|
||||
the text encoding used when sending/receiving messages to the server.
|
||||
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
|
||||
cached and only fetched from the server once. If `False`, the tools list will be
|
||||
fetched from the server on each call to `list_tools()`. The cache can be
|
||||
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
|
||||
if you know the server will not change its tools list, because it can drastically
|
||||
improve latency (by avoiding a round-trip to the server every time).
|
||||
name: A readable name for the server. If not provided, we'll create one from the
|
||||
command.
|
||||
"""
|
||||
super().__init__(cache_tools_list)
|
||||
|
||||
self.params = StdioServerParameters(
|
||||
command=params["command"],
|
||||
args=params.get("args", []),
|
||||
env=params.get("env"),
|
||||
cwd=params.get("cwd"),
|
||||
encoding=params.get("encoding", "utf-8"),
|
||||
encoding_error_handler=params.get(
|
||||
"encoding_error_handler", "strict"
|
||||
),
|
||||
)
|
||||
|
||||
self._name = name or f"stdio: {self.params.command}"
|
||||
|
||||
def create_streams(
|
||||
self,
|
||||
) -> AbstractAsyncContextManager[
|
||||
tuple[
|
||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
MemoryObjectSendStream[JSONRPCMessage],
|
||||
]
|
||||
]:
|
||||
"""Create the streams for the server."""
|
||||
return stdio_client(self.params)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""A readable name for the server."""
|
||||
return self._name
|
||||
|
||||
|
||||
class MCPServerSseParams(TypedDict):
|
||||
"""Mirrors the params in`mcp.client.sse.sse_client`."""
|
||||
|
||||
url: str
|
||||
"""The URL of the server."""
|
||||
|
||||
headers: NotRequired[dict[str, str]]
|
||||
"""The headers to send to the server."""
|
||||
|
||||
timeout: NotRequired[float]
|
||||
"""The timeout for the HTTP request. Defaults to 5 seconds."""
|
||||
|
||||
sse_read_timeout: NotRequired[float]
|
||||
"""The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""
|
||||
|
||||
|
||||
class MCPServerSse(_MCPServerWithClientSession):
|
||||
"""MCP server implementation that uses the HTTP with SSE transport. See the [spec]
|
||||
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse)
|
||||
for details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: MCPServerSseParams,
|
||||
cache_tools_list: bool = False,
|
||||
name: str | None = None,
|
||||
):
|
||||
"""Create a new MCP server based on the HTTP with SSE transport.
|
||||
|
||||
Args:
|
||||
params: The params that configure the server. This includes the URL of the server,
|
||||
the headers to send to the server, the timeout for the HTTP request, and the
|
||||
timeout for the SSE connection.
|
||||
|
||||
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
|
||||
cached and only fetched from the server once. If `False`, the tools list will be
|
||||
fetched from the server on each call to `list_tools()`. The cache can be
|
||||
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
|
||||
if you know the server will not change its tools list, because it can drastically
|
||||
improve latency (by avoiding a round-trip to the server every time).
|
||||
|
||||
name: A readable name for the server. If not provided, we'll create one from the
|
||||
URL.
|
||||
"""
|
||||
super().__init__(cache_tools_list)
|
||||
|
||||
self.params = params
|
||||
self._name = name or f"sse: {self.params['url']}"
|
||||
|
||||
def create_streams(
|
||||
self,
|
||||
) -> AbstractAsyncContextManager[
|
||||
tuple[
|
||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
MemoryObjectSendStream[JSONRPCMessage],
|
||||
]
|
||||
]:
|
||||
"""Create the streams for the server."""
|
||||
return sse_client(
|
||||
url=self.params["url"],
|
||||
headers=self.params.get("headers", None),
|
||||
timeout=self.params.get("timeout", 5),
|
||||
sse_read_timeout=self.params.get(
|
||||
"sse_read_timeout", 60 * 5
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""A readable name for the server."""
|
||||
return self._name
|
||||
|
||||
|
||||
def mcp_flow_get_tool_schema(
|
||||
params: MCPServerSseParams,
|
||||
) -> MCPServer:
|
||||
server = MCPServerSse(params, cache_tools_list=True)
|
||||
|
||||
# Connect the server
|
||||
asyncio.run(server.connect())
|
||||
|
||||
# Return the server
|
||||
output = asyncio.run(server.list_tools())
|
||||
|
||||
# Cleanup the server
|
||||
asyncio.run(server.cleanup())
|
||||
|
||||
return output.model_dump()
|
||||
|
||||
|
||||
def mcp_flow(
|
||||
params: MCPServerSseParams,
|
||||
function_call: dict[str, Any],
|
||||
) -> MCPServer:
|
||||
server = MCPServerSse(params, cache_tools_list=True)
|
||||
|
||||
# Connect the server
|
||||
asyncio.run(server.connect())
|
||||
|
||||
# Return the server
|
||||
output = asyncio.run(server.call_tool(function_call))
|
||||
|
||||
output = output.model_dump()
|
||||
|
||||
# Cleanup the server
|
||||
asyncio.run(server.cleanup())
|
||||
|
||||
return any_to_str(output)
|
||||
|
||||
|
||||
def batch_mcp_flow(
|
||||
params: List[MCPServerSseParams],
|
||||
function_call: List[dict[str, Any]] = [],
|
||||
) -> MCPServer:
|
||||
output_list = []
|
||||
|
||||
for param in params:
|
||||
output = mcp_flow(param, function_call)
|
||||
output_list.append(output)
|
||||
|
||||
return output_list"
|
Loading…
Reference in new issue