parent
9cb2500e58
commit
89fc8c7609
@ -0,0 +1,554 @@
|
||||
from contextlib import AsyncExitStack
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
TypedDict,
|
||||
cast,
|
||||
)
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.types import (
|
||||
CallToolResult,
|
||||
EmbeddedResource,
|
||||
ImageContent,
|
||||
PromptMessage,
|
||||
TextContent,
|
||||
)
|
||||
from mcp.types import (
|
||||
Tool as MCPTool,
|
||||
)
|
||||
|
||||
|
||||
def convert_mcp_prompt_message_to_message(
|
||||
message: PromptMessage,
|
||||
) -> str:
|
||||
"""Convert an MCP prompt message to a string message.
|
||||
|
||||
Args:
|
||||
message: MCP prompt message to convert
|
||||
|
||||
Returns:
|
||||
a string message
|
||||
"""
|
||||
if message.content.type == "text":
|
||||
if message.role == "user":
|
||||
return str(message.content.text)
|
||||
elif message.role == "assistant":
|
||||
return str(
|
||||
message.content.text
|
||||
) # Fixed attribute name from str to text
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported prompt message role: {message.role}"
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported prompt message content type: {message.content.type}"
|
||||
)
|
||||
|
||||
|
||||
async def load_mcp_prompt(
|
||||
session: ClientSession,
|
||||
name: str,
|
||||
arguments: Optional[dict[str, Any]] = None,
|
||||
) -> List[str]:
|
||||
"""Load MCP prompt and convert to messages."""
|
||||
response = await session.get_prompt(name, arguments)
|
||||
|
||||
return [
|
||||
convert_mcp_prompt_message_to_message(message)
|
||||
for message in response.messages
|
||||
]
|
||||
|
||||
|
||||
DEFAULT_ENCODING = "utf-8"
|
||||
DEFAULT_ENCODING_ERROR_HANDLER = "strict"
|
||||
|
||||
DEFAULT_HTTP_TIMEOUT = 5
|
||||
DEFAULT_SSE_READ_TIMEOUT = 60 * 5
|
||||
|
||||
|
||||
class StdioConnection(TypedDict):
|
||||
transport: Literal["stdio"]
|
||||
|
||||
command: str
|
||||
"""The executable to run to start the server."""
|
||||
|
||||
args: list[str]
|
||||
"""Command line arguments to pass to the executable."""
|
||||
|
||||
env: dict[str, str] | None
|
||||
"""The environment to use when spawning the process."""
|
||||
|
||||
encoding: str
|
||||
"""The text encoding used when sending/receiving messages to the server."""
|
||||
|
||||
encoding_error_handler: Literal["strict", "ignore", "replace"]
|
||||
"""
|
||||
The text encoding error handler.
|
||||
|
||||
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
|
||||
explanations of possible values
|
||||
"""
|
||||
|
||||
|
||||
class SSEConnection(TypedDict):
|
||||
transport: Literal["sse"]
|
||||
|
||||
url: str
|
||||
"""The URL of the SSE endpoint to connect to."""
|
||||
|
||||
headers: dict[str, Any] | None
|
||||
"""HTTP headers to send to the SSE endpoint"""
|
||||
|
||||
timeout: float
|
||||
"""HTTP timeout"""
|
||||
|
||||
sse_read_timeout: float
|
||||
"""SSE read timeout"""
|
||||
|
||||
|
||||
NonTextContent = ImageContent | EmbeddedResource
|
||||
|
||||
|
||||
def _convert_call_tool_result(
|
||||
call_tool_result: CallToolResult,
|
||||
) -> tuple[str | list[str], list[NonTextContent] | None]:
|
||||
text_contents: list[TextContent] = []
|
||||
non_text_contents = []
|
||||
for content in call_tool_result.content:
|
||||
if isinstance(content, TextContent):
|
||||
text_contents.append(content)
|
||||
else:
|
||||
non_text_contents.append(content)
|
||||
|
||||
tool_content: str | list[str] = [
|
||||
content.text for content in text_contents
|
||||
]
|
||||
if len(text_contents) == 1:
|
||||
tool_content = tool_content[0]
|
||||
|
||||
if call_tool_result.isError:
|
||||
raise ValueError("Error calling tool")
|
||||
|
||||
return tool_content, non_text_contents or None
|
||||
|
||||
|
||||
def convert_mcp_tool_to_function(
|
||||
session: ClientSession,
|
||||
tool: MCPTool,
|
||||
) -> Callable[
|
||||
...,
|
||||
Coroutine[
|
||||
Any, Any, tuple[str | list[str], list[NonTextContent] | None]
|
||||
],
|
||||
]:
|
||||
"""Convert an MCP tool to a callable function.
|
||||
|
||||
NOTE: this tool can be executed only in a context of an active MCP client session.
|
||||
|
||||
Args:
|
||||
session: MCP client session
|
||||
tool: MCP tool to convert
|
||||
|
||||
Returns:
|
||||
a callable function
|
||||
"""
|
||||
|
||||
async def call_tool(
|
||||
**arguments: dict[str, Any],
|
||||
) -> tuple[str | list[str], list[NonTextContent] | None]:
|
||||
"""Execute the tool with the given arguments."""
|
||||
call_tool_result = await session.call_tool(
|
||||
tool.name, arguments
|
||||
)
|
||||
return _convert_call_tool_result(call_tool_result)
|
||||
|
||||
# Add metadata as attributes to the function
|
||||
call_tool.__name__ = tool.name
|
||||
call_tool.__doc__ = tool.description or ""
|
||||
call_tool.schema = tool.inputSchema
|
||||
|
||||
return call_tool
|
||||
|
||||
|
||||
async def load_mcp_tools(session: ClientSession) -> list[Callable]:
|
||||
"""Load all available MCP tools and convert them to callable functions."""
|
||||
tools = await session.list_tools()
|
||||
return [
|
||||
convert_mcp_tool_to_function(session, tool)
|
||||
for tool in tools.tools
|
||||
]
|
||||
|
||||
|
||||
class MultiServerMCPClient:
|
||||
"""Client for connecting to multiple MCP servers and loading tools from them."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connections: dict[
|
||||
str, StdioConnection | SSEConnection
|
||||
] = None,
|
||||
) -> None:
|
||||
"""Initialize a MultiServerMCPClient with MCP servers connections.
|
||||
|
||||
Args:
|
||||
connections: A dictionary mapping server names to connection configurations.
|
||||
Each configuration can be either a StdioConnection or SSEConnection.
|
||||
If None, no initial connections are established.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
async with MultiServerMCPClient(
|
||||
{
|
||||
"math": {
|
||||
"command": "python",
|
||||
# Make sure to update to the full absolute path to your math_server.py file
|
||||
"args": ["/path/to/math_server.py"],
|
||||
"transport": "stdio",
|
||||
},
|
||||
"weather": {
|
||||
# make sure you start your weather server on port 8000
|
||||
"url": "http://localhost:8000/sse",
|
||||
"transport": "sse",
|
||||
}
|
||||
}
|
||||
) as client:
|
||||
all_tools = client.get_tools()
|
||||
...
|
||||
```
|
||||
"""
|
||||
self.connections = connections
|
||||
self.exit_stack = AsyncExitStack()
|
||||
self.sessions: dict[str, ClientSession] = {}
|
||||
self.server_name_to_tools: dict[str, list[Callable]] = {}
|
||||
|
||||
async def _initialize_session_and_load_tools(
|
||||
self, server_name: str, session: ClientSession
|
||||
) -> None:
|
||||
"""Initialize a session and load tools from it.
|
||||
|
||||
Args:
|
||||
server_name: Name to identify this server connection
|
||||
session: The ClientSession to initialize
|
||||
"""
|
||||
# Initialize the session
|
||||
await session.initialize()
|
||||
self.sessions[server_name] = session
|
||||
|
||||
# Load tools from this server
|
||||
server_tools = await load_mcp_tools(session)
|
||||
self.server_name_to_tools[server_name] = server_tools
|
||||
|
||||
async def connect_to_server(
|
||||
self,
|
||||
server_name: str,
|
||||
*,
|
||||
transport: Literal["stdio", "sse"] = "stdio",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Connect to an MCP server using either stdio or SSE.
|
||||
|
||||
This is a generic method that calls either connect_to_server_via_stdio or connect_to_server_via_sse
|
||||
based on the provided transport parameter.
|
||||
|
||||
Args:
|
||||
server_name: Name to identify this server connection
|
||||
transport: Type of transport to use ("stdio" or "sse"), defaults to "stdio"
|
||||
**kwargs: Additional arguments to pass to the specific connection method
|
||||
|
||||
Raises:
|
||||
ValueError: If transport is not recognized
|
||||
ValueError: If required parameters for the specified transport are missing
|
||||
"""
|
||||
if transport == "sse":
|
||||
if "url" not in kwargs:
|
||||
raise ValueError(
|
||||
"'url' parameter is required for SSE connection"
|
||||
)
|
||||
await self.connect_to_server_via_sse(
|
||||
server_name,
|
||||
url=kwargs["url"],
|
||||
headers=kwargs.get("headers"),
|
||||
timeout=kwargs.get("timeout", DEFAULT_HTTP_TIMEOUT),
|
||||
sse_read_timeout=kwargs.get(
|
||||
"sse_read_timeout", DEFAULT_SSE_READ_TIMEOUT
|
||||
),
|
||||
)
|
||||
elif transport == "stdio":
|
||||
if "command" not in kwargs:
|
||||
raise ValueError(
|
||||
"'command' parameter is required for stdio connection"
|
||||
)
|
||||
if "args" not in kwargs:
|
||||
raise ValueError(
|
||||
"'args' parameter is required for stdio connection"
|
||||
)
|
||||
await self.connect_to_server_via_stdio(
|
||||
server_name,
|
||||
command=kwargs["command"],
|
||||
args=kwargs["args"],
|
||||
env=kwargs.get("env"),
|
||||
encoding=kwargs.get("encoding", DEFAULT_ENCODING),
|
||||
encoding_error_handler=kwargs.get(
|
||||
"encoding_error_handler",
|
||||
DEFAULT_ENCODING_ERROR_HANDLER,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'"
|
||||
)
|
||||
|
||||
async def connect_to_server_via_stdio(
|
||||
self,
|
||||
server_name: str,
|
||||
*,
|
||||
command: str,
|
||||
args: list[str],
|
||||
env: dict[str, str] | None = None,
|
||||
encoding: str = DEFAULT_ENCODING,
|
||||
encoding_error_handler: Literal[
|
||||
"strict", "ignore", "replace"
|
||||
] = DEFAULT_ENCODING_ERROR_HANDLER,
|
||||
) -> None:
|
||||
"""Connect to a specific MCP server using stdio
|
||||
|
||||
Args:
|
||||
server_name: Name to identify this server connection
|
||||
command: Command to execute
|
||||
args: Arguments for the command
|
||||
env: Environment variables for the command
|
||||
encoding: Character encoding
|
||||
encoding_error_handler: How to handle encoding errors
|
||||
"""
|
||||
server_params = StdioServerParameters(
|
||||
command=command,
|
||||
args=args,
|
||||
env=env,
|
||||
encoding=encoding,
|
||||
encoding_error_handler=encoding_error_handler,
|
||||
)
|
||||
|
||||
# Create and store the connection
|
||||
stdio_transport = await self.exit_stack.enter_async_context(
|
||||
stdio_client(server_params)
|
||||
)
|
||||
read, write = stdio_transport
|
||||
session = cast(
|
||||
ClientSession,
|
||||
await self.exit_stack.enter_async_context(
|
||||
ClientSession(read, write)
|
||||
),
|
||||
)
|
||||
|
||||
await self._initialize_session_and_load_tools(
|
||||
server_name, session
|
||||
)
|
||||
|
||||
async def connect_to_server_via_sse(
|
||||
self,
|
||||
server_name: str,
|
||||
*,
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = DEFAULT_HTTP_TIMEOUT,
|
||||
sse_read_timeout: float = DEFAULT_SSE_READ_TIMEOUT,
|
||||
) -> None:
|
||||
"""Connect to a specific MCP server using SSE
|
||||
|
||||
Args:
|
||||
server_name: Name to identify this server connection
|
||||
url: URL of the SSE server
|
||||
headers: HTTP headers to send to the SSE endpoint
|
||||
timeout: HTTP timeout
|
||||
sse_read_timeout: SSE read timeout
|
||||
"""
|
||||
# Create and store the connection
|
||||
sse_transport = await self.exit_stack.enter_async_context(
|
||||
sse_client(url, headers, timeout, sse_read_timeout)
|
||||
)
|
||||
read, write = sse_transport
|
||||
session = cast(
|
||||
ClientSession,
|
||||
await self.exit_stack.enter_async_context(
|
||||
ClientSession(read, write)
|
||||
),
|
||||
)
|
||||
|
||||
await self._initialize_session_and_load_tools(
|
||||
server_name, session
|
||||
)
|
||||
|
||||
def get_tools(self) -> list[Callable]:
|
||||
"""Get a list of all tools from all connected servers."""
|
||||
all_tools: list[Callable] = []
|
||||
for server_tools in self.server_name_to_tools.values():
|
||||
all_tools.extend(server_tools)
|
||||
return all_tools
|
||||
|
||||
async def get_prompt(
|
||||
self,
|
||||
server_name: str,
|
||||
prompt_name: str,
|
||||
arguments: Optional[dict[str, Any]] = None,
|
||||
) -> List[str]:
|
||||
"""Get a prompt from a given MCP server."""
|
||||
session = self.sessions[server_name]
|
||||
return await load_mcp_prompt(session, prompt_name, arguments)
|
||||
|
||||
async def __aenter__(self) -> "MultiServerMCPClient":
|
||||
try:
|
||||
connections = self.connections or {}
|
||||
for server_name, connection in connections.items():
|
||||
connection_dict = connection.copy()
|
||||
transport = connection_dict.pop("transport")
|
||||
if transport == "stdio":
|
||||
await self.connect_to_server_via_stdio(
|
||||
server_name, **connection_dict
|
||||
)
|
||||
elif transport == "sse":
|
||||
await self.connect_to_server_via_sse(
|
||||
server_name, **connection_dict
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'"
|
||||
)
|
||||
return self
|
||||
except Exception:
|
||||
await self.exit_stack.aclose()
|
||||
raise
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
await self.exit_stack.aclose()
|
||||
|
||||
|
||||
#!/usr/bin/env python3
|
||||
import asyncio
|
||||
import os
|
||||
import json
|
||||
from typing import List, Any, Callable
|
||||
|
||||
# # Import our MCP client module
|
||||
# from mcp_client import MultiServerMCPClient
|
||||
|
||||
|
||||
async def main():
|
||||
"""Test script for demonstrating MCP client usage."""
|
||||
print("Starting MCP Client test...")
|
||||
|
||||
# Create a connection to multiple MCP servers
|
||||
# You'll need to update these paths to match your setup
|
||||
async with MultiServerMCPClient(
|
||||
{
|
||||
"math": {
|
||||
"transport": "stdio",
|
||||
"command": "python",
|
||||
"args": ["/path/to/math_server.py"],
|
||||
"env": {"DEBUG": "1"},
|
||||
},
|
||||
"search": {
|
||||
"transport": "sse",
|
||||
"url": "http://localhost:8000/sse",
|
||||
"headers": {
|
||||
"Authorization": f"Bearer {os.environ.get('API_KEY', '')}"
|
||||
},
|
||||
},
|
||||
}
|
||||
) as client:
|
||||
# Get all available tools
|
||||
tools = client.get_tools()
|
||||
print(f"Found {len(tools)} tools across all servers")
|
||||
|
||||
# Print tool information
|
||||
for i, tool in enumerate(tools):
|
||||
print(f"\nTool {i+1}: {tool.__name__}")
|
||||
print(f" Description: {tool.__doc__}")
|
||||
if hasattr(tool, "schema") and tool.schema:
|
||||
print(
|
||||
f" Schema: {json.dumps(tool.schema, indent=2)[:100]}..."
|
||||
)
|
||||
|
||||
# Example: Use a specific tool if available
|
||||
calculator_tool = next(
|
||||
(t for t in tools if t.__name__ == "calculator"), None
|
||||
)
|
||||
if calculator_tool:
|
||||
print("\n\nTesting calculator tool:")
|
||||
try:
|
||||
# Call the tool as an async function
|
||||
result, artifacts = await calculator_tool(
|
||||
expression="2 + 2 * 3"
|
||||
)
|
||||
print(f" Calculator result: {result}")
|
||||
if artifacts:
|
||||
print(
|
||||
f" With {len(artifacts)} additional artifacts"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" Error using calculator: {e}")
|
||||
|
||||
# Example: Load a prompt from a server
|
||||
try:
|
||||
print("\n\nTesting prompt loading:")
|
||||
prompt_messages = await client.get_prompt(
|
||||
"math",
|
||||
"calculation_introduction",
|
||||
{"user_name": "Test User"},
|
||||
)
|
||||
print(
|
||||
f" Loaded prompt with {len(prompt_messages)} messages:"
|
||||
)
|
||||
for i, msg in enumerate(prompt_messages):
|
||||
print(f" Message {i+1}: {msg[:50]}...")
|
||||
except Exception as e:
|
||||
print(f" Error loading prompt: {e}")
|
||||
|
||||
|
||||
async def create_custom_tool():
|
||||
"""Example of creating a custom tool function."""
|
||||
|
||||
# Define a tool function with metadata
|
||||
async def add_numbers(a: float, b: float) -> tuple[str, None]:
|
||||
"""Add two numbers together."""
|
||||
result = a + b
|
||||
return f"The sum of {a} and {b} is {result}", None
|
||||
|
||||
# Add metadata to the function
|
||||
add_numbers.__name__ = "add_numbers"
|
||||
add_numbers.__doc__ = (
|
||||
"Add two numbers together and return the result."
|
||||
)
|
||||
add_numbers.schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "number", "description": "First number"},
|
||||
"b": {"type": "number", "description": "Second number"},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
}
|
||||
|
||||
# Use the tool
|
||||
result, _ = await add_numbers(a=5, b=7)
|
||||
print(f"\nCustom tool result: {result}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run both examples
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(main())
|
||||
loop.run_until_complete(create_custom_tool())
|
Loading…
Reference in new issue