parent
fb494267eb
commit
a912702f90
@ -0,0 +1,152 @@
|
|||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
|
import uuid
|
||||||
|
import time
|
||||||
|
|
||||||
|
class AgentInputConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Configuration for an agent. This can be further customized
|
||||||
|
per agent type if needed.
|
||||||
|
"""
|
||||||
|
agent_name: str = Field(..., description="Name of the agent")
|
||||||
|
agent_type: str = Field(..., description="Type of agent (e.g. 'llm', 'tool', 'memory')")
|
||||||
|
model_name: Optional[str] = Field(None, description="Name of the model to use")
|
||||||
|
temperature: float = Field(0.7, description="Temperature for model sampling")
|
||||||
|
max_tokens: int = Field(4096, description="Maximum tokens for model response")
|
||||||
|
system_prompt: Optional[str] = Field(None, description="System prompt for the agent")
|
||||||
|
tools: Optional[List[str]] = Field(None, description="List of tool names available to agent")
|
||||||
|
memory_type: Optional[str] = Field(None, description="Type of memory to use")
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional agent metadata")
|
||||||
|
|
||||||
|
class BaseSwarmSchema(BaseModel):
|
||||||
|
"""
|
||||||
|
Base schema for all swarm types.
|
||||||
|
"""
|
||||||
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
agents: List[AgentInputConfig] # Using AgentInputConfig
|
||||||
|
max_loops: int = 1
|
||||||
|
swarm_type: str # e.g., "SequentialWorkflow", "ConcurrentWorkflow", etc.
|
||||||
|
created_at: str = Field(default_factory=lambda: time.strftime("%Y-%m-%d %H:%M:%S"))
|
||||||
|
config: Dict[str, Any] = Field(default_factory=dict) # Flexible config
|
||||||
|
|
||||||
|
# Additional fields
|
||||||
|
timeout: Optional[int] = Field(None, description="Timeout in seconds for swarm execution")
|
||||||
|
error_handling: str = Field("stop", description="Error handling strategy: 'stop', 'continue', or 'retry'")
|
||||||
|
max_retries: int = Field(3, description="Maximum number of retry attempts")
|
||||||
|
logging_level: str = Field("info", description="Logging level for the swarm")
|
||||||
|
metrics_enabled: bool = Field(True, description="Whether to collect metrics")
|
||||||
|
tags: List[str] = Field(default_factory=list, description="Tags for categorizing swarms")
|
||||||
|
|
||||||
|
@validator("swarm_type")
|
||||||
|
def validate_swarm_type(cls, v):
|
||||||
|
"""Validates the swarm type is one of the allowed types"""
|
||||||
|
allowed_types = [
|
||||||
|
"SequentialWorkflow",
|
||||||
|
"ConcurrentWorkflow",
|
||||||
|
"AgentRearrange",
|
||||||
|
"MixtureOfAgents",
|
||||||
|
"SpreadSheetSwarm",
|
||||||
|
"AutoSwarm",
|
||||||
|
"HierarchicalSwarm",
|
||||||
|
"FeedbackSwarm"
|
||||||
|
]
|
||||||
|
if v not in allowed_types:
|
||||||
|
raise ValueError(f"Swarm type must be one of: {allowed_types}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("config")
|
||||||
|
def validate_config(cls, v, values):
|
||||||
|
"""
|
||||||
|
Validates the 'config' dictionary based on the 'swarm_type'.
|
||||||
|
"""
|
||||||
|
swarm_type = values.get("swarm_type")
|
||||||
|
|
||||||
|
# Common validation for all swarm types
|
||||||
|
if not isinstance(v, dict):
|
||||||
|
raise ValueError("Config must be a dictionary")
|
||||||
|
|
||||||
|
# Type-specific validation
|
||||||
|
if swarm_type == "SequentialWorkflow":
|
||||||
|
if "flow" not in v:
|
||||||
|
raise ValueError("SequentialWorkflow requires a 'flow' configuration.")
|
||||||
|
if not isinstance(v["flow"], list):
|
||||||
|
raise ValueError("Flow configuration must be a list")
|
||||||
|
|
||||||
|
elif swarm_type == "ConcurrentWorkflow":
|
||||||
|
if "max_workers" not in v:
|
||||||
|
raise ValueError("ConcurrentWorkflow requires a 'max_workers' configuration.")
|
||||||
|
if not isinstance(v["max_workers"], int) or v["max_workers"] < 1:
|
||||||
|
raise ValueError("max_workers must be a positive integer")
|
||||||
|
|
||||||
|
elif swarm_type == "AgentRearrange":
|
||||||
|
if "flow" not in v:
|
||||||
|
raise ValueError("AgentRearrange requires a 'flow' configuration.")
|
||||||
|
if not isinstance(v["flow"], list):
|
||||||
|
raise ValueError("Flow configuration must be a list")
|
||||||
|
|
||||||
|
elif swarm_type == "MixtureOfAgents":
|
||||||
|
if "aggregator_agent" not in v:
|
||||||
|
raise ValueError("MixtureOfAgents requires an 'aggregator_agent' configuration.")
|
||||||
|
if "voting_method" not in v:
|
||||||
|
v["voting_method"] = "majority" # Set default voting method
|
||||||
|
|
||||||
|
elif swarm_type == "SpreadSheetSwarm":
|
||||||
|
if "save_file_path" not in v:
|
||||||
|
raise ValueError("SpreadSheetSwarm requires a 'save_file_path' configuration.")
|
||||||
|
if not isinstance(v["save_file_path"], str):
|
||||||
|
raise ValueError("save_file_path must be a string")
|
||||||
|
|
||||||
|
elif swarm_type == "AutoSwarm":
|
||||||
|
if "optimization_metric" not in v:
|
||||||
|
v["optimization_metric"] = "performance" # Set default metric
|
||||||
|
if "adaptation_strategy" not in v:
|
||||||
|
v["adaptation_strategy"] = "dynamic" # Set default strategy
|
||||||
|
|
||||||
|
elif swarm_type == "HierarchicalSwarm":
|
||||||
|
if "hierarchy_levels" not in v:
|
||||||
|
raise ValueError("HierarchicalSwarm requires 'hierarchy_levels' configuration.")
|
||||||
|
if not isinstance(v["hierarchy_levels"], int) or v["hierarchy_levels"] < 1:
|
||||||
|
raise ValueError("hierarchy_levels must be a positive integer")
|
||||||
|
|
||||||
|
elif swarm_type == "FeedbackSwarm":
|
||||||
|
if "feedback_collection" not in v:
|
||||||
|
v["feedback_collection"] = "continuous" # Set default collection method
|
||||||
|
if "feedback_integration" not in v:
|
||||||
|
v["feedback_integration"] = "weighted" # Set default integration method
|
||||||
|
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("error_handling")
|
||||||
|
def validate_error_handling(cls, v):
|
||||||
|
"""Validates error handling strategy"""
|
||||||
|
allowed_strategies = ["stop", "continue", "retry"]
|
||||||
|
if v not in allowed_strategies:
|
||||||
|
raise ValueError(f"Error handling must be one of: {allowed_strategies}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("logging_level")
|
||||||
|
def validate_logging_level(cls, v):
|
||||||
|
"""Validates logging level"""
|
||||||
|
allowed_levels = ["debug", "info", "warning", "error", "critical"]
|
||||||
|
if v.lower() not in allowed_levels:
|
||||||
|
raise ValueError(f"Logging level must be one of: {allowed_levels}")
|
||||||
|
return v.lower()
|
||||||
|
|
||||||
|
def get_agent_by_name(self, name: str) -> Optional[AgentInputConfig]:
|
||||||
|
"""Helper method to get agent config by name"""
|
||||||
|
for agent in self.agents:
|
||||||
|
if agent.agent_name == name:
|
||||||
|
return agent
|
||||||
|
return None
|
||||||
|
|
||||||
|
def add_tag(self, tag: str) -> None:
|
||||||
|
"""Helper method to add a tag"""
|
||||||
|
if tag not in self.tags:
|
||||||
|
self.tags.append(tag)
|
||||||
|
|
||||||
|
def remove_tag(self, tag: str) -> None:
|
||||||
|
"""Helper method to remove a tag"""
|
||||||
|
if tag in self.tags:
|
||||||
|
self.tags.remove(tag)
|
@ -0,0 +1,90 @@
|
|||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from swarms.utils.litellm_tokenizer import count_tokens
|
||||||
|
|
||||||
|
class Step(BaseModel):
|
||||||
|
"""
|
||||||
|
Represents a single step in an agent's task execution.
|
||||||
|
"""
|
||||||
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
name: str = Field(..., description="Name of the agent")
|
||||||
|
task: Optional[str] = Field(None, description="Task given to the agent at this step")
|
||||||
|
input: Optional[str] = Field(None, description="Input provided to the agent at this step")
|
||||||
|
output: Optional[str] = Field(None, description="Output generated by the agent at this step")
|
||||||
|
error: Optional[str] = Field(None, description="Error message if any error occurred during the step")
|
||||||
|
start_time: str = Field(default_factory=lambda: time.strftime("%Y-%m-%d %H:%M:%S"))
|
||||||
|
end_time: Optional[str] = Field(None, description="End time of the step")
|
||||||
|
runtime: Optional[float] = Field(None, description="Runtime of the step in seconds")
|
||||||
|
tokens_used: Optional[int] = Field(None, description="Number of tokens used in this step")
|
||||||
|
cost: Optional[float] = Field(None, description="Cost of the step")
|
||||||
|
metadata: Optional[Dict[str, Any]] = Field(
|
||||||
|
None, description="Additional metadata about the step"
|
||||||
|
)
|
||||||
|
|
||||||
|
def calculate_tokens(self, model: str = "gpt-4o") -> int:
|
||||||
|
"""Calculate total tokens used in this step"""
|
||||||
|
total = 0
|
||||||
|
if self.input:
|
||||||
|
total += count_tokens(self.input, model)
|
||||||
|
if self.output:
|
||||||
|
total += count_tokens(self.output, model)
|
||||||
|
self.tokens_used = total
|
||||||
|
return total
|
||||||
|
|
||||||
|
class AgentTaskOutput(BaseModel):
|
||||||
|
"""
|
||||||
|
Represents the output of an agent's execution.
|
||||||
|
"""
|
||||||
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
agent_name: str = Field(..., description="Name of the agent")
|
||||||
|
task: str = Field(..., description="The task agent was asked to perform")
|
||||||
|
steps: List[Step] = Field(..., description="List of steps taken by the agent")
|
||||||
|
start_time: str = Field(default_factory=lambda: time.strftime("%Y-%m-%d %H:%M:%S"))
|
||||||
|
end_time: Optional[str] = Field(None, description="End time of the agent's execution")
|
||||||
|
total_tokens: Optional[int] = Field(None, description="Total tokens used by the agent")
|
||||||
|
cost: Optional[float] = Field(None, description="Total cost of the agent execution")
|
||||||
|
# Add any other fields from ManySteps that are relevant, like full_history
|
||||||
|
|
||||||
|
def calculate_total_tokens(self, model: str = "gpt-4o") -> int:
|
||||||
|
"""Calculate total tokens across all steps"""
|
||||||
|
total = 0
|
||||||
|
for step in self.steps:
|
||||||
|
total += step.calculate_tokens(model)
|
||||||
|
self.total_tokens = total
|
||||||
|
return total
|
||||||
|
|
||||||
|
class OutputSchema(BaseModel):
|
||||||
|
"""
|
||||||
|
Unified output schema for all swarm types.
|
||||||
|
"""
|
||||||
|
swarm_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
swarm_type: str = Field(..., description="Type of the swarm")
|
||||||
|
task: str = Field(..., description="The task given to the swarm")
|
||||||
|
agent_outputs: List[AgentTaskOutput] = Field(..., description="List of agent outputs")
|
||||||
|
timestamp: str = Field(default_factory=lambda: time.strftime("%Y-%m-%d %H:%M:%S"))
|
||||||
|
swarm_specific_output: Optional[Dict] = Field(None, description="Additional data specific to the swarm type")
|
||||||
|
|
||||||
|
class SwarmOutputFormatter:
|
||||||
|
"""
|
||||||
|
Formatter class to transform raw swarm output into the OutputSchema format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_output(
|
||||||
|
swarm_id: str,
|
||||||
|
swarm_type: str,
|
||||||
|
task: str,
|
||||||
|
agent_outputs: List[AgentTaskOutput],
|
||||||
|
swarm_specific_output: Optional[Dict] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Formats the output into a standardized JSON string."""
|
||||||
|
output = OutputSchema(
|
||||||
|
swarm_id=swarm_id,
|
||||||
|
swarm_type=swarm_type,
|
||||||
|
task=task,
|
||||||
|
agent_outputs=agent_outputs,
|
||||||
|
swarm_specific_output=swarm_specific_output,
|
||||||
|
)
|
||||||
|
return output.model_dump_json(indent=4)
|
Loading…
Reference in new issue