[FEATS][CSVTOAgents] [Various Bug Fixes]

pull/700/merge
Kye Gomez 4 days ago
parent 704f91fdea
commit 5ecdc87f4a

@ -3,7 +3,6 @@ import uuid
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional from typing import Dict, List, Optional
import numpy as np
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
from qdrant_client.http import models from qdrant_client.http import models
from qdrant_client.http.models import Distance, VectorParams from qdrant_client.http.models import Distance, VectorParams
@ -18,7 +17,7 @@ class QdrantMemory:
collection_name: str = "agent_memories", collection_name: str = "agent_memories",
vector_size: int = 1536, # Default size for Claude embeddings vector_size: int = 1536, # Default size for Claude embeddings
url: Optional[str] = None, url: Optional[str] = None,
api_key: Optional[str] = None api_key: Optional[str] = None,
): ):
"""Initialize Qdrant memory system. """Initialize Qdrant memory system.
@ -35,7 +34,9 @@ class QdrantMemory:
if url and api_key: if url and api_key:
self.client = QdrantClient(url=url, api_key=api_key) self.client = QdrantClient(url=url, api_key=api_key)
else: else:
self.client = QdrantClient(":memory:") # Local in-memory storage self.client = QdrantClient(
":memory:"
) # Local in-memory storage
# Create collection if it doesn't exist # Create collection if it doesn't exist
self._create_collection() self._create_collection()
@ -43,22 +44,23 @@ class QdrantMemory:
def _create_collection(self): def _create_collection(self):
"""Create the Qdrant collection if it doesn't exist.""" """Create the Qdrant collection if it doesn't exist."""
collections = self.client.get_collections().collections collections = self.client.get_collections().collections
exists = any(col.name == self.collection_name for col in collections) exists = any(
col.name == self.collection_name for col in collections
)
if not exists: if not exists:
self.client.create_collection( self.client.create_collection(
collection_name=self.collection_name, collection_name=self.collection_name,
vectors_config=VectorParams( vectors_config=VectorParams(
size=self.vector_size, size=self.vector_size, distance=Distance.COSINE
distance=Distance.COSINE ),
)
) )
def add( def add(
self, self,
text: str, text: str,
embedding: List[float], embedding: List[float],
metadata: Optional[Dict] = None metadata: Optional[Dict] = None,
) -> str: ) -> str:
"""Add a memory to the store. """Add a memory to the store.
@ -75,21 +77,18 @@ class QdrantMemory:
# Add timestamp and generate ID # Add timestamp and generate ID
memory_id = str(uuid.uuid4()) memory_id = str(uuid.uuid4())
metadata.update({ metadata.update(
"timestamp": datetime.utcnow().isoformat(), {"timestamp": datetime.utcnow().isoformat(), "text": text}
"text": text )
})
# Store the point # Store the point
self.client.upsert( self.client.upsert(
collection_name=self.collection_name, collection_name=self.collection_name,
points=[ points=[
models.PointStruct( models.PointStruct(
id=memory_id, id=memory_id, payload=metadata, vector=embedding
payload=metadata,
vector=embedding
) )
] ],
) )
return memory_id return memory_id
@ -98,7 +97,7 @@ class QdrantMemory:
self, self,
query_embedding: List[float], query_embedding: List[float],
limit: int = 5, limit: int = 5,
score_threshold: float = 0.7 score_threshold: float = 0.7,
) -> List[Dict]: ) -> List[Dict]:
"""Query memories based on vector similarity. """Query memories based on vector similarity.
@ -114,7 +113,7 @@ class QdrantMemory:
collection_name=self.collection_name, collection_name=self.collection_name,
query_vector=query_embedding, query_vector=query_embedding,
limit=limit, limit=limit,
score_threshold=score_threshold score_threshold=score_threshold,
) )
memories = [] memories = []
@ -129,9 +128,7 @@ class QdrantMemory:
"""Delete a specific memory by ID.""" """Delete a specific memory by ID."""
self.client.delete( self.client.delete(
collection_name=self.collection_name, collection_name=self.collection_name,
points_selector=models.PointIdsList( points_selector=models.PointIdsList(points=[memory_id]),
points=[memory_id]
)
) )
def clear(self): def clear(self):
@ -139,6 +136,7 @@ class QdrantMemory:
self.client.delete_collection(self.collection_name) self.client.delete_collection(self.collection_name)
self._create_collection() self._create_collection()
# # Example usage # # Example usage
# if __name__ == "__main__": # if __name__ == "__main__":
# # Initialize memory # # Initialize memory
@ -168,7 +166,9 @@ class QdrantMemory:
# Initialize memory with optional cloud configuration # Initialize memory with optional cloud configuration
memory = QdrantMemory( memory = QdrantMemory(
url=os.getenv("QDRANT_URL"), # Optional: For cloud deployment url=os.getenv("QDRANT_URL"), # Optional: For cloud deployment
api_key=os.getenv("QDRANT_API_KEY") # Optional: For cloud deployment api_key=os.getenv(
"QDRANT_API_KEY"
), # Optional: For cloud deployment
) )
# Model # Model
@ -184,4 +184,6 @@ agent = Agent(
) )
# Run a query # Run a query
agent.run("What are the components of a startup's stock incentive equity plan?") agent.run(
"What are the components of a startup's stock incentive equity plan?"
)

@ -14,39 +14,71 @@ from prometheus_client import Counter, Histogram, start_http_server
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from tenacity import retry, stop_after_attempt, wait_exponential from tenacity import retry, stop_after_attempt, wait_exponential
from swarms.prompts.finance_agent_sys_prompt import FINANCIAL_AGENT_SYS_PROMPT from swarms.prompts.finance_agent_sys_prompt import (
FINANCIAL_AGENT_SYS_PROMPT,
)
from swarms.structs.agent import Agent from swarms.structs.agent import Agent
# Enhanced metrics # Enhanced metrics
TASK_COUNTER = Counter('swarm_tasks_total', 'Total number of tasks processed') TASK_COUNTER = Counter(
TASK_LATENCY = Histogram('swarm_task_duration_seconds', 'Task processing duration') "swarm_tasks_total", "Total number of tasks processed"
TASK_FAILURES = Counter('swarm_task_failures_total', 'Total number of task failures') )
AGENT_ERRORS = Counter('swarm_agent_errors_total', 'Total number of agent errors') TASK_LATENCY = Histogram(
"swarm_task_duration_seconds", "Task processing duration"
)
TASK_FAILURES = Counter(
"swarm_task_failures_total", "Total number of task failures"
)
AGENT_ERRORS = Counter(
"swarm_agent_errors_total", "Total number of agent errors"
)
# Define types using Literal # Define types using Literal
TaskStatus = Literal["pending", "processing", "completed", "failed"] TaskStatus = Literal["pending", "processing", "completed", "failed"]
TaskPriority = Literal["low", "medium", "high", "critical"] TaskPriority = Literal["low", "medium", "high", "critical"]
class SecurityConfig(BaseModel): class SecurityConfig(BaseModel):
"""Security configuration for the swarm""" """Security configuration for the swarm"""
encryption_key: str = Field(..., description="Encryption key for sensitive data")
tls_cert_path: Optional[str] = Field(None, description="Path to TLS certificate") encryption_key: str = Field(
tls_key_path: Optional[str] = Field(None, description="Path to TLS private key") ..., description="Encryption key for sensitive data"
auth_token: Optional[str] = Field(None, description="Authentication token") )
max_message_size: int = Field(default=1048576, description="Maximum message size in bytes") tls_cert_path: Optional[str] = Field(
rate_limit: int = Field(default=100, description="Maximum tasks per minute") None, description="Path to TLS certificate"
)
@validator('encryption_key') tls_key_path: Optional[str] = Field(
None, description="Path to TLS private key"
)
auth_token: Optional[str] = Field(
None, description="Authentication token"
)
max_message_size: int = Field(
default=1048576, description="Maximum message size in bytes"
)
rate_limit: int = Field(
default=100, description="Maximum tasks per minute"
)
@validator("encryption_key")
def validate_encryption_key(cls, v): def validate_encryption_key(cls, v):
if len(v) < 32: if len(v) < 32:
raise ValueError("Encryption key must be at least 32 bytes long") raise ValueError(
"Encryption key must be at least 32 bytes long"
)
return v return v
class Task(BaseModel): class Task(BaseModel):
"""Enhanced task model with additional metadata and validation""" """Enhanced task model with additional metadata and validation"""
task_id: str = Field(..., description="Unique identifier for the task")
description: str = Field(..., description="Task description or instructions") task_id: str = Field(
..., description="Unique identifier for the task"
)
description: str = Field(
..., description="Task description or instructions"
)
output_type: Literal["string", "json", "file"] = Field("string") output_type: Literal["string", "json", "file"] = Field("string")
status: TaskStatus = Field(default="pending") status: TaskStatus = Field(default="pending")
priority: TaskPriority = Field(default="medium") priority: TaskPriority = Field(default="medium")
@ -56,20 +88,19 @@ class Task(BaseModel):
retry_count: int = Field(default=0) retry_count: int = Field(default=0)
metadata: Dict[str, Any] = Field(default_factory=dict) metadata: Dict[str, Any] = Field(default_factory=dict)
@validator('task_id') @validator("task_id")
def validate_task_id(cls, v): def validate_task_id(cls, v):
if not v.strip(): if not v.strip():
raise ValueError("task_id cannot be empty") raise ValueError("task_id cannot be empty")
return v return v
class Config: class Config:
json_encoders = { json_encoders = {datetime: lambda v: v.isoformat()}
datetime: lambda v: v.isoformat()
}
class TaskResult(BaseModel): class TaskResult(BaseModel):
"""Model for task execution results""" """Model for task execution results"""
task_id: str task_id: str
status: TaskStatus status: TaskStatus
result: Any result: Any
@ -121,14 +152,22 @@ class SecurePulsarSwarm:
self.task_timeout = task_timeout self.task_timeout = task_timeout
# Initialize encryption # Initialize encryption
self.cipher_suite = Fernet(security_config.encryption_key.encode()) self.cipher_suite = Fernet(
security_config.encryption_key.encode()
)
# Setup metrics server # Setup metrics server
start_http_server(metrics_port) start_http_server(metrics_port)
# Initialize Pulsar client with security settings # Initialize Pulsar client with security settings
client_config = { client_config = {
"authentication": None if not security_config.auth_token else pulsar.AuthenticationToken(security_config.auth_token), "authentication": (
None
if not security_config.auth_token
else pulsar.AuthenticationToken(
security_config.auth_token
)
),
"operation_timeout_seconds": 30, "operation_timeout_seconds": 30,
"connection_timeout_seconds": 30, "connection_timeout_seconds": 30,
"use_tls": bool(security_config.tls_cert_path), "use_tls": bool(security_config.tls_cert_path),
@ -145,7 +184,9 @@ class SecurePulsarSwarm:
self.last_execution_time = time.time() self.last_execution_time = time.time()
self.execution_count = 0 self.execution_count = 0
logger.info(f"Secure Pulsar Swarm '{self.name}' initialized with enhanced security features") logger.info(
f"Secure Pulsar Swarm '{self.name}' initialized with enhanced security features"
)
def _create_producer(self): def _create_producer(self):
"""Create a secure producer with retry logic""" """Create a secure producer with retry logic"""
@ -155,7 +196,7 @@ class SecurePulsarSwarm:
compression_type=pulsar.CompressionType.LZ4, compression_type=pulsar.CompressionType.LZ4,
block_if_queue_full=True, block_if_queue_full=True,
batching_enabled=True, batching_enabled=True,
batching_max_publish_delay_ms=10 batching_max_publish_delay_ms=10,
) )
def _create_consumer(self): def _create_consumer(self):
@ -166,7 +207,7 @@ class SecurePulsarSwarm:
consumer_type=pulsar.ConsumerType.Shared, consumer_type=pulsar.ConsumerType.Shared,
message_listener=None, message_listener=None,
receiver_queue_size=1000, receiver_queue_size=1000,
max_total_receiver_queue_size_across_partitions=50000 max_total_receiver_queue_size_across_partitions=50000,
) )
def _encrypt_message(self, data: str) -> bytes: def _encrypt_message(self, data: str) -> bytes:
@ -177,14 +218,19 @@ class SecurePulsarSwarm:
"""Decrypt message data""" """Decrypt message data"""
return self.cipher_suite.decrypt(data).decode() return self.cipher_suite.decrypt(data).decode()
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10)) @retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
)
def publish_task(self, task: Task) -> None: def publish_task(self, task: Task) -> None:
"""Publish a task with enhanced security and reliability""" """Publish a task with enhanced security and reliability"""
try: try:
# Validate message size # Validate message size
task_data = task.json() task_data = task.json()
if len(task_data) > self.security_config.max_message_size: if len(task_data) > self.security_config.max_message_size:
raise ValueError("Task data exceeds maximum message size") raise ValueError(
"Task data exceeds maximum message size"
)
# Rate limiting # Rate limiting
current_time = time.time() current_time = time.time()
@ -192,7 +238,10 @@ class SecurePulsarSwarm:
self.execution_count = 0 self.execution_count = 0
self.last_execution_time = current_time self.last_execution_time = current_time
if self.execution_count >= self.security_config.rate_limit: if (
self.execution_count
>= self.security_config.rate_limit
):
raise ValueError("Rate limit exceeded") raise ValueError("Rate limit exceeded")
# Encrypt and publish # Encrypt and publish
@ -200,11 +249,15 @@ class SecurePulsarSwarm:
message_id = self.producer.send(encrypted_data) message_id = self.producer.send(encrypted_data)
self.execution_count += 1 self.execution_count += 1
logger.info(f"Task {task.task_id} published successfully with message ID {message_id}") logger.info(
f"Task {task.task_id} published successfully with message ID {message_id}"
)
except Exception as e: except Exception as e:
TASK_FAILURES.inc() TASK_FAILURES.inc()
logger.error(f"Error publishing task {task.task_id}: {str(e)}") logger.error(
f"Error publishing task {task.task_id}: {str(e)}"
)
raise raise
async def _process_task(self, task: Task) -> TaskResult: async def _process_task(self, task: Task) -> TaskResult:
@ -219,7 +272,9 @@ class SecurePulsarSwarm:
self.agents.append(agent) self.agents.append(agent)
# Execute task with timeout # Execute task with timeout
future = self.executor.submit(agent.run, task.description) future = self.executor.submit(
agent.run, task.description
)
result = future.result(timeout=self.task_timeout) result = future.result(timeout=self.task_timeout)
# Handle different output types # Handle different output types
@ -239,8 +294,9 @@ class SecurePulsarSwarm:
task_id=task.task_id, task_id=task.task_id,
status="completed", status="completed",
result=result, result=result,
execution_time=time.time() - task.started_at.timestamp(), execution_time=time.time()
agent_id=agent.agent_name - task.started_at.timestamp(),
agent_id=agent.agent_name,
) )
except TimeoutError: except TimeoutError:
@ -253,14 +309,17 @@ class SecurePulsarSwarm:
status="failed", status="failed",
result=None, result=None,
error_message=error_msg, error_message=error_msg,
execution_time=time.time() - task.started_at.timestamp(), execution_time=time.time()
agent_id=agent.agent_name - task.started_at.timestamp(),
agent_id=agent.agent_name,
) )
except Exception as e: except Exception as e:
TASK_FAILURES.inc() TASK_FAILURES.inc()
AGENT_ERRORS.inc() AGENT_ERRORS.inc()
error_msg = f"Error processing task {task.task_id}: {str(e)}" error_msg = (
f"Error processing task {task.task_id}: {str(e)}"
)
logger.error(error_msg) logger.error(error_msg)
task.status = "failed" task.status = "failed"
return TaskResult( return TaskResult(
@ -268,8 +327,9 @@ class SecurePulsarSwarm:
status="failed", status="failed",
result=None, result=None,
error_message=error_msg, error_message=error_msg,
execution_time=time.time() - task.started_at.timestamp(), execution_time=time.time()
agent_id=agent.agent_name - task.started_at.timestamp(),
agent_id=agent.agent_name,
) )
async def consume_tasks(self): async def consume_tasks(self):
@ -281,7 +341,9 @@ class SecurePulsarSwarm:
try: try:
# Circuit breaker pattern # Circuit breaker pattern
if consecutive_failures >= 5: if consecutive_failures >= 5:
logger.warning(f"Circuit breaker triggered. Waiting {backoff_time} seconds") logger.warning(
f"Circuit breaker triggered. Waiting {backoff_time} seconds"
)
await asyncio.sleep(backoff_time) await asyncio.sleep(backoff_time)
backoff_time = min(backoff_time * 2, 60) backoff_time = min(backoff_time * 2, 60)
continue continue
@ -291,7 +353,9 @@ class SecurePulsarSwarm:
try: try:
# Decrypt and process message # Decrypt and process message
decrypted_data = self._decrypt_message(message.data()) decrypted_data = self._decrypt_message(
message.data()
)
task_data = json.loads(decrypted_data) task_data = json.loads(decrypted_data)
task = Task(**task_data) task = Task(**task_data)
@ -306,13 +370,21 @@ class SecurePulsarSwarm:
else: else:
if task.retry_count < self.retry_attempts: if task.retry_count < self.retry_attempts:
task.retry_count += 1 task.retry_count += 1
await self.consumer.negative_acknowledge(message) await self.consumer.negative_acknowledge(
message
)
else: else:
await self.consumer.acknowledge_async(message) await self.consumer.acknowledge_async(
logger.error(f"Task {task.task_id} failed after {self.retry_attempts} attempts") message
)
logger.error(
f"Task {task.task_id} failed after {self.retry_attempts} attempts"
)
except Exception as e: except Exception as e:
logger.error(f"Error processing message: {str(e)}") logger.error(
f"Error processing message: {str(e)}"
)
await self.consumer.negative_acknowledge(message) await self.consumer.negative_acknowledge(message)
consecutive_failures += 1 consecutive_failures += 1
@ -336,6 +408,7 @@ class SecurePulsarSwarm:
except Exception as e: except Exception as e:
logger.error(f"Error during cleanup: {str(e)}") logger.error(f"Error during cleanup: {str(e)}")
if __name__ == "__main__": if __name__ == "__main__":
# Example usage with security configuration # Example usage with security configuration
security_config = SecurityConfig( security_config = SecurityConfig(
@ -344,7 +417,7 @@ if __name__ == "__main__":
tls_key_path="/path/to/key.pem", tls_key_path="/path/to/key.pem",
auth_token="your-auth-token", auth_token="your-auth-token",
max_message_size=1048576, max_message_size=1048576,
rate_limit=100 rate_limit=100,
) )
# Agent factory function # Agent factory function
@ -383,7 +456,7 @@ if __name__ == "__main__":
max_workers=5, max_workers=5,
retry_attempts=3, retry_attempts=3,
task_timeout=300, task_timeout=300,
metrics_port=8000 metrics_port=8000,
) as swarm: ) as swarm:
# Example task # Example task
task = Task( task = Task(
@ -391,7 +464,10 @@ if __name__ == "__main__":
description="Analyze Q4 financial reports", description="Analyze Q4 financial reports",
output_type="json", output_type="json",
priority="high", priority="high",
metadata={"department": "finance", "requester": "john.doe@company.com"} metadata={
"department": "finance",
"requester": "john.doe@company.com",
},
) )
# Run the swarm # Run the swarm

@ -571,9 +571,7 @@ class Agent:
) )
# Telemetry Processor to log agent data # Telemetry Processor to log agent data
threading.Thread( log_agent_data(self.to_dict())
target=log_agent_data(self.to_dict())
).start()
if self.llm is None and self.model_name is not None: if self.llm is None and self.model_name is not None:
self.llm = self.llm_handling() self.llm = self.llm_handling()

@ -1,2 +0,0 @@
agent_name,system_prompt,model_name,max_loops,autosave,dashboard,verbose,dynamic_temperature,saved_state_path,user_name,retry_attempts,context_length,return_step_meta,output_type,streaming
Financial-Analysis-Agent,You are a financial expert...,gpt-4o-mini,1,True,False,True,True,finance_agent.json,swarms_corp,3,200000,False,string,False
1 agent_name system_prompt model_name max_loops autosave dashboard verbose dynamic_temperature saved_state_path user_name retry_attempts context_length return_step_meta output_type streaming
2 Financial-Analysis-Agent You are a financial expert... gpt-4o-mini 1 True False True True finance_agent.json swarms_corp 3 200000 False string False

@ -1,7 +1,11 @@
from typing import List, Dict, Optional, TypedDict, Literal, Union, Any from typing import (
List,
Dict,
TypedDict,
Any,
)
from dataclasses import dataclass from dataclasses import dataclass
import csv import csv
import os
from pathlib import Path from pathlib import Path
import logging import logging
from enum import Enum from enum import Enum
@ -11,8 +15,10 @@ from swarms import Agent
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ModelName(str, Enum): class ModelName(str, Enum):
"""Valid model names for swarms agents""" """Valid model names for swarms agents"""
GPT4O = "gpt-4o" GPT4O = "gpt-4o"
GPT4O_MINI = "gpt-4o-mini" GPT4O_MINI = "gpt-4o-mini"
GPT4 = "gpt-4" GPT4 = "gpt-4"
@ -30,8 +36,10 @@ class ModelName(str, Enum):
"""Check if model name is valid""" """Check if model name is valid"""
return model_name in cls.get_model_names() return model_name in cls.get_model_names()
class AgentConfigDict(TypedDict): class AgentConfigDict(TypedDict):
"""TypedDict for agent configuration""" """TypedDict for agent configuration"""
agent_name: str agent_name: str
system_prompt: str system_prompt: str
model_name: str # Using str instead of ModelName for flexibility model_name: str # Using str instead of ModelName for flexibility
@ -48,9 +56,11 @@ class AgentConfigDict(TypedDict):
output_type: str output_type: str
streaming: bool streaming: bool
@dataclass @dataclass
class AgentValidationError(Exception): class AgentValidationError(Exception):
"""Custom exception for agent validation errors""" """Custom exception for agent validation errors"""
message: str message: str
field: str field: str
value: Any value: Any
@ -58,6 +68,7 @@ class AgentValidationError(Exception):
def __str__(self) -> str: def __str__(self) -> str:
return f"Validation error in field '{self.field}': {self.message}. Got value: {self.value}" return f"Validation error in field '{self.field}': {self.message}. Got value: {self.value}"
class AgentValidator: class AgentValidator:
"""Validates agent configuration data""" """Validates agent configuration data"""
@ -66,46 +77,75 @@ class AgentValidator:
"""Validate and convert agent configuration""" """Validate and convert agent configuration"""
try: try:
# Validate model name # Validate model name
model_name = str(config['model_name']) model_name = str(config["model_name"])
if not ModelName.is_valid_model(model_name): if not ModelName.is_valid_model(model_name):
valid_models = ModelName.get_model_names() valid_models = ModelName.get_model_names()
raise AgentValidationError( raise AgentValidationError(
f"Invalid model name. Must be one of: {', '.join(valid_models)}", f"Invalid model name. Must be one of: {', '.join(valid_models)}",
"model_name", "model_name",
model_name model_name,
) )
# Convert types with error handling # Convert types with error handling
validated_config: AgentConfigDict = { validated_config: AgentConfigDict = {
'agent_name': str(config.get('agent_name', '')), "agent_name": str(config.get("agent_name", "")),
'system_prompt': str(config.get('system_prompt', '')), "system_prompt": str(config.get("system_prompt", "")),
'model_name': model_name, "model_name": model_name,
'max_loops': int(config.get('max_loops', 1)), "max_loops": int(config.get("max_loops", 1)),
'autosave': bool(str(config.get('autosave', True)).lower() == 'true'), "autosave": bool(
'dashboard': bool(str(config.get('dashboard', False)).lower() == 'true'), str(config.get("autosave", True)).lower()
'verbose': bool(str(config.get('verbose', True)).lower() == 'true'), == "true"
'dynamic_temperature': bool(str(config.get('dynamic_temperature', True)).lower() == 'true'), ),
'saved_state_path': str(config.get('saved_state_path', '')), "dashboard": bool(
'user_name': str(config.get('user_name', 'default_user')), str(config.get("dashboard", False)).lower()
'retry_attempts': int(config.get('retry_attempts', 3)), == "true"
'context_length': int(config.get('context_length', 200000)), ),
'return_step_meta': bool(str(config.get('return_step_meta', False)).lower() == 'true'), "verbose": bool(
'output_type': str(config.get('output_type', 'string')), str(config.get("verbose", True)).lower() == "true"
'streaming': bool(str(config.get('streaming', False)).lower() == 'true') ),
"dynamic_temperature": bool(
str(
config.get("dynamic_temperature", True)
).lower()
== "true"
),
"saved_state_path": str(
config.get("saved_state_path", "")
),
"user_name": str(
config.get("user_name", "default_user")
),
"retry_attempts": int(
config.get("retry_attempts", 3)
),
"context_length": int(
config.get("context_length", 200000)
),
"return_step_meta": bool(
str(config.get("return_step_meta", False)).lower()
== "true"
),
"output_type": str(
config.get("output_type", "string")
),
"streaming": bool(
str(config.get("streaming", False)).lower()
== "true"
),
} }
return validated_config return validated_config
except (ValueError, KeyError) as e: except (ValueError, KeyError) as e:
raise AgentValidationError( raise AgentValidationError(
str(e), str(e), str(e.__class__.__name__), str(config)
str(e.__class__.__name__),
str(config)
) )
@dataclass @dataclass
class AgentCSV: class AgentCSV:
"""Class to manage agents through CSV with type safety""" """Class to manage agents through CSV with type safety"""
csv_path: Path csv_path: Path
def __post_init__(self) -> None: def __post_init__(self) -> None:
@ -117,10 +157,21 @@ class AgentCSV:
def headers(self) -> List[str]: def headers(self) -> List[str]:
"""CSV headers for agent configuration""" """CSV headers for agent configuration"""
return [ return [
"agent_name", "system_prompt", "model_name", "max_loops", "agent_name",
"autosave", "dashboard", "verbose", "dynamic_temperature", "system_prompt",
"saved_state_path", "user_name", "retry_attempts", "context_length", "model_name",
"return_step_meta", "output_type", "streaming" "max_loops",
"autosave",
"dashboard",
"verbose",
"dynamic_temperature",
"saved_state_path",
"user_name",
"retry_attempts",
"context_length",
"return_step_meta",
"output_type",
"streaming",
] ]
def create_agent_csv(self, agents: List[Dict[str, Any]]) -> None: def create_agent_csv(self, agents: List[Dict[str, Any]]) -> None:
@ -128,65 +179,96 @@ class AgentCSV:
validated_agents = [] validated_agents = []
for agent in agents: for agent in agents:
try: try:
validated_config = AgentValidator.validate_config(agent) validated_config = AgentValidator.validate_config(
agent
)
validated_agents.append(validated_config) validated_agents.append(validated_config)
except AgentValidationError as e: except AgentValidationError as e:
logger.error(f"Validation error for agent {agent.get('agent_name', 'unknown')}: {e}") logger.error(
f"Validation error for agent {agent.get('agent_name', 'unknown')}: {e}"
)
raise raise
with open(self.csv_path, 'w', newline='') as f: with open(self.csv_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=self.headers) writer = csv.DictWriter(f, fieldnames=self.headers)
writer.writeheader() writer.writeheader()
writer.writerows(validated_agents) writer.writerows(validated_agents)
logger.info(f"Created CSV with {len(validated_agents)} agents at {self.csv_path}") logger.info(
f"Created CSV with {len(validated_agents)} agents at {self.csv_path}"
)
def load_agents(self) -> List[Agent]: def load_agents(self) -> List[Agent]:
"""Load and create agents from CSV with validation""" """Load and create agents from CSV with validation"""
if not self.csv_path.exists(): if not self.csv_path.exists():
raise FileNotFoundError(f"CSV file not found at {self.csv_path}") raise FileNotFoundError(
f"CSV file not found at {self.csv_path}"
)
agents: List[Agent] = [] agents: List[Agent] = []
with open(self.csv_path, 'r') as f: with open(self.csv_path, "r") as f:
reader = csv.DictReader(f) reader = csv.DictReader(f)
for row in reader: for row in reader:
try: try:
validated_config = AgentValidator.validate_config(row) validated_config = AgentValidator.validate_config(
row
)
agent = Agent( agent = Agent(
agent_name=validated_config['agent_name'], agent_name=validated_config["agent_name"],
system_prompt=validated_config['system_prompt'], system_prompt=validated_config[
model_name=validated_config['model_name'], "system_prompt"
max_loops=validated_config['max_loops'], ],
autosave=validated_config['autosave'], model_name=validated_config["model_name"],
dashboard=validated_config['dashboard'], max_loops=validated_config["max_loops"],
verbose=validated_config['verbose'], autosave=validated_config["autosave"],
dynamic_temperature_enabled=validated_config['dynamic_temperature'], dashboard=validated_config["dashboard"],
saved_state_path=validated_config['saved_state_path'], verbose=validated_config["verbose"],
user_name=validated_config['user_name'], dynamic_temperature_enabled=validated_config[
retry_attempts=validated_config['retry_attempts'], "dynamic_temperature"
context_length=validated_config['context_length'], ],
return_step_meta=validated_config['return_step_meta'], saved_state_path=validated_config[
output_type=validated_config['output_type'], "saved_state_path"
streaming_on=validated_config['streaming'] ],
user_name=validated_config["user_name"],
retry_attempts=validated_config[
"retry_attempts"
],
context_length=validated_config[
"context_length"
],
return_step_meta=validated_config[
"return_step_meta"
],
output_type=validated_config["output_type"],
streaming_on=validated_config["streaming"],
) )
agents.append(agent) agents.append(agent)
except AgentValidationError as e: except AgentValidationError as e:
logger.error(f"Skipping invalid agent configuration: {e}") logger.error(
f"Skipping invalid agent configuration: {e}"
)
continue continue
logger.info(f"Loaded {len(agents)} agents from {self.csv_path}") logger.info(
f"Loaded {len(agents)} agents from {self.csv_path}"
)
return agents return agents
def add_agent(self, agent_config: Dict[str, Any]) -> None: def add_agent(self, agent_config: Dict[str, Any]) -> None:
"""Add a new validated agent configuration to CSV""" """Add a new validated agent configuration to CSV"""
validated_config = AgentValidator.validate_config(agent_config) validated_config = AgentValidator.validate_config(
agent_config
)
with open(self.csv_path, 'a', newline='') as f: with open(self.csv_path, "a", newline="") as f:
writer = csv.DictWriter(f, fieldnames=self.headers) writer = csv.DictWriter(f, fieldnames=self.headers)
writer.writerow(validated_config) writer.writerow(validated_config)
logger.info(f"Added new agent {validated_config['agent_name']} to {self.csv_path}") logger.info(
f"Added new agent {validated_config['agent_name']} to {self.csv_path}"
)
# Example usage # Example usage
if __name__ == "__main__": if __name__ == "__main__":
@ -207,7 +289,7 @@ if __name__ == "__main__":
"context_length": 200000, "context_length": 200000,
"return_step_meta": False, "return_step_meta": False,
"output_type": "string", "output_type": "string",
"streaming": False "streaming": False,
} }
] ]

@ -11,7 +11,7 @@ Todo:
import os import os
import subprocess import subprocess
import uuid import uuid
from datetime import UTC, datetime from datetime import datetime
from typing import List, Literal, Optional from typing import List, Literal, Optional
from loguru import logger from loguru import logger
@ -241,7 +241,7 @@ class MultiAgentRouter:
dict: A dictionary containing the routing result, including the selected agent, reasoning, and response. dict: A dictionary containing the routing result, including the selected agent, reasoning, and response.
""" """
try: try:
start_time = datetime.now(UTC) start_time = datetime.now()
# Get boss decision using function calling # Get boss decision using function calling
boss_response = self.function_caller.get_completion(task) boss_response = self.function_caller.get_completion(task)
@ -259,7 +259,7 @@ class MultiAgentRouter:
final_task = boss_response.modified_task or task final_task = boss_response.modified_task or task
# Execute the task with the selected agent if enabled # Execute the task with the selected agent if enabled
execution_start = datetime.now(UTC) execution_start = datetime.now()
agent_response = None agent_response = None
execution_time = 0 execution_time = 0
@ -267,20 +267,18 @@ class MultiAgentRouter:
# Use the agent's run method directly # Use the agent's run method directly
agent_response = selected_agent.run(final_task) agent_response = selected_agent.run(final_task)
execution_time = ( execution_time = (
datetime.now(UTC) - execution_start datetime.now() - execution_start
).total_seconds() ).total_seconds()
else: else:
logger.info( logger.info(
"Task execution skipped (execute_task=False)" "Task execution skipped (execute_task=False)"
) )
total_time = ( total_time = (datetime.now() - start_time).total_seconds()
datetime.now(UTC) - start_time
).total_seconds()
result = { result = {
"id": str(uuid.uuid4()), "id": str(uuid.uuid4()),
"timestamp": datetime.now(UTC).isoformat(), "timestamp": datetime.now().isoformat(),
"task": { "task": {
"original": task, "original": task,
"modified": ( "modified": (

Loading…
Cancel
Save