Implement middleware support for tool executions

Added support for middleware functions to modify tool execution parameters and context.
pull/1183/head
CI-DEV 2 months ago committed by GitHub
parent 638e9e2ba2
commit 5b61979350
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -7,7 +7,7 @@ import traceback
from collections import deque
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Callable, Dict, List, Literal, Optional
from uuid import uuid4
from loguru import logger
@ -20,6 +20,20 @@ from swarms.tools.mcp_client_tools import (
)
# Middleware type definition
# A middleware function receives the tool execution context and can modify inputs/outputs.
# Middleware functions are called before tool execution and can modify params and context in-place.
# Args:
# tool_name: Name of the tool being executed
# params: Dictionary of tool parameters (task, img, imgs, correct_answer, max_retries)
# Can be modified in-place by the middleware
# context: Additional context dictionary (agent, config, etc.)
# Can be modified in-place by the middleware
# Returns:
# None (modifications are done in-place)
MiddlewareType = Callable[[str, Dict[str, Any], Dict[str, Any]], None]
class TaskStatus(Enum):
"""Status of a task in the queue."""
@ -558,12 +572,14 @@ class AOP:
4. Manage the MCP server lifecycle
5. Queue-based task execution for improved performance and reliability
6. Persistence mode with automatic restart and failsafe protection
7. Custom middleware support for intercepting and modifying tool executions
Attributes:
mcp_server: The FastMCP server instance
agents: Dictionary mapping tool names to agent instances
tool_configs: Dictionary mapping tool names to their configurations
task_queues: Dictionary mapping tool names to their task queues
middlewares: List of middleware functions to apply to tool executions
server_name: Name of the MCP server
queue_enabled: Whether queue-based execution is enabled
persistence: Whether persistence mode is enabled
@ -573,6 +589,27 @@ class AOP:
max_network_retries: Maximum number of network reconnection attempts
network_retry_delay: Delay between network retry attempts in seconds
network_timeout: Network connection timeout in seconds
Example:
>>> from swarms import Agent, AOP
>>>
>>> # Define a middleware function
>>> def auth_middleware(tool_name: str, params: dict, context: dict) -> None:
... # Add authentication logic here
... if not context.get("authenticated"):
... raise ValueError("Not authenticated")
... # Modify params if needed
... params["task"] = f"[AUTH] {params['task']}"
>>>
>>> # Create AOP with middleware
>>> aop = AOP(
... server_name="MyServer",
... middlewares=[auth_middleware]
... )
>>>
>>> # Add agents
>>> agent = Agent(model_name="gpt-4")
>>> aop.add_agent(agent)
"""
def __init__(
@ -600,6 +637,7 @@ class AOP:
log_level: Literal[
"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"
] = "INFO",
middlewares: Optional[List[MiddlewareType]] = None,
*args,
**kwargs,
):
@ -628,6 +666,10 @@ class AOP:
max_network_retries: Maximum number of network reconnection attempts
network_retry_delay: Delay between network retry attempts in seconds
network_timeout: Network connection timeout in seconds
middlewares: Optional list of middleware functions to apply to tool executions.
Each middleware receives (tool_name, params, context) and can modify
params and context in-place. Middlewares are applied in order before
each tool execution.
"""
self.server_name = server_name
self.description = description
@ -663,6 +705,7 @@ class AOP:
self.tool_configs: Dict[str, AgentToolConfig] = {}
self.task_queues: Dict[str, TaskQueue] = {}
self.transport = transport
self.middlewares: List[MiddlewareType] = middlewares or []
self.mcp_server = FastMCP(
name=server_name,
port=port,
@ -1020,6 +1063,47 @@ class AOP:
"error": error_msg,
}
# Prepare params and context for middleware
params = {
"task": task,
"img": img,
"imgs": imgs,
"correct_answer": correct_answer,
"max_retries": max_retries,
}
context = {
"agent": agent,
"config": config,
"tool_name": tool_name,
}
# Apply middleware in order
for middleware in self.middlewares:
try:
middleware(tool_name, params, context)
except Exception as e:
# Middleware exceptions should stop execution
# This allows middleware to reject requests (e.g., auth failures)
error_msg = f"Middleware error for tool '{tool_name}': {str(e)}"
logger.warning(error_msg)
if config.traceback_enabled:
logger.debug(
f"Middleware traceback: {traceback.format_exc()}"
)
# Return error response instead of continuing
return {
"result": "",
"success": False,
"error": error_msg,
}
# Extract params after middleware processing
task = params.get("task", task)
img = params.get("img", img)
imgs = params.get("imgs", imgs)
correct_answer = params.get("correct_answer", correct_answer)
max_retries = params.get("max_retries", max_retries)
# Use queue-based execution if enabled
if (
self.queue_enabled

Loading…
Cancel
Save