fix(mcp): resolve client initialization and update server configuration in mcp integration

pull/819/head
Pavan Kumar 3 months ago committed by ascender1729
parent d46da9c8bb
commit 925709de6e

@ -5,10 +5,11 @@ from loguru import logger
# Create FastMCP instance with SSE transport # Create FastMCP instance with SSE transport
mcp = FastMCP( mcp = FastMCP(
host="0.0.0.0", host="0.0.0.0",
port=8000, port=8000,
transport="sse",
require_session_id=False, require_session_id=False,
transport="sse" # Explicitly specify SSE transport timeout=30.0
) )
@mcp.tool() @mcp.tool()

@ -43,6 +43,7 @@ async def _execute_mcp_tool(
method: Literal["stdio", "sse"] = "sse", method: Literal["stdio", "sse"] = "sse",
parameters: Dict[Any, Any] = None, parameters: Dict[Any, Any] = None,
output_type: Literal["str", "dict"] = "str", output_type: Literal["str", "dict"] = "str",
timeout: float = 30.0,
*args, *args,
**kwargs, **kwargs,
) -> Dict[Any, Any]: ) -> Dict[Any, Any]:

@ -297,8 +297,10 @@ class MCPServerSse:
"""Connect to the MCP server with proper locking.""" """Connect to the MCP server with proper locking."""
async with self._connection_lock: async with self._connection_lock:
if not self.client: if not self.client:
self.client = ClientSession() transport = await self.create_streams()
await self.client.connect(self.create_streams()) read_stream, write_stream = transport
self.client = ClientSession(read_stream=read_stream, write_stream=write_stream)
await self.client.initialize()
def create_streams(self, **kwargs) -> AbstractAsyncContextManager[Any]: def create_streams(self, **kwargs) -> AbstractAsyncContextManager[Any]:
return sse_client( return sse_client(
@ -312,7 +314,7 @@ class MCPServerSse:
"""Parse input while preserving original format.""" """Parse input while preserving original format."""
if isinstance(payload, dict): if isinstance(payload, dict):
return payload return payload
if isinstance(payload, str): if isinstance(payload, str):
try: try:
# Try to parse as JSON # Try to parse as JSON
@ -321,37 +323,37 @@ class MCPServerSse:
except json.JSONDecodeError: except json.JSONDecodeError:
# Check if it's a math operation # Check if it's a math operation
import re import re
# Pattern matching for basic math operations # Pattern matching for basic math operations
add_pattern = r"(?i)(?:what\s+is\s+)?(\d+)\s*(?:plus|\+)\s*(\d+)" add_pattern = r"(?i)(?:what\s+is\s+)?(\d+)\s*(?:plus|\+)\s*(\d+)"
mult_pattern = r"(?i)(?:multiply|times|\*)\s*(\d+)\s*(?:and|by)?\s*(\d+)" mult_pattern = r"(?i)(?:multiply|times|\*)\s*(\d+)\s*(?:and|by)?\s*(\d+)"
div_pattern = r"(?i)(?:divide)\s*(\d+)\s*(?:by)\s*(\d+)" div_pattern = r"(?i)(?:divide)\s*(\d+)\s*(?:by)\s*(\d+)"
# Check for addition # Check for addition
if match := re.search(add_pattern, payload): if match := re.search(add_pattern, payload):
a, b = map(int, match.groups()) a, b = map(int, match.groups())
return {"tool_name": "add", "a": a, "b": b} return {"tool_name": "add", "a": a, "b": b}
# Check for multiplication # Check for multiplication
if match := re.search(mult_pattern, payload): if match := re.search(mult_pattern, payload):
a, b = map(int, match.groups()) a, b = map(int, match.groups())
return {"tool_name": "multiply", "a": a, "b": b} return {"tool_name": "multiply", "a": a, "b": b}
# Check for division # Check for division
if match := re.search(div_pattern, payload): if match := re.search(div_pattern, payload):
a, b = map(int, match.groups()) a, b = map(int, match.groups())
return {"tool_name": "divide", "a": a, "b": b} return {"tool_name": "divide", "a": a, "b": b}
# Default to text input if no pattern matches # Default to text input if no pattern matches
return {"text": payload} return {"text": payload}
return {"text": str(payload)} return {"text": str(payload)}
def _format_output(self, result: Any, original_input: Any) -> str: def _format_output(self, result: Any, original_input: Any) -> str:
"""Format output based on input type and result.""" """Format output based on input type and result."""
if not self.preserve_format: if not self.preserve_format:
return str(result) return str(result)
try: try:
if isinstance(result, (int, float)): if isinstance(result, (int, float)):
# For numeric results, format based on operation # For numeric results, format based on operation
@ -376,30 +378,30 @@ class MCPServerSse:
"""Call a tool on the MCP server with support for various input formats.""" """Call a tool on the MCP server with support for various input formats."""
if not self.client: if not self.client:
raise RuntimeError("Not connected to MCP server") raise RuntimeError("Not connected to MCP server")
# Store original input for formatting # Store original input for formatting
original_input = payload original_input = payload
# Parse input # Parse input
parsed_payload = self._parse_input(payload) parsed_payload = self._parse_input(payload)
# Add message to history # Add message to history
self.messages.append({ self.messages.append({
"role": "user", "role": "user",
"content": str(payload), "content": str(payload),
"parsed": parsed_payload "parsed": parsed_payload
}) })
try: try:
result = await self.client.call_tool(parsed_payload) result = await self.client.call_tool(parsed_payload)
formatted_result = self._format_output(result, original_input) formatted_result = self._format_output(result, original_input)
self.messages.append({ self.messages.append({
"role": "assistant", "role": "assistant",
"content": formatted_result, "content": formatted_result,
"raw_result": result "raw_result": result
}) })
return formatted_result return formatted_result
except Exception as e: except Exception as e:
error_msg = f"Error calling tool: {str(e)}" error_msg = f"Error calling tool: {str(e)}"
@ -502,6 +504,4 @@ async def _batch(params: List[MCPServerSseParams], payload: dict[str, Any] | str
return [any_to_str(r) for r in results if not isinstance(r, Exception)] return [any_to_str(r) for r in results if not isinstance(r, Exception)]
except Exception as e: except Exception as e:
logger.error(f"Error in batch processing: {e}") logger.error(f"Error in batch processing: {e}")
return [] return []
Loading…
Cancel
Save