Update de_hallucination_swarm.py

pull/1034/head
CI-DEV 2 months ago committed by GitHub
parent 3080c42a77
commit 95a6d2f87b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -2,6 +2,7 @@ from typing import List, Dict, Any, Optional
import time import time
from loguru import logger from loguru import logger
from swarms.structs.agent import Agent from swarms.structs.agent import Agent
from swarms.security import SwarmShieldIntegration, ShieldConfig
# Prompt templates for different agent roles # Prompt templates for different agent roles
GENERATOR_PROMPT = """ GENERATOR_PROMPT = """
@ -110,6 +111,9 @@ class DeHallucinationSwarm:
iterations: int = 2, iterations: int = 2,
system_prompt: str = GENERATOR_PROMPT, system_prompt: str = GENERATOR_PROMPT,
store_intermediate_results: bool = True, store_intermediate_results: bool = True,
shield_config: Optional[ShieldConfig] = None,
enable_security: bool = True,
security_level: str = "standard",
): ):
""" """
Initialize the DeHallucinationSwarm with configurable agents. Initialize the DeHallucinationSwarm with configurable agents.
@ -118,6 +122,9 @@ class DeHallucinationSwarm:
model_names: List of model names for generator, critic, refiner, and validator model_names: List of model names for generator, critic, refiner, and validator
iterations: Number of criticism-refinement cycles to perform iterations: Number of criticism-refinement cycles to perform
store_intermediate_results: Whether to store all intermediate outputs store_intermediate_results: Whether to store all intermediate outputs
shield_config (ShieldConfig, optional): Security configuration for SwarmShield integration. Defaults to None.
enable_security (bool, optional): Whether to enable SwarmShield security features. Defaults to True.
security_level (str, optional): Pre-defined security level. Options: "basic", "standard", "enhanced", "maximum". Defaults to "standard".
""" """
self.name = name self.name = name
self.description = description self.description = description
@ -126,6 +133,9 @@ class DeHallucinationSwarm:
self.system_prompt = system_prompt self.system_prompt = system_prompt
self.history = [] self.history = []
# Initialize SwarmShield integration
self._initialize_swarm_shield(shield_config, enable_security, security_level)
# Initialize all agents # Initialize all agents
self.generator = Agent( self.generator = Agent(
agent_name="Generator", agent_name="Generator",
@ -155,6 +165,92 @@ class DeHallucinationSwarm:
model_name=model_names[3], model_name=model_names[3],
) )
def _initialize_swarm_shield(
self,
shield_config: Optional[ShieldConfig] = None,
enable_security: bool = True,
security_level: str = "standard"
) -> None:
"""Initialize SwarmShield integration for security features."""
self.enable_security = enable_security
self.security_level = security_level
if enable_security:
if shield_config is None:
shield_config = ShieldConfig.get_security_level(security_level)
self.swarm_shield = SwarmShieldIntegration(shield_config)
logger.info(f"SwarmShield initialized with {security_level} security level")
else:
self.swarm_shield = None
logger.info("SwarmShield security disabled")
# Security methods
def validate_task_with_shield(self, task: str) -> str:
"""Validate and sanitize task input using SwarmShield."""
if self.swarm_shield:
return self.swarm_shield.validate_and_protect_input(task)
return task
def validate_agent_config_with_shield(self, agent_config: dict) -> dict:
"""Validate agent configuration using SwarmShield."""
if self.swarm_shield:
return self.swarm_shield.validate_and_protect_input(str(agent_config))
return agent_config
def process_agent_communication_with_shield(self, message: str, agent_name: str) -> str:
"""Process agent communication through SwarmShield security."""
if self.swarm_shield:
return self.swarm_shield.process_agent_communication(message, agent_name)
return message
def check_rate_limit_with_shield(self, agent_name: str) -> bool:
"""Check rate limits for an agent using SwarmShield."""
if self.swarm_shield:
return self.swarm_shield.check_rate_limit(agent_name)
return True
def add_secure_message(self, message: str, agent_name: str) -> None:
"""Add a message to secure conversation history."""
if self.swarm_shield:
self.swarm_shield.add_secure_message(message, agent_name)
def get_secure_messages(self) -> List[dict]:
"""Get secure conversation messages."""
if self.swarm_shield:
return self.swarm_shield.get_secure_messages()
return []
def get_security_stats(self) -> dict:
"""Get security statistics and metrics."""
if self.swarm_shield:
return self.swarm_shield.get_security_stats()
return {"security_enabled": False}
def update_shield_config(self, new_config: ShieldConfig) -> None:
"""Update SwarmShield configuration."""
if self.swarm_shield:
self.swarm_shield.update_config(new_config)
logger.info("SwarmShield configuration updated")
def enable_security(self) -> None:
"""Enable SwarmShield security features."""
if not self.swarm_shield:
self._initialize_swarm_shield(enable_security=True, security_level=self.security_level)
logger.info("SwarmShield security enabled")
def disable_security(self) -> None:
"""Disable SwarmShield security features."""
self.swarm_shield = None
self.enable_security = False
logger.info("SwarmShield security disabled")
def cleanup_security(self) -> None:
"""Clean up SwarmShield resources."""
if self.swarm_shield:
self.swarm_shield.cleanup()
logger.info("SwarmShield resources cleaned up")
def _log_step( def _log_step(
self, self,
step_name: str, step_name: str,

Loading…
Cancel
Save