Update model_router.py

pull/1034/head
CI-DEV 2 months ago committed by GitHub
parent 1cbbd4874a
commit 975f038c63
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,13 +1,14 @@
import asyncio import asyncio
import os import os
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Optional from typing import Optional, List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from swarms.utils.function_caller_model import OpenAIFunctionCaller from swarms.utils.function_caller_model import OpenAIFunctionCaller
from swarms.utils.any_to_str import any_to_str from swarms.utils.any_to_str import any_to_str
from swarms.utils.formatter import formatter from swarms.utils.formatter import formatter
from swarms.utils.litellm_wrapper import LiteLLM from swarms.utils.litellm_wrapper import LiteLLM
from swarms.security import SwarmShieldIntegration, ShieldConfig
model_recommendations = { model_recommendations = {
"gpt-4o": { "gpt-4o": {
@ -185,6 +186,9 @@ class ModelRouter:
max_workers: int = 10, max_workers: int = 10,
api_key: str = None, api_key: str = None,
max_loops: int = 1, max_loops: int = 1,
shield_config: Optional[ShieldConfig] = None,
enable_security: bool = True,
security_level: str = "standard",
*args, *args,
**kwargs, **kwargs,
): ):
@ -196,6 +200,11 @@ class ModelRouter:
max_tokens (int): Maximum output tokens max_tokens (int): Maximum output tokens
temperature (float): Model temperature parameter temperature (float): Model temperature parameter
max_workers (int): Max concurrent workers max_workers (int): Max concurrent workers
api_key (str): API key for the model caller
max_loops (int): Maximum number of execution loops
shield_config (ShieldConfig, optional): Security configuration for SwarmShield integration. Defaults to None.
enable_security (bool, optional): Whether to enable SwarmShield security features. Defaults to True.
security_level (str, optional): Pre-defined security level. Options: "basic", "standard", "enhanced", "maximum". Defaults to "standard".
*args: Additional positional arguments *args: Additional positional arguments
**kwargs: Additional keyword arguments **kwargs: Additional keyword arguments
""" """
@ -207,6 +216,9 @@ class ModelRouter:
self.model_output = ModelOutput self.model_output = ModelOutput
self.max_loops = max_loops self.max_loops = max_loops
# Initialize SwarmShield integration
self._initialize_swarm_shield(shield_config, enable_security, security_level)
if self.max_workers == "auto": if self.max_workers == "auto":
self.max_workers = os.cpu_count() self.max_workers = os.cpu_count()
@ -221,6 +233,86 @@ class ModelRouter:
f"Failed to initialize ModelRouter: {str(e)}" f"Failed to initialize ModelRouter: {str(e)}"
) )
def _initialize_swarm_shield(
self,
shield_config: Optional[ShieldConfig] = None,
enable_security: bool = True,
security_level: str = "standard"
) -> None:
"""Initialize SwarmShield integration for security features."""
self.enable_security = enable_security
self.security_level = security_level
if enable_security:
if shield_config is None:
shield_config = ShieldConfig.get_security_level(security_level)
self.swarm_shield = SwarmShieldIntegration(shield_config)
else:
self.swarm_shield = None
# Security methods
def validate_task_with_shield(self, task: str) -> str:
"""Validate and sanitize task input using SwarmShield."""
if self.swarm_shield:
return self.swarm_shield.validate_and_protect_input(task)
return task
def validate_agent_config_with_shield(self, agent_config: dict) -> dict:
"""Validate agent configuration using SwarmShield."""
if self.swarm_shield:
return self.swarm_shield.validate_and_protect_input(str(agent_config))
return agent_config
def process_agent_communication_with_shield(self, message: str, agent_name: str) -> str:
"""Process agent communication through SwarmShield security."""
if self.swarm_shield:
return self.swarm_shield.process_agent_communication(message, agent_name)
return message
def check_rate_limit_with_shield(self, agent_name: str) -> bool:
"""Check rate limits for an agent using SwarmShield."""
if self.swarm_shield:
return self.swarm_shield.check_rate_limit(agent_name)
return True
def add_secure_message(self, message: str, agent_name: str) -> None:
"""Add a message to secure conversation history."""
if self.swarm_shield:
self.swarm_shield.add_secure_message(message, agent_name)
def get_secure_messages(self) -> List[dict]:
"""Get secure conversation messages."""
if self.swarm_shield:
return self.swarm_shield.get_secure_messages()
return []
def get_security_stats(self) -> dict:
"""Get security statistics and metrics."""
if self.swarm_shield:
return self.swarm_shield.get_security_stats()
return {"security_enabled": False}
def update_shield_config(self, new_config: ShieldConfig) -> None:
"""Update SwarmShield configuration."""
if self.swarm_shield:
self.swarm_shield.update_config(new_config)
def enable_security(self) -> None:
"""Enable SwarmShield security features."""
if not self.swarm_shield:
self._initialize_swarm_shield(enable_security=True, security_level=self.security_level)
def disable_security(self) -> None:
"""Disable SwarmShield security features."""
self.swarm_shield = None
self.enable_security = False
def cleanup_security(self) -> None:
"""Clean up SwarmShield resources."""
if self.swarm_shield:
self.swarm_shield.cleanup()
def step(self, task: str): def step(self, task: str):
""" """
Run a single task through the model router. Run a single task through the model router.

Loading…
Cancel
Save