You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/swarms/structs/csv_to_agent.py

284 lines
9.2 KiB

from typing import (
List,
Dict,
TypedDict,
Any,
)
from dataclasses import dataclass
import csv
from pathlib import Path
from enum import Enum
from swarms.structs.agent import Agent
class ModelName(str, Enum):
"""Valid model names for swarms agents"""
GPT4O = "gpt-4o"
GPT4O_MINI = "gpt-4o-mini"
GPT4 = "gpt-4"
GPT35_TURBO = "gpt-3.5-turbo"
CLAUDE = "claude-v1"
CLAUDE2 = "claude-2"
@classmethod
def get_model_names(cls) -> List[str]:
"""Get list of valid model names"""
return [model.value for model in cls]
@classmethod
def is_valid_model(cls, model_name: str) -> bool:
"""Check if model name is valid"""
return model_name in cls.get_model_names()
class AgentConfigDict(TypedDict):
"""TypedDict for agent configuration"""
agent_name: str
system_prompt: str
model_name: str # Using str instead of ModelName for flexibility
max_loops: int
autosave: bool
dashboard: bool
verbose: bool
dynamic_temperature: bool
saved_state_path: str
user_name: str
retry_attempts: int
context_length: int
return_step_meta: bool
output_type: str
streaming: bool
@dataclass
class AgentValidationError(Exception):
"""Custom exception for agent validation errors"""
message: str
field: str
value: Any
def __str__(self) -> str:
return f"Validation error in field '{self.field}': {self.message}. Got value: {self.value}"
class AgentValidator:
"""Validates agent configuration data"""
@staticmethod
def validate_config(config: Dict[str, Any]) -> AgentConfigDict:
"""Validate and convert agent configuration"""
try:
# Validate model name
model_name = str(config["model_name"])
if not ModelName.is_valid_model(model_name):
valid_models = ModelName.get_model_names()
raise AgentValidationError(
f"Invalid model name. Must be one of: {', '.join(valid_models)}",
"model_name",
model_name,
)
# Convert types with error handling
validated_config: AgentConfigDict = {
"agent_name": str(config.get("agent_name", "")),
"system_prompt": str(config.get("system_prompt", "")),
"model_name": model_name,
"max_loops": int(config.get("max_loops", 1)),
"autosave": bool(
str(config.get("autosave", True)).lower()
== "true"
),
"dashboard": bool(
str(config.get("dashboard", False)).lower()
== "true"
),
"verbose": bool(
str(config.get("verbose", True)).lower() == "true"
),
"dynamic_temperature": bool(
str(
config.get("dynamic_temperature", True)
).lower()
== "true"
),
"saved_state_path": str(
config.get("saved_state_path", "")
),
"user_name": str(
config.get("user_name", "default_user")
),
"retry_attempts": int(
config.get("retry_attempts", 3)
),
"context_length": int(
config.get("context_length", 200000)
),
"return_step_meta": bool(
str(config.get("return_step_meta", False)).lower()
== "true"
),
"output_type": str(
config.get("output_type", "string")
),
"streaming": bool(
str(config.get("streaming", False)).lower()
== "true"
),
}
return validated_config
except (ValueError, KeyError) as e:
raise AgentValidationError(
str(e), str(e.__class__.__name__), str(config)
)
class AgentLoader:
"""Class to manage agents through CSV with type safety"""
csv_path: Path
def __post_init__(self) -> None:
"""Convert string path to Path object if necessary"""
if isinstance(self.csv_path, str):
self.csv_path = Path(self.csv_path)
@property
def headers(self) -> List[str]:
"""CSV headers for agent configuration"""
return [
"agent_name",
"system_prompt",
"model_name",
"max_loops",
"autosave",
"dashboard",
"verbose",
"dynamic_temperature",
"saved_state_path",
"user_name",
"retry_attempts",
"context_length",
"return_step_meta",
"output_type",
"streaming",
]
def create_agent_csv(self, agents: List[Dict[str, Any]]) -> None:
"""Create a CSV file with validated agent configurations"""
validated_agents = []
for agent in agents:
try:
validated_config = AgentValidator.validate_config(
agent
)
validated_agents.append(validated_config)
except AgentValidationError as e:
print(
f"Validation error for agent {agent.get('agent_name', 'unknown')}: {e}"
)
raise
with open(self.csv_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=self.headers)
writer.writeheader()
writer.writerows(validated_agents)
print(
f"Created CSV with {len(validated_agents)} agents at {self.csv_path}"
)
def load_agents(self, file_type: str = "csv") -> List[Agent]:
"""Load and create agents from CSV or JSON with validation"""
if file_type == "csv":
if not self.csv_path.exists():
raise FileNotFoundError(
f"CSV file not found at {self.csv_path}"
)
return self._load_agents_from_csv()
elif file_type == "json":
return self._load_agents_from_json()
else:
raise ValueError(
"Unsupported file type. Use 'csv' or 'json'."
)
def _load_agents_from_csv(self) -> List[Agent]:
"""Load agents from a CSV file"""
agents: List[Agent] = []
with open(self.csv_path, "r") as f:
reader = csv.DictReader(f)
for row in reader:
try:
validated_config = AgentValidator.validate_config(
row
)
agent = self._create_agent(validated_config)
agents.append(agent)
except AgentValidationError as e:
print(
f"Skipping invalid agent configuration: {e}"
)
continue
print(f"Loaded {len(agents)} agents from {self.csv_path}")
return agents
def _load_agents_from_json(self) -> List[Agent]:
"""Load agents from a JSON file"""
import json
if not self.csv_path.with_suffix(".json").exists():
raise FileNotFoundError(
f"JSON file not found at {self.csv_path.with_suffix('.json')}"
)
agents: List[Agent] = []
with open(self.csv_path.with_suffix(".json"), "r") as f:
agents_data = json.load(f)
for agent in agents_data:
try:
validated_config = AgentValidator.validate_config(
agent
)
agent = self._create_agent(validated_config)
agents.append(agent)
except AgentValidationError as e:
print(
f"Skipping invalid agent configuration: {e}"
)
continue
print(
f"Loaded {len(agents)} agents from {self.csv_path.with_suffix('.json')}"
)
return agents
def _create_agent(
self, validated_config: AgentConfigDict
) -> Agent:
"""Create an Agent instance from validated configuration"""
return Agent(
agent_name=validated_config["agent_name"],
system_prompt=validated_config["system_prompt"],
model_name=validated_config["model_name"],
max_loops=validated_config["max_loops"],
autosave=validated_config["autosave"],
dashboard=validated_config["dashboard"],
verbose=validated_config["verbose"],
dynamic_temperature_enabled=validated_config[
"dynamic_temperature"
],
saved_state_path=validated_config["saved_state_path"],
user_name=validated_config["user_name"],
retry_attempts=validated_config["retry_attempts"],
context_length=validated_config["context_length"],
return_step_meta=validated_config["return_step_meta"],
output_type=validated_config["output_type"],
streaming_on=validated_config["streaming"],
)