|
|
|
|
@ -7,7 +7,7 @@ import traceback
|
|
|
|
|
from collections import deque
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
from enum import Enum
|
|
|
|
|
from typing import Any, Dict, List, Literal, Optional
|
|
|
|
|
from typing import Any, Callable, Dict, List, Literal, Optional
|
|
|
|
|
from uuid import uuid4
|
|
|
|
|
|
|
|
|
|
from loguru import logger
|
|
|
|
|
@ -20,6 +20,20 @@ from swarms.tools.mcp_client_tools import (
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Middleware type definition
|
|
|
|
|
# A middleware function receives the tool execution context and can modify inputs/outputs.
|
|
|
|
|
# Middleware functions are called before tool execution and can modify params and context in-place.
|
|
|
|
|
# Args:
|
|
|
|
|
# tool_name: Name of the tool being executed
|
|
|
|
|
# params: Dictionary of tool parameters (task, img, imgs, correct_answer, max_retries)
|
|
|
|
|
# Can be modified in-place by the middleware
|
|
|
|
|
# context: Additional context dictionary (agent, config, etc.)
|
|
|
|
|
# Can be modified in-place by the middleware
|
|
|
|
|
# Returns:
|
|
|
|
|
# None (modifications are done in-place)
|
|
|
|
|
MiddlewareType = Callable[[str, Dict[str, Any], Dict[str, Any]], None]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TaskStatus(Enum):
|
|
|
|
|
"""Status of a task in the queue."""
|
|
|
|
|
|
|
|
|
|
@ -558,12 +572,14 @@ class AOP:
|
|
|
|
|
4. Manage the MCP server lifecycle
|
|
|
|
|
5. Queue-based task execution for improved performance and reliability
|
|
|
|
|
6. Persistence mode with automatic restart and failsafe protection
|
|
|
|
|
7. Custom middleware support for intercepting and modifying tool executions
|
|
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
|
mcp_server: The FastMCP server instance
|
|
|
|
|
agents: Dictionary mapping tool names to agent instances
|
|
|
|
|
tool_configs: Dictionary mapping tool names to their configurations
|
|
|
|
|
task_queues: Dictionary mapping tool names to their task queues
|
|
|
|
|
middlewares: List of middleware functions to apply to tool executions
|
|
|
|
|
server_name: Name of the MCP server
|
|
|
|
|
queue_enabled: Whether queue-based execution is enabled
|
|
|
|
|
persistence: Whether persistence mode is enabled
|
|
|
|
|
@ -573,6 +589,27 @@ class AOP:
|
|
|
|
|
max_network_retries: Maximum number of network reconnection attempts
|
|
|
|
|
network_retry_delay: Delay between network retry attempts in seconds
|
|
|
|
|
network_timeout: Network connection timeout in seconds
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
>>> from swarms import Agent, AOP
|
|
|
|
|
>>>
|
|
|
|
|
>>> # Define a middleware function
|
|
|
|
|
>>> def auth_middleware(tool_name: str, params: dict, context: dict) -> None:
|
|
|
|
|
... # Add authentication logic here
|
|
|
|
|
... if not context.get("authenticated"):
|
|
|
|
|
... raise ValueError("Not authenticated")
|
|
|
|
|
... # Modify params if needed
|
|
|
|
|
... params["task"] = f"[AUTH] {params['task']}"
|
|
|
|
|
>>>
|
|
|
|
|
>>> # Create AOP with middleware
|
|
|
|
|
>>> aop = AOP(
|
|
|
|
|
... server_name="MyServer",
|
|
|
|
|
... middlewares=[auth_middleware]
|
|
|
|
|
... )
|
|
|
|
|
>>>
|
|
|
|
|
>>> # Add agents
|
|
|
|
|
>>> agent = Agent(model_name="gpt-4")
|
|
|
|
|
>>> aop.add_agent(agent)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
@ -600,6 +637,7 @@ class AOP:
|
|
|
|
|
log_level: Literal[
|
|
|
|
|
"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"
|
|
|
|
|
] = "INFO",
|
|
|
|
|
middlewares: Optional[List[MiddlewareType]] = None,
|
|
|
|
|
*args,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
@ -628,6 +666,10 @@ class AOP:
|
|
|
|
|
max_network_retries: Maximum number of network reconnection attempts
|
|
|
|
|
network_retry_delay: Delay between network retry attempts in seconds
|
|
|
|
|
network_timeout: Network connection timeout in seconds
|
|
|
|
|
middlewares: Optional list of middleware functions to apply to tool executions.
|
|
|
|
|
Each middleware receives (tool_name, params, context) and can modify
|
|
|
|
|
params and context in-place. Middlewares are applied in order before
|
|
|
|
|
each tool execution.
|
|
|
|
|
"""
|
|
|
|
|
self.server_name = server_name
|
|
|
|
|
self.description = description
|
|
|
|
|
@ -663,6 +705,7 @@ class AOP:
|
|
|
|
|
self.tool_configs: Dict[str, AgentToolConfig] = {}
|
|
|
|
|
self.task_queues: Dict[str, TaskQueue] = {}
|
|
|
|
|
self.transport = transport
|
|
|
|
|
self.middlewares: List[MiddlewareType] = middlewares or []
|
|
|
|
|
self.mcp_server = FastMCP(
|
|
|
|
|
name=server_name,
|
|
|
|
|
port=port,
|
|
|
|
|
@ -1020,6 +1063,47 @@ class AOP:
|
|
|
|
|
"error": error_msg,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Prepare params and context for middleware
|
|
|
|
|
params = {
|
|
|
|
|
"task": task,
|
|
|
|
|
"img": img,
|
|
|
|
|
"imgs": imgs,
|
|
|
|
|
"correct_answer": correct_answer,
|
|
|
|
|
"max_retries": max_retries,
|
|
|
|
|
}
|
|
|
|
|
context = {
|
|
|
|
|
"agent": agent,
|
|
|
|
|
"config": config,
|
|
|
|
|
"tool_name": tool_name,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Apply middleware in order
|
|
|
|
|
for middleware in self.middlewares:
|
|
|
|
|
try:
|
|
|
|
|
middleware(tool_name, params, context)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
# Middleware exceptions should stop execution
|
|
|
|
|
# This allows middleware to reject requests (e.g., auth failures)
|
|
|
|
|
error_msg = f"Middleware error for tool '{tool_name}': {str(e)}"
|
|
|
|
|
logger.warning(error_msg)
|
|
|
|
|
if config.traceback_enabled:
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"Middleware traceback: {traceback.format_exc()}"
|
|
|
|
|
)
|
|
|
|
|
# Return error response instead of continuing
|
|
|
|
|
return {
|
|
|
|
|
"result": "",
|
|
|
|
|
"success": False,
|
|
|
|
|
"error": error_msg,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Extract params after middleware processing
|
|
|
|
|
task = params.get("task", task)
|
|
|
|
|
img = params.get("img", img)
|
|
|
|
|
imgs = params.get("imgs", imgs)
|
|
|
|
|
correct_answer = params.get("correct_answer", correct_answer)
|
|
|
|
|
max_retries = params.get("max_retries", max_retries)
|
|
|
|
|
|
|
|
|
|
# Use queue-based execution if enabled
|
|
|
|
|
if (
|
|
|
|
|
self.queue_enabled
|
|
|
|
|
|