parent
bb69f8e696
commit
704f91fdea
@ -0,0 +1,235 @@
|
||||
from typing import List, Dict, Optional, TypedDict, Literal, Union, Any
|
||||
from dataclasses import dataclass
|
||||
import csv
|
||||
import os
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from enum import Enum
|
||||
from swarms import Agent
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelName(str, Enum):
|
||||
"""Valid model names for swarms agents"""
|
||||
GPT4O = "gpt-4o"
|
||||
GPT4O_MINI = "gpt-4o-mini"
|
||||
GPT4 = "gpt-4"
|
||||
GPT35_TURBO = "gpt-3.5-turbo"
|
||||
CLAUDE = "claude-v1"
|
||||
CLAUDE2 = "claude-2"
|
||||
|
||||
@classmethod
|
||||
def get_model_names(cls) -> List[str]:
|
||||
"""Get list of valid model names"""
|
||||
return [model.value for model in cls]
|
||||
|
||||
@classmethod
|
||||
def is_valid_model(cls, model_name: str) -> bool:
|
||||
"""Check if model name is valid"""
|
||||
return model_name in cls.get_model_names()
|
||||
|
||||
class AgentConfigDict(TypedDict):
|
||||
"""TypedDict for agent configuration"""
|
||||
agent_name: str
|
||||
system_prompt: str
|
||||
model_name: str # Using str instead of ModelName for flexibility
|
||||
max_loops: int
|
||||
autosave: bool
|
||||
dashboard: bool
|
||||
verbose: bool
|
||||
dynamic_temperature: bool
|
||||
saved_state_path: str
|
||||
user_name: str
|
||||
retry_attempts: int
|
||||
context_length: int
|
||||
return_step_meta: bool
|
||||
output_type: str
|
||||
streaming: bool
|
||||
|
||||
@dataclass
|
||||
class AgentValidationError(Exception):
|
||||
"""Custom exception for agent validation errors"""
|
||||
message: str
|
||||
field: str
|
||||
value: Any
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Validation error in field '{self.field}': {self.message}. Got value: {self.value}"
|
||||
|
||||
class AgentValidator:
|
||||
"""Validates agent configuration data"""
|
||||
|
||||
@staticmethod
|
||||
def validate_config(config: Dict[str, Any]) -> AgentConfigDict:
|
||||
"""Validate and convert agent configuration"""
|
||||
try:
|
||||
# Validate model name
|
||||
model_name = str(config['model_name'])
|
||||
if not ModelName.is_valid_model(model_name):
|
||||
valid_models = ModelName.get_model_names()
|
||||
raise AgentValidationError(
|
||||
f"Invalid model name. Must be one of: {', '.join(valid_models)}",
|
||||
"model_name",
|
||||
model_name
|
||||
)
|
||||
|
||||
# Convert types with error handling
|
||||
validated_config: AgentConfigDict = {
|
||||
'agent_name': str(config.get('agent_name', '')),
|
||||
'system_prompt': str(config.get('system_prompt', '')),
|
||||
'model_name': model_name,
|
||||
'max_loops': int(config.get('max_loops', 1)),
|
||||
'autosave': bool(str(config.get('autosave', True)).lower() == 'true'),
|
||||
'dashboard': bool(str(config.get('dashboard', False)).lower() == 'true'),
|
||||
'verbose': bool(str(config.get('verbose', True)).lower() == 'true'),
|
||||
'dynamic_temperature': bool(str(config.get('dynamic_temperature', True)).lower() == 'true'),
|
||||
'saved_state_path': str(config.get('saved_state_path', '')),
|
||||
'user_name': str(config.get('user_name', 'default_user')),
|
||||
'retry_attempts': int(config.get('retry_attempts', 3)),
|
||||
'context_length': int(config.get('context_length', 200000)),
|
||||
'return_step_meta': bool(str(config.get('return_step_meta', False)).lower() == 'true'),
|
||||
'output_type': str(config.get('output_type', 'string')),
|
||||
'streaming': bool(str(config.get('streaming', False)).lower() == 'true')
|
||||
}
|
||||
|
||||
return validated_config
|
||||
|
||||
except (ValueError, KeyError) as e:
|
||||
raise AgentValidationError(
|
||||
str(e),
|
||||
str(e.__class__.__name__),
|
||||
str(config)
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class AgentCSV:
|
||||
"""Class to manage agents through CSV with type safety"""
|
||||
csv_path: Path
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Convert string path to Path object if necessary"""
|
||||
if isinstance(self.csv_path, str):
|
||||
self.csv_path = Path(self.csv_path)
|
||||
|
||||
@property
|
||||
def headers(self) -> List[str]:
|
||||
"""CSV headers for agent configuration"""
|
||||
return [
|
||||
"agent_name", "system_prompt", "model_name", "max_loops",
|
||||
"autosave", "dashboard", "verbose", "dynamic_temperature",
|
||||
"saved_state_path", "user_name", "retry_attempts", "context_length",
|
||||
"return_step_meta", "output_type", "streaming"
|
||||
]
|
||||
|
||||
def create_agent_csv(self, agents: List[Dict[str, Any]]) -> None:
|
||||
"""Create a CSV file with validated agent configurations"""
|
||||
validated_agents = []
|
||||
for agent in agents:
|
||||
try:
|
||||
validated_config = AgentValidator.validate_config(agent)
|
||||
validated_agents.append(validated_config)
|
||||
except AgentValidationError as e:
|
||||
logger.error(f"Validation error for agent {agent.get('agent_name', 'unknown')}: {e}")
|
||||
raise
|
||||
|
||||
with open(self.csv_path, 'w', newline='') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=self.headers)
|
||||
writer.writeheader()
|
||||
writer.writerows(validated_agents)
|
||||
|
||||
logger.info(f"Created CSV with {len(validated_agents)} agents at {self.csv_path}")
|
||||
|
||||
def load_agents(self) -> List[Agent]:
|
||||
"""Load and create agents from CSV with validation"""
|
||||
if not self.csv_path.exists():
|
||||
raise FileNotFoundError(f"CSV file not found at {self.csv_path}")
|
||||
|
||||
agents: List[Agent] = []
|
||||
with open(self.csv_path, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
try:
|
||||
validated_config = AgentValidator.validate_config(row)
|
||||
|
||||
agent = Agent(
|
||||
agent_name=validated_config['agent_name'],
|
||||
system_prompt=validated_config['system_prompt'],
|
||||
model_name=validated_config['model_name'],
|
||||
max_loops=validated_config['max_loops'],
|
||||
autosave=validated_config['autosave'],
|
||||
dashboard=validated_config['dashboard'],
|
||||
verbose=validated_config['verbose'],
|
||||
dynamic_temperature_enabled=validated_config['dynamic_temperature'],
|
||||
saved_state_path=validated_config['saved_state_path'],
|
||||
user_name=validated_config['user_name'],
|
||||
retry_attempts=validated_config['retry_attempts'],
|
||||
context_length=validated_config['context_length'],
|
||||
return_step_meta=validated_config['return_step_meta'],
|
||||
output_type=validated_config['output_type'],
|
||||
streaming_on=validated_config['streaming']
|
||||
)
|
||||
agents.append(agent)
|
||||
except AgentValidationError as e:
|
||||
logger.error(f"Skipping invalid agent configuration: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Loaded {len(agents)} agents from {self.csv_path}")
|
||||
return agents
|
||||
|
||||
def add_agent(self, agent_config: Dict[str, Any]) -> None:
|
||||
"""Add a new validated agent configuration to CSV"""
|
||||
validated_config = AgentValidator.validate_config(agent_config)
|
||||
|
||||
with open(self.csv_path, 'a', newline='') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=self.headers)
|
||||
writer.writerow(validated_config)
|
||||
|
||||
logger.info(f"Added new agent {validated_config['agent_name']} to {self.csv_path}")
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Example agent configurations
|
||||
agent_configs = [
|
||||
{
|
||||
"agent_name": "Financial-Analysis-Agent",
|
||||
"system_prompt": "You are a financial expert...",
|
||||
"model_name": "gpt-4o-mini", # Updated to correct model name
|
||||
"max_loops": 1,
|
||||
"autosave": True,
|
||||
"dashboard": False,
|
||||
"verbose": True,
|
||||
"dynamic_temperature": True,
|
||||
"saved_state_path": "finance_agent.json",
|
||||
"user_name": "swarms_corp",
|
||||
"retry_attempts": 3,
|
||||
"context_length": 200000,
|
||||
"return_step_meta": False,
|
||||
"output_type": "string",
|
||||
"streaming": False
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
# Initialize CSV manager
|
||||
csv_manager = AgentCSV(Path("agents.csv"))
|
||||
|
||||
# Create CSV with initial agents
|
||||
csv_manager.create_agent_csv(agent_configs)
|
||||
|
||||
# Load agents from CSV
|
||||
agents = csv_manager.load_agents()
|
||||
|
||||
# Use an agent
|
||||
if agents:
|
||||
financial_agent = agents[0]
|
||||
response = financial_agent.run(
|
||||
"How can I establish a ROTH IRA to buy stocks and get a tax break?"
|
||||
)
|
||||
print(response)
|
||||
|
||||
except AgentValidationError as e:
|
||||
logger.error(f"Validation error: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {e}")
|
@ -0,0 +1,399 @@
|
||||
import asyncio
|
||||
import json
|
||||
import secrets
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, TimeoutError
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
import pulsar
|
||||
from cryptography.fernet import Fernet
|
||||
from loguru import logger
|
||||
from prometheus_client import Counter, Histogram, start_http_server
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from swarms.prompts.finance_agent_sys_prompt import FINANCIAL_AGENT_SYS_PROMPT
|
||||
from swarms.structs.agent import Agent
|
||||
|
||||
|
||||
# Enhanced metrics
|
||||
TASK_COUNTER = Counter('swarm_tasks_total', 'Total number of tasks processed')
|
||||
TASK_LATENCY = Histogram('swarm_task_duration_seconds', 'Task processing duration')
|
||||
TASK_FAILURES = Counter('swarm_task_failures_total', 'Total number of task failures')
|
||||
AGENT_ERRORS = Counter('swarm_agent_errors_total', 'Total number of agent errors')
|
||||
|
||||
# Define types using Literal
|
||||
TaskStatus = Literal["pending", "processing", "completed", "failed"]
|
||||
TaskPriority = Literal["low", "medium", "high", "critical"]
|
||||
|
||||
class SecurityConfig(BaseModel):
|
||||
"""Security configuration for the swarm"""
|
||||
encryption_key: str = Field(..., description="Encryption key for sensitive data")
|
||||
tls_cert_path: Optional[str] = Field(None, description="Path to TLS certificate")
|
||||
tls_key_path: Optional[str] = Field(None, description="Path to TLS private key")
|
||||
auth_token: Optional[str] = Field(None, description="Authentication token")
|
||||
max_message_size: int = Field(default=1048576, description="Maximum message size in bytes")
|
||||
rate_limit: int = Field(default=100, description="Maximum tasks per minute")
|
||||
|
||||
@validator('encryption_key')
|
||||
def validate_encryption_key(cls, v):
|
||||
if len(v) < 32:
|
||||
raise ValueError("Encryption key must be at least 32 bytes long")
|
||||
return v
|
||||
|
||||
class Task(BaseModel):
|
||||
"""Enhanced task model with additional metadata and validation"""
|
||||
task_id: str = Field(..., description="Unique identifier for the task")
|
||||
description: str = Field(..., description="Task description or instructions")
|
||||
output_type: Literal["string", "json", "file"] = Field("string")
|
||||
status: TaskStatus = Field(default="pending")
|
||||
priority: TaskPriority = Field(default="medium")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
retry_count: int = Field(default=0)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@validator('task_id')
|
||||
def validate_task_id(cls, v):
|
||||
if not v.strip():
|
||||
raise ValueError("task_id cannot be empty")
|
||||
return v
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat()
|
||||
}
|
||||
|
||||
|
||||
class TaskResult(BaseModel):
|
||||
"""Model for task execution results"""
|
||||
task_id: str
|
||||
status: TaskStatus
|
||||
result: Any
|
||||
error_message: Optional[str] = None
|
||||
execution_time: float
|
||||
agent_id: str
|
||||
|
||||
|
||||
@contextmanager
|
||||
def task_timing():
|
||||
"""Context manager for timing task execution"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
duration = time.time() - start_time
|
||||
TASK_LATENCY.observe(duration)
|
||||
|
||||
|
||||
class SecurePulsarSwarm:
|
||||
"""
|
||||
Enhanced secure, scalable swarm system with improved reliability and security features.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
agents: List[Any],
|
||||
pulsar_url: str,
|
||||
subscription_name: str,
|
||||
topic_name: str,
|
||||
security_config: SecurityConfig,
|
||||
max_workers: int = 5,
|
||||
retry_attempts: int = 3,
|
||||
task_timeout: int = 300,
|
||||
metrics_port: int = 8000,
|
||||
):
|
||||
"""Initialize the enhanced Pulsar Swarm"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.agents = agents
|
||||
self.pulsar_url = pulsar_url
|
||||
self.subscription_name = subscription_name
|
||||
self.topic_name = topic_name
|
||||
self.security_config = security_config
|
||||
self.max_workers = max_workers
|
||||
self.retry_attempts = retry_attempts
|
||||
self.task_timeout = task_timeout
|
||||
|
||||
# Initialize encryption
|
||||
self.cipher_suite = Fernet(security_config.encryption_key.encode())
|
||||
|
||||
# Setup metrics server
|
||||
start_http_server(metrics_port)
|
||||
|
||||
# Initialize Pulsar client with security settings
|
||||
client_config = {
|
||||
"authentication": None if not security_config.auth_token else pulsar.AuthenticationToken(security_config.auth_token),
|
||||
"operation_timeout_seconds": 30,
|
||||
"connection_timeout_seconds": 30,
|
||||
"use_tls": bool(security_config.tls_cert_path),
|
||||
"tls_trust_certs_file_path": security_config.tls_cert_path,
|
||||
"tls_allow_insecure_connection": False,
|
||||
}
|
||||
|
||||
self.client = pulsar.Client(self.pulsar_url, **client_config)
|
||||
self.producer = self._create_producer()
|
||||
self.consumer = self._create_consumer()
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
|
||||
# Initialize rate limiting
|
||||
self.last_execution_time = time.time()
|
||||
self.execution_count = 0
|
||||
|
||||
logger.info(f"Secure Pulsar Swarm '{self.name}' initialized with enhanced security features")
|
||||
|
||||
def _create_producer(self):
|
||||
"""Create a secure producer with retry logic"""
|
||||
return self.client.create_producer(
|
||||
self.topic_name,
|
||||
max_pending_messages=1000,
|
||||
compression_type=pulsar.CompressionType.LZ4,
|
||||
block_if_queue_full=True,
|
||||
batching_enabled=True,
|
||||
batching_max_publish_delay_ms=10
|
||||
)
|
||||
|
||||
def _create_consumer(self):
|
||||
"""Create a secure consumer with retry logic"""
|
||||
return self.client.subscribe(
|
||||
self.topic_name,
|
||||
subscription_name=self.subscription_name,
|
||||
consumer_type=pulsar.ConsumerType.Shared,
|
||||
message_listener=None,
|
||||
receiver_queue_size=1000,
|
||||
max_total_receiver_queue_size_across_partitions=50000
|
||||
)
|
||||
|
||||
def _encrypt_message(self, data: str) -> bytes:
|
||||
"""Encrypt message data"""
|
||||
return self.cipher_suite.encrypt(data.encode())
|
||||
|
||||
def _decrypt_message(self, data: bytes) -> str:
|
||||
"""Decrypt message data"""
|
||||
return self.cipher_suite.decrypt(data).decode()
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
|
||||
def publish_task(self, task: Task) -> None:
|
||||
"""Publish a task with enhanced security and reliability"""
|
||||
try:
|
||||
# Validate message size
|
||||
task_data = task.json()
|
||||
if len(task_data) > self.security_config.max_message_size:
|
||||
raise ValueError("Task data exceeds maximum message size")
|
||||
|
||||
# Rate limiting
|
||||
current_time = time.time()
|
||||
if current_time - self.last_execution_time >= 60:
|
||||
self.execution_count = 0
|
||||
self.last_execution_time = current_time
|
||||
|
||||
if self.execution_count >= self.security_config.rate_limit:
|
||||
raise ValueError("Rate limit exceeded")
|
||||
|
||||
# Encrypt and publish
|
||||
encrypted_data = self._encrypt_message(task_data)
|
||||
message_id = self.producer.send(encrypted_data)
|
||||
|
||||
self.execution_count += 1
|
||||
logger.info(f"Task {task.task_id} published successfully with message ID {message_id}")
|
||||
|
||||
except Exception as e:
|
||||
TASK_FAILURES.inc()
|
||||
logger.error(f"Error publishing task {task.task_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _process_task(self, task: Task) -> TaskResult:
|
||||
"""Process a task with comprehensive error handling and monitoring"""
|
||||
task.status = "processing"
|
||||
task.started_at = datetime.utcnow()
|
||||
|
||||
with task_timing():
|
||||
try:
|
||||
# Select agent using round-robin
|
||||
agent = self.agents.pop(0)
|
||||
self.agents.append(agent)
|
||||
|
||||
# Execute task with timeout
|
||||
future = self.executor.submit(agent.run, task.description)
|
||||
result = future.result(timeout=self.task_timeout)
|
||||
|
||||
# Handle different output types
|
||||
if task.output_type == "json":
|
||||
result = json.loads(result)
|
||||
elif task.output_type == "file":
|
||||
file_path = f"output_{task.task_id}_{int(time.time())}.txt"
|
||||
with open(file_path, "w") as f:
|
||||
f.write(result)
|
||||
result = {"file_path": file_path}
|
||||
|
||||
task.status = "completed"
|
||||
task.completed_at = datetime.utcnow()
|
||||
TASK_COUNTER.inc()
|
||||
|
||||
return TaskResult(
|
||||
task_id=task.task_id,
|
||||
status="completed",
|
||||
result=result,
|
||||
execution_time=time.time() - task.started_at.timestamp(),
|
||||
agent_id=agent.agent_name
|
||||
)
|
||||
|
||||
except TimeoutError:
|
||||
TASK_FAILURES.inc()
|
||||
error_msg = f"Task {task.task_id} timed out after {self.task_timeout} seconds"
|
||||
logger.error(error_msg)
|
||||
task.status = "failed"
|
||||
return TaskResult(
|
||||
task_id=task.task_id,
|
||||
status="failed",
|
||||
result=None,
|
||||
error_message=error_msg,
|
||||
execution_time=time.time() - task.started_at.timestamp(),
|
||||
agent_id=agent.agent_name
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
TASK_FAILURES.inc()
|
||||
AGENT_ERRORS.inc()
|
||||
error_msg = f"Error processing task {task.task_id}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
task.status = "failed"
|
||||
return TaskResult(
|
||||
task_id=task.task_id,
|
||||
status="failed",
|
||||
result=None,
|
||||
error_message=error_msg,
|
||||
execution_time=time.time() - task.started_at.timestamp(),
|
||||
agent_id=agent.agent_name
|
||||
)
|
||||
|
||||
async def consume_tasks(self):
|
||||
"""Enhanced task consumption with circuit breaker and backoff"""
|
||||
consecutive_failures = 0
|
||||
backoff_time = 1
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Circuit breaker pattern
|
||||
if consecutive_failures >= 5:
|
||||
logger.warning(f"Circuit breaker triggered. Waiting {backoff_time} seconds")
|
||||
await asyncio.sleep(backoff_time)
|
||||
backoff_time = min(backoff_time * 2, 60)
|
||||
continue
|
||||
|
||||
# Receive message with timeout
|
||||
message = await self.consumer.receive_async()
|
||||
|
||||
try:
|
||||
# Decrypt and process message
|
||||
decrypted_data = self._decrypt_message(message.data())
|
||||
task_data = json.loads(decrypted_data)
|
||||
task = Task(**task_data)
|
||||
|
||||
# Process task
|
||||
result = await self._process_task(task)
|
||||
|
||||
# Handle result
|
||||
if result.status == "completed":
|
||||
await self.consumer.acknowledge_async(message)
|
||||
consecutive_failures = 0
|
||||
backoff_time = 1
|
||||
else:
|
||||
if task.retry_count < self.retry_attempts:
|
||||
task.retry_count += 1
|
||||
await self.consumer.negative_acknowledge(message)
|
||||
else:
|
||||
await self.consumer.acknowledge_async(message)
|
||||
logger.error(f"Task {task.task_id} failed after {self.retry_attempts} attempts")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {str(e)}")
|
||||
await self.consumer.negative_acknowledge(message)
|
||||
consecutive_failures += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in consume_tasks: {str(e)}")
|
||||
consecutive_failures += 1
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry"""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit with proper cleanup"""
|
||||
try:
|
||||
self.producer.flush()
|
||||
self.producer.close()
|
||||
self.consumer.close()
|
||||
self.client.close()
|
||||
self.executor.shutdown(wait=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during cleanup: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage with security configuration
|
||||
security_config = SecurityConfig(
|
||||
encryption_key=secrets.token_urlsafe(32),
|
||||
tls_cert_path="/path/to/cert.pem",
|
||||
tls_key_path="/path/to/key.pem",
|
||||
auth_token="your-auth-token",
|
||||
max_message_size=1048576,
|
||||
rate_limit=100
|
||||
)
|
||||
|
||||
# Agent factory function
|
||||
def create_financial_agent() -> Agent:
|
||||
"""Factory function to create a financial analysis agent."""
|
||||
return Agent(
|
||||
agent_name="Financial-Analysis-Agent",
|
||||
system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
|
||||
model_name="gpt-4o-mini",
|
||||
max_loops=1,
|
||||
autosave=True,
|
||||
dashboard=False,
|
||||
verbose=True,
|
||||
dynamic_temperature_enabled=True,
|
||||
saved_state_path="finance_agent.json",
|
||||
user_name="swarms_corp",
|
||||
retry_attempts=1,
|
||||
context_length=200000,
|
||||
return_step_meta=False,
|
||||
output_type="string",
|
||||
streaming_on=False,
|
||||
)
|
||||
|
||||
# Initialize agents (implementation not shown)
|
||||
agents = [create_financial_agent() for _ in range(3)]
|
||||
|
||||
# Initialize the secure swarm
|
||||
with SecurePulsarSwarm(
|
||||
name="Secure Financial Swarm",
|
||||
description="Production-grade financial analysis swarm",
|
||||
agents=agents,
|
||||
pulsar_url="pulsar+ssl://localhost:6651",
|
||||
subscription_name="secure_financial_subscription",
|
||||
topic_name="secure_financial_tasks",
|
||||
security_config=security_config,
|
||||
max_workers=5,
|
||||
retry_attempts=3,
|
||||
task_timeout=300,
|
||||
metrics_port=8000
|
||||
) as swarm:
|
||||
# Example task
|
||||
task = Task(
|
||||
task_id=secrets.token_urlsafe(16),
|
||||
description="Analyze Q4 financial reports",
|
||||
output_type="json",
|
||||
priority="high",
|
||||
metadata={"department": "finance", "requester": "john.doe@company.com"}
|
||||
)
|
||||
|
||||
# Run the swarm
|
||||
swarm.publish_task(task)
|
||||
asyncio.run(swarm.consume_tasks())
|
@ -0,0 +1,187 @@
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.models import Distance, VectorParams
|
||||
from swarm_models import Anthropic
|
||||
|
||||
from swarms import Agent
|
||||
|
||||
|
||||
class QdrantMemory:
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str = "agent_memories",
|
||||
vector_size: int = 1536, # Default size for Claude embeddings
|
||||
url: Optional[str] = None,
|
||||
api_key: Optional[str] = None
|
||||
):
|
||||
"""Initialize Qdrant memory system.
|
||||
|
||||
Args:
|
||||
collection_name: Name of the Qdrant collection to use
|
||||
vector_size: Dimension of the embedding vectors
|
||||
url: Optional Qdrant server URL (defaults to local)
|
||||
api_key: Optional Qdrant API key for cloud deployment
|
||||
"""
|
||||
self.collection_name = collection_name
|
||||
self.vector_size = vector_size
|
||||
|
||||
# Initialize Qdrant client
|
||||
if url and api_key:
|
||||
self.client = QdrantClient(url=url, api_key=api_key)
|
||||
else:
|
||||
self.client = QdrantClient(":memory:") # Local in-memory storage
|
||||
|
||||
# Create collection if it doesn't exist
|
||||
self._create_collection()
|
||||
|
||||
def _create_collection(self):
|
||||
"""Create the Qdrant collection if it doesn't exist."""
|
||||
collections = self.client.get_collections().collections
|
||||
exists = any(col.name == self.collection_name for col in collections)
|
||||
|
||||
if not exists:
|
||||
self.client.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
vectors_config=VectorParams(
|
||||
size=self.vector_size,
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
|
||||
def add(
|
||||
self,
|
||||
text: str,
|
||||
embedding: List[float],
|
||||
metadata: Optional[Dict] = None
|
||||
) -> str:
|
||||
"""Add a memory to the store.
|
||||
|
||||
Args:
|
||||
text: The text content of the memory
|
||||
embedding: Vector embedding of the text
|
||||
metadata: Optional metadata to store with the memory
|
||||
|
||||
Returns:
|
||||
str: ID of the stored memory
|
||||
"""
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
# Add timestamp and generate ID
|
||||
memory_id = str(uuid.uuid4())
|
||||
metadata.update({
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"text": text
|
||||
})
|
||||
|
||||
# Store the point
|
||||
self.client.upsert(
|
||||
collection_name=self.collection_name,
|
||||
points=[
|
||||
models.PointStruct(
|
||||
id=memory_id,
|
||||
payload=metadata,
|
||||
vector=embedding
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return memory_id
|
||||
|
||||
def query(
|
||||
self,
|
||||
query_embedding: List[float],
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.7
|
||||
) -> List[Dict]:
|
||||
"""Query memories based on vector similarity.
|
||||
|
||||
Args:
|
||||
query_embedding: Vector embedding of the query
|
||||
limit: Maximum number of results to return
|
||||
score_threshold: Minimum similarity score threshold
|
||||
|
||||
Returns:
|
||||
List of matching memories with their metadata
|
||||
"""
|
||||
results = self.client.search(
|
||||
collection_name=self.collection_name,
|
||||
query_vector=query_embedding,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold
|
||||
)
|
||||
|
||||
memories = []
|
||||
for res in results:
|
||||
memory = res.payload
|
||||
memory["similarity_score"] = res.score
|
||||
memories.append(memory)
|
||||
|
||||
return memories
|
||||
|
||||
def delete(self, memory_id: str):
|
||||
"""Delete a specific memory by ID."""
|
||||
self.client.delete(
|
||||
collection_name=self.collection_name,
|
||||
points_selector=models.PointIdsList(
|
||||
points=[memory_id]
|
||||
)
|
||||
)
|
||||
|
||||
def clear(self):
|
||||
"""Clear all memories from the collection."""
|
||||
self.client.delete_collection(self.collection_name)
|
||||
self._create_collection()
|
||||
|
||||
# # Example usage
|
||||
# if __name__ == "__main__":
|
||||
# # Initialize memory
|
||||
# memory = QdrantMemory()
|
||||
|
||||
# # Example embedding (would normally come from an embedding model)
|
||||
# example_embedding = np.random.rand(1536).tolist()
|
||||
|
||||
# # Add a memory
|
||||
# memory_id = memory.add(
|
||||
# text="Important financial analysis about startup equity.",
|
||||
# embedding=example_embedding,
|
||||
# metadata={"category": "finance", "importance": "high"}
|
||||
# )
|
||||
|
||||
# # Query memories
|
||||
# results = memory.query(
|
||||
# query_embedding=example_embedding,
|
||||
# limit=5
|
||||
# )
|
||||
|
||||
# print(f"Found {len(results)} relevant memories")
|
||||
# for result in results:
|
||||
# print(f"Memory: {result['text']}")
|
||||
# print(f"Similarity: {result['similarity_score']:.2f}")
|
||||
|
||||
# Initialize memory with optional cloud configuration
|
||||
memory = QdrantMemory(
|
||||
url=os.getenv("QDRANT_URL"), # Optional: For cloud deployment
|
||||
api_key=os.getenv("QDRANT_API_KEY") # Optional: For cloud deployment
|
||||
)
|
||||
|
||||
# Model
|
||||
model = Anthropic(anthropic_api_key=os.getenv("ANTHROPIC_API_KEY"))
|
||||
|
||||
# Initialize the agent with Qdrant memory
|
||||
agent = Agent(
|
||||
agent_name="Financial-Analysis-Agent",
|
||||
system_prompt="Agent system prompt here",
|
||||
agent_description="Agent performs financial analysis.",
|
||||
llm=model,
|
||||
long_term_memory=memory,
|
||||
)
|
||||
|
||||
# Run a query
|
||||
agent.run("What are the components of a startup's stock incentive equity plan?")
|
|
Loading…
Reference in new issue