parent
b03379280d
commit
df31f6f2db
@ -0,0 +1,29 @@
|
|||||||
|
"""
|
||||||
|
Security module for Swarms framework.
|
||||||
|
|
||||||
|
This module provides enterprise-grade security features including:
|
||||||
|
- SwarmShield integration for encrypted communications
|
||||||
|
- Input validation and sanitization
|
||||||
|
- Output filtering and safety checks
|
||||||
|
- Rate limiting and abuse prevention
|
||||||
|
- Audit logging and compliance features
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .swarm_shield import SwarmShield, EncryptionStrength
|
||||||
|
from .shield_config import ShieldConfig
|
||||||
|
from .input_validator import InputValidator
|
||||||
|
from .output_filter import OutputFilter
|
||||||
|
from .safety_checker import SafetyChecker
|
||||||
|
from .rate_limiter import RateLimiter
|
||||||
|
from .swarm_shield_integration import SwarmShieldIntegration
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SwarmShield",
|
||||||
|
"EncryptionStrength",
|
||||||
|
"ShieldConfig",
|
||||||
|
"InputValidator",
|
||||||
|
"OutputFilter",
|
||||||
|
"SafetyChecker",
|
||||||
|
"RateLimiter",
|
||||||
|
"SwarmShieldIntegration",
|
||||||
|
]
|
@ -0,0 +1,259 @@
|
|||||||
|
"""
|
||||||
|
Input validation and sanitization for Swarms framework.
|
||||||
|
|
||||||
|
This module provides comprehensive input validation, sanitization,
|
||||||
|
and security checks for all swarm inputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import html
|
||||||
|
from typing import List, Optional, Dict, Any, Tuple
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from swarms.utils.loguru_logger import initialize_logger
|
||||||
|
|
||||||
|
# Initialize logger for input validation
|
||||||
|
validation_logger = initialize_logger(log_folder="input_validation")
|
||||||
|
|
||||||
|
|
||||||
|
class InputValidator:
|
||||||
|
"""
|
||||||
|
Input validation and sanitization for swarm security
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Input length validation
|
||||||
|
- Pattern-based blocking
|
||||||
|
- XSS prevention
|
||||||
|
- SQL injection prevention
|
||||||
|
- URL validation
|
||||||
|
- Content type validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Initialize input validator with configuration
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Validation configuration dictionary
|
||||||
|
"""
|
||||||
|
self.enabled = config.get("enabled", True)
|
||||||
|
self.max_length = config.get("max_length", 10000)
|
||||||
|
self.blocked_patterns = config.get("blocked_patterns", [])
|
||||||
|
self.allowed_domains = config.get("allowed_domains", [])
|
||||||
|
|
||||||
|
# Compile regex patterns for performance
|
||||||
|
self._compiled_patterns = [
|
||||||
|
re.compile(pattern, re.IGNORECASE) for pattern in self.blocked_patterns
|
||||||
|
]
|
||||||
|
|
||||||
|
# Common malicious patterns
|
||||||
|
self._malicious_patterns = [
|
||||||
|
re.compile(r"<script.*?>.*?</script>", re.IGNORECASE),
|
||||||
|
re.compile(r"javascript:", re.IGNORECASE),
|
||||||
|
re.compile(r"data:text/html", re.IGNORECASE),
|
||||||
|
re.compile(r"vbscript:", re.IGNORECASE),
|
||||||
|
re.compile(r"on\w+\s*=", re.IGNORECASE),
|
||||||
|
re.compile(r"<iframe.*?>.*?</iframe>", re.IGNORECASE),
|
||||||
|
re.compile(r"<object.*?>.*?</object>", re.IGNORECASE),
|
||||||
|
re.compile(r"<embed.*?>", re.IGNORECASE),
|
||||||
|
re.compile(r"<link.*?>", re.IGNORECASE),
|
||||||
|
re.compile(r"<meta.*?>", re.IGNORECASE),
|
||||||
|
]
|
||||||
|
|
||||||
|
# SQL injection patterns
|
||||||
|
self._sql_patterns = [
|
||||||
|
re.compile(r"(\b(union|select|insert|update|delete|drop|create|alter)\b)", re.IGNORECASE),
|
||||||
|
re.compile(r"(--|#|/\*|\*/)", re.IGNORECASE),
|
||||||
|
re.compile(r"(\b(exec|execute|xp_|sp_)\b)", re.IGNORECASE),
|
||||||
|
]
|
||||||
|
|
||||||
|
validation_logger.info("InputValidator initialized")
|
||||||
|
|
||||||
|
def validate_input(self, input_data: str, input_type: str = "text") -> Tuple[bool, str, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Validate and sanitize input data
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_data: Input data to validate
|
||||||
|
input_type: Type of input (text, url, code, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_valid, sanitized_data, error_message)
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
return True, input_data, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Basic type validation
|
||||||
|
if not isinstance(input_data, str):
|
||||||
|
return False, "", "Input must be a string"
|
||||||
|
|
||||||
|
# Length validation
|
||||||
|
if len(input_data) > self.max_length:
|
||||||
|
return False, "", f"Input exceeds maximum length of {self.max_length} characters"
|
||||||
|
|
||||||
|
# Empty input check
|
||||||
|
if not input_data.strip():
|
||||||
|
return False, "", "Input cannot be empty"
|
||||||
|
|
||||||
|
# Sanitize the input
|
||||||
|
sanitized = self._sanitize_input(input_data)
|
||||||
|
|
||||||
|
# Check for blocked patterns
|
||||||
|
if self._check_blocked_patterns(sanitized):
|
||||||
|
return False, "", "Input contains blocked patterns"
|
||||||
|
|
||||||
|
# Check for malicious patterns
|
||||||
|
if self._check_malicious_patterns(sanitized):
|
||||||
|
return False, "", "Input contains potentially malicious content"
|
||||||
|
|
||||||
|
# Type-specific validation
|
||||||
|
if input_type == "url":
|
||||||
|
if not self._validate_url(sanitized):
|
||||||
|
return False, "", "Invalid URL format"
|
||||||
|
|
||||||
|
elif input_type == "code":
|
||||||
|
if not self._validate_code(sanitized):
|
||||||
|
return False, "", "Invalid code content"
|
||||||
|
|
||||||
|
elif input_type == "json":
|
||||||
|
if not self._validate_json(sanitized):
|
||||||
|
return False, "", "Invalid JSON format"
|
||||||
|
|
||||||
|
validation_logger.debug(f"Input validation passed for type: {input_type}")
|
||||||
|
return True, sanitized, None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
validation_logger.error(f"Input validation error: {e}")
|
||||||
|
return False, "", f"Validation error: {str(e)}"
|
||||||
|
|
||||||
|
def _sanitize_input(self, input_data: str) -> str:
|
||||||
|
"""Sanitize input data to prevent XSS and other attacks"""
|
||||||
|
# HTML escape
|
||||||
|
sanitized = html.escape(input_data)
|
||||||
|
|
||||||
|
# Remove null bytes
|
||||||
|
sanitized = sanitized.replace('\x00', '')
|
||||||
|
|
||||||
|
# Normalize whitespace
|
||||||
|
sanitized = ' '.join(sanitized.split())
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
def _check_blocked_patterns(self, input_data: str) -> bool:
|
||||||
|
"""Check if input contains blocked patterns"""
|
||||||
|
for pattern in self._compiled_patterns:
|
||||||
|
if pattern.search(input_data):
|
||||||
|
validation_logger.warning(f"Blocked pattern detected: {pattern.pattern}")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _check_malicious_patterns(self, input_data: str) -> bool:
|
||||||
|
"""Check if input contains malicious patterns"""
|
||||||
|
for pattern in self._malicious_patterns:
|
||||||
|
if pattern.search(input_data):
|
||||||
|
validation_logger.warning(f"Malicious pattern detected: {pattern.pattern}")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _validate_url(self, url: str) -> bool:
|
||||||
|
"""Validate URL format and domain"""
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
|
||||||
|
# Check if it's a valid URL
|
||||||
|
if not all([parsed.scheme, parsed.netloc]):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check allowed domains if specified
|
||||||
|
if self.allowed_domains:
|
||||||
|
domain = parsed.netloc.lower()
|
||||||
|
if not any(allowed in domain for allowed in self.allowed_domains):
|
||||||
|
validation_logger.warning(f"Domain not allowed: {domain}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _validate_code(self, code: str) -> bool:
|
||||||
|
"""Validate code content for safety"""
|
||||||
|
# Check for SQL injection patterns
|
||||||
|
for pattern in self._sql_patterns:
|
||||||
|
if pattern.search(code):
|
||||||
|
validation_logger.warning(f"SQL injection pattern detected: {pattern.pattern}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for dangerous system calls
|
||||||
|
dangerous_calls = [
|
||||||
|
'os.system', 'subprocess.call', 'eval(', 'exec(',
|
||||||
|
'__import__', 'globals()', 'locals()'
|
||||||
|
]
|
||||||
|
|
||||||
|
for call in dangerous_calls:
|
||||||
|
if call in code:
|
||||||
|
validation_logger.warning(f"Dangerous call detected: {call}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _validate_json(self, json_str: str) -> bool:
|
||||||
|
"""Validate JSON format"""
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
json.loads(json_str)
|
||||||
|
return True
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def validate_task(self, task: str) -> Tuple[bool, str, Optional[str]]:
|
||||||
|
"""Validate swarm task input"""
|
||||||
|
return self.validate_input(task, "text")
|
||||||
|
|
||||||
|
def validate_agent_name(self, agent_name: str) -> Tuple[bool, str, Optional[str]]:
|
||||||
|
"""Validate agent name input"""
|
||||||
|
# Additional validation for agent names
|
||||||
|
if not re.match(r'^[a-zA-Z0-9_-]+$', agent_name):
|
||||||
|
return False, "", "Agent name can only contain letters, numbers, underscores, and hyphens"
|
||||||
|
|
||||||
|
if len(agent_name) < 1 or len(agent_name) > 50:
|
||||||
|
return False, "", "Agent name must be between 1 and 50 characters"
|
||||||
|
|
||||||
|
return self.validate_input(agent_name, "text")
|
||||||
|
|
||||||
|
def validate_message(self, message: str) -> Tuple[bool, str, Optional[str]]:
|
||||||
|
"""Validate message input"""
|
||||||
|
return self.validate_input(message, "text")
|
||||||
|
|
||||||
|
def validate_config(self, config: Dict[str, Any]) -> Tuple[bool, Dict[str, Any], Optional[str]]:
|
||||||
|
"""Validate configuration input"""
|
||||||
|
try:
|
||||||
|
# Convert config to string for validation
|
||||||
|
import json
|
||||||
|
config_str = json.dumps(config)
|
||||||
|
|
||||||
|
is_valid, sanitized, error = self.validate_input(config_str, "json")
|
||||||
|
if not is_valid:
|
||||||
|
return False, {}, error
|
||||||
|
|
||||||
|
# Parse back to dict
|
||||||
|
validated_config = json.loads(sanitized)
|
||||||
|
return True, validated_config, None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return False, {}, f"Configuration validation error: {str(e)}"
|
||||||
|
|
||||||
|
def get_validation_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get validation statistics"""
|
||||||
|
return {
|
||||||
|
"enabled": self.enabled,
|
||||||
|
"max_length": self.max_length,
|
||||||
|
"blocked_patterns_count": len(self.blocked_patterns),
|
||||||
|
"allowed_domains_count": len(self.allowed_domains),
|
||||||
|
"malicious_patterns_count": len(self._malicious_patterns),
|
||||||
|
"sql_patterns_count": len(self._sql_patterns),
|
||||||
|
}
|
@ -0,0 +1,285 @@
|
|||||||
|
"""
|
||||||
|
Output filtering and sanitization for Swarms framework.
|
||||||
|
|
||||||
|
This module provides comprehensive output filtering, sanitization,
|
||||||
|
and sensitive data protection for all swarm outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
from typing import List, Optional, Dict, Any, Tuple, Union
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from swarms.utils.loguru_logger import initialize_logger
|
||||||
|
|
||||||
|
# Initialize logger for output filtering
|
||||||
|
filter_logger = initialize_logger(log_folder="output_filtering")
|
||||||
|
|
||||||
|
|
||||||
|
class OutputFilter:
|
||||||
|
"""
|
||||||
|
Output filtering and sanitization for swarm security
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Sensitive data filtering
|
||||||
|
- Output sanitization
|
||||||
|
- Content type filtering
|
||||||
|
- PII protection
|
||||||
|
- Malicious content detection
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Initialize output filter with configuration
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Filtering configuration dictionary
|
||||||
|
"""
|
||||||
|
self.enabled = config.get("enabled", True)
|
||||||
|
self.filter_sensitive = config.get("filter_sensitive", True)
|
||||||
|
self.sensitive_patterns = config.get("sensitive_patterns", [])
|
||||||
|
|
||||||
|
# Compile regex patterns for performance
|
||||||
|
self._compiled_patterns = [
|
||||||
|
re.compile(pattern, re.IGNORECASE) for pattern in self.sensitive_patterns
|
||||||
|
]
|
||||||
|
|
||||||
|
# Default sensitive data patterns
|
||||||
|
self._default_patterns = [
|
||||||
|
re.compile(r"\b\d{3}-\d{2}-\d{4}\b"), # SSN
|
||||||
|
re.compile(r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b"), # Credit card
|
||||||
|
re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"), # Email
|
||||||
|
re.compile(r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b"), # IP address
|
||||||
|
re.compile(r"\b\d{3}[\s-]?\d{3}[\s-]?\d{4}\b"), # Phone number
|
||||||
|
re.compile(r"\b[A-Z]{2}\d{2}[A-Z0-9]{10,30}\b"), # IBAN
|
||||||
|
re.compile(r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{1,3}\b"), # Extended CC
|
||||||
|
]
|
||||||
|
|
||||||
|
# Malicious content patterns
|
||||||
|
self._malicious_patterns = [
|
||||||
|
re.compile(r"<script.*?>.*?</script>", re.IGNORECASE),
|
||||||
|
re.compile(r"javascript:", re.IGNORECASE),
|
||||||
|
re.compile(r"data:text/html", re.IGNORECASE),
|
||||||
|
re.compile(r"vbscript:", re.IGNORECASE),
|
||||||
|
re.compile(r"on\w+\s*=", re.IGNORECASE),
|
||||||
|
re.compile(r"<iframe.*?>.*?</iframe>", re.IGNORECASE),
|
||||||
|
re.compile(r"<object.*?>.*?</object>", re.IGNORECASE),
|
||||||
|
re.compile(r"<embed.*?>", re.IGNORECASE),
|
||||||
|
]
|
||||||
|
|
||||||
|
# API key patterns
|
||||||
|
self._api_key_patterns = [
|
||||||
|
re.compile(r"sk-[a-zA-Z0-9]{32,}"), # OpenAI API key
|
||||||
|
re.compile(r"pk_[a-zA-Z0-9]{32,}"), # OpenAI API key (public)
|
||||||
|
re.compile(r"[a-zA-Z0-9]{32,}"), # Generic API key
|
||||||
|
]
|
||||||
|
|
||||||
|
filter_logger.info("OutputFilter initialized")
|
||||||
|
|
||||||
|
def filter_output(self, output_data: Union[str, Dict, List], output_type: str = "text") -> Tuple[bool, Union[str, Dict, List], Optional[str]]:
|
||||||
|
"""
|
||||||
|
Filter and sanitize output data
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_data: Output data to filter
|
||||||
|
output_type: Type of output (text, json, dict, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_safe, filtered_data, warning_message)
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
return True, output_data, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert to string for processing
|
||||||
|
if isinstance(output_data, (dict, list)):
|
||||||
|
output_str = json.dumps(output_data, ensure_ascii=False)
|
||||||
|
else:
|
||||||
|
output_str = str(output_data)
|
||||||
|
|
||||||
|
# Check for malicious content
|
||||||
|
if self._check_malicious_content(output_str):
|
||||||
|
return False, "", "Output contains potentially malicious content"
|
||||||
|
|
||||||
|
# Filter sensitive data
|
||||||
|
if self.filter_sensitive:
|
||||||
|
filtered_str = self._filter_sensitive_data(output_str)
|
||||||
|
else:
|
||||||
|
filtered_str = output_str
|
||||||
|
|
||||||
|
# Convert back to original type if needed
|
||||||
|
if isinstance(output_data, (dict, list)) and output_type in ["json", "dict"]:
|
||||||
|
try:
|
||||||
|
filtered_data = json.loads(filtered_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
filtered_data = filtered_str
|
||||||
|
else:
|
||||||
|
filtered_data = filtered_str
|
||||||
|
|
||||||
|
# Check if any sensitive data was filtered
|
||||||
|
warning = None
|
||||||
|
if filtered_str != output_str:
|
||||||
|
warning = "Sensitive data was filtered from output"
|
||||||
|
|
||||||
|
filter_logger.debug(f"Output filtering completed for type: {output_type}")
|
||||||
|
return True, filtered_data, warning
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
filter_logger.error(f"Output filtering error: {e}")
|
||||||
|
return False, "", f"Filtering error: {str(e)}"
|
||||||
|
|
||||||
|
def _check_malicious_content(self, content: str) -> bool:
|
||||||
|
"""Check if content contains malicious patterns"""
|
||||||
|
for pattern in self._malicious_patterns:
|
||||||
|
if pattern.search(content):
|
||||||
|
filter_logger.warning(f"Malicious content detected: {pattern.pattern}")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _filter_sensitive_data(self, content: str) -> str:
|
||||||
|
"""Filter sensitive data from content"""
|
||||||
|
filtered_content = content
|
||||||
|
|
||||||
|
# Filter custom sensitive patterns
|
||||||
|
for pattern in self._compiled_patterns:
|
||||||
|
filtered_content = pattern.sub("[SENSITIVE_DATA]", filtered_content)
|
||||||
|
|
||||||
|
# Filter default sensitive patterns
|
||||||
|
for pattern in self._default_patterns:
|
||||||
|
filtered_content = pattern.sub("[SENSITIVE_DATA]", filtered_content)
|
||||||
|
|
||||||
|
# Filter API keys
|
||||||
|
for pattern in self._api_key_patterns:
|
||||||
|
filtered_content = pattern.sub("[API_KEY]", filtered_content)
|
||||||
|
|
||||||
|
return filtered_content
|
||||||
|
|
||||||
|
def filter_agent_response(self, response: str, agent_name: str) -> Tuple[bool, str, Optional[str]]:
|
||||||
|
"""Filter agent response output"""
|
||||||
|
return self.filter_output(response, "text")
|
||||||
|
|
||||||
|
def filter_swarm_output(self, output: Union[str, Dict, List]) -> Tuple[bool, Union[str, Dict, List], Optional[str]]:
|
||||||
|
"""Filter swarm output"""
|
||||||
|
return self.filter_output(output, "json")
|
||||||
|
|
||||||
|
def filter_conversation_history(self, history: List[Dict]) -> Tuple[bool, List[Dict], Optional[str]]:
|
||||||
|
"""Filter conversation history"""
|
||||||
|
try:
|
||||||
|
filtered_history = []
|
||||||
|
warnings = []
|
||||||
|
|
||||||
|
for message in history:
|
||||||
|
# Filter message content
|
||||||
|
is_safe, filtered_content, warning = self.filter_output(
|
||||||
|
message.get("content", ""), "text"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_safe:
|
||||||
|
return False, [], "Conversation history contains unsafe content"
|
||||||
|
|
||||||
|
# Create filtered message
|
||||||
|
filtered_message = message.copy()
|
||||||
|
filtered_message["content"] = filtered_content
|
||||||
|
|
||||||
|
if warning:
|
||||||
|
warnings.append(warning)
|
||||||
|
|
||||||
|
filtered_history.append(filtered_message)
|
||||||
|
|
||||||
|
warning_msg = "; ".join(set(warnings)) if warnings else None
|
||||||
|
return True, filtered_history, warning_msg
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
filter_logger.error(f"Conversation history filtering error: {e}")
|
||||||
|
return False, [], f"History filtering error: {str(e)}"
|
||||||
|
|
||||||
|
def filter_config_output(self, config: Dict[str, Any]) -> Tuple[bool, Dict[str, Any], Optional[str]]:
|
||||||
|
"""Filter configuration output"""
|
||||||
|
try:
|
||||||
|
# Create a copy to avoid modifying original
|
||||||
|
filtered_config = config.copy()
|
||||||
|
|
||||||
|
# Filter sensitive config fields
|
||||||
|
sensitive_fields = [
|
||||||
|
"api_key", "secret", "password", "token", "key",
|
||||||
|
"credential", "auth", "private", "secret_key"
|
||||||
|
]
|
||||||
|
|
||||||
|
warnings = []
|
||||||
|
for field in sensitive_fields:
|
||||||
|
if field in filtered_config:
|
||||||
|
if isinstance(filtered_config[field], str):
|
||||||
|
filtered_config[field] = "[SENSITIVE_CONFIG]"
|
||||||
|
warnings.append(f"Sensitive config field '{field}' was filtered")
|
||||||
|
|
||||||
|
warning_msg = "; ".join(warnings) if warnings else None
|
||||||
|
return True, filtered_config, warning_msg
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
filter_logger.error(f"Config filtering error: {e}")
|
||||||
|
return False, {}, f"Config filtering error: {str(e)}"
|
||||||
|
|
||||||
|
def sanitize_for_logging(self, data: Union[str, Dict, List]) -> str:
|
||||||
|
"""Sanitize data for logging purposes"""
|
||||||
|
try:
|
||||||
|
if isinstance(data, (dict, list)):
|
||||||
|
data_str = json.dumps(data, ensure_ascii=False)
|
||||||
|
else:
|
||||||
|
data_str = str(data)
|
||||||
|
|
||||||
|
# Apply aggressive filtering for logs
|
||||||
|
sanitized = self._filter_sensitive_data(data_str)
|
||||||
|
|
||||||
|
# Truncate if too long
|
||||||
|
if len(sanitized) > 1000:
|
||||||
|
sanitized = sanitized[:1000] + "... [TRUNCATED]"
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
filter_logger.error(f"Log sanitization error: {e}")
|
||||||
|
return "[SANITIZATION_ERROR]"
|
||||||
|
|
||||||
|
def add_custom_pattern(self, pattern: str, description: str = "") -> None:
|
||||||
|
"""Add custom sensitive data pattern"""
|
||||||
|
try:
|
||||||
|
compiled_pattern = re.compile(pattern, re.IGNORECASE)
|
||||||
|
self._compiled_patterns.append(compiled_pattern)
|
||||||
|
self.sensitive_patterns.append(pattern)
|
||||||
|
|
||||||
|
filter_logger.info(f"Added custom pattern: {pattern} ({description})")
|
||||||
|
|
||||||
|
except re.error as e:
|
||||||
|
filter_logger.error(f"Invalid regex pattern: {pattern} - {e}")
|
||||||
|
|
||||||
|
def remove_pattern(self, pattern: str) -> bool:
|
||||||
|
"""Remove sensitive data pattern"""
|
||||||
|
try:
|
||||||
|
if pattern in self.sensitive_patterns:
|
||||||
|
self.sensitive_patterns.remove(pattern)
|
||||||
|
|
||||||
|
# Recompile patterns
|
||||||
|
self._compiled_patterns = [
|
||||||
|
re.compile(p, re.IGNORECASE) for p in self.sensitive_patterns
|
||||||
|
]
|
||||||
|
|
||||||
|
filter_logger.info(f"Removed pattern: {pattern}")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
filter_logger.error(f"Error removing pattern: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_filter_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get filtering statistics"""
|
||||||
|
return {
|
||||||
|
"enabled": self.enabled,
|
||||||
|
"filter_sensitive": self.filter_sensitive,
|
||||||
|
"sensitive_patterns_count": len(self.sensitive_patterns),
|
||||||
|
"malicious_patterns_count": len(self._malicious_patterns),
|
||||||
|
"api_key_patterns_count": len(self._api_key_patterns),
|
||||||
|
"default_patterns_count": len(self._default_patterns),
|
||||||
|
}
|
@ -0,0 +1,323 @@
|
|||||||
|
"""
|
||||||
|
Rate limiting and abuse prevention for Swarms framework.
|
||||||
|
|
||||||
|
This module provides comprehensive rate limiting, request tracking,
|
||||||
|
and abuse prevention for all swarm operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
from typing import Dict, List, Optional, Tuple, Any
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from swarms.utils.loguru_logger import initialize_logger
|
||||||
|
|
||||||
|
# Initialize logger for rate limiting
|
||||||
|
rate_logger = initialize_logger(log_folder="rate_limiting")
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""
|
||||||
|
Rate limiting and abuse prevention for swarm security
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Per-agent rate limiting
|
||||||
|
- Token-based limiting
|
||||||
|
- Request tracking
|
||||||
|
- Abuse detection
|
||||||
|
- Automatic blocking
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Initialize rate limiter with configuration
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Rate limiting configuration dictionary
|
||||||
|
"""
|
||||||
|
self.enabled = config.get("enabled", True)
|
||||||
|
self.max_requests_per_minute = config.get("max_requests_per_minute", 60)
|
||||||
|
self.max_tokens_per_request = config.get("max_tokens_per_request", 10000)
|
||||||
|
self.window = config.get("window", 60) # seconds
|
||||||
|
|
||||||
|
# Request tracking
|
||||||
|
self._request_history: Dict[str, deque] = defaultdict(lambda: deque())
|
||||||
|
self._token_usage: Dict[str, int] = defaultdict(int)
|
||||||
|
self._blocked_agents: Dict[str, float] = {}
|
||||||
|
|
||||||
|
# Thread safety
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
# Cleanup thread
|
||||||
|
self._cleanup_thread = None
|
||||||
|
self._stop_cleanup = False
|
||||||
|
|
||||||
|
if self.enabled:
|
||||||
|
self._start_cleanup_thread()
|
||||||
|
|
||||||
|
rate_logger.info(f"RateLimiter initialized: {self.max_requests_per_minute} req/min, {self.max_tokens_per_request} tokens/req")
|
||||||
|
|
||||||
|
def check_rate_limit(self, agent_name: str, request_size: int = 1) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Check if request is within rate limits
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_name: Name of the agent making the request
|
||||||
|
request_size: Size of the request (tokens or complexity)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_allowed, error_message)
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with self._lock:
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# Check if agent is blocked
|
||||||
|
if agent_name in self._blocked_agents:
|
||||||
|
block_until = self._blocked_agents[agent_name]
|
||||||
|
if current_time < block_until:
|
||||||
|
remaining = block_until - current_time
|
||||||
|
return False, f"Agent {agent_name} is blocked for {remaining:.1f} more seconds"
|
||||||
|
else:
|
||||||
|
# Unblock agent
|
||||||
|
del self._blocked_agents[agent_name]
|
||||||
|
|
||||||
|
# Check token limit
|
||||||
|
if request_size > self.max_tokens_per_request:
|
||||||
|
return False, f"Request size {request_size} exceeds token limit {self.max_tokens_per_request}"
|
||||||
|
|
||||||
|
# Get agent's request history
|
||||||
|
history = self._request_history[agent_name]
|
||||||
|
|
||||||
|
# Remove old requests outside the window
|
||||||
|
cutoff_time = current_time - self.window
|
||||||
|
while history and history[0] < cutoff_time:
|
||||||
|
history.popleft()
|
||||||
|
|
||||||
|
# Check request count limit
|
||||||
|
if len(history) >= self.max_requests_per_minute:
|
||||||
|
# Block agent temporarily
|
||||||
|
block_duration = min(300, self.window * 2) # Max 5 minutes
|
||||||
|
self._blocked_agents[agent_name] = current_time + block_duration
|
||||||
|
|
||||||
|
rate_logger.warning(f"Agent {agent_name} rate limit exceeded, blocked for {block_duration}s")
|
||||||
|
return False, f"Rate limit exceeded. Agent blocked for {block_duration} seconds"
|
||||||
|
|
||||||
|
# Add current request
|
||||||
|
history.append(current_time)
|
||||||
|
|
||||||
|
# Update token usage
|
||||||
|
self._token_usage[agent_name] += request_size
|
||||||
|
|
||||||
|
rate_logger.debug(f"Rate limit check passed for {agent_name}")
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
rate_logger.error(f"Rate limit check error: {e}")
|
||||||
|
return False, f"Rate limit check error: {str(e)}"
|
||||||
|
|
||||||
|
def check_agent_limit(self, agent_name: str) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""Check agent-specific rate limits"""
|
||||||
|
return self.check_rate_limit(agent_name, 1)
|
||||||
|
|
||||||
|
def check_token_limit(self, agent_name: str, token_count: int) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""Check token-based rate limits"""
|
||||||
|
return self.check_rate_limit(agent_name, token_count)
|
||||||
|
|
||||||
|
def track_request(self, agent_name: str, request_type: str = "general", metadata: Dict[str, Any] = None) -> None:
|
||||||
|
"""
|
||||||
|
Track a request for monitoring purposes
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_name: Name of the agent
|
||||||
|
request_type: Type of request
|
||||||
|
metadata: Additional request metadata
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
with self._lock:
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# Store request metadata
|
||||||
|
request_info = {
|
||||||
|
"timestamp": current_time,
|
||||||
|
"type": request_type,
|
||||||
|
"metadata": metadata or {}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add to history (we only store timestamps for rate limiting)
|
||||||
|
self._request_history[agent_name].append(current_time)
|
||||||
|
|
||||||
|
rate_logger.debug(f"Tracked request: {agent_name} - {request_type}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
rate_logger.error(f"Request tracking error: {e}")
|
||||||
|
|
||||||
|
def get_agent_stats(self, agent_name: str) -> Dict[str, Any]:
|
||||||
|
"""Get rate limiting statistics for an agent"""
|
||||||
|
try:
|
||||||
|
with self._lock:
|
||||||
|
current_time = time.time()
|
||||||
|
history = self._request_history[agent_name]
|
||||||
|
|
||||||
|
# Clean old requests
|
||||||
|
cutoff_time = current_time - self.window
|
||||||
|
while history and history[0] < cutoff_time:
|
||||||
|
history.popleft()
|
||||||
|
|
||||||
|
# Calculate statistics
|
||||||
|
recent_requests = len(history)
|
||||||
|
total_tokens = self._token_usage.get(agent_name, 0)
|
||||||
|
is_blocked = agent_name in self._blocked_agents
|
||||||
|
block_remaining = 0
|
||||||
|
|
||||||
|
if is_blocked:
|
||||||
|
block_remaining = self._blocked_agents[agent_name] - current_time
|
||||||
|
|
||||||
|
return {
|
||||||
|
"agent_name": agent_name,
|
||||||
|
"recent_requests": recent_requests,
|
||||||
|
"max_requests": self.max_requests_per_minute,
|
||||||
|
"total_tokens": total_tokens,
|
||||||
|
"max_tokens_per_request": self.max_tokens_per_request,
|
||||||
|
"is_blocked": is_blocked,
|
||||||
|
"block_remaining": max(0, block_remaining),
|
||||||
|
"window_seconds": self.window,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
rate_logger.error(f"Error getting agent stats: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def get_all_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get rate limiting statistics for all agents"""
|
||||||
|
try:
|
||||||
|
with self._lock:
|
||||||
|
stats = {
|
||||||
|
"enabled": self.enabled,
|
||||||
|
"max_requests_per_minute": self.max_requests_per_minute,
|
||||||
|
"max_tokens_per_request": self.max_tokens_per_request,
|
||||||
|
"window_seconds": self.window,
|
||||||
|
"total_agents": len(self._request_history),
|
||||||
|
"blocked_agents": len(self._blocked_agents),
|
||||||
|
"agents": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
for agent_name in self._request_history:
|
||||||
|
stats["agents"][agent_name] = self.get_agent_stats(agent_name)
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
rate_logger.error(f"Error getting all stats: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def reset_agent(self, agent_name: str) -> bool:
|
||||||
|
"""Reset rate limiting for an agent"""
|
||||||
|
try:
|
||||||
|
with self._lock:
|
||||||
|
if agent_name in self._request_history:
|
||||||
|
self._request_history[agent_name].clear()
|
||||||
|
|
||||||
|
if agent_name in self._token_usage:
|
||||||
|
self._token_usage[agent_name] = 0
|
||||||
|
|
||||||
|
if agent_name in self._blocked_agents:
|
||||||
|
del self._blocked_agents[agent_name]
|
||||||
|
|
||||||
|
rate_logger.info(f"Reset rate limiting for agent: {agent_name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
rate_logger.error(f"Error resetting agent: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def block_agent(self, agent_name: str, duration: int = 300) -> bool:
|
||||||
|
"""Manually block an agent"""
|
||||||
|
try:
|
||||||
|
with self._lock:
|
||||||
|
current_time = time.time()
|
||||||
|
self._blocked_agents[agent_name] = current_time + duration
|
||||||
|
|
||||||
|
rate_logger.warning(f"Manually blocked agent {agent_name} for {duration}s")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
rate_logger.error(f"Error blocking agent: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def unblock_agent(self, agent_name: str) -> bool:
|
||||||
|
"""Unblock an agent"""
|
||||||
|
try:
|
||||||
|
with self._lock:
|
||||||
|
if agent_name in self._blocked_agents:
|
||||||
|
del self._blocked_agents[agent_name]
|
||||||
|
rate_logger.info(f"Unblocked agent: {agent_name}")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
rate_logger.error(f"Error unblocking agent: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _start_cleanup_thread(self) -> None:
|
||||||
|
"""Start background cleanup thread"""
|
||||||
|
def cleanup_worker():
|
||||||
|
while not self._stop_cleanup:
|
||||||
|
try:
|
||||||
|
time.sleep(60) # Cleanup every minute
|
||||||
|
self._cleanup_old_data()
|
||||||
|
except Exception as e:
|
||||||
|
rate_logger.error(f"Cleanup thread error: {e}")
|
||||||
|
|
||||||
|
self._cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True)
|
||||||
|
self._cleanup_thread.start()
|
||||||
|
|
||||||
|
def _cleanup_old_data(self) -> None:
|
||||||
|
"""Clean up old rate limiting data"""
|
||||||
|
try:
|
||||||
|
with self._lock:
|
||||||
|
current_time = time.time()
|
||||||
|
cutoff_time = current_time - (self.window * 2) # Keep 2x window
|
||||||
|
|
||||||
|
# Clean request history
|
||||||
|
for agent_name in list(self._request_history.keys()):
|
||||||
|
history = self._request_history[agent_name]
|
||||||
|
while history and history[0] < cutoff_time:
|
||||||
|
history.popleft()
|
||||||
|
|
||||||
|
# Remove empty histories
|
||||||
|
if not history:
|
||||||
|
del self._request_history[agent_name]
|
||||||
|
|
||||||
|
# Clean blocked agents
|
||||||
|
for agent_name in list(self._blocked_agents.keys()):
|
||||||
|
if self._blocked_agents[agent_name] < current_time:
|
||||||
|
del self._blocked_agents[agent_name]
|
||||||
|
|
||||||
|
# Reset token usage periodically
|
||||||
|
if current_time % 3600 < 60: # Reset every hour
|
||||||
|
self._token_usage.clear()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
rate_logger.error(f"Cleanup error: {e}")
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Stop the rate limiter and cleanup thread"""
|
||||||
|
self._stop_cleanup = True
|
||||||
|
if self._cleanup_thread:
|
||||||
|
self._cleanup_thread.join(timeout=5)
|
||||||
|
|
||||||
|
rate_logger.info("RateLimiter stopped")
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Cleanup on destruction"""
|
||||||
|
self.stop()
|
@ -0,0 +1,249 @@
|
|||||||
|
"""
|
||||||
|
Safety checking and content filtering for Swarms framework.
|
||||||
|
|
||||||
|
This module provides comprehensive safety checks, content filtering,
|
||||||
|
and ethical AI compliance for all swarm operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import List, Optional, Dict, Any, Tuple
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from swarms.utils.loguru_logger import initialize_logger
|
||||||
|
from swarms.prompts.safety_prompt import SAFETY_PROMPT
|
||||||
|
|
||||||
|
# Initialize logger for safety checking
|
||||||
|
safety_logger = initialize_logger(log_folder="safety_checking")
|
||||||
|
|
||||||
|
|
||||||
|
class ContentLevel(Enum):
|
||||||
|
"""Content filtering levels"""
|
||||||
|
|
||||||
|
LOW = "low" # Minimal filtering
|
||||||
|
MODERATE = "moderate" # Standard filtering
|
||||||
|
HIGH = "high" # Aggressive filtering
|
||||||
|
|
||||||
|
|
||||||
|
class SafetyChecker:
|
||||||
|
"""
|
||||||
|
Safety checking and content filtering for swarm security
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Content safety assessment
|
||||||
|
- Ethical AI compliance
|
||||||
|
- Harmful content detection
|
||||||
|
- Bias detection
|
||||||
|
- Safety prompt integration
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Initialize safety checker with configuration
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Safety configuration dictionary
|
||||||
|
"""
|
||||||
|
self.enabled = config.get("enabled", True)
|
||||||
|
self.safety_prompt = config.get("safety_prompt", True)
|
||||||
|
self.filter_level = ContentLevel(config.get("filter_level", "moderate"))
|
||||||
|
|
||||||
|
# Harmful content patterns
|
||||||
|
self._harmful_patterns = [
|
||||||
|
re.compile(r"\b(kill|murder|suicide|bomb|explosive|weapon)\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\b(hack|crack|steal|fraud|scam|phishing)\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\b(drug|illegal|contraband|smuggle)\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\b(hate|racist|sexist|discriminate)\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\b(terrorist|extremist|radical)\b", re.IGNORECASE),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Bias detection patterns
|
||||||
|
self._bias_patterns = [
|
||||||
|
re.compile(r"\b(all|every|always|never|none)\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\b(men are|women are|blacks are|whites are)\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\b(stereotypical|typical|usual)\b", re.IGNORECASE),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Age-inappropriate content
|
||||||
|
self._age_inappropriate = [
|
||||||
|
re.compile(r"\b(sex|porn|adult|explicit)\b", re.IGNORECASE),
|
||||||
|
re.compile(r"\b(violence|gore|blood|death)\b", re.IGNORECASE),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Handle filter_level safely
|
||||||
|
filter_level_str = (
|
||||||
|
self.filter_level.value
|
||||||
|
if hasattr(self.filter_level, 'value')
|
||||||
|
else str(self.filter_level)
|
||||||
|
)
|
||||||
|
safety_logger.info(f"SafetyChecker initialized with level: {filter_level_str}")
|
||||||
|
|
||||||
|
def check_safety(self, content: str, content_type: str = "text") -> Tuple[bool, str, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Check content for safety and ethical compliance
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to check
|
||||||
|
content_type: Type of content (text, code, config, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_safe, sanitized_content, warning_message)
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
return True, content, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Basic type validation
|
||||||
|
if not isinstance(content, str):
|
||||||
|
return False, "", "Content must be a string"
|
||||||
|
|
||||||
|
# Check for harmful content
|
||||||
|
if self._check_harmful_content(content):
|
||||||
|
return False, "", "Content contains potentially harmful material"
|
||||||
|
|
||||||
|
# Check for bias
|
||||||
|
if self._check_bias(content):
|
||||||
|
return False, "", "Content contains potentially biased language"
|
||||||
|
|
||||||
|
# Check age appropriateness
|
||||||
|
if self._check_age_appropriate(content):
|
||||||
|
return False, "", "Content may not be age-appropriate"
|
||||||
|
|
||||||
|
# Apply content filtering based on level
|
||||||
|
sanitized = self._filter_content(content)
|
||||||
|
|
||||||
|
# Check if content was modified
|
||||||
|
warning = None
|
||||||
|
if sanitized != content:
|
||||||
|
# Handle filter_level safely
|
||||||
|
filter_level_str = (
|
||||||
|
self.filter_level.value
|
||||||
|
if hasattr(self.filter_level, 'value')
|
||||||
|
else str(self.filter_level)
|
||||||
|
)
|
||||||
|
warning = f"Content was filtered for {filter_level_str} safety level"
|
||||||
|
|
||||||
|
safety_logger.debug(f"Safety check passed for type: {content_type}")
|
||||||
|
return True, sanitized, warning
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
safety_logger.error(f"Safety check error: {e}")
|
||||||
|
return False, "", f"Safety check error: {str(e)}"
|
||||||
|
|
||||||
|
def _check_harmful_content(self, content: str) -> bool:
|
||||||
|
"""Check for harmful content patterns"""
|
||||||
|
for pattern in self._harmful_patterns:
|
||||||
|
if pattern.search(content):
|
||||||
|
safety_logger.warning(f"Harmful content detected: {pattern.pattern}")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _check_bias(self, content: str) -> bool:
|
||||||
|
"""Check for biased language patterns"""
|
||||||
|
for pattern in self._bias_patterns:
|
||||||
|
if pattern.search(content):
|
||||||
|
safety_logger.warning(f"Bias detected: {pattern.pattern}")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _check_age_appropriate(self, content: str) -> bool:
|
||||||
|
"""Check for age-inappropriate content"""
|
||||||
|
for pattern in self._age_inappropriate:
|
||||||
|
if pattern.search(content):
|
||||||
|
safety_logger.warning(f"Age-inappropriate content detected: {pattern.pattern}")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _filter_content(self, content: str) -> str:
|
||||||
|
"""Filter content based on safety level"""
|
||||||
|
filtered_content = content
|
||||||
|
|
||||||
|
if self.filter_level == ContentLevel.HIGH:
|
||||||
|
# Aggressive filtering
|
||||||
|
for pattern in self._harmful_patterns + self._bias_patterns:
|
||||||
|
filtered_content = pattern.sub("[FILTERED]", filtered_content)
|
||||||
|
|
||||||
|
elif self.filter_level == ContentLevel.MODERATE:
|
||||||
|
# Moderate filtering - only filter obvious harmful content
|
||||||
|
for pattern in self._harmful_patterns:
|
||||||
|
filtered_content = pattern.sub("[FILTERED]", filtered_content)
|
||||||
|
|
||||||
|
# LOW level doesn't filter content, only detects
|
||||||
|
|
||||||
|
return filtered_content
|
||||||
|
|
||||||
|
def check_agent_safety(self, agent_name: str, system_prompt: str) -> Tuple[bool, str, Optional[str]]:
|
||||||
|
"""Check agent system prompt for safety"""
|
||||||
|
return self.check_safety(system_prompt, "agent_prompt")
|
||||||
|
|
||||||
|
def check_task_safety(self, task: str) -> Tuple[bool, str, Optional[str]]:
|
||||||
|
"""Check task description for safety"""
|
||||||
|
return self.check_safety(task, "task")
|
||||||
|
|
||||||
|
def check_response_safety(self, response: str, agent_name: str) -> Tuple[bool, str, Optional[str]]:
|
||||||
|
"""Check agent response for safety"""
|
||||||
|
return self.check_safety(response, "response")
|
||||||
|
|
||||||
|
def check_config_safety(self, config: Dict[str, Any]) -> Tuple[bool, Dict[str, Any], Optional[str]]:
|
||||||
|
"""Check configuration for safety"""
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
config_str = json.dumps(config)
|
||||||
|
|
||||||
|
is_safe, sanitized, error = self.check_safety(config_str, "config")
|
||||||
|
if not is_safe:
|
||||||
|
return False, {}, error
|
||||||
|
|
||||||
|
# Parse back to dict
|
||||||
|
safe_config = json.loads(sanitized)
|
||||||
|
return True, safe_config, None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return False, {}, f"Config safety check error: {str(e)}"
|
||||||
|
|
||||||
|
def get_safety_prompt(self) -> str:
|
||||||
|
"""Get safety prompt for integration"""
|
||||||
|
if self.safety_prompt:
|
||||||
|
return SAFETY_PROMPT
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def add_harmful_pattern(self, pattern: str, description: str = "") -> None:
|
||||||
|
"""Add custom harmful content pattern"""
|
||||||
|
try:
|
||||||
|
compiled_pattern = re.compile(pattern, re.IGNORECASE)
|
||||||
|
self._harmful_patterns.append(compiled_pattern)
|
||||||
|
|
||||||
|
safety_logger.info(f"Added harmful pattern: {pattern} ({description})")
|
||||||
|
|
||||||
|
except re.error as e:
|
||||||
|
safety_logger.error(f"Invalid regex pattern: {pattern} - {e}")
|
||||||
|
|
||||||
|
def add_bias_pattern(self, pattern: str, description: str = "") -> None:
|
||||||
|
"""Add custom bias detection pattern"""
|
||||||
|
try:
|
||||||
|
compiled_pattern = re.compile(pattern, re.IGNORECASE)
|
||||||
|
self._bias_patterns.append(compiled_pattern)
|
||||||
|
|
||||||
|
safety_logger.info(f"Added bias pattern: {pattern} ({description})")
|
||||||
|
|
||||||
|
except re.error as e:
|
||||||
|
safety_logger.error(f"Invalid regex pattern: {pattern} - {e}")
|
||||||
|
|
||||||
|
def get_safety_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get safety checking statistics"""
|
||||||
|
# Handle filter_level safely
|
||||||
|
filter_level_str = (
|
||||||
|
self.filter_level.value
|
||||||
|
if hasattr(self.filter_level, 'value')
|
||||||
|
else str(self.filter_level)
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"enabled": self.enabled,
|
||||||
|
"safety_prompt": self.safety_prompt,
|
||||||
|
"filter_level": filter_level_str,
|
||||||
|
"harmful_patterns_count": len(self._harmful_patterns),
|
||||||
|
"bias_patterns_count": len(self._bias_patterns),
|
||||||
|
"age_inappropriate_patterns_count": len(self._age_inappropriate),
|
||||||
|
}
|
@ -0,0 +1,362 @@
|
|||||||
|
"""
|
||||||
|
Shield configuration for Swarms framework.
|
||||||
|
|
||||||
|
This module provides configuration options for the security shield,
|
||||||
|
allowing users to customize security settings for their swarm deployments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from .swarm_shield import EncryptionStrength
|
||||||
|
|
||||||
|
|
||||||
|
class SecurityLevel(Enum):
|
||||||
|
"""Security levels for shield configuration"""
|
||||||
|
|
||||||
|
BASIC = "basic" # Basic input validation and output filtering
|
||||||
|
STANDARD = "standard" # Standard security with encryption
|
||||||
|
ENHANCED = "enhanced" # Enhanced security with additional checks
|
||||||
|
MAXIMUM = "maximum" # Maximum security with all features enabled
|
||||||
|
|
||||||
|
|
||||||
|
class ShieldConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Configuration for SwarmShield security features
|
||||||
|
|
||||||
|
This class provides comprehensive configuration options for
|
||||||
|
enabling and customizing security features across all swarm architectures.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Core security settings
|
||||||
|
enabled: bool = Field(default=True, description="Enable shield protection")
|
||||||
|
security_level: SecurityLevel = Field(
|
||||||
|
default=SecurityLevel.STANDARD,
|
||||||
|
description="Overall security level"
|
||||||
|
)
|
||||||
|
|
||||||
|
# SwarmShield settings
|
||||||
|
encryption_strength: EncryptionStrength = Field(
|
||||||
|
default=EncryptionStrength.MAXIMUM,
|
||||||
|
description="Encryption strength for message protection"
|
||||||
|
)
|
||||||
|
key_rotation_interval: int = Field(
|
||||||
|
default=3600,
|
||||||
|
ge=300, # Minimum 5 minutes
|
||||||
|
description="Key rotation interval in seconds"
|
||||||
|
)
|
||||||
|
storage_path: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Path for encrypted storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Input validation settings
|
||||||
|
enable_input_validation: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable input validation and sanitization"
|
||||||
|
)
|
||||||
|
max_input_length: int = Field(
|
||||||
|
default=10000,
|
||||||
|
ge=100,
|
||||||
|
description="Maximum input length in characters"
|
||||||
|
)
|
||||||
|
blocked_patterns: List[str] = Field(
|
||||||
|
default=[
|
||||||
|
r"<script.*?>.*?</script>", # XSS prevention
|
||||||
|
r"javascript:", # JavaScript injection
|
||||||
|
r"data:text/html", # HTML injection
|
||||||
|
r"vbscript:", # VBScript injection
|
||||||
|
r"on\w+\s*=", # Event handler injection
|
||||||
|
],
|
||||||
|
description="Regex patterns to block in inputs"
|
||||||
|
)
|
||||||
|
allowed_domains: List[str] = Field(
|
||||||
|
default=[],
|
||||||
|
description="Allowed domains for external requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Output filtering settings
|
||||||
|
enable_output_filtering: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable output filtering and sanitization"
|
||||||
|
)
|
||||||
|
filter_sensitive_data: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Filter sensitive data from outputs"
|
||||||
|
)
|
||||||
|
sensitive_patterns: List[str] = Field(
|
||||||
|
default=[
|
||||||
|
r"\b\d{3}-\d{2}-\d{4}\b", # SSN
|
||||||
|
r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b", # Credit card
|
||||||
|
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", # Email
|
||||||
|
r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b", # IP address
|
||||||
|
],
|
||||||
|
description="Regex patterns for sensitive data"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Safety checking settings
|
||||||
|
enable_safety_checks: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable safety and content filtering"
|
||||||
|
)
|
||||||
|
safety_prompt_on: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable safety prompt integration"
|
||||||
|
)
|
||||||
|
content_filter_level: str = Field(
|
||||||
|
default="moderate",
|
||||||
|
description="Content filtering level (low, moderate, high)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rate limiting settings
|
||||||
|
enable_rate_limiting: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable rate limiting and abuse prevention"
|
||||||
|
)
|
||||||
|
max_requests_per_minute: int = Field(
|
||||||
|
default=60,
|
||||||
|
ge=1,
|
||||||
|
description="Maximum requests per minute per agent"
|
||||||
|
)
|
||||||
|
max_tokens_per_request: int = Field(
|
||||||
|
default=10000,
|
||||||
|
ge=100,
|
||||||
|
description="Maximum tokens per request"
|
||||||
|
)
|
||||||
|
rate_limit_window: int = Field(
|
||||||
|
default=60,
|
||||||
|
ge=10,
|
||||||
|
description="Rate limit window in seconds"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Audit and logging settings
|
||||||
|
enable_audit_logging: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable comprehensive audit logging"
|
||||||
|
)
|
||||||
|
log_security_events: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Log security-related events"
|
||||||
|
)
|
||||||
|
log_input_output: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Log input and output data (use with caution)"
|
||||||
|
)
|
||||||
|
audit_retention_days: int = Field(
|
||||||
|
default=90,
|
||||||
|
ge=1,
|
||||||
|
description="Audit log retention period in days"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Performance settings
|
||||||
|
enable_caching: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable security result caching"
|
||||||
|
)
|
||||||
|
cache_ttl: int = Field(
|
||||||
|
default=300,
|
||||||
|
ge=60,
|
||||||
|
description="Cache TTL in seconds"
|
||||||
|
)
|
||||||
|
max_cache_size: int = Field(
|
||||||
|
default=1000,
|
||||||
|
ge=100,
|
||||||
|
description="Maximum cache entries"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Integration settings
|
||||||
|
integrate_with_conversation: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Integrate with conversation management"
|
||||||
|
)
|
||||||
|
protect_agent_communications: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Protect inter-agent communications"
|
||||||
|
)
|
||||||
|
encrypt_storage: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Encrypt persistent storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Custom settings
|
||||||
|
custom_rules: Dict[str, Any] = Field(
|
||||||
|
default={},
|
||||||
|
description="Custom security rules and configurations"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compatibility fields for examples
|
||||||
|
encryption_enabled: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable encryption (alias for enabled)"
|
||||||
|
)
|
||||||
|
input_validation_enabled: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable input validation (alias for enable_input_validation)"
|
||||||
|
)
|
||||||
|
output_filtering_enabled: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable output filtering (alias for enable_output_filtering)"
|
||||||
|
)
|
||||||
|
rate_limiting_enabled: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable rate limiting (alias for enable_rate_limiting)"
|
||||||
|
)
|
||||||
|
safety_checking_enabled: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable safety checking (alias for enable_safety_checks)"
|
||||||
|
)
|
||||||
|
block_suspicious_content: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Block suspicious content patterns"
|
||||||
|
)
|
||||||
|
custom_blocked_patterns: List[str] = Field(
|
||||||
|
default=[],
|
||||||
|
description="Custom patterns to block in addition to default ones"
|
||||||
|
)
|
||||||
|
safety_threshold: float = Field(
|
||||||
|
default=0.7,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Safety threshold for content filtering"
|
||||||
|
)
|
||||||
|
bias_detection_enabled: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Enable bias detection in content"
|
||||||
|
)
|
||||||
|
content_moderation_enabled: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable content moderation"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Pydantic configuration"""
|
||||||
|
use_enum_values = True
|
||||||
|
validate_assignment = True
|
||||||
|
|
||||||
|
def get_encryption_config(self) -> Dict[str, Any]:
|
||||||
|
"""Get encryption configuration for SwarmShield"""
|
||||||
|
return {
|
||||||
|
"encryption_strength": self.encryption_strength,
|
||||||
|
"key_rotation_interval": self.key_rotation_interval,
|
||||||
|
"storage_path": self.storage_path,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_validation_config(self) -> Dict[str, Any]:
|
||||||
|
"""Get input validation configuration"""
|
||||||
|
return {
|
||||||
|
"enabled": self.enable_input_validation,
|
||||||
|
"max_length": self.max_input_length,
|
||||||
|
"blocked_patterns": self.blocked_patterns,
|
||||||
|
"allowed_domains": self.allowed_domains,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_filtering_config(self) -> Dict[str, Any]:
|
||||||
|
"""Get output filtering configuration"""
|
||||||
|
return {
|
||||||
|
"enabled": self.enable_output_filtering,
|
||||||
|
"filter_sensitive": self.filter_sensitive_data,
|
||||||
|
"sensitive_patterns": self.sensitive_patterns,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_safety_config(self) -> Dict[str, Any]:
|
||||||
|
"""Get safety checking configuration"""
|
||||||
|
return {
|
||||||
|
"enabled": self.enable_safety_checks,
|
||||||
|
"safety_prompt": self.safety_prompt_on,
|
||||||
|
"filter_level": self.content_filter_level,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_rate_limit_config(self) -> Dict[str, Any]:
|
||||||
|
"""Get rate limiting configuration"""
|
||||||
|
return {
|
||||||
|
"enabled": self.enable_rate_limiting,
|
||||||
|
"max_requests_per_minute": self.max_requests_per_minute,
|
||||||
|
"max_tokens_per_request": self.max_tokens_per_request,
|
||||||
|
"window": self.rate_limit_window,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_audit_config(self) -> Dict[str, Any]:
|
||||||
|
"""Get audit logging configuration"""
|
||||||
|
return {
|
||||||
|
"enabled": self.enable_audit_logging,
|
||||||
|
"log_security": self.log_security_events,
|
||||||
|
"log_io": self.log_input_output,
|
||||||
|
"retention_days": self.audit_retention_days,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_performance_config(self) -> Dict[str, Any]:
|
||||||
|
"""Get performance configuration"""
|
||||||
|
return {
|
||||||
|
"enable_caching": self.enable_caching,
|
||||||
|
"cache_ttl": self.cache_ttl,
|
||||||
|
"max_cache_size": self.max_cache_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_integration_config(self) -> Dict[str, Any]:
|
||||||
|
"""Get integration configuration"""
|
||||||
|
return {
|
||||||
|
"conversation": self.integrate_with_conversation,
|
||||||
|
"agent_communications": self.protect_agent_communications,
|
||||||
|
"encrypt_storage": self.encrypt_storage,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_basic_config(cls) -> "ShieldConfig":
|
||||||
|
"""Create a basic security configuration"""
|
||||||
|
return cls(
|
||||||
|
security_level=SecurityLevel.BASIC,
|
||||||
|
encryption_strength=EncryptionStrength.STANDARD,
|
||||||
|
enable_safety_checks=False,
|
||||||
|
enable_rate_limiting=False,
|
||||||
|
enable_audit_logging=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_standard_config(cls) -> "ShieldConfig":
|
||||||
|
"""Create a standard security configuration"""
|
||||||
|
return cls(
|
||||||
|
security_level=SecurityLevel.STANDARD,
|
||||||
|
encryption_strength=EncryptionStrength.ENHANCED,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_enhanced_config(cls) -> "ShieldConfig":
|
||||||
|
"""Create an enhanced security configuration"""
|
||||||
|
return cls(
|
||||||
|
security_level=SecurityLevel.ENHANCED,
|
||||||
|
encryption_strength=EncryptionStrength.MAXIMUM,
|
||||||
|
max_requests_per_minute=30,
|
||||||
|
content_filter_level="high",
|
||||||
|
log_input_output=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_maximum_config(cls) -> "ShieldConfig":
|
||||||
|
"""Create a maximum security configuration"""
|
||||||
|
return cls(
|
||||||
|
security_level=SecurityLevel.MAXIMUM,
|
||||||
|
encryption_strength=EncryptionStrength.MAXIMUM,
|
||||||
|
key_rotation_interval=1800, # 30 minutes
|
||||||
|
max_requests_per_minute=20,
|
||||||
|
content_filter_level="high",
|
||||||
|
log_input_output=True,
|
||||||
|
audit_retention_days=365,
|
||||||
|
max_cache_size=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_security_level(cls, level: str) -> "ShieldConfig":
|
||||||
|
"""Get a security configuration for the specified level"""
|
||||||
|
level = level.lower()
|
||||||
|
|
||||||
|
if level == "basic":
|
||||||
|
return cls.create_basic_config()
|
||||||
|
elif level == "standard":
|
||||||
|
return cls.create_standard_config()
|
||||||
|
elif level == "enhanced":
|
||||||
|
return cls.create_enhanced_config()
|
||||||
|
elif level == "maximum":
|
||||||
|
return cls.create_maximum_config()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown security level: {level}. Must be one of: basic, standard, enhanced, maximum")
|
@ -0,0 +1,737 @@
|
|||||||
|
"""
|
||||||
|
SwarmShield integration for Swarms framework.
|
||||||
|
|
||||||
|
This module provides enterprise-grade security for swarm communications,
|
||||||
|
including encryption, conversation management, and audit capabilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import secrets
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
from cryptography.hazmat.primitives import hashes
|
||||||
|
from cryptography.hazmat.primitives.ciphers import (
|
||||||
|
Cipher,
|
||||||
|
algorithms,
|
||||||
|
modes,
|
||||||
|
)
|
||||||
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from swarms.utils.loguru_logger import initialize_logger
|
||||||
|
|
||||||
|
# Initialize logger for security module
|
||||||
|
security_logger = initialize_logger(log_folder="security")
|
||||||
|
|
||||||
|
|
||||||
|
class EncryptionStrength(Enum):
|
||||||
|
"""Encryption strength levels for SwarmShield"""
|
||||||
|
|
||||||
|
STANDARD = "standard" # AES-256
|
||||||
|
ENHANCED = "enhanced" # AES-256 + SHA-512
|
||||||
|
MAXIMUM = "maximum" # AES-256 + SHA-512 + HMAC
|
||||||
|
|
||||||
|
|
||||||
|
class SwarmShield:
|
||||||
|
"""
|
||||||
|
SwarmShield: Advanced security system for swarm agents
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Multi-layer message encryption
|
||||||
|
- Secure conversation storage
|
||||||
|
- Automatic key rotation
|
||||||
|
- Message integrity verification
|
||||||
|
- Integration with Swarms framework
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encryption_strength: EncryptionStrength = EncryptionStrength.MAXIMUM,
|
||||||
|
key_rotation_interval: int = 3600, # 1 hour
|
||||||
|
storage_path: Optional[str] = None,
|
||||||
|
enable_logging: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize SwarmShield with security settings
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encryption_strength: Level of encryption to use
|
||||||
|
key_rotation_interval: Key rotation interval in seconds
|
||||||
|
storage_path: Path for encrypted storage
|
||||||
|
enable_logging: Enable security logging
|
||||||
|
"""
|
||||||
|
self.encryption_strength = encryption_strength
|
||||||
|
self.key_rotation_interval = key_rotation_interval
|
||||||
|
self.enable_logging = enable_logging
|
||||||
|
|
||||||
|
# Set storage path within swarms framework
|
||||||
|
if storage_path:
|
||||||
|
self.storage_path = Path(storage_path)
|
||||||
|
else:
|
||||||
|
self.storage_path = Path("swarms_security_storage")
|
||||||
|
|
||||||
|
# Initialize storage and locks
|
||||||
|
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._conv_lock = threading.Lock()
|
||||||
|
self._conversations: Dict[str, List[Dict]] = {}
|
||||||
|
|
||||||
|
# Initialize security components
|
||||||
|
self._initialize_security()
|
||||||
|
self._load_conversations()
|
||||||
|
|
||||||
|
if self.enable_logging:
|
||||||
|
# Handle encryption_strength safely (could be enum or string)
|
||||||
|
encryption_str = (
|
||||||
|
encryption_strength.value
|
||||||
|
if hasattr(encryption_strength, 'value')
|
||||||
|
else str(encryption_strength)
|
||||||
|
)
|
||||||
|
security_logger.info(
|
||||||
|
f"SwarmShield initialized with {encryption_str} encryption"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _initialize_security(self) -> None:
|
||||||
|
"""Set up encryption keys and components"""
|
||||||
|
try:
|
||||||
|
# Generate master key and salt
|
||||||
|
self.master_key = secrets.token_bytes(32)
|
||||||
|
self.salt = os.urandom(16)
|
||||||
|
|
||||||
|
# Initialize key derivation
|
||||||
|
self.kdf = PBKDF2HMAC(
|
||||||
|
algorithm=hashes.SHA512(),
|
||||||
|
length=32,
|
||||||
|
salt=self.salt,
|
||||||
|
iterations=600000,
|
||||||
|
backend=default_backend(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate initial keys
|
||||||
|
self._rotate_keys()
|
||||||
|
self.hmac_key = secrets.token_bytes(32)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.enable_logging:
|
||||||
|
security_logger.error(f"Security initialization failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _rotate_keys(self) -> None:
|
||||||
|
"""Perform security key rotation"""
|
||||||
|
try:
|
||||||
|
self.encryption_key = self.kdf.derive(self.master_key)
|
||||||
|
self.iv = os.urandom(16)
|
||||||
|
self.last_rotation = time.time()
|
||||||
|
if self.enable_logging:
|
||||||
|
security_logger.debug("Security keys rotated successfully")
|
||||||
|
except Exception as e:
|
||||||
|
if self.enable_logging:
|
||||||
|
security_logger.error(f"Key rotation failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _check_rotation(self) -> None:
|
||||||
|
"""Check and perform key rotation if needed"""
|
||||||
|
if (
|
||||||
|
time.time() - self.last_rotation
|
||||||
|
>= self.key_rotation_interval
|
||||||
|
):
|
||||||
|
self._rotate_keys()
|
||||||
|
|
||||||
|
def _save_conversation(self, conversation_id: str) -> None:
|
||||||
|
"""Save conversation to encrypted storage"""
|
||||||
|
try:
|
||||||
|
if conversation_id not in self._conversations:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Encrypt conversation data
|
||||||
|
json_data = json.dumps(
|
||||||
|
self._conversations[conversation_id]
|
||||||
|
).encode()
|
||||||
|
cipher = Cipher(
|
||||||
|
algorithms.AES(self.encryption_key),
|
||||||
|
modes.GCM(self.iv),
|
||||||
|
backend=default_backend(),
|
||||||
|
)
|
||||||
|
encryptor = cipher.encryptor()
|
||||||
|
encrypted_data = (
|
||||||
|
encryptor.update(json_data) + encryptor.finalize()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine encrypted data with authentication tag
|
||||||
|
combined_data = encrypted_data + encryptor.tag
|
||||||
|
|
||||||
|
# Save atomically using temporary file
|
||||||
|
conv_path = self.storage_path / f"{conversation_id}.conv"
|
||||||
|
temp_path = conv_path.with_suffix(".tmp")
|
||||||
|
|
||||||
|
with open(temp_path, "wb") as f:
|
||||||
|
f.write(combined_data)
|
||||||
|
temp_path.replace(conv_path)
|
||||||
|
|
||||||
|
if self.enable_logging:
|
||||||
|
security_logger.debug(f"Saved conversation {conversation_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.enable_logging:
|
||||||
|
security_logger.error(f"Failed to save conversation: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _load_conversations(self) -> None:
|
||||||
|
"""Load existing conversations from storage"""
|
||||||
|
try:
|
||||||
|
for file_path in self.storage_path.glob("*.conv"):
|
||||||
|
try:
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
combined_data = f.read()
|
||||||
|
conversation_id = file_path.stem
|
||||||
|
|
||||||
|
# Split combined data into encrypted data and authentication tag
|
||||||
|
if len(combined_data) < 16: # Minimum size for GCM tag
|
||||||
|
continue
|
||||||
|
|
||||||
|
encrypted_data = combined_data[:-16]
|
||||||
|
auth_tag = combined_data[-16:]
|
||||||
|
|
||||||
|
# Decrypt conversation data
|
||||||
|
cipher = Cipher(
|
||||||
|
algorithms.AES(self.encryption_key),
|
||||||
|
modes.GCM(self.iv, auth_tag),
|
||||||
|
backend=default_backend(),
|
||||||
|
)
|
||||||
|
decryptor = cipher.decryptor()
|
||||||
|
json_data = (
|
||||||
|
decryptor.update(encrypted_data)
|
||||||
|
+ decryptor.finalize()
|
||||||
|
)
|
||||||
|
|
||||||
|
self._conversations[conversation_id] = json.loads(
|
||||||
|
json_data
|
||||||
|
)
|
||||||
|
if self.enable_logging:
|
||||||
|
security_logger.debug(
|
||||||
|
f"Loaded conversation {conversation_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.enable_logging:
|
||||||
|
security_logger.error(
|
||||||
|
f"Failed to load conversation {file_path}: {e}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.enable_logging:
|
||||||
|
security_logger.error(f"Failed to load conversations: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def protect_message(self, agent_name: str, message: str) -> str:
|
||||||
|
"""
|
||||||
|
Encrypt a message with multiple security layers
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_name: Name of the sending agent
|
||||||
|
message: Message to encrypt
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Encrypted message string
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._check_rotation()
|
||||||
|
|
||||||
|
# Validate inputs
|
||||||
|
if not isinstance(agent_name, str) or not isinstance(
|
||||||
|
message, str
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Agent name and message must be strings"
|
||||||
|
)
|
||||||
|
if not agent_name.strip() or not message.strip():
|
||||||
|
raise ValueError(
|
||||||
|
"Agent name and message cannot be empty"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate message ID and timestamp
|
||||||
|
message_id = secrets.token_hex(16)
|
||||||
|
timestamp = datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
# Encrypt message content
|
||||||
|
message_bytes = message.encode()
|
||||||
|
cipher = Cipher(
|
||||||
|
algorithms.AES(self.encryption_key),
|
||||||
|
modes.GCM(self.iv),
|
||||||
|
backend=default_backend(),
|
||||||
|
)
|
||||||
|
encryptor = cipher.encryptor()
|
||||||
|
ciphertext = (
|
||||||
|
encryptor.update(message_bytes) + encryptor.finalize()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate message hash
|
||||||
|
message_hash = hashlib.sha512(message_bytes).hexdigest()
|
||||||
|
|
||||||
|
# Generate HMAC if maximum security
|
||||||
|
hmac_signature = None
|
||||||
|
if self.encryption_strength == EncryptionStrength.MAXIMUM:
|
||||||
|
h = hmac.new(
|
||||||
|
self.hmac_key, ciphertext, hashlib.sha512
|
||||||
|
)
|
||||||
|
hmac_signature = h.digest()
|
||||||
|
|
||||||
|
# Create secure package
|
||||||
|
secure_package = {
|
||||||
|
"id": message_id,
|
||||||
|
"time": timestamp,
|
||||||
|
"agent": agent_name,
|
||||||
|
"cipher": base64.b64encode(ciphertext).decode(),
|
||||||
|
"tag": base64.b64encode(encryptor.tag).decode(),
|
||||||
|
"hash": message_hash,
|
||||||
|
"hmac": (
|
||||||
|
base64.b64encode(hmac_signature).decode()
|
||||||
|
if hmac_signature
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
return base64.b64encode(
|
||||||
|
json.dumps(secure_package).encode()
|
||||||
|
).decode()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.enable_logging:
|
||||||
|
security_logger.error(f"Failed to protect message: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def retrieve_message(self, encrypted_str: str) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Decrypt and verify a message
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encrypted_str: Encrypted message string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (agent_name, message)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Decode secure package
|
||||||
|
secure_package = json.loads(
|
||||||
|
base64.b64decode(encrypted_str)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get components
|
||||||
|
ciphertext = base64.b64decode(secure_package["cipher"])
|
||||||
|
tag = base64.b64decode(secure_package["tag"])
|
||||||
|
|
||||||
|
# Verify HMAC if present
|
||||||
|
if secure_package["hmac"]:
|
||||||
|
hmac_signature = base64.b64decode(
|
||||||
|
secure_package["hmac"]
|
||||||
|
)
|
||||||
|
h = hmac.new(
|
||||||
|
self.hmac_key, ciphertext, hashlib.sha512
|
||||||
|
)
|
||||||
|
if not hmac.compare_digest(
|
||||||
|
hmac_signature, h.digest()
|
||||||
|
):
|
||||||
|
raise ValueError("HMAC verification failed")
|
||||||
|
|
||||||
|
# Decrypt message
|
||||||
|
cipher = Cipher(
|
||||||
|
algorithms.AES(self.encryption_key),
|
||||||
|
modes.GCM(self.iv, tag),
|
||||||
|
backend=default_backend(),
|
||||||
|
)
|
||||||
|
decryptor = cipher.decryptor()
|
||||||
|
decrypted_data = (
|
||||||
|
decryptor.update(ciphertext) + decryptor.finalize()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify hash
|
||||||
|
if (
|
||||||
|
hashlib.sha512(decrypted_data).hexdigest()
|
||||||
|
!= secure_package["hash"]
|
||||||
|
):
|
||||||
|
raise ValueError("Message hash verification failed")
|
||||||
|
|
||||||
|
return secure_package["agent"], decrypted_data.decode()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.enable_logging:
|
||||||
|
security_logger.error(f"Failed to retrieve message: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def create_conversation(self, name: str = "") -> str:
|
||||||
|
"""Create a new secure conversation"""
|
||||||
|
conversation_id = str(uuid.uuid4())
|
||||||
|
with self._conv_lock:
|
||||||
|
self._conversations[conversation_id] = {
|
||||||
|
"id": conversation_id,
|
||||||
|
"name": name,
|
||||||
|
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"messages": [],
|
||||||
|
}
|
||||||
|
self._save_conversation(conversation_id)
|
||||||
|
if self.enable_logging:
|
||||||
|
security_logger.info(f"Created conversation {conversation_id}")
|
||||||
|
return conversation_id
|
||||||
|
|
||||||
|
def add_message(
|
||||||
|
self, conversation_id: str, agent_name: str, message: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Add an encrypted message to a conversation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: Target conversation ID
|
||||||
|
agent_name: Name of the sending agent
|
||||||
|
message: Message content
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Encrypt message
|
||||||
|
encrypted = self.protect_message(agent_name, message)
|
||||||
|
|
||||||
|
# Add to conversation
|
||||||
|
with self._conv_lock:
|
||||||
|
if conversation_id not in self._conversations:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid conversation ID: {conversation_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._conversations[conversation_id][
|
||||||
|
"messages"
|
||||||
|
].append(
|
||||||
|
{
|
||||||
|
"timestamp": datetime.now(
|
||||||
|
timezone.utc
|
||||||
|
).isoformat(),
|
||||||
|
"data": encrypted,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save changes
|
||||||
|
self._save_conversation(conversation_id)
|
||||||
|
|
||||||
|
if self.enable_logging:
|
||||||
|
security_logger.info(
|
||||||
|
f"Added message to conversation {conversation_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.enable_logging:
|
||||||
|
security_logger.error(f"Failed to add message: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_messages(
|
||||||
|
self, conversation_id: str
|
||||||
|
) -> List[Tuple[str, str, datetime]]:
|
||||||
|
"""
|
||||||
|
Get decrypted messages from a conversation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: Target conversation ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (agent_name, message, timestamp) tuples
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self._conv_lock:
|
||||||
|
if conversation_id not in self._conversations:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid conversation ID: {conversation_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
history = []
|
||||||
|
for msg in self._conversations[conversation_id][
|
||||||
|
"messages"
|
||||||
|
]:
|
||||||
|
agent_name, message = self.retrieve_message(
|
||||||
|
msg["data"]
|
||||||
|
)
|
||||||
|
timestamp = datetime.fromisoformat(
|
||||||
|
msg["timestamp"]
|
||||||
|
)
|
||||||
|
history.append((agent_name, message, timestamp))
|
||||||
|
|
||||||
|
return history
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.enable_logging:
|
||||||
|
security_logger.error(f"Failed to get messages: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_conversation_summary(self, conversation_id: str) -> Dict:
|
||||||
|
"""
|
||||||
|
Get summary statistics for a conversation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: Target conversation ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with conversation statistics
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self._conv_lock:
|
||||||
|
if conversation_id not in self._conversations:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid conversation ID: {conversation_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
conv = self._conversations[conversation_id]
|
||||||
|
messages = conv["messages"]
|
||||||
|
|
||||||
|
# Get unique agents
|
||||||
|
agents = set()
|
||||||
|
for msg in messages:
|
||||||
|
agent_name, _ = self.retrieve_message(msg["data"])
|
||||||
|
agents.add(agent_name)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": conversation_id,
|
||||||
|
"name": conv["name"],
|
||||||
|
"created_at": conv["created_at"],
|
||||||
|
"message_count": len(messages),
|
||||||
|
"agents": list(agents),
|
||||||
|
"last_message": (
|
||||||
|
messages[-1]["timestamp"] if messages else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.enable_logging:
|
||||||
|
security_logger.error(f"Failed to get summary: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def delete_conversation(self, conversation_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Delete a conversation and its storage
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: Target conversation ID
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self._conv_lock:
|
||||||
|
if conversation_id not in self._conversations:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid conversation ID: {conversation_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove from memory
|
||||||
|
del self._conversations[conversation_id]
|
||||||
|
|
||||||
|
# Remove from storage
|
||||||
|
conv_path = self.storage_path / f"{conversation_id}.conv"
|
||||||
|
if conv_path.exists():
|
||||||
|
conv_path.unlink()
|
||||||
|
|
||||||
|
if self.enable_logging:
|
||||||
|
security_logger.info(f"Deleted conversation {conversation_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.enable_logging:
|
||||||
|
security_logger.error(f"Failed to delete conversation: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def export_conversation(
|
||||||
|
self, conversation_id: str, format: str = "json", path: str = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Export a conversation to a file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: Target conversation ID
|
||||||
|
format: Export format (json, txt)
|
||||||
|
path: Output file path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to exported file
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
messages = self.get_messages(conversation_id)
|
||||||
|
summary = self.get_conversation_summary(conversation_id)
|
||||||
|
|
||||||
|
if not path:
|
||||||
|
path = f"conversation_{conversation_id}.{format}"
|
||||||
|
|
||||||
|
if format.lower() == "json":
|
||||||
|
export_data = {
|
||||||
|
"summary": summary,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"agent": agent,
|
||||||
|
"message": message,
|
||||||
|
"timestamp": timestamp.isoformat(),
|
||||||
|
}
|
||||||
|
for agent, message, timestamp in messages
|
||||||
|
],
|
||||||
|
}
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(export_data, f, indent=2)
|
||||||
|
|
||||||
|
elif format.lower() == "txt":
|
||||||
|
with open(path, "w") as f:
|
||||||
|
f.write(f"Conversation: {summary['name']}\n")
|
||||||
|
f.write(f"Created: {summary['created_at']}\n")
|
||||||
|
f.write(f"Messages: {summary['message_count']}\n")
|
||||||
|
f.write(f"Agents: {', '.join(summary['agents'])}\n\n")
|
||||||
|
|
||||||
|
for agent, message, timestamp in messages:
|
||||||
|
f.write(f"[{timestamp}] {agent}: {message}\n")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported format: {format}")
|
||||||
|
|
||||||
|
if self.enable_logging:
|
||||||
|
security_logger.info(f"Exported conversation to {path}")
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.enable_logging:
|
||||||
|
security_logger.error(f"Failed to export conversation: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def backup_conversations(self, backup_path: str = None) -> str:
|
||||||
|
"""
|
||||||
|
Create a backup of all conversations
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backup_path: Backup directory path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to backup directory
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not backup_path:
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
backup_path = f"swarm_shield_backup_{timestamp}"
|
||||||
|
|
||||||
|
backup_dir = Path(backup_path)
|
||||||
|
backup_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Export all conversations
|
||||||
|
for conversation_id in self._conversations:
|
||||||
|
self.export_conversation(
|
||||||
|
conversation_id,
|
||||||
|
format="json",
|
||||||
|
path=str(backup_dir / f"{conversation_id}.json"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.enable_logging:
|
||||||
|
security_logger.info(f"Backup created at {backup_path}")
|
||||||
|
|
||||||
|
return backup_path
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.enable_logging:
|
||||||
|
security_logger.error(f"Failed to create backup: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_agent_stats(self, agent_name: str) -> Dict:
|
||||||
|
"""
|
||||||
|
Get statistics for a specific agent
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_name: Name of the agent
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with agent statistics
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
total_messages = 0
|
||||||
|
conversations = set()
|
||||||
|
|
||||||
|
for conversation_id, conv in self._conversations.items():
|
||||||
|
for msg in conv["messages"]:
|
||||||
|
msg_agent, _ = self.retrieve_message(msg["data"])
|
||||||
|
if msg_agent == agent_name:
|
||||||
|
total_messages += 1
|
||||||
|
conversations.add(conversation_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"agent_name": agent_name,
|
||||||
|
"total_messages": total_messages,
|
||||||
|
"conversations": len(conversations),
|
||||||
|
"conversation_ids": list(conversations),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.enable_logging:
|
||||||
|
security_logger.error(f"Failed to get agent stats: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def query_conversations(
|
||||||
|
self,
|
||||||
|
agent_name: str = None,
|
||||||
|
text: str = None,
|
||||||
|
start_date: datetime = None,
|
||||||
|
end_date: datetime = None,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Search conversations with filters
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_name: Filter by agent name
|
||||||
|
text: Search for text in messages
|
||||||
|
start_date: Filter by start date
|
||||||
|
end_date: Filter by end date
|
||||||
|
limit: Maximum results to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching conversation summaries
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for conversation_id, conv in self._conversations.items():
|
||||||
|
# Check date range
|
||||||
|
conv_date = datetime.fromisoformat(conv["created_at"])
|
||||||
|
if start_date and conv_date < start_date:
|
||||||
|
continue
|
||||||
|
if end_date and conv_date > end_date:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check agent filter
|
||||||
|
if agent_name:
|
||||||
|
conv_agents = set()
|
||||||
|
for msg in conv["messages"]:
|
||||||
|
msg_agent, _ = self.retrieve_message(msg["data"])
|
||||||
|
conv_agents.add(msg_agent)
|
||||||
|
if agent_name not in conv_agents:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check text filter
|
||||||
|
if text:
|
||||||
|
text_found = False
|
||||||
|
for msg in conv["messages"]:
|
||||||
|
_, message = self.retrieve_message(msg["data"])
|
||||||
|
if text.lower() in message.lower():
|
||||||
|
text_found = True
|
||||||
|
break
|
||||||
|
if not text_found:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Add to results
|
||||||
|
summary = self.get_conversation_summary(conversation_id)
|
||||||
|
results.append(summary)
|
||||||
|
|
||||||
|
if len(results) >= limit:
|
||||||
|
break
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.enable_logging:
|
||||||
|
security_logger.error(f"Failed to query conversations: {e}")
|
||||||
|
raise
|
@ -0,0 +1,318 @@
|
|||||||
|
"""
|
||||||
|
SwarmShield integration for Swarms framework.
|
||||||
|
|
||||||
|
This module provides the main integration point for all security features,
|
||||||
|
combining input validation, output filtering, safety checking, rate limiting,
|
||||||
|
and encryption into a unified security shield.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, Optional, Tuple, Union, List
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from swarms.utils.loguru_logger import initialize_logger
|
||||||
|
from .shield_config import ShieldConfig
|
||||||
|
from .swarm_shield import SwarmShield, EncryptionStrength
|
||||||
|
from .input_validator import InputValidator
|
||||||
|
from .output_filter import OutputFilter
|
||||||
|
from .safety_checker import SafetyChecker
|
||||||
|
from .rate_limiter import RateLimiter
|
||||||
|
|
||||||
|
# Initialize logger for shield integration
|
||||||
|
shield_logger = initialize_logger(log_folder="shield_integration")
|
||||||
|
|
||||||
|
|
||||||
|
class SwarmShieldIntegration:
|
||||||
|
"""
|
||||||
|
Main SwarmShield integration class
|
||||||
|
|
||||||
|
This class provides a unified interface for all security features:
|
||||||
|
- Input validation and sanitization
|
||||||
|
- Output filtering and sensitive data protection
|
||||||
|
- Safety checking and ethical AI compliance
|
||||||
|
- Rate limiting and abuse prevention
|
||||||
|
- Encrypted communication and storage
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[ShieldConfig] = None):
|
||||||
|
"""
|
||||||
|
Initialize SwarmShield integration
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Shield configuration. If None, uses default configuration.
|
||||||
|
"""
|
||||||
|
self.config = config or ShieldConfig()
|
||||||
|
|
||||||
|
# Initialize security components
|
||||||
|
self._initialize_components()
|
||||||
|
|
||||||
|
# Handle security_level safely (could be enum or string)
|
||||||
|
security_level_str = (
|
||||||
|
self.config.security_level.value
|
||||||
|
if hasattr(self.config.security_level, 'value')
|
||||||
|
else str(self.config.security_level)
|
||||||
|
)
|
||||||
|
shield_logger.info(f"SwarmShield integration initialized with {security_level_str} security")
|
||||||
|
|
||||||
|
def _initialize_components(self) -> None:
|
||||||
|
"""Initialize all security components"""
|
||||||
|
try:
|
||||||
|
# Initialize SwarmShield for encryption
|
||||||
|
if self.config.integrate_with_conversation:
|
||||||
|
self.swarm_shield = SwarmShield(
|
||||||
|
**self.config.get_encryption_config(),
|
||||||
|
enable_logging=self.config.enable_audit_logging
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.swarm_shield = None
|
||||||
|
|
||||||
|
# Initialize input validator
|
||||||
|
self.input_validator = InputValidator(
|
||||||
|
self.config.get_validation_config()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize output filter
|
||||||
|
self.output_filter = OutputFilter(
|
||||||
|
self.config.get_filtering_config()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize safety checker
|
||||||
|
self.safety_checker = SafetyChecker(
|
||||||
|
self.config.get_safety_config()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize rate limiter
|
||||||
|
self.rate_limiter = RateLimiter(
|
||||||
|
self.config.get_rate_limit_config()
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
shield_logger.error(f"Failed to initialize security components: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def validate_and_protect_input(self, input_data: str, agent_name: str, input_type: str = "text") -> Tuple[bool, str, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Validate and protect input data
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_data: Input data to validate
|
||||||
|
agent_name: Name of the agent
|
||||||
|
input_type: Type of input
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_valid, protected_data, error_message)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check rate limits
|
||||||
|
is_allowed, error = self.rate_limiter.check_agent_limit(agent_name)
|
||||||
|
if not is_allowed:
|
||||||
|
return False, "", error
|
||||||
|
|
||||||
|
# Validate input
|
||||||
|
is_valid, sanitized, error = self.input_validator.validate_input(input_data, input_type)
|
||||||
|
if not is_valid:
|
||||||
|
return False, "", error
|
||||||
|
|
||||||
|
# Check safety
|
||||||
|
is_safe, safe_content, error = self.safety_checker.check_safety(sanitized, input_type)
|
||||||
|
if not is_safe:
|
||||||
|
return False, "", error
|
||||||
|
|
||||||
|
# Protect with encryption if enabled
|
||||||
|
if self.swarm_shield and self.config.protect_agent_communications:
|
||||||
|
try:
|
||||||
|
protected = self.swarm_shield.protect_message(agent_name, safe_content)
|
||||||
|
return True, protected, None
|
||||||
|
except Exception as e:
|
||||||
|
shield_logger.error(f"Encryption failed: {e}")
|
||||||
|
return False, "", f"Encryption error: {str(e)}"
|
||||||
|
|
||||||
|
return True, safe_content, None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
shield_logger.error(f"Input validation error: {e}")
|
||||||
|
return False, "", f"Validation error: {str(e)}"
|
||||||
|
|
||||||
|
def filter_and_protect_output(self, output_data: Union[str, Dict, List], agent_name: str, output_type: str = "text") -> Tuple[bool, Union[str, Dict, List], Optional[str]]:
|
||||||
|
"""
|
||||||
|
Filter and protect output data
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_data: Output data to filter
|
||||||
|
agent_name: Name of the agent
|
||||||
|
output_type: Type of output
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_safe, filtered_data, warning_message)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Filter output
|
||||||
|
is_safe, filtered, warning = self.output_filter.filter_output(output_data, output_type)
|
||||||
|
if not is_safe:
|
||||||
|
return False, "", "Output contains unsafe content"
|
||||||
|
|
||||||
|
# Check safety
|
||||||
|
if isinstance(filtered, str):
|
||||||
|
is_safe, safe_content, error = self.safety_checker.check_safety(filtered, output_type)
|
||||||
|
if not is_safe:
|
||||||
|
return False, "", error
|
||||||
|
filtered = safe_content
|
||||||
|
|
||||||
|
# Track request
|
||||||
|
self.rate_limiter.track_request(agent_name, f"output_{output_type}")
|
||||||
|
|
||||||
|
return True, filtered, warning
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
shield_logger.error(f"Output filtering error: {e}")
|
||||||
|
return False, "", f"Filtering error: {str(e)}"
|
||||||
|
|
||||||
|
def process_agent_communication(self, agent_name: str, message: str, direction: str = "outbound") -> Tuple[bool, str, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Process agent communication (inbound or outbound)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_name: Name of the agent
|
||||||
|
message: Message content
|
||||||
|
direction: "inbound" or "outbound"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_valid, processed_message, error_message)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if direction == "inbound":
|
||||||
|
return self.validate_and_protect_input(message, agent_name, "text")
|
||||||
|
elif direction == "outbound":
|
||||||
|
return self.filter_and_protect_output(message, agent_name, "text")
|
||||||
|
else:
|
||||||
|
return False, "", f"Invalid direction: {direction}"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
shield_logger.error(f"Communication processing error: {e}")
|
||||||
|
return False, "", f"Processing error: {str(e)}"
|
||||||
|
|
||||||
|
def validate_task(self, task: str, agent_name: str) -> Tuple[bool, str, Optional[str]]:
|
||||||
|
"""Validate swarm task"""
|
||||||
|
return self.validate_and_protect_input(task, agent_name, "task")
|
||||||
|
|
||||||
|
def validate_agent_config(self, config: Dict[str, Any], agent_name: str) -> Tuple[bool, Dict[str, Any], Optional[str]]:
|
||||||
|
"""Validate agent configuration"""
|
||||||
|
try:
|
||||||
|
# Validate config structure
|
||||||
|
is_valid, validated_config, error = self.input_validator.validate_config(config)
|
||||||
|
if not is_valid:
|
||||||
|
return False, {}, error
|
||||||
|
|
||||||
|
# Check safety
|
||||||
|
is_safe, safe_config, error = self.safety_checker.check_config_safety(validated_config)
|
||||||
|
if not is_safe:
|
||||||
|
return False, {}, error
|
||||||
|
|
||||||
|
# Filter sensitive data
|
||||||
|
is_safe, filtered_config, warning = self.output_filter.filter_config_output(safe_config)
|
||||||
|
if not is_safe:
|
||||||
|
return False, {}, "Configuration contains unsafe content"
|
||||||
|
|
||||||
|
return True, filtered_config, warning
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
shield_logger.error(f"Config validation error: {e}")
|
||||||
|
return False, {}, f"Config validation error: {str(e)}"
|
||||||
|
|
||||||
|
def create_secure_conversation(self, name: str = "") -> Optional[str]:
|
||||||
|
"""Create a secure conversation if encryption is enabled"""
|
||||||
|
if self.swarm_shield and self.config.integrate_with_conversation:
|
||||||
|
try:
|
||||||
|
return self.swarm_shield.create_conversation(name)
|
||||||
|
except Exception as e:
|
||||||
|
shield_logger.error(f"Failed to create secure conversation: {e}")
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def add_secure_message(self, conversation_id: str, agent_name: str, message: str) -> bool:
|
||||||
|
"""Add a message to secure conversation"""
|
||||||
|
if self.swarm_shield and self.config.integrate_with_conversation:
|
||||||
|
try:
|
||||||
|
self.swarm_shield.add_message(conversation_id, agent_name, message)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
shield_logger.error(f"Failed to add secure message: {e}")
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_secure_messages(self, conversation_id: str) -> List[Tuple[str, str, datetime]]:
|
||||||
|
"""Get messages from secure conversation"""
|
||||||
|
if self.swarm_shield and self.config.integrate_with_conversation:
|
||||||
|
try:
|
||||||
|
return self.swarm_shield.get_messages(conversation_id)
|
||||||
|
except Exception as e:
|
||||||
|
shield_logger.error(f"Failed to get secure messages: {e}")
|
||||||
|
return []
|
||||||
|
return []
|
||||||
|
|
||||||
|
def check_rate_limit(self, agent_name: str, request_size: int = 1) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""Check rate limits for an agent"""
|
||||||
|
return self.rate_limiter.check_rate_limit(agent_name, request_size)
|
||||||
|
|
||||||
|
def get_security_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get comprehensive security statistics"""
|
||||||
|
stats = {
|
||||||
|
"security_enabled": self.config.enabled,
|
||||||
|
"input_validations": getattr(self.input_validator, 'validation_count', 0),
|
||||||
|
"rate_limit_checks": getattr(self.rate_limiter, 'check_count', 0),
|
||||||
|
"blocked_requests": getattr(self.rate_limiter, 'blocked_count', 0),
|
||||||
|
"filtered_outputs": getattr(self.output_filter, 'filter_count', 0),
|
||||||
|
"violations": getattr(self.safety_checker, 'violation_count', 0),
|
||||||
|
"encryption_enabled": self.swarm_shield is not None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle security_level safely
|
||||||
|
if hasattr(self.config.security_level, 'value'):
|
||||||
|
stats["security_level"] = self.config.security_level.value
|
||||||
|
else:
|
||||||
|
stats["security_level"] = str(self.config.security_level)
|
||||||
|
|
||||||
|
# Handle encryption_strength safely
|
||||||
|
if self.swarm_shield and hasattr(self.swarm_shield.encryption_strength, 'value'):
|
||||||
|
stats["encryption_strength"] = self.swarm_shield.encryption_strength.value
|
||||||
|
elif self.swarm_shield:
|
||||||
|
stats["encryption_strength"] = str(self.swarm_shield.encryption_strength)
|
||||||
|
else:
|
||||||
|
stats["encryption_strength"] = "none"
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def update_config(self, new_config: ShieldConfig) -> bool:
|
||||||
|
"""Update shield configuration"""
|
||||||
|
try:
|
||||||
|
self.config = new_config
|
||||||
|
self._initialize_components()
|
||||||
|
shield_logger.info("Shield configuration updated")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
shield_logger.error(f"Failed to update configuration: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def enable_security(self) -> None:
|
||||||
|
"""Enable all security features"""
|
||||||
|
self.config.enabled = True
|
||||||
|
shield_logger.info("Security features enabled")
|
||||||
|
|
||||||
|
def disable_security(self) -> None:
|
||||||
|
"""Disable all security features"""
|
||||||
|
self.config.enabled = False
|
||||||
|
shield_logger.info("Security features disabled")
|
||||||
|
|
||||||
|
def cleanup(self) -> None:
|
||||||
|
"""Cleanup resources"""
|
||||||
|
try:
|
||||||
|
if hasattr(self, 'rate_limiter'):
|
||||||
|
self.rate_limiter.stop()
|
||||||
|
shield_logger.info("SwarmShield integration cleaned up")
|
||||||
|
except Exception as e:
|
||||||
|
shield_logger.error(f"Cleanup error: {e}")
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Cleanup on destruction"""
|
||||||
|
self.cleanup()
|
Loading…
Reference in new issue