parent
f61ada7928
commit
616c5757b0
@ -0,0 +1,397 @@
|
|||||||
|
from swarms import Agent
|
||||||
|
from swarms.tools.mcp_integration import MCPServerSseParams, MCPServerSse, mcp_flow_get_tool_schema
|
||||||
|
from loguru import logger
|
||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import httpx
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Configure logging for more detailed output
|
||||||
|
logger.remove()
|
||||||
|
logger.add(sys.stdout,
|
||||||
|
level="DEBUG",
|
||||||
|
format="{time} | {level} | {module}:{function}:{line} - {message}")
|
||||||
|
|
||||||
|
# Relaxed prompt that doesn't enforce strict JSON formatting
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Create server parameters
|
||||||
|
def get_server_params():
|
||||||
|
"""Get the MCP server connection parameters."""
|
||||||
|
return MCPServerSseParams(
|
||||||
|
url=
|
||||||
|
"http://127.0.0.1:8000", # Use 127.0.0.1 instead of localhost/0.0.0.0
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "text/event-stream"
|
||||||
|
},
|
||||||
|
timeout=15.0, # Longer timeout
|
||||||
|
sse_read_timeout=60.0 # Longer read timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_math_system():
|
||||||
|
"""Initialize the math agent with MCP server configuration."""
|
||||||
|
# Create the agent with the MCP server configuration
|
||||||
|
math_agent = Agent(agent_name="Math Assistant",
|
||||||
|
agent_description="Friendly math calculator",
|
||||||
|
system_prompt=MATH_AGENT_PROMPT,
|
||||||
|
max_loops=1,
|
||||||
|
mcp_servers=[get_server_params()],
|
||||||
|
model_name="gpt-3.5-turbo",
|
||||||
|
verbose=True)
|
||||||
|
|
||||||
|
return math_agent
|
||||||
|
|
||||||
|
|
||||||
|
# Function to get list of available tools from the server
|
||||||
|
async def get_tools_list():
|
||||||
|
"""Fetch and format the list of available tools from the server."""
|
||||||
|
try:
|
||||||
|
server_params = get_server_params()
|
||||||
|
tools = await mcp_flow_get_tool_schema(server_params)
|
||||||
|
|
||||||
|
if not tools:
|
||||||
|
return "No tools are currently available on the server."
|
||||||
|
|
||||||
|
# Format the tools information
|
||||||
|
tools_info = "Available tools:\n"
|
||||||
|
for tool in tools:
|
||||||
|
tools_info += f"\n- {tool.name}: {tool.description or 'No description'}\n"
|
||||||
|
if tool.parameters and hasattr(tool.parameters, 'properties'):
|
||||||
|
tools_info += " Parameters:\n"
|
||||||
|
for param_name, param_info in tool.parameters.properties.items(
|
||||||
|
):
|
||||||
|
param_type = param_info.get('type', 'unknown')
|
||||||
|
param_desc = param_info.get('description',
|
||||||
|
'No description')
|
||||||
|
tools_info += f" - {param_name} ({param_type}): {param_desc}\n"
|
||||||
|
|
||||||
|
return tools_info
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get tools list: {e}")
|
||||||
|
return f"Error retrieving tools list: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
# Function to test server connection
|
||||||
|
def test_server_connection():
|
||||||
|
"""Test if the server is reachable and responsive."""
|
||||||
|
try:
|
||||||
|
# Create a short-lived connection to check server
|
||||||
|
server = MCPServerSse(get_server_params())
|
||||||
|
|
||||||
|
# Try connecting (this is synchronous)
|
||||||
|
asyncio.run(server.connect())
|
||||||
|
asyncio.run(server.cleanup())
|
||||||
|
logger.info("✅ Server connection test successful")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Server connection test failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Manual math operation handler as ultimate fallback
|
||||||
|
def manual_math(query):
|
||||||
|
"""Parse and solve a math problem without using the server."""
|
||||||
|
query = query.lower()
|
||||||
|
|
||||||
|
# Check if user is asking for available tools/functions
|
||||||
|
if "list" in query and ("tools" in query or "functions" in query
|
||||||
|
or "operations" in query):
|
||||||
|
return """
|
||||||
|
Available tools:
|
||||||
|
1. add - Add two numbers together (e.g., "add 3 and 4")
|
||||||
|
2. multiply - Multiply two numbers together (e.g., "multiply 5 and 6")
|
||||||
|
3. divide - Divide the first number by the second (e.g., "divide 10 by 2")
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
if "add" in query or "plus" in query or "sum" in query:
|
||||||
|
# Extract numbers using a simple approach
|
||||||
|
numbers = [int(s) for s in query.split() if s.isdigit()]
|
||||||
|
if len(numbers) >= 2:
|
||||||
|
result = numbers[0] + numbers[1]
|
||||||
|
return f"The sum of {numbers[0]} and {numbers[1]} is {result}"
|
||||||
|
|
||||||
|
elif "multiply" in query or "times" in query or "product" in query:
|
||||||
|
numbers = [int(s) for s in query.split() if s.isdigit()]
|
||||||
|
if len(numbers) >= 2:
|
||||||
|
result = numbers[0] * numbers[1]
|
||||||
|
return f"The product of {numbers[0]} and {numbers[1]} is {result}"
|
||||||
|
|
||||||
|
elif "divide" in query or "quotient" in query:
|
||||||
|
numbers = [int(s) for s in query.split() if s.isdigit()]
|
||||||
|
if len(numbers) >= 2:
|
||||||
|
if numbers[1] == 0:
|
||||||
|
return "Cannot divide by zero"
|
||||||
|
result = numbers[0] / numbers[1]
|
||||||
|
return f"{numbers[0]} divided by {numbers[1]} is {result}"
|
||||||
|
|
||||||
|
return "I couldn't parse your math request. Try something like 'add 3 and 4'."
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Manual math error: {e}")
|
||||||
|
return f"Error performing calculation: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
try:
|
||||||
|
logger.info("Initializing math system...")
|
||||||
|
|
||||||
|
# Test server connection first
|
||||||
|
server_available = test_server_connection()
|
||||||
|
|
||||||
|
if server_available:
|
||||||
|
math_agent = initialize_math_system()
|
||||||
|
print("\nMath Calculator Ready! (Server connection successful)")
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
"\nServer connection failed - using fallback calculator mode")
|
||||||
|
math_agent = None
|
||||||
|
|
||||||
|
print("Ask me any math question!")
|
||||||
|
print("Examples: 'what is 5 plus 3?' or 'can you multiply 4 and 6?'")
|
||||||
|
print("Type 'list tools' to see available operations")
|
||||||
|
print("Type 'exit' to quit\n")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
query = input("What would you like to calculate? ").strip()
|
||||||
|
if not query:
|
||||||
|
continue
|
||||||
|
if query.lower() == 'exit':
|
||||||
|
break
|
||||||
|
|
||||||
|
# Handle special commands
|
||||||
|
if query.lower() in ('list tools', 'show tools',
|
||||||
|
'available tools', 'what tools'):
|
||||||
|
if server_available:
|
||||||
|
# Get tools list from server
|
||||||
|
tools_info = asyncio.run(get_tools_list())
|
||||||
|
print(f"\n{tools_info}\n")
|
||||||
|
else:
|
||||||
|
# Use manual fallback
|
||||||
|
print(manual_math("list tools"))
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"Processing query: {query}")
|
||||||
|
|
||||||
|
# First try the agent if available
|
||||||
|
if math_agent and server_available:
|
||||||
|
try:
|
||||||
|
result = math_agent.run(query)
|
||||||
|
print(f"\nResult: {result}\n")
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Agent error: {e}")
|
||||||
|
print("Agent encountered an error, trying fallback...")
|
||||||
|
|
||||||
|
# If agent fails or isn't available, use manual calculator
|
||||||
|
result = manual_math(query)
|
||||||
|
print(f"\nCalculation result: {result}\n")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nGoodbye!")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing query: {e}")
|
||||||
|
print(f"Sorry, there was an error: {str(e)}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"System initialization error: {e}")
|
||||||
|
print(f"Failed to start the math system: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main() "from fastmcp import FastMCP
|
||||||
|
from loguru import logger
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Create the MCP server with detailed debugging
|
||||||
|
mcp = FastMCP(
|
||||||
|
host="0.0.0.0", # Bind to all interfaces
|
||||||
|
port=8000,
|
||||||
|
transport="sse",
|
||||||
|
require_session_id=False,
|
||||||
|
cors_allowed_origins=["*"], # Allow connections from any origin
|
||||||
|
debug=True # Enable debug mode for more verbose output
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Add a more flexible parsing approach
|
||||||
|
def parse_input(input_str):
|
||||||
|
"""Parse input that could be JSON or natural language."""
|
||||||
|
try:
|
||||||
|
# First try to parse as JSON
|
||||||
|
return json.loads(input_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# If not JSON, try to parse natural language
|
||||||
|
input_lower = input_str.lower()
|
||||||
|
|
||||||
|
# Parse for addition
|
||||||
|
if "add" in input_lower or "plus" in input_lower or "sum" in input_lower:
|
||||||
|
# Extract numbers - very simple approach
|
||||||
|
numbers = [int(s) for s in input_lower.split() if s.isdigit()]
|
||||||
|
if len(numbers) >= 2:
|
||||||
|
return {"a": numbers[0], "b": numbers[1]}
|
||||||
|
|
||||||
|
# Parse for multiplication
|
||||||
|
if "multiply" in input_lower or "times" in input_lower or "product" in input_lower:
|
||||||
|
numbers = [int(s) for s in input_lower.split() if s.isdigit()]
|
||||||
|
if len(numbers) >= 2:
|
||||||
|
return {"a": numbers[0], "b": numbers[1]}
|
||||||
|
|
||||||
|
# Parse for division
|
||||||
|
if "divide" in input_lower or "quotient" in input_lower:
|
||||||
|
numbers = [int(s) for s in input_lower.split() if s.isdigit()]
|
||||||
|
if len(numbers) >= 2:
|
||||||
|
return {"a": numbers[0], "b": numbers[1]}
|
||||||
|
|
||||||
|
# Could not parse successfully
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Define tools with more flexible input handling
|
||||||
|
@mcp.tool()
|
||||||
|
def add(input_str=None, a=None, b=None):
|
||||||
|
"""Add two numbers. Can accept JSON parameters or natural language.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_str (str, optional): Natural language input to parse
|
||||||
|
a (int, optional): First number if provided directly
|
||||||
|
b (int, optional): Second number if provided directly
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A message containing the sum
|
||||||
|
"""
|
||||||
|
logger.info(f"Add tool called with input_str={input_str}, a={a}, b={b}")
|
||||||
|
|
||||||
|
# If we got a natural language string instead of parameters
|
||||||
|
if input_str and not (a is not None and b is not None):
|
||||||
|
parsed = parse_input(input_str)
|
||||||
|
if parsed:
|
||||||
|
a = parsed.get("a")
|
||||||
|
b = parsed.get("b")
|
||||||
|
|
||||||
|
# Validate we have what we need
|
||||||
|
if a is None or b is None:
|
||||||
|
return "Sorry, I couldn't understand the numbers to add"
|
||||||
|
|
||||||
|
try:
|
||||||
|
a = int(a)
|
||||||
|
b = int(b)
|
||||||
|
result = a + b
|
||||||
|
return f"The sum of {a} and {b} is {result}"
|
||||||
|
except ValueError:
|
||||||
|
return "Please provide valid numbers for addition"
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
def multiply(input_str=None, a=None, b=None):
|
||||||
|
"""Multiply two numbers. Can accept JSON parameters or natural language.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_str (str, optional): Natural language input to parse
|
||||||
|
a (int, optional): First number if provided directly
|
||||||
|
b (int, optional): Second number if provided directly
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A message containing the product
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"Multiply tool called with input_str={input_str}, a={a}, b={b}")
|
||||||
|
|
||||||
|
# If we got a natural language string instead of parameters
|
||||||
|
if input_str and not (a is not None and b is not None):
|
||||||
|
parsed = parse_input(input_str)
|
||||||
|
if parsed:
|
||||||
|
a = parsed.get("a")
|
||||||
|
b = parsed.get("b")
|
||||||
|
|
||||||
|
# Validate we have what we need
|
||||||
|
if a is None or b is None:
|
||||||
|
return "Sorry, I couldn't understand the numbers to multiply"
|
||||||
|
|
||||||
|
try:
|
||||||
|
a = int(a)
|
||||||
|
b = int(b)
|
||||||
|
result = a * b
|
||||||
|
return f"The product of {a} and {b} is {result}"
|
||||||
|
except ValueError:
|
||||||
|
return "Please provide valid numbers for multiplication"
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
def divide(input_str=None, a=None, b=None):
|
||||||
|
"""Divide two numbers. Can accept JSON parameters or natural language.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_str (str, optional): Natural language input to parse
|
||||||
|
a (int, optional): Numerator if provided directly
|
||||||
|
b (int, optional): Denominator if provided directly
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A message containing the division result or an error message
|
||||||
|
"""
|
||||||
|
logger.info(f"Divide tool called with input_str={input_str}, a={a}, b={b}")
|
||||||
|
|
||||||
|
# If we got a natural language string instead of parameters
|
||||||
|
if input_str and not (a is not None and b is not None):
|
||||||
|
parsed = parse_input(input_str)
|
||||||
|
if parsed:
|
||||||
|
a = parsed.get("a")
|
||||||
|
b = parsed.get("b")
|
||||||
|
|
||||||
|
# Validate we have what we need
|
||||||
|
if a is None or b is None:
|
||||||
|
return "Sorry, I couldn't understand the numbers to divide"
|
||||||
|
|
||||||
|
try:
|
||||||
|
a = int(a)
|
||||||
|
b = int(b)
|
||||||
|
|
||||||
|
if b == 0:
|
||||||
|
logger.warning("Division by zero attempted")
|
||||||
|
return "Cannot divide by zero"
|
||||||
|
|
||||||
|
result = a / b
|
||||||
|
return f"{a} divided by {b} is {result}"
|
||||||
|
except ValueError:
|
||||||
|
return "Please provide valid numbers for division"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
logger.info("Starting math server on http://0.0.0.0:8000")
|
||||||
|
print("Math MCP Server is running. Press Ctrl+C to stop.")
|
||||||
|
print(
|
||||||
|
"Server is configured to accept both JSON and natural language input"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add a small delay to ensure logging is complete before the server starts
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
# Run the MCP server
|
||||||
|
mcp.run()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Server shutdown requested")
|
||||||
|
print("\nShutting down server...")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Server error: {e}")
|
||||||
|
raise
|
||||||
|
" server is runnig poeroperly "2025-04-20 17:35:01.251 | INFO | __main__:<module>:161 - Starting math server on http://0.0.0.0:8000
|
||||||
|
Math MCP Server is running. Press Ctrl+C to stop.
|
||||||
|
Server is configured to accept both JSON and natural language input
|
||||||
|
[04/20/25 17:35:01] INFO Starting server "FastMCP"... " butwhy im getting these errore "2025-04-20T17:35:04.174629+0000 | INFO | mcp_client:main:159 - Initializing math system...
|
||||||
|
2025-04-20T17:35:04.203591+0000 | ERROR | mcp_integration:connect:89 - Error initializing MCP server: unhandled errors in a TaskGroup (1 sub-exception)
|
||||||
|
2025-04-20T17:35:04.204437+0000 | ERROR | mcp_client:test_server_connection:110 - ❌ Server connection test failed: unhandled errors in a TaskGroup (1 sub-exception)
|
||||||
|
|
||||||
|
Server connection failed - using fallback calculator mode
|
||||||
|
Ask me any math question!
|
||||||
|
Examples: 'what is 5 plus 3?' or 'can you multiply 4 and 6?'
|
||||||
|
Type 'list tools' to see available operations
|
||||||
|
Type 'exit' to quit
|
||||||
|
|
||||||
|
What would you like to calculate? "
|
File diff suppressed because it is too large
Load Diff
@ -1,320 +1,311 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Literal, Union
|
from typing import Any, Dict, List, Optional, Literal, Union
|
||||||
from typing_extensions import NotRequired, TypedDict
|
from typing_extensions import NotRequired, TypedDict
|
||||||
|
|
||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
|
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
|
||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
from mcp.types import CallToolResult, JSONRPCMessage
|
from mcp.types import CallToolResult, JSONRPCMessage
|
||||||
|
|
||||||
from swarms.utils.any_to_str import any_to_str
|
from swarms.utils.any_to_str import any_to_str
|
||||||
|
|
||||||
|
|
||||||
class MCPServer(abc.ABC):
|
class MCPServer(abc.ABC):
|
||||||
"""Base class for Model Context Protocol servers."""
|
"""Base class for Model Context Protocol servers."""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def connect(self) -> None:
|
async def connect(self) -> None:
|
||||||
"""Establish connection to the MCP server."""
|
"""Establish connection to the MCP server."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
"""Human-readable server name."""
|
"""Human-readable server name."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def cleanup(self) -> None:
|
async def cleanup(self) -> None:
|
||||||
"""Clean up resources and close connection."""
|
"""Clean up resources and close connection."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def list_tools(self) -> List[MCPTool]:
|
async def list_tools(self) -> List[MCPTool]:
|
||||||
"""List available MCP tools on the server."""
|
"""List available MCP tools on the server."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def call_tool(
|
async def call_tool(self, tool_name: str,
|
||||||
self, tool_name: str, arguments: Dict[str, Any] | None
|
arguments: Dict[str, Any] | None) -> CallToolResult:
|
||||||
) -> CallToolResult:
|
"""Invoke a tool by name with provided arguments."""
|
||||||
"""Invoke a tool by name with provided arguments."""
|
pass
|
||||||
pass
|
|
||||||
|
|
||||||
|
class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
||||||
class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
"""Mixin providing ClientSession-based MCP communication."""
|
||||||
"""Mixin providing ClientSession-based MCP communication."""
|
|
||||||
|
def __init__(self, cache_tools_list: bool = False):
|
||||||
def __init__(self, cache_tools_list: bool = False):
|
self.session: Optional[ClientSession] = None
|
||||||
self.session: Optional[ClientSession] = None
|
self.exit_stack: AsyncExitStack = AsyncExitStack()
|
||||||
self.exit_stack: AsyncExitStack = AsyncExitStack()
|
self._cleanup_lock = asyncio.Lock()
|
||||||
self._cleanup_lock = asyncio.Lock()
|
self.cache_tools_list = cache_tools_list
|
||||||
self.cache_tools_list = cache_tools_list
|
self._cache_dirty = True
|
||||||
self._cache_dirty = True
|
self._tools_list: Optional[List[MCPTool]] = None
|
||||||
self._tools_list: Optional[List[MCPTool]] = None
|
|
||||||
|
@abc.abstractmethod
|
||||||
@abc.abstractmethod
|
def create_streams(
|
||||||
def create_streams(
|
self
|
||||||
self
|
) -> AbstractAsyncContextManager[tuple[
|
||||||
) -> AbstractAsyncContextManager[
|
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||||
tuple[
|
MemoryObjectSendStream[JSONRPCMessage],
|
||||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
]]:
|
||||||
MemoryObjectSendStream[JSONRPCMessage],
|
"""Supply the read/write streams for the MCP transport."""
|
||||||
]
|
pass
|
||||||
]:
|
|
||||||
"""Supply the read/write streams for the MCP transport."""
|
async def __aenter__(self) -> MCPServer:
|
||||||
pass
|
await self.connect()
|
||||||
|
return self # type: ignore
|
||||||
async def __aenter__(self) -> MCPServer:
|
|
||||||
await self.connect()
|
async def __aexit__(self, exc_type, exc_value, tb) -> None:
|
||||||
return self # type: ignore
|
await self.cleanup()
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_value, tb) -> None:
|
async def connect(self) -> None:
|
||||||
await self.cleanup()
|
"""Initialize transport and ClientSession."""
|
||||||
|
try:
|
||||||
async def connect(self) -> None:
|
transport = await self.exit_stack.enter_async_context(
|
||||||
"""Initialize transport and ClientSession."""
|
self.create_streams())
|
||||||
try:
|
read, write = transport
|
||||||
transport = await self.exit_stack.enter_async_context(
|
session = await self.exit_stack.enter_async_context(
|
||||||
self.create_streams()
|
ClientSession(read, write))
|
||||||
)
|
await session.initialize()
|
||||||
read, write = transport
|
self.session = session
|
||||||
session = await self.exit_stack.enter_async_context(
|
except Exception as e:
|
||||||
ClientSession(read, write)
|
logger.error(f"Error initializing MCP server: {e}")
|
||||||
)
|
await self.cleanup()
|
||||||
await session.initialize()
|
raise
|
||||||
self.session = session
|
|
||||||
except Exception as e:
|
async def cleanup(self) -> None:
|
||||||
logger.error(f"Error initializing MCP server: {e}")
|
"""Close session and transport."""
|
||||||
await self.cleanup()
|
async with self._cleanup_lock:
|
||||||
raise
|
try:
|
||||||
|
await self.exit_stack.aclose()
|
||||||
async def cleanup(self) -> None:
|
except Exception as e:
|
||||||
"""Close session and transport."""
|
logger.error(f"Error during cleanup: {e}")
|
||||||
async with self._cleanup_lock:
|
finally:
|
||||||
try:
|
self.session = None
|
||||||
await self.exit_stack.aclose()
|
|
||||||
except Exception as e:
|
async def list_tools(self) -> List[MCPTool]:
|
||||||
logger.error(f"Error during cleanup: {e}")
|
if not self.session:
|
||||||
finally:
|
raise RuntimeError("Server not connected. Call connect() first.")
|
||||||
self.session = None
|
if self.cache_tools_list and not self._cache_dirty and self._tools_list:
|
||||||
|
return self._tools_list
|
||||||
async def list_tools(self) -> List[MCPTool]:
|
self._cache_dirty = False
|
||||||
if not self.session:
|
self._tools_list = (await self.session.list_tools()).tools
|
||||||
raise RuntimeError("Server not connected. Call connect() first.")
|
return self._tools_list # type: ignore
|
||||||
if self.cache_tools_list and not self._cache_dirty and self._tools_list:
|
|
||||||
return self._tools_list
|
async def call_tool(
|
||||||
self._cache_dirty = False
|
self,
|
||||||
self._tools_list = (await self.session.list_tools()).tools
|
tool_name: str | None = None,
|
||||||
return self._tools_list # type: ignore
|
arguments: Dict[str, Any] | None = None) -> CallToolResult:
|
||||||
|
if not arguments:
|
||||||
async def call_tool(
|
raise ValueError("Arguments dict is required to call a tool")
|
||||||
self, tool_name: str | None = None, arguments: Dict[str, Any] | None = None
|
name = tool_name or arguments.get("tool_name") or arguments.get("name")
|
||||||
) -> CallToolResult:
|
if not name:
|
||||||
if not arguments:
|
raise ValueError("Tool name missing in arguments")
|
||||||
raise ValueError("Arguments dict is required to call a tool")
|
if not self.session:
|
||||||
name = tool_name or arguments.get("tool_name") or arguments.get("name")
|
raise RuntimeError("Server not connected. Call connect() first.")
|
||||||
if not name:
|
return await self.session.call_tool(name, arguments)
|
||||||
raise ValueError("Tool name missing in arguments")
|
|
||||||
if not self.session:
|
|
||||||
raise RuntimeError("Server not connected. Call connect() first.")
|
class MCPServerStdioParams(TypedDict):
|
||||||
return await self.session.call_tool(name, arguments)
|
"""Configuration for stdio transport."""
|
||||||
|
command: str
|
||||||
|
args: NotRequired[List[str]]
|
||||||
class MCPServerStdioParams(TypedDict):
|
env: NotRequired[Dict[str, str]]
|
||||||
"""Configuration for stdio transport."""
|
cwd: NotRequired[str | Path]
|
||||||
command: str
|
encoding: NotRequired[str]
|
||||||
args: NotRequired[List[str]]
|
encoding_error_handler: NotRequired[Literal["strict", "ignore", "replace"]]
|
||||||
env: NotRequired[Dict[str, str]]
|
|
||||||
cwd: NotRequired[str | Path]
|
|
||||||
encoding: NotRequired[str]
|
class MCPServerStdio(_MCPServerWithClientSession):
|
||||||
encoding_error_handler: NotRequired[Literal["strict", "ignore", "replace"]]
|
"""MCP server over stdio transport."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
class MCPServerStdio(_MCPServerWithClientSession):
|
self,
|
||||||
"""MCP server over stdio transport."""
|
params: MCPServerStdioParams,
|
||||||
|
cache_tools_list: bool = False,
|
||||||
def __init__(
|
name: Optional[str] = None,
|
||||||
self,
|
):
|
||||||
params: MCPServerStdioParams,
|
super().__init__(cache_tools_list)
|
||||||
cache_tools_list: bool = False,
|
self.params = StdioServerParameters(
|
||||||
name: Optional[str] = None,
|
command=params["command"],
|
||||||
):
|
args=params.get("args", []),
|
||||||
super().__init__(cache_tools_list)
|
env=params.get("env"),
|
||||||
self.params = StdioServerParameters(
|
cwd=params.get("cwd"),
|
||||||
command=params["command"],
|
encoding=params.get("encoding", "utf-8"),
|
||||||
args=params.get("args", []),
|
encoding_error_handler=params.get("encoding_error_handler",
|
||||||
env=params.get("env"),
|
"strict"),
|
||||||
cwd=params.get("cwd"),
|
)
|
||||||
encoding=params.get("encoding", "utf-8"),
|
self._name = name or f"stdio:{self.params.command}"
|
||||||
encoding_error_handler=params.get("encoding_error_handler", "strict"),
|
|
||||||
)
|
def create_streams(
|
||||||
self._name = name or f"stdio:{self.params.command}"
|
self
|
||||||
|
) -> AbstractAsyncContextManager[tuple[
|
||||||
def create_streams(self) -> AbstractAsyncContextManager[
|
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||||
tuple[
|
MemoryObjectSendStream[JSONRPCMessage],
|
||||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
]]:
|
||||||
MemoryObjectSendStream[JSONRPCMessage],
|
return stdio_client(self.params)
|
||||||
]
|
|
||||||
]:
|
@property
|
||||||
return stdio_client(self.params)
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return self._name
|
class MCPServerSseParams(TypedDict):
|
||||||
|
"""Configuration for HTTP+SSE transport."""
|
||||||
|
url: str
|
||||||
class MCPServerSseParams(TypedDict):
|
headers: NotRequired[Dict[str, str]]
|
||||||
"""Configuration for HTTP+SSE transport."""
|
timeout: NotRequired[float]
|
||||||
url: str
|
sse_read_timeout: NotRequired[float]
|
||||||
headers: NotRequired[Dict[str, str]]
|
|
||||||
timeout: NotRequired[float]
|
|
||||||
sse_read_timeout: NotRequired[float]
|
class MCPServerSse(_MCPServerWithClientSession):
|
||||||
|
"""MCP server over HTTP with SSE transport."""
|
||||||
|
|
||||||
class MCPServerSse(_MCPServerWithClientSession):
|
def __init__(
|
||||||
"""MCP server over HTTP with SSE transport."""
|
self,
|
||||||
|
params: MCPServerSseParams,
|
||||||
def __init__(
|
cache_tools_list: bool = False,
|
||||||
self,
|
name: Optional[str] = None,
|
||||||
params: MCPServerSseParams,
|
):
|
||||||
cache_tools_list: bool = False,
|
super().__init__(cache_tools_list)
|
||||||
name: Optional[str] = None,
|
self.params = params
|
||||||
):
|
self._name = name or f"sse:{params['url']}"
|
||||||
super().__init__(cache_tools_list)
|
|
||||||
self.params = params
|
def create_streams(
|
||||||
self._name = name or f"sse:{params['url']}"
|
self
|
||||||
|
) -> AbstractAsyncContextManager[tuple[
|
||||||
def create_streams(self) -> AbstractAsyncContextManager[
|
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||||
tuple[
|
MemoryObjectSendStream[JSONRPCMessage],
|
||||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
]]:
|
||||||
MemoryObjectSendStream[JSONRPCMessage],
|
return sse_client(
|
||||||
]
|
url=self.params["url"],
|
||||||
]:
|
headers=self.params.get("headers"),
|
||||||
return sse_client(
|
timeout=self.params.get("timeout", 5),
|
||||||
url=self.params["url"],
|
sse_read_timeout=self.params.get("sse_read_timeout", 300),
|
||||||
headers=self.params.get("headers"),
|
)
|
||||||
timeout=self.params.get("timeout", 5),
|
|
||||||
sse_read_timeout=self.params.get("sse_read_timeout", 300),
|
@property
|
||||||
)
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return self._name
|
async def call_tool_fast(server: MCPServerSse,
|
||||||
|
payload: Dict[str, Any] | str) -> Any:
|
||||||
|
"""Async function to call a tool on a server with proper cleanup."""
|
||||||
async def call_tool_fast(
|
try:
|
||||||
server: MCPServerSse, payload: Dict[str, Any] | str
|
await server.connect()
|
||||||
) -> Any:
|
arguments = payload if isinstance(payload, dict) else None
|
||||||
"""Async function to call a tool on a server with proper cleanup."""
|
result = await server.call_tool(arguments=arguments)
|
||||||
try:
|
return result
|
||||||
await server.connect()
|
finally:
|
||||||
arguments = payload if isinstance(payload, dict) else None
|
await server.cleanup()
|
||||||
result = await server.call_tool(arguments=arguments)
|
|
||||||
return result
|
|
||||||
finally:
|
async def mcp_flow_get_tool_schema(params: MCPServerSseParams, ) -> Any:
|
||||||
await server.cleanup()
|
"""Async function to get tool schema from MCP server."""
|
||||||
|
async with MCPServerSse(params) as server:
|
||||||
|
tools = await server.list_tools()
|
||||||
async def mcp_flow_get_tool_schema(
|
return tools
|
||||||
params: MCPServerSseParams,
|
|
||||||
) -> Any:
|
|
||||||
"""Async function to get tool schema from MCP server."""
|
async def mcp_flow(
|
||||||
async with MCPServerSse(params) as server:
|
params: MCPServerSseParams,
|
||||||
tools = await server.list_tools()
|
function_call: Dict[str, Any] | str,
|
||||||
return tools
|
) -> Any:
|
||||||
|
"""Async function to call a tool with given parameters."""
|
||||||
|
async with MCPServerSse(params) as server:
|
||||||
async def mcp_flow(
|
return await call_tool_fast(server, function_call)
|
||||||
params: MCPServerSseParams,
|
|
||||||
function_call: Dict[str, Any] | str,
|
|
||||||
) -> Any:
|
async def _call_one_server(params: MCPServerSseParams,
|
||||||
"""Async function to call a tool with given parameters."""
|
payload: Dict[str, Any] | str) -> Any:
|
||||||
async with MCPServerSse(params) as server:
|
"""Helper function to call a single MCP server."""
|
||||||
return await call_tool_fast(server, function_call)
|
server = MCPServerSse(params)
|
||||||
|
try:
|
||||||
|
await server.connect()
|
||||||
async def _call_one_server(
|
arguments = payload if isinstance(payload, dict) else None
|
||||||
params: MCPServerSseParams, payload: Dict[str, Any] | str
|
return await server.call_tool(arguments=arguments)
|
||||||
) -> Any:
|
finally:
|
||||||
"""Helper function to call a single MCP server."""
|
await server.cleanup()
|
||||||
server = MCPServerSse(params)
|
|
||||||
try:
|
|
||||||
await server.connect()
|
async def abatch_mcp_flow(params: List[MCPServerSseParams],
|
||||||
arguments = payload if isinstance(payload, dict) else None
|
payload: Dict[str, Any] | str) -> List[Any]:
|
||||||
return await server.call_tool(arguments=arguments)
|
"""Async function to execute a batch of MCP calls concurrently.
|
||||||
finally:
|
|
||||||
await server.cleanup()
|
Args:
|
||||||
|
params (List[MCPServerSseParams]): List of MCP server configurations
|
||||||
|
payload (Dict[str, Any] | str): The payload to send to each server
|
||||||
async def abatch_mcp_flow(
|
|
||||||
params: List[MCPServerSseParams], payload: Dict[str, Any] | str
|
Returns:
|
||||||
) -> List[Any]:
|
List[Any]: Results from all MCP servers
|
||||||
"""Async function to execute a batch of MCP calls concurrently.
|
"""
|
||||||
|
if not params:
|
||||||
Args:
|
logger.warning("No MCP servers provided for batch operation")
|
||||||
params (List[MCPServerSseParams]): List of MCP server configurations
|
return []
|
||||||
payload (Dict[str, Any] | str): The payload to send to each server
|
|
||||||
|
try:
|
||||||
Returns:
|
return await asyncio.gather(
|
||||||
List[Any]: Results from all MCP servers
|
*[_call_one_server(p, payload) for p in params])
|
||||||
"""
|
except Exception as e:
|
||||||
if not params:
|
logger.error(f"Error in abatch_mcp_flow: {e}")
|
||||||
logger.warning("No MCP servers provided for batch operation")
|
# Return partial results if any were successful
|
||||||
return []
|
return [f"Error in batch operation: {str(e)}"]
|
||||||
|
|
||||||
try:
|
|
||||||
return await asyncio.gather(*[_call_one_server(p, payload) for p in params])
|
def batch_mcp_flow(params: List[MCPServerSseParams],
|
||||||
except Exception as e:
|
payload: Dict[str, Any] | str) -> List[Any]:
|
||||||
logger.error(f"Error in abatch_mcp_flow: {e}")
|
"""Sync wrapper for batch MCP operations.
|
||||||
# Return partial results if any were successful
|
|
||||||
return [f"Error in batch operation: {str(e)}"]
|
This creates a new event loop if needed to run the async batch operation.
|
||||||
|
ONLY use this when not already in an async context.
|
||||||
|
|
||||||
def batch_mcp_flow(
|
Args:
|
||||||
params: List[MCPServerSseParams], payload: Dict[str, Any] | str
|
params (List[MCPServerSseParams]): List of MCP server configurations
|
||||||
) -> List[Any]:
|
payload (Dict[str, Any] | str): The payload to send to each server
|
||||||
"""Sync wrapper for batch MCP operations.
|
|
||||||
|
Returns:
|
||||||
This creates a new event loop if needed to run the async batch operation.
|
List[Any]: Results from all MCP servers
|
||||||
ONLY use this when not already in an async context.
|
"""
|
||||||
|
if not params:
|
||||||
Args:
|
logger.warning("No MCP servers provided for batch operation")
|
||||||
params (List[MCPServerSseParams]): List of MCP server configurations
|
return []
|
||||||
payload (Dict[str, Any] | str): The payload to send to each server
|
|
||||||
|
try:
|
||||||
Returns:
|
# Check if we're already in an event loop
|
||||||
List[Any]: Results from all MCP servers
|
try:
|
||||||
"""
|
loop = asyncio.get_event_loop()
|
||||||
if not params:
|
except RuntimeError:
|
||||||
logger.warning("No MCP servers provided for batch operation")
|
# No event loop exists, create one
|
||||||
return []
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
try:
|
|
||||||
# Check if we're already in an event loop
|
if loop.is_running():
|
||||||
try:
|
# We're already in an async context, can't use asyncio.run
|
||||||
loop = asyncio.get_event_loop()
|
# Use a future to bridge sync-async gap
|
||||||
except RuntimeError:
|
future = asyncio.run_coroutine_threadsafe(
|
||||||
# No event loop exists, create one
|
abatch_mcp_flow(params, payload), loop)
|
||||||
loop = asyncio.new_event_loop()
|
return future.result(timeout=30) # Add timeout to prevent hanging
|
||||||
asyncio.set_event_loop(loop)
|
else:
|
||||||
|
# We're not in an async context, safe to use loop.run_until_complete
|
||||||
if loop.is_running():
|
return loop.run_until_complete(abatch_mcp_flow(params, payload))
|
||||||
# We're already in an async context, can't use asyncio.run
|
except Exception as e:
|
||||||
# Use a future to bridge sync-async gap
|
logger.error(f"Error in batch_mcp_flow: {e}")
|
||||||
future = asyncio.run_coroutine_threadsafe(
|
return [f"Error in batch operation: {str(e)}"]
|
||||||
abatch_mcp_flow(params, payload), loop
|
|
||||||
)
|
|
||||||
return future.result(timeout=30) # Add timeout to prevent hanging
|
|
||||||
else:
|
|
||||||
# We're not in an async context, safe to use loop.run_until_complete
|
|
||||||
return loop.run_until_complete(abatch_mcp_flow(params, payload))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in batch_mcp_flow: {e}")
|
|
||||||
return [f"Error in batch operation: {str(e)}"]
|
|
||||||
|
Loading…
Reference in new issue