parent
ea66e78154
commit
a612352abd
@ -0,0 +1,83 @@
|
||||
from swarms import Agent
|
||||
from swarms.tools.mcp_integration import MCPServerSseParams
|
||||
from loguru import logger
|
||||
|
||||
# Comprehensive math prompt that encourages proper JSON formatting
|
||||
MATH_AGENT_PROMPT = """
|
||||
You are a helpful math calculator assistant.
|
||||
|
||||
Your role is to understand natural language math requests and perform calculations.
|
||||
When asked to perform calculations:
|
||||
|
||||
1. Determine the operation (add, multiply, or divide)
|
||||
2. Extract the numbers from the request
|
||||
3. Use the appropriate math operation tool
|
||||
|
||||
FORMAT YOUR TOOL CALLS AS JSON with this format:
|
||||
{"tool_name": "add", "a": <first_number>, "b": <second_number>}
|
||||
or
|
||||
{"tool_name": "multiply", "a": <first_number>, "b": <second_number>}
|
||||
or
|
||||
{"tool_name": "divide", "a": <first_number>, "b": <second_number>}
|
||||
|
||||
Always respond with a tool call in JSON format first, followed by a brief explanation.
|
||||
"""
|
||||
|
||||
def initialize_math_system():
|
||||
"""Initialize the math agent with MCP server configuration."""
|
||||
# Configure the MCP server connection
|
||||
math_server = MCPServerSseParams(
|
||||
url="http://0.0.0.0:8000",
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=5.0,
|
||||
sse_read_timeout=30.0
|
||||
)
|
||||
|
||||
# Create the agent with the MCP server configuration
|
||||
math_agent = Agent(
|
||||
agent_name="Math Assistant",
|
||||
agent_description="Friendly math calculator",
|
||||
system_prompt=MATH_AGENT_PROMPT,
|
||||
max_loops=1,
|
||||
mcp_servers=[math_server], # Pass MCP server config as a list
|
||||
model_name="gpt-3.5-turbo",
|
||||
verbose=True # Enable verbose mode to see more details
|
||||
)
|
||||
|
||||
return math_agent
|
||||
|
||||
def main():
|
||||
try:
|
||||
logger.info("Initializing math system...")
|
||||
math_agent = initialize_math_system()
|
||||
|
||||
print("\nMath Calculator Ready!")
|
||||
print("Ask me any math question!")
|
||||
print("Examples: 'what is 5 plus 3?' or 'can you multiply 4 and 6?'")
|
||||
print("Type 'exit' to quit\n")
|
||||
|
||||
while True:
|
||||
try:
|
||||
query = input("What would you like to calculate? ").strip()
|
||||
if not query:
|
||||
continue
|
||||
if query.lower() == 'exit':
|
||||
break
|
||||
|
||||
logger.info(f"Processing query: {query}")
|
||||
result = math_agent.run(query)
|
||||
print(f"\nResult: {result}\n")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing query: {e}")
|
||||
print(f"Sorry, there was an error: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"System initialization error: {e}")
|
||||
print(f"Failed to start the math system: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,53 +1,83 @@
|
||||
from swarms import Agent
|
||||
from swarms.tools.mcp_integration import MCPServerSseParams
|
||||
from loguru import logger
|
||||
|
||||
from swarms import Agent
|
||||
from swarms.tools.mcp_integration import MCPServerSseParams
|
||||
from swarms.prompts.agent_prompts import MATH_AGENT_PROMPT
|
||||
from loguru import logger
|
||||
|
||||
def initialize_math_system():
|
||||
"""Initialize the math agent with MCP server configuration."""
|
||||
math_server = MCPServerSseParams(
|
||||
url="http://0.0.0.0:8000",
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=5.0,
|
||||
sse_read_timeout=30.0
|
||||
)
|
||||
|
||||
math_agent = Agent(
|
||||
agent_name="Math Assistant",
|
||||
agent_description="Friendly math calculator",
|
||||
system_prompt=MATH_AGENT_PROMPT,
|
||||
max_loops=1,
|
||||
mcp_servers=[math_server],
|
||||
model_name="gpt-3.5-turbo"
|
||||
)
|
||||
|
||||
return math_agent
|
||||
|
||||
def main():
|
||||
math_agent = initialize_math_system()
|
||||
|
||||
print("\nMath Calculator Ready!")
|
||||
print("Ask me any math question!")
|
||||
print("Examples: 'what is 5 plus 3?' or 'can you multiply 4 and 6?'")
|
||||
print("Type 'exit' to quit\n")
|
||||
|
||||
while True:
|
||||
try:
|
||||
query = input("What would you like to calculate? ").strip()
|
||||
if not query:
|
||||
continue
|
||||
if query.lower() == 'exit':
|
||||
break
|
||||
|
||||
result = math_agent.run(query)
|
||||
print(f"\nResult: {result}\n")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# Comprehensive math prompt that encourages proper JSON formatting
|
||||
MATH_AGENT_PROMPT = """
|
||||
You are a helpful math calculator assistant.
|
||||
|
||||
Your role is to understand natural language math requests and perform calculations.
|
||||
When asked to perform calculations:
|
||||
|
||||
1. Determine the operation (add, multiply, or divide)
|
||||
2. Extract the numbers from the request
|
||||
3. Use the appropriate math operation tool
|
||||
|
||||
FORMAT YOUR TOOL CALLS AS JSON with this format:
|
||||
{"tool_name": "add", "a": <first_number>, "b": <second_number>}
|
||||
or
|
||||
{"tool_name": "multiply", "a": <first_number>, "b": <second_number>}
|
||||
or
|
||||
{"tool_name": "divide", "a": <first_number>, "b": <second_number>}
|
||||
|
||||
Always respond with a tool call in JSON format first, followed by a brief explanation.
|
||||
"""
|
||||
|
||||
def initialize_math_system():
|
||||
"""Initialize the math agent with MCP server configuration."""
|
||||
# Configure the MCP server connection
|
||||
math_server = MCPServerSseParams(
|
||||
url="http://0.0.0.0:8000",
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=5.0,
|
||||
sse_read_timeout=30.0
|
||||
)
|
||||
|
||||
# Create the agent with the MCP server configuration
|
||||
math_agent = Agent(
|
||||
agent_name="Math Assistant",
|
||||
agent_description="Friendly math calculator",
|
||||
system_prompt=MATH_AGENT_PROMPT,
|
||||
max_loops=1,
|
||||
mcp_servers=[math_server], # Pass MCP server config as a list
|
||||
model_name="gpt-3.5-turbo",
|
||||
verbose=True # Enable verbose mode to see more details
|
||||
)
|
||||
|
||||
return math_agent
|
||||
|
||||
def main():
|
||||
try:
|
||||
logger.info("Initializing math system...")
|
||||
math_agent = initialize_math_system()
|
||||
|
||||
print("\nMath Calculator Ready!")
|
||||
print("Ask me any math question!")
|
||||
print("Examples: 'what is 5 plus 3?' or 'can you multiply 4 and 6?'")
|
||||
print("Type 'exit' to quit\n")
|
||||
|
||||
while True:
|
||||
try:
|
||||
query = input("What would you like to calculate? ").strip()
|
||||
if not query:
|
||||
continue
|
||||
if query.lower() == 'exit':
|
||||
break
|
||||
|
||||
logger.info(f"Processing query: {query}")
|
||||
result = math_agent.run(query)
|
||||
print(f"\nResult: {result}\n")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing query: {e}")
|
||||
print(f"Sorry, there was an error: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"System initialization error: {e}")
|
||||
print(f"Failed to start the math system: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,38 +1,79 @@
|
||||
|
||||
from fastmcp import FastMCP
|
||||
from loguru import logger
|
||||
import time
|
||||
|
||||
# Create the MCP server
|
||||
mcp = FastMCP(host="0.0.0.0",
|
||||
port=8000,
|
||||
transport="sse",
|
||||
require_session_id=False)
|
||||
|
||||
mcp = FastMCP(
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
transport="sse",
|
||||
require_session_id=False
|
||||
)
|
||||
|
||||
# Define tools with proper type hints and docstrings
|
||||
@mcp.tool()
|
||||
def add(a: int, b: int) -> str:
|
||||
"""Add two numbers."""
|
||||
"""Add two numbers.
|
||||
|
||||
Args:
|
||||
a (int): First number
|
||||
b (int): Second number
|
||||
|
||||
Returns:
|
||||
str: A message containing the sum
|
||||
"""
|
||||
logger.info(f"Adding {a} and {b}")
|
||||
result = a + b
|
||||
return f"The sum of {a} and {b} is {result}"
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def multiply(a: int, b: int) -> str:
|
||||
"""Multiply two numbers."""
|
||||
"""Multiply two numbers.
|
||||
|
||||
Args:
|
||||
a (int): First number
|
||||
b (int): Second number
|
||||
|
||||
Returns:
|
||||
str: A message containing the product
|
||||
"""
|
||||
logger.info(f"Multiplying {a} and {b}")
|
||||
result = a * b
|
||||
return f"The product of {a} and {b} is {result}"
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def divide(a: int, b: int) -> str:
|
||||
"""Divide two numbers."""
|
||||
"""Divide two numbers.
|
||||
|
||||
Args:
|
||||
a (int): Numerator
|
||||
b (int): Denominator
|
||||
|
||||
Returns:
|
||||
str: A message containing the division result or an error message
|
||||
"""
|
||||
logger.info(f"Dividing {a} by {b}")
|
||||
if b == 0:
|
||||
logger.warning("Division by zero attempted")
|
||||
return "Cannot divide by zero"
|
||||
result = a / b
|
||||
return f"{a} divided by {b} is {result}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
logger.info("Starting math server on http://0.0.0.0:8000")
|
||||
print("Math MCP Server is running. Press Ctrl+C to stop.")
|
||||
|
||||
# Add a small delay to ensure logging is complete before the server starts
|
||||
time.sleep(0.5)
|
||||
|
||||
# Run the MCP server
|
||||
mcp.run()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Server shutdown requested")
|
||||
print("\nShutting down server...")
|
||||
except Exception as e:
|
||||
logger.error(f"Server error: {e}")
|
||||
raise
|
||||
|
@ -1,255 +1,320 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Literal
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from loguru import logger
|
||||
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.types import CallToolResult, JSONRPCMessage
|
||||
|
||||
from swarms.utils.any_to_str import any_to_str
|
||||
|
||||
|
||||
class MCPServer(abc.ABC):
|
||||
"""Base class for Model Context Protocol servers."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def connect(self) -> None:
|
||||
"""Establish connection to the MCP server."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Human-readable server name."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def cleanup(self) -> None:
|
||||
"""Clean up resources and close connection."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_tools(self) -> List[MCPTool]:
|
||||
"""List available MCP tools on the server."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def call_tool(
|
||||
self, tool_name: str, arguments: Dict[str, Any] | None
|
||||
) -> CallToolResult:
|
||||
"""Invoke a tool by name with provided arguments."""
|
||||
pass
|
||||
|
||||
|
||||
class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
||||
"""Mixin providing ClientSession-based MCP communication."""
|
||||
|
||||
def __init__(self, cache_tools_list: bool = False):
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.exit_stack: AsyncExitStack = AsyncExitStack()
|
||||
self._cleanup_lock = asyncio.Lock()
|
||||
self.cache_tools_list = cache_tools_list
|
||||
self._cache_dirty = True
|
||||
self._tools_list: Optional[List[MCPTool]] = None
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_streams(
|
||||
self
|
||||
) -> AbstractAsyncContextManager[
|
||||
tuple[
|
||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
MemoryObjectSendStream[JSONRPCMessage],
|
||||
]
|
||||
]:
|
||||
"""Supply the read/write streams for the MCP transport."""
|
||||
pass
|
||||
|
||||
async def __aenter__(self) -> MCPServer:
|
||||
await self.connect()
|
||||
return self # type: ignore
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, tb) -> None:
|
||||
await self.cleanup()
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize transport and ClientSession."""
|
||||
try:
|
||||
transport = await self.exit_stack.enter_async_context(
|
||||
self.create_streams()
|
||||
)
|
||||
read, write = transport
|
||||
session = await self.exit_stack.enter_async_context(
|
||||
ClientSession(read, write)
|
||||
)
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing MCP server: {e}")
|
||||
await self.cleanup()
|
||||
raise
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Close session and transport."""
|
||||
async with self._cleanup_lock:
|
||||
try:
|
||||
await self.exit_stack.aclose()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during cleanup: {e}")
|
||||
finally:
|
||||
self.session = None
|
||||
|
||||
async def list_tools(self) -> List[MCPTool]:
|
||||
if not self.session:
|
||||
raise RuntimeError("Server not connected. Call connect() first.")
|
||||
if self.cache_tools_list and not self._cache_dirty and self._tools_list:
|
||||
return self._tools_list
|
||||
self._cache_dirty = False
|
||||
self._tools_list = (await self.session.list_tools()).tools
|
||||
return self._tools_list # type: ignore
|
||||
|
||||
async def call_tool(
|
||||
self, tool_name: str | None = None, arguments: Dict[str, Any] | None = None
|
||||
) -> CallToolResult:
|
||||
if not arguments:
|
||||
raise ValueError("Arguments dict is required to call a tool")
|
||||
name = tool_name or arguments.get("tool_name") or arguments.get("name")
|
||||
if not name:
|
||||
raise ValueError("Tool name missing in arguments")
|
||||
if not self.session:
|
||||
raise RuntimeError("Server not connected. Call connect() first.")
|
||||
return await self.session.call_tool(name, arguments)
|
||||
|
||||
|
||||
class MCPServerStdioParams(TypedDict):
|
||||
"""Configuration for stdio transport."""
|
||||
command: str
|
||||
args: NotRequired[List[str]]
|
||||
env: NotRequired[Dict[str, str]]
|
||||
cwd: NotRequired[str | Path]
|
||||
encoding: NotRequired[str]
|
||||
encoding_error_handler: NotRequired[Literal["strict", "ignore", "replace"]]
|
||||
|
||||
|
||||
class MCPServerStdio(_MCPServerWithClientSession):
|
||||
"""MCP server over stdio transport."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: MCPServerStdioParams,
|
||||
cache_tools_list: bool = False,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
super().__init__(cache_tools_list)
|
||||
self.params = StdioServerParameters(
|
||||
command=params["command"],
|
||||
args=params.get("args", []),
|
||||
env=params.get("env"),
|
||||
cwd=params.get("cwd"),
|
||||
encoding=params.get("encoding", "utf-8"),
|
||||
encoding_error_handler=params.get("encoding_error_handler", "strict"),
|
||||
)
|
||||
self._name = name or f"stdio:{self.params.command}"
|
||||
|
||||
def create_streams(self) -> AbstractAsyncContextManager[
|
||||
tuple[
|
||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
MemoryObjectSendStream[JSONRPCMessage],
|
||||
]
|
||||
]:
|
||||
return stdio_client(self.params)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
|
||||
class MCPServerSseParams(TypedDict):
|
||||
"""Configuration for HTTP+SSE transport."""
|
||||
url: str
|
||||
headers: NotRequired[Dict[str, str]]
|
||||
timeout: NotRequired[float]
|
||||
sse_read_timeout: NotRequired[float]
|
||||
|
||||
|
||||
class MCPServerSse(_MCPServerWithClientSession):
|
||||
"""MCP server over HTTP with SSE transport."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: MCPServerSseParams,
|
||||
cache_tools_list: bool = False,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
super().__init__(cache_tools_list)
|
||||
self.params = params
|
||||
self._name = name or f"sse:{params['url']}"
|
||||
|
||||
def create_streams(self) -> AbstractAsyncContextManager[
|
||||
tuple[
|
||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
MemoryObjectSendStream[JSONRPCMessage],
|
||||
]
|
||||
]:
|
||||
return sse_client(
|
||||
url=self.params["url"],
|
||||
headers=self.params.get("headers"),
|
||||
timeout=self.params.get("timeout", 5),
|
||||
sse_read_timeout=self.params.get("sse_read_timeout", 300),
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
|
||||
async def call_tool_fast(
|
||||
server: MCPServerSse, payload: Dict[str, Any] | str
|
||||
) -> Any:
|
||||
try:
|
||||
await server.connect()
|
||||
result = await server.call_tool(arguments=payload if isinstance(payload, dict) else None)
|
||||
return result
|
||||
finally:
|
||||
await server.cleanup()
|
||||
|
||||
|
||||
async def mcp_flow_get_tool_schema(
|
||||
params: MCPServerSseParams,
|
||||
) -> Any:
|
||||
async with MCPServerSse(params) as server:
|
||||
tools = await server.list_tools()
|
||||
return tools
|
||||
|
||||
|
||||
async def mcp_flow(
|
||||
params: MCPServerSseParams,
|
||||
function_call: Dict[str, Any] | str,
|
||||
) -> Any:
|
||||
async with MCPServerSse(params) as server:
|
||||
return await call_tool_fast(server, function_call)
|
||||
|
||||
|
||||
async def _call_one_server(
|
||||
params: MCPServerSseParams, payload: Dict[str, Any] | str
|
||||
) -> Any:
|
||||
server = MCPServerSse(params)
|
||||
try:
|
||||
await server.connect()
|
||||
return await server.call_tool(arguments=payload if isinstance(payload, dict) else None)
|
||||
finally:
|
||||
await server.cleanup()
|
||||
|
||||
|
||||
def batch_mcp_flow(
|
||||
params: List[MCPServerSseParams], payload: Dict[str, Any] | str
|
||||
) -> List[Any]:
|
||||
return asyncio.run(
|
||||
asyncio.gather(*[_call_one_server(p, payload) for p in params])
|
||||
)
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Literal, Union
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from loguru import logger
|
||||
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.types import CallToolResult, JSONRPCMessage
|
||||
|
||||
from swarms.utils.any_to_str import any_to_str
|
||||
|
||||
|
||||
class MCPServer(abc.ABC):
|
||||
"""Base class for Model Context Protocol servers."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def connect(self) -> None:
|
||||
"""Establish connection to the MCP server."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Human-readable server name."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def cleanup(self) -> None:
|
||||
"""Clean up resources and close connection."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_tools(self) -> List[MCPTool]:
|
||||
"""List available MCP tools on the server."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def call_tool(
|
||||
self, tool_name: str, arguments: Dict[str, Any] | None
|
||||
) -> CallToolResult:
|
||||
"""Invoke a tool by name with provided arguments."""
|
||||
pass
|
||||
|
||||
|
||||
class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
||||
"""Mixin providing ClientSession-based MCP communication."""
|
||||
|
||||
def __init__(self, cache_tools_list: bool = False):
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.exit_stack: AsyncExitStack = AsyncExitStack()
|
||||
self._cleanup_lock = asyncio.Lock()
|
||||
self.cache_tools_list = cache_tools_list
|
||||
self._cache_dirty = True
|
||||
self._tools_list: Optional[List[MCPTool]] = None
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_streams(
|
||||
self
|
||||
) -> AbstractAsyncContextManager[
|
||||
tuple[
|
||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
MemoryObjectSendStream[JSONRPCMessage],
|
||||
]
|
||||
]:
|
||||
"""Supply the read/write streams for the MCP transport."""
|
||||
pass
|
||||
|
||||
async def __aenter__(self) -> MCPServer:
|
||||
await self.connect()
|
||||
return self # type: ignore
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, tb) -> None:
|
||||
await self.cleanup()
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize transport and ClientSession."""
|
||||
try:
|
||||
transport = await self.exit_stack.enter_async_context(
|
||||
self.create_streams()
|
||||
)
|
||||
read, write = transport
|
||||
session = await self.exit_stack.enter_async_context(
|
||||
ClientSession(read, write)
|
||||
)
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing MCP server: {e}")
|
||||
await self.cleanup()
|
||||
raise
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Close session and transport."""
|
||||
async with self._cleanup_lock:
|
||||
try:
|
||||
await self.exit_stack.aclose()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during cleanup: {e}")
|
||||
finally:
|
||||
self.session = None
|
||||
|
||||
async def list_tools(self) -> List[MCPTool]:
|
||||
if not self.session:
|
||||
raise RuntimeError("Server not connected. Call connect() first.")
|
||||
if self.cache_tools_list and not self._cache_dirty and self._tools_list:
|
||||
return self._tools_list
|
||||
self._cache_dirty = False
|
||||
self._tools_list = (await self.session.list_tools()).tools
|
||||
return self._tools_list # type: ignore
|
||||
|
||||
async def call_tool(
|
||||
self, tool_name: str | None = None, arguments: Dict[str, Any] | None = None
|
||||
) -> CallToolResult:
|
||||
if not arguments:
|
||||
raise ValueError("Arguments dict is required to call a tool")
|
||||
name = tool_name or arguments.get("tool_name") or arguments.get("name")
|
||||
if not name:
|
||||
raise ValueError("Tool name missing in arguments")
|
||||
if not self.session:
|
||||
raise RuntimeError("Server not connected. Call connect() first.")
|
||||
return await self.session.call_tool(name, arguments)
|
||||
|
||||
|
||||
class MCPServerStdioParams(TypedDict):
|
||||
"""Configuration for stdio transport."""
|
||||
command: str
|
||||
args: NotRequired[List[str]]
|
||||
env: NotRequired[Dict[str, str]]
|
||||
cwd: NotRequired[str | Path]
|
||||
encoding: NotRequired[str]
|
||||
encoding_error_handler: NotRequired[Literal["strict", "ignore", "replace"]]
|
||||
|
||||
|
||||
class MCPServerStdio(_MCPServerWithClientSession):
|
||||
"""MCP server over stdio transport."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: MCPServerStdioParams,
|
||||
cache_tools_list: bool = False,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
super().__init__(cache_tools_list)
|
||||
self.params = StdioServerParameters(
|
||||
command=params["command"],
|
||||
args=params.get("args", []),
|
||||
env=params.get("env"),
|
||||
cwd=params.get("cwd"),
|
||||
encoding=params.get("encoding", "utf-8"),
|
||||
encoding_error_handler=params.get("encoding_error_handler", "strict"),
|
||||
)
|
||||
self._name = name or f"stdio:{self.params.command}"
|
||||
|
||||
def create_streams(self) -> AbstractAsyncContextManager[
|
||||
tuple[
|
||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
MemoryObjectSendStream[JSONRPCMessage],
|
||||
]
|
||||
]:
|
||||
return stdio_client(self.params)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
|
||||
class MCPServerSseParams(TypedDict):
|
||||
"""Configuration for HTTP+SSE transport."""
|
||||
url: str
|
||||
headers: NotRequired[Dict[str, str]]
|
||||
timeout: NotRequired[float]
|
||||
sse_read_timeout: NotRequired[float]
|
||||
|
||||
|
||||
class MCPServerSse(_MCPServerWithClientSession):
|
||||
"""MCP server over HTTP with SSE transport."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: MCPServerSseParams,
|
||||
cache_tools_list: bool = False,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
super().__init__(cache_tools_list)
|
||||
self.params = params
|
||||
self._name = name or f"sse:{params['url']}"
|
||||
|
||||
def create_streams(self) -> AbstractAsyncContextManager[
|
||||
tuple[
|
||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
MemoryObjectSendStream[JSONRPCMessage],
|
||||
]
|
||||
]:
|
||||
return sse_client(
|
||||
url=self.params["url"],
|
||||
headers=self.params.get("headers"),
|
||||
timeout=self.params.get("timeout", 5),
|
||||
sse_read_timeout=self.params.get("sse_read_timeout", 300),
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
|
||||
async def call_tool_fast(
|
||||
server: MCPServerSse, payload: Dict[str, Any] | str
|
||||
) -> Any:
|
||||
"""Async function to call a tool on a server with proper cleanup."""
|
||||
try:
|
||||
await server.connect()
|
||||
arguments = payload if isinstance(payload, dict) else None
|
||||
result = await server.call_tool(arguments=arguments)
|
||||
return result
|
||||
finally:
|
||||
await server.cleanup()
|
||||
|
||||
|
||||
async def mcp_flow_get_tool_schema(
|
||||
params: MCPServerSseParams,
|
||||
) -> Any:
|
||||
"""Async function to get tool schema from MCP server."""
|
||||
async with MCPServerSse(params) as server:
|
||||
tools = await server.list_tools()
|
||||
return tools
|
||||
|
||||
|
||||
async def mcp_flow(
|
||||
params: MCPServerSseParams,
|
||||
function_call: Dict[str, Any] | str,
|
||||
) -> Any:
|
||||
"""Async function to call a tool with given parameters."""
|
||||
async with MCPServerSse(params) as server:
|
||||
return await call_tool_fast(server, function_call)
|
||||
|
||||
|
||||
async def _call_one_server(
|
||||
params: MCPServerSseParams, payload: Dict[str, Any] | str
|
||||
) -> Any:
|
||||
"""Helper function to call a single MCP server."""
|
||||
server = MCPServerSse(params)
|
||||
try:
|
||||
await server.connect()
|
||||
arguments = payload if isinstance(payload, dict) else None
|
||||
return await server.call_tool(arguments=arguments)
|
||||
finally:
|
||||
await server.cleanup()
|
||||
|
||||
|
||||
async def abatch_mcp_flow(
|
||||
params: List[MCPServerSseParams], payload: Dict[str, Any] | str
|
||||
) -> List[Any]:
|
||||
"""Async function to execute a batch of MCP calls concurrently.
|
||||
|
||||
Args:
|
||||
params (List[MCPServerSseParams]): List of MCP server configurations
|
||||
payload (Dict[str, Any] | str): The payload to send to each server
|
||||
|
||||
Returns:
|
||||
List[Any]: Results from all MCP servers
|
||||
"""
|
||||
if not params:
|
||||
logger.warning("No MCP servers provided for batch operation")
|
||||
return []
|
||||
|
||||
try:
|
||||
return await asyncio.gather(*[_call_one_server(p, payload) for p in params])
|
||||
except Exception as e:
|
||||
logger.error(f"Error in abatch_mcp_flow: {e}")
|
||||
# Return partial results if any were successful
|
||||
return [f"Error in batch operation: {str(e)}"]
|
||||
|
||||
|
||||
def batch_mcp_flow(
|
||||
params: List[MCPServerSseParams], payload: Dict[str, Any] | str
|
||||
) -> List[Any]:
|
||||
"""Sync wrapper for batch MCP operations.
|
||||
|
||||
This creates a new event loop if needed to run the async batch operation.
|
||||
ONLY use this when not already in an async context.
|
||||
|
||||
Args:
|
||||
params (List[MCPServerSseParams]): List of MCP server configurations
|
||||
payload (Dict[str, Any] | str): The payload to send to each server
|
||||
|
||||
Returns:
|
||||
List[Any]: Results from all MCP servers
|
||||
"""
|
||||
if not params:
|
||||
logger.warning("No MCP servers provided for batch operation")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Check if we're already in an event loop
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
# No event loop exists, create one
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
if loop.is_running():
|
||||
# We're already in an async context, can't use asyncio.run
|
||||
# Use a future to bridge sync-async gap
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
abatch_mcp_flow(params, payload), loop
|
||||
)
|
||||
return future.result(timeout=30) # Add timeout to prevent hanging
|
||||
else:
|
||||
# We're not in an async context, safe to use loop.run_until_complete
|
||||
return loop.run_until_complete(abatch_mcp_flow(params, payload))
|
||||
except Exception as e:
|
||||
logger.error(f"Error in batch_mcp_flow: {e}")
|
||||
return [f"Error in batch operation: {str(e)}"]
|
Loading…
Reference in new issue