From 93a8a1f499c09d52e9d4a86db0ef58e09212d0ab Mon Sep 17 00:00:00 2001 From: harshalmore31 <86048671+harshalmore31@users.noreply.github.com> Date: Sat, 8 Feb 2025 00:41:49 +0530 Subject: [PATCH] Add token counting functionality using litellm tokenizer --- swarms/schemas/base_schemas.py | 29 +++++ swarms/schemas/base_swarm_schemas.py | 152 +++++++++++++++++++++++++++ swarms/schemas/output_schemas.py | 90 ++++++++++++++++ swarms/utils/litellm_tokenizer.py | 16 +++ 4 files changed, 287 insertions(+) create mode 100644 swarms/schemas/base_swarm_schemas.py create mode 100644 swarms/schemas/output_schemas.py create mode 100644 swarms/utils/litellm_tokenizer.py diff --git a/swarms/schemas/base_schemas.py b/swarms/schemas/base_schemas.py index 2669f8d4..640acf22 100644 --- a/swarms/schemas/base_schemas.py +++ b/swarms/schemas/base_schemas.py @@ -3,6 +3,7 @@ import time from typing import List, Literal, Optional, Union from pydantic import BaseModel, Field +from swarms.utils.litellm_tokenizer import count_tokens class ModelCard(BaseModel): @@ -49,6 +50,18 @@ class ChatMessageInput(BaseModel): ) content: Union[str, List[ContentItem]] + def count_tokens(self, model: str = "gpt-4o") -> int: + """Count tokens in the message content""" + if isinstance(self.content, str): + return count_tokens(self.content, model) + elif isinstance(self.content, list): + total = 0 + for item in self.content: + if isinstance(item, TextContent): + total += count_tokens(item.text, model) + return total + return 0 + class ChatMessageResponse(BaseModel): role: str = Field( @@ -92,6 +105,22 @@ class UsageInfo(BaseModel): total_tokens: int = 0 completion_tokens: Optional[int] = 0 + @classmethod + def calculate_usage( + cls, + messages: List[ChatMessageInput], + completion: Optional[str] = None, + model: str = "gpt-4o" + ) -> "UsageInfo": + """Calculate token usage for messages and completion""" + prompt_tokens = sum(msg.count_tokens(model) for msg in messages) + completion_tokens = count_tokens(completion, model) if completion else 0 + return cls( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens + ) + class ChatCompletionResponse(BaseModel): model: str diff --git a/swarms/schemas/base_swarm_schemas.py b/swarms/schemas/base_swarm_schemas.py new file mode 100644 index 00000000..da9f7c31 --- /dev/null +++ b/swarms/schemas/base_swarm_schemas.py @@ -0,0 +1,152 @@ +from typing import Any, Dict, List, Optional, Union +from pydantic import BaseModel, Field, validator +import uuid +import time + +class AgentInputConfig(BaseModel): + """ + Configuration for an agent. This can be further customized + per agent type if needed. + """ + agent_name: str = Field(..., description="Name of the agent") + agent_type: str = Field(..., description="Type of agent (e.g. 'llm', 'tool', 'memory')") + model_name: Optional[str] = Field(None, description="Name of the model to use") + temperature: float = Field(0.7, description="Temperature for model sampling") + max_tokens: int = Field(4096, description="Maximum tokens for model response") + system_prompt: Optional[str] = Field(None, description="System prompt for the agent") + tools: Optional[List[str]] = Field(None, description="List of tool names available to agent") + memory_type: Optional[str] = Field(None, description="Type of memory to use") + metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional agent metadata") + +class BaseSwarmSchema(BaseModel): + """ + Base schema for all swarm types. + """ + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + name: str + description: str + agents: List[AgentInputConfig] # Using AgentInputConfig + max_loops: int = 1 + swarm_type: str # e.g., "SequentialWorkflow", "ConcurrentWorkflow", etc. + created_at: str = Field(default_factory=lambda: time.strftime("%Y-%m-%d %H:%M:%S")) + config: Dict[str, Any] = Field(default_factory=dict) # Flexible config + + # Additional fields + timeout: Optional[int] = Field(None, description="Timeout in seconds for swarm execution") + error_handling: str = Field("stop", description="Error handling strategy: 'stop', 'continue', or 'retry'") + max_retries: int = Field(3, description="Maximum number of retry attempts") + logging_level: str = Field("info", description="Logging level for the swarm") + metrics_enabled: bool = Field(True, description="Whether to collect metrics") + tags: List[str] = Field(default_factory=list, description="Tags for categorizing swarms") + + @validator("swarm_type") + def validate_swarm_type(cls, v): + """Validates the swarm type is one of the allowed types""" + allowed_types = [ + "SequentialWorkflow", + "ConcurrentWorkflow", + "AgentRearrange", + "MixtureOfAgents", + "SpreadSheetSwarm", + "AutoSwarm", + "HierarchicalSwarm", + "FeedbackSwarm" + ] + if v not in allowed_types: + raise ValueError(f"Swarm type must be one of: {allowed_types}") + return v + + @validator("config") + def validate_config(cls, v, values): + """ + Validates the 'config' dictionary based on the 'swarm_type'. + """ + swarm_type = values.get("swarm_type") + + # Common validation for all swarm types + if not isinstance(v, dict): + raise ValueError("Config must be a dictionary") + + # Type-specific validation + if swarm_type == "SequentialWorkflow": + if "flow" not in v: + raise ValueError("SequentialWorkflow requires a 'flow' configuration.") + if not isinstance(v["flow"], list): + raise ValueError("Flow configuration must be a list") + + elif swarm_type == "ConcurrentWorkflow": + if "max_workers" not in v: + raise ValueError("ConcurrentWorkflow requires a 'max_workers' configuration.") + if not isinstance(v["max_workers"], int) or v["max_workers"] < 1: + raise ValueError("max_workers must be a positive integer") + + elif swarm_type == "AgentRearrange": + if "flow" not in v: + raise ValueError("AgentRearrange requires a 'flow' configuration.") + if not isinstance(v["flow"], list): + raise ValueError("Flow configuration must be a list") + + elif swarm_type == "MixtureOfAgents": + if "aggregator_agent" not in v: + raise ValueError("MixtureOfAgents requires an 'aggregator_agent' configuration.") + if "voting_method" not in v: + v["voting_method"] = "majority" # Set default voting method + + elif swarm_type == "SpreadSheetSwarm": + if "save_file_path" not in v: + raise ValueError("SpreadSheetSwarm requires a 'save_file_path' configuration.") + if not isinstance(v["save_file_path"], str): + raise ValueError("save_file_path must be a string") + + elif swarm_type == "AutoSwarm": + if "optimization_metric" not in v: + v["optimization_metric"] = "performance" # Set default metric + if "adaptation_strategy" not in v: + v["adaptation_strategy"] = "dynamic" # Set default strategy + + elif swarm_type == "HierarchicalSwarm": + if "hierarchy_levels" not in v: + raise ValueError("HierarchicalSwarm requires 'hierarchy_levels' configuration.") + if not isinstance(v["hierarchy_levels"], int) or v["hierarchy_levels"] < 1: + raise ValueError("hierarchy_levels must be a positive integer") + + elif swarm_type == "FeedbackSwarm": + if "feedback_collection" not in v: + v["feedback_collection"] = "continuous" # Set default collection method + if "feedback_integration" not in v: + v["feedback_integration"] = "weighted" # Set default integration method + + return v + + @validator("error_handling") + def validate_error_handling(cls, v): + """Validates error handling strategy""" + allowed_strategies = ["stop", "continue", "retry"] + if v not in allowed_strategies: + raise ValueError(f"Error handling must be one of: {allowed_strategies}") + return v + + @validator("logging_level") + def validate_logging_level(cls, v): + """Validates logging level""" + allowed_levels = ["debug", "info", "warning", "error", "critical"] + if v.lower() not in allowed_levels: + raise ValueError(f"Logging level must be one of: {allowed_levels}") + return v.lower() + + def get_agent_by_name(self, name: str) -> Optional[AgentInputConfig]: + """Helper method to get agent config by name""" + for agent in self.agents: + if agent.agent_name == name: + return agent + return None + + def add_tag(self, tag: str) -> None: + """Helper method to add a tag""" + if tag not in self.tags: + self.tags.append(tag) + + def remove_tag(self, tag: str) -> None: + """Helper method to remove a tag""" + if tag in self.tags: + self.tags.remove(tag) \ No newline at end of file diff --git a/swarms/schemas/output_schemas.py b/swarms/schemas/output_schemas.py new file mode 100644 index 00000000..afac3afb --- /dev/null +++ b/swarms/schemas/output_schemas.py @@ -0,0 +1,90 @@ +from typing import Any, Dict, List, Optional +from pydantic import BaseModel, Field +import time +import uuid +from swarms.utils.litellm_tokenizer import count_tokens + +class Step(BaseModel): + """ + Represents a single step in an agent's task execution. + """ + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + name: str = Field(..., description="Name of the agent") + task: Optional[str] = Field(None, description="Task given to the agent at this step") + input: Optional[str] = Field(None, description="Input provided to the agent at this step") + output: Optional[str] = Field(None, description="Output generated by the agent at this step") + error: Optional[str] = Field(None, description="Error message if any error occurred during the step") + start_time: str = Field(default_factory=lambda: time.strftime("%Y-%m-%d %H:%M:%S")) + end_time: Optional[str] = Field(None, description="End time of the step") + runtime: Optional[float] = Field(None, description="Runtime of the step in seconds") + tokens_used: Optional[int] = Field(None, description="Number of tokens used in this step") + cost: Optional[float] = Field(None, description="Cost of the step") + metadata: Optional[Dict[str, Any]] = Field( + None, description="Additional metadata about the step" + ) + + def calculate_tokens(self, model: str = "gpt-4o") -> int: + """Calculate total tokens used in this step""" + total = 0 + if self.input: + total += count_tokens(self.input, model) + if self.output: + total += count_tokens(self.output, model) + self.tokens_used = total + return total + +class AgentTaskOutput(BaseModel): + """ + Represents the output of an agent's execution. + """ + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + agent_name: str = Field(..., description="Name of the agent") + task: str = Field(..., description="The task agent was asked to perform") + steps: List[Step] = Field(..., description="List of steps taken by the agent") + start_time: str = Field(default_factory=lambda: time.strftime("%Y-%m-%d %H:%M:%S")) + end_time: Optional[str] = Field(None, description="End time of the agent's execution") + total_tokens: Optional[int] = Field(None, description="Total tokens used by the agent") + cost: Optional[float] = Field(None, description="Total cost of the agent execution") + # Add any other fields from ManySteps that are relevant, like full_history + + def calculate_total_tokens(self, model: str = "gpt-4o") -> int: + """Calculate total tokens across all steps""" + total = 0 + for step in self.steps: + total += step.calculate_tokens(model) + self.total_tokens = total + return total + +class OutputSchema(BaseModel): + """ + Unified output schema for all swarm types. + """ + swarm_id: str = Field(default_factory=lambda: str(uuid.uuid4())) + swarm_type: str = Field(..., description="Type of the swarm") + task: str = Field(..., description="The task given to the swarm") + agent_outputs: List[AgentTaskOutput] = Field(..., description="List of agent outputs") + timestamp: str = Field(default_factory=lambda: time.strftime("%Y-%m-%d %H:%M:%S")) + swarm_specific_output: Optional[Dict] = Field(None, description="Additional data specific to the swarm type") + +class SwarmOutputFormatter: + """ + Formatter class to transform raw swarm output into the OutputSchema format. + """ + + @staticmethod + def format_output( + swarm_id: str, + swarm_type: str, + task: str, + agent_outputs: List[AgentTaskOutput], + swarm_specific_output: Optional[Dict] = None, + ) -> str: + """Formats the output into a standardized JSON string.""" + output = OutputSchema( + swarm_id=swarm_id, + swarm_type=swarm_type, + task=task, + agent_outputs=agent_outputs, + swarm_specific_output=swarm_specific_output, + ) + return output.model_dump_json(indent=4) \ No newline at end of file diff --git a/swarms/utils/litellm_tokenizer.py b/swarms/utils/litellm_tokenizer.py new file mode 100644 index 00000000..2d7cae4c --- /dev/null +++ b/swarms/utils/litellm_tokenizer.py @@ -0,0 +1,16 @@ +import subprocess + + +def count_tokens(text: str, model: str = "gpt-4o") -> int: + """Count the number of tokens in the given text.""" + try: + from litellm import encode + except ImportError: + subprocess.run(["pip", "install", "litellm"]) + from litellm import encode + + return len(encode(model=model, text=text)) + + +# if __name__ == "__main__": +# print(count_tokens("Hello, how are you?"))