Update base_swarm.py

pull/1034/head
CI-DEV 2 months ago committed by GitHub
parent f1cf87561c
commit c1cba2e05c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -2,6 +2,7 @@ import os
import asyncio
import json
import uuid
from datetime import datetime
from swarms.utils.file_processing import create_file_in_folder
from abc import ABC
from concurrent.futures import ThreadPoolExecutor, as_completed
@ -13,6 +14,7 @@ from typing import (
Optional,
Sequence,
Union,
Tuple,
)
import yaml
@ -23,6 +25,9 @@ from swarms.structs.omni_agent_types import AgentType
from pydantic import BaseModel
from swarms.utils.loguru_logger import initialize_logger
# Import SwarmShield components
from swarms.security import SwarmShieldIntegration, ShieldConfig
logger = initialize_logger(log_folder="base_swarm")
@ -33,7 +38,8 @@ class BaseSwarm(ABC):
Attributes:
agents (List[Agent]): A list of agents
max_loops (int): The maximum number of loops to run
shield_config (ShieldConfig): Security configuration for SwarmShield
shield (SwarmShieldIntegration): Security shield integration
Methods:
communicate: Communicate with the swarm through the orchestrator, protocols, and the universal communication layer
@ -95,6 +101,10 @@ class BaseSwarm(ABC):
collective_memory_system: Optional[Any] = False,
agent_ops_on: bool = False,
output_schema: Optional[BaseModel] = None,
# SwarmShield parameters
shield_config: Optional[ShieldConfig] = None,
enable_security: bool = True,
security_level: str = "standard",
*args,
**kwargs,
):
@ -184,6 +194,47 @@ class BaseSwarm(ABC):
agent.agent_name: agent for agent in self.agents
}
# Initialize SwarmShield security
self._initialize_swarm_shield(shield_config, enable_security, security_level)
def _initialize_swarm_shield(self, shield_config: Optional[ShieldConfig], enable_security: bool, security_level: str) -> None:
"""Initialize SwarmShield security integration"""
try:
# Set up shield configuration
if shield_config is None:
if security_level == "basic":
shield_config = ShieldConfig.create_basic_config()
elif security_level == "enhanced":
shield_config = ShieldConfig.create_enhanced_config()
elif security_level == "maximum":
shield_config = ShieldConfig.create_maximum_config()
else:
shield_config = ShieldConfig.create_standard_config()
# Override enabled state if specified
if not enable_security:
shield_config.enabled = False
# Initialize shield integration
self.shield_config = shield_config
self.shield = SwarmShieldIntegration(shield_config)
# Create secure conversation if enabled
if self.shield_config.integrate_with_conversation:
self.secure_conversation_id = self.shield.create_secure_conversation(f"swarm_{self.name or 'unnamed'}")
else:
self.secure_conversation_id = None
logger.info(f"SwarmShield initialized with {security_level} security level")
except Exception as e:
logger.error(f"Failed to initialize SwarmShield: {e}")
# Fallback to basic security
self.shield_config = ShieldConfig.create_basic_config()
self.shield_config.enabled = False
self.shield = None
self.secure_conversation_id = None
def communicate(self):
"""Communicate with the swarm through the orchestrator, protocols, and the universal communication layer"""
...
@ -791,3 +842,162 @@ class BaseSwarm(ABC):
Convert agents to a pandas DataFrame.
"""
...
# ==================== SwarmShield Security Methods ====================
def validate_task_with_shield(self, task: str, agent_name: str = "default") -> Tuple[bool, str, Optional[str]]:
"""Validate task with SwarmShield protection"""
if not self.shield or not self.shield_config.enabled:
return True, task, None
return self.shield.validate_task(task, agent_name)
def validate_agent_config_with_shield(self, agent_config: Dict[str, Any], agent_name: str = "default") -> Tuple[bool, Dict[str, Any], Optional[str]]:
"""Validate agent configuration with SwarmShield protection"""
if not self.shield or not self.shield_config.enabled:
return True, agent_config, None
return self.shield.validate_agent_config(agent_config, agent_name)
def process_agent_communication_with_shield(self, agent_name: str, message: str, direction: str = "outbound") -> tuple[bool, str, Optional[str]]:
"""
Process agent communication using SwarmShield security
Args:
agent_name: Name of the agent
message: Message content
direction: "inbound" or "outbound"
Returns:
Tuple of (is_valid, processed_message, error_message)
"""
if not hasattr(self, 'shield') or self.shield is None:
return True, message, None
try:
return self.shield.process_agent_communication(agent_name, message, direction)
except Exception as e:
logger.error(f"Communication processing error: {e}")
return False, "", f"Processing error: {str(e)}"
def check_rate_limit_with_shield(self, agent_name: str, request_size: int = 1) -> tuple[bool, Optional[str]]:
"""
Check rate limits using SwarmShield security
Args:
agent_name: Name of the agent
request_size: Size of the request
Returns:
Tuple of (is_allowed, error_message)
"""
if not hasattr(self, 'shield') or self.shield is None:
return True, None
try:
return self.shield.check_rate_limit(agent_name, request_size)
except Exception as e:
logger.error(f"Rate limit check error: {e}")
return False, f"Rate limit error: {str(e)}"
def add_secure_message(self, agent_name: str, message: str) -> bool:
"""
Add a message to secure conversation
Args:
agent_name: Name of the agent
message: Message content
Returns:
True if successful, False otherwise
"""
if not hasattr(self, 'shield') or self.shield is None or not hasattr(self, 'secure_conversation_id'):
return False
try:
if self.secure_conversation_id:
return self.shield.add_secure_message(self.secure_conversation_id, agent_name, message)
return False
except Exception as e:
logger.error(f"Failed to add secure message: {e}")
return False
def get_secure_messages(self) -> List[tuple[str, str, datetime]]:
"""
Get messages from secure conversation
Returns:
List of (agent_name, message, timestamp) tuples
"""
if not hasattr(self, 'shield') or self.shield is None or not hasattr(self, 'secure_conversation_id'):
return []
try:
if self.secure_conversation_id:
return self.shield.get_secure_messages(self.secure_conversation_id)
return []
except Exception as e:
logger.error(f"Failed to get secure messages: {e}")
return []
def get_security_stats(self) -> Dict[str, Any]:
"""
Get comprehensive security statistics
Returns:
Dictionary with security statistics
"""
if not hasattr(self, 'shield') or self.shield is None:
return {"error": "SwarmShield not initialized"}
try:
return self.shield.get_security_stats()
except Exception as e:
logger.error(f"Failed to get security stats: {e}")
return {"error": str(e)}
def update_shield_config(self, new_config: ShieldConfig) -> bool:
"""
Update SwarmShield configuration
Args:
new_config: New shield configuration
Returns:
True if successful, False otherwise
"""
if not hasattr(self, 'shield') or self.shield is None:
return False
try:
success = self.shield.update_config(new_config)
if success:
self.shield_config = new_config
logger.info("Shield configuration updated successfully")
return success
except Exception as e:
logger.error(f"Failed to update shield config: {e}")
return False
def enable_security(self) -> None:
"""Enable all security features"""
if hasattr(self, 'shield') and self.shield is not None:
self.shield.enable_security()
logger.info("Security features enabled")
def disable_security(self) -> None:
"""Disable all security features"""
if hasattr(self, 'shield') and self.shield is not None:
self.shield.disable_security()
logger.info("Security features disabled")
def cleanup_security(self) -> None:
"""Cleanup security resources"""
if hasattr(self, 'shield') and self.shield is not None:
self.shield.cleanup()
logger.info("Security resources cleaned up")
def __del__(self):
"""Cleanup on destruction"""
try:
self.cleanup_security()
except:
pass

Loading…
Cancel
Save