pull/983/merge
王祥宇 2 days ago committed by GitHub
commit 62177f486b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -30,6 +30,7 @@ from swarms.prompts.agent_system_prompts import AGENT_SYSTEM_PROMPT_3
from swarms.prompts.multi_modal_autonomous_instruction_prompt import ( from swarms.prompts.multi_modal_autonomous_instruction_prompt import (
MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1, MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1,
) )
from swarms.tools.mcp_client_call import aget_mcp_tools
from swarms.prompts.tools import tool_sop_prompt from swarms.prompts.tools import tool_sop_prompt
from swarms.schemas.agent_mcp_errors import ( from swarms.schemas.agent_mcp_errors import (
AgentMCPConnectionError, AgentMCPConnectionError,
@ -433,6 +434,7 @@ class Agent:
summarize_multiple_images: bool = False, summarize_multiple_images: bool = False,
tool_retry_attempts: int = 3, tool_retry_attempts: int = 3,
speed_mode: str = None, speed_mode: str = None,
lazy_init_mcp: bool = False,
*args, *args,
**kwargs, **kwargs,
): ):
@ -621,6 +623,37 @@ class Agent:
self.print_dashboard() self.print_dashboard()
self.reliability_check() self.reliability_check()
self.lazy_init_mcp = lazy_init_mcp
self._mcp_tools_loaded = False
@classmethod
async def create(cls, **kwargs):
"""
Asynchronously creates an Agent instance.
This is the preferred way to create an Agent that uses MCP tools
when running in an async context (like inside FastAPI, Quart, etc.)
Args:
**kwargs: All parameters accepted by Agent.__init__
Returns:
An initialized Agent instance with MCP tools loaded
"""
# 创建带有延迟初始化标志的实例
instance = cls(lazy_init_mcp=True, **kwargs)
# 异步加载 MCP 工具(如果配置了)
if exists(instance.mcp_url) or exists(instance.mcp_urls) or exists(instance.mcp_config):
await instance.async_init_mcp_tools()
# 完成初始化 LLM
if instance.llm is None:
# 使用异步转换方式运行同步函数
instance.llm = await asyncio.to_thread(instance.llm_handling)
return instance
def rag_setup_handling(self): def rag_setup_handling(self):
return AgentRAGHandler( return AgentRAGHandler(
@ -774,22 +807,21 @@ class Agent:
This function checks for either a single MCP URL or multiple MCP URLs and adds the available tools This function checks for either a single MCP URL or multiple MCP URLs and adds the available tools
to the agent's memory. The tools are listed in JSON format. to the agent's memory. The tools are listed in JSON format.
Raises:
Exception: If there's an error accessing the MCP tools
""" """
# 如果工具已经加载过且处于懒加载模式,直接返回已缓存的工具
if hasattr(self, '_mcp_tools_loaded') and self._mcp_tools_loaded and self.tools_list_dictionary is not None:
return self.tools_list_dictionary
try: try:
if exists(self.mcp_url): if exists(self.mcp_url):
tools = get_mcp_tools_sync(server_path=self.mcp_url) tools = get_mcp_tools_sync(server_path=self.mcp_url)
elif exists(self.mcp_config): elif exists(self.mcp_config):
tools = get_mcp_tools_sync(connection=self.mcp_config) tools = get_mcp_tools_sync(connection=self.mcp_config)
# logger.info(f"Tools: {tools}")
elif exists(self.mcp_urls): elif exists(self.mcp_urls):
tools = get_tools_for_multiple_mcp_servers( tools = get_tools_for_multiple_mcp_servers(
urls=self.mcp_urls, urls=self.mcp_urls,
output_type="str", output_type="str",
) )
# print(f"Tools: {tools} for {self.mcp_urls}")
else: else:
raise AgentMCPConnectionError( raise AgentMCPConnectionError(
"mcp_url must be either a string URL or MCPConnection object" "mcp_url must be either a string URL or MCPConnection object"
@ -799,18 +831,71 @@ class Agent:
exists(self.mcp_url) exists(self.mcp_url)
or exists(self.mcp_urls) or exists(self.mcp_urls)
or exists(self.mcp_config) or exists(self.mcp_config)
): ) and self.print_on is True:
if self.print_on is True:
self.pretty_print( self.pretty_print(
f"✨ [SYSTEM] Successfully integrated {len(tools)} MCP tools into agent: {self.agent_name} | Status: ONLINE | Time: {time.strftime('%H:%M:%S')}", f"✨ [SYSTEM] Successfully integrated {len(tools)} MCP tools into agent: {self.agent_name} | Status: ONLINE | Time: {time.strftime('%H:%M:%S')}",
loop_count=0, loop_count=0,
) )
# 标记工具已加载并保存
self._mcp_tools_loaded = True
self.tools_list_dictionary = tools
return tools return tools
except AgentMCPConnectionError as e: except AgentMCPConnectionError as e:
logger.error(f"Error in MCP connection: {e}") logger.error(f"Error in MCP connection: {e}")
raise e raise e
async def async_init_mcp_tools(self):
"""
Asynchronously initialize MCP tools.
This method should be used when the agent is created in an async context
to avoid event loop conflicts.
Returns:
The list of MCP tools
"""
# 如果工具已加载,直接返回
if hasattr(self, '_mcp_tools_loaded') and self._mcp_tools_loaded and self.tools_list_dictionary is not None:
return self.tools_list_dictionary
try:
if exists(self.mcp_url):
tools = await aget_mcp_tools(server_path=self.mcp_url, format="openai")
elif exists(self.mcp_config):
tools = await aget_mcp_tools(connection=self.mcp_config, format="openai")
elif exists(self.mcp_urls):
# 使用异步转换方式运行同步函数
tools = await asyncio.to_thread(
get_tools_for_multiple_mcp_servers,
urls=self.mcp_urls,
output_type="str"
)
else:
raise AgentMCPConnectionError(
"mcp_url must be either a string URL or MCPConnection object"
)
if (
exists(self.mcp_url)
or exists(self.mcp_urls)
or exists(self.mcp_config)
) and self.print_on is True:
# 使用异步转换方式运行同步函数
await asyncio.to_thread(
self.pretty_print,
f"✨ [SYSTEM] Successfully integrated {len(tools)} MCP tools into agent: {self.agent_name} | Status: ONLINE | Time: {time.strftime('%H:%M:%S')}",
loop_count=0
)
# 标记工具已加载并保存
self._mcp_tools_loaded = True
self.tools_list_dictionary = tools
return tools
except Exception as e:
logger.error(f"Error in async MCP tools initialization: {e}")
raise AgentMCPConnectionError(f"Failed to initialize MCP tools: {str(e)}")
def setup_config(self): def setup_config(self):
# The max_loops will be set dynamically if the dynamic_loop # The max_loops will be set dynamically if the dynamic_loop
if self.dynamic_loops is True: if self.dynamic_loops is True:
@ -1270,25 +1355,18 @@ class Agent:
""" """
Asynchronously runs the agent with the specified parameters. Asynchronously runs the agent with the specified parameters.
Args: Enhanced to support proper async initialization of MCP tools if needed.
task (Optional[str]): The task to be performed. Defaults to None.
img (Optional[str]): The image to be processed. Defaults to None.
is_last (bool): Indicates if this is the last task. Defaults to False.
device (str): The device to use for execution. Defaults to "cpu".
device_id (int): The ID of the GPU to use if device is set to "gpu". Defaults to 1.
all_cores (bool): If True, uses all available CPU cores. Defaults to True.
do_not_use_cluster_ops (bool): If True, does not use cluster operations. Defaults to True.
all_gpus (bool): If True, uses all available GPUs. Defaults to False.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
Any: The result of the asynchronous operation.
Raises:
Exception: If an error occurs during the asynchronous operation.
""" """
try: try:
# 如果需要且尚未加载 MCP 工具,先进行异步初始化
if (exists(self.mcp_url) or exists(self.mcp_urls) or exists(self.mcp_config)) and \
not (hasattr(self, '_mcp_tools_loaded') and self._mcp_tools_loaded):
await self.async_init_mcp_tools()
# 确保 LLM 已初始化并加载了工具
if self.llm is None:
self.llm = await asyncio.to_thread(self.llm_handling)
# 使用原来的方式调用同步 run 函数
return await asyncio.to_thread( return await asyncio.to_thread(
self.run, self.run,
task=task, task=task,
@ -1297,9 +1375,7 @@ class Agent:
**kwargs, **kwargs,
) )
except Exception as error: except Exception as error:
await self._handle_run_error( await self._handle_run_error(error)
error
) # Ensure this is also async if needed
def __call__( def __call__(
self, self,
@ -3233,4 +3309,6 @@ class Agent:
f"Agent '{self.agent_name}' encountered error during tool execution in loop {loop_count}: {str(e)}. " f"Agent '{self.agent_name}' encountered error during tool execution in loop {loop_count}: {str(e)}. "
f"Full traceback: {traceback.format_exc()}. " f"Full traceback: {traceback.format_exc()}. "
f"Attempting to retry tool execution with 3 attempts" f"Attempting to retry tool execution with 3 attempts"
) )

@ -12,9 +12,20 @@ from loguru import logger
from mcp import ClientSession from mcp import ClientSession
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
# Try to import nest_asyncio if available
try:
import nest_asyncio
HAS_NEST_ASYNCIO = True
logger.debug("nest_asyncio is available and will be used for nested event loops")
except ImportError:
HAS_NEST_ASYNCIO = False
logger.debug("nest_asyncio is not available, will use alternative methods for nested event loops")
try: try:
from mcp.client.streamable_http import streamablehttp_client from mcp.client.streamable_http import streamablehttp_client
HAS_STREAMABLE_HTTP = True
except ImportError: except ImportError:
HAS_STREAMABLE_HTTP = False
logger.error( logger.error(
"streamablehttp_client is not available. Please ensure the MCP SDK is up to date with pip3 install -U mcp" "streamablehttp_client is not available. Please ensure the MCP SDK is up to date with pip3 install -U mcp"
) )
@ -28,7 +39,6 @@ from openai.types.chat import ChatCompletionToolParam
from openai.types.shared_params.function_definition import ( from openai.types.shared_params.function_definition import (
FunctionDefinition, FunctionDefinition,
) )
from swarms.schemas.mcp_schemas import ( from swarms.schemas.mcp_schemas import (
MCPConnection, MCPConnection,
) )
@ -38,37 +48,33 @@ from urllib.parse import urlparse
class MCPError(Exception): class MCPError(Exception):
"""Base exception for MCP related errors.""" """Base exception for MCP related errors."""
pass pass
class MCPConnectionError(MCPError): class MCPConnectionError(MCPError):
"""Raised when there are issues connecting to the MCP server.""" """Raised when there are issues connecting to the MCP server."""
pass pass
class MCPToolError(MCPError): class MCPToolError(MCPError):
"""Raised when there are issues with MCP tool operations.""" """Raised when there are issues with MCP tool operations."""
pass pass
class MCPValidationError(MCPError): class MCPValidationError(MCPError):
"""Raised when there are validation issues with MCP operations.""" """Raised when there are validation issues with MCP operations."""
pass pass
class MCPExecutionError(MCPError): class MCPExecutionError(MCPError):
"""Raised when there are issues executing MCP operations.""" """Raised when there are issues executing MCP operations."""
pass pass
######################################################## ########################################################
# List MCP Tool functions # List MCP Tool functions
######################################################## ########################################################
def transform_mcp_tool_to_openai_tool( def transform_mcp_tool_to_openai_tool(
mcp_tool: MCPTool, mcp_tool: MCPTool,
) -> ChatCompletionToolParam: ) -> ChatCompletionToolParam:
@ -118,7 +124,6 @@ async def load_mcp_tools(
# Call MCP Tool functions # Call MCP Tool functions
######################################################## ########################################################
async def call_mcp_tool( async def call_mcp_tool(
session: ClientSession, session: ClientSession,
call_tool_request_params: MCPCallToolRequestParams, call_tool_request_params: MCPCallToolRequestParams,
@ -203,7 +208,6 @@ def retry_with_backoff(retries=3, backoff_in_seconds=1):
Returns: Returns:
Decorated async function with retry logic. Decorated async function with retry logic.
""" """
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
@ -226,30 +230,59 @@ def retry_with_backoff(retries=3, backoff_in_seconds=1):
) )
await asyncio.sleep(sleep_time) await asyncio.sleep(sleep_time)
x += 1 x += 1
return wrapper return wrapper
return decorator return decorator
def _run_in_new_thread(func, *args, **kwargs):
"""Run a coroutine function in a new thread with its own event loop."""
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(_run_in_new_loop, func, *args, **kwargs)
return future.result()
def _run_in_new_loop(func, *args, **kwargs):
"""Run a coroutine function in a new event loop."""
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
return loop.run_until_complete(func(*args, **kwargs))
finally:
loop.close()
@contextlib.contextmanager @contextlib.contextmanager
def get_or_create_event_loop(): def get_or_create_event_loop():
""" """Context manager to handle event loop creation and cleanup with better handling of running loops."""
Context manager to handle event loop creation and cleanup.
Yields:
asyncio.AbstractEventLoop: The event loop to use.
Ensures the event loop is properly closed if created here.
"""
try: try:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop_was_running = loop.is_running()
# If loop is running and nest_asyncio is available, apply it
if loop_was_running and HAS_NEST_ASYNCIO:
nest_asyncio.apply(loop)
logger.debug("Applied nest_asyncio to running event loop")
created_new = False
# If loop is running and nest_asyncio is not available, create a new loop
elif loop_was_running:
logger.debug("Event loop is already running, creating new loop")
loop = asyncio.new_event_loop()
created_new = True
else:
created_new = False
except RuntimeError: except RuntimeError:
logger.debug("No event loop found, creating new one")
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
created_new = True
loop_was_running = False
try: try:
yield loop yield loop
finally: finally:
# Only close the loop if we created it and it's not the main event loop # Only close the loop if we created a new one and it's not running
if loop != asyncio.get_event_loop() and not loop.is_running(): if created_new and not loop.is_running():
if not loop.is_closed(): if not loop.is_closed():
loop.close() loop.close()
@ -304,7 +337,7 @@ def get_mcp_client(transport, url, headers=None, timeout=5, **kwargs):
f"Getting MCP client for transport '{transport}' and url '{url}'." f"Getting MCP client for transport '{transport}' and url '{url}'."
) )
if transport == "streamable_http": if transport == "streamable_http":
if streamablehttp_client is None: if not HAS_STREAMABLE_HTTP:
logger.error("streamablehttp_client is not available.") logger.error("streamablehttp_client is not available.")
raise ImportError( raise ImportError(
"streamablehttp_client is not available. Please ensure the MCP SDK is up to date." "streamablehttp_client is not available. Please ensure the MCP SDK is up to date."
@ -391,9 +424,11 @@ async def aget_mcp_tools(
server_path, server_path,
) )
url = server_path url = server_path
logger.info( logger.info(
f"Fetching MCP tools from server: {server_path} using transport: {transport}" f"Fetching MCP tools from server: {server_path} using transport: {transport}"
) )
try: try:
async with get_mcp_client( async with get_mcp_client(
transport, transport,
@ -433,13 +468,14 @@ def get_mcp_tools_sync(
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
Synchronous version of get_mcp_tools that handles event loop management. Synchronous version of get_mcp_tools that handles event loop management.
Improved to handle cases where the event loop is already running.
Args: Args:
server_path (str): Path to the MCP server script. server_path (str): Path to the MCP server script.
format (str): Format to return tools in ('openai' or 'mcp'). format (str): Format to return tools in ('openai' or 'mcp').
connection (Optional[MCPConnection]): Optional connection object. connection (Optional[MCPConnection]): Optional connection object.
transport (Optional[str]): Transport type. If None, auto-detects. transport (Optional[str]): Transport type. If None, auto-detects.
Returns: Returns:
List[Dict[str, Any]]: List of available MCP tools in OpenAI format. List[Dict[str, Any]]: List of available MCP tools in requested format.
Raises: Raises:
MCPValidationError: If server_path is invalid. MCPValidationError: If server_path is invalid.
MCPConnectionError: If connection to server fails. MCPConnectionError: If connection to server fails.
@ -448,10 +484,49 @@ def get_mcp_tools_sync(
logger.info( logger.info(
f"get_mcp_tools_sync called for server_path: {server_path}" f"get_mcp_tools_sync called for server_path: {server_path}"
) )
if transport is None: if transport is None:
transport = auto_detect_transport(server_path) transport = auto_detect_transport(server_path)
with get_or_create_event_loop() as loop:
try: try:
# Check if we're in a running event loop
try:
loop = asyncio.get_event_loop()
loop_is_running = loop.is_running()
except RuntimeError:
loop_is_running = False
loop = None
# If loop is already running and nest_asyncio is available, use it
if loop_is_running and HAS_NEST_ASYNCIO:
logger.debug("Using nest_asyncio with running event loop")
nest_asyncio.apply(loop)
return loop.run_until_complete(
aget_mcp_tools(
server_path=server_path,
format=format,
connection=connection,
transport=transport,
*args,
**kwargs,
)
)
# If loop is running but nest_asyncio not available, use thread
elif loop_is_running:
logger.debug("Event loop is running, executing in separate thread")
return _run_in_new_thread(
aget_mcp_tools,
server_path=server_path,
format=format,
connection=connection,
transport=transport,
*args,
**kwargs,
)
# Standard case: no running loop or we're not in an event loop
else:
logger.debug("Using standard event loop management")
with get_or_create_event_loop() as loop:
return loop.run_until_complete( return loop.run_until_complete(
aget_mcp_tools( aget_mcp_tools(
server_path=server_path, server_path=server_path,
@ -520,11 +595,13 @@ def get_tools_for_multiple_mcp_servers(
f"get_tools_for_multiple_mcp_servers called for {len(urls)} urls." f"get_tools_for_multiple_mcp_servers called for {len(urls)} urls."
) )
tools = [] tools = []
(
max_workers = (
min(32, os.cpu_count() + 4) min(32, os.cpu_count() + 4)
if max_workers is None if max_workers is None
else max_workers else max_workers
) )
with ThreadPoolExecutor(max_workers=max_workers) as executor: with ThreadPoolExecutor(max_workers=max_workers) as executor:
if exists(connections): if exists(connections):
future_to_url = { future_to_url = {
@ -548,6 +625,7 @@ def get_tools_for_multiple_mcp_servers(
): url ): url
for url in urls for url in urls
} }
for future in as_completed(future_to_url): for future in as_completed(future_to_url):
url = future_to_url[future] url = future_to_url[future]
try: try:
@ -560,6 +638,7 @@ def get_tools_for_multiple_mcp_servers(
raise MCPExecutionError( raise MCPExecutionError(
f"Failed to fetch tools from {url}: {str(e)}" f"Failed to fetch tools from {url}: {str(e)}"
) )
return tools return tools
@ -603,6 +682,7 @@ async def _execute_tool_call_simple(
"sse", "sse",
server_path, server_path,
) )
try: try:
async with get_mcp_client( async with get_mcp_client(
transport, transport,
@ -756,7 +836,7 @@ async def _create_server_tool_mapping_async(
urls: List[str], urls: List[str],
connections: List[MCPConnection] = None, connections: List[MCPConnection] = None,
format: str = "openai", format: str = "openai",
transport: str = "sse", transport: Optional[str] = None,
) -> Dict[str, Dict[str, Any]]: ) -> Dict[str, Dict[str, Any]]:
""" """
Async version: Create a mapping of function names to server information for all MCP servers. Async version: Create a mapping of function names to server information for all MCP servers.
@ -764,7 +844,7 @@ async def _create_server_tool_mapping_async(
urls (List[str]): List of server URLs. urls (List[str]): List of server URLs.
connections (List[MCPConnection]): Optional list of MCPConnection objects. connections (List[MCPConnection]): Optional list of MCPConnection objects.
format (str): Format to fetch tools in. format (str): Format to fetch tools in.
transport (str): Transport type. transport (Optional[str]): Transport type. If None, auto-detects per URL.
Returns: Returns:
Dict[str, Dict[str, Any]]: Mapping of function names to server info. Dict[str, Dict[str, Any]]: Mapping of function names to server info.
""" """
@ -776,11 +856,16 @@ async def _create_server_tool_mapping_async(
else None else None
) )
try: try:
if transport is None:
transport_to_use = auto_detect_transport(url)
else:
transport_to_use = transport
tools = await aget_mcp_tools( tools = await aget_mcp_tools(
server_path=url, server_path=url,
connection=connection, connection=connection,
format=format, format=format,
transport=transport, transport=transport_to_use,
) )
for tool in tools: for tool in tools:
if isinstance(tool, dict) and "function" in tool: if isinstance(tool, dict) and "function" in tool:
@ -810,7 +895,7 @@ async def _execute_tool_on_server(
tool_call: Dict[str, Any], tool_call: Dict[str, Any],
server_info: Dict[str, Any], server_info: Dict[str, Any],
output_type: Literal["json", "dict", "str", "formatted"] = "str", output_type: Literal["json", "dict", "str", "formatted"] = "str",
transport: str = "sse", transport: Optional[str] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Execute a single tool call on a specific server. Execute a single tool call on a specific server.
@ -818,7 +903,7 @@ async def _execute_tool_on_server(
tool_call (Dict[str, Any]): The tool call to execute. tool_call (Dict[str, Any]): The tool call to execute.
server_info (Dict[str, Any]): Server information from the mapping. server_info (Dict[str, Any]): Server information from the mapping.
output_type (Literal): Output format type. output_type (Literal): Output format type.
transport (str): Transport type. transport (Optional[str]): Transport type. If None, auto-detects.
Returns: Returns:
Dict[str, Any]: Execution result with server metadata. Dict[str, Any]: Execution result with server metadata.
""" """
@ -861,7 +946,7 @@ async def execute_multiple_tools_on_multiple_mcp_servers(
connections: List[MCPConnection] = None, connections: List[MCPConnection] = None,
output_type: Literal["json", "dict", "str", "formatted"] = "str", output_type: Literal["json", "dict", "str", "formatted"] = "str",
max_concurrent: Optional[int] = None, max_concurrent: Optional[int] = None,
transport: str = "sse", transport: Optional[str] = None,
*args, *args,
**kwargs, **kwargs,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
@ -873,17 +958,19 @@ async def execute_multiple_tools_on_multiple_mcp_servers(
connections (List[MCPConnection]): Optional list of MCPConnection objects. connections (List[MCPConnection]): Optional list of MCPConnection objects.
output_type (Literal): Output format type. output_type (Literal): Output format type.
max_concurrent (Optional[int]): Max concurrent tasks. max_concurrent (Optional[int]): Max concurrent tasks.
transport (str): Transport type. transport (Optional[str]): Transport type. If None, auto-detects per URL.
Returns: Returns:
List[Dict[str, Any]]: List of execution results. List[Dict[str, Any]]: List of execution results.
""" """
if not responses: if not responses:
logger.warning("No responses provided for execution") logger.warning("No responses provided for execution")
return [] return []
if not urls: if not urls:
raise MCPValidationError("No server URLs provided") raise MCPValidationError("No server URLs provided")
logger.info( logger.info(
f"Creating tool mapping for {len(urls)} servers using transport: {transport}" f"Creating tool mapping for {len(urls)} servers"
) )
server_tool_mapping = await _create_server_tool_mapping_async( server_tool_mapping = await _create_server_tool_mapping_async(
urls=urls, urls=urls,
@ -891,17 +978,21 @@ async def execute_multiple_tools_on_multiple_mcp_servers(
format="openai", format="openai",
transport=transport, transport=transport,
) )
if not server_tool_mapping: if not server_tool_mapping:
raise MCPExecutionError( raise MCPExecutionError(
"No tools found on any of the provided servers" "No tools found on any of the provided servers"
) )
logger.info( logger.info(
f"Found {len(server_tool_mapping)} unique functions across all servers" f"Found {len(server_tool_mapping)} unique functions across all servers"
) )
all_tool_calls = [] all_tool_calls = []
logger.info( logger.info(
f"Processing {len(responses)} responses for tool call extraction" f"Processing {len(responses)} responses for tool call extraction"
) )
if len(responses) > 10 and all( if len(responses) > 10 and all(
isinstance(r, str) and len(r) == 1 for r in responses isinstance(r, str) and len(r) == 1 for r in responses
): ):
@ -936,6 +1027,7 @@ async def execute_multiple_tools_on_multiple_mcp_servers(
logger.warning( logger.warning(
f"Failed to reconstruct response from characters: {str(e)}" f"Failed to reconstruct response from characters: {str(e)}"
) )
for i, response in enumerate(responses): for i, response in enumerate(responses):
logger.debug( logger.debug(
f"Processing response {i}: {type(response)} - {response}" f"Processing response {i}: {type(response)} - {response}"
@ -951,6 +1043,7 @@ async def execute_multiple_tools_on_multiple_mcp_servers(
f"Failed to parse JSON response at index {i}: {response}" f"Failed to parse JSON response at index {i}: {response}"
) )
continue continue
if isinstance(response, dict): if isinstance(response, dict):
if "function" in response: if "function" in response:
logger.debug( logger.debug(
@ -1024,10 +1117,13 @@ async def execute_multiple_tools_on_multiple_mcp_servers(
f"Unsupported response type at index {i}: {type(response)}" f"Unsupported response type at index {i}: {type(response)}"
) )
continue continue
if not all_tool_calls: if not all_tool_calls:
logger.warning("No tool calls found in responses") logger.warning("No tool calls found in responses")
return [] return []
logger.info(f"Found {len(all_tool_calls)} tool calls to execute") logger.info(f"Found {len(all_tool_calls)} tool calls to execute")
max_concurrent = max_concurrent or len(all_tool_calls) max_concurrent = max_concurrent or len(all_tool_calls)
semaphore = asyncio.Semaphore(max_concurrent) semaphore = asyncio.Semaphore(max_concurrent)
@ -1048,6 +1144,7 @@ async def execute_multiple_tools_on_multiple_mcp_servers(
"error": f"Function '{function_name}' not available on any server", "error": f"Function '{function_name}' not available on any server",
"status": "not_found", "status": "not_found",
} }
server_info = server_tool_mapping[function_name] server_info = server_tool_mapping[function_name]
result = await _execute_tool_on_server( result = await _execute_tool_on_server(
tool_call=tool_call, tool_call=tool_call,
@ -1062,7 +1159,9 @@ async def execute_multiple_tools_on_multiple_mcp_servers(
execute_with_semaphore(tool_call_info) execute_with_semaphore(tool_call_info)
for tool_call_info in all_tool_calls for tool_call_info in all_tool_calls
] ]
results = await asyncio.gather(*tasks, return_exceptions=True) results = await asyncio.gather(*tasks, return_exceptions=True)
processed_results = [] processed_results = []
for i, result in enumerate(results): for i, result in enumerate(results):
if isinstance(result, Exception): if isinstance(result, Exception):
@ -1084,6 +1183,7 @@ async def execute_multiple_tools_on_multiple_mcp_servers(
) )
else: else:
processed_results.append(result) processed_results.append(result)
logger.info( logger.info(
f"Completed execution of {len(processed_results)} tool calls" f"Completed execution of {len(processed_results)} tool calls"
) )
@ -1096,24 +1196,66 @@ def execute_multiple_tools_on_multiple_mcp_servers_sync(
connections: List[MCPConnection] = None, connections: List[MCPConnection] = None,
output_type: Literal["json", "dict", "str", "formatted"] = "str", output_type: Literal["json", "dict", "str", "formatted"] = "str",
max_concurrent: Optional[int] = None, max_concurrent: Optional[int] = None,
transport: str = "sse", transport: Optional[str] = None,
*args, *args,
**kwargs, **kwargs,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
Synchronous version of execute_multiple_tools_on_multiple_mcp_servers. Synchronous version of execute_multiple_tools_on_multiple_mcp_servers.
Modified to handle running event loops better.
Args: Args:
responses (List[Dict[str, Any]]): List of tool call requests. responses (List[Dict[str, Any]]): List of tool call requests.
urls (List[str]): List of server URLs. urls (List[str]): List of server URLs.
connections (List[MCPConnection]): Optional list of MCPConnection objects. connections (List[MCPConnection]): Optional list of MCPConnection objects.
output_type (Literal): Output format type. output_type (Literal): Output format type.
max_concurrent (Optional[int]): Max concurrent tasks. max_concurrent (Optional[int]): Max concurrent tasks.
transport (str): Transport type. transport (Optional[str]): Transport type. If None, auto-detects per URL.
Returns: Returns:
List[Dict[str, Any]]: List of execution results. List[Dict[str, Any]]: List of execution results.
""" """
with get_or_create_event_loop() as loop:
try: try:
# Check if we're in a running event loop
try:
loop = asyncio.get_event_loop()
loop_is_running = loop.is_running()
except RuntimeError:
loop_is_running = False
loop = None
# If loop is already running and nest_asyncio is available, use it
if loop_is_running and HAS_NEST_ASYNCIO:
logger.debug("Using nest_asyncio with running event loop for multiple tools")
nest_asyncio.apply(loop)
return loop.run_until_complete(
execute_multiple_tools_on_multiple_mcp_servers(
responses=responses,
urls=urls,
connections=connections,
output_type=output_type,
max_concurrent=max_concurrent,
transport=transport,
*args,
**kwargs,
)
)
# If loop is running but nest_asyncio not available, use thread
elif loop_is_running:
logger.debug("Event loop is running, executing multiple tools in separate thread")
return _run_in_new_thread(
execute_multiple_tools_on_multiple_mcp_servers,
responses=responses,
urls=urls,
connections=connections,
output_type=output_type,
max_concurrent=max_concurrent,
transport=transport,
*args,
**kwargs,
)
# Standard case: no running loop or we're not in an event loop
else:
logger.debug("Using standard event loop management for multiple tools")
with get_or_create_event_loop() as loop:
return loop.run_until_complete( return loop.run_until_complete(
execute_multiple_tools_on_multiple_mcp_servers( execute_multiple_tools_on_multiple_mcp_servers(
responses=responses, responses=responses,

Loading…
Cancel
Save