diff --git a/.replit b/.replit index a11e85df..0cdbe71c 100644 --- a/.replit +++ b/.replit @@ -23,11 +23,11 @@ args = "python -m unittest tests/tools/test_mcp_integration.py -v" [[workflows.workflow]] name = "Run MCP Demo" author = 13983571 -mode = "sequential" +mode = "parallel" [[workflows.workflow.tasks]] task = "shell.exec" -args = "python examples/mcp_example/mock_math_server.py & " +args = "python examples/mcp_example/mock_math_server.py" [[workflows.workflow.tasks]] task = "shell.exec" diff --git a/attached_assets/Pasted-The-root-of-that-unhandled-errors-in-a-TaskGroup-1-sub-exception-is-simply-that-your-client-s-M-1745170772061.txt b/attached_assets/Pasted-The-root-of-that-unhandled-errors-in-a-TaskGroup-1-sub-exception-is-simply-that-your-client-s-M-1745170772061.txt new file mode 100644 index 00000000..f69de41a --- /dev/null +++ b/attached_assets/Pasted-The-root-of-that-unhandled-errors-in-a-TaskGroup-1-sub-exception-is-simply-that-your-client-s-M-1745170772061.txt @@ -0,0 +1,100 @@ +The root of that “unhandled errors in a TaskGroup (1 sub‑exception)” is simply that your client’s `MCPServerSse.connect()` is failing under the hood (most likely a connection/refused or path‐not‐found error) and AnyIO is wrapping it in a TaskGroup exception. You don’t see the real cause because it gets hidden by AnyIO’s TaskGroup. Here’s how to unmask it and fix it: + +--- + +## 1. Diagnose the real error +Wrap the connect call and print the underlying exception: + +```python +async def _test_connect(): + server = MCPServerSse(get_server_params()) + try: + await server.connect() + await server.cleanup() + return True + except Exception as e: + # Print the actual cause + import traceback; traceback.print_exc() + return False + +print(asyncio.run(_test_connect())) +``` + +You’ll probably see a **connection refused** or **404 on /sse** in the stack trace. + +--- + +## 2. Ensure client and server agree on your SSE endpoint +By default FastMCP serves its SSE stream at `/sse` and messages on `/messages`. If you only pass `url="http://127.0.0.1:8000"` the client will try whatever its default path is (often `/events` or `/stream`). You need to be explicit: + +```python +from swarms.tools.mcp_integration import MCPServerSseParams + +def get_server_params(): + return MCPServerSseParams( + url="http://127.0.0.1:8000", + sse_path="/sse", # <— tell it exactly where the SSE lives + messages_path="/messages", # <— if your server uses /messages for POSTs + headers={ + "Content-Type": "application/json", + "Accept": "text/event-stream", + }, + timeout=15.0, + sse_read_timeout=60.0, + require_session_id=False, # match your server’s require_session_id + ) +``` + +--- + +## 3. Don’t manually call `MCPServerSse` unless you need to +Your `test_server_connection()` can more reliably just do a raw HTTP(S) health‑check: + +```python +def test_server_connection(): + health_url = get_server_params().url + get_server_params().sse_path + try: + r = httpx.get(health_url, + headers={"Accept":"text/event-stream"}, + timeout=5.0) + if r.status_code == 200: + logger.info("✅ SSE endpoint is up") + return True + else: + logger.error(f"❌ Unexpected status {r.status_code}") + return False + except Exception as e: + logger.error(f"❌ Connection to SSE endpoint failed: {e}") + return False +``` + +That way you see immediately if the server is refusing connections or returning 404. + +--- + +## 4. Align your Agent configuration +Once you’ve verified the raw GET to `http://127.0.0.1:8000/sse` is 200, your Agent should work with exactly the same params: + +```python +math_agent = Agent( + agent_name="Math Assistant", + agent_description="Friendly math calculator", + system_prompt=MATH_AGENT_PROMPT, + max_loops=1, + model_name="gpt-3.5-turbo", + verbose=True, + mcp_servers=[ get_server_params() ] +) +``` + +Now when you do `math_agent.run("add 3 and 4")`, the SSE handshake will succeed and you’ll no longer see that TaskGroup error. + +--- + +### TL;DR +1. **Print the real exception** behind the TaskGroup to see “connection refused” or “404.” +2. **Explicitly set** `sse_path="/sse"` (and `messages_path`) in `MCPServerSseParams`. +3. **Health‑check** with a simple `httpx.get("…/sse")` instead of `server.connect()`. +4. Pass those same params straight into your `Agent`. + +Once your client is pointing at the exact SSE URL your FastMCP server is serving, the Agent will connect cleanly and you’ll be back to doing math instead of wrestling TaskGroup errors. \ No newline at end of file diff --git a/attached_assets/Pasted-from-swarms-import-Agent-from-swarms-tools-mcp-integration-import-MCPServerSseParams-MCPServerSse--1745170779273.txt b/attached_assets/Pasted-from-swarms-import-Agent-from-swarms-tools-mcp-integration-import-MCPServerSseParams-MCPServerSse--1745170779273.txt new file mode 100644 index 00000000..7ad49aac --- /dev/null +++ b/attached_assets/Pasted-from-swarms-import-Agent-from-swarms-tools-mcp-integration-import-MCPServerSseParams-MCPServerSse--1745170779273.txt @@ -0,0 +1,397 @@ +from swarms import Agent +from swarms.tools.mcp_integration import MCPServerSseParams, MCPServerSse, mcp_flow_get_tool_schema +from loguru import logger +import sys +import asyncio +import json +import httpx +import time + +# Configure logging for more detailed output +logger.remove() +logger.add(sys.stdout, + level="DEBUG", + format="{time} | {level} | {module}:{function}:{line} - {message}") + +# Relaxed prompt that doesn't enforce strict JSON formatting + + + +# Create server parameters +def get_server_params(): + """Get the MCP server connection parameters.""" + return MCPServerSseParams( + url= + "http://127.0.0.1:8000", # Use 127.0.0.1 instead of localhost/0.0.0.0 + headers={ + "Content-Type": "application/json", + "Accept": "text/event-stream" + }, + timeout=15.0, # Longer timeout + sse_read_timeout=60.0 # Longer read timeout + ) + + +def initialize_math_system(): + """Initialize the math agent with MCP server configuration.""" + # Create the agent with the MCP server configuration + math_agent = Agent(agent_name="Math Assistant", + agent_description="Friendly math calculator", + system_prompt=MATH_AGENT_PROMPT, + max_loops=1, + mcp_servers=[get_server_params()], + model_name="gpt-3.5-turbo", + verbose=True) + + return math_agent + + +# Function to get list of available tools from the server +async def get_tools_list(): + """Fetch and format the list of available tools from the server.""" + try: + server_params = get_server_params() + tools = await mcp_flow_get_tool_schema(server_params) + + if not tools: + return "No tools are currently available on the server." + + # Format the tools information + tools_info = "Available tools:\n" + for tool in tools: + tools_info += f"\n- {tool.name}: {tool.description or 'No description'}\n" + if tool.parameters and hasattr(tool.parameters, 'properties'): + tools_info += " Parameters:\n" + for param_name, param_info in tool.parameters.properties.items( + ): + param_type = param_info.get('type', 'unknown') + param_desc = param_info.get('description', + 'No description') + tools_info += f" - {param_name} ({param_type}): {param_desc}\n" + + return tools_info + except Exception as e: + logger.error(f"Failed to get tools list: {e}") + return f"Error retrieving tools list: {str(e)}" + + +# Function to test server connection +def test_server_connection(): + """Test if the server is reachable and responsive.""" + try: + # Create a short-lived connection to check server + server = MCPServerSse(get_server_params()) + + # Try connecting (this is synchronous) + asyncio.run(server.connect()) + asyncio.run(server.cleanup()) + logger.info("✅ Server connection test successful") + return True + except Exception as e: + logger.error(f"❌ Server connection test failed: {e}") + return False + + +# Manual math operation handler as ultimate fallback +def manual_math(query): + """Parse and solve a math problem without using the server.""" + query = query.lower() + + # Check if user is asking for available tools/functions + if "list" in query and ("tools" in query or "functions" in query + or "operations" in query): + return """ +Available tools: +1. add - Add two numbers together (e.g., "add 3 and 4") +2. multiply - Multiply two numbers together (e.g., "multiply 5 and 6") +3. divide - Divide the first number by the second (e.g., "divide 10 by 2") +""" + + try: + if "add" in query or "plus" in query or "sum" in query: + # Extract numbers using a simple approach + numbers = [int(s) for s in query.split() if s.isdigit()] + if len(numbers) >= 2: + result = numbers[0] + numbers[1] + return f"The sum of {numbers[0]} and {numbers[1]} is {result}" + + elif "multiply" in query or "times" in query or "product" in query: + numbers = [int(s) for s in query.split() if s.isdigit()] + if len(numbers) >= 2: + result = numbers[0] * numbers[1] + return f"The product of {numbers[0]} and {numbers[1]} is {result}" + + elif "divide" in query or "quotient" in query: + numbers = [int(s) for s in query.split() if s.isdigit()] + if len(numbers) >= 2: + if numbers[1] == 0: + return "Cannot divide by zero" + result = numbers[0] / numbers[1] + return f"{numbers[0]} divided by {numbers[1]} is {result}" + + return "I couldn't parse your math request. Try something like 'add 3 and 4'." + except Exception as e: + logger.error(f"Manual math error: {e}") + return f"Error performing calculation: {str(e)}" + + +def main(): + try: + logger.info("Initializing math system...") + + # Test server connection first + server_available = test_server_connection() + + if server_available: + math_agent = initialize_math_system() + print("\nMath Calculator Ready! (Server connection successful)") + else: + print( + "\nServer connection failed - using fallback calculator mode") + math_agent = None + + print("Ask me any math question!") + print("Examples: 'what is 5 plus 3?' or 'can you multiply 4 and 6?'") + print("Type 'list tools' to see available operations") + print("Type 'exit' to quit\n") + + while True: + try: + query = input("What would you like to calculate? ").strip() + if not query: + continue + if query.lower() == 'exit': + break + + # Handle special commands + if query.lower() in ('list tools', 'show tools', + 'available tools', 'what tools'): + if server_available: + # Get tools list from server + tools_info = asyncio.run(get_tools_list()) + print(f"\n{tools_info}\n") + else: + # Use manual fallback + print(manual_math("list tools")) + continue + + logger.info(f"Processing query: {query}") + + # First try the agent if available + if math_agent and server_available: + try: + result = math_agent.run(query) + print(f"\nResult: {result}\n") + continue + except Exception as e: + logger.error(f"Agent error: {e}") + print("Agent encountered an error, trying fallback...") + + # If agent fails or isn't available, use manual calculator + result = manual_math(query) + print(f"\nCalculation result: {result}\n") + + except KeyboardInterrupt: + print("\nGoodbye!") + break + except Exception as e: + logger.error(f"Error processing query: {e}") + print(f"Sorry, there was an error: {str(e)}") + + except Exception as e: + logger.error(f"System initialization error: {e}") + print(f"Failed to start the math system: {str(e)}") + + +if __name__ == "__main__": + main() "from fastmcp import FastMCP +from loguru import logger +import time +import json + +# Create the MCP server with detailed debugging +mcp = FastMCP( + host="0.0.0.0", # Bind to all interfaces + port=8000, + transport="sse", + require_session_id=False, + cors_allowed_origins=["*"], # Allow connections from any origin + debug=True # Enable debug mode for more verbose output +) + + +# Add a more flexible parsing approach +def parse_input(input_str): + """Parse input that could be JSON or natural language.""" + try: + # First try to parse as JSON + return json.loads(input_str) + except json.JSONDecodeError: + # If not JSON, try to parse natural language + input_lower = input_str.lower() + + # Parse for addition + if "add" in input_lower or "plus" in input_lower or "sum" in input_lower: + # Extract numbers - very simple approach + numbers = [int(s) for s in input_lower.split() if s.isdigit()] + if len(numbers) >= 2: + return {"a": numbers[0], "b": numbers[1]} + + # Parse for multiplication + if "multiply" in input_lower or "times" in input_lower or "product" in input_lower: + numbers = [int(s) for s in input_lower.split() if s.isdigit()] + if len(numbers) >= 2: + return {"a": numbers[0], "b": numbers[1]} + + # Parse for division + if "divide" in input_lower or "quotient" in input_lower: + numbers = [int(s) for s in input_lower.split() if s.isdigit()] + if len(numbers) >= 2: + return {"a": numbers[0], "b": numbers[1]} + + # Could not parse successfully + return None + + +# Define tools with more flexible input handling +@mcp.tool() +def add(input_str=None, a=None, b=None): + """Add two numbers. Can accept JSON parameters or natural language. + + Args: + input_str (str, optional): Natural language input to parse + a (int, optional): First number if provided directly + b (int, optional): Second number if provided directly + + Returns: + str: A message containing the sum + """ + logger.info(f"Add tool called with input_str={input_str}, a={a}, b={b}") + + # If we got a natural language string instead of parameters + if input_str and not (a is not None and b is not None): + parsed = parse_input(input_str) + if parsed: + a = parsed.get("a") + b = parsed.get("b") + + # Validate we have what we need + if a is None or b is None: + return "Sorry, I couldn't understand the numbers to add" + + try: + a = int(a) + b = int(b) + result = a + b + return f"The sum of {a} and {b} is {result}" + except ValueError: + return "Please provide valid numbers for addition" + + +@mcp.tool() +def multiply(input_str=None, a=None, b=None): + """Multiply two numbers. Can accept JSON parameters or natural language. + + Args: + input_str (str, optional): Natural language input to parse + a (int, optional): First number if provided directly + b (int, optional): Second number if provided directly + + Returns: + str: A message containing the product + """ + logger.info( + f"Multiply tool called with input_str={input_str}, a={a}, b={b}") + + # If we got a natural language string instead of parameters + if input_str and not (a is not None and b is not None): + parsed = parse_input(input_str) + if parsed: + a = parsed.get("a") + b = parsed.get("b") + + # Validate we have what we need + if a is None or b is None: + return "Sorry, I couldn't understand the numbers to multiply" + + try: + a = int(a) + b = int(b) + result = a * b + return f"The product of {a} and {b} is {result}" + except ValueError: + return "Please provide valid numbers for multiplication" + + +@mcp.tool() +def divide(input_str=None, a=None, b=None): + """Divide two numbers. Can accept JSON parameters or natural language. + + Args: + input_str (str, optional): Natural language input to parse + a (int, optional): Numerator if provided directly + b (int, optional): Denominator if provided directly + + Returns: + str: A message containing the division result or an error message + """ + logger.info(f"Divide tool called with input_str={input_str}, a={a}, b={b}") + + # If we got a natural language string instead of parameters + if input_str and not (a is not None and b is not None): + parsed = parse_input(input_str) + if parsed: + a = parsed.get("a") + b = parsed.get("b") + + # Validate we have what we need + if a is None or b is None: + return "Sorry, I couldn't understand the numbers to divide" + + try: + a = int(a) + b = int(b) + + if b == 0: + logger.warning("Division by zero attempted") + return "Cannot divide by zero" + + result = a / b + return f"{a} divided by {b} is {result}" + except ValueError: + return "Please provide valid numbers for division" + + +if __name__ == "__main__": + try: + logger.info("Starting math server on http://0.0.0.0:8000") + print("Math MCP Server is running. Press Ctrl+C to stop.") + print( + "Server is configured to accept both JSON and natural language input" + ) + + # Add a small delay to ensure logging is complete before the server starts + time.sleep(0.5) + + # Run the MCP server + mcp.run() + except KeyboardInterrupt: + logger.info("Server shutdown requested") + print("\nShutting down server...") + except Exception as e: + logger.error(f"Server error: {e}") + raise +" server is runnig poeroperly "2025-04-20 17:35:01.251 | INFO | __main__::161 - Starting math server on http://0.0.0.0:8000 +Math MCP Server is running. Press Ctrl+C to stop. +Server is configured to accept both JSON and natural language input +[04/20/25 17:35:01] INFO Starting server "FastMCP"... " butwhy im getting these errore "2025-04-20T17:35:04.174629+0000 | INFO | mcp_client:main:159 - Initializing math system... +2025-04-20T17:35:04.203591+0000 | ERROR | mcp_integration:connect:89 - Error initializing MCP server: unhandled errors in a TaskGroup (1 sub-exception) +2025-04-20T17:35:04.204437+0000 | ERROR | mcp_client:test_server_connection:110 - ❌ Server connection test failed: unhandled errors in a TaskGroup (1 sub-exception) + +Server connection failed - using fallback calculator mode +Ask me any math question! +Examples: 'what is 5 plus 3?' or 'can you multiply 4 and 6?' +Type 'list tools' to see available operations +Type 'exit' to quit + +What would you like to calculate? " \ No newline at end of file diff --git a/examples/mcp_example/mcp_client.py b/examples/mcp_example/mcp_client.py index bf83d2d4..8c6ae8ae 100644 --- a/examples/mcp_example/mcp_client.py +++ b/examples/mcp_example/mcp_client.py @@ -1,60 +1,158 @@ - from swarms import Agent -from swarms.tools.mcp_integration import MCPServerSseParams +from swarms.tools.mcp_integration import MCPServerSseParams, MCPServerSse, mcp_flow_get_tool_schema from loguru import logger +import sys +import asyncio +import json +import httpx +import time + +# Configure logging for more detailed output +logger.remove() +logger.add(sys.stdout, + level="DEBUG", + format="{time} | {level} | {module}:{function}:{line} - {message}") + +# Relaxed prompt that doesn't enforce strict JSON formatting + + + +# Create server parameters +def get_server_params(): + """Get the MCP server connection parameters.""" + return MCPServerSseParams( + url= + "http://127.0.0.1:8000", # Use 127.0.0.1 instead of localhost/0.0.0.0 + headers={ + "Content-Type": "application/json", + "Accept": "text/event-stream" + }, + timeout=15.0, # Longer timeout + sse_read_timeout=60.0 # Longer read timeout + ) -# Comprehensive math prompt that encourages proper JSON formatting -MATH_AGENT_PROMPT = """ -You are a helpful math calculator assistant. -Your role is to understand natural language math requests and perform calculations. -When asked to perform calculations: +def initialize_math_system(): + """Initialize the math agent with MCP server configuration.""" + # Create the agent with the MCP server configuration + math_agent = Agent(agent_name="Math Assistant", + agent_description="Friendly math calculator", + system_prompt=MATH_AGENT_PROMPT, + max_loops=1, + mcp_servers=[get_server_params()], + model_name="gpt-3.5-turbo", + verbose=True) -1. Determine the operation (add, multiply, or divide) -2. Extract the numbers from the request -3. Use the appropriate math operation tool + return math_agent -FORMAT YOUR TOOL CALLS AS JSON with this format: -{"tool_name": "add", "a": , "b": } -or -{"tool_name": "multiply", "a": , "b": } -or -{"tool_name": "divide", "a": , "b": } -Always respond with a tool call in JSON format first, followed by a brief explanation. -""" +# Function to get list of available tools from the server +async def get_tools_list(): + """Fetch and format the list of available tools from the server.""" + try: + server_params = get_server_params() + tools = await mcp_flow_get_tool_schema(server_params) + + if not tools: + return "No tools are currently available on the server." + + # Format the tools information + tools_info = "Available tools:\n" + for tool in tools: + tools_info += f"\n- {tool.name}: {tool.description or 'No description'}\n" + if tool.parameters and hasattr(tool.parameters, 'properties'): + tools_info += " Parameters:\n" + for param_name, param_info in tool.parameters.properties.items( + ): + param_type = param_info.get('type', 'unknown') + param_desc = param_info.get('description', + 'No description') + tools_info += f" - {param_name} ({param_type}): {param_desc}\n" + + return tools_info + except Exception as e: + logger.error(f"Failed to get tools list: {e}") + return f"Error retrieving tools list: {str(e)}" -def initialize_math_system(): - """Initialize the math agent with MCP server configuration.""" - # Configure the MCP server connection - math_server = MCPServerSseParams( - url="http://0.0.0.0:8000", - headers={"Content-Type": "application/json"}, - timeout=5.0, - sse_read_timeout=30.0 - ) - # Create the agent with the MCP server configuration - math_agent = Agent( - agent_name="Math Assistant", - agent_description="Friendly math calculator", - system_prompt=MATH_AGENT_PROMPT, - max_loops=1, - mcp_servers=[math_server], # Pass MCP server config as a list - model_name="gpt-3.5-turbo", - verbose=True # Enable verbose mode to see more details - ) +# Function to test server connection +def test_server_connection(): + """Test if the server is reachable and responsive.""" + try: + # Create a short-lived connection to check server + server = MCPServerSse(get_server_params()) + + # Try connecting (this is synchronous) + asyncio.run(server.connect()) + asyncio.run(server.cleanup()) + logger.info("✅ Server connection test successful") + return True + except Exception as e: + logger.error(f"❌ Server connection test failed: {e}") + return False + + +# Manual math operation handler as ultimate fallback +def manual_math(query): + """Parse and solve a math problem without using the server.""" + query = query.lower() + + # Check if user is asking for available tools/functions + if "list" in query and ("tools" in query or "functions" in query + or "operations" in query): + return """ +Available tools: +1. add - Add two numbers together (e.g., "add 3 and 4") +2. multiply - Multiply two numbers together (e.g., "multiply 5 and 6") +3. divide - Divide the first number by the second (e.g., "divide 10 by 2") +""" + + try: + if "add" in query or "plus" in query or "sum" in query: + # Extract numbers using a simple approach + numbers = [int(s) for s in query.split() if s.isdigit()] + if len(numbers) >= 2: + result = numbers[0] + numbers[1] + return f"The sum of {numbers[0]} and {numbers[1]} is {result}" + + elif "multiply" in query or "times" in query or "product" in query: + numbers = [int(s) for s in query.split() if s.isdigit()] + if len(numbers) >= 2: + result = numbers[0] * numbers[1] + return f"The product of {numbers[0]} and {numbers[1]} is {result}" + + elif "divide" in query or "quotient" in query: + numbers = [int(s) for s in query.split() if s.isdigit()] + if len(numbers) >= 2: + if numbers[1] == 0: + return "Cannot divide by zero" + result = numbers[0] / numbers[1] + return f"{numbers[0]} divided by {numbers[1]} is {result}" + + return "I couldn't parse your math request. Try something like 'add 3 and 4'." + except Exception as e: + logger.error(f"Manual math error: {e}") + return f"Error performing calculation: {str(e)}" - return math_agent def main(): try: logger.info("Initializing math system...") - math_agent = initialize_math_system() - print("\nMath Calculator Ready!") + # Test server connection first + server_available = test_server_connection() + + if server_available: + math_agent = initialize_math_system() + print("\nMath Calculator Ready! (Server connection successful)") + else: + print( + "\nServer connection failed - using fallback calculator mode") + math_agent = None + print("Ask me any math question!") print("Examples: 'what is 5 plus 3?' or 'can you multiply 4 and 6?'") + print("Type 'list tools' to see available operations") print("Type 'exit' to quit\n") while True: @@ -65,9 +163,33 @@ def main(): if query.lower() == 'exit': break + # Handle special commands + if query.lower() in ('list tools', 'show tools', + 'available tools', 'what tools'): + if server_available: + # Get tools list from server + tools_info = asyncio.run(get_tools_list()) + print(f"\n{tools_info}\n") + else: + # Use manual fallback + print(manual_math("list tools")) + continue + logger.info(f"Processing query: {query}") - result = math_agent.run(query) - print(f"\nResult: {result}\n") + + # First try the agent if available + if math_agent and server_available: + try: + result = math_agent.run(query) + print(f"\nResult: {result}\n") + continue + except Exception as e: + logger.error(f"Agent error: {e}") + print("Agent encountered an error, trying fallback...") + + # If agent fails or isn't available, use manual calculator + result = manual_math(query) + print(f"\nCalculation result: {result}\n") except KeyboardInterrupt: print("\nGoodbye!") @@ -80,5 +202,6 @@ def main(): logger.error(f"System initialization error: {e}") print(f"Failed to start the math system: {str(e)}") + if __name__ == "__main__": main() diff --git a/examples/mcp_example/mock_math_server.py b/examples/mcp_example/mock_math_server.py index 05ff56f0..45c456a9 100644 --- a/examples/mcp_example/mock_math_server.py +++ b/examples/mcp_example/mock_math_server.py @@ -1,70 +1,168 @@ from fastmcp import FastMCP from loguru import logger import time - -# Create the MCP server -mcp = FastMCP(host="0.0.0.0", - port=8000, - transport="sse", - require_session_id=False) - - -# Define tools with proper type hints and docstrings +import json + +# Create the MCP server with detailed debugging +mcp = FastMCP( + host="0.0.0.0", # Bind to all interfaces + port=8000, + transport="sse", + require_session_id=False, + cors_allowed_origins=["*"], # Allow connections from any origin + debug=True # Enable debug mode for more verbose output +) + + +# Add a more flexible parsing approach +def parse_input(input_str): + """Parse input that could be JSON or natural language.""" + try: + # First try to parse as JSON + return json.loads(input_str) + except json.JSONDecodeError: + # If not JSON, try to parse natural language + input_lower = input_str.lower() + + # Parse for addition + if "add" in input_lower or "plus" in input_lower or "sum" in input_lower: + # Extract numbers - very simple approach + numbers = [int(s) for s in input_lower.split() if s.isdigit()] + if len(numbers) >= 2: + return {"a": numbers[0], "b": numbers[1]} + + # Parse for multiplication + if "multiply" in input_lower or "times" in input_lower or "product" in input_lower: + numbers = [int(s) for s in input_lower.split() if s.isdigit()] + if len(numbers) >= 2: + return {"a": numbers[0], "b": numbers[1]} + + # Parse for division + if "divide" in input_lower or "quotient" in input_lower: + numbers = [int(s) for s in input_lower.split() if s.isdigit()] + if len(numbers) >= 2: + return {"a": numbers[0], "b": numbers[1]} + + # Could not parse successfully + return None + + +# Define tools with more flexible input handling @mcp.tool() -def add(a: int, b: int) -> str: - """Add two numbers. +def add(input_str=None, a=None, b=None): + """Add two numbers. Can accept JSON parameters or natural language. Args: - a (int): First number - b (int): Second number + input_str (str, optional): Natural language input to parse + a (int, optional): First number if provided directly + b (int, optional): Second number if provided directly Returns: str: A message containing the sum """ - logger.info(f"Adding {a} and {b}") - result = a + b - return f"The sum of {a} and {b} is {result}" + logger.info(f"Add tool called with input_str={input_str}, a={a}, b={b}") + + # If we got a natural language string instead of parameters + if input_str and not (a is not None and b is not None): + parsed = parse_input(input_str) + if parsed: + a = parsed.get("a") + b = parsed.get("b") + + # Validate we have what we need + if a is None or b is None: + return "Sorry, I couldn't understand the numbers to add" + + try: + a = int(a) + b = int(b) + result = a + b + return f"The sum of {a} and {b} is {result}" + except ValueError: + return "Please provide valid numbers for addition" @mcp.tool() -def multiply(a: int, b: int) -> str: - """Multiply two numbers. +def multiply(input_str=None, a=None, b=None): + """Multiply two numbers. Can accept JSON parameters or natural language. Args: - a (int): First number - b (int): Second number + input_str (str, optional): Natural language input to parse + a (int, optional): First number if provided directly + b (int, optional): Second number if provided directly Returns: str: A message containing the product """ - logger.info(f"Multiplying {a} and {b}") - result = a * b - return f"The product of {a} and {b} is {result}" + logger.info( + f"Multiply tool called with input_str={input_str}, a={a}, b={b}") + + # If we got a natural language string instead of parameters + if input_str and not (a is not None and b is not None): + parsed = parse_input(input_str) + if parsed: + a = parsed.get("a") + b = parsed.get("b") + + # Validate we have what we need + if a is None or b is None: + return "Sorry, I couldn't understand the numbers to multiply" + + try: + a = int(a) + b = int(b) + result = a * b + return f"The product of {a} and {b} is {result}" + except ValueError: + return "Please provide valid numbers for multiplication" @mcp.tool() -def divide(a: int, b: int) -> str: - """Divide two numbers. +def divide(input_str=None, a=None, b=None): + """Divide two numbers. Can accept JSON parameters or natural language. Args: - a (int): Numerator - b (int): Denominator + input_str (str, optional): Natural language input to parse + a (int, optional): Numerator if provided directly + b (int, optional): Denominator if provided directly Returns: str: A message containing the division result or an error message """ - logger.info(f"Dividing {a} by {b}") - if b == 0: - logger.warning("Division by zero attempted") - return "Cannot divide by zero" - result = a / b - return f"{a} divided by {b} is {result}" + logger.info(f"Divide tool called with input_str={input_str}, a={a}, b={b}") + + # If we got a natural language string instead of parameters + if input_str and not (a is not None and b is not None): + parsed = parse_input(input_str) + if parsed: + a = parsed.get("a") + b = parsed.get("b") + + # Validate we have what we need + if a is None or b is None: + return "Sorry, I couldn't understand the numbers to divide" + + try: + a = int(a) + b = int(b) + + if b == 0: + logger.warning("Division by zero attempted") + return "Cannot divide by zero" + + result = a / b + return f"{a} divided by {b} is {result}" + except ValueError: + return "Please provide valid numbers for division" if __name__ == "__main__": try: logger.info("Starting math server on http://0.0.0.0:8000") print("Math MCP Server is running. Press Ctrl+C to stop.") + print( + "Server is configured to accept both JSON and natural language input" + ) # Add a small delay to ensure logging is complete before the server starts time.sleep(0.5) diff --git a/swarms/prompts/agent_prompts.py b/swarms/prompts/agent_prompts.py index 5136b8e0..ff379134 100644 --- a/swarms/prompts/agent_prompts.py +++ b/swarms/prompts/agent_prompts.py @@ -1,24 +1,25 @@ # Agent prompts for MCP testing and interactions -# Keeping the original format that already has JSON formatting -MATH_AGENT_PROMPT = """You are a helpful math calculator assistant. +MATH_AGENT_PROMPT = """ +You are a helpful math calculator assistant. + Your role is to understand natural language math requests and perform calculations. When asked to perform calculations: + 1. Determine the operation (add, multiply, or divide) 2. Extract the numbers from the request -3. Use the appropriate math operation tool -Format your tool calls as JSON with the tool_name and parameters. +3. Call the appropriate operation -Example: -User: "what is 5 plus 3?" -You: Using the add operation for 5 and 3 -{"tool_name": "add", "a": 5, "b": 3} +You can use these tools: +- add: Add two numbers together +- multiply: Multiply two numbers together +- divide: Divide the first number by the second number -User: "multiply 4 times 6" -You: Using multiply for 4 and 6 -{"tool_name": "multiply", "a": 4, "b": 6} -""" +If the user asks for a list of available tools or functions, tell them about the above operations. +Just tell me which operation to perform and what numbers to use in natural language. +No need for strict JSON formatting - I'll handle the tool calling for you. +""" FINANCE_AGENT_PROMPT = """You are a financial analysis agent with access to stock market data services. Key responsibilities: 1. Interpret financial queries and determine required data diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index 44c9a95a..af50cca5 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -29,8 +29,7 @@ from swarms.agents.ape_agent import auto_generate_prompt from swarms.artifacts.main_artifact import Artifact from swarms.prompts.agent_system_prompts import AGENT_SYSTEM_PROMPT_3 from swarms.prompts.multi_modal_autonomous_instruction_prompt import ( - MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1, -) + MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1, ) from swarms.prompts.tools import tool_sop_prompt from swarms.schemas.agent_step_schemas import ManySteps, Step from swarms.schemas.base_schemas import ( @@ -87,7 +86,6 @@ def exists(val): # Agent output types ToolUsageType = Union[BaseModel, Dict[str, Any]] - # Agent Exceptions @@ -322,8 +320,7 @@ class Agent: stopping_func: Optional[Callable] = None, custom_loop_condition: Optional[Callable] = None, sentiment_threshold: Optional[ - float - ] = None, # Evaluate on output using an external model + float] = None, # Evaluate on output using an external model custom_exit_command: Optional[str] = "exit", sentiment_analyzer: Optional[Callable] = None, limit_tokens_from_string: Optional[Callable] = None, @@ -362,9 +359,8 @@ class Agent: use_cases: Optional[List[Dict[str, str]]] = None, step_pool: List[Step] = [], print_every_step: Optional[bool] = False, - time_created: Optional[str] = time.strftime( - "%Y-%m-%d %H:%M:%S", time.localtime() - ), + time_created: Optional[str] = time.strftime("%Y-%m-%d %H:%M:%S", + time.localtime()), agent_output: ManySteps = None, executor_workers: int = os.cpu_count(), data_memory: Optional[Callable] = None, @@ -451,9 +447,7 @@ class Agent: self.output_type = output_type self.function_calling_type = function_calling_type self.output_cleaner = output_cleaner - self.function_calling_format_type = ( - function_calling_format_type - ) + self.function_calling_format_type = (function_calling_format_type) self.list_base_models = list_base_models self.metadata_output_type = metadata_output_type self.state_save_file_type = state_save_file_type @@ -507,7 +501,8 @@ class Agent: self.role = role self.no_print = no_print self.tools_list_dictionary = tools_list_dictionary - self.mcp_servers = mcp_servers or [] # Initialize mcp_servers to an empty list if None + self.mcp_servers = mcp_servers or [ + ] # Initialize mcp_servers to an empty list if None self._cached_llm = ( None # Add this line to cache the LLM instance @@ -516,10 +511,7 @@ class Agent: "gpt-4o-mini" # Move default model name here ) - if ( - self.agent_name is not None - or self.agent_description is not None - ): + if (self.agent_name is not None or self.agent_description is not None): prompt = f"Your Name: {self.agent_name} \n\n Your Description: {self.agent_description} \n\n {system_prompt}" else: prompt = system_prompt @@ -539,9 +531,7 @@ class Agent: self.feedback = [] # Initialize the executor - self.executor = ThreadPoolExecutor( - max_workers=executor_workers - ) + self.executor = ThreadPoolExecutor(max_workers=executor_workers) self.init_handling() @@ -557,8 +547,7 @@ class Agent: (self.handle_tool_init, True), # Always run tool init ( self.handle_tool_schema_ops, - exists(self.tool_schema) - or exists(self.list_base_models), + exists(self.tool_schema) or exists(self.list_base_models), ), ( self.handle_sop_ops, @@ -567,14 +556,11 @@ class Agent: ] # Filter out tasks whose conditions are False - filtered_tasks = [ - task for task, condition in tasks if condition - ] + filtered_tasks = [task for task, condition in tasks if condition] # Execute all tasks concurrently - with concurrent.futures.ThreadPoolExecutor( - max_workers=os.cpu_count() * 4 - ) as executor: + with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count() * + 4) as executor: # Map tasks to futures and collect results results = {} future_to_task = { @@ -583,21 +569,15 @@ class Agent: } # Wait for each future to complete and collect results/exceptions - for future in concurrent.futures.as_completed( - future_to_task - ): + for future in concurrent.futures.as_completed(future_to_task): task_name = future_to_task[future] try: result = future.result() results[task_name] = result - logging.info( - f"Task {task_name} completed successfully" - ) + logging.info(f"Task {task_name} completed successfully") except Exception as e: results[task_name] = None - logging.error( - f"Task {task_name} failed with error: {e}" - ) + logging.error(f"Task {task_name} failed with error: {e}") # Run sequential operations after all concurrent tasks are done self.agent_output = self.agent_output_model() @@ -618,9 +598,7 @@ class Agent: max_loops=self.max_loops, steps=self.short_memory.to_dict(), full_history=self.short_memory.get_str(), - total_tokens=count_tokens( - text=self.short_memory.get_str() - ), + total_tokens=count_tokens(text=self.short_memory.get_str()), stopping_token=self.stopping_token, interactive=self.interactive, dynamic_temperature_enabled=self.dynamic_temperature_enabled, @@ -647,23 +625,17 @@ class Agent: } if self.llm_args is not None: - self._cached_llm = LiteLLM( - **{**common_args, **self.llm_args} - ) + self._cached_llm = LiteLLM(**{**common_args, **self.llm_args}) elif self.tools_list_dictionary is not None: self._cached_llm = LiteLLM( **common_args, tools_list_dictionary=self.tools_list_dictionary, tool_choice="auto", - parallel_tool_calls=len( - self.tools_list_dictionary - ) - > 1, + parallel_tool_calls=len(self.tools_list_dictionary) > 1, ) else: - self._cached_llm = LiteLLM( - **common_args, stream=self.streaming_on - ) + self._cached_llm = LiteLLM(**common_args, + stream=self.streaming_on) return self._cached_llm except AgentLLMInitializationError as e: @@ -674,12 +646,8 @@ class Agent: def handle_tool_init(self): # Initialize the tool struct - if ( - exists(self.tools) - or exists(self.list_base_models) - or exists(self.tool_schema) - or exists(self.mcp_servers) - ): + if (exists(self.tools) or exists(self.list_base_models) + or exists(self.tool_schema) or exists(self.mcp_servers)): self.tool_struct = BaseTool( tools=self.tools, @@ -692,28 +660,21 @@ class Agent: "Tools provided make sure the functions have documentation ++ type hints, otherwise tool execution won't be reliable." ) # Add the tool prompt to the memory - self.short_memory.add( - role="system", content=self.tool_system_prompt - ) + self.short_memory.add(role="system", + content=self.tool_system_prompt) # Log the tools - logger.info( - f"Tools provided: Accessing {len(self.tools)} tools" - ) + logger.info(f"Tools provided: Accessing {len(self.tools)} tools") # Transform the tools into an openai schema # self.convert_tool_into_openai_schema() # Transform the tools into an openai schema - tool_dict = ( - self.tool_struct.convert_tool_into_openai_schema() - ) + tool_dict = (self.tool_struct.convert_tool_into_openai_schema()) self.short_memory.add(role="system", content=tool_dict) # Now create a function calling map for every tools - self.function_map = { - tool.__name__: tool for tool in self.tools - } + self.function_map = {tool.__name__: tool for tool in self.tools} def setup_config(self): # The max_loops will be set dynamically if the dynamic_loop @@ -760,21 +721,16 @@ class Agent: logger.warning( "No agent details found. Usingtask as fallback for promptgeneration." ) - self.system_prompt = auto_generate_prompt( - task, self.llm - ) + self.system_prompt = auto_generate_prompt(task, self.llm) else: # Combine all available components combined_prompt = " ".join(components) logger.info( - f"Auto-generating prompt from: {', '.join(components)}" - ) + f"Auto-generating prompt from: {', '.join(components)}") self.system_prompt = auto_generate_prompt( - combined_prompt, self.llm - ) - self.short_memory.add( - role="system", content=self.system_prompt - ) + combined_prompt, self.llm) + self.short_memory.add(role="system", + content=self.system_prompt) logger.info("Auto-generated prompt successfully.") @@ -789,13 +745,9 @@ class Agent: def agent_initialization(self): try: - logger.info( - f"Initializing Autonomous Agent {self.agent_name}..." - ) + logger.info(f"Initializing Autonomous Agent {self.agent_name}...") self.check_parameters() - logger.info( - f"{self.agent_name} Initialized Successfully." - ) + logger.info(f"{self.agent_name} Initialized Successfully.") logger.info( f"Autonomous Agent {self.agent_name} Activated, all systems operational. Executing task..." ) @@ -814,9 +766,7 @@ class Agent: return self.stopping_condition(response) return False except Exception as error: - logger.error( - f"Error checking stopping condition: {error}" - ) + logger.error(f"Error checking stopping condition: {error}") def dynamic_temperature(self): """ @@ -837,11 +787,7 @@ class Agent: def print_dashboard(self): """Print dashboard""" - formatter.print_panel( - f"Initializing Agent: {self.agent_name}" - ) - - ) + formatter.print_panel(f"Initializing Agent: {self.agent_name}") data = self.to_dict() @@ -861,8 +807,7 @@ class Agent: Configuration: {data} ---------------------------------------- - """, - ) + """, ) # Check parameters def check_parameters(self): @@ -914,21 +859,14 @@ class Agent: try: # 1. Batch process initial setup setup_tasks = [ - lambda: self.check_if_no_prompt_then_autogenerate( - task - ), - lambda: self.short_memory.add( - role=self.user_name, content=task - ), - lambda: ( - self.plan(task) if self.plan_enabled else None - ), + lambda: self.check_if_no_prompt_then_autogenerate(task), + lambda: self.short_memory.add(role=self.user_name, + content=task), + lambda: (self.plan(task) if self.plan_enabled else None), ] # Execute setup tasks concurrently - with ThreadPoolExecutor( - max_workers=len(setup_tasks) - ) as executor: + with ThreadPoolExecutor(max_workers=len(setup_tasks)) as executor: executor.map(lambda f: f(), setup_tasks) # Set the loop count @@ -953,10 +891,7 @@ class Agent: f"Task Request for {self.agent_name}", ) - while ( - self.max_loops == "auto" - or loop_count < self.max_loops - ): + while (self.max_loops == "auto" or loop_count < self.max_loops): loop_count += 1 # self.short_memory.add( @@ -969,35 +904,25 @@ class Agent: self.dynamic_temperature() # Task prompt - task_prompt = ( - self.short_memory.return_history_as_string() - ) + task_prompt = (self.short_memory.return_history_as_string()) # Parameters attempt = 0 success = False while attempt < self.retry_attempts and not success: try: - if ( - self.long_term_memory is not None - and self.rag_every_loop is True - ): - logger.info( - "Querying RAG database for context..." - ) + if (self.long_term_memory is not None + and self.rag_every_loop is True): + logger.info("Querying RAG database for context...") self.memory_query(task_prompt) # Generate response using LLM - response_args = ( - (task_prompt, *args) - if img is None - else (task_prompt, img, *args) - ) + response_args = ((task_prompt, + *args) if img is None else + (task_prompt, img, *args)) # Call the LLM - response = self.call_llm( - *response_args, **kwargs - ) + response = self.call_llm(*response_args, **kwargs) # Convert to a str if the response is not a str response = self.parse_llm_output(response) @@ -1014,30 +939,27 @@ class Agent: # 9. Batch memory updates and prints update_tasks = [ - lambda: self.short_memory.add( - role=self.agent_name, content=response - ), - lambda: self.pretty_print( - response, loop_count - ), + lambda: self.short_memory.add(role=self.agent_name, + content=response), + lambda: self.pretty_print(response, loop_count), lambda: self.output_cleaner_op(response), ] with ThreadPoolExecutor( - max_workers=len(update_tasks) - ) as executor: + max_workers=len(update_tasks)) as executor: executor.map(lambda f: f(), update_tasks) # Check and execute tools (including MCP) - if self.tools is not None or hasattr(self, 'mcp_servers'): + if self.tools is not None or hasattr( + self, 'mcp_servers'): if self.tools: out = self.parse_and_execute_tools(response) - if hasattr(self, 'mcp_servers') and self.mcp_servers: + if hasattr(self, + 'mcp_servers') and self.mcp_servers: out = self.mcp_execution_flow(response) - self.short_memory.add( - role="Tool Executor", content=out - ) + self.short_memory.add(role="Tool Executor", + content=out) agent_print( f"{self.agent_name} - Tool Executor", @@ -1055,9 +977,8 @@ class Agent: self.streaming_on, ) - self.short_memory.add( - role=self.agent_name, content=out - ) + self.short_memory.add(role=self.agent_name, + content=out) self.sentiment_and_evaluator(response) @@ -1070,10 +991,8 @@ class Agent: if self.autosave is True: self.save() - logger.error( - f"Attempt {attempt+1}: Error generating" - f" response: {e}" - ) + logger.error(f"Attempt {attempt+1}: Error generating" + f" response: {e}") attempt += 1 if not success: @@ -1083,23 +1002,17 @@ class Agent: if self.autosave is True: self.save() - logger.error( - "Failed to generate a valid response after" - " retry attempts." - ) + logger.error("Failed to generate a valid response after" + " retry attempts.") break # Exit the loop if all retry attempts fail # Check stopping conditions - if ( - self.stopping_condition is not None - and self._check_stopping_condition(response) - ): + if (self.stopping_condition is not None + and self._check_stopping_condition(response)): logger.info("Stopping condition met.") break - elif ( - self.stopping_func is not None - and self.stopping_func(response) - ): + elif (self.stopping_func is not None + and self.stopping_func(response)): logger.info("Stopping function met.") break @@ -1108,21 +1021,15 @@ class Agent: user_input = input("You: ") # User-defined exit command - if ( - user_input.lower() - == self.custom_exit_command.lower() - ): + if (user_input.lower() == self.custom_exit_command.lower() + ): print("Exiting as per user request.") break - self.short_memory.add( - role="User", content=user_input - ) + self.short_memory.add(role="User", content=user_input) if self.loop_interval: - logger.info( - f"Sleeping for {self.loop_interval} seconds" - ) + logger.info(f"Sleeping for {self.loop_interval} seconds") time.sleep(self.loop_interval) if self.autosave is True: @@ -1141,14 +1048,11 @@ class Agent: lambda: self.save() if self.autosave else None, ] - with ThreadPoolExecutor( - max_workers=len(final_tasks) - ) as executor: + with ThreadPoolExecutor(max_workers=len(final_tasks)) as executor: executor.map(lambda f: f(), final_tasks) - return history_output_formatter( - self.short_memory, type=self.output_type - ) + return history_output_formatter(self.short_memory, + type=self.output_type) except Exception as error: self._handle_run_error(error) @@ -1170,7 +1074,7 @@ class Agent: def _handle_run_error(self, error: any): process_thread = threading.Thread( target=self.__handle_run_error, - args=(error,), + args=(error, ), daemon=True, ) process_thread.start() @@ -1219,8 +1123,7 @@ class Agent: ) except Exception as error: await self._handle_run_error( - error - ) # Ensure this is also async if needed + error) # Ensure this is also async if needed def __call__( self, @@ -1255,12 +1158,8 @@ class Agent: except Exception as error: self._handle_run_error(error) - def receive_message( - self, agent_name: str, task: str, *args, **kwargs - ): - return self.run( - task=f"From {agent_name}: {task}", *args, **kwargs - ) + def receive_message(self, agent_name: str, task: str, *args, **kwargs): + return self.run(task=f"From {agent_name}: {task}", *args, **kwargs) def dict_to_csv(self, data: dict) -> str: """ @@ -1311,8 +1210,7 @@ class Agent: except Exception as error: retries += 1 logger.error( - f"Attempt {retries}: Error executing tool: {error}" - ) + f"Attempt {retries}: Error executing tool: {error}") if retries == max_retries: raise error time.sleep(1) # Wait for a bit before retrying @@ -1328,9 +1226,7 @@ class Agent: """ logger.info(f"Adding memory: {message}") - return self.short_memory.add( - role=self.agent_name, content=message - ) + return self.short_memory.add(role=self.agent_name, content=message) def plan(self, task: str, *args, **kwargs) -> None: """ @@ -1347,9 +1243,7 @@ class Agent: logger.info(f"Plan: {plan}") # Add the plan to the memory - self.short_memory.add( - role=self.agent_name, content=str(plan) - ) + self.short_memory.add(role=self.agent_name, content=str(plan)) return None except Exception as error: @@ -1365,16 +1259,13 @@ class Agent: """ try: logger.info(f"Running concurrent task: {task}") - future = self.executor.submit( - self.run, task, *args, **kwargs - ) + future = self.executor.submit(self.run, task, *args, **kwargs) result = await asyncio.wrap_future(future) logger.info(f"Completed task: {result}") return result except Exception as error: logger.error( - f"Error running agent: {error} while running concurrently" - ) + f"Error running agent: {error} while running concurrently") def run_concurrent_tasks(self, tasks: List[str], *args, **kwargs): """ @@ -1386,9 +1277,7 @@ class Agent: try: logger.info(f"Running concurrent tasks: {tasks}") futures = [ - self.executor.submit( - self.run, task=task, *args, **kwargs - ) + self.executor.submit(self.run, task=task, *args, **kwargs) for task in tasks ] results = [future.result() for future in futures] @@ -1426,8 +1315,7 @@ class Agent: try: # Create a list of coroutines for each task coroutines = [ - self.arun(task=task, *args, **kwargs) - for task in tasks + self.arun(task=task, *args, **kwargs) for task in tasks ] # Use asyncio.gather to run them concurrently results = await asyncio.gather(*coroutines) @@ -1451,20 +1339,15 @@ class Agent: """ try: # Determine the save path - resolved_path = ( - file_path - or self.saved_state_path - or f"{self.agent_name}_state.json" - ) + resolved_path = (file_path or self.saved_state_path + or f"{self.agent_name}_state.json") # Ensure path has .json extension if not resolved_path.endswith(".json"): resolved_path += ".json" # Create full path including workspace directory - full_path = os.path.join( - self.workspace_dir, resolved_path - ) + full_path = os.path.join(self.workspace_dir, resolved_path) backup_path = full_path + ".backup" temp_path = full_path + ".temp" @@ -1489,25 +1372,19 @@ class Agent: try: os.remove(backup_path) except Exception as e: - logger.warning( - f"Could not remove backup file: {e}" - ) + logger.warning(f"Could not remove backup file: {e}") # Log saved state information if verbose if self.verbose: self._log_saved_state_info(full_path) - logger.info( - f"Successfully saved agent state to: {full_path}" - ) + logger.info(f"Successfully saved agent state to: {full_path}") # Handle additional component saves self._save_additional_components(full_path) except OSError as e: - logger.error( - f"Filesystem error while saving agent state: {e}" - ) + logger.error(f"Filesystem error while saving agent state: {e}") raise except Exception as e: logger.error(f"Unexpected error saving agent state: {e}") @@ -1517,40 +1394,25 @@ class Agent: """Save additional agent components like memory.""" try: # Save long term memory if it exists - if ( - hasattr(self, "long_term_memory") - and self.long_term_memory is not None - ): - memory_path = ( - f"{os.path.splitext(base_path)[0]}_memory.json" - ) + if (hasattr(self, "long_term_memory") + and self.long_term_memory is not None): + memory_path = (f"{os.path.splitext(base_path)[0]}_memory.json") try: self.long_term_memory.save(memory_path) - logger.info( - f"Saved long-term memory to: {memory_path}" - ) + logger.info(f"Saved long-term memory to: {memory_path}") except Exception as e: - logger.warning( - f"Could not save long-term memory: {e}" - ) + logger.warning(f"Could not save long-term memory: {e}") # Save memory manager if it exists - if ( - hasattr(self, "memory_manager") - and self.memory_manager is not None - ): + if (hasattr(self, "memory_manager") + and self.memory_manager is not None): manager_path = f"{os.path.splitext(base_path)[0]}_memory_manager.json" try: - self.memory_manager.save_memory_snapshot( - manager_path - ) + self.memory_manager.save_memory_snapshot(manager_path) logger.info( - f"Saved memory manager state to: {manager_path}" - ) + f"Saved memory manager state to: {manager_path}") except Exception as e: - logger.warning( - f"Could not save memory manager: {e}" - ) + logger.warning(f"Could not save memory manager: {e}") except Exception as e: logger.warning(f"Error saving additional components: {e}") @@ -1569,8 +1431,7 @@ class Agent: self.save() if self.verbose: logger.debug( - f"Autosaved agent state (interval: {interval}s)" - ) + f"Autosaved agent state (interval: {interval}s)") except Exception as e: logger.error(f"Autosave failed: {e}") time.sleep(interval) @@ -1597,9 +1458,7 @@ class Agent: """Cleanup method to be called on exit. Ensures final state is saved.""" try: if getattr(self, "autosave", False): - logger.info( - "Performing final autosave before exit..." - ) + logger.info("Performing final autosave before exit...") self.disable_autosave() self.save() except Exception as e: @@ -1621,22 +1480,11 @@ class Agent: try: # Resolve load path conditionally with a check for self.load_state_path resolved_path = ( - file_path - or self.load_state_path - or ( - f"{self.saved_state_path}.json" - if self.saved_state_path - else ( - f"{self.agent_name}.json" - if self.agent_name - else ( - f"{self.workspace_dir}/{self.agent_name}_state.json" - if self.workspace_dir and self.agent_name - else None - ) - ) - ) - ) + file_path or self.load_state_path or + (f"{self.saved_state_path}.json" if self.saved_state_path else + (f"{self.agent_name}.json" if self.agent_name else + (f"{self.workspace_dir}/{self.agent_name}_state.json" + if self.workspace_dir and self.agent_name else None)))) # Load state using SafeStateManager SafeStateManager.load_state(self, resolved_path) @@ -1661,10 +1509,8 @@ class Agent: """ try: # Reinitialize conversation if needed - if ( - not hasattr(self, "short_memory") - or self.short_memory is None - ): + if (not hasattr(self, "short_memory") + or self.short_memory is None): self.short_memory = Conversation( system_prompt=self.system_prompt, time_enabled=False, @@ -1674,9 +1520,7 @@ class Agent: # Reinitialize executor if needed if not hasattr(self, "executor") or self.executor is None: - self.executor = ThreadPoolExecutor( - max_workers=os.cpu_count() - ) + self.executor = ThreadPoolExecutor(max_workers=os.cpu_count()) # # Reinitialize tool structure if needed # if hasattr(self, 'tools') and (self.tools or getattr(self, 'list_base_models', None)): @@ -1697,19 +1541,13 @@ class Agent: preserved = SafeLoaderUtils.preserve_instances(self) logger.info(f"Saved agent state to: {file_path}") - logger.debug( - f"Saved {len(state_dict)} configuration values" - ) - logger.debug( - f"Preserved {len(preserved)} class instances" - ) + logger.debug(f"Saved {len(state_dict)} configuration values") + logger.debug(f"Preserved {len(preserved)} class instances") if self.verbose: logger.debug("Preserved instances:") for name, instance in preserved.items(): - logger.debug( - f" - {name}: {type(instance).__name__}" - ) + logger.debug(f" - {name}: {type(instance).__name__}") except Exception as e: logger.error(f"Error logging state info: {e}") @@ -1720,19 +1558,13 @@ class Agent: preserved = SafeLoaderUtils.preserve_instances(self) logger.info(f"Loaded agent state from: {file_path}") - logger.debug( - f"Loaded {len(state_dict)} configuration values" - ) - logger.debug( - f"Preserved {len(preserved)} class instances" - ) + logger.debug(f"Loaded {len(state_dict)} configuration values") + logger.debug(f"Preserved {len(preserved)} class instances") if self.verbose: logger.debug("Current class instances:") for name, instance in preserved.items(): - logger.debug( - f" - {name}: {type(instance).__name__}" - ) + logger.debug(f" - {name}: {type(instance).__name__}") except Exception as e: logger.error(f"Error logging state info: {e}") @@ -1811,9 +1643,7 @@ class Agent: Returns: str: The filtered response """ - logger.info( - f"Applying response filters to response: {response}" - ) + logger.info(f"Applying response filters to response: {response}") for word in self.response_filters: response = response.replace(word, "[FILTERED]") return response @@ -1884,9 +1714,7 @@ class Agent: for doc in docs: data = data_to_text(doc) - return self.short_memory.add( - role=self.user_name, content=data - ) + return self.short_memory.add(role=self.user_name, content=data) except Exception as error: logger.info(f"Error ingesting docs: {error}", "red") @@ -1899,9 +1727,7 @@ class Agent: try: logger.info(f"Ingesting pdf: {pdf}") text = pdf_to_text(pdf) - return self.short_memory.add( - role=self.user_name, content=text - ) + return self.short_memory.add(role=self.user_name, content=text) except Exception as error: logger.info(f"Error ingesting pdf: {error}", "red") @@ -1914,9 +1740,8 @@ class Agent: logger.info(f"Error receiving message: {error}") raise error - def send_agent_message( - self, agent_name: str, message: str, *args, **kwargs - ): + def send_agent_message(self, agent_name: str, message: str, *args, + **kwargs): """Send a message to the agent""" try: logger.info(f"Sending agent message: {message}") @@ -1987,13 +1812,9 @@ class Agent: all_text += f"\nContent from {file}:\n{text}\n" # Add the combined content to memory - return self.short_memory.add( - role=self.user_name, content=all_text - ) + return self.short_memory.add(role=self.user_name, content=all_text) except Exception as error: - logger.error( - f"Error getting docs from doc folders: {error}" - ) + logger.error(f"Error getting docs from doc folders: {error}") raise error def memory_query(self, task: str = None, *args, **kwargs) -> None: @@ -2003,12 +1824,10 @@ class Agent: formatter.print_panel(f"Querying RAG for: {task}") memory_retrieval = self.long_term_memory.query( - task, *args, **kwargs - ) + task, *args, **kwargs) memory_retrieval = ( - f"Documents Available: {str(memory_retrieval)}" - ) + f"Documents Available: {str(memory_retrieval)}") # # Count the tokens # memory_token_count = count_tokens( @@ -2047,17 +1866,13 @@ class Agent: print(f"Sentiment: {sentiment}") if sentiment > self.sentiment_threshold: - print( - f"Sentiment: {sentiment} is above" - " threshold:" - f" {self.sentiment_threshold}" - ) + print(f"Sentiment: {sentiment} is above" + " threshold:" + f" {self.sentiment_threshold}") elif sentiment < self.sentiment_threshold: - print( - f"Sentiment: {sentiment} is below" - " threshold:" - f" {self.sentiment_threshold}" - ) + print(f"Sentiment: {sentiment} is below" + " threshold:" + f" {self.sentiment_threshold}") self.short_memory.add( role=self.agent_name, @@ -2066,9 +1881,7 @@ class Agent: except Exception as e: print(f"Error occurred during sentiment analysis: {e}") - def stream_response( - self, response: str, delay: float = 0.001 - ) -> None: + def stream_response(self, response: str, delay: float = 0.001) -> None: """ Streams the response token by token. @@ -2101,19 +1914,16 @@ class Agent: # Log the amount of tokens left in the memory and in the task if self.tokenizer is not None: tokens_used = count_tokens( - self.short_memory.return_history_as_string() - ) + self.short_memory.return_history_as_string()) logger.info( - f"Tokens available: {self.context_length - tokens_used}" - ) + f"Tokens available: {self.context_length - tokens_used}") return tokens_used def tokens_checks(self): # Check the tokens available tokens_used = count_tokens( - self.short_memory.return_history_as_string() - ) + self.short_memory.return_history_as_string()) out = self.check_available_tokens() logger.info( @@ -2122,9 +1932,7 @@ class Agent: return out - def log_step_metadata( - self, loop: int, task: str, response: str - ) -> Step: + def log_step_metadata(self, loop: int, task: str, response: str) -> Step: """Log metadata for each step of agent execution.""" # Generate unique step ID step_id = f"step_{loop}_{uuid.uuid4().hex}" @@ -2134,7 +1942,7 @@ class Agent: # prompt_tokens = count_tokens(full_memory) # completion_tokens = count_tokens(response) # total_tokens = prompt_tokens + completion_tokens - total_tokens = (count_tokens(task) + count_tokens(response),) + total_tokens = (count_tokens(task) + count_tokens(response), ) # # Get memory responses # memory_responses = { @@ -2233,18 +2041,14 @@ class Agent: """Update tool usage information for a specific step.""" for step in self.agent_output.steps: if step.step_id == step_id: - step.response.tool_calls.append( - { - "tool": tool_name, - "arguments": tool_args, - "response": str(tool_response), - } - ) + step.response.tool_calls.append({ + "tool": tool_name, + "arguments": tool_args, + "response": str(tool_response), + }) break - def _serialize_callable( - self, attr_value: Callable - ) -> Dict[str, Any]: + def _serialize_callable(self, attr_value: Callable) -> Dict[str, Any]: """ Serializes callable attributes by extracting their name and docstring. @@ -2255,9 +2059,8 @@ class Agent: Dict[str, Any]: Dictionary with name and docstring of the callable. """ return { - "name": getattr( - attr_value, "__name__", type(attr_value).__name__ - ), + "name": getattr(attr_value, "__name__", + type(attr_value).__name__), "doc": getattr(attr_value, "__doc__", None), } @@ -2276,9 +2079,8 @@ class Agent: if callable(attr_value): return self._serialize_callable(attr_value) elif hasattr(attr_value, "to_dict"): - return ( - attr_value.to_dict() - ) # Recursive serialization for nested objects + return (attr_value.to_dict() + ) # Recursive serialization for nested objects else: json.dumps( attr_value @@ -2301,14 +2103,10 @@ class Agent: } def to_json(self, indent: int = 4, *args, **kwargs): - return json.dumps( - self.to_dict(), indent=indent, *args, **kwargs - ) + return json.dumps(self.to_dict(), indent=indent, *args, **kwargs) def to_yaml(self, indent: int = 4, *args, **kwargs): - return yaml.dump( - self.to_dict(), indent=indent, *args, **kwargs - ) + return yaml.dump(self.to_dict(), indent=indent, *args, **kwargs) def to_toml(self, *args, **kwargs): return toml.dumps(self.to_dict(), *args, **kwargs) @@ -2343,14 +2141,11 @@ class Agent: if exists(self.tool_schema): logger.info(f"Tool schema provided: {self.tool_schema}") - output = self.tool_struct.base_model_to_dict( - self.tool_schema, output_str=True - ) + output = self.tool_struct.base_model_to_dict(self.tool_schema, + output_str=True) # Add the tool schema to the short memory - self.short_memory.add( - role=self.agent_name, content=output - ) + self.short_memory.add(role=self.agent_name, content=output) # If multiple base models, then conver them. if exists(self.list_base_models): @@ -2359,13 +2154,10 @@ class Agent: ) schemas = self.tool_struct.multi_base_models_to_dict( - output_str=True - ) + output_str=True) # If the output is a string then add it to the memory - self.short_memory.add( - role=self.agent_name, content=schemas - ) + self.short_memory.add(role=self.agent_name, content=schemas) return None @@ -2411,14 +2203,10 @@ class Agent: # If the user inputs a list of strings for the sop then join them and set the sop if exists(self.sop_list): self.sop = "\n".join(self.sop_list) - self.short_memory.add( - role=self.user_name, content=self.sop - ) + self.short_memory.add(role=self.user_name, content=self.sop) if exists(self.sop): - self.short_memory.add( - role=self.user_name, content=self.sop - ) + self.short_memory.add(role=self.user_name, content=self.sop) logger.info("SOP Uploaded into the memory") @@ -2466,9 +2254,7 @@ class Agent: if scheduled_run_date: while datetime.now() < scheduled_run_date: - time.sleep( - 1 - ) # Sleep for a short period to avoid busy waiting + time.sleep(1) # Sleep for a short period to avoid busy waiting try: # If cluster ops disabled, run directly @@ -2489,9 +2275,8 @@ class Agent: except ValueError as e: self._handle_run_error(e) - def handle_artifacts( - self, text: str, file_output_path: str, file_extension: str - ) -> None: + def handle_artifacts(self, text: str, file_output_path: str, + file_extension: str) -> None: """Handle creating and saving artifacts with error handling.""" try: # Ensure file_extension starts with a dot @@ -2518,26 +2303,18 @@ class Agent: edit_count=0, ) - logger.info( - f"Saving artifact with extension: {file_extension}" - ) + logger.info(f"Saving artifact with extension: {file_extension}") artifact.save_as(file_extension) - logger.success( - f"Successfully saved artifact to {full_path}" - ) + logger.success(f"Successfully saved artifact to {full_path}") except ValueError as e: - logger.error( - f"Invalid input values for artifact: {str(e)}" - ) + logger.error(f"Invalid input values for artifact: {str(e)}") raise except IOError as e: logger.error(f"Error saving artifact to file: {str(e)}") raise except Exception as e: - logger.error( - f"Unexpected error handling artifact: {str(e)}" - ) + logger.error(f"Unexpected error handling artifact: {str(e)}") raise def showcase_config(self): @@ -2547,32 +2324,29 @@ class Agent: for key, value in config_dict.items(): if isinstance(value, list): # Format list as a comma-separated string - config_dict[key] = ", ".join( - str(item) for item in value - ) + config_dict[key] = ", ".join(str(item) for item in value) elif isinstance(value, dict): # Format dict as key-value pairs in a single string - config_dict[key] = ", ".join( - f"{k}: {v}" for k, v in value.items() - ) + config_dict[key] = ", ".join(f"{k}: {v}" + for k, v in value.items()) else: # Ensure any non-iterable value is a string config_dict[key] = str(value) - return formatter.print_table( - f"Agent: {self.agent_name} Configuration", config_dict - ) + return formatter.print_table(f"Agent: {self.agent_name} Configuration", + config_dict) - def talk_to( - self, agent: Any, task: str, img: str = None, *args, **kwargs - ) -> Any: + def talk_to(self, + agent: Any, + task: str, + img: str = None, + *args, + **kwargs) -> Any: """ Talk to another agent. """ # return agent.run(f"{agent.agent_name}: {task}", img, *args, **kwargs) - output = self.run( - f"{self.agent_name}: {task}", img, *args, **kwargs - ) + output = self.run(f"{self.agent_name}: {task}", img, *args, **kwargs) return agent.run( task=f"From {self.agent_name}: Message: {output}", @@ -2595,9 +2369,7 @@ class Agent: with ThreadPoolExecutor() as executor: # Create futures for each agent conversation futures = [ - executor.submit( - self.talk_to, agent, task, *args, **kwargs - ) + executor.submit(self.talk_to, agent, task, *args, **kwargs) for agent in agents ] @@ -2609,9 +2381,7 @@ class Agent: outputs.append(result) except Exception as e: logger.error(f"Error in agent communication: {e}") - outputs.append( - None - ) # or handle error case as needed + outputs.append(None) # or handle error case as needed return outputs @@ -2627,7 +2397,8 @@ class Agent: # self.stream_response(response) formatter.print_panel_token_by_token( f"{self.agent_name}: {response}", - title=f"Agent Name: {self.agent_name} [Max Loops: {loop_count}]", + title= + f"Agent Name: {self.agent_name} [Max Loops: {loop_count}]", ) else: # logger.info(f"Response: {response}") @@ -2647,14 +2418,13 @@ class Agent: else: return str(response) - - def sentiment_and_evaluator(self, response: str): if self.evaluator: logger.info("Evaluating response...") evaluated_response = self.evaluator(response) - print("Evaluated Response:" f" {evaluated_response}") + print("Evaluated Response:" + f" {evaluated_response}") self.short_memory.add( role="Evaluator", content=evaluated_response, @@ -2693,12 +2463,15 @@ class Agent: try: tool_calls = json.loads(response) is_json = True - logger.debug(f"Successfully parsed response as JSON: {tool_calls}") + logger.debug( + f"Successfully parsed response as JSON: {tool_calls}") except json.JSONDecodeError: # If not JSON, treat as natural language tool_calls = [response] is_json = False - logger.debug(f"Could not parse response as JSON, treating as natural language") + logger.debug( + f"Could not parse response as JSON, treating as natural language" + ) # Execute tool calls against MCP servers results = [] @@ -2708,7 +2481,9 @@ class Agent: if isinstance(tool_calls, dict): tool_calls = [tool_calls] - logger.debug(f"Executing {len(tool_calls)} tool calls against {len(self.mcp_servers)} MCP servers") + logger.debug( + f"Executing {len(tool_calls)} tool calls against {len(self.mcp_servers)} MCP servers" + ) for tool_call in tool_calls: try: @@ -2725,25 +2500,18 @@ class Agent: # Add successful result to memory with context self.short_memory.add( role="assistant", - content=f"Tool execution result: {result}" - ) + content=f"Tool execution result: {result}") else: error_msg = "No result from tool execution" errors.append(error_msg) logger.debug(error_msg) - self.short_memory.add( - role="error", - content=error_msg - ) + self.short_memory.add(role="error", content=error_msg) except Exception as e: error_msg = f"Error executing tool call: {str(e)}" errors.append(error_msg) logger.error(error_msg) - self.short_memory.add( - role="error", - content=error_msg - ) + self.short_memory.add(role="error", content=error_msg) # Format the final response if results: @@ -2760,20 +2528,17 @@ class Agent: if len(errors) == 1: return errors[0] else: - return "Multiple errors occurred:\n" + "\n".join(f"- {err}" for err in errors) + return "Multiple errors occurred:\n" + "\n".join( + f"- {err}" for err in errors) else: return "No results or errors returned" except Exception as e: error_msg = f"Error in MCP execution flow: {str(e)}" logger.error(error_msg) - self.short_memory.add( - role="error", - content=error_msg - ) + self.short_memory.add(role="error", content=error_msg) return error_msg - def mcp_execution_flow(self, response: str) -> str: """Synchronous wrapper for MCP execution flow. @@ -2797,15 +2562,17 @@ class Agent: if loop.is_running(): # We're in an async context, use run_coroutine_threadsafe - logger.debug("Using run_coroutine_threadsafe to execute MCP flow") + logger.debug( + "Using run_coroutine_threadsafe to execute MCP flow") future = asyncio.run_coroutine_threadsafe( - self.amcp_execution_flow(response), loop - ) - return future.result(timeout=30) # Adding timeout to prevent hanging + self.amcp_execution_flow(response), loop) + return future.result( + timeout=30) # Adding timeout to prevent hanging else: # We're not in an async context, use loop.run_until_complete logger.debug("Using run_until_complete to execute MCP flow") - return loop.run_until_complete(self.amcp_execution_flow(response)) + return loop.run_until_complete( + self.amcp_execution_flow(response)) except Exception as e: error_msg = f"Error in MCP execution flow wrapper: {str(e)}" diff --git a/swarms/tools/mcp_integration.py b/swarms/tools/mcp_integration.py index 0959f5f2..6dd25e91 100644 --- a/swarms/tools/mcp_integration.py +++ b/swarms/tools/mcp_integration.py @@ -1,320 +1,311 @@ - from __future__ import annotations - - import abc - import asyncio - from contextlib import AbstractAsyncContextManager, AsyncExitStack - from pathlib import Path - from typing import Any, Dict, List, Optional, Literal, Union - from typing_extensions import NotRequired, TypedDict - - from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream - from loguru import logger - from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client - from mcp.client.sse import sse_client - from mcp.types import CallToolResult, JSONRPCMessage - - from swarms.utils.any_to_str import any_to_str - - - class MCPServer(abc.ABC): - """Base class for Model Context Protocol servers.""" - - @abc.abstractmethod - async def connect(self) -> None: - """Establish connection to the MCP server.""" - pass - - @property - @abc.abstractmethod - def name(self) -> str: - """Human-readable server name.""" - pass - - @abc.abstractmethod - async def cleanup(self) -> None: - """Clean up resources and close connection.""" - pass - - @abc.abstractmethod - async def list_tools(self) -> List[MCPTool]: - """List available MCP tools on the server.""" - pass - - @abc.abstractmethod - async def call_tool( - self, tool_name: str, arguments: Dict[str, Any] | None - ) -> CallToolResult: - """Invoke a tool by name with provided arguments.""" - pass - - - class _MCPServerWithClientSession(MCPServer, abc.ABC): - """Mixin providing ClientSession-based MCP communication.""" - - def __init__(self, cache_tools_list: bool = False): - self.session: Optional[ClientSession] = None - self.exit_stack: AsyncExitStack = AsyncExitStack() - self._cleanup_lock = asyncio.Lock() - self.cache_tools_list = cache_tools_list - self._cache_dirty = True - self._tools_list: Optional[List[MCPTool]] = None - - @abc.abstractmethod - def create_streams( - self - ) -> AbstractAsyncContextManager[ - tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], - ] - ]: - """Supply the read/write streams for the MCP transport.""" - pass - - async def __aenter__(self) -> MCPServer: - await self.connect() - return self # type: ignore - - async def __aexit__(self, exc_type, exc_value, tb) -> None: - await self.cleanup() - - async def connect(self) -> None: - """Initialize transport and ClientSession.""" - try: - transport = await self.exit_stack.enter_async_context( - self.create_streams() - ) - read, write = transport - session = await self.exit_stack.enter_async_context( - ClientSession(read, write) - ) - await session.initialize() - self.session = session - except Exception as e: - logger.error(f"Error initializing MCP server: {e}") - await self.cleanup() - raise - - async def cleanup(self) -> None: - """Close session and transport.""" - async with self._cleanup_lock: - try: - await self.exit_stack.aclose() - except Exception as e: - logger.error(f"Error during cleanup: {e}") - finally: - self.session = None - - async def list_tools(self) -> List[MCPTool]: - if not self.session: - raise RuntimeError("Server not connected. Call connect() first.") - if self.cache_tools_list and not self._cache_dirty and self._tools_list: - return self._tools_list - self._cache_dirty = False - self._tools_list = (await self.session.list_tools()).tools - return self._tools_list # type: ignore - - async def call_tool( - self, tool_name: str | None = None, arguments: Dict[str, Any] | None = None - ) -> CallToolResult: - if not arguments: - raise ValueError("Arguments dict is required to call a tool") - name = tool_name or arguments.get("tool_name") or arguments.get("name") - if not name: - raise ValueError("Tool name missing in arguments") - if not self.session: - raise RuntimeError("Server not connected. Call connect() first.") - return await self.session.call_tool(name, arguments) - - - class MCPServerStdioParams(TypedDict): - """Configuration for stdio transport.""" - command: str - args: NotRequired[List[str]] - env: NotRequired[Dict[str, str]] - cwd: NotRequired[str | Path] - encoding: NotRequired[str] - encoding_error_handler: NotRequired[Literal["strict", "ignore", "replace"]] - - - class MCPServerStdio(_MCPServerWithClientSession): - """MCP server over stdio transport.""" - - def __init__( - self, - params: MCPServerStdioParams, - cache_tools_list: bool = False, - name: Optional[str] = None, - ): - super().__init__(cache_tools_list) - self.params = StdioServerParameters( - command=params["command"], - args=params.get("args", []), - env=params.get("env"), - cwd=params.get("cwd"), - encoding=params.get("encoding", "utf-8"), - encoding_error_handler=params.get("encoding_error_handler", "strict"), - ) - self._name = name or f"stdio:{self.params.command}" - - def create_streams(self) -> AbstractAsyncContextManager[ - tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], - ] - ]: - return stdio_client(self.params) - - @property - def name(self) -> str: - return self._name - - - class MCPServerSseParams(TypedDict): - """Configuration for HTTP+SSE transport.""" - url: str - headers: NotRequired[Dict[str, str]] - timeout: NotRequired[float] - sse_read_timeout: NotRequired[float] - - - class MCPServerSse(_MCPServerWithClientSession): - """MCP server over HTTP with SSE transport.""" - - def __init__( - self, - params: MCPServerSseParams, - cache_tools_list: bool = False, - name: Optional[str] = None, - ): - super().__init__(cache_tools_list) - self.params = params - self._name = name or f"sse:{params['url']}" - - def create_streams(self) -> AbstractAsyncContextManager[ - tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], - ] - ]: - return sse_client( - url=self.params["url"], - headers=self.params.get("headers"), - timeout=self.params.get("timeout", 5), - sse_read_timeout=self.params.get("sse_read_timeout", 300), - ) - - @property - def name(self) -> str: - return self._name - - - async def call_tool_fast( - server: MCPServerSse, payload: Dict[str, Any] | str - ) -> Any: - """Async function to call a tool on a server with proper cleanup.""" - try: - await server.connect() - arguments = payload if isinstance(payload, dict) else None - result = await server.call_tool(arguments=arguments) - return result - finally: - await server.cleanup() - - - async def mcp_flow_get_tool_schema( - params: MCPServerSseParams, - ) -> Any: - """Async function to get tool schema from MCP server.""" - async with MCPServerSse(params) as server: - tools = await server.list_tools() - return tools - - - async def mcp_flow( - params: MCPServerSseParams, - function_call: Dict[str, Any] | str, - ) -> Any: - """Async function to call a tool with given parameters.""" - async with MCPServerSse(params) as server: - return await call_tool_fast(server, function_call) - - - async def _call_one_server( - params: MCPServerSseParams, payload: Dict[str, Any] | str - ) -> Any: - """Helper function to call a single MCP server.""" - server = MCPServerSse(params) - try: - await server.connect() - arguments = payload if isinstance(payload, dict) else None - return await server.call_tool(arguments=arguments) - finally: - await server.cleanup() - - - async def abatch_mcp_flow( - params: List[MCPServerSseParams], payload: Dict[str, Any] | str - ) -> List[Any]: - """Async function to execute a batch of MCP calls concurrently. - - Args: - params (List[MCPServerSseParams]): List of MCP server configurations - payload (Dict[str, Any] | str): The payload to send to each server - - Returns: - List[Any]: Results from all MCP servers - """ - if not params: - logger.warning("No MCP servers provided for batch operation") - return [] - - try: - return await asyncio.gather(*[_call_one_server(p, payload) for p in params]) - except Exception as e: - logger.error(f"Error in abatch_mcp_flow: {e}") - # Return partial results if any were successful - return [f"Error in batch operation: {str(e)}"] - - - def batch_mcp_flow( - params: List[MCPServerSseParams], payload: Dict[str, Any] | str - ) -> List[Any]: - """Sync wrapper for batch MCP operations. - - This creates a new event loop if needed to run the async batch operation. - ONLY use this when not already in an async context. - - Args: - params (List[MCPServerSseParams]): List of MCP server configurations - payload (Dict[str, Any] | str): The payload to send to each server - - Returns: - List[Any]: Results from all MCP servers - """ - if not params: - logger.warning("No MCP servers provided for batch operation") - return [] - - try: - # Check if we're already in an event loop - try: - loop = asyncio.get_event_loop() - except RuntimeError: - # No event loop exists, create one - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - if loop.is_running(): - # We're already in an async context, can't use asyncio.run - # Use a future to bridge sync-async gap - future = asyncio.run_coroutine_threadsafe( - abatch_mcp_flow(params, payload), loop - ) - return future.result(timeout=30) # Add timeout to prevent hanging - else: - # We're not in an async context, safe to use loop.run_until_complete - return loop.run_until_complete(abatch_mcp_flow(params, payload)) - except Exception as e: - logger.error(f"Error in batch_mcp_flow: {e}") - return [f"Error in batch operation: {str(e)}"] \ No newline at end of file +from __future__ import annotations + +import abc +import asyncio +from contextlib import AbstractAsyncContextManager, AsyncExitStack +from pathlib import Path +from typing import Any, Dict, List, Optional, Literal, Union +from typing_extensions import NotRequired, TypedDict + +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from loguru import logger +from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client +from mcp.client.sse import sse_client +from mcp.types import CallToolResult, JSONRPCMessage + +from swarms.utils.any_to_str import any_to_str + + +class MCPServer(abc.ABC): + """Base class for Model Context Protocol servers.""" + + @abc.abstractmethod + async def connect(self) -> None: + """Establish connection to the MCP server.""" + pass + + @property + @abc.abstractmethod + def name(self) -> str: + """Human-readable server name.""" + pass + + @abc.abstractmethod + async def cleanup(self) -> None: + """Clean up resources and close connection.""" + pass + + @abc.abstractmethod + async def list_tools(self) -> List[MCPTool]: + """List available MCP tools on the server.""" + pass + + @abc.abstractmethod + async def call_tool(self, tool_name: str, + arguments: Dict[str, Any] | None) -> CallToolResult: + """Invoke a tool by name with provided arguments.""" + pass + + +class _MCPServerWithClientSession(MCPServer, abc.ABC): + """Mixin providing ClientSession-based MCP communication.""" + + def __init__(self, cache_tools_list: bool = False): + self.session: Optional[ClientSession] = None + self.exit_stack: AsyncExitStack = AsyncExitStack() + self._cleanup_lock = asyncio.Lock() + self.cache_tools_list = cache_tools_list + self._cache_dirty = True + self._tools_list: Optional[List[MCPTool]] = None + + @abc.abstractmethod + def create_streams( + self + ) -> AbstractAsyncContextManager[tuple[ + MemoryObjectReceiveStream[JSONRPCMessage | Exception], + MemoryObjectSendStream[JSONRPCMessage], + ]]: + """Supply the read/write streams for the MCP transport.""" + pass + + async def __aenter__(self) -> MCPServer: + await self.connect() + return self # type: ignore + + async def __aexit__(self, exc_type, exc_value, tb) -> None: + await self.cleanup() + + async def connect(self) -> None: + """Initialize transport and ClientSession.""" + try: + transport = await self.exit_stack.enter_async_context( + self.create_streams()) + read, write = transport + session = await self.exit_stack.enter_async_context( + ClientSession(read, write)) + await session.initialize() + self.session = session + except Exception as e: + logger.error(f"Error initializing MCP server: {e}") + await self.cleanup() + raise + + async def cleanup(self) -> None: + """Close session and transport.""" + async with self._cleanup_lock: + try: + await self.exit_stack.aclose() + except Exception as e: + logger.error(f"Error during cleanup: {e}") + finally: + self.session = None + + async def list_tools(self) -> List[MCPTool]: + if not self.session: + raise RuntimeError("Server not connected. Call connect() first.") + if self.cache_tools_list and not self._cache_dirty and self._tools_list: + return self._tools_list + self._cache_dirty = False + self._tools_list = (await self.session.list_tools()).tools + return self._tools_list # type: ignore + + async def call_tool( + self, + tool_name: str | None = None, + arguments: Dict[str, Any] | None = None) -> CallToolResult: + if not arguments: + raise ValueError("Arguments dict is required to call a tool") + name = tool_name or arguments.get("tool_name") or arguments.get("name") + if not name: + raise ValueError("Tool name missing in arguments") + if not self.session: + raise RuntimeError("Server not connected. Call connect() first.") + return await self.session.call_tool(name, arguments) + + +class MCPServerStdioParams(TypedDict): + """Configuration for stdio transport.""" + command: str + args: NotRequired[List[str]] + env: NotRequired[Dict[str, str]] + cwd: NotRequired[str | Path] + encoding: NotRequired[str] + encoding_error_handler: NotRequired[Literal["strict", "ignore", "replace"]] + + +class MCPServerStdio(_MCPServerWithClientSession): + """MCP server over stdio transport.""" + + def __init__( + self, + params: MCPServerStdioParams, + cache_tools_list: bool = False, + name: Optional[str] = None, + ): + super().__init__(cache_tools_list) + self.params = StdioServerParameters( + command=params["command"], + args=params.get("args", []), + env=params.get("env"), + cwd=params.get("cwd"), + encoding=params.get("encoding", "utf-8"), + encoding_error_handler=params.get("encoding_error_handler", + "strict"), + ) + self._name = name or f"stdio:{self.params.command}" + + def create_streams( + self + ) -> AbstractAsyncContextManager[tuple[ + MemoryObjectReceiveStream[JSONRPCMessage | Exception], + MemoryObjectSendStream[JSONRPCMessage], + ]]: + return stdio_client(self.params) + + @property + def name(self) -> str: + return self._name + + +class MCPServerSseParams(TypedDict): + """Configuration for HTTP+SSE transport.""" + url: str + headers: NotRequired[Dict[str, str]] + timeout: NotRequired[float] + sse_read_timeout: NotRequired[float] + + +class MCPServerSse(_MCPServerWithClientSession): + """MCP server over HTTP with SSE transport.""" + + def __init__( + self, + params: MCPServerSseParams, + cache_tools_list: bool = False, + name: Optional[str] = None, + ): + super().__init__(cache_tools_list) + self.params = params + self._name = name or f"sse:{params['url']}" + + def create_streams( + self + ) -> AbstractAsyncContextManager[tuple[ + MemoryObjectReceiveStream[JSONRPCMessage | Exception], + MemoryObjectSendStream[JSONRPCMessage], + ]]: + return sse_client( + url=self.params["url"], + headers=self.params.get("headers"), + timeout=self.params.get("timeout", 5), + sse_read_timeout=self.params.get("sse_read_timeout", 300), + ) + + @property + def name(self) -> str: + return self._name + + +async def call_tool_fast(server: MCPServerSse, + payload: Dict[str, Any] | str) -> Any: + """Async function to call a tool on a server with proper cleanup.""" + try: + await server.connect() + arguments = payload if isinstance(payload, dict) else None + result = await server.call_tool(arguments=arguments) + return result + finally: + await server.cleanup() + + +async def mcp_flow_get_tool_schema(params: MCPServerSseParams, ) -> Any: + """Async function to get tool schema from MCP server.""" + async with MCPServerSse(params) as server: + tools = await server.list_tools() + return tools + + +async def mcp_flow( + params: MCPServerSseParams, + function_call: Dict[str, Any] | str, +) -> Any: + """Async function to call a tool with given parameters.""" + async with MCPServerSse(params) as server: + return await call_tool_fast(server, function_call) + + +async def _call_one_server(params: MCPServerSseParams, + payload: Dict[str, Any] | str) -> Any: + """Helper function to call a single MCP server.""" + server = MCPServerSse(params) + try: + await server.connect() + arguments = payload if isinstance(payload, dict) else None + return await server.call_tool(arguments=arguments) + finally: + await server.cleanup() + + +async def abatch_mcp_flow(params: List[MCPServerSseParams], + payload: Dict[str, Any] | str) -> List[Any]: + """Async function to execute a batch of MCP calls concurrently. + + Args: + params (List[MCPServerSseParams]): List of MCP server configurations + payload (Dict[str, Any] | str): The payload to send to each server + + Returns: + List[Any]: Results from all MCP servers + """ + if not params: + logger.warning("No MCP servers provided for batch operation") + return [] + + try: + return await asyncio.gather( + *[_call_one_server(p, payload) for p in params]) + except Exception as e: + logger.error(f"Error in abatch_mcp_flow: {e}") + # Return partial results if any were successful + return [f"Error in batch operation: {str(e)}"] + + +def batch_mcp_flow(params: List[MCPServerSseParams], + payload: Dict[str, Any] | str) -> List[Any]: + """Sync wrapper for batch MCP operations. + + This creates a new event loop if needed to run the async batch operation. + ONLY use this when not already in an async context. + + Args: + params (List[MCPServerSseParams]): List of MCP server configurations + payload (Dict[str, Any] | str): The payload to send to each server + + Returns: + List[Any]: Results from all MCP servers + """ + if not params: + logger.warning("No MCP servers provided for batch operation") + return [] + + try: + # Check if we're already in an event loop + try: + loop = asyncio.get_event_loop() + except RuntimeError: + # No event loop exists, create one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + if loop.is_running(): + # We're already in an async context, can't use asyncio.run + # Use a future to bridge sync-async gap + future = asyncio.run_coroutine_threadsafe( + abatch_mcp_flow(params, payload), loop) + return future.result(timeout=30) # Add timeout to prevent hanging + else: + # We're not in an async context, safe to use loop.run_until_complete + return loop.run_until_complete(abatch_mcp_flow(params, payload)) + except Exception as e: + logger.error(f"Error in batch_mcp_flow: {e}") + return [f"Error in batch operation: {str(e)}"]