Update mcp_unified_client.py

pull/1005/head
CI-DEV 2 months ago committed by GitHub
parent 30f88aa988
commit 22459b3eba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -251,7 +251,7 @@ class MCPUnifiedClient:
except Exception as e: except Exception as e:
logger.error(f"HTTP read error: {e}") logger.error(f"HTTP read error: {e}")
raise MCPConnectionError(f"HTTP read failed: {e}") raise MCPConnectionError(f"HTTP read failed: {e}")
async def write(data): async def write(data):
# Implement HTTP write logic for MCP # Implement HTTP write logic for MCP
try: try:
@ -265,7 +265,7 @@ class MCPUnifiedClient:
except Exception as e: except Exception as e:
logger.error(f"HTTP write error: {e}") logger.error(f"HTTP write error: {e}")
raise MCPConnectionError(f"HTTP write failed: {e}") raise MCPConnectionError(f"HTTP write failed: {e}")
yield read, write yield read, write
async def get_tools(self, format: Literal["mcp", "openai"] = "openai") -> List[Dict[str, Any]]: async def get_tools(self, format: Literal["mcp", "openai"] = "openai") -> List[Dict[str, Any]]:
@ -507,13 +507,13 @@ async def aexecute_tool_call_unified(
return await client.call_tool(tool_name, arguments) return await client.call_tool(tool_name, arguments)
def execute_tool_call_streaming_unified( def call_tool_streaming_sync(
config: Union[UnifiedTransportConfig, MCPConnection, str], config: Union[UnifiedTransportConfig, MCPConnection, str],
tool_name: str, tool_name: str,
arguments: Dict[str, Any] arguments: Dict[str, Any]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
Execute a tool call with streaming using the unified client. Call a tool with streaming support synchronously.
Args: Args:
config: Transport configuration config: Transport configuration
@ -527,13 +527,13 @@ def execute_tool_call_streaming_unified(
return client.call_tool_streaming_sync(tool_name, arguments) return client.call_tool_streaming_sync(tool_name, arguments)
async def aexecute_tool_call_streaming_unified( async def call_tool_streaming(
config: Union[UnifiedTransportConfig, MCPConnection, str], config: Union[UnifiedTransportConfig, MCPConnection, str],
tool_name: str, tool_name: str,
arguments: Dict[str, Any] arguments: Dict[str, Any]
) -> AsyncGenerator[Dict[str, Any], None]: ) -> AsyncGenerator[Dict[str, Any], None]:
""" """
Async version of execute_tool_call_streaming_unified. Call a tool with streaming support asynchronously.
Args: Args:
config: Transport configuration config: Transport configuration
@ -548,175 +548,32 @@ async def aexecute_tool_call_streaming_unified(
yield result yield result
# Function that matches the Agent class expectations def execute_tool_call_streaming_unified(
def call_tool_streaming_sync( config: Union[UnifiedTransportConfig, MCPConnection, str],
response: Any, tool_name: str,
server_path: Optional[str] = None, arguments: Dict[str, Any]
connection: Optional[MCPConnection] = None,
config: Optional[UnifiedTransportConfig] = None
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
Call a tool with streaming support - matches Agent class expectations. Execute a tool call with streaming support using the unified client.
Args: Args:
response: The response from the LLM (may contain tool calls)
server_path: MCP server path/URL
connection: MCP connection object
config: Transport configuration config: Transport configuration
tool_name: Name of the tool to call
arguments: Tool arguments
Returns: Returns:
List of streaming tool execution results List of streaming tool execution results
""" """
try: return call_tool_streaming_sync(config, tool_name, arguments)
# Determine the configuration to use
if config is not None:
transport_config = config
elif connection is not None:
transport_config = UnifiedTransportConfig(
transport_type=connection.transport or "auto",
url=connection.url,
headers=connection.headers,
timeout=connection.timeout or 30,
authorization_token=connection.authorization_token,
auto_detect=True,
enable_streaming=True
)
elif server_path is not None:
transport_config = UnifiedTransportConfig(
url=server_path,
transport_type="auto",
auto_detect=True,
enable_streaming=True
)
else:
raise ValueError("Either server_path, connection, or config must be provided")
# Extract tool calls from response if it's a string
if isinstance(response, str):
tool_calls = _extract_tool_calls_from_response(response)
else:
tool_calls = [{"name": "default_tool", "arguments": {}}]
# Execute each tool call with streaming
all_results = []
client = MCPUnifiedClient(transport_config)
for tool_call in tool_calls:
tool_name = tool_call.get("name", "default_tool")
arguments = tool_call.get("arguments", {})
try:
results = client.call_tool_streaming_sync(tool_name, arguments)
all_results.extend(results)
except Exception as e:
logger.error(f"Error calling tool {tool_name}: {e}")
# Add error result
all_results.append({
"error": str(e),
"tool_name": tool_name,
"arguments": arguments
})
return all_results
except Exception as e:
logger.error(f"Error in call_tool_streaming_sync: {e}")
return [{"error": str(e)}]
def _extract_tool_calls_from_response(response: str) -> List[Dict[str, Any]]:
"""
Extract tool calls from LLM response.
Args:
response: The response string from the LLM
Returns:
List of tool call dictionaries
"""
import re
import json
tool_calls = []
try:
# Try to find JSON tool calls
json_match = re.search(r'```json\s*(\{.*?\})\s*```', response, re.DOTALL)
if json_match:
try:
tool_data = json.loads(json_match.group(1))
# Check for tool_uses format
if "tool_uses" in tool_data and tool_data["tool_uses"]:
for tool_call in tool_data["tool_uses"]:
if "recipient_name" in tool_call:
tool_name = tool_call["recipient_name"]
arguments = tool_call.get("parameters", {})
tool_calls.append({
"name": tool_name,
"arguments": arguments
})
# Check for direct tool call format
elif "name" in tool_data and "arguments" in tool_data:
tool_calls.append({
"name": tool_data["name"],
"arguments": tool_data["arguments"]
})
except json.JSONDecodeError:
pass
# If no JSON found, try to extract from text
if not tool_calls:
# Look for common tool patterns
response_lower = response.lower()
if "calculate" in response_lower or "compute" in response_lower:
# Extract mathematical expression
expr_match = re.search(r'(\d+\s*[\+\-\*\/]\s*\d+)', response)
if expr_match:
tool_calls.append({
"name": "calculate",
"arguments": {"expression": expr_match.group(1)}
})
else:
tool_calls.append({
"name": "calculate",
"arguments": {"expression": "2+2"}
})
elif "search" in response_lower or "find" in response_lower:
tool_calls.append({
"name": "search",
"arguments": {"query": response.strip()}
})
else:
# Default tool call
tool_calls.append({
"name": "default_tool",
"arguments": {"input": response.strip()}
})
except Exception as e:
logger.error(f"Error extracting tool calls: {e}")
# Return default tool call
tool_calls.append({
"name": "default_tool",
"arguments": {"input": response.strip()}
})
return tool_calls
# Helper functions for creating configurations # Configuration factory functions
def create_stdio_config(command: List[str], **kwargs) -> UnifiedTransportConfig: def create_stdio_config(command: List[str], **kwargs) -> UnifiedTransportConfig:
""" """
Create configuration for stdio transport. Create stdio transport configuration.
Args: Args:
command: Command and arguments to run command: Command to execute
**kwargs: Additional configuration options **kwargs: Additional configuration options
Returns: Returns:
@ -732,7 +589,7 @@ def create_stdio_config(command: List[str], **kwargs) -> UnifiedTransportConfig:
def create_http_config(url: str, headers: Optional[Dict[str, str]] = None, **kwargs) -> UnifiedTransportConfig: def create_http_config(url: str, headers: Optional[Dict[str, str]] = None, **kwargs) -> UnifiedTransportConfig:
""" """
Create configuration for HTTP transport. Create HTTP transport configuration.
Args: Args:
url: Server URL url: Server URL
@ -753,7 +610,7 @@ def create_http_config(url: str, headers: Optional[Dict[str, str]] = None, **kwa
def create_streamable_http_config(url: str, headers: Optional[Dict[str, str]] = None, **kwargs) -> UnifiedTransportConfig: def create_streamable_http_config(url: str, headers: Optional[Dict[str, str]] = None, **kwargs) -> UnifiedTransportConfig:
""" """
Create configuration for streamable HTTP transport. Create streamable HTTP transport configuration.
Args: Args:
url: Server URL url: Server URL
@ -774,7 +631,7 @@ def create_streamable_http_config(url: str, headers: Optional[Dict[str, str]] =
def create_sse_config(url: str, headers: Optional[Dict[str, str]] = None, **kwargs) -> UnifiedTransportConfig: def create_sse_config(url: str, headers: Optional[Dict[str, str]] = None, **kwargs) -> UnifiedTransportConfig:
""" """
Create configuration for SSE transport. Create SSE transport configuration.
Args: Args:
url: Server URL url: Server URL

Loading…
Cancel
Save