Update mcp_unified_client.py

pull/1005/head
CI-DEV 2 months ago committed by GitHub
parent be7862d7a4
commit faf6f2226d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -567,6 +567,432 @@ def execute_tool_call_streaming_unified(
return call_tool_streaming_sync(config, tool_name, arguments) return call_tool_streaming_sync(config, tool_name, arguments)
# Advanced functionality for Agent class integration
def call_tool_streaming_sync_advanced(
response: Any,
server_path: Optional[str] = None,
connection: Optional[MCPConnection] = None,
config: Optional[UnifiedTransportConfig] = None
) -> List[Dict[str, Any]]:
"""
Advanced function that matches the Agent class expectations.
Handles complex response parsing and multiple tool execution.
Args:
response: The response from the LLM (may contain tool calls)
server_path: MCP server path/URL
connection: MCP connection object
config: Transport configuration
Returns:
List of streaming tool execution results
"""
try:
# Determine the configuration to use
if config is not None:
transport_config = config
elif connection is not None:
transport_config = UnifiedTransportConfig(
transport_type=connection.transport or "auto",
url=connection.url,
headers=connection.headers,
timeout=connection.timeout or 30,
authorization_token=connection.authorization_token,
auto_detect=True,
enable_streaming=True
)
elif server_path is not None:
transport_config = UnifiedTransportConfig(
url=server_path,
transport_type="auto",
auto_detect=True,
enable_streaming=True
)
else:
raise ValueError("Either server_path, connection, or config must be provided")
# Extract tool calls from response if it's a string
if isinstance(response, str):
tool_calls = _extract_tool_calls_from_response_advanced(response)
else:
tool_calls = [{"name": "default_tool", "arguments": {}}]
# Execute each tool call with streaming
all_results = []
client = MCPUnifiedClient(transport_config)
for tool_call in tool_calls:
tool_name = tool_call.get("name", "default_tool")
arguments = tool_call.get("arguments", {})
try:
results = client.call_tool_streaming_sync(tool_name, arguments)
all_results.extend(results)
except Exception as e:
logger.error(f"Error calling tool {tool_name}: {e}")
# Add error result
all_results.append({
"error": str(e),
"tool_name": tool_name,
"arguments": arguments
})
return all_results
except Exception as e:
logger.error(f"Error in call_tool_streaming_sync_advanced: {e}")
return [{"error": str(e)}]
def _extract_tool_calls_from_response_advanced(response: str) -> List[Dict[str, Any]]:
"""
Advanced tool call extraction with comprehensive parsing capabilities.
Args:
response: The response string from the LLM
Returns:
List of tool call dictionaries
"""
import re
import json
tool_calls = []
try:
# Try to find JSON tool calls in code blocks
json_match = re.search(r'```json\s*(\{.*?\})\s*```', response, re.DOTALL)
if json_match:
try:
tool_data = json.loads(json_match.group(1))
# Check for tool_uses format (OpenAI format)
if "tool_uses" in tool_data and tool_data["tool_uses"]:
for tool_call in tool_data["tool_uses"]:
if "recipient_name" in tool_call:
tool_name = tool_call["recipient_name"]
arguments = tool_call.get("parameters", {})
tool_calls.append({
"name": tool_name,
"arguments": arguments
})
# Check for direct tool call format
elif "name" in tool_data and "arguments" in tool_data:
tool_calls.append({
"name": tool_data["name"],
"arguments": tool_data["arguments"]
})
# Check for function_calls format
elif "function_calls" in tool_data and tool_data["function_calls"]:
for tool_call in tool_data["function_calls"]:
if "name" in tool_call and "arguments" in tool_call:
tool_calls.append({
"name": tool_call["name"],
"arguments": tool_call["arguments"]
})
except json.JSONDecodeError:
pass
# Try to find JSON tool calls without code blocks
if not tool_calls:
json_patterns = [
r'\{[^{}]*"name"[^{}]*"arguments"[^{}]*\}',
r'\{[^{}]*"tool_uses"[^{}]*\}',
r'\{[^{}]*"function_calls"[^{}]*\}'
]
for pattern in json_patterns:
matches = re.findall(pattern, response, re.DOTALL)
for match in matches:
try:
tool_data = json.loads(match)
# Check for tool_uses format
if "tool_uses" in tool_data and tool_data["tool_uses"]:
for tool_call in tool_data["tool_uses"]:
if "recipient_name" in tool_call:
tool_calls.append({
"name": tool_call["recipient_name"],
"arguments": tool_call.get("parameters", {})
})
# Check for direct tool call format
elif "name" in tool_data and "arguments" in tool_data:
tool_calls.append({
"name": tool_data["name"],
"arguments": tool_data["arguments"]
})
# Check for function_calls format
elif "function_calls" in tool_data and tool_data["function_calls"]:
for tool_call in tool_data["function_calls"]:
if "name" in tool_call and "arguments" in tool_call:
tool_calls.append({
"name": tool_call["name"],
"arguments": tool_call["arguments"]
})
except json.JSONDecodeError:
continue
# If no JSON found, try to extract from text using pattern matching
if not tool_calls:
response_lower = response.lower()
# Look for mathematical expressions
if "calculate" in response_lower or "compute" in response_lower or "math" in response_lower:
# Extract mathematical expression
expr_patterns = [
r'(\d+\s*[\+\-\*\/\^]\s*\d+)',
r'calculate\s+(.+?)(?:\n|\.|$)',
r'compute\s+(.+?)(?:\n|\.|$)'
]
for pattern in expr_patterns:
expr_match = re.search(pattern, response, re.IGNORECASE)
if expr_match:
expression = expr_match.group(1).strip()
tool_calls.append({
"name": "calculate",
"arguments": {"expression": expression}
})
break
# Default calculation if no expression found
if not any("calculate" in tc.get("name", "") for tc in tool_calls):
tool_calls.append({
"name": "calculate",
"arguments": {"expression": "2+2"}
})
# Look for search operations
elif "search" in response_lower or "find" in response_lower or "look up" in response_lower:
# Extract search query
search_patterns = [
r'search\s+for\s+(.+?)(?:\n|\.|$)',
r'find\s+(.+?)(?:\n|\.|$)',
r'look up\s+(.+?)(?:\n|\.|$)'
]
for pattern in search_patterns:
search_match = re.search(pattern, response, re.IGNORECASE)
if search_match:
query = search_match.group(1).strip()
tool_calls.append({
"name": "search",
"arguments": {"query": query}
})
break
# Default search if no query found
if not any("search" in tc.get("name", "") for tc in tool_calls):
tool_calls.append({
"name": "search",
"arguments": {"query": response.strip()}
})
# Look for file operations
elif "read" in response_lower or "file" in response_lower or "open" in response_lower:
# Extract file path
file_patterns = [
r'read\s+(.+?)(?:\n|\.|$)',
r'open\s+(.+?)(?:\n|\.|$)',
r'file\s+(.+?)(?:\n|\.|$)'
]
for pattern in file_patterns:
file_match = re.search(pattern, response, re.IGNORECASE)
if file_match:
file_path = file_match.group(1).strip()
tool_calls.append({
"name": "read_file",
"arguments": {"file_path": file_path}
})
break
# Look for web operations
elif "web" in response_lower or "url" in response_lower or "http" in response_lower:
# Extract URL
url_patterns = [
r'https?://[^\s]+',
r'www\.[^\s]+',
r'url\s+(.+?)(?:\n|\.|$)'
]
for pattern in url_patterns:
url_match = re.search(pattern, response, re.IGNORECASE)
if url_match:
url = url_match.group(0) if pattern.startswith('http') else url_match.group(1).strip()
tool_calls.append({
"name": "fetch_url",
"arguments": {"url": url}
})
break
# Default tool call if no specific patterns found
else:
tool_calls.append({
"name": "default_tool",
"arguments": {"input": response.strip()}
})
except Exception as e:
logger.error(f"Error extracting tool calls: {e}")
# Return default tool call
tool_calls.append({
"name": "default_tool",
"arguments": {"input": response.strip()}
})
return tool_calls
# Advanced multiple server functionality
async def execute_tools_on_multiple_servers_unified(
server_configs: List[Union[UnifiedTransportConfig, MCPConnection, str]],
tool_name: str,
arguments: Dict[str, Any],
max_concurrent: int = 3
) -> List[Dict[str, Any]]:
"""
Execute the same tool on multiple MCP servers concurrently.
Args:
server_configs: List of server configurations
tool_name: Name of the tool to call
arguments: Tool arguments
max_concurrent: Maximum concurrent executions
Returns:
List of results from all servers
"""
semaphore = asyncio.Semaphore(max_concurrent)
async def execute_on_single_server(config):
async with semaphore:
try:
client = MCPUnifiedClient(config)
result = await client.call_tool(tool_name, arguments)
return {
"success": True,
"server": str(config),
"result": result
}
except Exception as e:
logger.error(f"Error executing tool on server {config}: {e}")
return {
"success": False,
"server": str(config),
"error": str(e)
}
tasks = [execute_on_single_server(config) for config in server_configs]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Handle exceptions
final_results = []
for result in results:
if isinstance(result, Exception):
final_results.append({
"success": False,
"error": str(result)
})
else:
final_results.append(result)
return final_results
def execute_tools_on_multiple_servers_unified_sync(
server_configs: List[Union[UnifiedTransportConfig, MCPConnection, str]],
tool_name: str,
arguments: Dict[str, Any],
max_concurrent: int = 3
) -> List[Dict[str, Any]]:
"""
Synchronous wrapper for execute_tools_on_multiple_servers_unified.
Args:
server_configs: List of server configurations
tool_name: Name of the tool to call
arguments: Tool arguments
max_concurrent: Maximum concurrent executions
Returns:
List of results from all servers
"""
with get_or_create_event_loop() as loop:
try:
return loop.run_until_complete(
execute_tools_on_multiple_servers_unified(
server_configs=server_configs,
tool_name=tool_name,
arguments=arguments,
max_concurrent=max_concurrent
)
)
except Exception as e:
logger.error(f"Error in execute_tools_on_multiple_servers_unified_sync: {e}")
return [{"success": False, "error": str(e)}]
# Advanced streaming with multiple servers
async def execute_tools_streaming_on_multiple_servers_unified(
server_configs: List[Union[UnifiedTransportConfig, MCPConnection, str]],
tool_name: str,
arguments: Dict[str, Any],
max_concurrent: int = 3
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Execute tools with streaming on multiple servers concurrently.
Args:
server_configs: List of server configurations
tool_name: Name of the tool to call
arguments: Tool arguments
max_concurrent: Maximum concurrent executions
Yields:
Streaming results from all servers
"""
semaphore = asyncio.Semaphore(max_concurrent)
async def execute_streaming_on_single_server(config):
async with semaphore:
try:
client = MCPUnifiedClient(config)
async for result in client.call_tool_streaming(tool_name, arguments):
yield {
"success": True,
"server": str(config),
"result": result,
"streaming": True
}
except Exception as e:
logger.error(f"Error executing streaming tool on server {config}: {e}")
yield {
"success": False,
"server": str(config),
"error": str(e),
"streaming": False
}
# Create tasks for all servers
tasks = [execute_streaming_on_single_server(config) for config in server_configs]
# Use asyncio.as_completed to yield results as they arrive
async def gather_streaming_results():
async for coro in asyncio.as_completed(tasks):
async for result in coro:
yield result
async for result in gather_streaming_results():
yield result
# Configuration factory functions # Configuration factory functions
def create_stdio_config(command: List[str], **kwargs) -> UnifiedTransportConfig: def create_stdio_config(command: List[str], **kwargs) -> UnifiedTransportConfig:
""" """
@ -726,7 +1152,11 @@ __all__ = [
"HTTPX_AVAILABLE", "HTTPX_AVAILABLE",
"MCP_AVAILABLE", "MCP_AVAILABLE",
"call_tool_streaming_sync", "call_tool_streaming_sync",
"call_tool_streaming_sync_advanced",
"execute_tool_call_streaming_unified", "execute_tool_call_streaming_unified",
"execute_tools_on_multiple_servers_unified",
"execute_tools_on_multiple_servers_unified_sync",
"execute_tools_streaming_on_multiple_servers_unified",
] ]

Loading…
Cancel
Save