|
|
|
@ -284,12 +284,16 @@ class MCPServerSse:
|
|
|
|
|
def __init__(self, params: MCPServerSseParams):
|
|
|
|
|
self.params = params
|
|
|
|
|
self.client: Optional[ClientSession] = None
|
|
|
|
|
self._connection_lock = asyncio.Lock()
|
|
|
|
|
self.messages = [] # Store messages instead of using conversation
|
|
|
|
|
self.preserve_format = True # Flag to preserve original formatting
|
|
|
|
|
|
|
|
|
|
async def connect(self):
|
|
|
|
|
"""Connect to the MCP server."""
|
|
|
|
|
if not self.client:
|
|
|
|
|
self.client = ClientSession()
|
|
|
|
|
await self.client.connect(self.create_streams())
|
|
|
|
|
"""Connect to the MCP server with proper locking."""
|
|
|
|
|
async with self._connection_lock:
|
|
|
|
|
if not self.client:
|
|
|
|
|
self.client = ClientSession()
|
|
|
|
|
await self.client.connect(self.create_streams())
|
|
|
|
|
|
|
|
|
|
def create_streams(self, **kwargs) -> AbstractAsyncContextManager[Any]:
|
|
|
|
|
return sse_client(
|
|
|
|
@ -299,45 +303,159 @@ class MCPServerSse:
|
|
|
|
|
sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def call_tool(self, payload: dict[str, Any]):
|
|
|
|
|
"""Call a tool on the MCP server."""
|
|
|
|
|
def _parse_input(self, payload: Any) -> dict:
|
|
|
|
|
"""Parse input while preserving original format."""
|
|
|
|
|
if isinstance(payload, dict):
|
|
|
|
|
return payload
|
|
|
|
|
|
|
|
|
|
if isinstance(payload, str):
|
|
|
|
|
try:
|
|
|
|
|
# Try to parse as JSON
|
|
|
|
|
import json
|
|
|
|
|
return json.loads(payload)
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
# Check if it's a math operation
|
|
|
|
|
import re
|
|
|
|
|
|
|
|
|
|
# Pattern matching for basic math operations
|
|
|
|
|
add_pattern = r"(?i)(?:what\s+is\s+)?(\d+)\s*(?:plus|\+)\s*(\d+)"
|
|
|
|
|
mult_pattern = r"(?i)(?:multiply|times|\*)\s*(\d+)\s*(?:and|by)?\s*(\d+)"
|
|
|
|
|
div_pattern = r"(?i)(?:divide)\s*(\d+)\s*(?:by)\s*(\d+)"
|
|
|
|
|
|
|
|
|
|
# Check for addition
|
|
|
|
|
if match := re.search(add_pattern, payload):
|
|
|
|
|
a, b = map(int, match.groups())
|
|
|
|
|
return {"tool_name": "add", "a": a, "b": b}
|
|
|
|
|
|
|
|
|
|
# Check for multiplication
|
|
|
|
|
if match := re.search(mult_pattern, payload):
|
|
|
|
|
a, b = map(int, match.groups())
|
|
|
|
|
return {"tool_name": "multiply", "a": a, "b": b}
|
|
|
|
|
|
|
|
|
|
# Check for division
|
|
|
|
|
if match := re.search(div_pattern, payload):
|
|
|
|
|
a, b = map(int, match.groups())
|
|
|
|
|
return {"tool_name": "divide", "a": a, "b": b}
|
|
|
|
|
|
|
|
|
|
# Default to text input if no pattern matches
|
|
|
|
|
return {"text": payload}
|
|
|
|
|
|
|
|
|
|
return {"text": str(payload)}
|
|
|
|
|
|
|
|
|
|
def _format_output(self, result: Any, original_input: Any) -> str:
|
|
|
|
|
"""Format output based on input type and result."""
|
|
|
|
|
if not self.preserve_format:
|
|
|
|
|
return str(result)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
if isinstance(result, (int, float)):
|
|
|
|
|
# For numeric results, format based on operation
|
|
|
|
|
if isinstance(original_input, dict):
|
|
|
|
|
tool_name = original_input.get("tool_name", "")
|
|
|
|
|
if tool_name == "add":
|
|
|
|
|
return f"{original_input['a']} + {original_input['b']} = {result}"
|
|
|
|
|
elif tool_name == "multiply":
|
|
|
|
|
return f"{original_input['a']} * {original_input['b']} = {result}"
|
|
|
|
|
elif tool_name == "divide":
|
|
|
|
|
return f"{original_input['a']} / {original_input['b']} = {result}"
|
|
|
|
|
return str(result)
|
|
|
|
|
elif isinstance(result, dict):
|
|
|
|
|
return json.dumps(result, indent=2)
|
|
|
|
|
else:
|
|
|
|
|
return str(result)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error formatting output: {e}")
|
|
|
|
|
return str(result)
|
|
|
|
|
|
|
|
|
|
async def call_tool(self, payload: Any) -> Any:
|
|
|
|
|
"""Call a tool on the MCP server with support for various input formats."""
|
|
|
|
|
if not self.client:
|
|
|
|
|
raise RuntimeError("Not connected to MCP server")
|
|
|
|
|
return await self.client.call_tool(payload)
|
|
|
|
|
|
|
|
|
|
async def cleanup(self):
|
|
|
|
|
"""Clean up the connection."""
|
|
|
|
|
if self.client:
|
|
|
|
|
await self.client.close()
|
|
|
|
|
self.client = None
|
|
|
|
|
# Store original input for formatting
|
|
|
|
|
original_input = payload
|
|
|
|
|
|
|
|
|
|
# Parse input
|
|
|
|
|
parsed_payload = self._parse_input(payload)
|
|
|
|
|
|
|
|
|
|
async def list_tools(self) -> list[Any]: # Added for compatibility
|
|
|
|
|
# Add message to history
|
|
|
|
|
self.messages.append({
|
|
|
|
|
"role": "user",
|
|
|
|
|
"content": str(payload),
|
|
|
|
|
"parsed": parsed_payload
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
result = await self.client.call_tool(parsed_payload)
|
|
|
|
|
formatted_result = self._format_output(result, original_input)
|
|
|
|
|
|
|
|
|
|
self.messages.append({
|
|
|
|
|
"role": "assistant",
|
|
|
|
|
"content": formatted_result,
|
|
|
|
|
"raw_result": result
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
return formatted_result
|
|
|
|
|
except Exception as e:
|
|
|
|
|
error_msg = f"Error calling tool: {str(e)}"
|
|
|
|
|
self.messages.append({
|
|
|
|
|
"role": "error",
|
|
|
|
|
"content": error_msg,
|
|
|
|
|
"original_input": payload
|
|
|
|
|
})
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
async def cleanup(self):
|
|
|
|
|
"""Clean up the connection with proper locking."""
|
|
|
|
|
async with self._connection_lock:
|
|
|
|
|
if self.client:
|
|
|
|
|
await self.client.close()
|
|
|
|
|
self.client = None
|
|
|
|
|
|
|
|
|
|
async def list_tools(self) -> list[Any]:
|
|
|
|
|
"""List available tools with proper error handling."""
|
|
|
|
|
if not self.client:
|
|
|
|
|
raise RuntimeError("Not connected to MCP server")
|
|
|
|
|
return await self.client.list_tools()
|
|
|
|
|
try:
|
|
|
|
|
return await self.client.list_tools()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error listing tools: {e}")
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def call_tool_fast(server: MCPServerSse, payload: dict[str, Any]):
|
|
|
|
|
async def call_tool_fast(server: MCPServerSse, payload: dict[str, Any] | str):
|
|
|
|
|
"""
|
|
|
|
|
Convenience wrapper that opens → calls → closes in one shot.
|
|
|
|
|
Convenience wrapper that opens → calls → closes in one shot with proper error handling.
|
|
|
|
|
"""
|
|
|
|
|
await server.connect()
|
|
|
|
|
result = await server.call_tool(payload)
|
|
|
|
|
await server.cleanup()
|
|
|
|
|
return result.model_dump() if hasattr(result, "model_dump") else result
|
|
|
|
|
try:
|
|
|
|
|
await server.connect()
|
|
|
|
|
result = await server.call_tool(payload)
|
|
|
|
|
return result.model_dump() if hasattr(result, "model_dump") else result
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error in call_tool_fast: {e}")
|
|
|
|
|
raise
|
|
|
|
|
finally:
|
|
|
|
|
await server.cleanup()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def mcp_flow_get_tool_schema(
|
|
|
|
|
params: MCPServerSseParams,
|
|
|
|
|
) -> Any: # Updated return type to Any
|
|
|
|
|
async with MCPServerSse(params) as server:
|
|
|
|
|
return (await server.list_tools()).model_dump()
|
|
|
|
|
) -> Any:
|
|
|
|
|
"""Get tool schema with proper error handling."""
|
|
|
|
|
try:
|
|
|
|
|
async with MCPServerSse(params) as server:
|
|
|
|
|
tools = await server.list_tools()
|
|
|
|
|
return tools.model_dump() if hasattr(tools, "model_dump") else tools
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error getting tool schema: {e}")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def mcp_flow(
|
|
|
|
|
params: MCPServerSseParams,
|
|
|
|
|
function_call: dict[str, Any],
|
|
|
|
|
) -> Any: # Updated return type to Any
|
|
|
|
|
function_call: dict[str, Any] | str,
|
|
|
|
|
) -> Any:
|
|
|
|
|
"""Execute MCP flow with proper error handling."""
|
|
|
|
|
try:
|
|
|
|
|
async with MCPServerSse(params) as server:
|
|
|
|
|
return await call_tool_fast(server, function_call)
|
|
|
|
@ -346,28 +464,40 @@ async def mcp_flow(
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Helper function to call one MCP server
|
|
|
|
|
async def _call_one_server(param: MCPServerSseParams, payload: dict[str, Any]) -> Any:
|
|
|
|
|
async def _call_one_server(param: MCPServerSseParams, payload: dict[str, Any] | str) -> Any:
|
|
|
|
|
"""Make a call to a single MCP server with proper async context management."""
|
|
|
|
|
async with MCPServerSse(param, cache_tools_list=True) as srv:
|
|
|
|
|
res = await srv.call_tool(payload)
|
|
|
|
|
try:
|
|
|
|
|
return res.model_dump() # For fast-mcp ≥0.2
|
|
|
|
|
except AttributeError:
|
|
|
|
|
return res # Plain dict or string
|
|
|
|
|
try:
|
|
|
|
|
server = MCPServerSse(param)
|
|
|
|
|
await server.connect()
|
|
|
|
|
result = await server.call_tool(payload)
|
|
|
|
|
return result
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error calling server: {e}")
|
|
|
|
|
raise
|
|
|
|
|
finally:
|
|
|
|
|
if 'server' in locals():
|
|
|
|
|
await server.cleanup()
|
|
|
|
|
|
|
|
|
|
# Synchronous wrapper for the Agent to use
|
|
|
|
|
def batch_mcp_flow(params: List[MCPServerSseParams], payload: dict[str, Any]) -> List[Any]:
|
|
|
|
|
|
|
|
|
|
def batch_mcp_flow(params: List[MCPServerSseParams], payload: dict[str, Any] | str) -> List[Any]:
|
|
|
|
|
"""Blocking helper that fans out to all MCP servers in params."""
|
|
|
|
|
return asyncio.run(_batch(params, payload))
|
|
|
|
|
try:
|
|
|
|
|
return asyncio.run(_batch(params, payload))
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error in batch_mcp_flow: {e}")
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Async implementation of batch processing
|
|
|
|
|
async def _batch(params: List[MCPServerSseParams], payload: dict[str, Any]) -> List[Any]:
|
|
|
|
|
async def _batch(params: List[MCPServerSseParams], payload: dict[str, Any] | str) -> List[Any]:
|
|
|
|
|
"""Fan out to all MCP servers asynchronously and gather results."""
|
|
|
|
|
coros = [_call_one_server(p, payload) for p in params]
|
|
|
|
|
results = await asyncio.gather(*coros, return_exceptions=True)
|
|
|
|
|
# Filter out exceptions and convert to strings
|
|
|
|
|
return [any_to_str(r) for r in results if not isinstance(r, Exception)]
|
|
|
|
|
try:
|
|
|
|
|
coros = [_call_one_server(p, payload) for p in params]
|
|
|
|
|
results = await asyncio.gather(*coros, return_exceptions=True)
|
|
|
|
|
# Filter out exceptions and convert to strings
|
|
|
|
|
return [any_to_str(r) for r in results if not isinstance(r, Exception)]
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error in batch processing: {e}")
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from mcp import (
|
|
|
|
|