You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
351 lines
12 KiB
351 lines
12 KiB
import concurrent.futures
|
|
import csv
|
|
import json
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import (
|
|
Any,
|
|
Dict,
|
|
List,
|
|
TypedDict,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
import yaml
|
|
from litellm import model_list
|
|
from tqdm import tqdm
|
|
|
|
from swarms.schemas.swarms_api_schemas import AgentSpec
|
|
from swarms.structs.agent import Agent
|
|
|
|
# Type variable for agent configuration
|
|
AgentConfigType = TypeVar(
|
|
"AgentConfigType", bound=Union[AgentSpec, Dict[str, Any]]
|
|
)
|
|
|
|
|
|
class ModelName(str, Enum):
|
|
"""Valid model names for swarms agents"""
|
|
|
|
GPT4O = "gpt-4.1"
|
|
GPT4O_MINI = "gpt-4o-mini"
|
|
GPT4 = "gpt-4"
|
|
GPT35_TURBO = "gpt-3.5-turbo"
|
|
CLAUDE = "claude-v1"
|
|
CLAUDE2 = "claude-2"
|
|
|
|
@classmethod
|
|
def get_model_names(cls) -> List[str]:
|
|
"""Get list of valid model names"""
|
|
return [model.value for model in cls]
|
|
|
|
@classmethod
|
|
def is_valid_model(cls, model_name: str) -> bool:
|
|
"""Check if model name is valid"""
|
|
return model_name in cls.get_model_names()
|
|
|
|
|
|
class FileType(str, Enum):
|
|
"""Supported file types for agent configuration"""
|
|
|
|
CSV = "csv"
|
|
JSON = "json"
|
|
YAML = "yaml"
|
|
|
|
|
|
class AgentConfigDict(TypedDict):
|
|
"""TypedDict for agent configuration"""
|
|
|
|
agent_name: str
|
|
system_prompt: str
|
|
model_name: str
|
|
max_loops: int
|
|
autosave: bool
|
|
dashboard: bool
|
|
verbose: bool
|
|
dynamic_temperature: bool
|
|
saved_state_path: str
|
|
user_name: str
|
|
retry_attempts: int
|
|
context_length: int
|
|
return_step_meta: bool
|
|
output_type: str
|
|
streaming: bool
|
|
|
|
|
|
@dataclass
|
|
class AgentValidationError(Exception):
|
|
"""Custom exception for agent validation errors"""
|
|
|
|
message: str
|
|
field: str
|
|
value: Any
|
|
|
|
def __str__(self) -> str:
|
|
return f"Validation error in field '{self.field}': {self.message}. Got value: {self.value}"
|
|
|
|
|
|
class AgentValidator:
|
|
"""Validates agent configuration data"""
|
|
|
|
@staticmethod
|
|
def validate_config(
|
|
config: Union[AgentSpec, Dict[str, Any]],
|
|
) -> AgentConfigDict:
|
|
"""Validate and convert agent configuration from either AgentSpec or Dict"""
|
|
try:
|
|
# Convert AgentSpec to dict if needed
|
|
if isinstance(config, AgentSpec):
|
|
config = config.model_dump()
|
|
|
|
# Validate model name using litellm model list
|
|
model_name = str(config["model_name"])
|
|
# model_list from litellm is a list of strings, not dicts
|
|
if isinstance(model_list, list) and len(model_list) > 0:
|
|
if isinstance(model_list[0], str):
|
|
# model_list is list of strings
|
|
if not any(
|
|
model_name in model or model in model_name
|
|
for model in model_list
|
|
):
|
|
raise AgentValidationError(
|
|
"Invalid model name. Must be one of the supported litellm models",
|
|
"model_name",
|
|
model_name,
|
|
)
|
|
elif isinstance(model_list[0], dict):
|
|
# model_list is list of dicts (fallback for different litellm versions)
|
|
if not any(
|
|
model_name in model.get("model_name", "")
|
|
for model in model_list
|
|
):
|
|
raise AgentValidationError(
|
|
"Invalid model name. Must be one of the supported litellm models",
|
|
"model_name",
|
|
model_name,
|
|
)
|
|
|
|
# Convert types with error handling
|
|
validated_config: AgentConfigDict = {
|
|
"agent_name": str(config.get("agent_name", "")),
|
|
"system_prompt": str(config.get("system_prompt", "")),
|
|
"model_name": model_name,
|
|
"max_loops": int(config.get("max_loops", 1)),
|
|
"autosave": bool(
|
|
str(config.get("autosave", True)).lower()
|
|
== "true"
|
|
),
|
|
"dashboard": bool(
|
|
str(config.get("dashboard", False)).lower()
|
|
== "true"
|
|
),
|
|
"verbose": bool(
|
|
str(config.get("verbose", True)).lower() == "true"
|
|
),
|
|
"dynamic_temperature": bool(
|
|
str(
|
|
config.get("dynamic_temperature", True)
|
|
).lower()
|
|
== "true"
|
|
),
|
|
"saved_state_path": str(
|
|
config.get("saved_state_path", "")
|
|
),
|
|
"user_name": str(
|
|
config.get("user_name", "default_user")
|
|
),
|
|
"retry_attempts": int(
|
|
config.get("retry_attempts", 3)
|
|
),
|
|
"context_length": int(
|
|
config.get("context_length", 200000)
|
|
),
|
|
"return_step_meta": bool(
|
|
str(config.get("return_step_meta", False)).lower()
|
|
== "true"
|
|
),
|
|
"output_type": str(
|
|
config.get("output_type", "string")
|
|
),
|
|
"streaming": bool(
|
|
str(config.get("streaming", False)).lower()
|
|
== "true"
|
|
),
|
|
}
|
|
|
|
return validated_config
|
|
|
|
except (ValueError, KeyError) as e:
|
|
raise AgentValidationError(
|
|
str(e), str(e.__class__.__name__), str(config)
|
|
)
|
|
|
|
|
|
class CSVAgentLoader:
|
|
"""Class to manage agents through various file formats with type safety and high performance"""
|
|
|
|
def __init__(
|
|
self, file_path: Union[str, Path], max_workers: int = 10
|
|
):
|
|
"""Initialize the AgentLoader with file path and max workers for parallel processing"""
|
|
self.file_path = (
|
|
Path(file_path)
|
|
if isinstance(file_path, str)
|
|
else file_path
|
|
)
|
|
self.max_workers = max_workers
|
|
|
|
@property
|
|
def file_type(self) -> FileType:
|
|
"""Determine the file type based on extension"""
|
|
ext = self.file_path.suffix.lower()
|
|
if ext == ".csv":
|
|
return FileType.CSV
|
|
elif ext == ".json":
|
|
return FileType.JSON
|
|
elif ext in [".yaml", ".yml"]:
|
|
return FileType.YAML
|
|
else:
|
|
raise ValueError(f"Unsupported file type: {ext}")
|
|
|
|
def create_agent_file(
|
|
self, agents: List[Union[AgentSpec, Dict[str, Any]]]
|
|
) -> None:
|
|
"""Create a file with validated agent configurations"""
|
|
validated_agents = []
|
|
for agent in agents:
|
|
try:
|
|
validated_config = AgentValidator.validate_config(
|
|
agent
|
|
)
|
|
validated_agents.append(validated_config)
|
|
except AgentValidationError as e:
|
|
print(
|
|
f"Validation error for agent {agent.get('agent_name', 'unknown')}: {e}"
|
|
)
|
|
raise
|
|
|
|
if self.file_type == FileType.CSV:
|
|
self._write_csv(validated_agents)
|
|
elif self.file_type == FileType.JSON:
|
|
self._write_json(validated_agents)
|
|
elif self.file_type == FileType.YAML:
|
|
self._write_yaml(validated_agents)
|
|
|
|
print(
|
|
f"Created {self.file_type.value} file with {len(validated_agents)} agents at {self.file_path}"
|
|
)
|
|
|
|
def load_agents(self) -> List[Agent]:
|
|
"""Load and create agents from file with validation and parallel processing"""
|
|
if not self.file_path.exists():
|
|
raise FileNotFoundError(
|
|
f"File not found at {self.file_path}"
|
|
)
|
|
|
|
if self.file_type == FileType.CSV:
|
|
agents_data = self._read_csv()
|
|
elif self.file_type == FileType.JSON:
|
|
agents_data = self._read_json()
|
|
elif self.file_type == FileType.YAML:
|
|
agents_data = self._read_yaml()
|
|
|
|
# Process agents in parallel with progress bar
|
|
agents: List[Agent] = []
|
|
with concurrent.futures.ThreadPoolExecutor(
|
|
max_workers=self.max_workers
|
|
) as executor:
|
|
futures = []
|
|
for agent_data in agents_data:
|
|
futures.append(
|
|
executor.submit(self._process_agent, agent_data)
|
|
)
|
|
|
|
# Use tqdm to show progress
|
|
for future in tqdm(
|
|
concurrent.futures.as_completed(futures),
|
|
total=len(futures),
|
|
desc="Loading agents",
|
|
):
|
|
try:
|
|
agent = future.result()
|
|
if agent:
|
|
agents.append(agent)
|
|
except Exception as e:
|
|
print(f"Error processing agent: {e}")
|
|
|
|
print(f"Loaded {len(agents)} agents from {self.file_path}")
|
|
return agents
|
|
|
|
def _process_agent(
|
|
self, agent_data: Union[AgentSpec, Dict[str, Any]]
|
|
) -> Union[Agent, None]:
|
|
"""Process a single agent configuration"""
|
|
try:
|
|
validated_config = AgentValidator.validate_config(
|
|
agent_data
|
|
)
|
|
return self._create_agent(validated_config)
|
|
except AgentValidationError as e:
|
|
print(f"Skipping invalid agent configuration: {e}")
|
|
return None
|
|
|
|
def _create_agent(
|
|
self, validated_config: AgentConfigDict
|
|
) -> Agent:
|
|
"""Create an Agent instance from validated configuration"""
|
|
return Agent(
|
|
agent_name=validated_config["agent_name"],
|
|
system_prompt=validated_config["system_prompt"],
|
|
model_name=validated_config["model_name"],
|
|
max_loops=validated_config["max_loops"],
|
|
autosave=validated_config["autosave"],
|
|
dashboard=validated_config["dashboard"],
|
|
verbose=validated_config["verbose"],
|
|
dynamic_temperature_enabled=validated_config[
|
|
"dynamic_temperature"
|
|
],
|
|
saved_state_path=validated_config["saved_state_path"],
|
|
user_name=validated_config["user_name"],
|
|
retry_attempts=validated_config["retry_attempts"],
|
|
context_length=validated_config["context_length"],
|
|
return_step_meta=validated_config["return_step_meta"],
|
|
output_type=validated_config["output_type"],
|
|
streaming_on=validated_config["streaming"],
|
|
)
|
|
|
|
def _write_csv(self, agents: List[Dict[str, Any]]) -> None:
|
|
"""Write agents to CSV file"""
|
|
with open(self.file_path, "w", newline="") as f:
|
|
writer = csv.DictWriter(f, fieldnames=agents[0].keys())
|
|
writer.writeheader()
|
|
writer.writerows(agents)
|
|
|
|
def _write_json(self, agents: List[Dict[str, Any]]) -> None:
|
|
"""Write agents to JSON file"""
|
|
with open(self.file_path, "w") as f:
|
|
json.dump(agents, f, indent=2)
|
|
|
|
def _write_yaml(self, agents: List[Dict[str, Any]]) -> None:
|
|
"""Write agents to YAML file"""
|
|
with open(self.file_path, "w") as f:
|
|
yaml.dump(agents, f, default_flow_style=False)
|
|
|
|
def _read_csv(self) -> List[Dict[str, Any]]:
|
|
"""Read agents from CSV file"""
|
|
with open(self.file_path, "r") as f:
|
|
reader = csv.DictReader(f)
|
|
return list(reader)
|
|
|
|
def _read_json(self) -> List[Dict[str, Any]]:
|
|
"""Read agents from JSON file"""
|
|
with open(self.file_path, "r") as f:
|
|
return json.load(f)
|
|
|
|
def _read_yaml(self) -> List[Dict[str, Any]]:
|
|
"""Read agents from YAML file"""
|
|
with open(self.file_path, "r") as f:
|
|
return yaml.safe_load(f)
|