Update MCP integration: Clean up server and client code, improve error handling, and simplify prompt

pull/819/head
ascender1729 3 months ago
parent b398753c72
commit eb9d337b45

@ -1,26 +1,49 @@
from fastmcp import FastMCP from fastmcp import FastMCP
from typing import Dict, Any
import asyncio
from loguru import logger
mcp = FastMCP("Math-Mock-Server") # Create FastMCP instance with SSE transport
mcp = FastMCP(
host="0.0.0.0",
port=8000,
require_session_id=False,
transport="sse" # Explicitly specify SSE transport
)
@mcp.tool() @mcp.tool()
def add(a: int, b: int) -> int: def add(a: int, b: int) -> int:
"""Add two numbers together""" """Add two numbers."""
return a + b return a + b
@mcp.tool() @mcp.tool()
def multiply(a: int, b: int) -> int: def multiply(a: int, b: int) -> int:
"""Multiply two numbers together""" """Multiply two numbers."""
return a * b return a * b
@mcp.tool() @mcp.tool()
def divide(a: int, b: int) -> float: def divide(a: int, b: int) -> float:
"""Divide two numbers""" """Divide two numbers."""
if b == 0: if b == 0:
return {"error": "Cannot divide by zero"} raise ValueError("Cannot divide by zero")
return a / b return a / b
async def run_server():
"""Run the server with proper error handling."""
try:
logger.info("Starting math server on http://0.0.0.0:8000")
await mcp.run_async()
except Exception as e:
logger.error(f"Server error: {e}")
raise
finally:
await mcp.cleanup()
if __name__ == "__main__": if __name__ == "__main__":
print("Starting Mock Math Server on port 8000...") try:
# FastMCP expects transport_kwargs as separate parameters asyncio.run(run_server())
mcp.run(transport="sse", host="0.0.0.0", port=8000) except KeyboardInterrupt:
logger.info("Server stopped by user")
except Exception as e:
logger.error(f"Fatal error: {e}")
raise

@ -284,12 +284,16 @@ class MCPServerSse:
def __init__(self, params: MCPServerSseParams): def __init__(self, params: MCPServerSseParams):
self.params = params self.params = params
self.client: Optional[ClientSession] = None self.client: Optional[ClientSession] = None
self._connection_lock = asyncio.Lock()
self.messages = [] # Store messages instead of using conversation
self.preserve_format = True # Flag to preserve original formatting
async def connect(self): async def connect(self):
"""Connect to the MCP server.""" """Connect to the MCP server with proper locking."""
if not self.client: async with self._connection_lock:
self.client = ClientSession() if not self.client:
await self.client.connect(self.create_streams()) self.client = ClientSession()
await self.client.connect(self.create_streams())
def create_streams(self, **kwargs) -> AbstractAsyncContextManager[Any]: def create_streams(self, **kwargs) -> AbstractAsyncContextManager[Any]:
return sse_client( return sse_client(
@ -299,45 +303,159 @@ class MCPServerSse:
sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5), sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5),
) )
async def call_tool(self, payload: dict[str, Any]): def _parse_input(self, payload: Any) -> dict:
"""Call a tool on the MCP server.""" """Parse input while preserving original format."""
if isinstance(payload, dict):
return payload
if isinstance(payload, str):
try:
# Try to parse as JSON
import json
return json.loads(payload)
except json.JSONDecodeError:
# Check if it's a math operation
import re
# Pattern matching for basic math operations
add_pattern = r"(?i)(?:what\s+is\s+)?(\d+)\s*(?:plus|\+)\s*(\d+)"
mult_pattern = r"(?i)(?:multiply|times|\*)\s*(\d+)\s*(?:and|by)?\s*(\d+)"
div_pattern = r"(?i)(?:divide)\s*(\d+)\s*(?:by)\s*(\d+)"
# Check for addition
if match := re.search(add_pattern, payload):
a, b = map(int, match.groups())
return {"tool_name": "add", "a": a, "b": b}
# Check for multiplication
if match := re.search(mult_pattern, payload):
a, b = map(int, match.groups())
return {"tool_name": "multiply", "a": a, "b": b}
# Check for division
if match := re.search(div_pattern, payload):
a, b = map(int, match.groups())
return {"tool_name": "divide", "a": a, "b": b}
# Default to text input if no pattern matches
return {"text": payload}
return {"text": str(payload)}
def _format_output(self, result: Any, original_input: Any) -> str:
"""Format output based on input type and result."""
if not self.preserve_format:
return str(result)
try:
if isinstance(result, (int, float)):
# For numeric results, format based on operation
if isinstance(original_input, dict):
tool_name = original_input.get("tool_name", "")
if tool_name == "add":
return f"{original_input['a']} + {original_input['b']} = {result}"
elif tool_name == "multiply":
return f"{original_input['a']} * {original_input['b']} = {result}"
elif tool_name == "divide":
return f"{original_input['a']} / {original_input['b']} = {result}"
return str(result)
elif isinstance(result, dict):
return json.dumps(result, indent=2)
else:
return str(result)
except Exception as e:
logger.error(f"Error formatting output: {e}")
return str(result)
async def call_tool(self, payload: Any) -> Any:
"""Call a tool on the MCP server with support for various input formats."""
if not self.client: if not self.client:
raise RuntimeError("Not connected to MCP server") raise RuntimeError("Not connected to MCP server")
return await self.client.call_tool(payload)
async def cleanup(self): # Store original input for formatting
"""Clean up the connection.""" original_input = payload
if self.client:
await self.client.close() # Parse input
self.client = None parsed_payload = self._parse_input(payload)
async def list_tools(self) -> list[Any]: # Added for compatibility # Add message to history
self.messages.append({
"role": "user",
"content": str(payload),
"parsed": parsed_payload
})
try:
result = await self.client.call_tool(parsed_payload)
formatted_result = self._format_output(result, original_input)
self.messages.append({
"role": "assistant",
"content": formatted_result,
"raw_result": result
})
return formatted_result
except Exception as e:
error_msg = f"Error calling tool: {str(e)}"
self.messages.append({
"role": "error",
"content": error_msg,
"original_input": payload
})
raise
async def cleanup(self):
"""Clean up the connection with proper locking."""
async with self._connection_lock:
if self.client:
await self.client.close()
self.client = None
async def list_tools(self) -> list[Any]:
"""List available tools with proper error handling."""
if not self.client: if not self.client:
raise RuntimeError("Not connected to MCP server") raise RuntimeError("Not connected to MCP server")
return await self.client.list_tools() try:
return await self.client.list_tools()
except Exception as e:
logger.error(f"Error listing tools: {e}")
return []
async def call_tool_fast(server: MCPServerSse, payload: dict[str, Any]): async def call_tool_fast(server: MCPServerSse, payload: dict[str, Any] | str):
""" """
Convenience wrapper that opens calls closes in one shot. Convenience wrapper that opens calls closes in one shot with proper error handling.
""" """
await server.connect() try:
result = await server.call_tool(payload) await server.connect()
await server.cleanup() result = await server.call_tool(payload)
return result.model_dump() if hasattr(result, "model_dump") else result return result.model_dump() if hasattr(result, "model_dump") else result
except Exception as e:
logger.error(f"Error in call_tool_fast: {e}")
raise
finally:
await server.cleanup()
async def mcp_flow_get_tool_schema( async def mcp_flow_get_tool_schema(
params: MCPServerSseParams, params: MCPServerSseParams,
) -> Any: # Updated return type to Any ) -> Any:
async with MCPServerSse(params) as server: """Get tool schema with proper error handling."""
return (await server.list_tools()).model_dump() try:
async with MCPServerSse(params) as server:
tools = await server.list_tools()
return tools.model_dump() if hasattr(tools, "model_dump") else tools
except Exception as e:
logger.error(f"Error getting tool schema: {e}")
raise
async def mcp_flow( async def mcp_flow(
params: MCPServerSseParams, params: MCPServerSseParams,
function_call: dict[str, Any], function_call: dict[str, Any] | str,
) -> Any: # Updated return type to Any ) -> Any:
"""Execute MCP flow with proper error handling."""
try: try:
async with MCPServerSse(params) as server: async with MCPServerSse(params) as server:
return await call_tool_fast(server, function_call) return await call_tool_fast(server, function_call)
@ -346,28 +464,40 @@ async def mcp_flow(
raise raise
# Helper function to call one MCP server async def _call_one_server(param: MCPServerSseParams, payload: dict[str, Any] | str) -> Any:
async def _call_one_server(param: MCPServerSseParams, payload: dict[str, Any]) -> Any:
"""Make a call to a single MCP server with proper async context management.""" """Make a call to a single MCP server with proper async context management."""
async with MCPServerSse(param, cache_tools_list=True) as srv: try:
res = await srv.call_tool(payload) server = MCPServerSse(param)
try: await server.connect()
return res.model_dump() # For fast-mcp ≥0.2 result = await server.call_tool(payload)
except AttributeError: return result
return res # Plain dict or string except Exception as e:
logger.error(f"Error calling server: {e}")
raise
finally:
if 'server' in locals():
await server.cleanup()
# Synchronous wrapper for the Agent to use
def batch_mcp_flow(params: List[MCPServerSseParams], payload: dict[str, Any]) -> List[Any]: def batch_mcp_flow(params: List[MCPServerSseParams], payload: dict[str, Any] | str) -> List[Any]:
"""Blocking helper that fans out to all MCP servers in params.""" """Blocking helper that fans out to all MCP servers in params."""
return asyncio.run(_batch(params, payload)) try:
return asyncio.run(_batch(params, payload))
except Exception as e:
logger.error(f"Error in batch_mcp_flow: {e}")
return []
# Async implementation of batch processing async def _batch(params: List[MCPServerSseParams], payload: dict[str, Any] | str) -> List[Any]:
async def _batch(params: List[MCPServerSseParams], payload: dict[str, Any]) -> List[Any]:
"""Fan out to all MCP servers asynchronously and gather results.""" """Fan out to all MCP servers asynchronously and gather results."""
coros = [_call_one_server(p, payload) for p in params] try:
results = await asyncio.gather(*coros, return_exceptions=True) coros = [_call_one_server(p, payload) for p in params]
# Filter out exceptions and convert to strings results = await asyncio.gather(*coros, return_exceptions=True)
return [any_to_str(r) for r in results if not isinstance(r, Exception)] # Filter out exceptions and convert to strings
return [any_to_str(r) for r in results if not isinstance(r, Exception)]
except Exception as e:
logger.error(f"Error in batch processing: {e}")
return []
from mcp import ( from mcp import (

Loading…
Cancel
Save