[FEATS][CSVTOAgents] [Various Bug Fixes]

pull/700/merge
Kye Gomez 3 days ago
parent 704f91fdea
commit 5ecdc87f4a

@ -3,7 +3,6 @@ import uuid
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional from typing import Dict, List, Optional
import numpy as np
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
from qdrant_client.http import models from qdrant_client.http import models
from qdrant_client.http.models import Distance, VectorParams from qdrant_client.http.models import Distance, VectorParams
@ -18,10 +17,10 @@ class QdrantMemory:
collection_name: str = "agent_memories", collection_name: str = "agent_memories",
vector_size: int = 1536, # Default size for Claude embeddings vector_size: int = 1536, # Default size for Claude embeddings
url: Optional[str] = None, url: Optional[str] = None,
api_key: Optional[str] = None api_key: Optional[str] = None,
): ):
"""Initialize Qdrant memory system. """Initialize Qdrant memory system.
Args: Args:
collection_name: Name of the Qdrant collection to use collection_name: Name of the Qdrant collection to use
vector_size: Dimension of the embedding vectors vector_size: Dimension of the embedding vectors
@ -30,83 +29,83 @@ class QdrantMemory:
""" """
self.collection_name = collection_name self.collection_name = collection_name
self.vector_size = vector_size self.vector_size = vector_size
# Initialize Qdrant client # Initialize Qdrant client
if url and api_key: if url and api_key:
self.client = QdrantClient(url=url, api_key=api_key) self.client = QdrantClient(url=url, api_key=api_key)
else: else:
self.client = QdrantClient(":memory:") # Local in-memory storage self.client = QdrantClient(
":memory:"
) # Local in-memory storage
# Create collection if it doesn't exist # Create collection if it doesn't exist
self._create_collection() self._create_collection()
def _create_collection(self): def _create_collection(self):
"""Create the Qdrant collection if it doesn't exist.""" """Create the Qdrant collection if it doesn't exist."""
collections = self.client.get_collections().collections collections = self.client.get_collections().collections
exists = any(col.name == self.collection_name for col in collections) exists = any(
col.name == self.collection_name for col in collections
)
if not exists: if not exists:
self.client.create_collection( self.client.create_collection(
collection_name=self.collection_name, collection_name=self.collection_name,
vectors_config=VectorParams( vectors_config=VectorParams(
size=self.vector_size, size=self.vector_size, distance=Distance.COSINE
distance=Distance.COSINE ),
)
) )
def add( def add(
self, self,
text: str, text: str,
embedding: List[float], embedding: List[float],
metadata: Optional[Dict] = None metadata: Optional[Dict] = None,
) -> str: ) -> str:
"""Add a memory to the store. """Add a memory to the store.
Args: Args:
text: The text content of the memory text: The text content of the memory
embedding: Vector embedding of the text embedding: Vector embedding of the text
metadata: Optional metadata to store with the memory metadata: Optional metadata to store with the memory
Returns: Returns:
str: ID of the stored memory str: ID of the stored memory
""" """
if metadata is None: if metadata is None:
metadata = {} metadata = {}
# Add timestamp and generate ID # Add timestamp and generate ID
memory_id = str(uuid.uuid4()) memory_id = str(uuid.uuid4())
metadata.update({ metadata.update(
"timestamp": datetime.utcnow().isoformat(), {"timestamp": datetime.utcnow().isoformat(), "text": text}
"text": text )
})
# Store the point # Store the point
self.client.upsert( self.client.upsert(
collection_name=self.collection_name, collection_name=self.collection_name,
points=[ points=[
models.PointStruct( models.PointStruct(
id=memory_id, id=memory_id, payload=metadata, vector=embedding
payload=metadata,
vector=embedding
) )
] ],
) )
return memory_id return memory_id
def query( def query(
self, self,
query_embedding: List[float], query_embedding: List[float],
limit: int = 5, limit: int = 5,
score_threshold: float = 0.7 score_threshold: float = 0.7,
) -> List[Dict]: ) -> List[Dict]:
"""Query memories based on vector similarity. """Query memories based on vector similarity.
Args: Args:
query_embedding: Vector embedding of the query query_embedding: Vector embedding of the query
limit: Maximum number of results to return limit: Maximum number of results to return
score_threshold: Minimum similarity score threshold score_threshold: Minimum similarity score threshold
Returns: Returns:
List of matching memories with their metadata List of matching memories with their metadata
""" """
@ -114,52 +113,51 @@ class QdrantMemory:
collection_name=self.collection_name, collection_name=self.collection_name,
query_vector=query_embedding, query_vector=query_embedding,
limit=limit, limit=limit,
score_threshold=score_threshold score_threshold=score_threshold,
) )
memories = [] memories = []
for res in results: for res in results:
memory = res.payload memory = res.payload
memory["similarity_score"] = res.score memory["similarity_score"] = res.score
memories.append(memory) memories.append(memory)
return memories return memories
def delete(self, memory_id: str): def delete(self, memory_id: str):
"""Delete a specific memory by ID.""" """Delete a specific memory by ID."""
self.client.delete( self.client.delete(
collection_name=self.collection_name, collection_name=self.collection_name,
points_selector=models.PointIdsList( points_selector=models.PointIdsList(points=[memory_id]),
points=[memory_id]
)
) )
def clear(self): def clear(self):
"""Clear all memories from the collection.""" """Clear all memories from the collection."""
self.client.delete_collection(self.collection_name) self.client.delete_collection(self.collection_name)
self._create_collection() self._create_collection()
# # Example usage # # Example usage
# if __name__ == "__main__": # if __name__ == "__main__":
# # Initialize memory # # Initialize memory
# memory = QdrantMemory() # memory = QdrantMemory()
# # Example embedding (would normally come from an embedding model) # # Example embedding (would normally come from an embedding model)
# example_embedding = np.random.rand(1536).tolist() # example_embedding = np.random.rand(1536).tolist()
# # Add a memory # # Add a memory
# memory_id = memory.add( # memory_id = memory.add(
# text="Important financial analysis about startup equity.", # text="Important financial analysis about startup equity.",
# embedding=example_embedding, # embedding=example_embedding,
# metadata={"category": "finance", "importance": "high"} # metadata={"category": "finance", "importance": "high"}
# ) # )
# # Query memories # # Query memories
# results = memory.query( # results = memory.query(
# query_embedding=example_embedding, # query_embedding=example_embedding,
# limit=5 # limit=5
# ) # )
# print(f"Found {len(results)} relevant memories") # print(f"Found {len(results)} relevant memories")
# for result in results: # for result in results:
# print(f"Memory: {result['text']}") # print(f"Memory: {result['text']}")
@ -168,7 +166,9 @@ class QdrantMemory:
# Initialize memory with optional cloud configuration # Initialize memory with optional cloud configuration
memory = QdrantMemory( memory = QdrantMemory(
url=os.getenv("QDRANT_URL"), # Optional: For cloud deployment url=os.getenv("QDRANT_URL"), # Optional: For cloud deployment
api_key=os.getenv("QDRANT_API_KEY") # Optional: For cloud deployment api_key=os.getenv(
"QDRANT_API_KEY"
), # Optional: For cloud deployment
) )
# Model # Model
@ -184,4 +184,6 @@ agent = Agent(
) )
# Run a query # Run a query
agent.run("What are the components of a startup's stock incentive equity plan?") agent.run(
"What are the components of a startup's stock incentive equity plan?"
)

@ -14,39 +14,71 @@ from prometheus_client import Counter, Histogram, start_http_server
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from tenacity import retry, stop_after_attempt, wait_exponential from tenacity import retry, stop_after_attempt, wait_exponential
from swarms.prompts.finance_agent_sys_prompt import FINANCIAL_AGENT_SYS_PROMPT from swarms.prompts.finance_agent_sys_prompt import (
FINANCIAL_AGENT_SYS_PROMPT,
)
from swarms.structs.agent import Agent from swarms.structs.agent import Agent
# Enhanced metrics # Enhanced metrics
TASK_COUNTER = Counter('swarm_tasks_total', 'Total number of tasks processed') TASK_COUNTER = Counter(
TASK_LATENCY = Histogram('swarm_task_duration_seconds', 'Task processing duration') "swarm_tasks_total", "Total number of tasks processed"
TASK_FAILURES = Counter('swarm_task_failures_total', 'Total number of task failures') )
AGENT_ERRORS = Counter('swarm_agent_errors_total', 'Total number of agent errors') TASK_LATENCY = Histogram(
"swarm_task_duration_seconds", "Task processing duration"
)
TASK_FAILURES = Counter(
"swarm_task_failures_total", "Total number of task failures"
)
AGENT_ERRORS = Counter(
"swarm_agent_errors_total", "Total number of agent errors"
)
# Define types using Literal # Define types using Literal
TaskStatus = Literal["pending", "processing", "completed", "failed"] TaskStatus = Literal["pending", "processing", "completed", "failed"]
TaskPriority = Literal["low", "medium", "high", "critical"] TaskPriority = Literal["low", "medium", "high", "critical"]
class SecurityConfig(BaseModel): class SecurityConfig(BaseModel):
"""Security configuration for the swarm""" """Security configuration for the swarm"""
encryption_key: str = Field(..., description="Encryption key for sensitive data")
tls_cert_path: Optional[str] = Field(None, description="Path to TLS certificate") encryption_key: str = Field(
tls_key_path: Optional[str] = Field(None, description="Path to TLS private key") ..., description="Encryption key for sensitive data"
auth_token: Optional[str] = Field(None, description="Authentication token") )
max_message_size: int = Field(default=1048576, description="Maximum message size in bytes") tls_cert_path: Optional[str] = Field(
rate_limit: int = Field(default=100, description="Maximum tasks per minute") None, description="Path to TLS certificate"
)
@validator('encryption_key') tls_key_path: Optional[str] = Field(
None, description="Path to TLS private key"
)
auth_token: Optional[str] = Field(
None, description="Authentication token"
)
max_message_size: int = Field(
default=1048576, description="Maximum message size in bytes"
)
rate_limit: int = Field(
default=100, description="Maximum tasks per minute"
)
@validator("encryption_key")
def validate_encryption_key(cls, v): def validate_encryption_key(cls, v):
if len(v) < 32: if len(v) < 32:
raise ValueError("Encryption key must be at least 32 bytes long") raise ValueError(
"Encryption key must be at least 32 bytes long"
)
return v return v
class Task(BaseModel): class Task(BaseModel):
"""Enhanced task model with additional metadata and validation""" """Enhanced task model with additional metadata and validation"""
task_id: str = Field(..., description="Unique identifier for the task")
description: str = Field(..., description="Task description or instructions") task_id: str = Field(
..., description="Unique identifier for the task"
)
description: str = Field(
..., description="Task description or instructions"
)
output_type: Literal["string", "json", "file"] = Field("string") output_type: Literal["string", "json", "file"] = Field("string")
status: TaskStatus = Field(default="pending") status: TaskStatus = Field(default="pending")
priority: TaskPriority = Field(default="medium") priority: TaskPriority = Field(default="medium")
@ -55,21 +87,20 @@ class Task(BaseModel):
completed_at: Optional[datetime] = None completed_at: Optional[datetime] = None
retry_count: int = Field(default=0) retry_count: int = Field(default=0)
metadata: Dict[str, Any] = Field(default_factory=dict) metadata: Dict[str, Any] = Field(default_factory=dict)
@validator('task_id') @validator("task_id")
def validate_task_id(cls, v): def validate_task_id(cls, v):
if not v.strip(): if not v.strip():
raise ValueError("task_id cannot be empty") raise ValueError("task_id cannot be empty")
return v return v
class Config: class Config:
json_encoders = { json_encoders = {datetime: lambda v: v.isoformat()}
datetime: lambda v: v.isoformat()
}
class TaskResult(BaseModel): class TaskResult(BaseModel):
"""Model for task execution results""" """Model for task execution results"""
task_id: str task_id: str
status: TaskStatus status: TaskStatus
result: Any result: Any
@ -119,33 +150,43 @@ class SecurePulsarSwarm:
self.max_workers = max_workers self.max_workers = max_workers
self.retry_attempts = retry_attempts self.retry_attempts = retry_attempts
self.task_timeout = task_timeout self.task_timeout = task_timeout
# Initialize encryption # Initialize encryption
self.cipher_suite = Fernet(security_config.encryption_key.encode()) self.cipher_suite = Fernet(
security_config.encryption_key.encode()
)
# Setup metrics server # Setup metrics server
start_http_server(metrics_port) start_http_server(metrics_port)
# Initialize Pulsar client with security settings # Initialize Pulsar client with security settings
client_config = { client_config = {
"authentication": None if not security_config.auth_token else pulsar.AuthenticationToken(security_config.auth_token), "authentication": (
None
if not security_config.auth_token
else pulsar.AuthenticationToken(
security_config.auth_token
)
),
"operation_timeout_seconds": 30, "operation_timeout_seconds": 30,
"connection_timeout_seconds": 30, "connection_timeout_seconds": 30,
"use_tls": bool(security_config.tls_cert_path), "use_tls": bool(security_config.tls_cert_path),
"tls_trust_certs_file_path": security_config.tls_cert_path, "tls_trust_certs_file_path": security_config.tls_cert_path,
"tls_allow_insecure_connection": False, "tls_allow_insecure_connection": False,
} }
self.client = pulsar.Client(self.pulsar_url, **client_config) self.client = pulsar.Client(self.pulsar_url, **client_config)
self.producer = self._create_producer() self.producer = self._create_producer()
self.consumer = self._create_consumer() self.consumer = self._create_consumer()
self.executor = ThreadPoolExecutor(max_workers=max_workers) self.executor = ThreadPoolExecutor(max_workers=max_workers)
# Initialize rate limiting # Initialize rate limiting
self.last_execution_time = time.time() self.last_execution_time = time.time()
self.execution_count = 0 self.execution_count = 0
logger.info(f"Secure Pulsar Swarm '{self.name}' initialized with enhanced security features") logger.info(
f"Secure Pulsar Swarm '{self.name}' initialized with enhanced security features"
)
def _create_producer(self): def _create_producer(self):
"""Create a secure producer with retry logic""" """Create a secure producer with retry logic"""
@ -155,7 +196,7 @@ class SecurePulsarSwarm:
compression_type=pulsar.CompressionType.LZ4, compression_type=pulsar.CompressionType.LZ4,
block_if_queue_full=True, block_if_queue_full=True,
batching_enabled=True, batching_enabled=True,
batching_max_publish_delay_ms=10 batching_max_publish_delay_ms=10,
) )
def _create_consumer(self): def _create_consumer(self):
@ -166,7 +207,7 @@ class SecurePulsarSwarm:
consumer_type=pulsar.ConsumerType.Shared, consumer_type=pulsar.ConsumerType.Shared,
message_listener=None, message_listener=None,
receiver_queue_size=1000, receiver_queue_size=1000,
max_total_receiver_queue_size_across_partitions=50000 max_total_receiver_queue_size_across_partitions=50000,
) )
def _encrypt_message(self, data: str) -> bytes: def _encrypt_message(self, data: str) -> bytes:
@ -177,51 +218,65 @@ class SecurePulsarSwarm:
"""Decrypt message data""" """Decrypt message data"""
return self.cipher_suite.decrypt(data).decode() return self.cipher_suite.decrypt(data).decode()
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10)) @retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
)
def publish_task(self, task: Task) -> None: def publish_task(self, task: Task) -> None:
"""Publish a task with enhanced security and reliability""" """Publish a task with enhanced security and reliability"""
try: try:
# Validate message size # Validate message size
task_data = task.json() task_data = task.json()
if len(task_data) > self.security_config.max_message_size: if len(task_data) > self.security_config.max_message_size:
raise ValueError("Task data exceeds maximum message size") raise ValueError(
"Task data exceeds maximum message size"
)
# Rate limiting # Rate limiting
current_time = time.time() current_time = time.time()
if current_time - self.last_execution_time >= 60: if current_time - self.last_execution_time >= 60:
self.execution_count = 0 self.execution_count = 0
self.last_execution_time = current_time self.last_execution_time = current_time
if self.execution_count >= self.security_config.rate_limit: if (
self.execution_count
>= self.security_config.rate_limit
):
raise ValueError("Rate limit exceeded") raise ValueError("Rate limit exceeded")
# Encrypt and publish # Encrypt and publish
encrypted_data = self._encrypt_message(task_data) encrypted_data = self._encrypt_message(task_data)
message_id = self.producer.send(encrypted_data) message_id = self.producer.send(encrypted_data)
self.execution_count += 1 self.execution_count += 1
logger.info(f"Task {task.task_id} published successfully with message ID {message_id}") logger.info(
f"Task {task.task_id} published successfully with message ID {message_id}"
)
except Exception as e: except Exception as e:
TASK_FAILURES.inc() TASK_FAILURES.inc()
logger.error(f"Error publishing task {task.task_id}: {str(e)}") logger.error(
f"Error publishing task {task.task_id}: {str(e)}"
)
raise raise
async def _process_task(self, task: Task) -> TaskResult: async def _process_task(self, task: Task) -> TaskResult:
"""Process a task with comprehensive error handling and monitoring""" """Process a task with comprehensive error handling and monitoring"""
task.status = "processing" task.status = "processing"
task.started_at = datetime.utcnow() task.started_at = datetime.utcnow()
with task_timing(): with task_timing():
try: try:
# Select agent using round-robin # Select agent using round-robin
agent = self.agents.pop(0) agent = self.agents.pop(0)
self.agents.append(agent) self.agents.append(agent)
# Execute task with timeout # Execute task with timeout
future = self.executor.submit(agent.run, task.description) future = self.executor.submit(
agent.run, task.description
)
result = future.result(timeout=self.task_timeout) result = future.result(timeout=self.task_timeout)
# Handle different output types # Handle different output types
if task.output_type == "json": if task.output_type == "json":
result = json.loads(result) result = json.loads(result)
@ -230,19 +285,20 @@ class SecurePulsarSwarm:
with open(file_path, "w") as f: with open(file_path, "w") as f:
f.write(result) f.write(result)
result = {"file_path": file_path} result = {"file_path": file_path}
task.status = "completed" task.status = "completed"
task.completed_at = datetime.utcnow() task.completed_at = datetime.utcnow()
TASK_COUNTER.inc() TASK_COUNTER.inc()
return TaskResult( return TaskResult(
task_id=task.task_id, task_id=task.task_id,
status="completed", status="completed",
result=result, result=result,
execution_time=time.time() - task.started_at.timestamp(), execution_time=time.time()
agent_id=agent.agent_name - task.started_at.timestamp(),
agent_id=agent.agent_name,
) )
except TimeoutError: except TimeoutError:
TASK_FAILURES.inc() TASK_FAILURES.inc()
error_msg = f"Task {task.task_id} timed out after {self.task_timeout} seconds" error_msg = f"Task {task.task_id} timed out after {self.task_timeout} seconds"
@ -253,14 +309,17 @@ class SecurePulsarSwarm:
status="failed", status="failed",
result=None, result=None,
error_message=error_msg, error_message=error_msg,
execution_time=time.time() - task.started_at.timestamp(), execution_time=time.time()
agent_id=agent.agent_name - task.started_at.timestamp(),
agent_id=agent.agent_name,
) )
except Exception as e: except Exception as e:
TASK_FAILURES.inc() TASK_FAILURES.inc()
AGENT_ERRORS.inc() AGENT_ERRORS.inc()
error_msg = f"Error processing task {task.task_id}: {str(e)}" error_msg = (
f"Error processing task {task.task_id}: {str(e)}"
)
logger.error(error_msg) logger.error(error_msg)
task.status = "failed" task.status = "failed"
return TaskResult( return TaskResult(
@ -268,36 +327,41 @@ class SecurePulsarSwarm:
status="failed", status="failed",
result=None, result=None,
error_message=error_msg, error_message=error_msg,
execution_time=time.time() - task.started_at.timestamp(), execution_time=time.time()
agent_id=agent.agent_name - task.started_at.timestamp(),
agent_id=agent.agent_name,
) )
async def consume_tasks(self): async def consume_tasks(self):
"""Enhanced task consumption with circuit breaker and backoff""" """Enhanced task consumption with circuit breaker and backoff"""
consecutive_failures = 0 consecutive_failures = 0
backoff_time = 1 backoff_time = 1
while True: while True:
try: try:
# Circuit breaker pattern # Circuit breaker pattern
if consecutive_failures >= 5: if consecutive_failures >= 5:
logger.warning(f"Circuit breaker triggered. Waiting {backoff_time} seconds") logger.warning(
f"Circuit breaker triggered. Waiting {backoff_time} seconds"
)
await asyncio.sleep(backoff_time) await asyncio.sleep(backoff_time)
backoff_time = min(backoff_time * 2, 60) backoff_time = min(backoff_time * 2, 60)
continue continue
# Receive message with timeout # Receive message with timeout
message = await self.consumer.receive_async() message = await self.consumer.receive_async()
try: try:
# Decrypt and process message # Decrypt and process message
decrypted_data = self._decrypt_message(message.data()) decrypted_data = self._decrypt_message(
message.data()
)
task_data = json.loads(decrypted_data) task_data = json.loads(decrypted_data)
task = Task(**task_data) task = Task(**task_data)
# Process task # Process task
result = await self._process_task(task) result = await self._process_task(task)
# Handle result # Handle result
if result.status == "completed": if result.status == "completed":
await self.consumer.acknowledge_async(message) await self.consumer.acknowledge_async(message)
@ -306,16 +370,24 @@ class SecurePulsarSwarm:
else: else:
if task.retry_count < self.retry_attempts: if task.retry_count < self.retry_attempts:
task.retry_count += 1 task.retry_count += 1
await self.consumer.negative_acknowledge(message) await self.consumer.negative_acknowledge(
message
)
else: else:
await self.consumer.acknowledge_async(message) await self.consumer.acknowledge_async(
logger.error(f"Task {task.task_id} failed after {self.retry_attempts} attempts") message
)
logger.error(
f"Task {task.task_id} failed after {self.retry_attempts} attempts"
)
except Exception as e: except Exception as e:
logger.error(f"Error processing message: {str(e)}") logger.error(
f"Error processing message: {str(e)}"
)
await self.consumer.negative_acknowledge(message) await self.consumer.negative_acknowledge(message)
consecutive_failures += 1 consecutive_failures += 1
except Exception as e: except Exception as e:
logger.error(f"Error in consume_tasks: {str(e)}") logger.error(f"Error in consume_tasks: {str(e)}")
consecutive_failures += 1 consecutive_failures += 1
@ -336,6 +408,7 @@ class SecurePulsarSwarm:
except Exception as e: except Exception as e:
logger.error(f"Error during cleanup: {str(e)}") logger.error(f"Error during cleanup: {str(e)}")
if __name__ == "__main__": if __name__ == "__main__":
# Example usage with security configuration # Example usage with security configuration
security_config = SecurityConfig( security_config = SecurityConfig(
@ -344,10 +417,10 @@ if __name__ == "__main__":
tls_key_path="/path/to/key.pem", tls_key_path="/path/to/key.pem",
auth_token="your-auth-token", auth_token="your-auth-token",
max_message_size=1048576, max_message_size=1048576,
rate_limit=100 rate_limit=100,
) )
# Agent factory function # Agent factory function
def create_financial_agent() -> Agent: def create_financial_agent() -> Agent:
"""Factory function to create a financial analysis agent.""" """Factory function to create a financial analysis agent."""
return Agent( return Agent(
@ -383,7 +456,7 @@ if __name__ == "__main__":
max_workers=5, max_workers=5,
retry_attempts=3, retry_attempts=3,
task_timeout=300, task_timeout=300,
metrics_port=8000 metrics_port=8000,
) as swarm: ) as swarm:
# Example task # Example task
task = Task( task = Task(
@ -391,9 +464,12 @@ if __name__ == "__main__":
description="Analyze Q4 financial reports", description="Analyze Q4 financial reports",
output_type="json", output_type="json",
priority="high", priority="high",
metadata={"department": "finance", "requester": "john.doe@company.com"} metadata={
"department": "finance",
"requester": "john.doe@company.com",
},
) )
# Run the swarm # Run the swarm
swarm.publish_task(task) swarm.publish_task(task)
asyncio.run(swarm.consume_tasks()) asyncio.run(swarm.consume_tasks())

@ -571,9 +571,7 @@ class Agent:
) )
# Telemetry Processor to log agent data # Telemetry Processor to log agent data
threading.Thread( log_agent_data(self.to_dict())
target=log_agent_data(self.to_dict())
).start()
if self.llm is None and self.model_name is not None: if self.llm is None and self.model_name is not None:
self.llm = self.llm_handling() self.llm = self.llm_handling()

@ -1,2 +0,0 @@
agent_name,system_prompt,model_name,max_loops,autosave,dashboard,verbose,dynamic_temperature,saved_state_path,user_name,retry_attempts,context_length,return_step_meta,output_type,streaming
Financial-Analysis-Agent,You are a financial expert...,gpt-4o-mini,1,True,False,True,True,finance_agent.json,swarms_corp,3,200000,False,string,False
1 agent_name system_prompt model_name max_loops autosave dashboard verbose dynamic_temperature saved_state_path user_name retry_attempts context_length return_step_meta output_type streaming
2 Financial-Analysis-Agent You are a financial expert... gpt-4o-mini 1 True False True True finance_agent.json swarms_corp 3 200000 False string False

@ -1,7 +1,11 @@
from typing import List, Dict, Optional, TypedDict, Literal, Union, Any from typing import (
List,
Dict,
TypedDict,
Any,
)
from dataclasses import dataclass from dataclasses import dataclass
import csv import csv
import os
from pathlib import Path from pathlib import Path
import logging import logging
from enum import Enum from enum import Enum
@ -11,27 +15,31 @@ from swarms import Agent
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ModelName(str, Enum): class ModelName(str, Enum):
"""Valid model names for swarms agents""" """Valid model names for swarms agents"""
GPT4O = "gpt-4o" GPT4O = "gpt-4o"
GPT4O_MINI = "gpt-4o-mini" GPT4O_MINI = "gpt-4o-mini"
GPT4 = "gpt-4" GPT4 = "gpt-4"
GPT35_TURBO = "gpt-3.5-turbo" GPT35_TURBO = "gpt-3.5-turbo"
CLAUDE = "claude-v1" CLAUDE = "claude-v1"
CLAUDE2 = "claude-2" CLAUDE2 = "claude-2"
@classmethod @classmethod
def get_model_names(cls) -> List[str]: def get_model_names(cls) -> List[str]:
"""Get list of valid model names""" """Get list of valid model names"""
return [model.value for model in cls] return [model.value for model in cls]
@classmethod @classmethod
def is_valid_model(cls, model_name: str) -> bool: def is_valid_model(cls, model_name: str) -> bool:
"""Check if model name is valid""" """Check if model name is valid"""
return model_name in cls.get_model_names() return model_name in cls.get_model_names()
class AgentConfigDict(TypedDict): class AgentConfigDict(TypedDict):
"""TypedDict for agent configuration""" """TypedDict for agent configuration"""
agent_name: str agent_name: str
system_prompt: str system_prompt: str
model_name: str # Using str instead of ModelName for flexibility model_name: str # Using str instead of ModelName for flexibility
@ -48,9 +56,11 @@ class AgentConfigDict(TypedDict):
output_type: str output_type: str
streaming: bool streaming: bool
@dataclass @dataclass
class AgentValidationError(Exception): class AgentValidationError(Exception):
"""Custom exception for agent validation errors""" """Custom exception for agent validation errors"""
message: str message: str
field: str field: str
value: Any value: Any
@ -58,135 +68,207 @@ class AgentValidationError(Exception):
def __str__(self) -> str: def __str__(self) -> str:
return f"Validation error in field '{self.field}': {self.message}. Got value: {self.value}" return f"Validation error in field '{self.field}': {self.message}. Got value: {self.value}"
class AgentValidator: class AgentValidator:
"""Validates agent configuration data""" """Validates agent configuration data"""
@staticmethod @staticmethod
def validate_config(config: Dict[str, Any]) -> AgentConfigDict: def validate_config(config: Dict[str, Any]) -> AgentConfigDict:
"""Validate and convert agent configuration""" """Validate and convert agent configuration"""
try: try:
# Validate model name # Validate model name
model_name = str(config['model_name']) model_name = str(config["model_name"])
if not ModelName.is_valid_model(model_name): if not ModelName.is_valid_model(model_name):
valid_models = ModelName.get_model_names() valid_models = ModelName.get_model_names()
raise AgentValidationError( raise AgentValidationError(
f"Invalid model name. Must be one of: {', '.join(valid_models)}", f"Invalid model name. Must be one of: {', '.join(valid_models)}",
"model_name", "model_name",
model_name model_name,
) )
# Convert types with error handling # Convert types with error handling
validated_config: AgentConfigDict = { validated_config: AgentConfigDict = {
'agent_name': str(config.get('agent_name', '')), "agent_name": str(config.get("agent_name", "")),
'system_prompt': str(config.get('system_prompt', '')), "system_prompt": str(config.get("system_prompt", "")),
'model_name': model_name, "model_name": model_name,
'max_loops': int(config.get('max_loops', 1)), "max_loops": int(config.get("max_loops", 1)),
'autosave': bool(str(config.get('autosave', True)).lower() == 'true'), "autosave": bool(
'dashboard': bool(str(config.get('dashboard', False)).lower() == 'true'), str(config.get("autosave", True)).lower()
'verbose': bool(str(config.get('verbose', True)).lower() == 'true'), == "true"
'dynamic_temperature': bool(str(config.get('dynamic_temperature', True)).lower() == 'true'), ),
'saved_state_path': str(config.get('saved_state_path', '')), "dashboard": bool(
'user_name': str(config.get('user_name', 'default_user')), str(config.get("dashboard", False)).lower()
'retry_attempts': int(config.get('retry_attempts', 3)), == "true"
'context_length': int(config.get('context_length', 200000)), ),
'return_step_meta': bool(str(config.get('return_step_meta', False)).lower() == 'true'), "verbose": bool(
'output_type': str(config.get('output_type', 'string')), str(config.get("verbose", True)).lower() == "true"
'streaming': bool(str(config.get('streaming', False)).lower() == 'true') ),
"dynamic_temperature": bool(
str(
config.get("dynamic_temperature", True)
).lower()
== "true"
),
"saved_state_path": str(
config.get("saved_state_path", "")
),
"user_name": str(
config.get("user_name", "default_user")
),
"retry_attempts": int(
config.get("retry_attempts", 3)
),
"context_length": int(
config.get("context_length", 200000)
),
"return_step_meta": bool(
str(config.get("return_step_meta", False)).lower()
== "true"
),
"output_type": str(
config.get("output_type", "string")
),
"streaming": bool(
str(config.get("streaming", False)).lower()
== "true"
),
} }
return validated_config return validated_config
except (ValueError, KeyError) as e: except (ValueError, KeyError) as e:
raise AgentValidationError( raise AgentValidationError(
str(e), str(e), str(e.__class__.__name__), str(config)
str(e.__class__.__name__),
str(config)
) )
@dataclass @dataclass
class AgentCSV: class AgentCSV:
"""Class to manage agents through CSV with type safety""" """Class to manage agents through CSV with type safety"""
csv_path: Path csv_path: Path
def __post_init__(self) -> None: def __post_init__(self) -> None:
"""Convert string path to Path object if necessary""" """Convert string path to Path object if necessary"""
if isinstance(self.csv_path, str): if isinstance(self.csv_path, str):
self.csv_path = Path(self.csv_path) self.csv_path = Path(self.csv_path)
@property @property
def headers(self) -> List[str]: def headers(self) -> List[str]:
"""CSV headers for agent configuration""" """CSV headers for agent configuration"""
return [ return [
"agent_name", "system_prompt", "model_name", "max_loops", "agent_name",
"autosave", "dashboard", "verbose", "dynamic_temperature", "system_prompt",
"saved_state_path", "user_name", "retry_attempts", "context_length", "model_name",
"return_step_meta", "output_type", "streaming" "max_loops",
"autosave",
"dashboard",
"verbose",
"dynamic_temperature",
"saved_state_path",
"user_name",
"retry_attempts",
"context_length",
"return_step_meta",
"output_type",
"streaming",
] ]
def create_agent_csv(self, agents: List[Dict[str, Any]]) -> None: def create_agent_csv(self, agents: List[Dict[str, Any]]) -> None:
"""Create a CSV file with validated agent configurations""" """Create a CSV file with validated agent configurations"""
validated_agents = [] validated_agents = []
for agent in agents: for agent in agents:
try: try:
validated_config = AgentValidator.validate_config(agent) validated_config = AgentValidator.validate_config(
agent
)
validated_agents.append(validated_config) validated_agents.append(validated_config)
except AgentValidationError as e: except AgentValidationError as e:
logger.error(f"Validation error for agent {agent.get('agent_name', 'unknown')}: {e}") logger.error(
f"Validation error for agent {agent.get('agent_name', 'unknown')}: {e}"
)
raise raise
with open(self.csv_path, 'w', newline='') as f: with open(self.csv_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=self.headers) writer = csv.DictWriter(f, fieldnames=self.headers)
writer.writeheader() writer.writeheader()
writer.writerows(validated_agents) writer.writerows(validated_agents)
logger.info(f"Created CSV with {len(validated_agents)} agents at {self.csv_path}") logger.info(
f"Created CSV with {len(validated_agents)} agents at {self.csv_path}"
)
def load_agents(self) -> List[Agent]: def load_agents(self) -> List[Agent]:
"""Load and create agents from CSV with validation""" """Load and create agents from CSV with validation"""
if not self.csv_path.exists(): if not self.csv_path.exists():
raise FileNotFoundError(f"CSV file not found at {self.csv_path}") raise FileNotFoundError(
f"CSV file not found at {self.csv_path}"
)
agents: List[Agent] = [] agents: List[Agent] = []
with open(self.csv_path, 'r') as f: with open(self.csv_path, "r") as f:
reader = csv.DictReader(f) reader = csv.DictReader(f)
for row in reader: for row in reader:
try: try:
validated_config = AgentValidator.validate_config(row) validated_config = AgentValidator.validate_config(
row
)
agent = Agent( agent = Agent(
agent_name=validated_config['agent_name'], agent_name=validated_config["agent_name"],
system_prompt=validated_config['system_prompt'], system_prompt=validated_config[
model_name=validated_config['model_name'], "system_prompt"
max_loops=validated_config['max_loops'], ],
autosave=validated_config['autosave'], model_name=validated_config["model_name"],
dashboard=validated_config['dashboard'], max_loops=validated_config["max_loops"],
verbose=validated_config['verbose'], autosave=validated_config["autosave"],
dynamic_temperature_enabled=validated_config['dynamic_temperature'], dashboard=validated_config["dashboard"],
saved_state_path=validated_config['saved_state_path'], verbose=validated_config["verbose"],
user_name=validated_config['user_name'], dynamic_temperature_enabled=validated_config[
retry_attempts=validated_config['retry_attempts'], "dynamic_temperature"
context_length=validated_config['context_length'], ],
return_step_meta=validated_config['return_step_meta'], saved_state_path=validated_config[
output_type=validated_config['output_type'], "saved_state_path"
streaming_on=validated_config['streaming'] ],
user_name=validated_config["user_name"],
retry_attempts=validated_config[
"retry_attempts"
],
context_length=validated_config[
"context_length"
],
return_step_meta=validated_config[
"return_step_meta"
],
output_type=validated_config["output_type"],
streaming_on=validated_config["streaming"],
) )
agents.append(agent) agents.append(agent)
except AgentValidationError as e: except AgentValidationError as e:
logger.error(f"Skipping invalid agent configuration: {e}") logger.error(
f"Skipping invalid agent configuration: {e}"
)
continue continue
logger.info(f"Loaded {len(agents)} agents from {self.csv_path}") logger.info(
f"Loaded {len(agents)} agents from {self.csv_path}"
)
return agents return agents
def add_agent(self, agent_config: Dict[str, Any]) -> None: def add_agent(self, agent_config: Dict[str, Any]) -> None:
"""Add a new validated agent configuration to CSV""" """Add a new validated agent configuration to CSV"""
validated_config = AgentValidator.validate_config(agent_config) validated_config = AgentValidator.validate_config(
agent_config
with open(self.csv_path, 'a', newline='') as f: )
with open(self.csv_path, "a", newline="") as f:
writer = csv.DictWriter(f, fieldnames=self.headers) writer = csv.DictWriter(f, fieldnames=self.headers)
writer.writerow(validated_config) writer.writerow(validated_config)
logger.info(f"Added new agent {validated_config['agent_name']} to {self.csv_path}") logger.info(
f"Added new agent {validated_config['agent_name']} to {self.csv_path}"
)
# Example usage # Example usage
if __name__ == "__main__": if __name__ == "__main__":
@ -207,20 +289,20 @@ if __name__ == "__main__":
"context_length": 200000, "context_length": 200000,
"return_step_meta": False, "return_step_meta": False,
"output_type": "string", "output_type": "string",
"streaming": False "streaming": False,
} }
] ]
try: try:
# Initialize CSV manager # Initialize CSV manager
csv_manager = AgentCSV(Path("agents.csv")) csv_manager = AgentCSV(Path("agents.csv"))
# Create CSV with initial agents # Create CSV with initial agents
csv_manager.create_agent_csv(agent_configs) csv_manager.create_agent_csv(agent_configs)
# Load agents from CSV # Load agents from CSV
agents = csv_manager.load_agents() agents = csv_manager.load_agents()
# Use an agent # Use an agent
if agents: if agents:
financial_agent = agents[0] financial_agent = agents[0]
@ -228,8 +310,8 @@ if __name__ == "__main__":
"How can I establish a ROTH IRA to buy stocks and get a tax break?" "How can I establish a ROTH IRA to buy stocks and get a tax break?"
) )
print(response) print(response)
except AgentValidationError as e: except AgentValidationError as e:
logger.error(f"Validation error: {e}") logger.error(f"Validation error: {e}")
except Exception as e: except Exception as e:
logger.error(f"Unexpected error: {e}") logger.error(f"Unexpected error: {e}")

@ -11,7 +11,7 @@ Todo:
import os import os
import subprocess import subprocess
import uuid import uuid
from datetime import UTC, datetime from datetime import datetime
from typing import List, Literal, Optional from typing import List, Literal, Optional
from loguru import logger from loguru import logger
@ -241,7 +241,7 @@ class MultiAgentRouter:
dict: A dictionary containing the routing result, including the selected agent, reasoning, and response. dict: A dictionary containing the routing result, including the selected agent, reasoning, and response.
""" """
try: try:
start_time = datetime.now(UTC) start_time = datetime.now()
# Get boss decision using function calling # Get boss decision using function calling
boss_response = self.function_caller.get_completion(task) boss_response = self.function_caller.get_completion(task)
@ -259,7 +259,7 @@ class MultiAgentRouter:
final_task = boss_response.modified_task or task final_task = boss_response.modified_task or task
# Execute the task with the selected agent if enabled # Execute the task with the selected agent if enabled
execution_start = datetime.now(UTC) execution_start = datetime.now()
agent_response = None agent_response = None
execution_time = 0 execution_time = 0
@ -267,20 +267,18 @@ class MultiAgentRouter:
# Use the agent's run method directly # Use the agent's run method directly
agent_response = selected_agent.run(final_task) agent_response = selected_agent.run(final_task)
execution_time = ( execution_time = (
datetime.now(UTC) - execution_start datetime.now() - execution_start
).total_seconds() ).total_seconds()
else: else:
logger.info( logger.info(
"Task execution skipped (execute_task=False)" "Task execution skipped (execute_task=False)"
) )
total_time = ( total_time = (datetime.now() - start_time).total_seconds()
datetime.now(UTC) - start_time
).total_seconds()
result = { result = {
"id": str(uuid.uuid4()), "id": str(uuid.uuid4()),
"timestamp": datetime.now(UTC).isoformat(), "timestamp": datetime.now().isoformat(),
"task": { "task": {
"original": task, "original": task,
"modified": ( "modified": (

Loading…
Cancel
Save