parent
bb69f8e696
commit
704f91fdea
@ -0,0 +1,235 @@
|
|||||||
|
from typing import List, Dict, Optional, TypedDict, Literal, Union, Any
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
from swarms import Agent
|
||||||
|
|
||||||
|
# Set up logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class ModelName(str, Enum):
|
||||||
|
"""Valid model names for swarms agents"""
|
||||||
|
GPT4O = "gpt-4o"
|
||||||
|
GPT4O_MINI = "gpt-4o-mini"
|
||||||
|
GPT4 = "gpt-4"
|
||||||
|
GPT35_TURBO = "gpt-3.5-turbo"
|
||||||
|
CLAUDE = "claude-v1"
|
||||||
|
CLAUDE2 = "claude-2"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_model_names(cls) -> List[str]:
|
||||||
|
"""Get list of valid model names"""
|
||||||
|
return [model.value for model in cls]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_valid_model(cls, model_name: str) -> bool:
|
||||||
|
"""Check if model name is valid"""
|
||||||
|
return model_name in cls.get_model_names()
|
||||||
|
|
||||||
|
class AgentConfigDict(TypedDict):
|
||||||
|
"""TypedDict for agent configuration"""
|
||||||
|
agent_name: str
|
||||||
|
system_prompt: str
|
||||||
|
model_name: str # Using str instead of ModelName for flexibility
|
||||||
|
max_loops: int
|
||||||
|
autosave: bool
|
||||||
|
dashboard: bool
|
||||||
|
verbose: bool
|
||||||
|
dynamic_temperature: bool
|
||||||
|
saved_state_path: str
|
||||||
|
user_name: str
|
||||||
|
retry_attempts: int
|
||||||
|
context_length: int
|
||||||
|
return_step_meta: bool
|
||||||
|
output_type: str
|
||||||
|
streaming: bool
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentValidationError(Exception):
|
||||||
|
"""Custom exception for agent validation errors"""
|
||||||
|
message: str
|
||||||
|
field: str
|
||||||
|
value: Any
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"Validation error in field '{self.field}': {self.message}. Got value: {self.value}"
|
||||||
|
|
||||||
|
class AgentValidator:
|
||||||
|
"""Validates agent configuration data"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_config(config: Dict[str, Any]) -> AgentConfigDict:
|
||||||
|
"""Validate and convert agent configuration"""
|
||||||
|
try:
|
||||||
|
# Validate model name
|
||||||
|
model_name = str(config['model_name'])
|
||||||
|
if not ModelName.is_valid_model(model_name):
|
||||||
|
valid_models = ModelName.get_model_names()
|
||||||
|
raise AgentValidationError(
|
||||||
|
f"Invalid model name. Must be one of: {', '.join(valid_models)}",
|
||||||
|
"model_name",
|
||||||
|
model_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert types with error handling
|
||||||
|
validated_config: AgentConfigDict = {
|
||||||
|
'agent_name': str(config.get('agent_name', '')),
|
||||||
|
'system_prompt': str(config.get('system_prompt', '')),
|
||||||
|
'model_name': model_name,
|
||||||
|
'max_loops': int(config.get('max_loops', 1)),
|
||||||
|
'autosave': bool(str(config.get('autosave', True)).lower() == 'true'),
|
||||||
|
'dashboard': bool(str(config.get('dashboard', False)).lower() == 'true'),
|
||||||
|
'verbose': bool(str(config.get('verbose', True)).lower() == 'true'),
|
||||||
|
'dynamic_temperature': bool(str(config.get('dynamic_temperature', True)).lower() == 'true'),
|
||||||
|
'saved_state_path': str(config.get('saved_state_path', '')),
|
||||||
|
'user_name': str(config.get('user_name', 'default_user')),
|
||||||
|
'retry_attempts': int(config.get('retry_attempts', 3)),
|
||||||
|
'context_length': int(config.get('context_length', 200000)),
|
||||||
|
'return_step_meta': bool(str(config.get('return_step_meta', False)).lower() == 'true'),
|
||||||
|
'output_type': str(config.get('output_type', 'string')),
|
||||||
|
'streaming': bool(str(config.get('streaming', False)).lower() == 'true')
|
||||||
|
}
|
||||||
|
|
||||||
|
return validated_config
|
||||||
|
|
||||||
|
except (ValueError, KeyError) as e:
|
||||||
|
raise AgentValidationError(
|
||||||
|
str(e),
|
||||||
|
str(e.__class__.__name__),
|
||||||
|
str(config)
|
||||||
|
)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentCSV:
|
||||||
|
"""Class to manage agents through CSV with type safety"""
|
||||||
|
csv_path: Path
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Convert string path to Path object if necessary"""
|
||||||
|
if isinstance(self.csv_path, str):
|
||||||
|
self.csv_path = Path(self.csv_path)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def headers(self) -> List[str]:
|
||||||
|
"""CSV headers for agent configuration"""
|
||||||
|
return [
|
||||||
|
"agent_name", "system_prompt", "model_name", "max_loops",
|
||||||
|
"autosave", "dashboard", "verbose", "dynamic_temperature",
|
||||||
|
"saved_state_path", "user_name", "retry_attempts", "context_length",
|
||||||
|
"return_step_meta", "output_type", "streaming"
|
||||||
|
]
|
||||||
|
|
||||||
|
def create_agent_csv(self, agents: List[Dict[str, Any]]) -> None:
|
||||||
|
"""Create a CSV file with validated agent configurations"""
|
||||||
|
validated_agents = []
|
||||||
|
for agent in agents:
|
||||||
|
try:
|
||||||
|
validated_config = AgentValidator.validate_config(agent)
|
||||||
|
validated_agents.append(validated_config)
|
||||||
|
except AgentValidationError as e:
|
||||||
|
logger.error(f"Validation error for agent {agent.get('agent_name', 'unknown')}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
with open(self.csv_path, 'w', newline='') as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=self.headers)
|
||||||
|
writer.writeheader()
|
||||||
|
writer.writerows(validated_agents)
|
||||||
|
|
||||||
|
logger.info(f"Created CSV with {len(validated_agents)} agents at {self.csv_path}")
|
||||||
|
|
||||||
|
def load_agents(self) -> List[Agent]:
|
||||||
|
"""Load and create agents from CSV with validation"""
|
||||||
|
if not self.csv_path.exists():
|
||||||
|
raise FileNotFoundError(f"CSV file not found at {self.csv_path}")
|
||||||
|
|
||||||
|
agents: List[Agent] = []
|
||||||
|
with open(self.csv_path, 'r') as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
try:
|
||||||
|
validated_config = AgentValidator.validate_config(row)
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
agent_name=validated_config['agent_name'],
|
||||||
|
system_prompt=validated_config['system_prompt'],
|
||||||
|
model_name=validated_config['model_name'],
|
||||||
|
max_loops=validated_config['max_loops'],
|
||||||
|
autosave=validated_config['autosave'],
|
||||||
|
dashboard=validated_config['dashboard'],
|
||||||
|
verbose=validated_config['verbose'],
|
||||||
|
dynamic_temperature_enabled=validated_config['dynamic_temperature'],
|
||||||
|
saved_state_path=validated_config['saved_state_path'],
|
||||||
|
user_name=validated_config['user_name'],
|
||||||
|
retry_attempts=validated_config['retry_attempts'],
|
||||||
|
context_length=validated_config['context_length'],
|
||||||
|
return_step_meta=validated_config['return_step_meta'],
|
||||||
|
output_type=validated_config['output_type'],
|
||||||
|
streaming_on=validated_config['streaming']
|
||||||
|
)
|
||||||
|
agents.append(agent)
|
||||||
|
except AgentValidationError as e:
|
||||||
|
logger.error(f"Skipping invalid agent configuration: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"Loaded {len(agents)} agents from {self.csv_path}")
|
||||||
|
return agents
|
||||||
|
|
||||||
|
def add_agent(self, agent_config: Dict[str, Any]) -> None:
|
||||||
|
"""Add a new validated agent configuration to CSV"""
|
||||||
|
validated_config = AgentValidator.validate_config(agent_config)
|
||||||
|
|
||||||
|
with open(self.csv_path, 'a', newline='') as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=self.headers)
|
||||||
|
writer.writerow(validated_config)
|
||||||
|
|
||||||
|
logger.info(f"Added new agent {validated_config['agent_name']} to {self.csv_path}")
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Example agent configurations
|
||||||
|
agent_configs = [
|
||||||
|
{
|
||||||
|
"agent_name": "Financial-Analysis-Agent",
|
||||||
|
"system_prompt": "You are a financial expert...",
|
||||||
|
"model_name": "gpt-4o-mini", # Updated to correct model name
|
||||||
|
"max_loops": 1,
|
||||||
|
"autosave": True,
|
||||||
|
"dashboard": False,
|
||||||
|
"verbose": True,
|
||||||
|
"dynamic_temperature": True,
|
||||||
|
"saved_state_path": "finance_agent.json",
|
||||||
|
"user_name": "swarms_corp",
|
||||||
|
"retry_attempts": 3,
|
||||||
|
"context_length": 200000,
|
||||||
|
"return_step_meta": False,
|
||||||
|
"output_type": "string",
|
||||||
|
"streaming": False
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize CSV manager
|
||||||
|
csv_manager = AgentCSV(Path("agents.csv"))
|
||||||
|
|
||||||
|
# Create CSV with initial agents
|
||||||
|
csv_manager.create_agent_csv(agent_configs)
|
||||||
|
|
||||||
|
# Load agents from CSV
|
||||||
|
agents = csv_manager.load_agents()
|
||||||
|
|
||||||
|
# Use an agent
|
||||||
|
if agents:
|
||||||
|
financial_agent = agents[0]
|
||||||
|
response = financial_agent.run(
|
||||||
|
"How can I establish a ROTH IRA to buy stocks and get a tax break?"
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
except AgentValidationError as e:
|
||||||
|
logger.error(f"Validation error: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error: {e}")
|
@ -0,0 +1,399 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, TimeoutError
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
|
import pulsar
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
from loguru import logger
|
||||||
|
from prometheus_client import Counter, Histogram, start_http_server
|
||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||||
|
|
||||||
|
from swarms.prompts.finance_agent_sys_prompt import FINANCIAL_AGENT_SYS_PROMPT
|
||||||
|
from swarms.structs.agent import Agent
|
||||||
|
|
||||||
|
|
||||||
|
# Enhanced metrics
|
||||||
|
TASK_COUNTER = Counter('swarm_tasks_total', 'Total number of tasks processed')
|
||||||
|
TASK_LATENCY = Histogram('swarm_task_duration_seconds', 'Task processing duration')
|
||||||
|
TASK_FAILURES = Counter('swarm_task_failures_total', 'Total number of task failures')
|
||||||
|
AGENT_ERRORS = Counter('swarm_agent_errors_total', 'Total number of agent errors')
|
||||||
|
|
||||||
|
# Define types using Literal
|
||||||
|
TaskStatus = Literal["pending", "processing", "completed", "failed"]
|
||||||
|
TaskPriority = Literal["low", "medium", "high", "critical"]
|
||||||
|
|
||||||
|
class SecurityConfig(BaseModel):
|
||||||
|
"""Security configuration for the swarm"""
|
||||||
|
encryption_key: str = Field(..., description="Encryption key for sensitive data")
|
||||||
|
tls_cert_path: Optional[str] = Field(None, description="Path to TLS certificate")
|
||||||
|
tls_key_path: Optional[str] = Field(None, description="Path to TLS private key")
|
||||||
|
auth_token: Optional[str] = Field(None, description="Authentication token")
|
||||||
|
max_message_size: int = Field(default=1048576, description="Maximum message size in bytes")
|
||||||
|
rate_limit: int = Field(default=100, description="Maximum tasks per minute")
|
||||||
|
|
||||||
|
@validator('encryption_key')
|
||||||
|
def validate_encryption_key(cls, v):
|
||||||
|
if len(v) < 32:
|
||||||
|
raise ValueError("Encryption key must be at least 32 bytes long")
|
||||||
|
return v
|
||||||
|
|
||||||
|
class Task(BaseModel):
|
||||||
|
"""Enhanced task model with additional metadata and validation"""
|
||||||
|
task_id: str = Field(..., description="Unique identifier for the task")
|
||||||
|
description: str = Field(..., description="Task description or instructions")
|
||||||
|
output_type: Literal["string", "json", "file"] = Field("string")
|
||||||
|
status: TaskStatus = Field(default="pending")
|
||||||
|
priority: TaskPriority = Field(default="medium")
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
started_at: Optional[datetime] = None
|
||||||
|
completed_at: Optional[datetime] = None
|
||||||
|
retry_count: int = Field(default=0)
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
@validator('task_id')
|
||||||
|
def validate_task_id(cls, v):
|
||||||
|
if not v.strip():
|
||||||
|
raise ValueError("task_id cannot be empty")
|
||||||
|
return v
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_encoders = {
|
||||||
|
datetime: lambda v: v.isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TaskResult(BaseModel):
|
||||||
|
"""Model for task execution results"""
|
||||||
|
task_id: str
|
||||||
|
status: TaskStatus
|
||||||
|
result: Any
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
execution_time: float
|
||||||
|
agent_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def task_timing():
|
||||||
|
"""Context manager for timing task execution"""
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
duration = time.time() - start_time
|
||||||
|
TASK_LATENCY.observe(duration)
|
||||||
|
|
||||||
|
|
||||||
|
class SecurePulsarSwarm:
|
||||||
|
"""
|
||||||
|
Enhanced secure, scalable swarm system with improved reliability and security features.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
agents: List[Any],
|
||||||
|
pulsar_url: str,
|
||||||
|
subscription_name: str,
|
||||||
|
topic_name: str,
|
||||||
|
security_config: SecurityConfig,
|
||||||
|
max_workers: int = 5,
|
||||||
|
retry_attempts: int = 3,
|
||||||
|
task_timeout: int = 300,
|
||||||
|
metrics_port: int = 8000,
|
||||||
|
):
|
||||||
|
"""Initialize the enhanced Pulsar Swarm"""
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.agents = agents
|
||||||
|
self.pulsar_url = pulsar_url
|
||||||
|
self.subscription_name = subscription_name
|
||||||
|
self.topic_name = topic_name
|
||||||
|
self.security_config = security_config
|
||||||
|
self.max_workers = max_workers
|
||||||
|
self.retry_attempts = retry_attempts
|
||||||
|
self.task_timeout = task_timeout
|
||||||
|
|
||||||
|
# Initialize encryption
|
||||||
|
self.cipher_suite = Fernet(security_config.encryption_key.encode())
|
||||||
|
|
||||||
|
# Setup metrics server
|
||||||
|
start_http_server(metrics_port)
|
||||||
|
|
||||||
|
# Initialize Pulsar client with security settings
|
||||||
|
client_config = {
|
||||||
|
"authentication": None if not security_config.auth_token else pulsar.AuthenticationToken(security_config.auth_token),
|
||||||
|
"operation_timeout_seconds": 30,
|
||||||
|
"connection_timeout_seconds": 30,
|
||||||
|
"use_tls": bool(security_config.tls_cert_path),
|
||||||
|
"tls_trust_certs_file_path": security_config.tls_cert_path,
|
||||||
|
"tls_allow_insecure_connection": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.client = pulsar.Client(self.pulsar_url, **client_config)
|
||||||
|
self.producer = self._create_producer()
|
||||||
|
self.consumer = self._create_consumer()
|
||||||
|
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||||
|
|
||||||
|
# Initialize rate limiting
|
||||||
|
self.last_execution_time = time.time()
|
||||||
|
self.execution_count = 0
|
||||||
|
|
||||||
|
logger.info(f"Secure Pulsar Swarm '{self.name}' initialized with enhanced security features")
|
||||||
|
|
||||||
|
def _create_producer(self):
|
||||||
|
"""Create a secure producer with retry logic"""
|
||||||
|
return self.client.create_producer(
|
||||||
|
self.topic_name,
|
||||||
|
max_pending_messages=1000,
|
||||||
|
compression_type=pulsar.CompressionType.LZ4,
|
||||||
|
block_if_queue_full=True,
|
||||||
|
batching_enabled=True,
|
||||||
|
batching_max_publish_delay_ms=10
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_consumer(self):
|
||||||
|
"""Create a secure consumer with retry logic"""
|
||||||
|
return self.client.subscribe(
|
||||||
|
self.topic_name,
|
||||||
|
subscription_name=self.subscription_name,
|
||||||
|
consumer_type=pulsar.ConsumerType.Shared,
|
||||||
|
message_listener=None,
|
||||||
|
receiver_queue_size=1000,
|
||||||
|
max_total_receiver_queue_size_across_partitions=50000
|
||||||
|
)
|
||||||
|
|
||||||
|
def _encrypt_message(self, data: str) -> bytes:
|
||||||
|
"""Encrypt message data"""
|
||||||
|
return self.cipher_suite.encrypt(data.encode())
|
||||||
|
|
||||||
|
def _decrypt_message(self, data: bytes) -> str:
|
||||||
|
"""Decrypt message data"""
|
||||||
|
return self.cipher_suite.decrypt(data).decode()
|
||||||
|
|
||||||
|
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
|
||||||
|
def publish_task(self, task: Task) -> None:
|
||||||
|
"""Publish a task with enhanced security and reliability"""
|
||||||
|
try:
|
||||||
|
# Validate message size
|
||||||
|
task_data = task.json()
|
||||||
|
if len(task_data) > self.security_config.max_message_size:
|
||||||
|
raise ValueError("Task data exceeds maximum message size")
|
||||||
|
|
||||||
|
# Rate limiting
|
||||||
|
current_time = time.time()
|
||||||
|
if current_time - self.last_execution_time >= 60:
|
||||||
|
self.execution_count = 0
|
||||||
|
self.last_execution_time = current_time
|
||||||
|
|
||||||
|
if self.execution_count >= self.security_config.rate_limit:
|
||||||
|
raise ValueError("Rate limit exceeded")
|
||||||
|
|
||||||
|
# Encrypt and publish
|
||||||
|
encrypted_data = self._encrypt_message(task_data)
|
||||||
|
message_id = self.producer.send(encrypted_data)
|
||||||
|
|
||||||
|
self.execution_count += 1
|
||||||
|
logger.info(f"Task {task.task_id} published successfully with message ID {message_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
TASK_FAILURES.inc()
|
||||||
|
logger.error(f"Error publishing task {task.task_id}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _process_task(self, task: Task) -> TaskResult:
|
||||||
|
"""Process a task with comprehensive error handling and monitoring"""
|
||||||
|
task.status = "processing"
|
||||||
|
task.started_at = datetime.utcnow()
|
||||||
|
|
||||||
|
with task_timing():
|
||||||
|
try:
|
||||||
|
# Select agent using round-robin
|
||||||
|
agent = self.agents.pop(0)
|
||||||
|
self.agents.append(agent)
|
||||||
|
|
||||||
|
# Execute task with timeout
|
||||||
|
future = self.executor.submit(agent.run, task.description)
|
||||||
|
result = future.result(timeout=self.task_timeout)
|
||||||
|
|
||||||
|
# Handle different output types
|
||||||
|
if task.output_type == "json":
|
||||||
|
result = json.loads(result)
|
||||||
|
elif task.output_type == "file":
|
||||||
|
file_path = f"output_{task.task_id}_{int(time.time())}.txt"
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
f.write(result)
|
||||||
|
result = {"file_path": file_path}
|
||||||
|
|
||||||
|
task.status = "completed"
|
||||||
|
task.completed_at = datetime.utcnow()
|
||||||
|
TASK_COUNTER.inc()
|
||||||
|
|
||||||
|
return TaskResult(
|
||||||
|
task_id=task.task_id,
|
||||||
|
status="completed",
|
||||||
|
result=result,
|
||||||
|
execution_time=time.time() - task.started_at.timestamp(),
|
||||||
|
agent_id=agent.agent_name
|
||||||
|
)
|
||||||
|
|
||||||
|
except TimeoutError:
|
||||||
|
TASK_FAILURES.inc()
|
||||||
|
error_msg = f"Task {task.task_id} timed out after {self.task_timeout} seconds"
|
||||||
|
logger.error(error_msg)
|
||||||
|
task.status = "failed"
|
||||||
|
return TaskResult(
|
||||||
|
task_id=task.task_id,
|
||||||
|
status="failed",
|
||||||
|
result=None,
|
||||||
|
error_message=error_msg,
|
||||||
|
execution_time=time.time() - task.started_at.timestamp(),
|
||||||
|
agent_id=agent.agent_name
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
TASK_FAILURES.inc()
|
||||||
|
AGENT_ERRORS.inc()
|
||||||
|
error_msg = f"Error processing task {task.task_id}: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
task.status = "failed"
|
||||||
|
return TaskResult(
|
||||||
|
task_id=task.task_id,
|
||||||
|
status="failed",
|
||||||
|
result=None,
|
||||||
|
error_message=error_msg,
|
||||||
|
execution_time=time.time() - task.started_at.timestamp(),
|
||||||
|
agent_id=agent.agent_name
|
||||||
|
)
|
||||||
|
|
||||||
|
async def consume_tasks(self):
|
||||||
|
"""Enhanced task consumption with circuit breaker and backoff"""
|
||||||
|
consecutive_failures = 0
|
||||||
|
backoff_time = 1
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Circuit breaker pattern
|
||||||
|
if consecutive_failures >= 5:
|
||||||
|
logger.warning(f"Circuit breaker triggered. Waiting {backoff_time} seconds")
|
||||||
|
await asyncio.sleep(backoff_time)
|
||||||
|
backoff_time = min(backoff_time * 2, 60)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Receive message with timeout
|
||||||
|
message = await self.consumer.receive_async()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Decrypt and process message
|
||||||
|
decrypted_data = self._decrypt_message(message.data())
|
||||||
|
task_data = json.loads(decrypted_data)
|
||||||
|
task = Task(**task_data)
|
||||||
|
|
||||||
|
# Process task
|
||||||
|
result = await self._process_task(task)
|
||||||
|
|
||||||
|
# Handle result
|
||||||
|
if result.status == "completed":
|
||||||
|
await self.consumer.acknowledge_async(message)
|
||||||
|
consecutive_failures = 0
|
||||||
|
backoff_time = 1
|
||||||
|
else:
|
||||||
|
if task.retry_count < self.retry_attempts:
|
||||||
|
task.retry_count += 1
|
||||||
|
await self.consumer.negative_acknowledge(message)
|
||||||
|
else:
|
||||||
|
await self.consumer.acknowledge_async(message)
|
||||||
|
logger.error(f"Task {task.task_id} failed after {self.retry_attempts} attempts")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing message: {str(e)}")
|
||||||
|
await self.consumer.negative_acknowledge(message)
|
||||||
|
consecutive_failures += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in consume_tasks: {str(e)}")
|
||||||
|
consecutive_failures += 1
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Context manager entry"""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Context manager exit with proper cleanup"""
|
||||||
|
try:
|
||||||
|
self.producer.flush()
|
||||||
|
self.producer.close()
|
||||||
|
self.consumer.close()
|
||||||
|
self.client.close()
|
||||||
|
self.executor.shutdown(wait=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during cleanup: {str(e)}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Example usage with security configuration
|
||||||
|
security_config = SecurityConfig(
|
||||||
|
encryption_key=secrets.token_urlsafe(32),
|
||||||
|
tls_cert_path="/path/to/cert.pem",
|
||||||
|
tls_key_path="/path/to/key.pem",
|
||||||
|
auth_token="your-auth-token",
|
||||||
|
max_message_size=1048576,
|
||||||
|
rate_limit=100
|
||||||
|
)
|
||||||
|
|
||||||
|
# Agent factory function
|
||||||
|
def create_financial_agent() -> Agent:
|
||||||
|
"""Factory function to create a financial analysis agent."""
|
||||||
|
return Agent(
|
||||||
|
agent_name="Financial-Analysis-Agent",
|
||||||
|
system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
max_loops=1,
|
||||||
|
autosave=True,
|
||||||
|
dashboard=False,
|
||||||
|
verbose=True,
|
||||||
|
dynamic_temperature_enabled=True,
|
||||||
|
saved_state_path="finance_agent.json",
|
||||||
|
user_name="swarms_corp",
|
||||||
|
retry_attempts=1,
|
||||||
|
context_length=200000,
|
||||||
|
return_step_meta=False,
|
||||||
|
output_type="string",
|
||||||
|
streaming_on=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize agents (implementation not shown)
|
||||||
|
agents = [create_financial_agent() for _ in range(3)]
|
||||||
|
|
||||||
|
# Initialize the secure swarm
|
||||||
|
with SecurePulsarSwarm(
|
||||||
|
name="Secure Financial Swarm",
|
||||||
|
description="Production-grade financial analysis swarm",
|
||||||
|
agents=agents,
|
||||||
|
pulsar_url="pulsar+ssl://localhost:6651",
|
||||||
|
subscription_name="secure_financial_subscription",
|
||||||
|
topic_name="secure_financial_tasks",
|
||||||
|
security_config=security_config,
|
||||||
|
max_workers=5,
|
||||||
|
retry_attempts=3,
|
||||||
|
task_timeout=300,
|
||||||
|
metrics_port=8000
|
||||||
|
) as swarm:
|
||||||
|
# Example task
|
||||||
|
task = Task(
|
||||||
|
task_id=secrets.token_urlsafe(16),
|
||||||
|
description="Analyze Q4 financial reports",
|
||||||
|
output_type="json",
|
||||||
|
priority="high",
|
||||||
|
metadata={"department": "finance", "requester": "john.doe@company.com"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the swarm
|
||||||
|
swarm.publish_task(task)
|
||||||
|
asyncio.run(swarm.consume_tasks())
|
@ -0,0 +1,187 @@
|
|||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
from qdrant_client.http import models
|
||||||
|
from qdrant_client.http.models import Distance, VectorParams
|
||||||
|
from swarm_models import Anthropic
|
||||||
|
|
||||||
|
from swarms import Agent
|
||||||
|
|
||||||
|
|
||||||
|
class QdrantMemory:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
collection_name: str = "agent_memories",
|
||||||
|
vector_size: int = 1536, # Default size for Claude embeddings
|
||||||
|
url: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
):
|
||||||
|
"""Initialize Qdrant memory system.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name: Name of the Qdrant collection to use
|
||||||
|
vector_size: Dimension of the embedding vectors
|
||||||
|
url: Optional Qdrant server URL (defaults to local)
|
||||||
|
api_key: Optional Qdrant API key for cloud deployment
|
||||||
|
"""
|
||||||
|
self.collection_name = collection_name
|
||||||
|
self.vector_size = vector_size
|
||||||
|
|
||||||
|
# Initialize Qdrant client
|
||||||
|
if url and api_key:
|
||||||
|
self.client = QdrantClient(url=url, api_key=api_key)
|
||||||
|
else:
|
||||||
|
self.client = QdrantClient(":memory:") # Local in-memory storage
|
||||||
|
|
||||||
|
# Create collection if it doesn't exist
|
||||||
|
self._create_collection()
|
||||||
|
|
||||||
|
def _create_collection(self):
|
||||||
|
"""Create the Qdrant collection if it doesn't exist."""
|
||||||
|
collections = self.client.get_collections().collections
|
||||||
|
exists = any(col.name == self.collection_name for col in collections)
|
||||||
|
|
||||||
|
if not exists:
|
||||||
|
self.client.create_collection(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
vectors_config=VectorParams(
|
||||||
|
size=self.vector_size,
|
||||||
|
distance=Distance.COSINE
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
embedding: List[float],
|
||||||
|
metadata: Optional[Dict] = None
|
||||||
|
) -> str:
|
||||||
|
"""Add a memory to the store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text content of the memory
|
||||||
|
embedding: Vector embedding of the text
|
||||||
|
metadata: Optional metadata to store with the memory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: ID of the stored memory
|
||||||
|
"""
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
# Add timestamp and generate ID
|
||||||
|
memory_id = str(uuid.uuid4())
|
||||||
|
metadata.update({
|
||||||
|
"timestamp": datetime.utcnow().isoformat(),
|
||||||
|
"text": text
|
||||||
|
})
|
||||||
|
|
||||||
|
# Store the point
|
||||||
|
self.client.upsert(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
points=[
|
||||||
|
models.PointStruct(
|
||||||
|
id=memory_id,
|
||||||
|
payload=metadata,
|
||||||
|
vector=embedding
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return memory_id
|
||||||
|
|
||||||
|
def query(
|
||||||
|
self,
|
||||||
|
query_embedding: List[float],
|
||||||
|
limit: int = 5,
|
||||||
|
score_threshold: float = 0.7
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""Query memories based on vector similarity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_embedding: Vector embedding of the query
|
||||||
|
limit: Maximum number of results to return
|
||||||
|
score_threshold: Minimum similarity score threshold
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memories with their metadata
|
||||||
|
"""
|
||||||
|
results = self.client.search(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
query_vector=query_embedding,
|
||||||
|
limit=limit,
|
||||||
|
score_threshold=score_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
memories = []
|
||||||
|
for res in results:
|
||||||
|
memory = res.payload
|
||||||
|
memory["similarity_score"] = res.score
|
||||||
|
memories.append(memory)
|
||||||
|
|
||||||
|
return memories
|
||||||
|
|
||||||
|
def delete(self, memory_id: str):
|
||||||
|
"""Delete a specific memory by ID."""
|
||||||
|
self.client.delete(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
points_selector=models.PointIdsList(
|
||||||
|
points=[memory_id]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""Clear all memories from the collection."""
|
||||||
|
self.client.delete_collection(self.collection_name)
|
||||||
|
self._create_collection()
|
||||||
|
|
||||||
|
# # Example usage
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# # Initialize memory
|
||||||
|
# memory = QdrantMemory()
|
||||||
|
|
||||||
|
# # Example embedding (would normally come from an embedding model)
|
||||||
|
# example_embedding = np.random.rand(1536).tolist()
|
||||||
|
|
||||||
|
# # Add a memory
|
||||||
|
# memory_id = memory.add(
|
||||||
|
# text="Important financial analysis about startup equity.",
|
||||||
|
# embedding=example_embedding,
|
||||||
|
# metadata={"category": "finance", "importance": "high"}
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # Query memories
|
||||||
|
# results = memory.query(
|
||||||
|
# query_embedding=example_embedding,
|
||||||
|
# limit=5
|
||||||
|
# )
|
||||||
|
|
||||||
|
# print(f"Found {len(results)} relevant memories")
|
||||||
|
# for result in results:
|
||||||
|
# print(f"Memory: {result['text']}")
|
||||||
|
# print(f"Similarity: {result['similarity_score']:.2f}")
|
||||||
|
|
||||||
|
# Initialize memory with optional cloud configuration
|
||||||
|
memory = QdrantMemory(
|
||||||
|
url=os.getenv("QDRANT_URL"), # Optional: For cloud deployment
|
||||||
|
api_key=os.getenv("QDRANT_API_KEY") # Optional: For cloud deployment
|
||||||
|
)
|
||||||
|
|
||||||
|
# Model
|
||||||
|
model = Anthropic(anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"))
|
||||||
|
|
||||||
|
# Initialize the agent with Qdrant memory
|
||||||
|
agent = Agent(
|
||||||
|
agent_name="Financial-Analysis-Agent",
|
||||||
|
system_prompt="Agent system prompt here",
|
||||||
|
agent_description="Agent performs financial analysis.",
|
||||||
|
llm=model,
|
||||||
|
long_term_memory=memory,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run a query
|
||||||
|
agent.run("What are the components of a startup's stock incentive equity plan?")
|
|
Loading…
Reference in new issue