|
|
|
@ -35,7 +35,7 @@ from loguru import logger
|
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
|
|
|
|
# Import existing MCP functionality
|
|
|
|
|
from swarms.schemas.mcp_schemas import MCPConnection
|
|
|
|
|
from swarms.schemas.mcp_schemas import MCPConnection, UnifiedTransportConfig
|
|
|
|
|
from swarms.tools.mcp_client_call import (
|
|
|
|
|
MCPConnectionError,
|
|
|
|
|
MCPExecutionError,
|
|
|
|
@ -74,79 +74,6 @@ except ImportError:
|
|
|
|
|
HTTPX_AVAILABLE = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UnifiedTransportConfig(BaseModel):
|
|
|
|
|
"""
|
|
|
|
|
Unified configuration for MCP transport types.
|
|
|
|
|
|
|
|
|
|
This extends the existing MCPConnection schema with additional
|
|
|
|
|
transport-specific options and auto-detection capabilities.
|
|
|
|
|
Includes streaming support for real-time communication.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# Transport type - can be auto-detected
|
|
|
|
|
transport_type: Literal["stdio", "http", "streamable_http", "sse", "auto"] = Field(
|
|
|
|
|
default="auto",
|
|
|
|
|
description="The transport type to use. 'auto' enables auto-detection."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Connection details
|
|
|
|
|
url: Optional[str] = Field(
|
|
|
|
|
default=None,
|
|
|
|
|
description="URL for HTTP-based transports or stdio command path"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# STDIO specific
|
|
|
|
|
command: Optional[List[str]] = Field(
|
|
|
|
|
default=None,
|
|
|
|
|
description="Command and arguments for stdio transport"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# HTTP specific
|
|
|
|
|
headers: Optional[Dict[str, str]] = Field(
|
|
|
|
|
default=None,
|
|
|
|
|
description="HTTP headers for HTTP-based transports"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Common settings
|
|
|
|
|
timeout: int = Field(
|
|
|
|
|
default=30,
|
|
|
|
|
description="Timeout in seconds"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
authorization_token: Optional[str] = Field(
|
|
|
|
|
default=None,
|
|
|
|
|
description="Authentication token for accessing the MCP server"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Auto-detection settings
|
|
|
|
|
auto_detect: bool = Field(
|
|
|
|
|
default=True,
|
|
|
|
|
description="Whether to auto-detect transport type from URL"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Fallback settings
|
|
|
|
|
fallback_transport: Literal["stdio", "http", "streamable_http", "sse"] = Field(
|
|
|
|
|
default="sse",
|
|
|
|
|
description="Fallback transport if auto-detection fails"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Streaming settings
|
|
|
|
|
enable_streaming: bool = Field(
|
|
|
|
|
default=True,
|
|
|
|
|
description="Whether to enable streaming support"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
streaming_timeout: Optional[int] = Field(
|
|
|
|
|
default=None,
|
|
|
|
|
description="Timeout for streaming operations"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
streaming_callback: Optional[Callable[[str], None]] = Field(
|
|
|
|
|
default=None,
|
|
|
|
|
description="Optional callback function for streaming chunks"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MCPUnifiedClient:
|
|
|
|
|
"""
|
|
|
|
|
Unified MCP client that supports multiple transport types.
|
|
|
|
@ -621,6 +548,168 @@ async def aexecute_tool_call_streaming_unified(
|
|
|
|
|
yield result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Function that matches the Agent class expectations
|
|
|
|
|
def call_tool_streaming_sync(
|
|
|
|
|
response: Any,
|
|
|
|
|
server_path: Optional[str] = None,
|
|
|
|
|
connection: Optional[MCPConnection] = None,
|
|
|
|
|
config: Optional[UnifiedTransportConfig] = None
|
|
|
|
|
) -> List[Dict[str, Any]]:
|
|
|
|
|
"""
|
|
|
|
|
Call a tool with streaming support - matches Agent class expectations.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
response: The response from the LLM (may contain tool calls)
|
|
|
|
|
server_path: MCP server path/URL
|
|
|
|
|
connection: MCP connection object
|
|
|
|
|
config: Transport configuration
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
List of streaming tool execution results
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
# Determine the configuration to use
|
|
|
|
|
if config is not None:
|
|
|
|
|
transport_config = config
|
|
|
|
|
elif connection is not None:
|
|
|
|
|
transport_config = UnifiedTransportConfig(
|
|
|
|
|
transport_type=connection.transport or "auto",
|
|
|
|
|
url=connection.url,
|
|
|
|
|
headers=connection.headers,
|
|
|
|
|
timeout=connection.timeout or 30,
|
|
|
|
|
authorization_token=connection.authorization_token,
|
|
|
|
|
auto_detect=True,
|
|
|
|
|
enable_streaming=True
|
|
|
|
|
)
|
|
|
|
|
elif server_path is not None:
|
|
|
|
|
transport_config = UnifiedTransportConfig(
|
|
|
|
|
url=server_path,
|
|
|
|
|
transport_type="auto",
|
|
|
|
|
auto_detect=True,
|
|
|
|
|
enable_streaming=True
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Either server_path, connection, or config must be provided")
|
|
|
|
|
|
|
|
|
|
# Extract tool calls from response if it's a string
|
|
|
|
|
if isinstance(response, str):
|
|
|
|
|
tool_calls = _extract_tool_calls_from_response(response)
|
|
|
|
|
else:
|
|
|
|
|
tool_calls = [{"name": "default_tool", "arguments": {}}]
|
|
|
|
|
|
|
|
|
|
# Execute each tool call with streaming
|
|
|
|
|
all_results = []
|
|
|
|
|
client = MCPUnifiedClient(transport_config)
|
|
|
|
|
|
|
|
|
|
for tool_call in tool_calls:
|
|
|
|
|
tool_name = tool_call.get("name", "default_tool")
|
|
|
|
|
arguments = tool_call.get("arguments", {})
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
results = client.call_tool_streaming_sync(tool_name, arguments)
|
|
|
|
|
all_results.extend(results)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error calling tool {tool_name}: {e}")
|
|
|
|
|
# Add error result
|
|
|
|
|
all_results.append({
|
|
|
|
|
"error": str(e),
|
|
|
|
|
"tool_name": tool_name,
|
|
|
|
|
"arguments": arguments
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
return all_results
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error in call_tool_streaming_sync: {e}")
|
|
|
|
|
return [{"error": str(e)}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _extract_tool_calls_from_response(response: str) -> List[Dict[str, Any]]:
|
|
|
|
|
"""
|
|
|
|
|
Extract tool calls from LLM response.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
response: The response string from the LLM
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
List of tool call dictionaries
|
|
|
|
|
"""
|
|
|
|
|
import re
|
|
|
|
|
import json
|
|
|
|
|
|
|
|
|
|
tool_calls = []
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Try to find JSON tool calls
|
|
|
|
|
json_match = re.search(r'```json\s*(\{.*?\})\s*```', response, re.DOTALL)
|
|
|
|
|
if json_match:
|
|
|
|
|
try:
|
|
|
|
|
tool_data = json.loads(json_match.group(1))
|
|
|
|
|
|
|
|
|
|
# Check for tool_uses format
|
|
|
|
|
if "tool_uses" in tool_data and tool_data["tool_uses"]:
|
|
|
|
|
for tool_call in tool_data["tool_uses"]:
|
|
|
|
|
if "recipient_name" in tool_call:
|
|
|
|
|
tool_name = tool_call["recipient_name"]
|
|
|
|
|
arguments = tool_call.get("parameters", {})
|
|
|
|
|
tool_calls.append({
|
|
|
|
|
"name": tool_name,
|
|
|
|
|
"arguments": arguments
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
# Check for direct tool call format
|
|
|
|
|
elif "name" in tool_data and "arguments" in tool_data:
|
|
|
|
|
tool_calls.append({
|
|
|
|
|
"name": tool_data["name"],
|
|
|
|
|
"arguments": tool_data["arguments"]
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
# If no JSON found, try to extract from text
|
|
|
|
|
if not tool_calls:
|
|
|
|
|
# Look for common tool patterns
|
|
|
|
|
response_lower = response.lower()
|
|
|
|
|
|
|
|
|
|
if "calculate" in response_lower or "compute" in response_lower:
|
|
|
|
|
# Extract mathematical expression
|
|
|
|
|
expr_match = re.search(r'(\d+\s*[\+\-\*\/]\s*\d+)', response)
|
|
|
|
|
if expr_match:
|
|
|
|
|
tool_calls.append({
|
|
|
|
|
"name": "calculate",
|
|
|
|
|
"arguments": {"expression": expr_match.group(1)}
|
|
|
|
|
})
|
|
|
|
|
else:
|
|
|
|
|
tool_calls.append({
|
|
|
|
|
"name": "calculate",
|
|
|
|
|
"arguments": {"expression": "2+2"}
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
elif "search" in response_lower or "find" in response_lower:
|
|
|
|
|
tool_calls.append({
|
|
|
|
|
"name": "search",
|
|
|
|
|
"arguments": {"query": response.strip()}
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
# Default tool call
|
|
|
|
|
tool_calls.append({
|
|
|
|
|
"name": "default_tool",
|
|
|
|
|
"arguments": {"input": response.strip()}
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error extracting tool calls: {e}")
|
|
|
|
|
# Return default tool call
|
|
|
|
|
tool_calls.append({
|
|
|
|
|
"name": "default_tool",
|
|
|
|
|
"arguments": {"input": response.strip()}
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
return tool_calls
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Helper functions for creating configurations
|
|
|
|
|
def create_stdio_config(command: List[str], **kwargs) -> UnifiedTransportConfig:
|
|
|
|
|
"""
|
|
|
|
|