diff --git a/swarms/structs/aop.py b/swarms/structs/aop.py index b95acb77..1fc32f35 100644 --- a/swarms/structs/aop.py +++ b/swarms/structs/aop.py @@ -7,7 +7,7 @@ import traceback from collections import deque from dataclasses import dataclass, field from enum import Enum -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Callable, Dict, List, Literal, Optional from uuid import uuid4 from loguru import logger @@ -20,6 +20,20 @@ from swarms.tools.mcp_client_tools import ( ) +# Middleware type definition +# A middleware function receives the tool execution context and can modify inputs/outputs. +# Middleware functions are called before tool execution and can modify params and context in-place. +# Args: +# tool_name: Name of the tool being executed +# params: Dictionary of tool parameters (task, img, imgs, correct_answer, max_retries) +# Can be modified in-place by the middleware +# context: Additional context dictionary (agent, config, etc.) +# Can be modified in-place by the middleware +# Returns: +# None (modifications are done in-place) +MiddlewareType = Callable[[str, Dict[str, Any], Dict[str, Any]], None] + + class TaskStatus(Enum): """Status of a task in the queue.""" @@ -558,12 +572,14 @@ class AOP: 4. Manage the MCP server lifecycle 5. Queue-based task execution for improved performance and reliability 6. Persistence mode with automatic restart and failsafe protection + 7. Custom middleware support for intercepting and modifying tool executions Attributes: mcp_server: The FastMCP server instance agents: Dictionary mapping tool names to agent instances tool_configs: Dictionary mapping tool names to their configurations task_queues: Dictionary mapping tool names to their task queues + middlewares: List of middleware functions to apply to tool executions server_name: Name of the MCP server queue_enabled: Whether queue-based execution is enabled persistence: Whether persistence mode is enabled @@ -573,6 +589,27 @@ class AOP: max_network_retries: Maximum number of network reconnection attempts network_retry_delay: Delay between network retry attempts in seconds network_timeout: Network connection timeout in seconds + + Example: + >>> from swarms import Agent, AOP + >>> + >>> # Define a middleware function + >>> def auth_middleware(tool_name: str, params: dict, context: dict) -> None: + ... # Add authentication logic here + ... if not context.get("authenticated"): + ... raise ValueError("Not authenticated") + ... # Modify params if needed + ... params["task"] = f"[AUTH] {params['task']}" + >>> + >>> # Create AOP with middleware + >>> aop = AOP( + ... server_name="MyServer", + ... middlewares=[auth_middleware] + ... ) + >>> + >>> # Add agents + >>> agent = Agent(model_name="gpt-4") + >>> aop.add_agent(agent) """ def __init__( @@ -600,6 +637,7 @@ class AOP: log_level: Literal[ "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL" ] = "INFO", + middlewares: Optional[List[MiddlewareType]] = None, *args, **kwargs, ): @@ -628,6 +666,10 @@ class AOP: max_network_retries: Maximum number of network reconnection attempts network_retry_delay: Delay between network retry attempts in seconds network_timeout: Network connection timeout in seconds + middlewares: Optional list of middleware functions to apply to tool executions. + Each middleware receives (tool_name, params, context) and can modify + params and context in-place. Middlewares are applied in order before + each tool execution. """ self.server_name = server_name self.description = description @@ -663,6 +705,7 @@ class AOP: self.tool_configs: Dict[str, AgentToolConfig] = {} self.task_queues: Dict[str, TaskQueue] = {} self.transport = transport + self.middlewares: List[MiddlewareType] = middlewares or [] self.mcp_server = FastMCP( name=server_name, port=port, @@ -1020,6 +1063,47 @@ class AOP: "error": error_msg, } + # Prepare params and context for middleware + params = { + "task": task, + "img": img, + "imgs": imgs, + "correct_answer": correct_answer, + "max_retries": max_retries, + } + context = { + "agent": agent, + "config": config, + "tool_name": tool_name, + } + + # Apply middleware in order + for middleware in self.middlewares: + try: + middleware(tool_name, params, context) + except Exception as e: + # Middleware exceptions should stop execution + # This allows middleware to reject requests (e.g., auth failures) + error_msg = f"Middleware error for tool '{tool_name}': {str(e)}" + logger.warning(error_msg) + if config.traceback_enabled: + logger.debug( + f"Middleware traceback: {traceback.format_exc()}" + ) + # Return error response instead of continuing + return { + "result": "", + "success": False, + "error": error_msg, + } + + # Extract params after middleware processing + task = params.get("task", task) + img = params.get("img", img) + imgs = params.get("imgs", imgs) + correct_answer = params.get("correct_answer", correct_answer) + max_retries = params.get("max_retries", max_retries) + # Use queue-based execution if enabled if ( self.queue_enabled