Update mcp_unified_client.py

pull/1005/head
CI-DEV 2 months ago committed by GitHub
parent 08ce3469f2
commit dffcba52b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -35,7 +35,7 @@ from loguru import logger
from pydantic import BaseModel, Field
# Import existing MCP functionality
from swarms.schemas.mcp_schemas import MCPConnection
from swarms.schemas.mcp_schemas import MCPConnection, UnifiedTransportConfig
from swarms.tools.mcp_client_call import (
MCPConnectionError,
MCPExecutionError,
@ -74,79 +74,6 @@ except ImportError:
HTTPX_AVAILABLE = False
class UnifiedTransportConfig(BaseModel):
"""
Unified configuration for MCP transport types.
This extends the existing MCPConnection schema with additional
transport-specific options and auto-detection capabilities.
Includes streaming support for real-time communication.
"""
# Transport type - can be auto-detected
transport_type: Literal["stdio", "http", "streamable_http", "sse", "auto"] = Field(
default="auto",
description="The transport type to use. 'auto' enables auto-detection."
)
# Connection details
url: Optional[str] = Field(
default=None,
description="URL for HTTP-based transports or stdio command path"
)
# STDIO specific
command: Optional[List[str]] = Field(
default=None,
description="Command and arguments for stdio transport"
)
# HTTP specific
headers: Optional[Dict[str, str]] = Field(
default=None,
description="HTTP headers for HTTP-based transports"
)
# Common settings
timeout: int = Field(
default=30,
description="Timeout in seconds"
)
authorization_token: Optional[str] = Field(
default=None,
description="Authentication token for accessing the MCP server"
)
# Auto-detection settings
auto_detect: bool = Field(
default=True,
description="Whether to auto-detect transport type from URL"
)
# Fallback settings
fallback_transport: Literal["stdio", "http", "streamable_http", "sse"] = Field(
default="sse",
description="Fallback transport if auto-detection fails"
)
# Streaming settings
enable_streaming: bool = Field(
default=True,
description="Whether to enable streaming support"
)
streaming_timeout: Optional[int] = Field(
default=None,
description="Timeout for streaming operations"
)
streaming_callback: Optional[Callable[[str], None]] = Field(
default=None,
description="Optional callback function for streaming chunks"
)
class MCPUnifiedClient:
"""
Unified MCP client that supports multiple transport types.
@ -621,6 +548,168 @@ async def aexecute_tool_call_streaming_unified(
yield result
# Function that matches the Agent class expectations
def call_tool_streaming_sync(
response: Any,
server_path: Optional[str] = None,
connection: Optional[MCPConnection] = None,
config: Optional[UnifiedTransportConfig] = None
) -> List[Dict[str, Any]]:
"""
Call a tool with streaming support - matches Agent class expectations.
Args:
response: The response from the LLM (may contain tool calls)
server_path: MCP server path/URL
connection: MCP connection object
config: Transport configuration
Returns:
List of streaming tool execution results
"""
try:
# Determine the configuration to use
if config is not None:
transport_config = config
elif connection is not None:
transport_config = UnifiedTransportConfig(
transport_type=connection.transport or "auto",
url=connection.url,
headers=connection.headers,
timeout=connection.timeout or 30,
authorization_token=connection.authorization_token,
auto_detect=True,
enable_streaming=True
)
elif server_path is not None:
transport_config = UnifiedTransportConfig(
url=server_path,
transport_type="auto",
auto_detect=True,
enable_streaming=True
)
else:
raise ValueError("Either server_path, connection, or config must be provided")
# Extract tool calls from response if it's a string
if isinstance(response, str):
tool_calls = _extract_tool_calls_from_response(response)
else:
tool_calls = [{"name": "default_tool", "arguments": {}}]
# Execute each tool call with streaming
all_results = []
client = MCPUnifiedClient(transport_config)
for tool_call in tool_calls:
tool_name = tool_call.get("name", "default_tool")
arguments = tool_call.get("arguments", {})
try:
results = client.call_tool_streaming_sync(tool_name, arguments)
all_results.extend(results)
except Exception as e:
logger.error(f"Error calling tool {tool_name}: {e}")
# Add error result
all_results.append({
"error": str(e),
"tool_name": tool_name,
"arguments": arguments
})
return all_results
except Exception as e:
logger.error(f"Error in call_tool_streaming_sync: {e}")
return [{"error": str(e)}]
def _extract_tool_calls_from_response(response: str) -> List[Dict[str, Any]]:
"""
Extract tool calls from LLM response.
Args:
response: The response string from the LLM
Returns:
List of tool call dictionaries
"""
import re
import json
tool_calls = []
try:
# Try to find JSON tool calls
json_match = re.search(r'```json\s*(\{.*?\})\s*```', response, re.DOTALL)
if json_match:
try:
tool_data = json.loads(json_match.group(1))
# Check for tool_uses format
if "tool_uses" in tool_data and tool_data["tool_uses"]:
for tool_call in tool_data["tool_uses"]:
if "recipient_name" in tool_call:
tool_name = tool_call["recipient_name"]
arguments = tool_call.get("parameters", {})
tool_calls.append({
"name": tool_name,
"arguments": arguments
})
# Check for direct tool call format
elif "name" in tool_data and "arguments" in tool_data:
tool_calls.append({
"name": tool_data["name"],
"arguments": tool_data["arguments"]
})
except json.JSONDecodeError:
pass
# If no JSON found, try to extract from text
if not tool_calls:
# Look for common tool patterns
response_lower = response.lower()
if "calculate" in response_lower or "compute" in response_lower:
# Extract mathematical expression
expr_match = re.search(r'(\d+\s*[\+\-\*\/]\s*\d+)', response)
if expr_match:
tool_calls.append({
"name": "calculate",
"arguments": {"expression": expr_match.group(1)}
})
else:
tool_calls.append({
"name": "calculate",
"arguments": {"expression": "2+2"}
})
elif "search" in response_lower or "find" in response_lower:
tool_calls.append({
"name": "search",
"arguments": {"query": response.strip()}
})
else:
# Default tool call
tool_calls.append({
"name": "default_tool",
"arguments": {"input": response.strip()}
})
except Exception as e:
logger.error(f"Error extracting tool calls: {e}")
# Return default tool call
tool_calls.append({
"name": "default_tool",
"arguments": {"input": response.strip()}
})
return tool_calls
# Helper functions for creating configurations
def create_stdio_config(command: List[str], **kwargs) -> UnifiedTransportConfig:
"""

Loading…
Cancel
Save