fix(mcp): implement proper async handling for MCP integration

pull/819/head
Pavan Kumar 3 months ago committed by ascender1729
parent f8d422fbd2
commit ff19580da5

@ -23,4 +23,4 @@ def divide(a: int, b: int) -> float:
if __name__ == "__main__":
print("Starting Mock Math Server on port 8000...")
# Fix the parameters to match the FastMCP API
mcp.run(transport="sse", port=8000)
mcp.run(transport="sse", host="0.0.0.0", port=8000)

@ -727,69 +727,6 @@ class Agent:
tool.__name__: tool for tool in self.tools
}
# def mcp_execution_flow(self, response: any):
# """
# Executes the MCP (Model Context Protocol) flow based on the provided response.
# This method takes a response, converts it from a string to a dictionary format,
# and checks for the presence of a tool name or a name in the response. If either
# is found, it retrieves the tool name and proceeds to call the batch_mcp_flow
# function to execute the corresponding tool actions.
# Args:
# response (any): The response to be processed, which can be in string format
# that represents a dictionary.
# Returns:
# The output from the batch_mcp_flow function, which contains the results of
# the tool execution. If an error occurs during processing, it logs the error
# and returns None.
# Raises:
# Exception: Logs any exceptions that occur during the execution flow.
# """
# try:
# response = str_to_dict(response)
# tool_output = batch_mcp_flow(
# self.mcp_servers,
# function_call=response,
# )
# return tool_output
# except Exception as e:
# logger.error(f"Error in mcp_execution_flow: {e}")
# return None
# def mcp_tool_handling(self):
# """
# Handles the retrieval of tool schemas from the MCP servers.
# This method iterates over the list of MCP servers, retrieves the tool schema
# for each server using the mcp_flow_get_tool_schema function, and compiles
# these schemas into a list. The resulting list is stored in the
# tools_list_dictionary attribute.
# Returns:
# list: A list of tool schemas retrieved from the MCP servers. If an error
# occurs during the retrieval process, it logs the error and returns None.
# Raises:
# Exception: Logs any exceptions that occur during the tool handling process.
# """
# try:
# self.tools_list_dictionary = []
# for mcp_server in self.mcp_servers:
# tool_schema = mcp_flow_get_tool_schema(mcp_server)
# self.tools_list_dictionary.append(tool_schema)
# print(self.tools_list_dictionary)
# return self.tools_list_dictionary
# except Exception as e:
# logger.error(f"Error in mcp_tool_handling: {e}")
# return None
def setup_config(self):
# The max_loops will be set dynamically if the dynamic_loop
if self.dynamic_loops is True:
@ -1936,7 +1873,7 @@ class Agent:
self.retry_interval = retry_interval
def reset(self):
"""Reset the agent"""
"""Reset the agent"""Reset the agent"""
self.short_memory = None
def ingest_docs(self, docs: List[str], *args, **kwargs):
@ -2775,14 +2712,16 @@ class Agent:
role="Output Cleaner",
content=response,
)
def mcp_execution_flow(self, payload: dict) -> str | None:
"""Forward the tool-call dict to every MCP server in self.mcp_servers"""
def mcp_execution_flow(self, response: str) -> str:
"""
Forward the JSON tool-call coming from the LLM to all MCP servers
listed in self.mcp_servers.
"""
try:
# Use asyncio.run which handles creating and closing the event loop
result = asyncio.run(batch_mcp_flow(self.mcp_servers, [payload]))
if not result:
return "No result returned from MCP server"
return any_to_str(result)
payload = json.loads(response) # {"tool_name": ...}
results = batch_mcp_flow(self.mcp_servers, payload)
# batch_mcp_flow already blocks, so results is a list[str]
return any_to_str(results[0] if len(results) == 1 else results)
except Exception as err:
logger.error(f"MCP flow failed: {err}")
return f"[MCP-error] {err}"

@ -346,22 +346,28 @@ async def mcp_flow(
raise
async def batch_mcp_flow(
params: List[MCPServerSseParams],
function_call: List[dict[str, Any]] = [],
) -> List[Any]: # Updated return type to List[Any]
async def process_param(param):
# Helper function to call one MCP server
async def _call_one_server(param: MCPServerSseParams, payload: dict[str, Any]) -> Any:
"""Make a call to a single MCP server with proper async context management."""
async with MCPServerSse(param, cache_tools_list=True) as srv:
res = await srv.call_tool(payload)
try:
async with MCPServerSse(param) as server:
return await call_tool_fast(server, function_call[0])
except IndexError:
return None # Handle case where function_call is empty
except Exception as e:
logger.error(f"Error processing parameter: {param}, error: {e}")
return None
results = await asyncio.gather(*(process_param(param) for param in params))
return [any_to_str(r) for r in results if r is not None]
return res.model_dump() # For fast-mcp ≥0.2
except AttributeError:
return res # Plain dict or string
# Synchronous wrapper for the Agent to use
def batch_mcp_flow(params: List[MCPServerSseParams], payload: dict[str, Any]) -> List[Any]:
"""Blocking helper that fans out to all MCP servers in params."""
return asyncio.run(_batch(params, payload))
# Async implementation of batch processing
async def _batch(params: List[MCPServerSseParams], payload: dict[str, Any]) -> List[Any]:
"""Fan out to all MCP servers asynchronously and gather results."""
coros = [_call_one_server(p, payload) for p in params]
results = await asyncio.gather(*coros, return_exceptions=True)
# Filter out exceptions and convert to strings
return [any_to_str(r) for r in results if not isinstance(r, Exception)]
from mcp import (

Loading…
Cancel
Save