Update mcp_unified_client.py

pull/1005/head
CI-DEV 2 months ago committed by GitHub
parent 30f88aa988
commit 22459b3eba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -507,13 +507,13 @@ async def aexecute_tool_call_unified(
return await client.call_tool(tool_name, arguments)
def execute_tool_call_streaming_unified(
def call_tool_streaming_sync(
config: Union[UnifiedTransportConfig, MCPConnection, str],
tool_name: str,
arguments: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""
Execute a tool call with streaming using the unified client.
Call a tool with streaming support synchronously.
Args:
config: Transport configuration
@ -527,13 +527,13 @@ def execute_tool_call_streaming_unified(
return client.call_tool_streaming_sync(tool_name, arguments)
async def aexecute_tool_call_streaming_unified(
async def call_tool_streaming(
config: Union[UnifiedTransportConfig, MCPConnection, str],
tool_name: str,
arguments: Dict[str, Any]
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Async version of execute_tool_call_streaming_unified.
Call a tool with streaming support asynchronously.
Args:
config: Transport configuration
@ -548,175 +548,32 @@ async def aexecute_tool_call_streaming_unified(
yield result
# Function that matches the Agent class expectations
def call_tool_streaming_sync(
response: Any,
server_path: Optional[str] = None,
connection: Optional[MCPConnection] = None,
config: Optional[UnifiedTransportConfig] = None
def execute_tool_call_streaming_unified(
config: Union[UnifiedTransportConfig, MCPConnection, str],
tool_name: str,
arguments: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""
Call a tool with streaming support - matches Agent class expectations.
Execute a tool call with streaming support using the unified client.
Args:
response: The response from the LLM (may contain tool calls)
server_path: MCP server path/URL
connection: MCP connection object
config: Transport configuration
tool_name: Name of the tool to call
arguments: Tool arguments
Returns:
List of streaming tool execution results
"""
try:
# Determine the configuration to use
if config is not None:
transport_config = config
elif connection is not None:
transport_config = UnifiedTransportConfig(
transport_type=connection.transport or "auto",
url=connection.url,
headers=connection.headers,
timeout=connection.timeout or 30,
authorization_token=connection.authorization_token,
auto_detect=True,
enable_streaming=True
)
elif server_path is not None:
transport_config = UnifiedTransportConfig(
url=server_path,
transport_type="auto",
auto_detect=True,
enable_streaming=True
)
else:
raise ValueError("Either server_path, connection, or config must be provided")
# Extract tool calls from response if it's a string
if isinstance(response, str):
tool_calls = _extract_tool_calls_from_response(response)
else:
tool_calls = [{"name": "default_tool", "arguments": {}}]
# Execute each tool call with streaming
all_results = []
client = MCPUnifiedClient(transport_config)
for tool_call in tool_calls:
tool_name = tool_call.get("name", "default_tool")
arguments = tool_call.get("arguments", {})
try:
results = client.call_tool_streaming_sync(tool_name, arguments)
all_results.extend(results)
except Exception as e:
logger.error(f"Error calling tool {tool_name}: {e}")
# Add error result
all_results.append({
"error": str(e),
"tool_name": tool_name,
"arguments": arguments
})
return all_results
except Exception as e:
logger.error(f"Error in call_tool_streaming_sync: {e}")
return [{"error": str(e)}]
def _extract_tool_calls_from_response(response: str) -> List[Dict[str, Any]]:
"""
Extract tool calls from LLM response.
Args:
response: The response string from the LLM
Returns:
List of tool call dictionaries
"""
import re
import json
tool_calls = []
try:
# Try to find JSON tool calls
json_match = re.search(r'```json\s*(\{.*?\})\s*```', response, re.DOTALL)
if json_match:
try:
tool_data = json.loads(json_match.group(1))
# Check for tool_uses format
if "tool_uses" in tool_data and tool_data["tool_uses"]:
for tool_call in tool_data["tool_uses"]:
if "recipient_name" in tool_call:
tool_name = tool_call["recipient_name"]
arguments = tool_call.get("parameters", {})
tool_calls.append({
"name": tool_name,
"arguments": arguments
})
# Check for direct tool call format
elif "name" in tool_data and "arguments" in tool_data:
tool_calls.append({
"name": tool_data["name"],
"arguments": tool_data["arguments"]
})
except json.JSONDecodeError:
pass
# If no JSON found, try to extract from text
if not tool_calls:
# Look for common tool patterns
response_lower = response.lower()
if "calculate" in response_lower or "compute" in response_lower:
# Extract mathematical expression
expr_match = re.search(r'(\d+\s*[\+\-\*\/]\s*\d+)', response)
if expr_match:
tool_calls.append({
"name": "calculate",
"arguments": {"expression": expr_match.group(1)}
})
else:
tool_calls.append({
"name": "calculate",
"arguments": {"expression": "2+2"}
})
elif "search" in response_lower or "find" in response_lower:
tool_calls.append({
"name": "search",
"arguments": {"query": response.strip()}
})
else:
# Default tool call
tool_calls.append({
"name": "default_tool",
"arguments": {"input": response.strip()}
})
except Exception as e:
logger.error(f"Error extracting tool calls: {e}")
# Return default tool call
tool_calls.append({
"name": "default_tool",
"arguments": {"input": response.strip()}
})
return tool_calls
return call_tool_streaming_sync(config, tool_name, arguments)
# Helper functions for creating configurations
# Configuration factory functions
def create_stdio_config(command: List[str], **kwargs) -> UnifiedTransportConfig:
"""
Create configuration for stdio transport.
Create stdio transport configuration.
Args:
command: Command and arguments to run
command: Command to execute
**kwargs: Additional configuration options
Returns:
@ -732,7 +589,7 @@ def create_stdio_config(command: List[str], **kwargs) -> UnifiedTransportConfig:
def create_http_config(url: str, headers: Optional[Dict[str, str]] = None, **kwargs) -> UnifiedTransportConfig:
"""
Create configuration for HTTP transport.
Create HTTP transport configuration.
Args:
url: Server URL
@ -753,7 +610,7 @@ def create_http_config(url: str, headers: Optional[Dict[str, str]] = None, **kwa
def create_streamable_http_config(url: str, headers: Optional[Dict[str, str]] = None, **kwargs) -> UnifiedTransportConfig:
"""
Create configuration for streamable HTTP transport.
Create streamable HTTP transport configuration.
Args:
url: Server URL
@ -774,7 +631,7 @@ def create_streamable_http_config(url: str, headers: Optional[Dict[str, str]] =
def create_sse_config(url: str, headers: Optional[Dict[str, str]] = None, **kwargs) -> UnifiedTransportConfig:
"""
Create configuration for SSE transport.
Create SSE transport configuration.
Args:
url: Server URL

Loading…
Cancel
Save