Add files via upload

pull/1034/head
CI-DEV 2 months ago committed by GitHub
parent b03379280d
commit df31f6f2db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,29 @@
"""
Security module for Swarms framework.
This module provides enterprise-grade security features including:
- SwarmShield integration for encrypted communications
- Input validation and sanitization
- Output filtering and safety checks
- Rate limiting and abuse prevention
- Audit logging and compliance features
"""
from .swarm_shield import SwarmShield, EncryptionStrength
from .shield_config import ShieldConfig
from .input_validator import InputValidator
from .output_filter import OutputFilter
from .safety_checker import SafetyChecker
from .rate_limiter import RateLimiter
from .swarm_shield_integration import SwarmShieldIntegration
__all__ = [
"SwarmShield",
"EncryptionStrength",
"ShieldConfig",
"InputValidator",
"OutputFilter",
"SafetyChecker",
"RateLimiter",
"SwarmShieldIntegration",
]

@ -0,0 +1,259 @@
"""
Input validation and sanitization for Swarms framework.
This module provides comprehensive input validation, sanitization,
and security checks for all swarm inputs.
"""
import re
import html
from typing import List, Optional, Dict, Any, Tuple
from urllib.parse import urlparse
from datetime import datetime
from loguru import logger
from swarms.utils.loguru_logger import initialize_logger
# Initialize logger for input validation
validation_logger = initialize_logger(log_folder="input_validation")
class InputValidator:
"""
Input validation and sanitization for swarm security
Features:
- Input length validation
- Pattern-based blocking
- XSS prevention
- SQL injection prevention
- URL validation
- Content type validation
"""
def __init__(self, config: Dict[str, Any]):
"""
Initialize input validator with configuration
Args:
config: Validation configuration dictionary
"""
self.enabled = config.get("enabled", True)
self.max_length = config.get("max_length", 10000)
self.blocked_patterns = config.get("blocked_patterns", [])
self.allowed_domains = config.get("allowed_domains", [])
# Compile regex patterns for performance
self._compiled_patterns = [
re.compile(pattern, re.IGNORECASE) for pattern in self.blocked_patterns
]
# Common malicious patterns
self._malicious_patterns = [
re.compile(r"<script.*?>.*?</script>", re.IGNORECASE),
re.compile(r"javascript:", re.IGNORECASE),
re.compile(r"data:text/html", re.IGNORECASE),
re.compile(r"vbscript:", re.IGNORECASE),
re.compile(r"on\w+\s*=", re.IGNORECASE),
re.compile(r"<iframe.*?>.*?</iframe>", re.IGNORECASE),
re.compile(r"<object.*?>.*?</object>", re.IGNORECASE),
re.compile(r"<embed.*?>", re.IGNORECASE),
re.compile(r"<link.*?>", re.IGNORECASE),
re.compile(r"<meta.*?>", re.IGNORECASE),
]
# SQL injection patterns
self._sql_patterns = [
re.compile(r"(\b(union|select|insert|update|delete|drop|create|alter)\b)", re.IGNORECASE),
re.compile(r"(--|#|/\*|\*/)", re.IGNORECASE),
re.compile(r"(\b(exec|execute|xp_|sp_)\b)", re.IGNORECASE),
]
validation_logger.info("InputValidator initialized")
def validate_input(self, input_data: str, input_type: str = "text") -> Tuple[bool, str, Optional[str]]:
"""
Validate and sanitize input data
Args:
input_data: Input data to validate
input_type: Type of input (text, url, code, etc.)
Returns:
Tuple of (is_valid, sanitized_data, error_message)
"""
if not self.enabled:
return True, input_data, None
try:
# Basic type validation
if not isinstance(input_data, str):
return False, "", "Input must be a string"
# Length validation
if len(input_data) > self.max_length:
return False, "", f"Input exceeds maximum length of {self.max_length} characters"
# Empty input check
if not input_data.strip():
return False, "", "Input cannot be empty"
# Sanitize the input
sanitized = self._sanitize_input(input_data)
# Check for blocked patterns
if self._check_blocked_patterns(sanitized):
return False, "", "Input contains blocked patterns"
# Check for malicious patterns
if self._check_malicious_patterns(sanitized):
return False, "", "Input contains potentially malicious content"
# Type-specific validation
if input_type == "url":
if not self._validate_url(sanitized):
return False, "", "Invalid URL format"
elif input_type == "code":
if not self._validate_code(sanitized):
return False, "", "Invalid code content"
elif input_type == "json":
if not self._validate_json(sanitized):
return False, "", "Invalid JSON format"
validation_logger.debug(f"Input validation passed for type: {input_type}")
return True, sanitized, None
except Exception as e:
validation_logger.error(f"Input validation error: {e}")
return False, "", f"Validation error: {str(e)}"
def _sanitize_input(self, input_data: str) -> str:
"""Sanitize input data to prevent XSS and other attacks"""
# HTML escape
sanitized = html.escape(input_data)
# Remove null bytes
sanitized = sanitized.replace('\x00', '')
# Normalize whitespace
sanitized = ' '.join(sanitized.split())
return sanitized
def _check_blocked_patterns(self, input_data: str) -> bool:
"""Check if input contains blocked patterns"""
for pattern in self._compiled_patterns:
if pattern.search(input_data):
validation_logger.warning(f"Blocked pattern detected: {pattern.pattern}")
return True
return False
def _check_malicious_patterns(self, input_data: str) -> bool:
"""Check if input contains malicious patterns"""
for pattern in self._malicious_patterns:
if pattern.search(input_data):
validation_logger.warning(f"Malicious pattern detected: {pattern.pattern}")
return True
return False
def _validate_url(self, url: str) -> bool:
"""Validate URL format and domain"""
try:
parsed = urlparse(url)
# Check if it's a valid URL
if not all([parsed.scheme, parsed.netloc]):
return False
# Check allowed domains if specified
if self.allowed_domains:
domain = parsed.netloc.lower()
if not any(allowed in domain for allowed in self.allowed_domains):
validation_logger.warning(f"Domain not allowed: {domain}")
return False
return True
except Exception:
return False
def _validate_code(self, code: str) -> bool:
"""Validate code content for safety"""
# Check for SQL injection patterns
for pattern in self._sql_patterns:
if pattern.search(code):
validation_logger.warning(f"SQL injection pattern detected: {pattern.pattern}")
return False
# Check for dangerous system calls
dangerous_calls = [
'os.system', 'subprocess.call', 'eval(', 'exec(',
'__import__', 'globals()', 'locals()'
]
for call in dangerous_calls:
if call in code:
validation_logger.warning(f"Dangerous call detected: {call}")
return False
return True
def _validate_json(self, json_str: str) -> bool:
"""Validate JSON format"""
try:
import json
json.loads(json_str)
return True
except (json.JSONDecodeError, ValueError):
return False
def validate_task(self, task: str) -> Tuple[bool, str, Optional[str]]:
"""Validate swarm task input"""
return self.validate_input(task, "text")
def validate_agent_name(self, agent_name: str) -> Tuple[bool, str, Optional[str]]:
"""Validate agent name input"""
# Additional validation for agent names
if not re.match(r'^[a-zA-Z0-9_-]+$', agent_name):
return False, "", "Agent name can only contain letters, numbers, underscores, and hyphens"
if len(agent_name) < 1 or len(agent_name) > 50:
return False, "", "Agent name must be between 1 and 50 characters"
return self.validate_input(agent_name, "text")
def validate_message(self, message: str) -> Tuple[bool, str, Optional[str]]:
"""Validate message input"""
return self.validate_input(message, "text")
def validate_config(self, config: Dict[str, Any]) -> Tuple[bool, Dict[str, Any], Optional[str]]:
"""Validate configuration input"""
try:
# Convert config to string for validation
import json
config_str = json.dumps(config)
is_valid, sanitized, error = self.validate_input(config_str, "json")
if not is_valid:
return False, {}, error
# Parse back to dict
validated_config = json.loads(sanitized)
return True, validated_config, None
except Exception as e:
return False, {}, f"Configuration validation error: {str(e)}"
def get_validation_stats(self) -> Dict[str, Any]:
"""Get validation statistics"""
return {
"enabled": self.enabled,
"max_length": self.max_length,
"blocked_patterns_count": len(self.blocked_patterns),
"allowed_domains_count": len(self.allowed_domains),
"malicious_patterns_count": len(self._malicious_patterns),
"sql_patterns_count": len(self._sql_patterns),
}

@ -0,0 +1,285 @@
"""
Output filtering and sanitization for Swarms framework.
This module provides comprehensive output filtering, sanitization,
and sensitive data protection for all swarm outputs.
"""
import re
import json
from typing import List, Optional, Dict, Any, Tuple, Union
from datetime import datetime
from loguru import logger
from swarms.utils.loguru_logger import initialize_logger
# Initialize logger for output filtering
filter_logger = initialize_logger(log_folder="output_filtering")
class OutputFilter:
"""
Output filtering and sanitization for swarm security
Features:
- Sensitive data filtering
- Output sanitization
- Content type filtering
- PII protection
- Malicious content detection
"""
def __init__(self, config: Dict[str, Any]):
"""
Initialize output filter with configuration
Args:
config: Filtering configuration dictionary
"""
self.enabled = config.get("enabled", True)
self.filter_sensitive = config.get("filter_sensitive", True)
self.sensitive_patterns = config.get("sensitive_patterns", [])
# Compile regex patterns for performance
self._compiled_patterns = [
re.compile(pattern, re.IGNORECASE) for pattern in self.sensitive_patterns
]
# Default sensitive data patterns
self._default_patterns = [
re.compile(r"\b\d{3}-\d{2}-\d{4}\b"), # SSN
re.compile(r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b"), # Credit card
re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"), # Email
re.compile(r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b"), # IP address
re.compile(r"\b\d{3}[\s-]?\d{3}[\s-]?\d{4}\b"), # Phone number
re.compile(r"\b[A-Z]{2}\d{2}[A-Z0-9]{10,30}\b"), # IBAN
re.compile(r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{1,3}\b"), # Extended CC
]
# Malicious content patterns
self._malicious_patterns = [
re.compile(r"<script.*?>.*?</script>", re.IGNORECASE),
re.compile(r"javascript:", re.IGNORECASE),
re.compile(r"data:text/html", re.IGNORECASE),
re.compile(r"vbscript:", re.IGNORECASE),
re.compile(r"on\w+\s*=", re.IGNORECASE),
re.compile(r"<iframe.*?>.*?</iframe>", re.IGNORECASE),
re.compile(r"<object.*?>.*?</object>", re.IGNORECASE),
re.compile(r"<embed.*?>", re.IGNORECASE),
]
# API key patterns
self._api_key_patterns = [
re.compile(r"sk-[a-zA-Z0-9]{32,}"), # OpenAI API key
re.compile(r"pk_[a-zA-Z0-9]{32,}"), # OpenAI API key (public)
re.compile(r"[a-zA-Z0-9]{32,}"), # Generic API key
]
filter_logger.info("OutputFilter initialized")
def filter_output(self, output_data: Union[str, Dict, List], output_type: str = "text") -> Tuple[bool, Union[str, Dict, List], Optional[str]]:
"""
Filter and sanitize output data
Args:
output_data: Output data to filter
output_type: Type of output (text, json, dict, etc.)
Returns:
Tuple of (is_safe, filtered_data, warning_message)
"""
if not self.enabled:
return True, output_data, None
try:
# Convert to string for processing
if isinstance(output_data, (dict, list)):
output_str = json.dumps(output_data, ensure_ascii=False)
else:
output_str = str(output_data)
# Check for malicious content
if self._check_malicious_content(output_str):
return False, "", "Output contains potentially malicious content"
# Filter sensitive data
if self.filter_sensitive:
filtered_str = self._filter_sensitive_data(output_str)
else:
filtered_str = output_str
# Convert back to original type if needed
if isinstance(output_data, (dict, list)) and output_type in ["json", "dict"]:
try:
filtered_data = json.loads(filtered_str)
except json.JSONDecodeError:
filtered_data = filtered_str
else:
filtered_data = filtered_str
# Check if any sensitive data was filtered
warning = None
if filtered_str != output_str:
warning = "Sensitive data was filtered from output"
filter_logger.debug(f"Output filtering completed for type: {output_type}")
return True, filtered_data, warning
except Exception as e:
filter_logger.error(f"Output filtering error: {e}")
return False, "", f"Filtering error: {str(e)}"
def _check_malicious_content(self, content: str) -> bool:
"""Check if content contains malicious patterns"""
for pattern in self._malicious_patterns:
if pattern.search(content):
filter_logger.warning(f"Malicious content detected: {pattern.pattern}")
return True
return False
def _filter_sensitive_data(self, content: str) -> str:
"""Filter sensitive data from content"""
filtered_content = content
# Filter custom sensitive patterns
for pattern in self._compiled_patterns:
filtered_content = pattern.sub("[SENSITIVE_DATA]", filtered_content)
# Filter default sensitive patterns
for pattern in self._default_patterns:
filtered_content = pattern.sub("[SENSITIVE_DATA]", filtered_content)
# Filter API keys
for pattern in self._api_key_patterns:
filtered_content = pattern.sub("[API_KEY]", filtered_content)
return filtered_content
def filter_agent_response(self, response: str, agent_name: str) -> Tuple[bool, str, Optional[str]]:
"""Filter agent response output"""
return self.filter_output(response, "text")
def filter_swarm_output(self, output: Union[str, Dict, List]) -> Tuple[bool, Union[str, Dict, List], Optional[str]]:
"""Filter swarm output"""
return self.filter_output(output, "json")
def filter_conversation_history(self, history: List[Dict]) -> Tuple[bool, List[Dict], Optional[str]]:
"""Filter conversation history"""
try:
filtered_history = []
warnings = []
for message in history:
# Filter message content
is_safe, filtered_content, warning = self.filter_output(
message.get("content", ""), "text"
)
if not is_safe:
return False, [], "Conversation history contains unsafe content"
# Create filtered message
filtered_message = message.copy()
filtered_message["content"] = filtered_content
if warning:
warnings.append(warning)
filtered_history.append(filtered_message)
warning_msg = "; ".join(set(warnings)) if warnings else None
return True, filtered_history, warning_msg
except Exception as e:
filter_logger.error(f"Conversation history filtering error: {e}")
return False, [], f"History filtering error: {str(e)}"
def filter_config_output(self, config: Dict[str, Any]) -> Tuple[bool, Dict[str, Any], Optional[str]]:
"""Filter configuration output"""
try:
# Create a copy to avoid modifying original
filtered_config = config.copy()
# Filter sensitive config fields
sensitive_fields = [
"api_key", "secret", "password", "token", "key",
"credential", "auth", "private", "secret_key"
]
warnings = []
for field in sensitive_fields:
if field in filtered_config:
if isinstance(filtered_config[field], str):
filtered_config[field] = "[SENSITIVE_CONFIG]"
warnings.append(f"Sensitive config field '{field}' was filtered")
warning_msg = "; ".join(warnings) if warnings else None
return True, filtered_config, warning_msg
except Exception as e:
filter_logger.error(f"Config filtering error: {e}")
return False, {}, f"Config filtering error: {str(e)}"
def sanitize_for_logging(self, data: Union[str, Dict, List]) -> str:
"""Sanitize data for logging purposes"""
try:
if isinstance(data, (dict, list)):
data_str = json.dumps(data, ensure_ascii=False)
else:
data_str = str(data)
# Apply aggressive filtering for logs
sanitized = self._filter_sensitive_data(data_str)
# Truncate if too long
if len(sanitized) > 1000:
sanitized = sanitized[:1000] + "... [TRUNCATED]"
return sanitized
except Exception as e:
filter_logger.error(f"Log sanitization error: {e}")
return "[SANITIZATION_ERROR]"
def add_custom_pattern(self, pattern: str, description: str = "") -> None:
"""Add custom sensitive data pattern"""
try:
compiled_pattern = re.compile(pattern, re.IGNORECASE)
self._compiled_patterns.append(compiled_pattern)
self.sensitive_patterns.append(pattern)
filter_logger.info(f"Added custom pattern: {pattern} ({description})")
except re.error as e:
filter_logger.error(f"Invalid regex pattern: {pattern} - {e}")
def remove_pattern(self, pattern: str) -> bool:
"""Remove sensitive data pattern"""
try:
if pattern in self.sensitive_patterns:
self.sensitive_patterns.remove(pattern)
# Recompile patterns
self._compiled_patterns = [
re.compile(p, re.IGNORECASE) for p in self.sensitive_patterns
]
filter_logger.info(f"Removed pattern: {pattern}")
return True
return False
except Exception as e:
filter_logger.error(f"Error removing pattern: {e}")
return False
def get_filter_stats(self) -> Dict[str, Any]:
"""Get filtering statistics"""
return {
"enabled": self.enabled,
"filter_sensitive": self.filter_sensitive,
"sensitive_patterns_count": len(self.sensitive_patterns),
"malicious_patterns_count": len(self._malicious_patterns),
"api_key_patterns_count": len(self._api_key_patterns),
"default_patterns_count": len(self._default_patterns),
}

@ -0,0 +1,323 @@
"""
Rate limiting and abuse prevention for Swarms framework.
This module provides comprehensive rate limiting, request tracking,
and abuse prevention for all swarm operations.
"""
import time
import threading
from typing import Dict, List, Optional, Tuple, Any
from collections import defaultdict, deque
from datetime import datetime, timedelta
from loguru import logger
from swarms.utils.loguru_logger import initialize_logger
# Initialize logger for rate limiting
rate_logger = initialize_logger(log_folder="rate_limiting")
class RateLimiter:
"""
Rate limiting and abuse prevention for swarm security
Features:
- Per-agent rate limiting
- Token-based limiting
- Request tracking
- Abuse detection
- Automatic blocking
"""
def __init__(self, config: Dict[str, Any]):
"""
Initialize rate limiter with configuration
Args:
config: Rate limiting configuration dictionary
"""
self.enabled = config.get("enabled", True)
self.max_requests_per_minute = config.get("max_requests_per_minute", 60)
self.max_tokens_per_request = config.get("max_tokens_per_request", 10000)
self.window = config.get("window", 60) # seconds
# Request tracking
self._request_history: Dict[str, deque] = defaultdict(lambda: deque())
self._token_usage: Dict[str, int] = defaultdict(int)
self._blocked_agents: Dict[str, float] = {}
# Thread safety
self._lock = threading.Lock()
# Cleanup thread
self._cleanup_thread = None
self._stop_cleanup = False
if self.enabled:
self._start_cleanup_thread()
rate_logger.info(f"RateLimiter initialized: {self.max_requests_per_minute} req/min, {self.max_tokens_per_request} tokens/req")
def check_rate_limit(self, agent_name: str, request_size: int = 1) -> Tuple[bool, Optional[str]]:
"""
Check if request is within rate limits
Args:
agent_name: Name of the agent making the request
request_size: Size of the request (tokens or complexity)
Returns:
Tuple of (is_allowed, error_message)
"""
if not self.enabled:
return True, None
try:
with self._lock:
current_time = time.time()
# Check if agent is blocked
if agent_name in self._blocked_agents:
block_until = self._blocked_agents[agent_name]
if current_time < block_until:
remaining = block_until - current_time
return False, f"Agent {agent_name} is blocked for {remaining:.1f} more seconds"
else:
# Unblock agent
del self._blocked_agents[agent_name]
# Check token limit
if request_size > self.max_tokens_per_request:
return False, f"Request size {request_size} exceeds token limit {self.max_tokens_per_request}"
# Get agent's request history
history = self._request_history[agent_name]
# Remove old requests outside the window
cutoff_time = current_time - self.window
while history and history[0] < cutoff_time:
history.popleft()
# Check request count limit
if len(history) >= self.max_requests_per_minute:
# Block agent temporarily
block_duration = min(300, self.window * 2) # Max 5 minutes
self._blocked_agents[agent_name] = current_time + block_duration
rate_logger.warning(f"Agent {agent_name} rate limit exceeded, blocked for {block_duration}s")
return False, f"Rate limit exceeded. Agent blocked for {block_duration} seconds"
# Add current request
history.append(current_time)
# Update token usage
self._token_usage[agent_name] += request_size
rate_logger.debug(f"Rate limit check passed for {agent_name}")
return True, None
except Exception as e:
rate_logger.error(f"Rate limit check error: {e}")
return False, f"Rate limit check error: {str(e)}"
def check_agent_limit(self, agent_name: str) -> Tuple[bool, Optional[str]]:
"""Check agent-specific rate limits"""
return self.check_rate_limit(agent_name, 1)
def check_token_limit(self, agent_name: str, token_count: int) -> Tuple[bool, Optional[str]]:
"""Check token-based rate limits"""
return self.check_rate_limit(agent_name, token_count)
def track_request(self, agent_name: str, request_type: str = "general", metadata: Dict[str, Any] = None) -> None:
"""
Track a request for monitoring purposes
Args:
agent_name: Name of the agent
request_type: Type of request
metadata: Additional request metadata
"""
if not self.enabled:
return
try:
with self._lock:
current_time = time.time()
# Store request metadata
request_info = {
"timestamp": current_time,
"type": request_type,
"metadata": metadata or {}
}
# Add to history (we only store timestamps for rate limiting)
self._request_history[agent_name].append(current_time)
rate_logger.debug(f"Tracked request: {agent_name} - {request_type}")
except Exception as e:
rate_logger.error(f"Request tracking error: {e}")
def get_agent_stats(self, agent_name: str) -> Dict[str, Any]:
"""Get rate limiting statistics for an agent"""
try:
with self._lock:
current_time = time.time()
history = self._request_history[agent_name]
# Clean old requests
cutoff_time = current_time - self.window
while history and history[0] < cutoff_time:
history.popleft()
# Calculate statistics
recent_requests = len(history)
total_tokens = self._token_usage.get(agent_name, 0)
is_blocked = agent_name in self._blocked_agents
block_remaining = 0
if is_blocked:
block_remaining = self._blocked_agents[agent_name] - current_time
return {
"agent_name": agent_name,
"recent_requests": recent_requests,
"max_requests": self.max_requests_per_minute,
"total_tokens": total_tokens,
"max_tokens_per_request": self.max_tokens_per_request,
"is_blocked": is_blocked,
"block_remaining": max(0, block_remaining),
"window_seconds": self.window,
}
except Exception as e:
rate_logger.error(f"Error getting agent stats: {e}")
return {}
def get_all_stats(self) -> Dict[str, Any]:
"""Get rate limiting statistics for all agents"""
try:
with self._lock:
stats = {
"enabled": self.enabled,
"max_requests_per_minute": self.max_requests_per_minute,
"max_tokens_per_request": self.max_tokens_per_request,
"window_seconds": self.window,
"total_agents": len(self._request_history),
"blocked_agents": len(self._blocked_agents),
"agents": {}
}
for agent_name in self._request_history:
stats["agents"][agent_name] = self.get_agent_stats(agent_name)
return stats
except Exception as e:
rate_logger.error(f"Error getting all stats: {e}")
return {}
def reset_agent(self, agent_name: str) -> bool:
"""Reset rate limiting for an agent"""
try:
with self._lock:
if agent_name in self._request_history:
self._request_history[agent_name].clear()
if agent_name in self._token_usage:
self._token_usage[agent_name] = 0
if agent_name in self._blocked_agents:
del self._blocked_agents[agent_name]
rate_logger.info(f"Reset rate limiting for agent: {agent_name}")
return True
except Exception as e:
rate_logger.error(f"Error resetting agent: {e}")
return False
def block_agent(self, agent_name: str, duration: int = 300) -> bool:
"""Manually block an agent"""
try:
with self._lock:
current_time = time.time()
self._blocked_agents[agent_name] = current_time + duration
rate_logger.warning(f"Manually blocked agent {agent_name} for {duration}s")
return True
except Exception as e:
rate_logger.error(f"Error blocking agent: {e}")
return False
def unblock_agent(self, agent_name: str) -> bool:
"""Unblock an agent"""
try:
with self._lock:
if agent_name in self._blocked_agents:
del self._blocked_agents[agent_name]
rate_logger.info(f"Unblocked agent: {agent_name}")
return True
return False
except Exception as e:
rate_logger.error(f"Error unblocking agent: {e}")
return False
def _start_cleanup_thread(self) -> None:
"""Start background cleanup thread"""
def cleanup_worker():
while not self._stop_cleanup:
try:
time.sleep(60) # Cleanup every minute
self._cleanup_old_data()
except Exception as e:
rate_logger.error(f"Cleanup thread error: {e}")
self._cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True)
self._cleanup_thread.start()
def _cleanup_old_data(self) -> None:
"""Clean up old rate limiting data"""
try:
with self._lock:
current_time = time.time()
cutoff_time = current_time - (self.window * 2) # Keep 2x window
# Clean request history
for agent_name in list(self._request_history.keys()):
history = self._request_history[agent_name]
while history and history[0] < cutoff_time:
history.popleft()
# Remove empty histories
if not history:
del self._request_history[agent_name]
# Clean blocked agents
for agent_name in list(self._blocked_agents.keys()):
if self._blocked_agents[agent_name] < current_time:
del self._blocked_agents[agent_name]
# Reset token usage periodically
if current_time % 3600 < 60: # Reset every hour
self._token_usage.clear()
except Exception as e:
rate_logger.error(f"Cleanup error: {e}")
def stop(self) -> None:
"""Stop the rate limiter and cleanup thread"""
self._stop_cleanup = True
if self._cleanup_thread:
self._cleanup_thread.join(timeout=5)
rate_logger.info("RateLimiter stopped")
def __del__(self):
"""Cleanup on destruction"""
self.stop()

@ -0,0 +1,249 @@
"""
Safety checking and content filtering for Swarms framework.
This module provides comprehensive safety checks, content filtering,
and ethical AI compliance for all swarm operations.
"""
import re
from typing import List, Optional, Dict, Any, Tuple
from enum import Enum
from loguru import logger
from swarms.utils.loguru_logger import initialize_logger
from swarms.prompts.safety_prompt import SAFETY_PROMPT
# Initialize logger for safety checking
safety_logger = initialize_logger(log_folder="safety_checking")
class ContentLevel(Enum):
"""Content filtering levels"""
LOW = "low" # Minimal filtering
MODERATE = "moderate" # Standard filtering
HIGH = "high" # Aggressive filtering
class SafetyChecker:
"""
Safety checking and content filtering for swarm security
Features:
- Content safety assessment
- Ethical AI compliance
- Harmful content detection
- Bias detection
- Safety prompt integration
"""
def __init__(self, config: Dict[str, Any]):
"""
Initialize safety checker with configuration
Args:
config: Safety configuration dictionary
"""
self.enabled = config.get("enabled", True)
self.safety_prompt = config.get("safety_prompt", True)
self.filter_level = ContentLevel(config.get("filter_level", "moderate"))
# Harmful content patterns
self._harmful_patterns = [
re.compile(r"\b(kill|murder|suicide|bomb|explosive|weapon)\b", re.IGNORECASE),
re.compile(r"\b(hack|crack|steal|fraud|scam|phishing)\b", re.IGNORECASE),
re.compile(r"\b(drug|illegal|contraband|smuggle)\b", re.IGNORECASE),
re.compile(r"\b(hate|racist|sexist|discriminate)\b", re.IGNORECASE),
re.compile(r"\b(terrorist|extremist|radical)\b", re.IGNORECASE),
]
# Bias detection patterns
self._bias_patterns = [
re.compile(r"\b(all|every|always|never|none)\b", re.IGNORECASE),
re.compile(r"\b(men are|women are|blacks are|whites are)\b", re.IGNORECASE),
re.compile(r"\b(stereotypical|typical|usual)\b", re.IGNORECASE),
]
# Age-inappropriate content
self._age_inappropriate = [
re.compile(r"\b(sex|porn|adult|explicit)\b", re.IGNORECASE),
re.compile(r"\b(violence|gore|blood|death)\b", re.IGNORECASE),
]
# Handle filter_level safely
filter_level_str = (
self.filter_level.value
if hasattr(self.filter_level, 'value')
else str(self.filter_level)
)
safety_logger.info(f"SafetyChecker initialized with level: {filter_level_str}")
def check_safety(self, content: str, content_type: str = "text") -> Tuple[bool, str, Optional[str]]:
"""
Check content for safety and ethical compliance
Args:
content: Content to check
content_type: Type of content (text, code, config, etc.)
Returns:
Tuple of (is_safe, sanitized_content, warning_message)
"""
if not self.enabled:
return True, content, None
try:
# Basic type validation
if not isinstance(content, str):
return False, "", "Content must be a string"
# Check for harmful content
if self._check_harmful_content(content):
return False, "", "Content contains potentially harmful material"
# Check for bias
if self._check_bias(content):
return False, "", "Content contains potentially biased language"
# Check age appropriateness
if self._check_age_appropriate(content):
return False, "", "Content may not be age-appropriate"
# Apply content filtering based on level
sanitized = self._filter_content(content)
# Check if content was modified
warning = None
if sanitized != content:
# Handle filter_level safely
filter_level_str = (
self.filter_level.value
if hasattr(self.filter_level, 'value')
else str(self.filter_level)
)
warning = f"Content was filtered for {filter_level_str} safety level"
safety_logger.debug(f"Safety check passed for type: {content_type}")
return True, sanitized, warning
except Exception as e:
safety_logger.error(f"Safety check error: {e}")
return False, "", f"Safety check error: {str(e)}"
def _check_harmful_content(self, content: str) -> bool:
"""Check for harmful content patterns"""
for pattern in self._harmful_patterns:
if pattern.search(content):
safety_logger.warning(f"Harmful content detected: {pattern.pattern}")
return True
return False
def _check_bias(self, content: str) -> bool:
"""Check for biased language patterns"""
for pattern in self._bias_patterns:
if pattern.search(content):
safety_logger.warning(f"Bias detected: {pattern.pattern}")
return True
return False
def _check_age_appropriate(self, content: str) -> bool:
"""Check for age-inappropriate content"""
for pattern in self._age_inappropriate:
if pattern.search(content):
safety_logger.warning(f"Age-inappropriate content detected: {pattern.pattern}")
return True
return False
def _filter_content(self, content: str) -> str:
"""Filter content based on safety level"""
filtered_content = content
if self.filter_level == ContentLevel.HIGH:
# Aggressive filtering
for pattern in self._harmful_patterns + self._bias_patterns:
filtered_content = pattern.sub("[FILTERED]", filtered_content)
elif self.filter_level == ContentLevel.MODERATE:
# Moderate filtering - only filter obvious harmful content
for pattern in self._harmful_patterns:
filtered_content = pattern.sub("[FILTERED]", filtered_content)
# LOW level doesn't filter content, only detects
return filtered_content
def check_agent_safety(self, agent_name: str, system_prompt: str) -> Tuple[bool, str, Optional[str]]:
"""Check agent system prompt for safety"""
return self.check_safety(system_prompt, "agent_prompt")
def check_task_safety(self, task: str) -> Tuple[bool, str, Optional[str]]:
"""Check task description for safety"""
return self.check_safety(task, "task")
def check_response_safety(self, response: str, agent_name: str) -> Tuple[bool, str, Optional[str]]:
"""Check agent response for safety"""
return self.check_safety(response, "response")
def check_config_safety(self, config: Dict[str, Any]) -> Tuple[bool, Dict[str, Any], Optional[str]]:
"""Check configuration for safety"""
try:
import json
config_str = json.dumps(config)
is_safe, sanitized, error = self.check_safety(config_str, "config")
if not is_safe:
return False, {}, error
# Parse back to dict
safe_config = json.loads(sanitized)
return True, safe_config, None
except Exception as e:
return False, {}, f"Config safety check error: {str(e)}"
def get_safety_prompt(self) -> str:
"""Get safety prompt for integration"""
if self.safety_prompt:
return SAFETY_PROMPT
return ""
def add_harmful_pattern(self, pattern: str, description: str = "") -> None:
"""Add custom harmful content pattern"""
try:
compiled_pattern = re.compile(pattern, re.IGNORECASE)
self._harmful_patterns.append(compiled_pattern)
safety_logger.info(f"Added harmful pattern: {pattern} ({description})")
except re.error as e:
safety_logger.error(f"Invalid regex pattern: {pattern} - {e}")
def add_bias_pattern(self, pattern: str, description: str = "") -> None:
"""Add custom bias detection pattern"""
try:
compiled_pattern = re.compile(pattern, re.IGNORECASE)
self._bias_patterns.append(compiled_pattern)
safety_logger.info(f"Added bias pattern: {pattern} ({description})")
except re.error as e:
safety_logger.error(f"Invalid regex pattern: {pattern} - {e}")
def get_safety_stats(self) -> Dict[str, Any]:
"""Get safety checking statistics"""
# Handle filter_level safely
filter_level_str = (
self.filter_level.value
if hasattr(self.filter_level, 'value')
else str(self.filter_level)
)
return {
"enabled": self.enabled,
"safety_prompt": self.safety_prompt,
"filter_level": filter_level_str,
"harmful_patterns_count": len(self._harmful_patterns),
"bias_patterns_count": len(self._bias_patterns),
"age_inappropriate_patterns_count": len(self._age_inappropriate),
}

@ -0,0 +1,362 @@
"""
Shield configuration for Swarms framework.
This module provides configuration options for the security shield,
allowing users to customize security settings for their swarm deployments.
"""
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
from enum import Enum
from .swarm_shield import EncryptionStrength
class SecurityLevel(Enum):
"""Security levels for shield configuration"""
BASIC = "basic" # Basic input validation and output filtering
STANDARD = "standard" # Standard security with encryption
ENHANCED = "enhanced" # Enhanced security with additional checks
MAXIMUM = "maximum" # Maximum security with all features enabled
class ShieldConfig(BaseModel):
"""
Configuration for SwarmShield security features
This class provides comprehensive configuration options for
enabling and customizing security features across all swarm architectures.
"""
# Core security settings
enabled: bool = Field(default=True, description="Enable shield protection")
security_level: SecurityLevel = Field(
default=SecurityLevel.STANDARD,
description="Overall security level"
)
# SwarmShield settings
encryption_strength: EncryptionStrength = Field(
default=EncryptionStrength.MAXIMUM,
description="Encryption strength for message protection"
)
key_rotation_interval: int = Field(
default=3600,
ge=300, # Minimum 5 minutes
description="Key rotation interval in seconds"
)
storage_path: Optional[str] = Field(
default=None,
description="Path for encrypted storage"
)
# Input validation settings
enable_input_validation: bool = Field(
default=True,
description="Enable input validation and sanitization"
)
max_input_length: int = Field(
default=10000,
ge=100,
description="Maximum input length in characters"
)
blocked_patterns: List[str] = Field(
default=[
r"<script.*?>.*?</script>", # XSS prevention
r"javascript:", # JavaScript injection
r"data:text/html", # HTML injection
r"vbscript:", # VBScript injection
r"on\w+\s*=", # Event handler injection
],
description="Regex patterns to block in inputs"
)
allowed_domains: List[str] = Field(
default=[],
description="Allowed domains for external requests"
)
# Output filtering settings
enable_output_filtering: bool = Field(
default=True,
description="Enable output filtering and sanitization"
)
filter_sensitive_data: bool = Field(
default=True,
description="Filter sensitive data from outputs"
)
sensitive_patterns: List[str] = Field(
default=[
r"\b\d{3}-\d{2}-\d{4}\b", # SSN
r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b", # Credit card
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", # Email
r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b", # IP address
],
description="Regex patterns for sensitive data"
)
# Safety checking settings
enable_safety_checks: bool = Field(
default=True,
description="Enable safety and content filtering"
)
safety_prompt_on: bool = Field(
default=True,
description="Enable safety prompt integration"
)
content_filter_level: str = Field(
default="moderate",
description="Content filtering level (low, moderate, high)"
)
# Rate limiting settings
enable_rate_limiting: bool = Field(
default=True,
description="Enable rate limiting and abuse prevention"
)
max_requests_per_minute: int = Field(
default=60,
ge=1,
description="Maximum requests per minute per agent"
)
max_tokens_per_request: int = Field(
default=10000,
ge=100,
description="Maximum tokens per request"
)
rate_limit_window: int = Field(
default=60,
ge=10,
description="Rate limit window in seconds"
)
# Audit and logging settings
enable_audit_logging: bool = Field(
default=True,
description="Enable comprehensive audit logging"
)
log_security_events: bool = Field(
default=True,
description="Log security-related events"
)
log_input_output: bool = Field(
default=False,
description="Log input and output data (use with caution)"
)
audit_retention_days: int = Field(
default=90,
ge=1,
description="Audit log retention period in days"
)
# Performance settings
enable_caching: bool = Field(
default=True,
description="Enable security result caching"
)
cache_ttl: int = Field(
default=300,
ge=60,
description="Cache TTL in seconds"
)
max_cache_size: int = Field(
default=1000,
ge=100,
description="Maximum cache entries"
)
# Integration settings
integrate_with_conversation: bool = Field(
default=True,
description="Integrate with conversation management"
)
protect_agent_communications: bool = Field(
default=True,
description="Protect inter-agent communications"
)
encrypt_storage: bool = Field(
default=True,
description="Encrypt persistent storage"
)
# Custom settings
custom_rules: Dict[str, Any] = Field(
default={},
description="Custom security rules and configurations"
)
# Compatibility fields for examples
encryption_enabled: bool = Field(
default=True,
description="Enable encryption (alias for enabled)"
)
input_validation_enabled: bool = Field(
default=True,
description="Enable input validation (alias for enable_input_validation)"
)
output_filtering_enabled: bool = Field(
default=True,
description="Enable output filtering (alias for enable_output_filtering)"
)
rate_limiting_enabled: bool = Field(
default=True,
description="Enable rate limiting (alias for enable_rate_limiting)"
)
safety_checking_enabled: bool = Field(
default=True,
description="Enable safety checking (alias for enable_safety_checks)"
)
block_suspicious_content: bool = Field(
default=True,
description="Block suspicious content patterns"
)
custom_blocked_patterns: List[str] = Field(
default=[],
description="Custom patterns to block in addition to default ones"
)
safety_threshold: float = Field(
default=0.7,
ge=0.0,
le=1.0,
description="Safety threshold for content filtering"
)
bias_detection_enabled: bool = Field(
default=False,
description="Enable bias detection in content"
)
content_moderation_enabled: bool = Field(
default=True,
description="Enable content moderation"
)
class Config:
"""Pydantic configuration"""
use_enum_values = True
validate_assignment = True
def get_encryption_config(self) -> Dict[str, Any]:
"""Get encryption configuration for SwarmShield"""
return {
"encryption_strength": self.encryption_strength,
"key_rotation_interval": self.key_rotation_interval,
"storage_path": self.storage_path,
}
def get_validation_config(self) -> Dict[str, Any]:
"""Get input validation configuration"""
return {
"enabled": self.enable_input_validation,
"max_length": self.max_input_length,
"blocked_patterns": self.blocked_patterns,
"allowed_domains": self.allowed_domains,
}
def get_filtering_config(self) -> Dict[str, Any]:
"""Get output filtering configuration"""
return {
"enabled": self.enable_output_filtering,
"filter_sensitive": self.filter_sensitive_data,
"sensitive_patterns": self.sensitive_patterns,
}
def get_safety_config(self) -> Dict[str, Any]:
"""Get safety checking configuration"""
return {
"enabled": self.enable_safety_checks,
"safety_prompt": self.safety_prompt_on,
"filter_level": self.content_filter_level,
}
def get_rate_limit_config(self) -> Dict[str, Any]:
"""Get rate limiting configuration"""
return {
"enabled": self.enable_rate_limiting,
"max_requests_per_minute": self.max_requests_per_minute,
"max_tokens_per_request": self.max_tokens_per_request,
"window": self.rate_limit_window,
}
def get_audit_config(self) -> Dict[str, Any]:
"""Get audit logging configuration"""
return {
"enabled": self.enable_audit_logging,
"log_security": self.log_security_events,
"log_io": self.log_input_output,
"retention_days": self.audit_retention_days,
}
def get_performance_config(self) -> Dict[str, Any]:
"""Get performance configuration"""
return {
"enable_caching": self.enable_caching,
"cache_ttl": self.cache_ttl,
"max_cache_size": self.max_cache_size,
}
def get_integration_config(self) -> Dict[str, Any]:
"""Get integration configuration"""
return {
"conversation": self.integrate_with_conversation,
"agent_communications": self.protect_agent_communications,
"encrypt_storage": self.encrypt_storage,
}
@classmethod
def create_basic_config(cls) -> "ShieldConfig":
"""Create a basic security configuration"""
return cls(
security_level=SecurityLevel.BASIC,
encryption_strength=EncryptionStrength.STANDARD,
enable_safety_checks=False,
enable_rate_limiting=False,
enable_audit_logging=False,
)
@classmethod
def create_standard_config(cls) -> "ShieldConfig":
"""Create a standard security configuration"""
return cls(
security_level=SecurityLevel.STANDARD,
encryption_strength=EncryptionStrength.ENHANCED,
)
@classmethod
def create_enhanced_config(cls) -> "ShieldConfig":
"""Create an enhanced security configuration"""
return cls(
security_level=SecurityLevel.ENHANCED,
encryption_strength=EncryptionStrength.MAXIMUM,
max_requests_per_minute=30,
content_filter_level="high",
log_input_output=True,
)
@classmethod
def create_maximum_config(cls) -> "ShieldConfig":
"""Create a maximum security configuration"""
return cls(
security_level=SecurityLevel.MAXIMUM,
encryption_strength=EncryptionStrength.MAXIMUM,
key_rotation_interval=1800, # 30 minutes
max_requests_per_minute=20,
content_filter_level="high",
log_input_output=True,
audit_retention_days=365,
max_cache_size=500,
)
@classmethod
def get_security_level(cls, level: str) -> "ShieldConfig":
"""Get a security configuration for the specified level"""
level = level.lower()
if level == "basic":
return cls.create_basic_config()
elif level == "standard":
return cls.create_standard_config()
elif level == "enhanced":
return cls.create_enhanced_config()
elif level == "maximum":
return cls.create_maximum_config()
else:
raise ValueError(f"Unknown security level: {level}. Must be one of: basic, standard, enhanced, maximum")

@ -0,0 +1,737 @@
"""
SwarmShield integration for Swarms framework.
This module provides enterprise-grade security for swarm communications,
including encryption, conversation management, and audit capabilities.
"""
from __future__ import annotations
import base64
import hashlib
import hmac
import json
import os
import secrets
import threading
import time
import uuid
from datetime import datetime, timezone
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers import (
Cipher,
algorithms,
modes,
)
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from loguru import logger
from swarms.utils.loguru_logger import initialize_logger
# Initialize logger for security module
security_logger = initialize_logger(log_folder="security")
class EncryptionStrength(Enum):
"""Encryption strength levels for SwarmShield"""
STANDARD = "standard" # AES-256
ENHANCED = "enhanced" # AES-256 + SHA-512
MAXIMUM = "maximum" # AES-256 + SHA-512 + HMAC
class SwarmShield:
"""
SwarmShield: Advanced security system for swarm agents
Features:
- Multi-layer message encryption
- Secure conversation storage
- Automatic key rotation
- Message integrity verification
- Integration with Swarms framework
"""
def __init__(
self,
encryption_strength: EncryptionStrength = EncryptionStrength.MAXIMUM,
key_rotation_interval: int = 3600, # 1 hour
storage_path: Optional[str] = None,
enable_logging: bool = True,
):
"""
Initialize SwarmShield with security settings
Args:
encryption_strength: Level of encryption to use
key_rotation_interval: Key rotation interval in seconds
storage_path: Path for encrypted storage
enable_logging: Enable security logging
"""
self.encryption_strength = encryption_strength
self.key_rotation_interval = key_rotation_interval
self.enable_logging = enable_logging
# Set storage path within swarms framework
if storage_path:
self.storage_path = Path(storage_path)
else:
self.storage_path = Path("swarms_security_storage")
# Initialize storage and locks
self.storage_path.mkdir(parents=True, exist_ok=True)
self._conv_lock = threading.Lock()
self._conversations: Dict[str, List[Dict]] = {}
# Initialize security components
self._initialize_security()
self._load_conversations()
if self.enable_logging:
# Handle encryption_strength safely (could be enum or string)
encryption_str = (
encryption_strength.value
if hasattr(encryption_strength, 'value')
else str(encryption_strength)
)
security_logger.info(
f"SwarmShield initialized with {encryption_str} encryption"
)
def _initialize_security(self) -> None:
"""Set up encryption keys and components"""
try:
# Generate master key and salt
self.master_key = secrets.token_bytes(32)
self.salt = os.urandom(16)
# Initialize key derivation
self.kdf = PBKDF2HMAC(
algorithm=hashes.SHA512(),
length=32,
salt=self.salt,
iterations=600000,
backend=default_backend(),
)
# Generate initial keys
self._rotate_keys()
self.hmac_key = secrets.token_bytes(32)
except Exception as e:
if self.enable_logging:
security_logger.error(f"Security initialization failed: {e}")
raise
def _rotate_keys(self) -> None:
"""Perform security key rotation"""
try:
self.encryption_key = self.kdf.derive(self.master_key)
self.iv = os.urandom(16)
self.last_rotation = time.time()
if self.enable_logging:
security_logger.debug("Security keys rotated successfully")
except Exception as e:
if self.enable_logging:
security_logger.error(f"Key rotation failed: {e}")
raise
def _check_rotation(self) -> None:
"""Check and perform key rotation if needed"""
if (
time.time() - self.last_rotation
>= self.key_rotation_interval
):
self._rotate_keys()
def _save_conversation(self, conversation_id: str) -> None:
"""Save conversation to encrypted storage"""
try:
if conversation_id not in self._conversations:
return
# Encrypt conversation data
json_data = json.dumps(
self._conversations[conversation_id]
).encode()
cipher = Cipher(
algorithms.AES(self.encryption_key),
modes.GCM(self.iv),
backend=default_backend(),
)
encryptor = cipher.encryptor()
encrypted_data = (
encryptor.update(json_data) + encryptor.finalize()
)
# Combine encrypted data with authentication tag
combined_data = encrypted_data + encryptor.tag
# Save atomically using temporary file
conv_path = self.storage_path / f"{conversation_id}.conv"
temp_path = conv_path.with_suffix(".tmp")
with open(temp_path, "wb") as f:
f.write(combined_data)
temp_path.replace(conv_path)
if self.enable_logging:
security_logger.debug(f"Saved conversation {conversation_id}")
except Exception as e:
if self.enable_logging:
security_logger.error(f"Failed to save conversation: {e}")
raise
def _load_conversations(self) -> None:
"""Load existing conversations from storage"""
try:
for file_path in self.storage_path.glob("*.conv"):
try:
with open(file_path, "rb") as f:
combined_data = f.read()
conversation_id = file_path.stem
# Split combined data into encrypted data and authentication tag
if len(combined_data) < 16: # Minimum size for GCM tag
continue
encrypted_data = combined_data[:-16]
auth_tag = combined_data[-16:]
# Decrypt conversation data
cipher = Cipher(
algorithms.AES(self.encryption_key),
modes.GCM(self.iv, auth_tag),
backend=default_backend(),
)
decryptor = cipher.decryptor()
json_data = (
decryptor.update(encrypted_data)
+ decryptor.finalize()
)
self._conversations[conversation_id] = json.loads(
json_data
)
if self.enable_logging:
security_logger.debug(
f"Loaded conversation {conversation_id}"
)
except Exception as e:
if self.enable_logging:
security_logger.error(
f"Failed to load conversation {file_path}: {e}"
)
continue
except Exception as e:
if self.enable_logging:
security_logger.error(f"Failed to load conversations: {e}")
raise
def protect_message(self, agent_name: str, message: str) -> str:
"""
Encrypt a message with multiple security layers
Args:
agent_name: Name of the sending agent
message: Message to encrypt
Returns:
Encrypted message string
"""
try:
self._check_rotation()
# Validate inputs
if not isinstance(agent_name, str) or not isinstance(
message, str
):
raise ValueError(
"Agent name and message must be strings"
)
if not agent_name.strip() or not message.strip():
raise ValueError(
"Agent name and message cannot be empty"
)
# Generate message ID and timestamp
message_id = secrets.token_hex(16)
timestamp = datetime.now(timezone.utc).isoformat()
# Encrypt message content
message_bytes = message.encode()
cipher = Cipher(
algorithms.AES(self.encryption_key),
modes.GCM(self.iv),
backend=default_backend(),
)
encryptor = cipher.encryptor()
ciphertext = (
encryptor.update(message_bytes) + encryptor.finalize()
)
# Calculate message hash
message_hash = hashlib.sha512(message_bytes).hexdigest()
# Generate HMAC if maximum security
hmac_signature = None
if self.encryption_strength == EncryptionStrength.MAXIMUM:
h = hmac.new(
self.hmac_key, ciphertext, hashlib.sha512
)
hmac_signature = h.digest()
# Create secure package
secure_package = {
"id": message_id,
"time": timestamp,
"agent": agent_name,
"cipher": base64.b64encode(ciphertext).decode(),
"tag": base64.b64encode(encryptor.tag).decode(),
"hash": message_hash,
"hmac": (
base64.b64encode(hmac_signature).decode()
if hmac_signature
else None
),
}
return base64.b64encode(
json.dumps(secure_package).encode()
).decode()
except Exception as e:
if self.enable_logging:
security_logger.error(f"Failed to protect message: {e}")
raise
def retrieve_message(self, encrypted_str: str) -> Tuple[str, str]:
"""
Decrypt and verify a message
Args:
encrypted_str: Encrypted message string
Returns:
Tuple of (agent_name, message)
"""
try:
# Decode secure package
secure_package = json.loads(
base64.b64decode(encrypted_str)
)
# Get components
ciphertext = base64.b64decode(secure_package["cipher"])
tag = base64.b64decode(secure_package["tag"])
# Verify HMAC if present
if secure_package["hmac"]:
hmac_signature = base64.b64decode(
secure_package["hmac"]
)
h = hmac.new(
self.hmac_key, ciphertext, hashlib.sha512
)
if not hmac.compare_digest(
hmac_signature, h.digest()
):
raise ValueError("HMAC verification failed")
# Decrypt message
cipher = Cipher(
algorithms.AES(self.encryption_key),
modes.GCM(self.iv, tag),
backend=default_backend(),
)
decryptor = cipher.decryptor()
decrypted_data = (
decryptor.update(ciphertext) + decryptor.finalize()
)
# Verify hash
if (
hashlib.sha512(decrypted_data).hexdigest()
!= secure_package["hash"]
):
raise ValueError("Message hash verification failed")
return secure_package["agent"], decrypted_data.decode()
except Exception as e:
if self.enable_logging:
security_logger.error(f"Failed to retrieve message: {e}")
raise
def create_conversation(self, name: str = "") -> str:
"""Create a new secure conversation"""
conversation_id = str(uuid.uuid4())
with self._conv_lock:
self._conversations[conversation_id] = {
"id": conversation_id,
"name": name,
"created_at": datetime.now(timezone.utc).isoformat(),
"messages": [],
}
self._save_conversation(conversation_id)
if self.enable_logging:
security_logger.info(f"Created conversation {conversation_id}")
return conversation_id
def add_message(
self, conversation_id: str, agent_name: str, message: str
) -> None:
"""
Add an encrypted message to a conversation
Args:
conversation_id: Target conversation ID
agent_name: Name of the sending agent
message: Message content
"""
try:
# Encrypt message
encrypted = self.protect_message(agent_name, message)
# Add to conversation
with self._conv_lock:
if conversation_id not in self._conversations:
raise ValueError(
f"Invalid conversation ID: {conversation_id}"
)
self._conversations[conversation_id][
"messages"
].append(
{
"timestamp": datetime.now(
timezone.utc
).isoformat(),
"data": encrypted,
}
)
# Save changes
self._save_conversation(conversation_id)
if self.enable_logging:
security_logger.info(
f"Added message to conversation {conversation_id}"
)
except Exception as e:
if self.enable_logging:
security_logger.error(f"Failed to add message: {e}")
raise
def get_messages(
self, conversation_id: str
) -> List[Tuple[str, str, datetime]]:
"""
Get decrypted messages from a conversation
Args:
conversation_id: Target conversation ID
Returns:
List of (agent_name, message, timestamp) tuples
"""
try:
with self._conv_lock:
if conversation_id not in self._conversations:
raise ValueError(
f"Invalid conversation ID: {conversation_id}"
)
history = []
for msg in self._conversations[conversation_id][
"messages"
]:
agent_name, message = self.retrieve_message(
msg["data"]
)
timestamp = datetime.fromisoformat(
msg["timestamp"]
)
history.append((agent_name, message, timestamp))
return history
except Exception as e:
if self.enable_logging:
security_logger.error(f"Failed to get messages: {e}")
raise
def get_conversation_summary(self, conversation_id: str) -> Dict:
"""
Get summary statistics for a conversation
Args:
conversation_id: Target conversation ID
Returns:
Dictionary with conversation statistics
"""
try:
with self._conv_lock:
if conversation_id not in self._conversations:
raise ValueError(
f"Invalid conversation ID: {conversation_id}"
)
conv = self._conversations[conversation_id]
messages = conv["messages"]
# Get unique agents
agents = set()
for msg in messages:
agent_name, _ = self.retrieve_message(msg["data"])
agents.add(agent_name)
return {
"id": conversation_id,
"name": conv["name"],
"created_at": conv["created_at"],
"message_count": len(messages),
"agents": list(agents),
"last_message": (
messages[-1]["timestamp"] if messages else None
),
}
except Exception as e:
if self.enable_logging:
security_logger.error(f"Failed to get summary: {e}")
raise
def delete_conversation(self, conversation_id: str) -> None:
"""
Delete a conversation and its storage
Args:
conversation_id: Target conversation ID
"""
try:
with self._conv_lock:
if conversation_id not in self._conversations:
raise ValueError(
f"Invalid conversation ID: {conversation_id}"
)
# Remove from memory
del self._conversations[conversation_id]
# Remove from storage
conv_path = self.storage_path / f"{conversation_id}.conv"
if conv_path.exists():
conv_path.unlink()
if self.enable_logging:
security_logger.info(f"Deleted conversation {conversation_id}")
except Exception as e:
if self.enable_logging:
security_logger.error(f"Failed to delete conversation: {e}")
raise
def export_conversation(
self, conversation_id: str, format: str = "json", path: str = None
) -> str:
"""
Export a conversation to a file
Args:
conversation_id: Target conversation ID
format: Export format (json, txt)
path: Output file path
Returns:
Path to exported file
"""
try:
messages = self.get_messages(conversation_id)
summary = self.get_conversation_summary(conversation_id)
if not path:
path = f"conversation_{conversation_id}.{format}"
if format.lower() == "json":
export_data = {
"summary": summary,
"messages": [
{
"agent": agent,
"message": message,
"timestamp": timestamp.isoformat(),
}
for agent, message, timestamp in messages
],
}
with open(path, "w") as f:
json.dump(export_data, f, indent=2)
elif format.lower() == "txt":
with open(path, "w") as f:
f.write(f"Conversation: {summary['name']}\n")
f.write(f"Created: {summary['created_at']}\n")
f.write(f"Messages: {summary['message_count']}\n")
f.write(f"Agents: {', '.join(summary['agents'])}\n\n")
for agent, message, timestamp in messages:
f.write(f"[{timestamp}] {agent}: {message}\n")
else:
raise ValueError(f"Unsupported format: {format}")
if self.enable_logging:
security_logger.info(f"Exported conversation to {path}")
return path
except Exception as e:
if self.enable_logging:
security_logger.error(f"Failed to export conversation: {e}")
raise
def backup_conversations(self, backup_path: str = None) -> str:
"""
Create a backup of all conversations
Args:
backup_path: Backup directory path
Returns:
Path to backup directory
"""
try:
if not backup_path:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = f"swarm_shield_backup_{timestamp}"
backup_dir = Path(backup_path)
backup_dir.mkdir(parents=True, exist_ok=True)
# Export all conversations
for conversation_id in self._conversations:
self.export_conversation(
conversation_id,
format="json",
path=str(backup_dir / f"{conversation_id}.json"),
)
if self.enable_logging:
security_logger.info(f"Backup created at {backup_path}")
return backup_path
except Exception as e:
if self.enable_logging:
security_logger.error(f"Failed to create backup: {e}")
raise
def get_agent_stats(self, agent_name: str) -> Dict:
"""
Get statistics for a specific agent
Args:
agent_name: Name of the agent
Returns:
Dictionary with agent statistics
"""
try:
total_messages = 0
conversations = set()
for conversation_id, conv in self._conversations.items():
for msg in conv["messages"]:
msg_agent, _ = self.retrieve_message(msg["data"])
if msg_agent == agent_name:
total_messages += 1
conversations.add(conversation_id)
return {
"agent_name": agent_name,
"total_messages": total_messages,
"conversations": len(conversations),
"conversation_ids": list(conversations),
}
except Exception as e:
if self.enable_logging:
security_logger.error(f"Failed to get agent stats: {e}")
raise
def query_conversations(
self,
agent_name: str = None,
text: str = None,
start_date: datetime = None,
end_date: datetime = None,
limit: int = 100,
) -> List[Dict]:
"""
Search conversations with filters
Args:
agent_name: Filter by agent name
text: Search for text in messages
start_date: Filter by start date
end_date: Filter by end date
limit: Maximum results to return
Returns:
List of matching conversation summaries
"""
try:
results = []
for conversation_id, conv in self._conversations.items():
# Check date range
conv_date = datetime.fromisoformat(conv["created_at"])
if start_date and conv_date < start_date:
continue
if end_date and conv_date > end_date:
continue
# Check agent filter
if agent_name:
conv_agents = set()
for msg in conv["messages"]:
msg_agent, _ = self.retrieve_message(msg["data"])
conv_agents.add(msg_agent)
if agent_name not in conv_agents:
continue
# Check text filter
if text:
text_found = False
for msg in conv["messages"]:
_, message = self.retrieve_message(msg["data"])
if text.lower() in message.lower():
text_found = True
break
if not text_found:
continue
# Add to results
summary = self.get_conversation_summary(conversation_id)
results.append(summary)
if len(results) >= limit:
break
return results
except Exception as e:
if self.enable_logging:
security_logger.error(f"Failed to query conversations: {e}")
raise

@ -0,0 +1,318 @@
"""
SwarmShield integration for Swarms framework.
This module provides the main integration point for all security features,
combining input validation, output filtering, safety checking, rate limiting,
and encryption into a unified security shield.
"""
from typing import Dict, Any, Optional, Tuple, Union, List
from datetime import datetime
from loguru import logger
from swarms.utils.loguru_logger import initialize_logger
from .shield_config import ShieldConfig
from .swarm_shield import SwarmShield, EncryptionStrength
from .input_validator import InputValidator
from .output_filter import OutputFilter
from .safety_checker import SafetyChecker
from .rate_limiter import RateLimiter
# Initialize logger for shield integration
shield_logger = initialize_logger(log_folder="shield_integration")
class SwarmShieldIntegration:
"""
Main SwarmShield integration class
This class provides a unified interface for all security features:
- Input validation and sanitization
- Output filtering and sensitive data protection
- Safety checking and ethical AI compliance
- Rate limiting and abuse prevention
- Encrypted communication and storage
"""
def __init__(self, config: Optional[ShieldConfig] = None):
"""
Initialize SwarmShield integration
Args:
config: Shield configuration. If None, uses default configuration.
"""
self.config = config or ShieldConfig()
# Initialize security components
self._initialize_components()
# Handle security_level safely (could be enum or string)
security_level_str = (
self.config.security_level.value
if hasattr(self.config.security_level, 'value')
else str(self.config.security_level)
)
shield_logger.info(f"SwarmShield integration initialized with {security_level_str} security")
def _initialize_components(self) -> None:
"""Initialize all security components"""
try:
# Initialize SwarmShield for encryption
if self.config.integrate_with_conversation:
self.swarm_shield = SwarmShield(
**self.config.get_encryption_config(),
enable_logging=self.config.enable_audit_logging
)
else:
self.swarm_shield = None
# Initialize input validator
self.input_validator = InputValidator(
self.config.get_validation_config()
)
# Initialize output filter
self.output_filter = OutputFilter(
self.config.get_filtering_config()
)
# Initialize safety checker
self.safety_checker = SafetyChecker(
self.config.get_safety_config()
)
# Initialize rate limiter
self.rate_limiter = RateLimiter(
self.config.get_rate_limit_config()
)
except Exception as e:
shield_logger.error(f"Failed to initialize security components: {e}")
raise
def validate_and_protect_input(self, input_data: str, agent_name: str, input_type: str = "text") -> Tuple[bool, str, Optional[str]]:
"""
Validate and protect input data
Args:
input_data: Input data to validate
agent_name: Name of the agent
input_type: Type of input
Returns:
Tuple of (is_valid, protected_data, error_message)
"""
try:
# Check rate limits
is_allowed, error = self.rate_limiter.check_agent_limit(agent_name)
if not is_allowed:
return False, "", error
# Validate input
is_valid, sanitized, error = self.input_validator.validate_input(input_data, input_type)
if not is_valid:
return False, "", error
# Check safety
is_safe, safe_content, error = self.safety_checker.check_safety(sanitized, input_type)
if not is_safe:
return False, "", error
# Protect with encryption if enabled
if self.swarm_shield and self.config.protect_agent_communications:
try:
protected = self.swarm_shield.protect_message(agent_name, safe_content)
return True, protected, None
except Exception as e:
shield_logger.error(f"Encryption failed: {e}")
return False, "", f"Encryption error: {str(e)}"
return True, safe_content, None
except Exception as e:
shield_logger.error(f"Input validation error: {e}")
return False, "", f"Validation error: {str(e)}"
def filter_and_protect_output(self, output_data: Union[str, Dict, List], agent_name: str, output_type: str = "text") -> Tuple[bool, Union[str, Dict, List], Optional[str]]:
"""
Filter and protect output data
Args:
output_data: Output data to filter
agent_name: Name of the agent
output_type: Type of output
Returns:
Tuple of (is_safe, filtered_data, warning_message)
"""
try:
# Filter output
is_safe, filtered, warning = self.output_filter.filter_output(output_data, output_type)
if not is_safe:
return False, "", "Output contains unsafe content"
# Check safety
if isinstance(filtered, str):
is_safe, safe_content, error = self.safety_checker.check_safety(filtered, output_type)
if not is_safe:
return False, "", error
filtered = safe_content
# Track request
self.rate_limiter.track_request(agent_name, f"output_{output_type}")
return True, filtered, warning
except Exception as e:
shield_logger.error(f"Output filtering error: {e}")
return False, "", f"Filtering error: {str(e)}"
def process_agent_communication(self, agent_name: str, message: str, direction: str = "outbound") -> Tuple[bool, str, Optional[str]]:
"""
Process agent communication (inbound or outbound)
Args:
agent_name: Name of the agent
message: Message content
direction: "inbound" or "outbound"
Returns:
Tuple of (is_valid, processed_message, error_message)
"""
try:
if direction == "inbound":
return self.validate_and_protect_input(message, agent_name, "text")
elif direction == "outbound":
return self.filter_and_protect_output(message, agent_name, "text")
else:
return False, "", f"Invalid direction: {direction}"
except Exception as e:
shield_logger.error(f"Communication processing error: {e}")
return False, "", f"Processing error: {str(e)}"
def validate_task(self, task: str, agent_name: str) -> Tuple[bool, str, Optional[str]]:
"""Validate swarm task"""
return self.validate_and_protect_input(task, agent_name, "task")
def validate_agent_config(self, config: Dict[str, Any], agent_name: str) -> Tuple[bool, Dict[str, Any], Optional[str]]:
"""Validate agent configuration"""
try:
# Validate config structure
is_valid, validated_config, error = self.input_validator.validate_config(config)
if not is_valid:
return False, {}, error
# Check safety
is_safe, safe_config, error = self.safety_checker.check_config_safety(validated_config)
if not is_safe:
return False, {}, error
# Filter sensitive data
is_safe, filtered_config, warning = self.output_filter.filter_config_output(safe_config)
if not is_safe:
return False, {}, "Configuration contains unsafe content"
return True, filtered_config, warning
except Exception as e:
shield_logger.error(f"Config validation error: {e}")
return False, {}, f"Config validation error: {str(e)}"
def create_secure_conversation(self, name: str = "") -> Optional[str]:
"""Create a secure conversation if encryption is enabled"""
if self.swarm_shield and self.config.integrate_with_conversation:
try:
return self.swarm_shield.create_conversation(name)
except Exception as e:
shield_logger.error(f"Failed to create secure conversation: {e}")
return None
return None
def add_secure_message(self, conversation_id: str, agent_name: str, message: str) -> bool:
"""Add a message to secure conversation"""
if self.swarm_shield and self.config.integrate_with_conversation:
try:
self.swarm_shield.add_message(conversation_id, agent_name, message)
return True
except Exception as e:
shield_logger.error(f"Failed to add secure message: {e}")
return False
return False
def get_secure_messages(self, conversation_id: str) -> List[Tuple[str, str, datetime]]:
"""Get messages from secure conversation"""
if self.swarm_shield and self.config.integrate_with_conversation:
try:
return self.swarm_shield.get_messages(conversation_id)
except Exception as e:
shield_logger.error(f"Failed to get secure messages: {e}")
return []
return []
def check_rate_limit(self, agent_name: str, request_size: int = 1) -> Tuple[bool, Optional[str]]:
"""Check rate limits for an agent"""
return self.rate_limiter.check_rate_limit(agent_name, request_size)
def get_security_stats(self) -> Dict[str, Any]:
"""Get comprehensive security statistics"""
stats = {
"security_enabled": self.config.enabled,
"input_validations": getattr(self.input_validator, 'validation_count', 0),
"rate_limit_checks": getattr(self.rate_limiter, 'check_count', 0),
"blocked_requests": getattr(self.rate_limiter, 'blocked_count', 0),
"filtered_outputs": getattr(self.output_filter, 'filter_count', 0),
"violations": getattr(self.safety_checker, 'violation_count', 0),
"encryption_enabled": self.swarm_shield is not None,
}
# Handle security_level safely
if hasattr(self.config.security_level, 'value'):
stats["security_level"] = self.config.security_level.value
else:
stats["security_level"] = str(self.config.security_level)
# Handle encryption_strength safely
if self.swarm_shield and hasattr(self.swarm_shield.encryption_strength, 'value'):
stats["encryption_strength"] = self.swarm_shield.encryption_strength.value
elif self.swarm_shield:
stats["encryption_strength"] = str(self.swarm_shield.encryption_strength)
else:
stats["encryption_strength"] = "none"
return stats
def update_config(self, new_config: ShieldConfig) -> bool:
"""Update shield configuration"""
try:
self.config = new_config
self._initialize_components()
shield_logger.info("Shield configuration updated")
return True
except Exception as e:
shield_logger.error(f"Failed to update configuration: {e}")
return False
def enable_security(self) -> None:
"""Enable all security features"""
self.config.enabled = True
shield_logger.info("Security features enabled")
def disable_security(self) -> None:
"""Disable all security features"""
self.config.enabled = False
shield_logger.info("Security features disabled")
def cleanup(self) -> None:
"""Cleanup resources"""
try:
if hasattr(self, 'rate_limiter'):
self.rate_limiter.stop()
shield_logger.info("SwarmShield integration cleaned up")
except Exception as e:
shield_logger.error(f"Cleanup error: {e}")
def __del__(self):
"""Cleanup on destruction"""
self.cleanup()
Loading…
Cancel
Save