|
|
|
@ -297,8 +297,10 @@ class MCPServerSse:
|
|
|
|
|
"""Connect to the MCP server with proper locking."""
|
|
|
|
|
async with self._connection_lock:
|
|
|
|
|
if not self.client:
|
|
|
|
|
self.client = ClientSession()
|
|
|
|
|
await self.client.connect(self.create_streams())
|
|
|
|
|
transport = await self.create_streams()
|
|
|
|
|
read_stream, write_stream = transport
|
|
|
|
|
self.client = ClientSession(read_stream=read_stream, write_stream=write_stream)
|
|
|
|
|
await self.client.initialize()
|
|
|
|
|
|
|
|
|
|
def create_streams(self, **kwargs) -> AbstractAsyncContextManager[Any]:
|
|
|
|
|
return sse_client(
|
|
|
|
@ -312,7 +314,7 @@ class MCPServerSse:
|
|
|
|
|
"""Parse input while preserving original format."""
|
|
|
|
|
if isinstance(payload, dict):
|
|
|
|
|
return payload
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(payload, str):
|
|
|
|
|
try:
|
|
|
|
|
# Try to parse as JSON
|
|
|
|
@ -321,37 +323,37 @@ class MCPServerSse:
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
# Check if it's a math operation
|
|
|
|
|
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Pattern matching for basic math operations
|
|
|
|
|
add_pattern = r"(?i)(?:what\s+is\s+)?(\d+)\s*(?:plus|\+)\s*(\d+)"
|
|
|
|
|
mult_pattern = r"(?i)(?:multiply|times|\*)\s*(\d+)\s*(?:and|by)?\s*(\d+)"
|
|
|
|
|
div_pattern = r"(?i)(?:divide)\s*(\d+)\s*(?:by)\s*(\d+)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Check for addition
|
|
|
|
|
if match := re.search(add_pattern, payload):
|
|
|
|
|
a, b = map(int, match.groups())
|
|
|
|
|
return {"tool_name": "add", "a": a, "b": b}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Check for multiplication
|
|
|
|
|
if match := re.search(mult_pattern, payload):
|
|
|
|
|
a, b = map(int, match.groups())
|
|
|
|
|
return {"tool_name": "multiply", "a": a, "b": b}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Check for division
|
|
|
|
|
if match := re.search(div_pattern, payload):
|
|
|
|
|
a, b = map(int, match.groups())
|
|
|
|
|
return {"tool_name": "divide", "a": a, "b": b}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Default to text input if no pattern matches
|
|
|
|
|
return {"text": payload}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return {"text": str(payload)}
|
|
|
|
|
|
|
|
|
|
def _format_output(self, result: Any, original_input: Any) -> str:
|
|
|
|
|
"""Format output based on input type and result."""
|
|
|
|
|
if not self.preserve_format:
|
|
|
|
|
return str(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
if isinstance(result, (int, float)):
|
|
|
|
|
# For numeric results, format based on operation
|
|
|
|
@ -376,30 +378,30 @@ class MCPServerSse:
|
|
|
|
|
"""Call a tool on the MCP server with support for various input formats."""
|
|
|
|
|
if not self.client:
|
|
|
|
|
raise RuntimeError("Not connected to MCP server")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Store original input for formatting
|
|
|
|
|
original_input = payload
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Parse input
|
|
|
|
|
parsed_payload = self._parse_input(payload)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Add message to history
|
|
|
|
|
self.messages.append({
|
|
|
|
|
"role": "user",
|
|
|
|
|
"content": str(payload),
|
|
|
|
|
"parsed": parsed_payload
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
result = await self.client.call_tool(parsed_payload)
|
|
|
|
|
formatted_result = self._format_output(result, original_input)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.messages.append({
|
|
|
|
|
"role": "assistant",
|
|
|
|
|
"content": formatted_result,
|
|
|
|
|
"raw_result": result
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return formatted_result
|
|
|
|
|
except Exception as e:
|
|
|
|
|
error_msg = f"Error calling tool: {str(e)}"
|
|
|
|
@ -502,6 +504,4 @@ async def _batch(params: List[MCPServerSseParams], payload: dict[str, Any] | str
|
|
|
|
|
return [any_to_str(r) for r in results if not isinstance(r, Exception)]
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error in batch processing: {e}")
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return []
|