diff --git a/attached_assets/Pasted-from-swarms-import-Agent-from-swarms-tools-mcp-integration-import-MCPServerSseParams-from-loguru-i-1745167190295.txt b/attached_assets/Pasted-from-swarms-import-Agent-from-swarms-tools-mcp-integration-import-MCPServerSseParams-from-loguru-i-1745167190295.txt new file mode 100644 index 00000000..84142de0 --- /dev/null +++ b/attached_assets/Pasted-from-swarms-import-Agent-from-swarms-tools-mcp-integration-import-MCPServerSseParams-from-loguru-i-1745167190295.txt @@ -0,0 +1,83 @@ +from swarms import Agent +from swarms.tools.mcp_integration import MCPServerSseParams +from loguru import logger + +# Comprehensive math prompt that encourages proper JSON formatting +MATH_AGENT_PROMPT = """ +You are a helpful math calculator assistant. + +Your role is to understand natural language math requests and perform calculations. +When asked to perform calculations: + +1. Determine the operation (add, multiply, or divide) +2. Extract the numbers from the request +3. Use the appropriate math operation tool + +FORMAT YOUR TOOL CALLS AS JSON with this format: +{"tool_name": "add", "a": , "b": } +or +{"tool_name": "multiply", "a": , "b": } +or +{"tool_name": "divide", "a": , "b": } + +Always respond with a tool call in JSON format first, followed by a brief explanation. +""" + +def initialize_math_system(): + """Initialize the math agent with MCP server configuration.""" + # Configure the MCP server connection + math_server = MCPServerSseParams( + url="http://0.0.0.0:8000", + headers={"Content-Type": "application/json"}, + timeout=5.0, + sse_read_timeout=30.0 + ) + + # Create the agent with the MCP server configuration + math_agent = Agent( + agent_name="Math Assistant", + agent_description="Friendly math calculator", + system_prompt=MATH_AGENT_PROMPT, + max_loops=1, + mcp_servers=[math_server], # Pass MCP server config as a list + model_name="gpt-3.5-turbo", + verbose=True # Enable verbose mode to see more details + ) + + return math_agent + +def main(): + try: + logger.info("Initializing math system...") + math_agent = initialize_math_system() + + print("\nMath Calculator Ready!") + print("Ask me any math question!") + print("Examples: 'what is 5 plus 3?' or 'can you multiply 4 and 6?'") + print("Type 'exit' to quit\n") + + while True: + try: + query = input("What would you like to calculate? ").strip() + if not query: + continue + if query.lower() == 'exit': + break + + logger.info(f"Processing query: {query}") + result = math_agent.run(query) + print(f"\nResult: {result}\n") + + except KeyboardInterrupt: + print("\nGoodbye!") + break + except Exception as e: + logger.error(f"Error processing query: {e}") + print(f"Sorry, there was an error: {str(e)}") + + except Exception as e: + logger.error(f"System initialization error: {e}") + print(f"Failed to start the math system: {str(e)}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/mcp_example/mcp_client.py b/examples/mcp_example/mcp_client.py index 0279a39a..71b29cdc 100644 --- a/examples/mcp_example/mcp_client.py +++ b/examples/mcp_example/mcp_client.py @@ -1,53 +1,83 @@ + from swarms import Agent + from swarms.tools.mcp_integration import MCPServerSseParams + from loguru import logger -from swarms import Agent -from swarms.tools.mcp_integration import MCPServerSseParams -from swarms.prompts.agent_prompts import MATH_AGENT_PROMPT -from loguru import logger - -def initialize_math_system(): - """Initialize the math agent with MCP server configuration.""" - math_server = MCPServerSseParams( - url="http://0.0.0.0:8000", - headers={"Content-Type": "application/json"}, - timeout=5.0, - sse_read_timeout=30.0 - ) - - math_agent = Agent( - agent_name="Math Assistant", - agent_description="Friendly math calculator", - system_prompt=MATH_AGENT_PROMPT, - max_loops=1, - mcp_servers=[math_server], - model_name="gpt-3.5-turbo" - ) - - return math_agent - -def main(): - math_agent = initialize_math_system() - - print("\nMath Calculator Ready!") - print("Ask me any math question!") - print("Examples: 'what is 5 plus 3?' or 'can you multiply 4 and 6?'") - print("Type 'exit' to quit\n") - - while True: - try: - query = input("What would you like to calculate? ").strip() - if not query: - continue - if query.lower() == 'exit': - break - - result = math_agent.run(query) - print(f"\nResult: {result}\n") - - except KeyboardInterrupt: - print("\nGoodbye!") - break - except Exception as e: - logger.error(f"Error: {e}") - -if __name__ == "__main__": - main() + # Comprehensive math prompt that encourages proper JSON formatting + MATH_AGENT_PROMPT = """ + You are a helpful math calculator assistant. + + Your role is to understand natural language math requests and perform calculations. + When asked to perform calculations: + + 1. Determine the operation (add, multiply, or divide) + 2. Extract the numbers from the request + 3. Use the appropriate math operation tool + + FORMAT YOUR TOOL CALLS AS JSON with this format: + {"tool_name": "add", "a": , "b": } + or + {"tool_name": "multiply", "a": , "b": } + or + {"tool_name": "divide", "a": , "b": } + + Always respond with a tool call in JSON format first, followed by a brief explanation. + """ + + def initialize_math_system(): + """Initialize the math agent with MCP server configuration.""" + # Configure the MCP server connection + math_server = MCPServerSseParams( + url="http://0.0.0.0:8000", + headers={"Content-Type": "application/json"}, + timeout=5.0, + sse_read_timeout=30.0 + ) + + # Create the agent with the MCP server configuration + math_agent = Agent( + agent_name="Math Assistant", + agent_description="Friendly math calculator", + system_prompt=MATH_AGENT_PROMPT, + max_loops=1, + mcp_servers=[math_server], # Pass MCP server config as a list + model_name="gpt-3.5-turbo", + verbose=True # Enable verbose mode to see more details + ) + + return math_agent + + def main(): + try: + logger.info("Initializing math system...") + math_agent = initialize_math_system() + + print("\nMath Calculator Ready!") + print("Ask me any math question!") + print("Examples: 'what is 5 plus 3?' or 'can you multiply 4 and 6?'") + print("Type 'exit' to quit\n") + + while True: + try: + query = input("What would you like to calculate? ").strip() + if not query: + continue + if query.lower() == 'exit': + break + + logger.info(f"Processing query: {query}") + result = math_agent.run(query) + print(f"\nResult: {result}\n") + + except KeyboardInterrupt: + print("\nGoodbye!") + break + except Exception as e: + logger.error(f"Error processing query: {e}") + print(f"Sorry, there was an error: {str(e)}") + + except Exception as e: + logger.error(f"System initialization error: {e}") + print(f"Failed to start the math system: {str(e)}") + + if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/mcp_example/mock_math_server.py b/examples/mcp_example/mock_math_server.py index e2574d95..05ff56f0 100644 --- a/examples/mcp_example/mock_math_server.py +++ b/examples/mcp_example/mock_math_server.py @@ -1,38 +1,79 @@ - from fastmcp import FastMCP from loguru import logger +import time + +# Create the MCP server +mcp = FastMCP(host="0.0.0.0", + port=8000, + transport="sse", + require_session_id=False) -mcp = FastMCP( - host="0.0.0.0", - port=8000, - transport="sse", - require_session_id=False -) +# Define tools with proper type hints and docstrings @mcp.tool() def add(a: int, b: int) -> str: - """Add two numbers.""" + """Add two numbers. + + Args: + a (int): First number + b (int): Second number + + Returns: + str: A message containing the sum + """ + logger.info(f"Adding {a} and {b}") result = a + b return f"The sum of {a} and {b} is {result}" -@mcp.tool() + +@mcp.tool() def multiply(a: int, b: int) -> str: - """Multiply two numbers.""" + """Multiply two numbers. + + Args: + a (int): First number + b (int): Second number + + Returns: + str: A message containing the product + """ + logger.info(f"Multiplying {a} and {b}") result = a * b return f"The product of {a} and {b} is {result}" + @mcp.tool() def divide(a: int, b: int) -> str: - """Divide two numbers.""" + """Divide two numbers. + + Args: + a (int): Numerator + b (int): Denominator + + Returns: + str: A message containing the division result or an error message + """ + logger.info(f"Dividing {a} by {b}") if b == 0: + logger.warning("Division by zero attempted") return "Cannot divide by zero" result = a / b return f"{a} divided by {b} is {result}" + if __name__ == "__main__": try: logger.info("Starting math server on http://0.0.0.0:8000") + print("Math MCP Server is running. Press Ctrl+C to stop.") + + # Add a small delay to ensure logging is complete before the server starts + time.sleep(0.5) + + # Run the MCP server mcp.run() + except KeyboardInterrupt: + logger.info("Server shutdown requested") + print("\nShutting down server...") except Exception as e: logger.error(f"Server error: {e}") raise diff --git a/swarms/prompts/agent_prompts.py b/swarms/prompts/agent_prompts.py index e65ba009..5136b8e0 100644 --- a/swarms/prompts/agent_prompts.py +++ b/swarms/prompts/agent_prompts.py @@ -1,14 +1,13 @@ # Agent prompts for MCP testing and interactions -MATH_AGENT_PROMPT = '''You are a helpful math calculator assistant. +# Keeping the original format that already has JSON formatting +MATH_AGENT_PROMPT = """You are a helpful math calculator assistant. Your role is to understand natural language math requests and perform calculations. - When asked to perform calculations: 1. Determine the operation (add, multiply, or divide) 2. Extract the numbers from the request 3. Use the appropriate math operation tool - -Respond conversationally but be concise. +Format your tool calls as JSON with the tool_name and parameters. Example: User: "what is 5 plus 3?" @@ -17,7 +16,8 @@ You: Using the add operation for 5 and 3 User: "multiply 4 times 6" You: Using multiply for 4 and 6 -{"tool_name": "multiply", "a": 4, "b": 6}''' +{"tool_name": "multiply", "a": 4, "b": 6} +""" FINANCE_AGENT_PROMPT = """You are a financial analysis agent with access to stock market data services. Key responsibilities: @@ -28,42 +28,40 @@ Key responsibilities: Use the available MCP tools to fetch real market data rather than making assumptions.""" + def generate_agent_role_prompt(agent): """Generates the agent role prompt. Args: agent (str): The type of the agent. Returns: str: The agent role prompt. """ prompts = { - "Finance Agent": ( - "You are a seasoned finance analyst AI assistant. Your" - " primary goal is to compose comprehensive, astute," - " impartial, and methodically arranged financial reports" - " based on provided data and trends." - ), - "Travel Agent": ( - "You are a world-travelled AI tour guide assistant. Your" - " main purpose is to draft engaging, insightful," - " unbiased, and well-structured travel reports on given" - " locations, including history, attractions, and cultural" - " insights." - ), - "Academic Research Agent": ( - "You are an AI academic research assistant. Your primary" - " responsibility is to create thorough, academically" - " rigorous, unbiased, and systematically organized" - " reports on a given research topic, following the" - " standards of scholarly work." - ), - "Default Agent": ( - "You are an AI critical thinker research assistant. Your" - " sole purpose is to write well written, critically" - " acclaimed, objective and structured reports on given" - " text." - ), + "Finance Agent": + ("You are a seasoned finance analyst AI assistant. Your" + " primary goal is to compose comprehensive, astute," + " impartial, and methodically arranged financial reports" + " based on provided data and trends."), + "Travel Agent": + ("You are a world-travelled AI tour guide assistant. Your" + " main purpose is to draft engaging, insightful," + " unbiased, and well-structured travel reports on given" + " locations, including history, attractions, and cultural" + " insights."), + "Academic Research Agent": + ("You are an AI academic research assistant. Your primary" + " responsibility is to create thorough, academically" + " rigorous, unbiased, and systematically organized" + " reports on a given research topic, following the" + " standards of scholarly work."), + "Default Agent": + ("You are an AI critical thinker research assistant. Your" + " sole purpose is to write well written, critically" + " acclaimed, objective and structured reports on given" + " text."), } return prompts.get(agent, "No such agent") + def generate_report_prompt(question, research_summary): """Generates the report prompt for the given question and research summary. Args: question (str): The question to generate the report prompt for @@ -71,16 +69,15 @@ def generate_report_prompt(question, research_summary): Returns: str: The report prompt for the given question and research summary """ - return ( - f'"""{research_summary}""" Using the above information,' - f' answer the following question or topic: "{question}" in a' - " detailed report -- The report should focus on the answer" - " to the question, should be well structured, informative," - " in depth, with facts and numbers if available, a minimum" - " of 1,200 words and with markdown syntax and apa format." - " Write all source urls at the end of the report in apa" - " format" - ) + return (f'"""{research_summary}""" Using the above information,' + f' answer the following question or topic: "{question}" in a' + " detailed report -- The report should focus on the answer" + " to the question, should be well structured, informative," + " in depth, with facts and numbers if available, a minimum" + " of 1,200 words and with markdown syntax and apa format." + " Write all source urls at the end of the report in apa" + " format") + def generate_search_queries_prompt(question): """Generates the search queries prompt for the given question. @@ -88,12 +85,11 @@ def generate_search_queries_prompt(question): Returns: str: The search queries prompt for the given question """ - return ( - "Write 4 google search queries to search online that form an" - f' objective opinion from the following: "{question}"You must' - " respond with a list of strings in the following format:" - ' ["query 1", "query 2", "query 3", "query 4"]' - ) + return ("Write 4 google search queries to search online that form an" + f' objective opinion from the following: "{question}"You must' + " respond with a list of strings in the following format:" + ' ["query 1", "query 2", "query 3", "query 4"]') + def generate_resource_report_prompt(question, research_summary): """Generates the resource report prompt for the given question and research summary. @@ -105,19 +101,18 @@ def generate_resource_report_prompt(question, research_summary): Returns: str: The resource report prompt for the given question and research summary. """ - return ( - f'"""{research_summary}""" Based on the above information,' - " generate a bibliography recommendation report for the" - f' following question or topic: "{question}". The report' - " should provide a detailed analysis of each recommended" - " resource, explaining how each source can contribute to" - " finding answers to the research question. Focus on the" - " relevance, reliability, and significance of each source." - " Ensure that the report is well-structured, informative," - " in-depth, and follows Markdown syntax. Include relevant" - " facts, figures, and numbers whenever available. The report" - " should have a minimum length of 1,200 words." - ) + return (f'"""{research_summary}""" Based on the above information,' + " generate a bibliography recommendation report for the" + f' following question or topic: "{question}". The report' + " should provide a detailed analysis of each recommended" + " resource, explaining how each source can contribute to" + " finding answers to the research question. Focus on the" + " relevance, reliability, and significance of each source." + " Ensure that the report is well-structured, informative," + " in-depth, and follows Markdown syntax. Include relevant" + " facts, figures, and numbers whenever available. The report" + " should have a minimum length of 1,200 words.") + def generate_outline_report_prompt(question, research_summary): """Generates the outline report prompt for the given question and research summary. @@ -126,17 +121,16 @@ def generate_outline_report_prompt(question, research_summary): Returns: str: The outline report prompt for the given question and research summary """ - return ( - f'"""{research_summary}""" Using the above information,' - " generate an outline for a research report in Markdown" - f' syntax for the following question or topic: "{question}".' - " The outline should provide a well-structured framework for" - " the research report, including the main sections," - " subsections, and key points to be covered. The research" - " report should be detailed, informative, in-depth, and a" - " minimum of 1,200 words. Use appropriate Markdown syntax to" - " format the outline and ensure readability." - ) + return (f'"""{research_summary}""" Using the above information,' + " generate an outline for a research report in Markdown" + f' syntax for the following question or topic: "{question}".' + " The outline should provide a well-structured framework for" + " the research report, including the main sections," + " subsections, and key points to be covered. The research" + " report should be detailed, informative, in-depth, and a" + " minimum of 1,200 words. Use appropriate Markdown syntax to" + " format the outline and ensure readability.") + def generate_concepts_prompt(question, research_summary): """Generates the concepts prompt for the given question. @@ -145,15 +139,14 @@ def generate_concepts_prompt(question, research_summary): Returns: str: The concepts prompt for the given question """ - return ( - f'"""{research_summary}""" Using the above information,' - " generate a list of 5 main concepts to learn for a research" - f' report on the following question or topic: "{question}".' - " The outline should provide a well-structured frameworkYou" - " must respond with a list of strings in the following" - ' format: ["concepts 1", "concepts 2", "concepts 3",' - ' "concepts 4, concepts 5"]' - ) + return (f'"""{research_summary}""" Using the above information,' + " generate a list of 5 main concepts to learn for a research" + f' report on the following question or topic: "{question}".' + " The outline should provide a well-structured frameworkYou" + " must respond with a list of strings in the following" + ' format: ["concepts 1", "concepts 2", "concepts 3",' + ' "concepts 4, concepts 5"]') + def generate_lesson_prompt(concept): """ @@ -164,20 +157,19 @@ def generate_lesson_prompt(concept): str: The lesson prompt for the given concept. """ - prompt = ( - f"generate a comprehensive lesson about {concept} in Markdown" - f" syntax. This should include the definitionof {concept}," - " its historical background and development, its" - " applications or uses in differentfields, and notable" - f" events or facts related to {concept}." - ) + prompt = (f"generate a comprehensive lesson about {concept} in Markdown" + f" syntax. This should include the definitionof {concept}," + " its historical background and development, its" + " applications or uses in differentfields, and notable" + f" events or facts related to {concept}.") return prompt + def get_report_by_type(report_type): report_type_mapping = { "research_report": generate_report_prompt, "resource_report": generate_resource_report_prompt, "outline_report": generate_outline_report_prompt, } - return report_type_mapping[report_type] \ No newline at end of file + return report_type_mapping[report_type] diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index 5cceabc7..44c9a95a 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -2647,18 +2647,7 @@ class Agent: else: return str(response) - async def mcp_execution_flow(self, tool_call): - """Execute MCP tool call flow""" - try: - result = await execute_mcp_tool( - url=self.mcp_servers[0]["url"], - parameters=tool_call, - output_type="str", - ) - return result - except Exception as e: - logger.error(f"Error executing tool call: {e}") - return f"Error executing tool call: {e}" + def sentiment_and_evaluator(self, response: str): if self.evaluator: @@ -2688,4 +2677,137 @@ class Agent: self.short_memory.add( role="Output Cleaner", content=response, - ) \ No newline at end of file + ) + + async def amcp_execution_flow(self, response: str) -> str: + """Async implementation of MCP execution flow. + + Args: + response (str): The response from the LLM containing tool calls or natural language. + + Returns: + str: The result of executing the tool calls with preserved formatting. + """ + try: + # Try to parse as JSON first + try: + tool_calls = json.loads(response) + is_json = True + logger.debug(f"Successfully parsed response as JSON: {tool_calls}") + except json.JSONDecodeError: + # If not JSON, treat as natural language + tool_calls = [response] + is_json = False + logger.debug(f"Could not parse response as JSON, treating as natural language") + + # Execute tool calls against MCP servers + results = [] + errors = [] + + # Handle both single tool call and array of tool calls + if isinstance(tool_calls, dict): + tool_calls = [tool_calls] + + logger.debug(f"Executing {len(tool_calls)} tool calls against {len(self.mcp_servers)} MCP servers") + + for tool_call in tool_calls: + try: + # Import here to avoid circular imports + from swarms.tools.mcp_integration import abatch_mcp_flow + + logger.debug(f"Executing tool call: {tool_call}") + # Execute the tool call against all MCP servers + result = await abatch_mcp_flow(self.mcp_servers, tool_call) + + if result: + logger.debug(f"Got result from MCP servers: {result}") + results.extend(result) + # Add successful result to memory with context + self.short_memory.add( + role="assistant", + content=f"Tool execution result: {result}" + ) + else: + error_msg = "No result from tool execution" + errors.append(error_msg) + logger.debug(error_msg) + self.short_memory.add( + role="error", + content=error_msg + ) + + except Exception as e: + error_msg = f"Error executing tool call: {str(e)}" + errors.append(error_msg) + logger.error(error_msg) + self.short_memory.add( + role="error", + content=error_msg + ) + + # Format the final response + if results: + if len(results) == 1: + # For single results, return as is to preserve formatting + return results[0] + else: + # For multiple results, combine with context + formatted_results = [] + for i, result in enumerate(results, 1): + formatted_results.append(f"Result {i}: {result}") + return "\n".join(formatted_results) + elif errors: + if len(errors) == 1: + return errors[0] + else: + return "Multiple errors occurred:\n" + "\n".join(f"- {err}" for err in errors) + else: + return "No results or errors returned" + + except Exception as e: + error_msg = f"Error in MCP execution flow: {str(e)}" + logger.error(error_msg) + self.short_memory.add( + role="error", + content=error_msg + ) + return error_msg + + + def mcp_execution_flow(self, response: str) -> str: + """Synchronous wrapper for MCP execution flow. + + This method creates a new event loop if needed or uses the existing one + to run the async MCP execution flow. + + Args: + response (str): The response from the LLM containing tool calls or natural language. + + Returns: + str: The result of executing the tool calls with preserved formatting. + """ + try: + # Check if we're already in an event loop + try: + loop = asyncio.get_event_loop() + except RuntimeError: + # No event loop exists, create one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + if loop.is_running(): + # We're in an async context, use run_coroutine_threadsafe + logger.debug("Using run_coroutine_threadsafe to execute MCP flow") + future = asyncio.run_coroutine_threadsafe( + self.amcp_execution_flow(response), loop + ) + return future.result(timeout=30) # Adding timeout to prevent hanging + else: + # We're not in an async context, use loop.run_until_complete + logger.debug("Using run_until_complete to execute MCP flow") + return loop.run_until_complete(self.amcp_execution_flow(response)) + + except Exception as e: + error_msg = f"Error in MCP execution flow wrapper: {str(e)}" + logger.error(error_msg) + return error_msg diff --git a/swarms/tools/mcp_integration.py b/swarms/tools/mcp_integration.py index 8878b0d0..0959f5f2 100644 --- a/swarms/tools/mcp_integration.py +++ b/swarms/tools/mcp_integration.py @@ -1,255 +1,320 @@ -from __future__ import annotations - -import abc -import asyncio -from contextlib import AbstractAsyncContextManager, AsyncExitStack -from pathlib import Path -from typing import Any, Dict, List, Optional, Literal -from typing_extensions import NotRequired, TypedDict - -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from loguru import logger -from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client -from mcp.client.sse import sse_client -from mcp.types import CallToolResult, JSONRPCMessage - -from swarms.utils.any_to_str import any_to_str - - -class MCPServer(abc.ABC): - """Base class for Model Context Protocol servers.""" - - @abc.abstractmethod - async def connect(self) -> None: - """Establish connection to the MCP server.""" - pass - - @property - @abc.abstractmethod - def name(self) -> str: - """Human-readable server name.""" - pass - - @abc.abstractmethod - async def cleanup(self) -> None: - """Clean up resources and close connection.""" - pass - - @abc.abstractmethod - async def list_tools(self) -> List[MCPTool]: - """List available MCP tools on the server.""" - pass - - @abc.abstractmethod - async def call_tool( - self, tool_name: str, arguments: Dict[str, Any] | None - ) -> CallToolResult: - """Invoke a tool by name with provided arguments.""" - pass - - -class _MCPServerWithClientSession(MCPServer, abc.ABC): - """Mixin providing ClientSession-based MCP communication.""" - - def __init__(self, cache_tools_list: bool = False): - self.session: Optional[ClientSession] = None - self.exit_stack: AsyncExitStack = AsyncExitStack() - self._cleanup_lock = asyncio.Lock() - self.cache_tools_list = cache_tools_list - self._cache_dirty = True - self._tools_list: Optional[List[MCPTool]] = None - - @abc.abstractmethod - def create_streams( - self - ) -> AbstractAsyncContextManager[ - tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], - ] - ]: - """Supply the read/write streams for the MCP transport.""" - pass - - async def __aenter__(self) -> MCPServer: - await self.connect() - return self # type: ignore - - async def __aexit__(self, exc_type, exc_value, tb) -> None: - await self.cleanup() - - async def connect(self) -> None: - """Initialize transport and ClientSession.""" - try: - transport = await self.exit_stack.enter_async_context( - self.create_streams() - ) - read, write = transport - session = await self.exit_stack.enter_async_context( - ClientSession(read, write) - ) - await session.initialize() - self.session = session - except Exception as e: - logger.error(f"Error initializing MCP server: {e}") - await self.cleanup() - raise - - async def cleanup(self) -> None: - """Close session and transport.""" - async with self._cleanup_lock: - try: - await self.exit_stack.aclose() - except Exception as e: - logger.error(f"Error during cleanup: {e}") - finally: - self.session = None - - async def list_tools(self) -> List[MCPTool]: - if not self.session: - raise RuntimeError("Server not connected. Call connect() first.") - if self.cache_tools_list and not self._cache_dirty and self._tools_list: - return self._tools_list - self._cache_dirty = False - self._tools_list = (await self.session.list_tools()).tools - return self._tools_list # type: ignore - - async def call_tool( - self, tool_name: str | None = None, arguments: Dict[str, Any] | None = None - ) -> CallToolResult: - if not arguments: - raise ValueError("Arguments dict is required to call a tool") - name = tool_name or arguments.get("tool_name") or arguments.get("name") - if not name: - raise ValueError("Tool name missing in arguments") - if not self.session: - raise RuntimeError("Server not connected. Call connect() first.") - return await self.session.call_tool(name, arguments) - - -class MCPServerStdioParams(TypedDict): - """Configuration for stdio transport.""" - command: str - args: NotRequired[List[str]] - env: NotRequired[Dict[str, str]] - cwd: NotRequired[str | Path] - encoding: NotRequired[str] - encoding_error_handler: NotRequired[Literal["strict", "ignore", "replace"]] - - -class MCPServerStdio(_MCPServerWithClientSession): - """MCP server over stdio transport.""" - - def __init__( - self, - params: MCPServerStdioParams, - cache_tools_list: bool = False, - name: Optional[str] = None, - ): - super().__init__(cache_tools_list) - self.params = StdioServerParameters( - command=params["command"], - args=params.get("args", []), - env=params.get("env"), - cwd=params.get("cwd"), - encoding=params.get("encoding", "utf-8"), - encoding_error_handler=params.get("encoding_error_handler", "strict"), - ) - self._name = name or f"stdio:{self.params.command}" - - def create_streams(self) -> AbstractAsyncContextManager[ - tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], - ] - ]: - return stdio_client(self.params) - - @property - def name(self) -> str: - return self._name - - -class MCPServerSseParams(TypedDict): - """Configuration for HTTP+SSE transport.""" - url: str - headers: NotRequired[Dict[str, str]] - timeout: NotRequired[float] - sse_read_timeout: NotRequired[float] - - -class MCPServerSse(_MCPServerWithClientSession): - """MCP server over HTTP with SSE transport.""" - - def __init__( - self, - params: MCPServerSseParams, - cache_tools_list: bool = False, - name: Optional[str] = None, - ): - super().__init__(cache_tools_list) - self.params = params - self._name = name or f"sse:{params['url']}" - - def create_streams(self) -> AbstractAsyncContextManager[ - tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], - ] - ]: - return sse_client( - url=self.params["url"], - headers=self.params.get("headers"), - timeout=self.params.get("timeout", 5), - sse_read_timeout=self.params.get("sse_read_timeout", 300), - ) - - @property - def name(self) -> str: - return self._name - - -async def call_tool_fast( - server: MCPServerSse, payload: Dict[str, Any] | str -) -> Any: - try: - await server.connect() - result = await server.call_tool(arguments=payload if isinstance(payload, dict) else None) - return result - finally: - await server.cleanup() - - -async def mcp_flow_get_tool_schema( - params: MCPServerSseParams, -) -> Any: - async with MCPServerSse(params) as server: - tools = await server.list_tools() - return tools - - -async def mcp_flow( - params: MCPServerSseParams, - function_call: Dict[str, Any] | str, -) -> Any: - async with MCPServerSse(params) as server: - return await call_tool_fast(server, function_call) - - -async def _call_one_server( - params: MCPServerSseParams, payload: Dict[str, Any] | str -) -> Any: - server = MCPServerSse(params) - try: - await server.connect() - return await server.call_tool(arguments=payload if isinstance(payload, dict) else None) - finally: - await server.cleanup() - - -def batch_mcp_flow( - params: List[MCPServerSseParams], payload: Dict[str, Any] | str -) -> List[Any]: - return asyncio.run( - asyncio.gather(*[_call_one_server(p, payload) for p in params]) - ) \ No newline at end of file + from __future__ import annotations + + import abc + import asyncio + from contextlib import AbstractAsyncContextManager, AsyncExitStack + from pathlib import Path + from typing import Any, Dict, List, Optional, Literal, Union + from typing_extensions import NotRequired, TypedDict + + from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream + from loguru import logger + from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client + from mcp.client.sse import sse_client + from mcp.types import CallToolResult, JSONRPCMessage + + from swarms.utils.any_to_str import any_to_str + + + class MCPServer(abc.ABC): + """Base class for Model Context Protocol servers.""" + + @abc.abstractmethod + async def connect(self) -> None: + """Establish connection to the MCP server.""" + pass + + @property + @abc.abstractmethod + def name(self) -> str: + """Human-readable server name.""" + pass + + @abc.abstractmethod + async def cleanup(self) -> None: + """Clean up resources and close connection.""" + pass + + @abc.abstractmethod + async def list_tools(self) -> List[MCPTool]: + """List available MCP tools on the server.""" + pass + + @abc.abstractmethod + async def call_tool( + self, tool_name: str, arguments: Dict[str, Any] | None + ) -> CallToolResult: + """Invoke a tool by name with provided arguments.""" + pass + + + class _MCPServerWithClientSession(MCPServer, abc.ABC): + """Mixin providing ClientSession-based MCP communication.""" + + def __init__(self, cache_tools_list: bool = False): + self.session: Optional[ClientSession] = None + self.exit_stack: AsyncExitStack = AsyncExitStack() + self._cleanup_lock = asyncio.Lock() + self.cache_tools_list = cache_tools_list + self._cache_dirty = True + self._tools_list: Optional[List[MCPTool]] = None + + @abc.abstractmethod + def create_streams( + self + ) -> AbstractAsyncContextManager[ + tuple[ + MemoryObjectReceiveStream[JSONRPCMessage | Exception], + MemoryObjectSendStream[JSONRPCMessage], + ] + ]: + """Supply the read/write streams for the MCP transport.""" + pass + + async def __aenter__(self) -> MCPServer: + await self.connect() + return self # type: ignore + + async def __aexit__(self, exc_type, exc_value, tb) -> None: + await self.cleanup() + + async def connect(self) -> None: + """Initialize transport and ClientSession.""" + try: + transport = await self.exit_stack.enter_async_context( + self.create_streams() + ) + read, write = transport + session = await self.exit_stack.enter_async_context( + ClientSession(read, write) + ) + await session.initialize() + self.session = session + except Exception as e: + logger.error(f"Error initializing MCP server: {e}") + await self.cleanup() + raise + + async def cleanup(self) -> None: + """Close session and transport.""" + async with self._cleanup_lock: + try: + await self.exit_stack.aclose() + except Exception as e: + logger.error(f"Error during cleanup: {e}") + finally: + self.session = None + + async def list_tools(self) -> List[MCPTool]: + if not self.session: + raise RuntimeError("Server not connected. Call connect() first.") + if self.cache_tools_list and not self._cache_dirty and self._tools_list: + return self._tools_list + self._cache_dirty = False + self._tools_list = (await self.session.list_tools()).tools + return self._tools_list # type: ignore + + async def call_tool( + self, tool_name: str | None = None, arguments: Dict[str, Any] | None = None + ) -> CallToolResult: + if not arguments: + raise ValueError("Arguments dict is required to call a tool") + name = tool_name or arguments.get("tool_name") or arguments.get("name") + if not name: + raise ValueError("Tool name missing in arguments") + if not self.session: + raise RuntimeError("Server not connected. Call connect() first.") + return await self.session.call_tool(name, arguments) + + + class MCPServerStdioParams(TypedDict): + """Configuration for stdio transport.""" + command: str + args: NotRequired[List[str]] + env: NotRequired[Dict[str, str]] + cwd: NotRequired[str | Path] + encoding: NotRequired[str] + encoding_error_handler: NotRequired[Literal["strict", "ignore", "replace"]] + + + class MCPServerStdio(_MCPServerWithClientSession): + """MCP server over stdio transport.""" + + def __init__( + self, + params: MCPServerStdioParams, + cache_tools_list: bool = False, + name: Optional[str] = None, + ): + super().__init__(cache_tools_list) + self.params = StdioServerParameters( + command=params["command"], + args=params.get("args", []), + env=params.get("env"), + cwd=params.get("cwd"), + encoding=params.get("encoding", "utf-8"), + encoding_error_handler=params.get("encoding_error_handler", "strict"), + ) + self._name = name or f"stdio:{self.params.command}" + + def create_streams(self) -> AbstractAsyncContextManager[ + tuple[ + MemoryObjectReceiveStream[JSONRPCMessage | Exception], + MemoryObjectSendStream[JSONRPCMessage], + ] + ]: + return stdio_client(self.params) + + @property + def name(self) -> str: + return self._name + + + class MCPServerSseParams(TypedDict): + """Configuration for HTTP+SSE transport.""" + url: str + headers: NotRequired[Dict[str, str]] + timeout: NotRequired[float] + sse_read_timeout: NotRequired[float] + + + class MCPServerSse(_MCPServerWithClientSession): + """MCP server over HTTP with SSE transport.""" + + def __init__( + self, + params: MCPServerSseParams, + cache_tools_list: bool = False, + name: Optional[str] = None, + ): + super().__init__(cache_tools_list) + self.params = params + self._name = name or f"sse:{params['url']}" + + def create_streams(self) -> AbstractAsyncContextManager[ + tuple[ + MemoryObjectReceiveStream[JSONRPCMessage | Exception], + MemoryObjectSendStream[JSONRPCMessage], + ] + ]: + return sse_client( + url=self.params["url"], + headers=self.params.get("headers"), + timeout=self.params.get("timeout", 5), + sse_read_timeout=self.params.get("sse_read_timeout", 300), + ) + + @property + def name(self) -> str: + return self._name + + + async def call_tool_fast( + server: MCPServerSse, payload: Dict[str, Any] | str + ) -> Any: + """Async function to call a tool on a server with proper cleanup.""" + try: + await server.connect() + arguments = payload if isinstance(payload, dict) else None + result = await server.call_tool(arguments=arguments) + return result + finally: + await server.cleanup() + + + async def mcp_flow_get_tool_schema( + params: MCPServerSseParams, + ) -> Any: + """Async function to get tool schema from MCP server.""" + async with MCPServerSse(params) as server: + tools = await server.list_tools() + return tools + + + async def mcp_flow( + params: MCPServerSseParams, + function_call: Dict[str, Any] | str, + ) -> Any: + """Async function to call a tool with given parameters.""" + async with MCPServerSse(params) as server: + return await call_tool_fast(server, function_call) + + + async def _call_one_server( + params: MCPServerSseParams, payload: Dict[str, Any] | str + ) -> Any: + """Helper function to call a single MCP server.""" + server = MCPServerSse(params) + try: + await server.connect() + arguments = payload if isinstance(payload, dict) else None + return await server.call_tool(arguments=arguments) + finally: + await server.cleanup() + + + async def abatch_mcp_flow( + params: List[MCPServerSseParams], payload: Dict[str, Any] | str + ) -> List[Any]: + """Async function to execute a batch of MCP calls concurrently. + + Args: + params (List[MCPServerSseParams]): List of MCP server configurations + payload (Dict[str, Any] | str): The payload to send to each server + + Returns: + List[Any]: Results from all MCP servers + """ + if not params: + logger.warning("No MCP servers provided for batch operation") + return [] + + try: + return await asyncio.gather(*[_call_one_server(p, payload) for p in params]) + except Exception as e: + logger.error(f"Error in abatch_mcp_flow: {e}") + # Return partial results if any were successful + return [f"Error in batch operation: {str(e)}"] + + + def batch_mcp_flow( + params: List[MCPServerSseParams], payload: Dict[str, Any] | str + ) -> List[Any]: + """Sync wrapper for batch MCP operations. + + This creates a new event loop if needed to run the async batch operation. + ONLY use this when not already in an async context. + + Args: + params (List[MCPServerSseParams]): List of MCP server configurations + payload (Dict[str, Any] | str): The payload to send to each server + + Returns: + List[Any]: Results from all MCP servers + """ + if not params: + logger.warning("No MCP servers provided for batch operation") + return [] + + try: + # Check if we're already in an event loop + try: + loop = asyncio.get_event_loop() + except RuntimeError: + # No event loop exists, create one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + if loop.is_running(): + # We're already in an async context, can't use asyncio.run + # Use a future to bridge sync-async gap + future = asyncio.run_coroutine_threadsafe( + abatch_mcp_flow(params, payload), loop + ) + return future.result(timeout=30) # Add timeout to prevent hanging + else: + # We're not in an async context, safe to use loop.run_until_complete + return loop.run_until_complete(abatch_mcp_flow(params, payload)) + except Exception as e: + logger.error(f"Error in batch_mcp_flow: {e}") + return [f"Error in batch operation: {str(e)}"] \ No newline at end of file