fix(mcp): implement proper async handling for MCP integration

pull/819/head
Pavan Kumar 3 months ago committed by ascender1729
parent f8d422fbd2
commit ff19580da5

@ -23,4 +23,4 @@ def divide(a: int, b: int) -> float:
if __name__ == "__main__": if __name__ == "__main__":
print("Starting Mock Math Server on port 8000...") print("Starting Mock Math Server on port 8000...")
# Fix the parameters to match the FastMCP API # Fix the parameters to match the FastMCP API
mcp.run(transport="sse", port=8000) mcp.run(transport="sse", host="0.0.0.0", port=8000)

@ -727,69 +727,6 @@ class Agent:
tool.__name__: tool for tool in self.tools tool.__name__: tool for tool in self.tools
} }
# def mcp_execution_flow(self, response: any):
# """
# Executes the MCP (Model Context Protocol) flow based on the provided response.
# This method takes a response, converts it from a string to a dictionary format,
# and checks for the presence of a tool name or a name in the response. If either
# is found, it retrieves the tool name and proceeds to call the batch_mcp_flow
# function to execute the corresponding tool actions.
# Args:
# response (any): The response to be processed, which can be in string format
# that represents a dictionary.
# Returns:
# The output from the batch_mcp_flow function, which contains the results of
# the tool execution. If an error occurs during processing, it logs the error
# and returns None.
# Raises:
# Exception: Logs any exceptions that occur during the execution flow.
# """
# try:
# response = str_to_dict(response)
# tool_output = batch_mcp_flow(
# self.mcp_servers,
# function_call=response,
# )
# return tool_output
# except Exception as e:
# logger.error(f"Error in mcp_execution_flow: {e}")
# return None
# def mcp_tool_handling(self):
# """
# Handles the retrieval of tool schemas from the MCP servers.
# This method iterates over the list of MCP servers, retrieves the tool schema
# for each server using the mcp_flow_get_tool_schema function, and compiles
# these schemas into a list. The resulting list is stored in the
# tools_list_dictionary attribute.
# Returns:
# list: A list of tool schemas retrieved from the MCP servers. If an error
# occurs during the retrieval process, it logs the error and returns None.
# Raises:
# Exception: Logs any exceptions that occur during the tool handling process.
# """
# try:
# self.tools_list_dictionary = []
# for mcp_server in self.mcp_servers:
# tool_schema = mcp_flow_get_tool_schema(mcp_server)
# self.tools_list_dictionary.append(tool_schema)
# print(self.tools_list_dictionary)
# return self.tools_list_dictionary
# except Exception as e:
# logger.error(f"Error in mcp_tool_handling: {e}")
# return None
def setup_config(self): def setup_config(self):
# The max_loops will be set dynamically if the dynamic_loop # The max_loops will be set dynamically if the dynamic_loop
if self.dynamic_loops is True: if self.dynamic_loops is True:
@ -1936,7 +1873,7 @@ class Agent:
self.retry_interval = retry_interval self.retry_interval = retry_interval
def reset(self): def reset(self):
"""Reset the agent""" """Reset the agent"""Reset the agent"""
self.short_memory = None self.short_memory = None
def ingest_docs(self, docs: List[str], *args, **kwargs): def ingest_docs(self, docs: List[str], *args, **kwargs):
@ -2775,14 +2712,16 @@ class Agent:
role="Output Cleaner", role="Output Cleaner",
content=response, content=response,
) )
def mcp_execution_flow(self, payload: dict) -> str | None: def mcp_execution_flow(self, response: str) -> str:
"""Forward the tool-call dict to every MCP server in self.mcp_servers""" """
Forward the JSON tool-call coming from the LLM to all MCP servers
listed in self.mcp_servers.
"""
try: try:
# Use asyncio.run which handles creating and closing the event loop payload = json.loads(response) # {"tool_name": ...}
result = asyncio.run(batch_mcp_flow(self.mcp_servers, [payload])) results = batch_mcp_flow(self.mcp_servers, payload)
if not result: # batch_mcp_flow already blocks, so results is a list[str]
return "No result returned from MCP server" return any_to_str(results[0] if len(results) == 1 else results)
return any_to_str(result)
except Exception as err: except Exception as err:
logger.error(f"MCP flow failed: {err}") logger.error(f"MCP flow failed: {err}")
return f"[MCP-error] {err}" return f"[MCP-error] {err}"

@ -346,22 +346,28 @@ async def mcp_flow(
raise raise
async def batch_mcp_flow( # Helper function to call one MCP server
params: List[MCPServerSseParams], async def _call_one_server(param: MCPServerSseParams, payload: dict[str, Any]) -> Any:
function_call: List[dict[str, Any]] = [], """Make a call to a single MCP server with proper async context management."""
) -> List[Any]: # Updated return type to List[Any] async with MCPServerSse(param, cache_tools_list=True) as srv:
async def process_param(param): res = await srv.call_tool(payload)
try: try:
async with MCPServerSse(param) as server: return res.model_dump() # For fast-mcp ≥0.2
return await call_tool_fast(server, function_call[0]) except AttributeError:
except IndexError: return res # Plain dict or string
return None # Handle case where function_call is empty
except Exception as e: # Synchronous wrapper for the Agent to use
logger.error(f"Error processing parameter: {param}, error: {e}") def batch_mcp_flow(params: List[MCPServerSseParams], payload: dict[str, Any]) -> List[Any]:
return None """Blocking helper that fans out to all MCP servers in params."""
return asyncio.run(_batch(params, payload))
results = await asyncio.gather(*(process_param(param) for param in params))
return [any_to_str(r) for r in results if r is not None] # Async implementation of batch processing
async def _batch(params: List[MCPServerSseParams], payload: dict[str, Any]) -> List[Any]:
"""Fan out to all MCP servers asynchronously and gather results."""
coros = [_call_one_server(p, payload) for p in params]
results = await asyncio.gather(*coros, return_exceptions=True)
# Filter out exceptions and convert to strings
return [any_to_str(r) for r in results if not isinstance(r, Exception)]
from mcp import ( from mcp import (

Loading…
Cancel
Save