From ff19580da56432a095c7778b0e824c22566b3464 Mon Sep 17 00:00:00 2001 From: Pavan Kumar <66913595+ascender1729@users.noreply.github.com> Date: Sun, 20 Apr 2025 11:41:19 +0000 Subject: [PATCH] fix(mcp): implement proper async handling for MCP integration --- examples/mcp_example/mock_math_server.py | 2 +- swarms/structs/agent.py | 83 ++++-------------------- swarms/tools/mcp_integration.py | 36 +++++----- 3 files changed, 33 insertions(+), 88 deletions(-) diff --git a/examples/mcp_example/mock_math_server.py b/examples/mcp_example/mock_math_server.py index 9de67735..6b6a575d 100644 --- a/examples/mcp_example/mock_math_server.py +++ b/examples/mcp_example/mock_math_server.py @@ -23,4 +23,4 @@ def divide(a: int, b: int) -> float: if __name__ == "__main__": print("Starting Mock Math Server on port 8000...") # Fix the parameters to match the FastMCP API - mcp.run(transport="sse", port=8000) + mcp.run(transport="sse", host="0.0.0.0", port=8000) diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index cf6f921a..3a2062d2 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -727,69 +727,6 @@ class Agent: tool.__name__: tool for tool in self.tools } - # def mcp_execution_flow(self, response: any): - # """ - # Executes the MCP (Model Context Protocol) flow based on the provided response. - - # This method takes a response, converts it from a string to a dictionary format, - # and checks for the presence of a tool name or a name in the response. If either - # is found, it retrieves the tool name and proceeds to call the batch_mcp_flow - # function to execute the corresponding tool actions. - - # Args: - # response (any): The response to be processed, which can be in string format - # that represents a dictionary. - - # Returns: - # The output from the batch_mcp_flow function, which contains the results of - # the tool execution. If an error occurs during processing, it logs the error - # and returns None. - - # Raises: - # Exception: Logs any exceptions that occur during the execution flow. - # """ - # try: - # response = str_to_dict(response) - - # tool_output = batch_mcp_flow( - # self.mcp_servers, - # function_call=response, - # ) - - # return tool_output - # except Exception as e: - # logger.error(f"Error in mcp_execution_flow: {e}") - # return None - - # def mcp_tool_handling(self): - # """ - # Handles the retrieval of tool schemas from the MCP servers. - - # This method iterates over the list of MCP servers, retrieves the tool schema - # for each server using the mcp_flow_get_tool_schema function, and compiles - # these schemas into a list. The resulting list is stored in the - # tools_list_dictionary attribute. - - # Returns: - # list: A list of tool schemas retrieved from the MCP servers. If an error - # occurs during the retrieval process, it logs the error and returns None. - - # Raises: - # Exception: Logs any exceptions that occur during the tool handling process. - # """ - # try: - # self.tools_list_dictionary = [] - - # for mcp_server in self.mcp_servers: - # tool_schema = mcp_flow_get_tool_schema(mcp_server) - # self.tools_list_dictionary.append(tool_schema) - - # print(self.tools_list_dictionary) - # return self.tools_list_dictionary - # except Exception as e: - # logger.error(f"Error in mcp_tool_handling: {e}") - # return None - def setup_config(self): # The max_loops will be set dynamically if the dynamic_loop if self.dynamic_loops is True: @@ -905,7 +842,7 @@ class Agent: # Randomly change the temperature attribute of self.llm object self.llm.temperature = random.uniform(0.0, 1.0) else: - # Use a default temperature + # Use a default temperature self.llm.temperature = 0.5 except Exception as error: logger.error( @@ -1936,7 +1873,7 @@ class Agent: self.retry_interval = retry_interval def reset(self): - """Reset the agent""" + """Reset the agent"""Reset the agent""" self.short_memory = None def ingest_docs(self, docs: List[str], *args, **kwargs): @@ -2775,14 +2712,16 @@ class Agent: role="Output Cleaner", content=response, ) - def mcp_execution_flow(self, payload: dict) -> str | None: - """Forward the tool-call dict to every MCP server in self.mcp_servers""" + def mcp_execution_flow(self, response: str) -> str: + """ + Forward the JSON tool-call coming from the LLM to all MCP servers + listed in self.mcp_servers. + """ try: - # Use asyncio.run which handles creating and closing the event loop - result = asyncio.run(batch_mcp_flow(self.mcp_servers, [payload])) - if not result: - return "No result returned from MCP server" - return any_to_str(result) + payload = json.loads(response) # {"tool_name": ...} + results = batch_mcp_flow(self.mcp_servers, payload) + # batch_mcp_flow already blocks, so results is a list[str] + return any_to_str(results[0] if len(results) == 1 else results) except Exception as err: logger.error(f"MCP flow failed: {err}") return f"[MCP-error] {err}" \ No newline at end of file diff --git a/swarms/tools/mcp_integration.py b/swarms/tools/mcp_integration.py index 1df684ad..2af01f86 100644 --- a/swarms/tools/mcp_integration.py +++ b/swarms/tools/mcp_integration.py @@ -346,22 +346,28 @@ async def mcp_flow( raise -async def batch_mcp_flow( - params: List[MCPServerSseParams], - function_call: List[dict[str, Any]] = [], -) -> List[Any]: # Updated return type to List[Any] - async def process_param(param): +# Helper function to call one MCP server +async def _call_one_server(param: MCPServerSseParams, payload: dict[str, Any]) -> Any: + """Make a call to a single MCP server with proper async context management.""" + async with MCPServerSse(param, cache_tools_list=True) as srv: + res = await srv.call_tool(payload) try: - async with MCPServerSse(param) as server: - return await call_tool_fast(server, function_call[0]) - except IndexError: - return None # Handle case where function_call is empty - except Exception as e: - logger.error(f"Error processing parameter: {param}, error: {e}") - return None - - results = await asyncio.gather(*(process_param(param) for param in params)) - return [any_to_str(r) for r in results if r is not None] + return res.model_dump() # For fast-mcp ≥0.2 + except AttributeError: + return res # Plain dict or string + +# Synchronous wrapper for the Agent to use +def batch_mcp_flow(params: List[MCPServerSseParams], payload: dict[str, Any]) -> List[Any]: + """Blocking helper that fans out to all MCP servers in params.""" + return asyncio.run(_batch(params, payload)) + +# Async implementation of batch processing +async def _batch(params: List[MCPServerSseParams], payload: dict[str, Any]) -> List[Any]: + """Fan out to all MCP servers asynchronously and gather results.""" + coros = [_call_one_server(p, payload) for p in params] + results = await asyncio.gather(*coros, return_exceptions=True) + # Filter out exceptions and convert to strings + return [any_to_str(r) for r in results if not isinstance(r, Exception)] from mcp import (