diff --git a/swarms/structs/base_swarm.py b/swarms/structs/base_swarm.py index 3d7b6e08..16f5fd45 100644 --- a/swarms/structs/base_swarm.py +++ b/swarms/structs/base_swarm.py @@ -2,6 +2,7 @@ import os import asyncio import json import uuid +from datetime import datetime from swarms.utils.file_processing import create_file_in_folder from abc import ABC from concurrent.futures import ThreadPoolExecutor, as_completed @@ -13,6 +14,7 @@ from typing import ( Optional, Sequence, Union, + Tuple, ) import yaml @@ -23,6 +25,9 @@ from swarms.structs.omni_agent_types import AgentType from pydantic import BaseModel from swarms.utils.loguru_logger import initialize_logger +# Import SwarmShield components +from swarms.security import SwarmShieldIntegration, ShieldConfig + logger = initialize_logger(log_folder="base_swarm") @@ -33,7 +38,8 @@ class BaseSwarm(ABC): Attributes: agents (List[Agent]): A list of agents max_loops (int): The maximum number of loops to run - + shield_config (ShieldConfig): Security configuration for SwarmShield + shield (SwarmShieldIntegration): Security shield integration Methods: communicate: Communicate with the swarm through the orchestrator, protocols, and the universal communication layer @@ -95,6 +101,10 @@ class BaseSwarm(ABC): collective_memory_system: Optional[Any] = False, agent_ops_on: bool = False, output_schema: Optional[BaseModel] = None, + # SwarmShield parameters + shield_config: Optional[ShieldConfig] = None, + enable_security: bool = True, + security_level: str = "standard", *args, **kwargs, ): @@ -184,6 +194,47 @@ class BaseSwarm(ABC): agent.agent_name: agent for agent in self.agents } + # Initialize SwarmShield security + self._initialize_swarm_shield(shield_config, enable_security, security_level) + + def _initialize_swarm_shield(self, shield_config: Optional[ShieldConfig], enable_security: bool, security_level: str) -> None: + """Initialize SwarmShield security integration""" + try: + # Set up shield configuration + if shield_config is None: + if security_level == "basic": + shield_config = ShieldConfig.create_basic_config() + elif security_level == "enhanced": + shield_config = ShieldConfig.create_enhanced_config() + elif security_level == "maximum": + shield_config = ShieldConfig.create_maximum_config() + else: + shield_config = ShieldConfig.create_standard_config() + + # Override enabled state if specified + if not enable_security: + shield_config.enabled = False + + # Initialize shield integration + self.shield_config = shield_config + self.shield = SwarmShieldIntegration(shield_config) + + # Create secure conversation if enabled + if self.shield_config.integrate_with_conversation: + self.secure_conversation_id = self.shield.create_secure_conversation(f"swarm_{self.name or 'unnamed'}") + else: + self.secure_conversation_id = None + + logger.info(f"SwarmShield initialized with {security_level} security level") + + except Exception as e: + logger.error(f"Failed to initialize SwarmShield: {e}") + # Fallback to basic security + self.shield_config = ShieldConfig.create_basic_config() + self.shield_config.enabled = False + self.shield = None + self.secure_conversation_id = None + def communicate(self): """Communicate with the swarm through the orchestrator, protocols, and the universal communication layer""" ... @@ -791,3 +842,162 @@ class BaseSwarm(ABC): Convert agents to a pandas DataFrame. """ ... + + # ==================== SwarmShield Security Methods ==================== + + def validate_task_with_shield(self, task: str, agent_name: str = "default") -> Tuple[bool, str, Optional[str]]: + """Validate task with SwarmShield protection""" + if not self.shield or not self.shield_config.enabled: + return True, task, None + return self.shield.validate_task(task, agent_name) + + def validate_agent_config_with_shield(self, agent_config: Dict[str, Any], agent_name: str = "default") -> Tuple[bool, Dict[str, Any], Optional[str]]: + """Validate agent configuration with SwarmShield protection""" + if not self.shield or not self.shield_config.enabled: + return True, agent_config, None + return self.shield.validate_agent_config(agent_config, agent_name) + + def process_agent_communication_with_shield(self, agent_name: str, message: str, direction: str = "outbound") -> tuple[bool, str, Optional[str]]: + """ + Process agent communication using SwarmShield security + + Args: + agent_name: Name of the agent + message: Message content + direction: "inbound" or "outbound" + + Returns: + Tuple of (is_valid, processed_message, error_message) + """ + if not hasattr(self, 'shield') or self.shield is None: + return True, message, None + + try: + return self.shield.process_agent_communication(agent_name, message, direction) + except Exception as e: + logger.error(f"Communication processing error: {e}") + return False, "", f"Processing error: {str(e)}" + + def check_rate_limit_with_shield(self, agent_name: str, request_size: int = 1) -> tuple[bool, Optional[str]]: + """ + Check rate limits using SwarmShield security + + Args: + agent_name: Name of the agent + request_size: Size of the request + + Returns: + Tuple of (is_allowed, error_message) + """ + if not hasattr(self, 'shield') or self.shield is None: + return True, None + + try: + return self.shield.check_rate_limit(agent_name, request_size) + except Exception as e: + logger.error(f"Rate limit check error: {e}") + return False, f"Rate limit error: {str(e)}" + + def add_secure_message(self, agent_name: str, message: str) -> bool: + """ + Add a message to secure conversation + + Args: + agent_name: Name of the agent + message: Message content + + Returns: + True if successful, False otherwise + """ + if not hasattr(self, 'shield') or self.shield is None or not hasattr(self, 'secure_conversation_id'): + return False + + try: + if self.secure_conversation_id: + return self.shield.add_secure_message(self.secure_conversation_id, agent_name, message) + return False + except Exception as e: + logger.error(f"Failed to add secure message: {e}") + return False + + def get_secure_messages(self) -> List[tuple[str, str, datetime]]: + """ + Get messages from secure conversation + + Returns: + List of (agent_name, message, timestamp) tuples + """ + if not hasattr(self, 'shield') or self.shield is None or not hasattr(self, 'secure_conversation_id'): + return [] + + try: + if self.secure_conversation_id: + return self.shield.get_secure_messages(self.secure_conversation_id) + return [] + except Exception as e: + logger.error(f"Failed to get secure messages: {e}") + return [] + + def get_security_stats(self) -> Dict[str, Any]: + """ + Get comprehensive security statistics + + Returns: + Dictionary with security statistics + """ + if not hasattr(self, 'shield') or self.shield is None: + return {"error": "SwarmShield not initialized"} + + try: + return self.shield.get_security_stats() + except Exception as e: + logger.error(f"Failed to get security stats: {e}") + return {"error": str(e)} + + def update_shield_config(self, new_config: ShieldConfig) -> bool: + """ + Update SwarmShield configuration + + Args: + new_config: New shield configuration + + Returns: + True if successful, False otherwise + """ + if not hasattr(self, 'shield') or self.shield is None: + return False + + try: + success = self.shield.update_config(new_config) + if success: + self.shield_config = new_config + logger.info("Shield configuration updated successfully") + return success + except Exception as e: + logger.error(f"Failed to update shield config: {e}") + return False + + def enable_security(self) -> None: + """Enable all security features""" + if hasattr(self, 'shield') and self.shield is not None: + self.shield.enable_security() + logger.info("Security features enabled") + + def disable_security(self) -> None: + """Disable all security features""" + if hasattr(self, 'shield') and self.shield is not None: + self.shield.disable_security() + logger.info("Security features disabled") + + def cleanup_security(self) -> None: + """Cleanup security resources""" + if hasattr(self, 'shield') and self.shield is not None: + self.shield.cleanup() + logger.info("Security resources cleaned up") + + def __del__(self): + """Cleanup on destruction""" + try: + self.cleanup_security() + except: + pass