|
|
|
@ -14,39 +14,71 @@ from prometheus_client import Counter, Histogram, start_http_server
|
|
|
|
|
from pydantic import BaseModel, Field, validator
|
|
|
|
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
|
|
|
|
|
|
|
|
|
from swarms.prompts.finance_agent_sys_prompt import FINANCIAL_AGENT_SYS_PROMPT
|
|
|
|
|
from swarms.prompts.finance_agent_sys_prompt import (
|
|
|
|
|
FINANCIAL_AGENT_SYS_PROMPT,
|
|
|
|
|
)
|
|
|
|
|
from swarms.structs.agent import Agent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Enhanced metrics
|
|
|
|
|
TASK_COUNTER = Counter('swarm_tasks_total', 'Total number of tasks processed')
|
|
|
|
|
TASK_LATENCY = Histogram('swarm_task_duration_seconds', 'Task processing duration')
|
|
|
|
|
TASK_FAILURES = Counter('swarm_task_failures_total', 'Total number of task failures')
|
|
|
|
|
AGENT_ERRORS = Counter('swarm_agent_errors_total', 'Total number of agent errors')
|
|
|
|
|
TASK_COUNTER = Counter(
|
|
|
|
|
"swarm_tasks_total", "Total number of tasks processed"
|
|
|
|
|
)
|
|
|
|
|
TASK_LATENCY = Histogram(
|
|
|
|
|
"swarm_task_duration_seconds", "Task processing duration"
|
|
|
|
|
)
|
|
|
|
|
TASK_FAILURES = Counter(
|
|
|
|
|
"swarm_task_failures_total", "Total number of task failures"
|
|
|
|
|
)
|
|
|
|
|
AGENT_ERRORS = Counter(
|
|
|
|
|
"swarm_agent_errors_total", "Total number of agent errors"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Define types using Literal
|
|
|
|
|
TaskStatus = Literal["pending", "processing", "completed", "failed"]
|
|
|
|
|
TaskPriority = Literal["low", "medium", "high", "critical"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SecurityConfig(BaseModel):
|
|
|
|
|
"""Security configuration for the swarm"""
|
|
|
|
|
encryption_key: str = Field(..., description="Encryption key for sensitive data")
|
|
|
|
|
tls_cert_path: Optional[str] = Field(None, description="Path to TLS certificate")
|
|
|
|
|
tls_key_path: Optional[str] = Field(None, description="Path to TLS private key")
|
|
|
|
|
auth_token: Optional[str] = Field(None, description="Authentication token")
|
|
|
|
|
max_message_size: int = Field(default=1048576, description="Maximum message size in bytes")
|
|
|
|
|
rate_limit: int = Field(default=100, description="Maximum tasks per minute")
|
|
|
|
|
|
|
|
|
|
@validator('encryption_key')
|
|
|
|
|
|
|
|
|
|
encryption_key: str = Field(
|
|
|
|
|
..., description="Encryption key for sensitive data"
|
|
|
|
|
)
|
|
|
|
|
tls_cert_path: Optional[str] = Field(
|
|
|
|
|
None, description="Path to TLS certificate"
|
|
|
|
|
)
|
|
|
|
|
tls_key_path: Optional[str] = Field(
|
|
|
|
|
None, description="Path to TLS private key"
|
|
|
|
|
)
|
|
|
|
|
auth_token: Optional[str] = Field(
|
|
|
|
|
None, description="Authentication token"
|
|
|
|
|
)
|
|
|
|
|
max_message_size: int = Field(
|
|
|
|
|
default=1048576, description="Maximum message size in bytes"
|
|
|
|
|
)
|
|
|
|
|
rate_limit: int = Field(
|
|
|
|
|
default=100, description="Maximum tasks per minute"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@validator("encryption_key")
|
|
|
|
|
def validate_encryption_key(cls, v):
|
|
|
|
|
if len(v) < 32:
|
|
|
|
|
raise ValueError("Encryption key must be at least 32 bytes long")
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Encryption key must be at least 32 bytes long"
|
|
|
|
|
)
|
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Task(BaseModel):
|
|
|
|
|
"""Enhanced task model with additional metadata and validation"""
|
|
|
|
|
task_id: str = Field(..., description="Unique identifier for the task")
|
|
|
|
|
description: str = Field(..., description="Task description or instructions")
|
|
|
|
|
|
|
|
|
|
task_id: str = Field(
|
|
|
|
|
..., description="Unique identifier for the task"
|
|
|
|
|
)
|
|
|
|
|
description: str = Field(
|
|
|
|
|
..., description="Task description or instructions"
|
|
|
|
|
)
|
|
|
|
|
output_type: Literal["string", "json", "file"] = Field("string")
|
|
|
|
|
status: TaskStatus = Field(default="pending")
|
|
|
|
|
priority: TaskPriority = Field(default="medium")
|
|
|
|
@ -55,21 +87,20 @@ class Task(BaseModel):
|
|
|
|
|
completed_at: Optional[datetime] = None
|
|
|
|
|
retry_count: int = Field(default=0)
|
|
|
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
@validator('task_id')
|
|
|
|
|
|
|
|
|
|
@validator("task_id")
|
|
|
|
|
def validate_task_id(cls, v):
|
|
|
|
|
if not v.strip():
|
|
|
|
|
raise ValueError("task_id cannot be empty")
|
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
|
json_encoders = {
|
|
|
|
|
datetime: lambda v: v.isoformat()
|
|
|
|
|
}
|
|
|
|
|
json_encoders = {datetime: lambda v: v.isoformat()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TaskResult(BaseModel):
|
|
|
|
|
"""Model for task execution results"""
|
|
|
|
|
|
|
|
|
|
task_id: str
|
|
|
|
|
status: TaskStatus
|
|
|
|
|
result: Any
|
|
|
|
@ -119,33 +150,43 @@ class SecurePulsarSwarm:
|
|
|
|
|
self.max_workers = max_workers
|
|
|
|
|
self.retry_attempts = retry_attempts
|
|
|
|
|
self.task_timeout = task_timeout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Initialize encryption
|
|
|
|
|
self.cipher_suite = Fernet(security_config.encryption_key.encode())
|
|
|
|
|
|
|
|
|
|
self.cipher_suite = Fernet(
|
|
|
|
|
security_config.encryption_key.encode()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Setup metrics server
|
|
|
|
|
start_http_server(metrics_port)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Initialize Pulsar client with security settings
|
|
|
|
|
client_config = {
|
|
|
|
|
"authentication": None if not security_config.auth_token else pulsar.AuthenticationToken(security_config.auth_token),
|
|
|
|
|
"authentication": (
|
|
|
|
|
None
|
|
|
|
|
if not security_config.auth_token
|
|
|
|
|
else pulsar.AuthenticationToken(
|
|
|
|
|
security_config.auth_token
|
|
|
|
|
)
|
|
|
|
|
),
|
|
|
|
|
"operation_timeout_seconds": 30,
|
|
|
|
|
"connection_timeout_seconds": 30,
|
|
|
|
|
"use_tls": bool(security_config.tls_cert_path),
|
|
|
|
|
"tls_trust_certs_file_path": security_config.tls_cert_path,
|
|
|
|
|
"tls_allow_insecure_connection": False,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.client = pulsar.Client(self.pulsar_url, **client_config)
|
|
|
|
|
self.producer = self._create_producer()
|
|
|
|
|
self.consumer = self._create_consumer()
|
|
|
|
|
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Initialize rate limiting
|
|
|
|
|
self.last_execution_time = time.time()
|
|
|
|
|
self.execution_count = 0
|
|
|
|
|
|
|
|
|
|
logger.info(f"Secure Pulsar Swarm '{self.name}' initialized with enhanced security features")
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Secure Pulsar Swarm '{self.name}' initialized with enhanced security features"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _create_producer(self):
|
|
|
|
|
"""Create a secure producer with retry logic"""
|
|
|
|
@ -155,7 +196,7 @@ class SecurePulsarSwarm:
|
|
|
|
|
compression_type=pulsar.CompressionType.LZ4,
|
|
|
|
|
block_if_queue_full=True,
|
|
|
|
|
batching_enabled=True,
|
|
|
|
|
batching_max_publish_delay_ms=10
|
|
|
|
|
batching_max_publish_delay_ms=10,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _create_consumer(self):
|
|
|
|
@ -166,7 +207,7 @@ class SecurePulsarSwarm:
|
|
|
|
|
consumer_type=pulsar.ConsumerType.Shared,
|
|
|
|
|
message_listener=None,
|
|
|
|
|
receiver_queue_size=1000,
|
|
|
|
|
max_total_receiver_queue_size_across_partitions=50000
|
|
|
|
|
max_total_receiver_queue_size_across_partitions=50000,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _encrypt_message(self, data: str) -> bytes:
|
|
|
|
@ -177,51 +218,65 @@ class SecurePulsarSwarm:
|
|
|
|
|
"""Decrypt message data"""
|
|
|
|
|
return self.cipher_suite.decrypt(data).decode()
|
|
|
|
|
|
|
|
|
|
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
|
|
|
|
|
@retry(
|
|
|
|
|
stop=stop_after_attempt(3),
|
|
|
|
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
|
|
|
)
|
|
|
|
|
def publish_task(self, task: Task) -> None:
|
|
|
|
|
"""Publish a task with enhanced security and reliability"""
|
|
|
|
|
try:
|
|
|
|
|
# Validate message size
|
|
|
|
|
task_data = task.json()
|
|
|
|
|
if len(task_data) > self.security_config.max_message_size:
|
|
|
|
|
raise ValueError("Task data exceeds maximum message size")
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Task data exceeds maximum message size"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Rate limiting
|
|
|
|
|
current_time = time.time()
|
|
|
|
|
if current_time - self.last_execution_time >= 60:
|
|
|
|
|
self.execution_count = 0
|
|
|
|
|
self.last_execution_time = current_time
|
|
|
|
|
|
|
|
|
|
if self.execution_count >= self.security_config.rate_limit:
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
|
|
self.execution_count
|
|
|
|
|
>= self.security_config.rate_limit
|
|
|
|
|
):
|
|
|
|
|
raise ValueError("Rate limit exceeded")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Encrypt and publish
|
|
|
|
|
encrypted_data = self._encrypt_message(task_data)
|
|
|
|
|
message_id = self.producer.send(encrypted_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.execution_count += 1
|
|
|
|
|
logger.info(f"Task {task.task_id} published successfully with message ID {message_id}")
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Task {task.task_id} published successfully with message ID {message_id}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
TASK_FAILURES.inc()
|
|
|
|
|
logger.error(f"Error publishing task {task.task_id}: {str(e)}")
|
|
|
|
|
logger.error(
|
|
|
|
|
f"Error publishing task {task.task_id}: {str(e)}"
|
|
|
|
|
)
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
async def _process_task(self, task: Task) -> TaskResult:
|
|
|
|
|
"""Process a task with comprehensive error handling and monitoring"""
|
|
|
|
|
task.status = "processing"
|
|
|
|
|
task.started_at = datetime.utcnow()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with task_timing():
|
|
|
|
|
try:
|
|
|
|
|
# Select agent using round-robin
|
|
|
|
|
agent = self.agents.pop(0)
|
|
|
|
|
self.agents.append(agent)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Execute task with timeout
|
|
|
|
|
future = self.executor.submit(agent.run, task.description)
|
|
|
|
|
future = self.executor.submit(
|
|
|
|
|
agent.run, task.description
|
|
|
|
|
)
|
|
|
|
|
result = future.result(timeout=self.task_timeout)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Handle different output types
|
|
|
|
|
if task.output_type == "json":
|
|
|
|
|
result = json.loads(result)
|
|
|
|
@ -230,19 +285,20 @@ class SecurePulsarSwarm:
|
|
|
|
|
with open(file_path, "w") as f:
|
|
|
|
|
f.write(result)
|
|
|
|
|
result = {"file_path": file_path}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
task.status = "completed"
|
|
|
|
|
task.completed_at = datetime.utcnow()
|
|
|
|
|
TASK_COUNTER.inc()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return TaskResult(
|
|
|
|
|
task_id=task.task_id,
|
|
|
|
|
status="completed",
|
|
|
|
|
result=result,
|
|
|
|
|
execution_time=time.time() - task.started_at.timestamp(),
|
|
|
|
|
agent_id=agent.agent_name
|
|
|
|
|
execution_time=time.time()
|
|
|
|
|
- task.started_at.timestamp(),
|
|
|
|
|
agent_id=agent.agent_name,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except TimeoutError:
|
|
|
|
|
TASK_FAILURES.inc()
|
|
|
|
|
error_msg = f"Task {task.task_id} timed out after {self.task_timeout} seconds"
|
|
|
|
@ -253,14 +309,17 @@ class SecurePulsarSwarm:
|
|
|
|
|
status="failed",
|
|
|
|
|
result=None,
|
|
|
|
|
error_message=error_msg,
|
|
|
|
|
execution_time=time.time() - task.started_at.timestamp(),
|
|
|
|
|
agent_id=agent.agent_name
|
|
|
|
|
execution_time=time.time()
|
|
|
|
|
- task.started_at.timestamp(),
|
|
|
|
|
agent_id=agent.agent_name,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
TASK_FAILURES.inc()
|
|
|
|
|
AGENT_ERRORS.inc()
|
|
|
|
|
error_msg = f"Error processing task {task.task_id}: {str(e)}"
|
|
|
|
|
error_msg = (
|
|
|
|
|
f"Error processing task {task.task_id}: {str(e)}"
|
|
|
|
|
)
|
|
|
|
|
logger.error(error_msg)
|
|
|
|
|
task.status = "failed"
|
|
|
|
|
return TaskResult(
|
|
|
|
@ -268,36 +327,41 @@ class SecurePulsarSwarm:
|
|
|
|
|
status="failed",
|
|
|
|
|
result=None,
|
|
|
|
|
error_message=error_msg,
|
|
|
|
|
execution_time=time.time() - task.started_at.timestamp(),
|
|
|
|
|
agent_id=agent.agent_name
|
|
|
|
|
execution_time=time.time()
|
|
|
|
|
- task.started_at.timestamp(),
|
|
|
|
|
agent_id=agent.agent_name,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def consume_tasks(self):
|
|
|
|
|
"""Enhanced task consumption with circuit breaker and backoff"""
|
|
|
|
|
consecutive_failures = 0
|
|
|
|
|
backoff_time = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
try:
|
|
|
|
|
# Circuit breaker pattern
|
|
|
|
|
if consecutive_failures >= 5:
|
|
|
|
|
logger.warning(f"Circuit breaker triggered. Waiting {backoff_time} seconds")
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"Circuit breaker triggered. Waiting {backoff_time} seconds"
|
|
|
|
|
)
|
|
|
|
|
await asyncio.sleep(backoff_time)
|
|
|
|
|
backoff_time = min(backoff_time * 2, 60)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Receive message with timeout
|
|
|
|
|
message = await self.consumer.receive_async()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Decrypt and process message
|
|
|
|
|
decrypted_data = self._decrypt_message(message.data())
|
|
|
|
|
decrypted_data = self._decrypt_message(
|
|
|
|
|
message.data()
|
|
|
|
|
)
|
|
|
|
|
task_data = json.loads(decrypted_data)
|
|
|
|
|
task = Task(**task_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Process task
|
|
|
|
|
result = await self._process_task(task)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Handle result
|
|
|
|
|
if result.status == "completed":
|
|
|
|
|
await self.consumer.acknowledge_async(message)
|
|
|
|
@ -306,16 +370,24 @@ class SecurePulsarSwarm:
|
|
|
|
|
else:
|
|
|
|
|
if task.retry_count < self.retry_attempts:
|
|
|
|
|
task.retry_count += 1
|
|
|
|
|
await self.consumer.negative_acknowledge(message)
|
|
|
|
|
await self.consumer.negative_acknowledge(
|
|
|
|
|
message
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
await self.consumer.acknowledge_async(message)
|
|
|
|
|
logger.error(f"Task {task.task_id} failed after {self.retry_attempts} attempts")
|
|
|
|
|
|
|
|
|
|
await self.consumer.acknowledge_async(
|
|
|
|
|
message
|
|
|
|
|
)
|
|
|
|
|
logger.error(
|
|
|
|
|
f"Task {task.task_id} failed after {self.retry_attempts} attempts"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error processing message: {str(e)}")
|
|
|
|
|
logger.error(
|
|
|
|
|
f"Error processing message: {str(e)}"
|
|
|
|
|
)
|
|
|
|
|
await self.consumer.negative_acknowledge(message)
|
|
|
|
|
consecutive_failures += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error in consume_tasks: {str(e)}")
|
|
|
|
|
consecutive_failures += 1
|
|
|
|
@ -336,6 +408,7 @@ class SecurePulsarSwarm:
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error during cleanup: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
# Example usage with security configuration
|
|
|
|
|
security_config = SecurityConfig(
|
|
|
|
@ -344,10 +417,10 @@ if __name__ == "__main__":
|
|
|
|
|
tls_key_path="/path/to/key.pem",
|
|
|
|
|
auth_token="your-auth-token",
|
|
|
|
|
max_message_size=1048576,
|
|
|
|
|
rate_limit=100
|
|
|
|
|
rate_limit=100,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Agent factory function
|
|
|
|
|
|
|
|
|
|
# Agent factory function
|
|
|
|
|
def create_financial_agent() -> Agent:
|
|
|
|
|
"""Factory function to create a financial analysis agent."""
|
|
|
|
|
return Agent(
|
|
|
|
@ -383,7 +456,7 @@ if __name__ == "__main__":
|
|
|
|
|
max_workers=5,
|
|
|
|
|
retry_attempts=3,
|
|
|
|
|
task_timeout=300,
|
|
|
|
|
metrics_port=8000
|
|
|
|
|
metrics_port=8000,
|
|
|
|
|
) as swarm:
|
|
|
|
|
# Example task
|
|
|
|
|
task = Task(
|
|
|
|
@ -391,9 +464,12 @@ if __name__ == "__main__":
|
|
|
|
|
description="Analyze Q4 financial reports",
|
|
|
|
|
output_type="json",
|
|
|
|
|
priority="high",
|
|
|
|
|
metadata={"department": "finance", "requester": "john.doe@company.com"}
|
|
|
|
|
metadata={
|
|
|
|
|
"department": "finance",
|
|
|
|
|
"requester": "john.doe@company.com",
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Run the swarm
|
|
|
|
|
swarm.publish_task(task)
|
|
|
|
|
asyncio.run(swarm.consume_tasks())
|
|
|
|
|
asyncio.run(swarm.consume_tasks())
|
|
|
|
|