From a912702f90af5192d984b3938e493ef867b4badd Mon Sep 17 00:00:00 2001
From: harshalmore31 <86048671+harshalmore31@users.noreply.github.com>
Date: Sat, 8 Feb 2025 00:46:49 +0530
Subject: [PATCH 1/2] add token counting functionality to message and step
 schemas

---
 swarms/schemas/base_schemas.py       |  29 +++++
 swarms/schemas/base_swarm_schemas.py | 152 +++++++++++++++++++++++++++
 swarms/schemas/output_schemas.py     |  90 ++++++++++++++++
 3 files changed, 271 insertions(+)
 create mode 100644 swarms/schemas/base_swarm_schemas.py
 create mode 100644 swarms/schemas/output_schemas.py

diff --git a/swarms/schemas/base_schemas.py b/swarms/schemas/base_schemas.py
index 2669f8d4..640acf22 100644
--- a/swarms/schemas/base_schemas.py
+++ b/swarms/schemas/base_schemas.py
@@ -3,6 +3,7 @@ import time
 from typing import List, Literal, Optional, Union
 
 from pydantic import BaseModel, Field
+from swarms.utils.litellm_tokenizer import count_tokens
 
 
 class ModelCard(BaseModel):
@@ -49,6 +50,18 @@ class ChatMessageInput(BaseModel):
     )
     content: Union[str, List[ContentItem]]
 
+    def count_tokens(self, model: str = "gpt-4o") -> int:
+        """Count tokens in the message content"""
+        if isinstance(self.content, str):
+            return count_tokens(self.content, model)
+        elif isinstance(self.content, list):
+            total = 0
+            for item in self.content:
+                if isinstance(item, TextContent):
+                    total += count_tokens(item.text, model)
+            return total
+        return 0
+
 
 class ChatMessageResponse(BaseModel):
     role: str = Field(
@@ -92,6 +105,22 @@ class UsageInfo(BaseModel):
     total_tokens: int = 0
     completion_tokens: Optional[int] = 0
 
+    @classmethod
+    def calculate_usage(
+        cls,
+        messages: List[ChatMessageInput],
+        completion: Optional[str] = None,
+        model: str = "gpt-4o"
+    ) -> "UsageInfo":
+        """Calculate token usage for messages and completion"""
+        prompt_tokens = sum(msg.count_tokens(model) for msg in messages)
+        completion_tokens = count_tokens(completion, model) if completion else 0
+        return cls(
+            prompt_tokens=prompt_tokens,
+            completion_tokens=completion_tokens,
+            total_tokens=prompt_tokens + completion_tokens
+        )
+
 
 class ChatCompletionResponse(BaseModel):
     model: str
diff --git a/swarms/schemas/base_swarm_schemas.py b/swarms/schemas/base_swarm_schemas.py
new file mode 100644
index 00000000..da9f7c31
--- /dev/null
+++ b/swarms/schemas/base_swarm_schemas.py
@@ -0,0 +1,152 @@
+from typing import Any, Dict, List, Optional, Union
+from pydantic import BaseModel, Field, validator
+import uuid
+import time
+
+class AgentInputConfig(BaseModel):
+    """
+    Configuration for an agent. This can be further customized
+    per agent type if needed.
+    """
+    agent_name: str = Field(..., description="Name of the agent")
+    agent_type: str = Field(..., description="Type of agent (e.g. 'llm', 'tool', 'memory')")
+    model_name: Optional[str] = Field(None, description="Name of the model to use")
+    temperature: float = Field(0.7, description="Temperature for model sampling")
+    max_tokens: int = Field(4096, description="Maximum tokens for model response")
+    system_prompt: Optional[str] = Field(None, description="System prompt for the agent")
+    tools: Optional[List[str]] = Field(None, description="List of tool names available to agent")
+    memory_type: Optional[str] = Field(None, description="Type of memory to use")
+    metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional agent metadata")
+
+class BaseSwarmSchema(BaseModel):
+    """
+    Base schema for all swarm types.
+    """
+    id: str = Field(default_factory=lambda: str(uuid.uuid4()))
+    name: str
+    description: str
+    agents: List[AgentInputConfig]  # Using AgentInputConfig
+    max_loops: int = 1
+    swarm_type: str  # e.g., "SequentialWorkflow", "ConcurrentWorkflow", etc.
+    created_at: str = Field(default_factory=lambda: time.strftime("%Y-%m-%d %H:%M:%S"))
+    config: Dict[str, Any] = Field(default_factory=dict)  # Flexible config
+    
+    # Additional fields
+    timeout: Optional[int] = Field(None, description="Timeout in seconds for swarm execution")
+    error_handling: str = Field("stop", description="Error handling strategy: 'stop', 'continue', or 'retry'")
+    max_retries: int = Field(3, description="Maximum number of retry attempts")
+    logging_level: str = Field("info", description="Logging level for the swarm")
+    metrics_enabled: bool = Field(True, description="Whether to collect metrics")
+    tags: List[str] = Field(default_factory=list, description="Tags for categorizing swarms")
+    
+    @validator("swarm_type")
+    def validate_swarm_type(cls, v):
+        """Validates the swarm type is one of the allowed types"""
+        allowed_types = [
+            "SequentialWorkflow",
+            "ConcurrentWorkflow", 
+            "AgentRearrange",
+            "MixtureOfAgents",
+            "SpreadSheetSwarm",
+            "AutoSwarm",
+            "HierarchicalSwarm",
+            "FeedbackSwarm"
+        ]
+        if v not in allowed_types:
+            raise ValueError(f"Swarm type must be one of: {allowed_types}")
+        return v
+
+    @validator("config")
+    def validate_config(cls, v, values):
+        """
+        Validates the 'config' dictionary based on the 'swarm_type'.
+        """
+        swarm_type = values.get("swarm_type")
+        
+        # Common validation for all swarm types
+        if not isinstance(v, dict):
+            raise ValueError("Config must be a dictionary")
+
+        # Type-specific validation
+        if swarm_type == "SequentialWorkflow":
+            if "flow" not in v:
+                raise ValueError("SequentialWorkflow requires a 'flow' configuration.")
+            if not isinstance(v["flow"], list):
+                raise ValueError("Flow configuration must be a list")
+
+        elif swarm_type == "ConcurrentWorkflow":
+            if "max_workers" not in v:
+                raise ValueError("ConcurrentWorkflow requires a 'max_workers' configuration.")
+            if not isinstance(v["max_workers"], int) or v["max_workers"] < 1:
+                raise ValueError("max_workers must be a positive integer")
+
+        elif swarm_type == "AgentRearrange":
+            if "flow" not in v:
+                raise ValueError("AgentRearrange requires a 'flow' configuration.")
+            if not isinstance(v["flow"], list):
+                raise ValueError("Flow configuration must be a list")
+
+        elif swarm_type == "MixtureOfAgents":
+            if "aggregator_agent" not in v:
+                raise ValueError("MixtureOfAgents requires an 'aggregator_agent' configuration.")
+            if "voting_method" not in v:
+                v["voting_method"] = "majority"  # Set default voting method
+
+        elif swarm_type == "SpreadSheetSwarm":
+            if "save_file_path" not in v:
+                raise ValueError("SpreadSheetSwarm requires a 'save_file_path' configuration.")
+            if not isinstance(v["save_file_path"], str):
+                raise ValueError("save_file_path must be a string")
+
+        elif swarm_type == "AutoSwarm":
+            if "optimization_metric" not in v:
+                v["optimization_metric"] = "performance"  # Set default metric
+            if "adaptation_strategy" not in v:
+                v["adaptation_strategy"] = "dynamic"  # Set default strategy
+
+        elif swarm_type == "HierarchicalSwarm":
+            if "hierarchy_levels" not in v:
+                raise ValueError("HierarchicalSwarm requires 'hierarchy_levels' configuration.")
+            if not isinstance(v["hierarchy_levels"], int) or v["hierarchy_levels"] < 1:
+                raise ValueError("hierarchy_levels must be a positive integer")
+
+        elif swarm_type == "FeedbackSwarm":
+            if "feedback_collection" not in v:
+                v["feedback_collection"] = "continuous"  # Set default collection method
+            if "feedback_integration" not in v:
+                v["feedback_integration"] = "weighted"  # Set default integration method
+
+        return v
+
+    @validator("error_handling")
+    def validate_error_handling(cls, v):
+        """Validates error handling strategy"""
+        allowed_strategies = ["stop", "continue", "retry"]
+        if v not in allowed_strategies:
+            raise ValueError(f"Error handling must be one of: {allowed_strategies}")
+        return v
+
+    @validator("logging_level")
+    def validate_logging_level(cls, v):
+        """Validates logging level"""
+        allowed_levels = ["debug", "info", "warning", "error", "critical"]
+        if v.lower() not in allowed_levels:
+            raise ValueError(f"Logging level must be one of: {allowed_levels}")
+        return v.lower()
+
+    def get_agent_by_name(self, name: str) -> Optional[AgentInputConfig]:
+        """Helper method to get agent config by name"""
+        for agent in self.agents:
+            if agent.agent_name == name:
+                return agent
+        return None
+
+    def add_tag(self, tag: str) -> None:
+        """Helper method to add a tag"""
+        if tag not in self.tags:
+            self.tags.append(tag)
+
+    def remove_tag(self, tag: str) -> None:
+        """Helper method to remove a tag"""
+        if tag in self.tags:
+            self.tags.remove(tag)
\ No newline at end of file
diff --git a/swarms/schemas/output_schemas.py b/swarms/schemas/output_schemas.py
new file mode 100644
index 00000000..afac3afb
--- /dev/null
+++ b/swarms/schemas/output_schemas.py
@@ -0,0 +1,90 @@
+from typing import Any, Dict, List, Optional
+from pydantic import BaseModel, Field
+import time
+import uuid
+from swarms.utils.litellm_tokenizer import count_tokens
+
+class Step(BaseModel):
+    """
+    Represents a single step in an agent's task execution.
+    """
+    id: str = Field(default_factory=lambda: str(uuid.uuid4()))
+    name: str = Field(..., description="Name of the agent")
+    task: Optional[str] = Field(None, description="Task given to the agent at this step")
+    input: Optional[str] = Field(None, description="Input provided to the agent at this step")
+    output: Optional[str] = Field(None, description="Output generated by the agent at this step")
+    error: Optional[str] = Field(None, description="Error message if any error occurred during the step")
+    start_time: str = Field(default_factory=lambda: time.strftime("%Y-%m-%d %H:%M:%S"))
+    end_time: Optional[str] = Field(None, description="End time of the step")
+    runtime: Optional[float] = Field(None, description="Runtime of the step in seconds")
+    tokens_used: Optional[int] = Field(None, description="Number of tokens used in this step")
+    cost: Optional[float] = Field(None, description="Cost of the step")
+    metadata: Optional[Dict[str, Any]] = Field(
+        None, description="Additional metadata about the step"
+    )
+
+    def calculate_tokens(self, model: str = "gpt-4o") -> int:
+        """Calculate total tokens used in this step"""
+        total = 0
+        if self.input:
+            total += count_tokens(self.input, model)
+        if self.output:
+            total += count_tokens(self.output, model)
+        self.tokens_used = total
+        return total
+
+class AgentTaskOutput(BaseModel):
+    """
+    Represents the output of an agent's execution.
+    """
+    id: str = Field(default_factory=lambda: str(uuid.uuid4()))
+    agent_name: str = Field(..., description="Name of the agent")
+    task: str = Field(..., description="The task agent was asked to perform")
+    steps: List[Step] = Field(..., description="List of steps taken by the agent")
+    start_time: str = Field(default_factory=lambda: time.strftime("%Y-%m-%d %H:%M:%S"))
+    end_time: Optional[str] = Field(None, description="End time of the agent's execution")
+    total_tokens: Optional[int] = Field(None, description="Total tokens used by the agent")
+    cost: Optional[float] = Field(None, description="Total cost of the agent execution")
+    # Add any other fields from ManySteps that are relevant, like full_history
+
+    def calculate_total_tokens(self, model: str = "gpt-4o") -> int:
+        """Calculate total tokens across all steps"""
+        total = 0
+        for step in self.steps:
+            total += step.calculate_tokens(model)
+        self.total_tokens = total
+        return total
+
+class OutputSchema(BaseModel):
+    """
+    Unified output schema for all swarm types.
+    """
+    swarm_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
+    swarm_type: str = Field(..., description="Type of the swarm")
+    task: str = Field(..., description="The task given to the swarm")
+    agent_outputs: List[AgentTaskOutput] = Field(..., description="List of agent outputs")
+    timestamp: str = Field(default_factory=lambda: time.strftime("%Y-%m-%d %H:%M:%S"))
+    swarm_specific_output: Optional[Dict] = Field(None, description="Additional data specific to the swarm type")
+
+class SwarmOutputFormatter:
+    """
+    Formatter class to transform raw swarm output into the OutputSchema format.
+    """
+
+    @staticmethod
+    def format_output(
+        swarm_id: str,
+        swarm_type: str,
+        task: str,
+        agent_outputs: List[AgentTaskOutput],
+        swarm_specific_output: Optional[Dict] = None,
+    ) -> str:
+        """Formats the output into a standardized JSON string."""
+        output = OutputSchema(
+            swarm_id=swarm_id,
+            swarm_type=swarm_type,
+            task=task,
+            agent_outputs=agent_outputs,
+            swarm_specific_output=swarm_specific_output,
+        )
+        return output.model_dump_json(indent=4)
\ No newline at end of file

From eb5eca7c80a0ac95a70cd459822dce35c2c49b46 Mon Sep 17 00:00:00 2001
From: harshalmore31 <harshalmore2468@gmail.com>
Date: Sat, 8 Feb 2025 00:57:56 +0530
Subject: [PATCH 2/2] Updated __init__.py

---
 swarms/schemas/__init__.py | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/swarms/schemas/__init__.py b/swarms/schemas/__init__.py
index f81ae400..b193483f 100644
--- a/swarms/schemas/__init__.py
+++ b/swarms/schemas/__init__.py
@@ -1,10 +1,12 @@
 from swarms.schemas.agent_step_schemas import Step, ManySteps
-
 from swarms.schemas.agent_input_schema import AgentSchema
-
+from swarms.schemas.base_swarm_schemas import BaseSwarmSchema
+from swarms.schemas.output_schemas import OutputSchema
 
 __all__ = [
     "Step",
     "ManySteps",
     "AgentSchema",
+    "BaseSwarmSchema",
+    "OutputSchema",
 ]