You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/swarms/structs/agent_security.py

319 lines
10 KiB

import base64
import json
import uuid
from datetime import datetime
from dataclasses import dataclass
from typing import Optional, Union, Dict, List
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
@dataclass
class EncryptedMessage:
"""Structure for encrypted messages between agents"""
sender_id: str
receiver_id: str
encrypted_content: bytes
timestamp: float
message_id: str
session_id: str
class EncryptionSession:
"""Represents an encrypted communication session between agents"""
def __init__(
self,
session_id: str,
agent_ids: List[str],
encrypted_keys: Dict[str, bytes],
created_at: datetime,
):
self.session_id = session_id
self.agent_ids = agent_ids
self.encrypted_keys = encrypted_keys
self.created_at = created_at
class AgentEncryption:
"""
Handles encryption for agent data both at rest and in transit.
Supports both symmetric (for data at rest) and asymmetric (for data in transit) encryption.
Also supports secure multi-agent communication.
"""
def __init__(
self,
agent_id: Optional[str] = None,
encryption_key: Optional[str] = None,
enable_transit_encryption: bool = False,
enable_rest_encryption: bool = False,
enable_multi_agent: bool = False,
):
self.agent_id = agent_id or str(uuid.uuid4())
self.enable_transit_encryption = enable_transit_encryption
self.enable_rest_encryption = enable_rest_encryption
self.enable_multi_agent = enable_multi_agent
# Multi-agent communication storage
self.sessions: Dict[str, EncryptionSession] = {}
self.known_agents: Dict[str, "AgentEncryption"] = {}
if enable_rest_encryption:
# Initialize encryption for data at rest
if encryption_key:
self.encryption_key = base64.urlsafe_b64encode(
PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=f"agent_{self.agent_id}".encode(), # Unique salt per agent
iterations=100000,
).derive(encryption_key.encode())
)
else:
self.encryption_key = Fernet.generate_key()
self.cipher_suite = Fernet(self.encryption_key)
if enable_transit_encryption or enable_multi_agent:
# Generate RSA key pair for transit encryption
self.private_key = rsa.generate_private_key(
public_exponent=65537, key_size=2048
)
self.public_key = self.private_key.public_key()
def register_agent(
self, agent_id: str, agent_encryption: "AgentEncryption"
) -> None:
"""Register another agent for secure communication"""
if not self.enable_multi_agent:
raise ValueError("Multi-agent support is not enabled")
self.known_agents[agent_id] = agent_encryption
def create_session(self, agent_ids: List[str]) -> str:
"""Create a new encrypted session between multiple agents"""
if not self.enable_multi_agent:
raise ValueError("Multi-agent support is not enabled")
session_id = str(uuid.uuid4())
# Generate a shared session key
session_key = Fernet.generate_key()
# Create encrypted copies of the session key for each agent
encrypted_keys = {}
for agent_id in agent_ids:
if (
agent_id not in self.known_agents
and agent_id != self.agent_id
):
raise ValueError(f"Agent {agent_id} not registered")
if agent_id == self.agent_id:
agent_public_key = self.public_key
else:
agent_public_key = self.known_agents[
agent_id
].public_key
encrypted_key = agent_public_key.encrypt(
session_key,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None,
),
)
encrypted_keys[agent_id] = encrypted_key
# Store session information
self.sessions[session_id] = EncryptionSession(
session_id=session_id,
agent_ids=agent_ids,
encrypted_keys=encrypted_keys,
created_at=datetime.now(),
)
return session_id
def encrypt_message(
self,
content: Union[str, dict],
receiver_id: str,
session_id: str,
) -> EncryptedMessage:
"""Encrypt a message for another agent within a session"""
if not self.enable_multi_agent:
raise ValueError("Multi-agent support is not enabled")
if session_id not in self.sessions:
raise ValueError("Invalid session ID")
session = self.sessions[session_id]
if (
self.agent_id not in session.agent_ids
or receiver_id not in session.agent_ids
):
raise ValueError("Sender or receiver not in session")
# Serialize content if it's a dictionary
if isinstance(content, dict):
content = json.dumps(content)
# Get the session key
encrypted_session_key = session.encrypted_keys[self.agent_id]
session_key = self.decrypt_session_key(encrypted_session_key)
# Create Fernet cipher with session key
cipher = Fernet(session_key)
# Encrypt the message
encrypted_content = cipher.encrypt(content.encode())
return EncryptedMessage(
sender_id=self.agent_id,
receiver_id=receiver_id,
encrypted_content=encrypted_content,
timestamp=datetime.now().timestamp(),
message_id=str(uuid.uuid4()),
session_id=session_id,
)
def decrypt_message(
self, message: EncryptedMessage
) -> Union[str, dict]:
"""Decrypt a message from another agent"""
if not self.enable_multi_agent:
raise ValueError("Multi-agent support is not enabled")
if message.session_id not in self.sessions:
raise ValueError("Invalid session ID")
if self.agent_id != message.receiver_id:
raise ValueError("Message not intended for this agent")
session = self.sessions[message.session_id]
# Get the session key
encrypted_session_key = session.encrypted_keys[self.agent_id]
session_key = self.decrypt_session_key(encrypted_session_key)
# Create Fernet cipher with session key
cipher = Fernet(session_key)
# Decrypt the message
decrypted_content = cipher.decrypt(
message.encrypted_content
).decode()
# Try to parse as JSON
try:
return json.loads(decrypted_content)
except json.JSONDecodeError:
return decrypted_content
def decrypt_session_key(self, encrypted_key: bytes) -> bytes:
"""Decrypt a session key using the agent's private key"""
return self.private_key.decrypt(
encrypted_key,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None,
),
)
# Original methods preserved below
def encrypt_at_rest(self, data: Union[str, dict, bytes]) -> bytes:
"""Encrypts data for storage"""
if not self.enable_rest_encryption:
return (
data
if isinstance(data, bytes)
else str(data).encode()
)
if isinstance(data, dict):
data = json.dumps(data)
if isinstance(data, str):
data = data.encode()
return self.cipher_suite.encrypt(data)
def decrypt_at_rest(
self, encrypted_data: bytes
) -> Union[str, dict]:
"""Decrypts stored data"""
if not self.enable_rest_encryption:
return encrypted_data.decode()
decrypted_data = self.cipher_suite.decrypt(encrypted_data)
try:
return json.loads(decrypted_data)
except json.JSONDecodeError:
return decrypted_data.decode()
def encrypt_for_transit(self, data: Union[str, dict]) -> bytes:
"""Encrypts data for transmission"""
if not self.enable_transit_encryption:
return str(data).encode()
if isinstance(data, dict):
data = json.dumps(data)
return self.public_key.encrypt(
data.encode(),
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None,
),
)
def decrypt_from_transit(
self, data: Union[bytes, str]
) -> Union[str, dict]:
"""Decrypts received data, handling both encrypted and unencrypted inputs"""
if not self.enable_transit_encryption:
return data.decode() if isinstance(data, bytes) else data
try:
if isinstance(data, bytes) and len(data) == 256:
decrypted_data = self.private_key.decrypt(
data,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None,
),
).decode()
else:
return (
data.decode() if isinstance(data, bytes) else data
)
try:
return json.loads(decrypted_data)
except json.JSONDecodeError:
return decrypted_data
except ValueError:
return data.decode() if isinstance(data, bytes) else data
def get_public_key_pem(self) -> bytes:
"""Returns the public key in PEM format for sharing"""
if (
not self.enable_transit_encryption
and not self.enable_multi_agent
):
return b""
return self.public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)