diff --git a/swarms/security/__init__.py b/swarms/security/__init__.py new file mode 100644 index 00000000..d5f3c85d --- /dev/null +++ b/swarms/security/__init__.py @@ -0,0 +1,29 @@ +""" +Security module for Swarms framework. + +This module provides enterprise-grade security features including: +- SwarmShield integration for encrypted communications +- Input validation and sanitization +- Output filtering and safety checks +- Rate limiting and abuse prevention +- Audit logging and compliance features +""" + +from .swarm_shield import SwarmShield, EncryptionStrength +from .shield_config import ShieldConfig +from .input_validator import InputValidator +from .output_filter import OutputFilter +from .safety_checker import SafetyChecker +from .rate_limiter import RateLimiter +from .swarm_shield_integration import SwarmShieldIntegration + +__all__ = [ + "SwarmShield", + "EncryptionStrength", + "ShieldConfig", + "InputValidator", + "OutputFilter", + "SafetyChecker", + "RateLimiter", + "SwarmShieldIntegration", +] \ No newline at end of file diff --git a/swarms/security/input_validator.py b/swarms/security/input_validator.py new file mode 100644 index 00000000..51b1cf23 --- /dev/null +++ b/swarms/security/input_validator.py @@ -0,0 +1,259 @@ +""" +Input validation and sanitization for Swarms framework. + +This module provides comprehensive input validation, sanitization, +and security checks for all swarm inputs. +""" + +import re +import html +from typing import List, Optional, Dict, Any, Tuple +from urllib.parse import urlparse +from datetime import datetime + +from loguru import logger + +from swarms.utils.loguru_logger import initialize_logger + +# Initialize logger for input validation +validation_logger = initialize_logger(log_folder="input_validation") + + +class InputValidator: + """ + Input validation and sanitization for swarm security + + Features: + - Input length validation + - Pattern-based blocking + - XSS prevention + - SQL injection prevention + - URL validation + - Content type validation + """ + + def __init__(self, config: Dict[str, Any]): + """ + Initialize input validator with configuration + + Args: + config: Validation configuration dictionary + """ + self.enabled = config.get("enabled", True) + self.max_length = config.get("max_length", 10000) + self.blocked_patterns = config.get("blocked_patterns", []) + self.allowed_domains = config.get("allowed_domains", []) + + # Compile regex patterns for performance + self._compiled_patterns = [ + re.compile(pattern, re.IGNORECASE) for pattern in self.blocked_patterns + ] + + # Common malicious patterns + self._malicious_patterns = [ + re.compile(r".*?", re.IGNORECASE), + re.compile(r"javascript:", re.IGNORECASE), + re.compile(r"data:text/html", re.IGNORECASE), + re.compile(r"vbscript:", re.IGNORECASE), + re.compile(r"on\w+\s*=", re.IGNORECASE), + re.compile(r".*?", re.IGNORECASE), + re.compile(r".*?", re.IGNORECASE), + re.compile(r"", re.IGNORECASE), + re.compile(r"", re.IGNORECASE), + re.compile(r"", re.IGNORECASE), + ] + + # SQL injection patterns + self._sql_patterns = [ + re.compile(r"(\b(union|select|insert|update|delete|drop|create|alter)\b)", re.IGNORECASE), + re.compile(r"(--|#|/\*|\*/)", re.IGNORECASE), + re.compile(r"(\b(exec|execute|xp_|sp_)\b)", re.IGNORECASE), + ] + + validation_logger.info("InputValidator initialized") + + def validate_input(self, input_data: str, input_type: str = "text") -> Tuple[bool, str, Optional[str]]: + """ + Validate and sanitize input data + + Args: + input_data: Input data to validate + input_type: Type of input (text, url, code, etc.) + + Returns: + Tuple of (is_valid, sanitized_data, error_message) + """ + if not self.enabled: + return True, input_data, None + + try: + # Basic type validation + if not isinstance(input_data, str): + return False, "", "Input must be a string" + + # Length validation + if len(input_data) > self.max_length: + return False, "", f"Input exceeds maximum length of {self.max_length} characters" + + # Empty input check + if not input_data.strip(): + return False, "", "Input cannot be empty" + + # Sanitize the input + sanitized = self._sanitize_input(input_data) + + # Check for blocked patterns + if self._check_blocked_patterns(sanitized): + return False, "", "Input contains blocked patterns" + + # Check for malicious patterns + if self._check_malicious_patterns(sanitized): + return False, "", "Input contains potentially malicious content" + + # Type-specific validation + if input_type == "url": + if not self._validate_url(sanitized): + return False, "", "Invalid URL format" + + elif input_type == "code": + if not self._validate_code(sanitized): + return False, "", "Invalid code content" + + elif input_type == "json": + if not self._validate_json(sanitized): + return False, "", "Invalid JSON format" + + validation_logger.debug(f"Input validation passed for type: {input_type}") + return True, sanitized, None + + except Exception as e: + validation_logger.error(f"Input validation error: {e}") + return False, "", f"Validation error: {str(e)}" + + def _sanitize_input(self, input_data: str) -> str: + """Sanitize input data to prevent XSS and other attacks""" + # HTML escape + sanitized = html.escape(input_data) + + # Remove null bytes + sanitized = sanitized.replace('\x00', '') + + # Normalize whitespace + sanitized = ' '.join(sanitized.split()) + + return sanitized + + def _check_blocked_patterns(self, input_data: str) -> bool: + """Check if input contains blocked patterns""" + for pattern in self._compiled_patterns: + if pattern.search(input_data): + validation_logger.warning(f"Blocked pattern detected: {pattern.pattern}") + return True + return False + + def _check_malicious_patterns(self, input_data: str) -> bool: + """Check if input contains malicious patterns""" + for pattern in self._malicious_patterns: + if pattern.search(input_data): + validation_logger.warning(f"Malicious pattern detected: {pattern.pattern}") + return True + return False + + def _validate_url(self, url: str) -> bool: + """Validate URL format and domain""" + try: + parsed = urlparse(url) + + # Check if it's a valid URL + if not all([parsed.scheme, parsed.netloc]): + return False + + # Check allowed domains if specified + if self.allowed_domains: + domain = parsed.netloc.lower() + if not any(allowed in domain for allowed in self.allowed_domains): + validation_logger.warning(f"Domain not allowed: {domain}") + return False + + return True + + except Exception: + return False + + def _validate_code(self, code: str) -> bool: + """Validate code content for safety""" + # Check for SQL injection patterns + for pattern in self._sql_patterns: + if pattern.search(code): + validation_logger.warning(f"SQL injection pattern detected: {pattern.pattern}") + return False + + # Check for dangerous system calls + dangerous_calls = [ + 'os.system', 'subprocess.call', 'eval(', 'exec(', + '__import__', 'globals()', 'locals()' + ] + + for call in dangerous_calls: + if call in code: + validation_logger.warning(f"Dangerous call detected: {call}") + return False + + return True + + def _validate_json(self, json_str: str) -> bool: + """Validate JSON format""" + try: + import json + json.loads(json_str) + return True + except (json.JSONDecodeError, ValueError): + return False + + def validate_task(self, task: str) -> Tuple[bool, str, Optional[str]]: + """Validate swarm task input""" + return self.validate_input(task, "text") + + def validate_agent_name(self, agent_name: str) -> Tuple[bool, str, Optional[str]]: + """Validate agent name input""" + # Additional validation for agent names + if not re.match(r'^[a-zA-Z0-9_-]+$', agent_name): + return False, "", "Agent name can only contain letters, numbers, underscores, and hyphens" + + if len(agent_name) < 1 or len(agent_name) > 50: + return False, "", "Agent name must be between 1 and 50 characters" + + return self.validate_input(agent_name, "text") + + def validate_message(self, message: str) -> Tuple[bool, str, Optional[str]]: + """Validate message input""" + return self.validate_input(message, "text") + + def validate_config(self, config: Dict[str, Any]) -> Tuple[bool, Dict[str, Any], Optional[str]]: + """Validate configuration input""" + try: + # Convert config to string for validation + import json + config_str = json.dumps(config) + + is_valid, sanitized, error = self.validate_input(config_str, "json") + if not is_valid: + return False, {}, error + + # Parse back to dict + validated_config = json.loads(sanitized) + return True, validated_config, None + + except Exception as e: + return False, {}, f"Configuration validation error: {str(e)}" + + def get_validation_stats(self) -> Dict[str, Any]: + """Get validation statistics""" + return { + "enabled": self.enabled, + "max_length": self.max_length, + "blocked_patterns_count": len(self.blocked_patterns), + "allowed_domains_count": len(self.allowed_domains), + "malicious_patterns_count": len(self._malicious_patterns), + "sql_patterns_count": len(self._sql_patterns), + } \ No newline at end of file diff --git a/swarms/security/output_filter.py b/swarms/security/output_filter.py new file mode 100644 index 00000000..a2989d51 --- /dev/null +++ b/swarms/security/output_filter.py @@ -0,0 +1,285 @@ +""" +Output filtering and sanitization for Swarms framework. + +This module provides comprehensive output filtering, sanitization, +and sensitive data protection for all swarm outputs. +""" + +import re +import json +from typing import List, Optional, Dict, Any, Tuple, Union +from datetime import datetime + +from loguru import logger + +from swarms.utils.loguru_logger import initialize_logger + +# Initialize logger for output filtering +filter_logger = initialize_logger(log_folder="output_filtering") + + +class OutputFilter: + """ + Output filtering and sanitization for swarm security + + Features: + - Sensitive data filtering + - Output sanitization + - Content type filtering + - PII protection + - Malicious content detection + """ + + def __init__(self, config: Dict[str, Any]): + """ + Initialize output filter with configuration + + Args: + config: Filtering configuration dictionary + """ + self.enabled = config.get("enabled", True) + self.filter_sensitive = config.get("filter_sensitive", True) + self.sensitive_patterns = config.get("sensitive_patterns", []) + + # Compile regex patterns for performance + self._compiled_patterns = [ + re.compile(pattern, re.IGNORECASE) for pattern in self.sensitive_patterns + ] + + # Default sensitive data patterns + self._default_patterns = [ + re.compile(r"\b\d{3}-\d{2}-\d{4}\b"), # SSN + re.compile(r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b"), # Credit card + re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"), # Email + re.compile(r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b"), # IP address + re.compile(r"\b\d{3}[\s-]?\d{3}[\s-]?\d{4}\b"), # Phone number + re.compile(r"\b[A-Z]{2}\d{2}[A-Z0-9]{10,30}\b"), # IBAN + re.compile(r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{1,3}\b"), # Extended CC + ] + + # Malicious content patterns + self._malicious_patterns = [ + re.compile(r".*?", re.IGNORECASE), + re.compile(r"javascript:", re.IGNORECASE), + re.compile(r"data:text/html", re.IGNORECASE), + re.compile(r"vbscript:", re.IGNORECASE), + re.compile(r"on\w+\s*=", re.IGNORECASE), + re.compile(r".*?", re.IGNORECASE), + re.compile(r".*?", re.IGNORECASE), + re.compile(r"", re.IGNORECASE), + ] + + # API key patterns + self._api_key_patterns = [ + re.compile(r"sk-[a-zA-Z0-9]{32,}"), # OpenAI API key + re.compile(r"pk_[a-zA-Z0-9]{32,}"), # OpenAI API key (public) + re.compile(r"[a-zA-Z0-9]{32,}"), # Generic API key + ] + + filter_logger.info("OutputFilter initialized") + + def filter_output(self, output_data: Union[str, Dict, List], output_type: str = "text") -> Tuple[bool, Union[str, Dict, List], Optional[str]]: + """ + Filter and sanitize output data + + Args: + output_data: Output data to filter + output_type: Type of output (text, json, dict, etc.) + + Returns: + Tuple of (is_safe, filtered_data, warning_message) + """ + if not self.enabled: + return True, output_data, None + + try: + # Convert to string for processing + if isinstance(output_data, (dict, list)): + output_str = json.dumps(output_data, ensure_ascii=False) + else: + output_str = str(output_data) + + # Check for malicious content + if self._check_malicious_content(output_str): + return False, "", "Output contains potentially malicious content" + + # Filter sensitive data + if self.filter_sensitive: + filtered_str = self._filter_sensitive_data(output_str) + else: + filtered_str = output_str + + # Convert back to original type if needed + if isinstance(output_data, (dict, list)) and output_type in ["json", "dict"]: + try: + filtered_data = json.loads(filtered_str) + except json.JSONDecodeError: + filtered_data = filtered_str + else: + filtered_data = filtered_str + + # Check if any sensitive data was filtered + warning = None + if filtered_str != output_str: + warning = "Sensitive data was filtered from output" + + filter_logger.debug(f"Output filtering completed for type: {output_type}") + return True, filtered_data, warning + + except Exception as e: + filter_logger.error(f"Output filtering error: {e}") + return False, "", f"Filtering error: {str(e)}" + + def _check_malicious_content(self, content: str) -> bool: + """Check if content contains malicious patterns""" + for pattern in self._malicious_patterns: + if pattern.search(content): + filter_logger.warning(f"Malicious content detected: {pattern.pattern}") + return True + return False + + def _filter_sensitive_data(self, content: str) -> str: + """Filter sensitive data from content""" + filtered_content = content + + # Filter custom sensitive patterns + for pattern in self._compiled_patterns: + filtered_content = pattern.sub("[SENSITIVE_DATA]", filtered_content) + + # Filter default sensitive patterns + for pattern in self._default_patterns: + filtered_content = pattern.sub("[SENSITIVE_DATA]", filtered_content) + + # Filter API keys + for pattern in self._api_key_patterns: + filtered_content = pattern.sub("[API_KEY]", filtered_content) + + return filtered_content + + def filter_agent_response(self, response: str, agent_name: str) -> Tuple[bool, str, Optional[str]]: + """Filter agent response output""" + return self.filter_output(response, "text") + + def filter_swarm_output(self, output: Union[str, Dict, List]) -> Tuple[bool, Union[str, Dict, List], Optional[str]]: + """Filter swarm output""" + return self.filter_output(output, "json") + + def filter_conversation_history(self, history: List[Dict]) -> Tuple[bool, List[Dict], Optional[str]]: + """Filter conversation history""" + try: + filtered_history = [] + warnings = [] + + for message in history: + # Filter message content + is_safe, filtered_content, warning = self.filter_output( + message.get("content", ""), "text" + ) + + if not is_safe: + return False, [], "Conversation history contains unsafe content" + + # Create filtered message + filtered_message = message.copy() + filtered_message["content"] = filtered_content + + if warning: + warnings.append(warning) + + filtered_history.append(filtered_message) + + warning_msg = "; ".join(set(warnings)) if warnings else None + return True, filtered_history, warning_msg + + except Exception as e: + filter_logger.error(f"Conversation history filtering error: {e}") + return False, [], f"History filtering error: {str(e)}" + + def filter_config_output(self, config: Dict[str, Any]) -> Tuple[bool, Dict[str, Any], Optional[str]]: + """Filter configuration output""" + try: + # Create a copy to avoid modifying original + filtered_config = config.copy() + + # Filter sensitive config fields + sensitive_fields = [ + "api_key", "secret", "password", "token", "key", + "credential", "auth", "private", "secret_key" + ] + + warnings = [] + for field in sensitive_fields: + if field in filtered_config: + if isinstance(filtered_config[field], str): + filtered_config[field] = "[SENSITIVE_CONFIG]" + warnings.append(f"Sensitive config field '{field}' was filtered") + + warning_msg = "; ".join(warnings) if warnings else None + return True, filtered_config, warning_msg + + except Exception as e: + filter_logger.error(f"Config filtering error: {e}") + return False, {}, f"Config filtering error: {str(e)}" + + def sanitize_for_logging(self, data: Union[str, Dict, List]) -> str: + """Sanitize data for logging purposes""" + try: + if isinstance(data, (dict, list)): + data_str = json.dumps(data, ensure_ascii=False) + else: + data_str = str(data) + + # Apply aggressive filtering for logs + sanitized = self._filter_sensitive_data(data_str) + + # Truncate if too long + if len(sanitized) > 1000: + sanitized = sanitized[:1000] + "... [TRUNCATED]" + + return sanitized + + except Exception as e: + filter_logger.error(f"Log sanitization error: {e}") + return "[SANITIZATION_ERROR]" + + def add_custom_pattern(self, pattern: str, description: str = "") -> None: + """Add custom sensitive data pattern""" + try: + compiled_pattern = re.compile(pattern, re.IGNORECASE) + self._compiled_patterns.append(compiled_pattern) + self.sensitive_patterns.append(pattern) + + filter_logger.info(f"Added custom pattern: {pattern} ({description})") + + except re.error as e: + filter_logger.error(f"Invalid regex pattern: {pattern} - {e}") + + def remove_pattern(self, pattern: str) -> bool: + """Remove sensitive data pattern""" + try: + if pattern in self.sensitive_patterns: + self.sensitive_patterns.remove(pattern) + + # Recompile patterns + self._compiled_patterns = [ + re.compile(p, re.IGNORECASE) for p in self.sensitive_patterns + ] + + filter_logger.info(f"Removed pattern: {pattern}") + return True + return False + + except Exception as e: + filter_logger.error(f"Error removing pattern: {e}") + return False + + def get_filter_stats(self) -> Dict[str, Any]: + """Get filtering statistics""" + return { + "enabled": self.enabled, + "filter_sensitive": self.filter_sensitive, + "sensitive_patterns_count": len(self.sensitive_patterns), + "malicious_patterns_count": len(self._malicious_patterns), + "api_key_patterns_count": len(self._api_key_patterns), + "default_patterns_count": len(self._default_patterns), + } \ No newline at end of file diff --git a/swarms/security/rate_limiter.py b/swarms/security/rate_limiter.py new file mode 100644 index 00000000..2806332e --- /dev/null +++ b/swarms/security/rate_limiter.py @@ -0,0 +1,323 @@ +""" +Rate limiting and abuse prevention for Swarms framework. + +This module provides comprehensive rate limiting, request tracking, +and abuse prevention for all swarm operations. +""" + +import time +import threading +from typing import Dict, List, Optional, Tuple, Any +from collections import defaultdict, deque +from datetime import datetime, timedelta + +from loguru import logger + +from swarms.utils.loguru_logger import initialize_logger + +# Initialize logger for rate limiting +rate_logger = initialize_logger(log_folder="rate_limiting") + + +class RateLimiter: + """ + Rate limiting and abuse prevention for swarm security + + Features: + - Per-agent rate limiting + - Token-based limiting + - Request tracking + - Abuse detection + - Automatic blocking + """ + + def __init__(self, config: Dict[str, Any]): + """ + Initialize rate limiter with configuration + + Args: + config: Rate limiting configuration dictionary + """ + self.enabled = config.get("enabled", True) + self.max_requests_per_minute = config.get("max_requests_per_minute", 60) + self.max_tokens_per_request = config.get("max_tokens_per_request", 10000) + self.window = config.get("window", 60) # seconds + + # Request tracking + self._request_history: Dict[str, deque] = defaultdict(lambda: deque()) + self._token_usage: Dict[str, int] = defaultdict(int) + self._blocked_agents: Dict[str, float] = {} + + # Thread safety + self._lock = threading.Lock() + + # Cleanup thread + self._cleanup_thread = None + self._stop_cleanup = False + + if self.enabled: + self._start_cleanup_thread() + + rate_logger.info(f"RateLimiter initialized: {self.max_requests_per_minute} req/min, {self.max_tokens_per_request} tokens/req") + + def check_rate_limit(self, agent_name: str, request_size: int = 1) -> Tuple[bool, Optional[str]]: + """ + Check if request is within rate limits + + Args: + agent_name: Name of the agent making the request + request_size: Size of the request (tokens or complexity) + + Returns: + Tuple of (is_allowed, error_message) + """ + if not self.enabled: + return True, None + + try: + with self._lock: + current_time = time.time() + + # Check if agent is blocked + if agent_name in self._blocked_agents: + block_until = self._blocked_agents[agent_name] + if current_time < block_until: + remaining = block_until - current_time + return False, f"Agent {agent_name} is blocked for {remaining:.1f} more seconds" + else: + # Unblock agent + del self._blocked_agents[agent_name] + + # Check token limit + if request_size > self.max_tokens_per_request: + return False, f"Request size {request_size} exceeds token limit {self.max_tokens_per_request}" + + # Get agent's request history + history = self._request_history[agent_name] + + # Remove old requests outside the window + cutoff_time = current_time - self.window + while history and history[0] < cutoff_time: + history.popleft() + + # Check request count limit + if len(history) >= self.max_requests_per_minute: + # Block agent temporarily + block_duration = min(300, self.window * 2) # Max 5 minutes + self._blocked_agents[agent_name] = current_time + block_duration + + rate_logger.warning(f"Agent {agent_name} rate limit exceeded, blocked for {block_duration}s") + return False, f"Rate limit exceeded. Agent blocked for {block_duration} seconds" + + # Add current request + history.append(current_time) + + # Update token usage + self._token_usage[agent_name] += request_size + + rate_logger.debug(f"Rate limit check passed for {agent_name}") + return True, None + + except Exception as e: + rate_logger.error(f"Rate limit check error: {e}") + return False, f"Rate limit check error: {str(e)}" + + def check_agent_limit(self, agent_name: str) -> Tuple[bool, Optional[str]]: + """Check agent-specific rate limits""" + return self.check_rate_limit(agent_name, 1) + + def check_token_limit(self, agent_name: str, token_count: int) -> Tuple[bool, Optional[str]]: + """Check token-based rate limits""" + return self.check_rate_limit(agent_name, token_count) + + def track_request(self, agent_name: str, request_type: str = "general", metadata: Dict[str, Any] = None) -> None: + """ + Track a request for monitoring purposes + + Args: + agent_name: Name of the agent + request_type: Type of request + metadata: Additional request metadata + """ + if not self.enabled: + return + + try: + with self._lock: + current_time = time.time() + + # Store request metadata + request_info = { + "timestamp": current_time, + "type": request_type, + "metadata": metadata or {} + } + + # Add to history (we only store timestamps for rate limiting) + self._request_history[agent_name].append(current_time) + + rate_logger.debug(f"Tracked request: {agent_name} - {request_type}") + + except Exception as e: + rate_logger.error(f"Request tracking error: {e}") + + def get_agent_stats(self, agent_name: str) -> Dict[str, Any]: + """Get rate limiting statistics for an agent""" + try: + with self._lock: + current_time = time.time() + history = self._request_history[agent_name] + + # Clean old requests + cutoff_time = current_time - self.window + while history and history[0] < cutoff_time: + history.popleft() + + # Calculate statistics + recent_requests = len(history) + total_tokens = self._token_usage.get(agent_name, 0) + is_blocked = agent_name in self._blocked_agents + block_remaining = 0 + + if is_blocked: + block_remaining = self._blocked_agents[agent_name] - current_time + + return { + "agent_name": agent_name, + "recent_requests": recent_requests, + "max_requests": self.max_requests_per_minute, + "total_tokens": total_tokens, + "max_tokens_per_request": self.max_tokens_per_request, + "is_blocked": is_blocked, + "block_remaining": max(0, block_remaining), + "window_seconds": self.window, + } + + except Exception as e: + rate_logger.error(f"Error getting agent stats: {e}") + return {} + + def get_all_stats(self) -> Dict[str, Any]: + """Get rate limiting statistics for all agents""" + try: + with self._lock: + stats = { + "enabled": self.enabled, + "max_requests_per_minute": self.max_requests_per_minute, + "max_tokens_per_request": self.max_tokens_per_request, + "window_seconds": self.window, + "total_agents": len(self._request_history), + "blocked_agents": len(self._blocked_agents), + "agents": {} + } + + for agent_name in self._request_history: + stats["agents"][agent_name] = self.get_agent_stats(agent_name) + + return stats + + except Exception as e: + rate_logger.error(f"Error getting all stats: {e}") + return {} + + def reset_agent(self, agent_name: str) -> bool: + """Reset rate limiting for an agent""" + try: + with self._lock: + if agent_name in self._request_history: + self._request_history[agent_name].clear() + + if agent_name in self._token_usage: + self._token_usage[agent_name] = 0 + + if agent_name in self._blocked_agents: + del self._blocked_agents[agent_name] + + rate_logger.info(f"Reset rate limiting for agent: {agent_name}") + return True + + except Exception as e: + rate_logger.error(f"Error resetting agent: {e}") + return False + + def block_agent(self, agent_name: str, duration: int = 300) -> bool: + """Manually block an agent""" + try: + with self._lock: + current_time = time.time() + self._blocked_agents[agent_name] = current_time + duration + + rate_logger.warning(f"Manually blocked agent {agent_name} for {duration}s") + return True + + except Exception as e: + rate_logger.error(f"Error blocking agent: {e}") + return False + + def unblock_agent(self, agent_name: str) -> bool: + """Unblock an agent""" + try: + with self._lock: + if agent_name in self._blocked_agents: + del self._blocked_agents[agent_name] + rate_logger.info(f"Unblocked agent: {agent_name}") + return True + return False + + except Exception as e: + rate_logger.error(f"Error unblocking agent: {e}") + return False + + def _start_cleanup_thread(self) -> None: + """Start background cleanup thread""" + def cleanup_worker(): + while not self._stop_cleanup: + try: + time.sleep(60) # Cleanup every minute + self._cleanup_old_data() + except Exception as e: + rate_logger.error(f"Cleanup thread error: {e}") + + self._cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True) + self._cleanup_thread.start() + + def _cleanup_old_data(self) -> None: + """Clean up old rate limiting data""" + try: + with self._lock: + current_time = time.time() + cutoff_time = current_time - (self.window * 2) # Keep 2x window + + # Clean request history + for agent_name in list(self._request_history.keys()): + history = self._request_history[agent_name] + while history and history[0] < cutoff_time: + history.popleft() + + # Remove empty histories + if not history: + del self._request_history[agent_name] + + # Clean blocked agents + for agent_name in list(self._blocked_agents.keys()): + if self._blocked_agents[agent_name] < current_time: + del self._blocked_agents[agent_name] + + # Reset token usage periodically + if current_time % 3600 < 60: # Reset every hour + self._token_usage.clear() + + except Exception as e: + rate_logger.error(f"Cleanup error: {e}") + + def stop(self) -> None: + """Stop the rate limiter and cleanup thread""" + self._stop_cleanup = True + if self._cleanup_thread: + self._cleanup_thread.join(timeout=5) + + rate_logger.info("RateLimiter stopped") + + def __del__(self): + """Cleanup on destruction""" + self.stop() \ No newline at end of file diff --git a/swarms/security/safety_checker.py b/swarms/security/safety_checker.py new file mode 100644 index 00000000..d31fd6e2 --- /dev/null +++ b/swarms/security/safety_checker.py @@ -0,0 +1,249 @@ +""" +Safety checking and content filtering for Swarms framework. + +This module provides comprehensive safety checks, content filtering, +and ethical AI compliance for all swarm operations. +""" + +import re +from typing import List, Optional, Dict, Any, Tuple +from enum import Enum + +from loguru import logger + +from swarms.utils.loguru_logger import initialize_logger +from swarms.prompts.safety_prompt import SAFETY_PROMPT + +# Initialize logger for safety checking +safety_logger = initialize_logger(log_folder="safety_checking") + + +class ContentLevel(Enum): + """Content filtering levels""" + + LOW = "low" # Minimal filtering + MODERATE = "moderate" # Standard filtering + HIGH = "high" # Aggressive filtering + + +class SafetyChecker: + """ + Safety checking and content filtering for swarm security + + Features: + - Content safety assessment + - Ethical AI compliance + - Harmful content detection + - Bias detection + - Safety prompt integration + """ + + def __init__(self, config: Dict[str, Any]): + """ + Initialize safety checker with configuration + + Args: + config: Safety configuration dictionary + """ + self.enabled = config.get("enabled", True) + self.safety_prompt = config.get("safety_prompt", True) + self.filter_level = ContentLevel(config.get("filter_level", "moderate")) + + # Harmful content patterns + self._harmful_patterns = [ + re.compile(r"\b(kill|murder|suicide|bomb|explosive|weapon)\b", re.IGNORECASE), + re.compile(r"\b(hack|crack|steal|fraud|scam|phishing)\b", re.IGNORECASE), + re.compile(r"\b(drug|illegal|contraband|smuggle)\b", re.IGNORECASE), + re.compile(r"\b(hate|racist|sexist|discriminate)\b", re.IGNORECASE), + re.compile(r"\b(terrorist|extremist|radical)\b", re.IGNORECASE), + ] + + # Bias detection patterns + self._bias_patterns = [ + re.compile(r"\b(all|every|always|never|none)\b", re.IGNORECASE), + re.compile(r"\b(men are|women are|blacks are|whites are)\b", re.IGNORECASE), + re.compile(r"\b(stereotypical|typical|usual)\b", re.IGNORECASE), + ] + + # Age-inappropriate content + self._age_inappropriate = [ + re.compile(r"\b(sex|porn|adult|explicit)\b", re.IGNORECASE), + re.compile(r"\b(violence|gore|blood|death)\b", re.IGNORECASE), + ] + + # Handle filter_level safely + filter_level_str = ( + self.filter_level.value + if hasattr(self.filter_level, 'value') + else str(self.filter_level) + ) + safety_logger.info(f"SafetyChecker initialized with level: {filter_level_str}") + + def check_safety(self, content: str, content_type: str = "text") -> Tuple[bool, str, Optional[str]]: + """ + Check content for safety and ethical compliance + + Args: + content: Content to check + content_type: Type of content (text, code, config, etc.) + + Returns: + Tuple of (is_safe, sanitized_content, warning_message) + """ + if not self.enabled: + return True, content, None + + try: + # Basic type validation + if not isinstance(content, str): + return False, "", "Content must be a string" + + # Check for harmful content + if self._check_harmful_content(content): + return False, "", "Content contains potentially harmful material" + + # Check for bias + if self._check_bias(content): + return False, "", "Content contains potentially biased language" + + # Check age appropriateness + if self._check_age_appropriate(content): + return False, "", "Content may not be age-appropriate" + + # Apply content filtering based on level + sanitized = self._filter_content(content) + + # Check if content was modified + warning = None + if sanitized != content: + # Handle filter_level safely + filter_level_str = ( + self.filter_level.value + if hasattr(self.filter_level, 'value') + else str(self.filter_level) + ) + warning = f"Content was filtered for {filter_level_str} safety level" + + safety_logger.debug(f"Safety check passed for type: {content_type}") + return True, sanitized, warning + + except Exception as e: + safety_logger.error(f"Safety check error: {e}") + return False, "", f"Safety check error: {str(e)}" + + def _check_harmful_content(self, content: str) -> bool: + """Check for harmful content patterns""" + for pattern in self._harmful_patterns: + if pattern.search(content): + safety_logger.warning(f"Harmful content detected: {pattern.pattern}") + return True + return False + + def _check_bias(self, content: str) -> bool: + """Check for biased language patterns""" + for pattern in self._bias_patterns: + if pattern.search(content): + safety_logger.warning(f"Bias detected: {pattern.pattern}") + return True + return False + + def _check_age_appropriate(self, content: str) -> bool: + """Check for age-inappropriate content""" + for pattern in self._age_inappropriate: + if pattern.search(content): + safety_logger.warning(f"Age-inappropriate content detected: {pattern.pattern}") + return True + return False + + def _filter_content(self, content: str) -> str: + """Filter content based on safety level""" + filtered_content = content + + if self.filter_level == ContentLevel.HIGH: + # Aggressive filtering + for pattern in self._harmful_patterns + self._bias_patterns: + filtered_content = pattern.sub("[FILTERED]", filtered_content) + + elif self.filter_level == ContentLevel.MODERATE: + # Moderate filtering - only filter obvious harmful content + for pattern in self._harmful_patterns: + filtered_content = pattern.sub("[FILTERED]", filtered_content) + + # LOW level doesn't filter content, only detects + + return filtered_content + + def check_agent_safety(self, agent_name: str, system_prompt: str) -> Tuple[bool, str, Optional[str]]: + """Check agent system prompt for safety""" + return self.check_safety(system_prompt, "agent_prompt") + + def check_task_safety(self, task: str) -> Tuple[bool, str, Optional[str]]: + """Check task description for safety""" + return self.check_safety(task, "task") + + def check_response_safety(self, response: str, agent_name: str) -> Tuple[bool, str, Optional[str]]: + """Check agent response for safety""" + return self.check_safety(response, "response") + + def check_config_safety(self, config: Dict[str, Any]) -> Tuple[bool, Dict[str, Any], Optional[str]]: + """Check configuration for safety""" + try: + import json + config_str = json.dumps(config) + + is_safe, sanitized, error = self.check_safety(config_str, "config") + if not is_safe: + return False, {}, error + + # Parse back to dict + safe_config = json.loads(sanitized) + return True, safe_config, None + + except Exception as e: + return False, {}, f"Config safety check error: {str(e)}" + + def get_safety_prompt(self) -> str: + """Get safety prompt for integration""" + if self.safety_prompt: + return SAFETY_PROMPT + return "" + + def add_harmful_pattern(self, pattern: str, description: str = "") -> None: + """Add custom harmful content pattern""" + try: + compiled_pattern = re.compile(pattern, re.IGNORECASE) + self._harmful_patterns.append(compiled_pattern) + + safety_logger.info(f"Added harmful pattern: {pattern} ({description})") + + except re.error as e: + safety_logger.error(f"Invalid regex pattern: {pattern} - {e}") + + def add_bias_pattern(self, pattern: str, description: str = "") -> None: + """Add custom bias detection pattern""" + try: + compiled_pattern = re.compile(pattern, re.IGNORECASE) + self._bias_patterns.append(compiled_pattern) + + safety_logger.info(f"Added bias pattern: {pattern} ({description})") + + except re.error as e: + safety_logger.error(f"Invalid regex pattern: {pattern} - {e}") + + def get_safety_stats(self) -> Dict[str, Any]: + """Get safety checking statistics""" + # Handle filter_level safely + filter_level_str = ( + self.filter_level.value + if hasattr(self.filter_level, 'value') + else str(self.filter_level) + ) + + return { + "enabled": self.enabled, + "safety_prompt": self.safety_prompt, + "filter_level": filter_level_str, + "harmful_patterns_count": len(self._harmful_patterns), + "bias_patterns_count": len(self._bias_patterns), + "age_inappropriate_patterns_count": len(self._age_inappropriate), + } \ No newline at end of file diff --git a/swarms/security/shield_config.py b/swarms/security/shield_config.py new file mode 100644 index 00000000..2d90513d --- /dev/null +++ b/swarms/security/shield_config.py @@ -0,0 +1,362 @@ +""" +Shield configuration for Swarms framework. + +This module provides configuration options for the security shield, +allowing users to customize security settings for their swarm deployments. +""" + +from typing import List, Optional, Dict, Any +from pydantic import BaseModel, Field +from enum import Enum + +from .swarm_shield import EncryptionStrength + + +class SecurityLevel(Enum): + """Security levels for shield configuration""" + + BASIC = "basic" # Basic input validation and output filtering + STANDARD = "standard" # Standard security with encryption + ENHANCED = "enhanced" # Enhanced security with additional checks + MAXIMUM = "maximum" # Maximum security with all features enabled + + +class ShieldConfig(BaseModel): + """ + Configuration for SwarmShield security features + + This class provides comprehensive configuration options for + enabling and customizing security features across all swarm architectures. + """ + + # Core security settings + enabled: bool = Field(default=True, description="Enable shield protection") + security_level: SecurityLevel = Field( + default=SecurityLevel.STANDARD, + description="Overall security level" + ) + + # SwarmShield settings + encryption_strength: EncryptionStrength = Field( + default=EncryptionStrength.MAXIMUM, + description="Encryption strength for message protection" + ) + key_rotation_interval: int = Field( + default=3600, + ge=300, # Minimum 5 minutes + description="Key rotation interval in seconds" + ) + storage_path: Optional[str] = Field( + default=None, + description="Path for encrypted storage" + ) + + # Input validation settings + enable_input_validation: bool = Field( + default=True, + description="Enable input validation and sanitization" + ) + max_input_length: int = Field( + default=10000, + ge=100, + description="Maximum input length in characters" + ) + blocked_patterns: List[str] = Field( + default=[ + r".*?", # XSS prevention + r"javascript:", # JavaScript injection + r"data:text/html", # HTML injection + r"vbscript:", # VBScript injection + r"on\w+\s*=", # Event handler injection + ], + description="Regex patterns to block in inputs" + ) + allowed_domains: List[str] = Field( + default=[], + description="Allowed domains for external requests" + ) + + # Output filtering settings + enable_output_filtering: bool = Field( + default=True, + description="Enable output filtering and sanitization" + ) + filter_sensitive_data: bool = Field( + default=True, + description="Filter sensitive data from outputs" + ) + sensitive_patterns: List[str] = Field( + default=[ + r"\b\d{3}-\d{2}-\d{4}\b", # SSN + r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b", # Credit card + r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", # Email + r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b", # IP address + ], + description="Regex patterns for sensitive data" + ) + + # Safety checking settings + enable_safety_checks: bool = Field( + default=True, + description="Enable safety and content filtering" + ) + safety_prompt_on: bool = Field( + default=True, + description="Enable safety prompt integration" + ) + content_filter_level: str = Field( + default="moderate", + description="Content filtering level (low, moderate, high)" + ) + + # Rate limiting settings + enable_rate_limiting: bool = Field( + default=True, + description="Enable rate limiting and abuse prevention" + ) + max_requests_per_minute: int = Field( + default=60, + ge=1, + description="Maximum requests per minute per agent" + ) + max_tokens_per_request: int = Field( + default=10000, + ge=100, + description="Maximum tokens per request" + ) + rate_limit_window: int = Field( + default=60, + ge=10, + description="Rate limit window in seconds" + ) + + # Audit and logging settings + enable_audit_logging: bool = Field( + default=True, + description="Enable comprehensive audit logging" + ) + log_security_events: bool = Field( + default=True, + description="Log security-related events" + ) + log_input_output: bool = Field( + default=False, + description="Log input and output data (use with caution)" + ) + audit_retention_days: int = Field( + default=90, + ge=1, + description="Audit log retention period in days" + ) + + # Performance settings + enable_caching: bool = Field( + default=True, + description="Enable security result caching" + ) + cache_ttl: int = Field( + default=300, + ge=60, + description="Cache TTL in seconds" + ) + max_cache_size: int = Field( + default=1000, + ge=100, + description="Maximum cache entries" + ) + + # Integration settings + integrate_with_conversation: bool = Field( + default=True, + description="Integrate with conversation management" + ) + protect_agent_communications: bool = Field( + default=True, + description="Protect inter-agent communications" + ) + encrypt_storage: bool = Field( + default=True, + description="Encrypt persistent storage" + ) + + # Custom settings + custom_rules: Dict[str, Any] = Field( + default={}, + description="Custom security rules and configurations" + ) + + # Compatibility fields for examples + encryption_enabled: bool = Field( + default=True, + description="Enable encryption (alias for enabled)" + ) + input_validation_enabled: bool = Field( + default=True, + description="Enable input validation (alias for enable_input_validation)" + ) + output_filtering_enabled: bool = Field( + default=True, + description="Enable output filtering (alias for enable_output_filtering)" + ) + rate_limiting_enabled: bool = Field( + default=True, + description="Enable rate limiting (alias for enable_rate_limiting)" + ) + safety_checking_enabled: bool = Field( + default=True, + description="Enable safety checking (alias for enable_safety_checks)" + ) + block_suspicious_content: bool = Field( + default=True, + description="Block suspicious content patterns" + ) + custom_blocked_patterns: List[str] = Field( + default=[], + description="Custom patterns to block in addition to default ones" + ) + safety_threshold: float = Field( + default=0.7, + ge=0.0, + le=1.0, + description="Safety threshold for content filtering" + ) + bias_detection_enabled: bool = Field( + default=False, + description="Enable bias detection in content" + ) + content_moderation_enabled: bool = Field( + default=True, + description="Enable content moderation" + ) + + class Config: + """Pydantic configuration""" + use_enum_values = True + validate_assignment = True + + def get_encryption_config(self) -> Dict[str, Any]: + """Get encryption configuration for SwarmShield""" + return { + "encryption_strength": self.encryption_strength, + "key_rotation_interval": self.key_rotation_interval, + "storage_path": self.storage_path, + } + + def get_validation_config(self) -> Dict[str, Any]: + """Get input validation configuration""" + return { + "enabled": self.enable_input_validation, + "max_length": self.max_input_length, + "blocked_patterns": self.blocked_patterns, + "allowed_domains": self.allowed_domains, + } + + def get_filtering_config(self) -> Dict[str, Any]: + """Get output filtering configuration""" + return { + "enabled": self.enable_output_filtering, + "filter_sensitive": self.filter_sensitive_data, + "sensitive_patterns": self.sensitive_patterns, + } + + def get_safety_config(self) -> Dict[str, Any]: + """Get safety checking configuration""" + return { + "enabled": self.enable_safety_checks, + "safety_prompt": self.safety_prompt_on, + "filter_level": self.content_filter_level, + } + + def get_rate_limit_config(self) -> Dict[str, Any]: + """Get rate limiting configuration""" + return { + "enabled": self.enable_rate_limiting, + "max_requests_per_minute": self.max_requests_per_minute, + "max_tokens_per_request": self.max_tokens_per_request, + "window": self.rate_limit_window, + } + + def get_audit_config(self) -> Dict[str, Any]: + """Get audit logging configuration""" + return { + "enabled": self.enable_audit_logging, + "log_security": self.log_security_events, + "log_io": self.log_input_output, + "retention_days": self.audit_retention_days, + } + + def get_performance_config(self) -> Dict[str, Any]: + """Get performance configuration""" + return { + "enable_caching": self.enable_caching, + "cache_ttl": self.cache_ttl, + "max_cache_size": self.max_cache_size, + } + + def get_integration_config(self) -> Dict[str, Any]: + """Get integration configuration""" + return { + "conversation": self.integrate_with_conversation, + "agent_communications": self.protect_agent_communications, + "encrypt_storage": self.encrypt_storage, + } + + @classmethod + def create_basic_config(cls) -> "ShieldConfig": + """Create a basic security configuration""" + return cls( + security_level=SecurityLevel.BASIC, + encryption_strength=EncryptionStrength.STANDARD, + enable_safety_checks=False, + enable_rate_limiting=False, + enable_audit_logging=False, + ) + + @classmethod + def create_standard_config(cls) -> "ShieldConfig": + """Create a standard security configuration""" + return cls( + security_level=SecurityLevel.STANDARD, + encryption_strength=EncryptionStrength.ENHANCED, + ) + + @classmethod + def create_enhanced_config(cls) -> "ShieldConfig": + """Create an enhanced security configuration""" + return cls( + security_level=SecurityLevel.ENHANCED, + encryption_strength=EncryptionStrength.MAXIMUM, + max_requests_per_minute=30, + content_filter_level="high", + log_input_output=True, + ) + + @classmethod + def create_maximum_config(cls) -> "ShieldConfig": + """Create a maximum security configuration""" + return cls( + security_level=SecurityLevel.MAXIMUM, + encryption_strength=EncryptionStrength.MAXIMUM, + key_rotation_interval=1800, # 30 minutes + max_requests_per_minute=20, + content_filter_level="high", + log_input_output=True, + audit_retention_days=365, + max_cache_size=500, + ) + + @classmethod + def get_security_level(cls, level: str) -> "ShieldConfig": + """Get a security configuration for the specified level""" + level = level.lower() + + if level == "basic": + return cls.create_basic_config() + elif level == "standard": + return cls.create_standard_config() + elif level == "enhanced": + return cls.create_enhanced_config() + elif level == "maximum": + return cls.create_maximum_config() + else: + raise ValueError(f"Unknown security level: {level}. Must be one of: basic, standard, enhanced, maximum") \ No newline at end of file diff --git a/swarms/security/swarm_shield.py b/swarms/security/swarm_shield.py new file mode 100644 index 00000000..5478fafa --- /dev/null +++ b/swarms/security/swarm_shield.py @@ -0,0 +1,737 @@ +""" +SwarmShield integration for Swarms framework. + +This module provides enterprise-grade security for swarm communications, +including encryption, conversation management, and audit capabilities. +""" + +from __future__ import annotations + +import base64 +import hashlib +import hmac +import json +import os +import secrets +import threading +import time +import uuid +from datetime import datetime, timezone +from enum import Enum +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.ciphers import ( + Cipher, + algorithms, + modes, +) +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +from loguru import logger + +from swarms.utils.loguru_logger import initialize_logger + +# Initialize logger for security module +security_logger = initialize_logger(log_folder="security") + + +class EncryptionStrength(Enum): + """Encryption strength levels for SwarmShield""" + + STANDARD = "standard" # AES-256 + ENHANCED = "enhanced" # AES-256 + SHA-512 + MAXIMUM = "maximum" # AES-256 + SHA-512 + HMAC + + +class SwarmShield: + """ + SwarmShield: Advanced security system for swarm agents + + Features: + - Multi-layer message encryption + - Secure conversation storage + - Automatic key rotation + - Message integrity verification + - Integration with Swarms framework + """ + + def __init__( + self, + encryption_strength: EncryptionStrength = EncryptionStrength.MAXIMUM, + key_rotation_interval: int = 3600, # 1 hour + storage_path: Optional[str] = None, + enable_logging: bool = True, + ): + """ + Initialize SwarmShield with security settings + + Args: + encryption_strength: Level of encryption to use + key_rotation_interval: Key rotation interval in seconds + storage_path: Path for encrypted storage + enable_logging: Enable security logging + """ + self.encryption_strength = encryption_strength + self.key_rotation_interval = key_rotation_interval + self.enable_logging = enable_logging + + # Set storage path within swarms framework + if storage_path: + self.storage_path = Path(storage_path) + else: + self.storage_path = Path("swarms_security_storage") + + # Initialize storage and locks + self.storage_path.mkdir(parents=True, exist_ok=True) + self._conv_lock = threading.Lock() + self._conversations: Dict[str, List[Dict]] = {} + + # Initialize security components + self._initialize_security() + self._load_conversations() + + if self.enable_logging: + # Handle encryption_strength safely (could be enum or string) + encryption_str = ( + encryption_strength.value + if hasattr(encryption_strength, 'value') + else str(encryption_strength) + ) + security_logger.info( + f"SwarmShield initialized with {encryption_str} encryption" + ) + + def _initialize_security(self) -> None: + """Set up encryption keys and components""" + try: + # Generate master key and salt + self.master_key = secrets.token_bytes(32) + self.salt = os.urandom(16) + + # Initialize key derivation + self.kdf = PBKDF2HMAC( + algorithm=hashes.SHA512(), + length=32, + salt=self.salt, + iterations=600000, + backend=default_backend(), + ) + + # Generate initial keys + self._rotate_keys() + self.hmac_key = secrets.token_bytes(32) + + except Exception as e: + if self.enable_logging: + security_logger.error(f"Security initialization failed: {e}") + raise + + def _rotate_keys(self) -> None: + """Perform security key rotation""" + try: + self.encryption_key = self.kdf.derive(self.master_key) + self.iv = os.urandom(16) + self.last_rotation = time.time() + if self.enable_logging: + security_logger.debug("Security keys rotated successfully") + except Exception as e: + if self.enable_logging: + security_logger.error(f"Key rotation failed: {e}") + raise + + def _check_rotation(self) -> None: + """Check and perform key rotation if needed""" + if ( + time.time() - self.last_rotation + >= self.key_rotation_interval + ): + self._rotate_keys() + + def _save_conversation(self, conversation_id: str) -> None: + """Save conversation to encrypted storage""" + try: + if conversation_id not in self._conversations: + return + + # Encrypt conversation data + json_data = json.dumps( + self._conversations[conversation_id] + ).encode() + cipher = Cipher( + algorithms.AES(self.encryption_key), + modes.GCM(self.iv), + backend=default_backend(), + ) + encryptor = cipher.encryptor() + encrypted_data = ( + encryptor.update(json_data) + encryptor.finalize() + ) + + # Combine encrypted data with authentication tag + combined_data = encrypted_data + encryptor.tag + + # Save atomically using temporary file + conv_path = self.storage_path / f"{conversation_id}.conv" + temp_path = conv_path.with_suffix(".tmp") + + with open(temp_path, "wb") as f: + f.write(combined_data) + temp_path.replace(conv_path) + + if self.enable_logging: + security_logger.debug(f"Saved conversation {conversation_id}") + + except Exception as e: + if self.enable_logging: + security_logger.error(f"Failed to save conversation: {e}") + raise + + def _load_conversations(self) -> None: + """Load existing conversations from storage""" + try: + for file_path in self.storage_path.glob("*.conv"): + try: + with open(file_path, "rb") as f: + combined_data = f.read() + conversation_id = file_path.stem + + # Split combined data into encrypted data and authentication tag + if len(combined_data) < 16: # Minimum size for GCM tag + continue + + encrypted_data = combined_data[:-16] + auth_tag = combined_data[-16:] + + # Decrypt conversation data + cipher = Cipher( + algorithms.AES(self.encryption_key), + modes.GCM(self.iv, auth_tag), + backend=default_backend(), + ) + decryptor = cipher.decryptor() + json_data = ( + decryptor.update(encrypted_data) + + decryptor.finalize() + ) + + self._conversations[conversation_id] = json.loads( + json_data + ) + if self.enable_logging: + security_logger.debug( + f"Loaded conversation {conversation_id}" + ) + + except Exception as e: + if self.enable_logging: + security_logger.error( + f"Failed to load conversation {file_path}: {e}" + ) + continue + + except Exception as e: + if self.enable_logging: + security_logger.error(f"Failed to load conversations: {e}") + raise + + def protect_message(self, agent_name: str, message: str) -> str: + """ + Encrypt a message with multiple security layers + + Args: + agent_name: Name of the sending agent + message: Message to encrypt + + Returns: + Encrypted message string + """ + try: + self._check_rotation() + + # Validate inputs + if not isinstance(agent_name, str) or not isinstance( + message, str + ): + raise ValueError( + "Agent name and message must be strings" + ) + if not agent_name.strip() or not message.strip(): + raise ValueError( + "Agent name and message cannot be empty" + ) + + # Generate message ID and timestamp + message_id = secrets.token_hex(16) + timestamp = datetime.now(timezone.utc).isoformat() + + # Encrypt message content + message_bytes = message.encode() + cipher = Cipher( + algorithms.AES(self.encryption_key), + modes.GCM(self.iv), + backend=default_backend(), + ) + encryptor = cipher.encryptor() + ciphertext = ( + encryptor.update(message_bytes) + encryptor.finalize() + ) + + # Calculate message hash + message_hash = hashlib.sha512(message_bytes).hexdigest() + + # Generate HMAC if maximum security + hmac_signature = None + if self.encryption_strength == EncryptionStrength.MAXIMUM: + h = hmac.new( + self.hmac_key, ciphertext, hashlib.sha512 + ) + hmac_signature = h.digest() + + # Create secure package + secure_package = { + "id": message_id, + "time": timestamp, + "agent": agent_name, + "cipher": base64.b64encode(ciphertext).decode(), + "tag": base64.b64encode(encryptor.tag).decode(), + "hash": message_hash, + "hmac": ( + base64.b64encode(hmac_signature).decode() + if hmac_signature + else None + ), + } + + return base64.b64encode( + json.dumps(secure_package).encode() + ).decode() + + except Exception as e: + if self.enable_logging: + security_logger.error(f"Failed to protect message: {e}") + raise + + def retrieve_message(self, encrypted_str: str) -> Tuple[str, str]: + """ + Decrypt and verify a message + + Args: + encrypted_str: Encrypted message string + + Returns: + Tuple of (agent_name, message) + """ + try: + # Decode secure package + secure_package = json.loads( + base64.b64decode(encrypted_str) + ) + + # Get components + ciphertext = base64.b64decode(secure_package["cipher"]) + tag = base64.b64decode(secure_package["tag"]) + + # Verify HMAC if present + if secure_package["hmac"]: + hmac_signature = base64.b64decode( + secure_package["hmac"] + ) + h = hmac.new( + self.hmac_key, ciphertext, hashlib.sha512 + ) + if not hmac.compare_digest( + hmac_signature, h.digest() + ): + raise ValueError("HMAC verification failed") + + # Decrypt message + cipher = Cipher( + algorithms.AES(self.encryption_key), + modes.GCM(self.iv, tag), + backend=default_backend(), + ) + decryptor = cipher.decryptor() + decrypted_data = ( + decryptor.update(ciphertext) + decryptor.finalize() + ) + + # Verify hash + if ( + hashlib.sha512(decrypted_data).hexdigest() + != secure_package["hash"] + ): + raise ValueError("Message hash verification failed") + + return secure_package["agent"], decrypted_data.decode() + + except Exception as e: + if self.enable_logging: + security_logger.error(f"Failed to retrieve message: {e}") + raise + + def create_conversation(self, name: str = "") -> str: + """Create a new secure conversation""" + conversation_id = str(uuid.uuid4()) + with self._conv_lock: + self._conversations[conversation_id] = { + "id": conversation_id, + "name": name, + "created_at": datetime.now(timezone.utc).isoformat(), + "messages": [], + } + self._save_conversation(conversation_id) + if self.enable_logging: + security_logger.info(f"Created conversation {conversation_id}") + return conversation_id + + def add_message( + self, conversation_id: str, agent_name: str, message: str + ) -> None: + """ + Add an encrypted message to a conversation + + Args: + conversation_id: Target conversation ID + agent_name: Name of the sending agent + message: Message content + """ + try: + # Encrypt message + encrypted = self.protect_message(agent_name, message) + + # Add to conversation + with self._conv_lock: + if conversation_id not in self._conversations: + raise ValueError( + f"Invalid conversation ID: {conversation_id}" + ) + + self._conversations[conversation_id][ + "messages" + ].append( + { + "timestamp": datetime.now( + timezone.utc + ).isoformat(), + "data": encrypted, + } + ) + + # Save changes + self._save_conversation(conversation_id) + + if self.enable_logging: + security_logger.info( + f"Added message to conversation {conversation_id}" + ) + + except Exception as e: + if self.enable_logging: + security_logger.error(f"Failed to add message: {e}") + raise + + def get_messages( + self, conversation_id: str + ) -> List[Tuple[str, str, datetime]]: + """ + Get decrypted messages from a conversation + + Args: + conversation_id: Target conversation ID + + Returns: + List of (agent_name, message, timestamp) tuples + """ + try: + with self._conv_lock: + if conversation_id not in self._conversations: + raise ValueError( + f"Invalid conversation ID: {conversation_id}" + ) + + history = [] + for msg in self._conversations[conversation_id][ + "messages" + ]: + agent_name, message = self.retrieve_message( + msg["data"] + ) + timestamp = datetime.fromisoformat( + msg["timestamp"] + ) + history.append((agent_name, message, timestamp)) + + return history + + except Exception as e: + if self.enable_logging: + security_logger.error(f"Failed to get messages: {e}") + raise + + def get_conversation_summary(self, conversation_id: str) -> Dict: + """ + Get summary statistics for a conversation + + Args: + conversation_id: Target conversation ID + + Returns: + Dictionary with conversation statistics + """ + try: + with self._conv_lock: + if conversation_id not in self._conversations: + raise ValueError( + f"Invalid conversation ID: {conversation_id}" + ) + + conv = self._conversations[conversation_id] + messages = conv["messages"] + + # Get unique agents + agents = set() + for msg in messages: + agent_name, _ = self.retrieve_message(msg["data"]) + agents.add(agent_name) + + return { + "id": conversation_id, + "name": conv["name"], + "created_at": conv["created_at"], + "message_count": len(messages), + "agents": list(agents), + "last_message": ( + messages[-1]["timestamp"] if messages else None + ), + } + + except Exception as e: + if self.enable_logging: + security_logger.error(f"Failed to get summary: {e}") + raise + + def delete_conversation(self, conversation_id: str) -> None: + """ + Delete a conversation and its storage + + Args: + conversation_id: Target conversation ID + """ + try: + with self._conv_lock: + if conversation_id not in self._conversations: + raise ValueError( + f"Invalid conversation ID: {conversation_id}" + ) + + # Remove from memory + del self._conversations[conversation_id] + + # Remove from storage + conv_path = self.storage_path / f"{conversation_id}.conv" + if conv_path.exists(): + conv_path.unlink() + + if self.enable_logging: + security_logger.info(f"Deleted conversation {conversation_id}") + + except Exception as e: + if self.enable_logging: + security_logger.error(f"Failed to delete conversation: {e}") + raise + + def export_conversation( + self, conversation_id: str, format: str = "json", path: str = None + ) -> str: + """ + Export a conversation to a file + + Args: + conversation_id: Target conversation ID + format: Export format (json, txt) + path: Output file path + + Returns: + Path to exported file + """ + try: + messages = self.get_messages(conversation_id) + summary = self.get_conversation_summary(conversation_id) + + if not path: + path = f"conversation_{conversation_id}.{format}" + + if format.lower() == "json": + export_data = { + "summary": summary, + "messages": [ + { + "agent": agent, + "message": message, + "timestamp": timestamp.isoformat(), + } + for agent, message, timestamp in messages + ], + } + with open(path, "w") as f: + json.dump(export_data, f, indent=2) + + elif format.lower() == "txt": + with open(path, "w") as f: + f.write(f"Conversation: {summary['name']}\n") + f.write(f"Created: {summary['created_at']}\n") + f.write(f"Messages: {summary['message_count']}\n") + f.write(f"Agents: {', '.join(summary['agents'])}\n\n") + + for agent, message, timestamp in messages: + f.write(f"[{timestamp}] {agent}: {message}\n") + + else: + raise ValueError(f"Unsupported format: {format}") + + if self.enable_logging: + security_logger.info(f"Exported conversation to {path}") + + return path + + except Exception as e: + if self.enable_logging: + security_logger.error(f"Failed to export conversation: {e}") + raise + + def backup_conversations(self, backup_path: str = None) -> str: + """ + Create a backup of all conversations + + Args: + backup_path: Backup directory path + + Returns: + Path to backup directory + """ + try: + if not backup_path: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_path = f"swarm_shield_backup_{timestamp}" + + backup_dir = Path(backup_path) + backup_dir.mkdir(parents=True, exist_ok=True) + + # Export all conversations + for conversation_id in self._conversations: + self.export_conversation( + conversation_id, + format="json", + path=str(backup_dir / f"{conversation_id}.json"), + ) + + if self.enable_logging: + security_logger.info(f"Backup created at {backup_path}") + + return backup_path + + except Exception as e: + if self.enable_logging: + security_logger.error(f"Failed to create backup: {e}") + raise + + def get_agent_stats(self, agent_name: str) -> Dict: + """ + Get statistics for a specific agent + + Args: + agent_name: Name of the agent + + Returns: + Dictionary with agent statistics + """ + try: + total_messages = 0 + conversations = set() + + for conversation_id, conv in self._conversations.items(): + for msg in conv["messages"]: + msg_agent, _ = self.retrieve_message(msg["data"]) + if msg_agent == agent_name: + total_messages += 1 + conversations.add(conversation_id) + + return { + "agent_name": agent_name, + "total_messages": total_messages, + "conversations": len(conversations), + "conversation_ids": list(conversations), + } + + except Exception as e: + if self.enable_logging: + security_logger.error(f"Failed to get agent stats: {e}") + raise + + def query_conversations( + self, + agent_name: str = None, + text: str = None, + start_date: datetime = None, + end_date: datetime = None, + limit: int = 100, + ) -> List[Dict]: + """ + Search conversations with filters + + Args: + agent_name: Filter by agent name + text: Search for text in messages + start_date: Filter by start date + end_date: Filter by end date + limit: Maximum results to return + + Returns: + List of matching conversation summaries + """ + try: + results = [] + + for conversation_id, conv in self._conversations.items(): + # Check date range + conv_date = datetime.fromisoformat(conv["created_at"]) + if start_date and conv_date < start_date: + continue + if end_date and conv_date > end_date: + continue + + # Check agent filter + if agent_name: + conv_agents = set() + for msg in conv["messages"]: + msg_agent, _ = self.retrieve_message(msg["data"]) + conv_agents.add(msg_agent) + if agent_name not in conv_agents: + continue + + # Check text filter + if text: + text_found = False + for msg in conv["messages"]: + _, message = self.retrieve_message(msg["data"]) + if text.lower() in message.lower(): + text_found = True + break + if not text_found: + continue + + # Add to results + summary = self.get_conversation_summary(conversation_id) + results.append(summary) + + if len(results) >= limit: + break + + return results + + except Exception as e: + if self.enable_logging: + security_logger.error(f"Failed to query conversations: {e}") + raise \ No newline at end of file diff --git a/swarms/security/swarm_shield_integration.py b/swarms/security/swarm_shield_integration.py new file mode 100644 index 00000000..dc7ec57f --- /dev/null +++ b/swarms/security/swarm_shield_integration.py @@ -0,0 +1,318 @@ +""" +SwarmShield integration for Swarms framework. + +This module provides the main integration point for all security features, +combining input validation, output filtering, safety checking, rate limiting, +and encryption into a unified security shield. +""" + +from typing import Dict, Any, Optional, Tuple, Union, List +from datetime import datetime + +from loguru import logger + +from swarms.utils.loguru_logger import initialize_logger +from .shield_config import ShieldConfig +from .swarm_shield import SwarmShield, EncryptionStrength +from .input_validator import InputValidator +from .output_filter import OutputFilter +from .safety_checker import SafetyChecker +from .rate_limiter import RateLimiter + +# Initialize logger for shield integration +shield_logger = initialize_logger(log_folder="shield_integration") + + +class SwarmShieldIntegration: + """ + Main SwarmShield integration class + + This class provides a unified interface for all security features: + - Input validation and sanitization + - Output filtering and sensitive data protection + - Safety checking and ethical AI compliance + - Rate limiting and abuse prevention + - Encrypted communication and storage + """ + + def __init__(self, config: Optional[ShieldConfig] = None): + """ + Initialize SwarmShield integration + + Args: + config: Shield configuration. If None, uses default configuration. + """ + self.config = config or ShieldConfig() + + # Initialize security components + self._initialize_components() + + # Handle security_level safely (could be enum or string) + security_level_str = ( + self.config.security_level.value + if hasattr(self.config.security_level, 'value') + else str(self.config.security_level) + ) + shield_logger.info(f"SwarmShield integration initialized with {security_level_str} security") + + def _initialize_components(self) -> None: + """Initialize all security components""" + try: + # Initialize SwarmShield for encryption + if self.config.integrate_with_conversation: + self.swarm_shield = SwarmShield( + **self.config.get_encryption_config(), + enable_logging=self.config.enable_audit_logging + ) + else: + self.swarm_shield = None + + # Initialize input validator + self.input_validator = InputValidator( + self.config.get_validation_config() + ) + + # Initialize output filter + self.output_filter = OutputFilter( + self.config.get_filtering_config() + ) + + # Initialize safety checker + self.safety_checker = SafetyChecker( + self.config.get_safety_config() + ) + + # Initialize rate limiter + self.rate_limiter = RateLimiter( + self.config.get_rate_limit_config() + ) + + except Exception as e: + shield_logger.error(f"Failed to initialize security components: {e}") + raise + + def validate_and_protect_input(self, input_data: str, agent_name: str, input_type: str = "text") -> Tuple[bool, str, Optional[str]]: + """ + Validate and protect input data + + Args: + input_data: Input data to validate + agent_name: Name of the agent + input_type: Type of input + + Returns: + Tuple of (is_valid, protected_data, error_message) + """ + try: + # Check rate limits + is_allowed, error = self.rate_limiter.check_agent_limit(agent_name) + if not is_allowed: + return False, "", error + + # Validate input + is_valid, sanitized, error = self.input_validator.validate_input(input_data, input_type) + if not is_valid: + return False, "", error + + # Check safety + is_safe, safe_content, error = self.safety_checker.check_safety(sanitized, input_type) + if not is_safe: + return False, "", error + + # Protect with encryption if enabled + if self.swarm_shield and self.config.protect_agent_communications: + try: + protected = self.swarm_shield.protect_message(agent_name, safe_content) + return True, protected, None + except Exception as e: + shield_logger.error(f"Encryption failed: {e}") + return False, "", f"Encryption error: {str(e)}" + + return True, safe_content, None + + except Exception as e: + shield_logger.error(f"Input validation error: {e}") + return False, "", f"Validation error: {str(e)}" + + def filter_and_protect_output(self, output_data: Union[str, Dict, List], agent_name: str, output_type: str = "text") -> Tuple[bool, Union[str, Dict, List], Optional[str]]: + """ + Filter and protect output data + + Args: + output_data: Output data to filter + agent_name: Name of the agent + output_type: Type of output + + Returns: + Tuple of (is_safe, filtered_data, warning_message) + """ + try: + # Filter output + is_safe, filtered, warning = self.output_filter.filter_output(output_data, output_type) + if not is_safe: + return False, "", "Output contains unsafe content" + + # Check safety + if isinstance(filtered, str): + is_safe, safe_content, error = self.safety_checker.check_safety(filtered, output_type) + if not is_safe: + return False, "", error + filtered = safe_content + + # Track request + self.rate_limiter.track_request(agent_name, f"output_{output_type}") + + return True, filtered, warning + + except Exception as e: + shield_logger.error(f"Output filtering error: {e}") + return False, "", f"Filtering error: {str(e)}" + + def process_agent_communication(self, agent_name: str, message: str, direction: str = "outbound") -> Tuple[bool, str, Optional[str]]: + """ + Process agent communication (inbound or outbound) + + Args: + agent_name: Name of the agent + message: Message content + direction: "inbound" or "outbound" + + Returns: + Tuple of (is_valid, processed_message, error_message) + """ + try: + if direction == "inbound": + return self.validate_and_protect_input(message, agent_name, "text") + elif direction == "outbound": + return self.filter_and_protect_output(message, agent_name, "text") + else: + return False, "", f"Invalid direction: {direction}" + + except Exception as e: + shield_logger.error(f"Communication processing error: {e}") + return False, "", f"Processing error: {str(e)}" + + def validate_task(self, task: str, agent_name: str) -> Tuple[bool, str, Optional[str]]: + """Validate swarm task""" + return self.validate_and_protect_input(task, agent_name, "task") + + def validate_agent_config(self, config: Dict[str, Any], agent_name: str) -> Tuple[bool, Dict[str, Any], Optional[str]]: + """Validate agent configuration""" + try: + # Validate config structure + is_valid, validated_config, error = self.input_validator.validate_config(config) + if not is_valid: + return False, {}, error + + # Check safety + is_safe, safe_config, error = self.safety_checker.check_config_safety(validated_config) + if not is_safe: + return False, {}, error + + # Filter sensitive data + is_safe, filtered_config, warning = self.output_filter.filter_config_output(safe_config) + if not is_safe: + return False, {}, "Configuration contains unsafe content" + + return True, filtered_config, warning + + except Exception as e: + shield_logger.error(f"Config validation error: {e}") + return False, {}, f"Config validation error: {str(e)}" + + def create_secure_conversation(self, name: str = "") -> Optional[str]: + """Create a secure conversation if encryption is enabled""" + if self.swarm_shield and self.config.integrate_with_conversation: + try: + return self.swarm_shield.create_conversation(name) + except Exception as e: + shield_logger.error(f"Failed to create secure conversation: {e}") + return None + return None + + def add_secure_message(self, conversation_id: str, agent_name: str, message: str) -> bool: + """Add a message to secure conversation""" + if self.swarm_shield and self.config.integrate_with_conversation: + try: + self.swarm_shield.add_message(conversation_id, agent_name, message) + return True + except Exception as e: + shield_logger.error(f"Failed to add secure message: {e}") + return False + return False + + def get_secure_messages(self, conversation_id: str) -> List[Tuple[str, str, datetime]]: + """Get messages from secure conversation""" + if self.swarm_shield and self.config.integrate_with_conversation: + try: + return self.swarm_shield.get_messages(conversation_id) + except Exception as e: + shield_logger.error(f"Failed to get secure messages: {e}") + return [] + return [] + + def check_rate_limit(self, agent_name: str, request_size: int = 1) -> Tuple[bool, Optional[str]]: + """Check rate limits for an agent""" + return self.rate_limiter.check_rate_limit(agent_name, request_size) + + def get_security_stats(self) -> Dict[str, Any]: + """Get comprehensive security statistics""" + stats = { + "security_enabled": self.config.enabled, + "input_validations": getattr(self.input_validator, 'validation_count', 0), + "rate_limit_checks": getattr(self.rate_limiter, 'check_count', 0), + "blocked_requests": getattr(self.rate_limiter, 'blocked_count', 0), + "filtered_outputs": getattr(self.output_filter, 'filter_count', 0), + "violations": getattr(self.safety_checker, 'violation_count', 0), + "encryption_enabled": self.swarm_shield is not None, + } + + # Handle security_level safely + if hasattr(self.config.security_level, 'value'): + stats["security_level"] = self.config.security_level.value + else: + stats["security_level"] = str(self.config.security_level) + + # Handle encryption_strength safely + if self.swarm_shield and hasattr(self.swarm_shield.encryption_strength, 'value'): + stats["encryption_strength"] = self.swarm_shield.encryption_strength.value + elif self.swarm_shield: + stats["encryption_strength"] = str(self.swarm_shield.encryption_strength) + else: + stats["encryption_strength"] = "none" + + return stats + + def update_config(self, new_config: ShieldConfig) -> bool: + """Update shield configuration""" + try: + self.config = new_config + self._initialize_components() + shield_logger.info("Shield configuration updated") + return True + except Exception as e: + shield_logger.error(f"Failed to update configuration: {e}") + return False + + def enable_security(self) -> None: + """Enable all security features""" + self.config.enabled = True + shield_logger.info("Security features enabled") + + def disable_security(self) -> None: + """Disable all security features""" + self.config.enabled = False + shield_logger.info("Security features disabled") + + def cleanup(self) -> None: + """Cleanup resources""" + try: + if hasattr(self, 'rate_limiter'): + self.rate_limiter.stop() + shield_logger.info("SwarmShield integration cleaned up") + except Exception as e: + shield_logger.error(f"Cleanup error: {e}") + + def __del__(self): + """Cleanup on destruction""" + self.cleanup() \ No newline at end of file