From 007eb5c011416def763356257903a8e229becb94 Mon Sep 17 00:00:00 2001 From: Kye Gomez Date: Mon, 3 Mar 2025 12:56:05 -0800 Subject: [PATCH] api client update --- swarms/structs/__init__.py | 17 +- swarms/structs/octotools.py | 275 ++++++--- swarms/structs/swarms_api.py | 1 - swarms_api_examples/swarms_api.py | 910 ------------------------------ 4 files changed, 201 insertions(+), 1002 deletions(-) delete mode 100644 swarms_api_examples/swarms_api.py diff --git a/swarms/structs/__init__.py b/swarms/structs/__init__.py index 547422a4..951dc203 100644 --- a/swarms/structs/__init__.py +++ b/swarms/structs/__init__.py @@ -73,15 +73,13 @@ from swarms.structs.swarming_architectures import ( star_swarm, ) from swarms.structs.swarms_api import ( - SwarmsAPIClient, - SwarmRequest, - SwarmAuthenticationError, + AgentInput, SwarmAPIError, + SwarmAuthenticationError, + SwarmRequest, + SwarmsAPIClient, SwarmValidationError, - AgentInput, ) -from swarms.structs.talk_hier import TalkHier, AgentRole, CommunicationEvent -from swarms.structs.octotools import OctoToolsSwarm, Tool, ToolType, get_default_tools __all__ = [ "Agent", @@ -148,13 +146,6 @@ __all__ = [ "MultiAgentRouter", "MemeAgentGenerator", "ModelRouter", - "OctoToolsSwarm", - "Tool", - "ToolType", - "get_default_tools", - "TalkHier", - "AgentRole", - "CommunicationEvent", "SwarmsAPIClient", "SwarmRequest", "SwarmAuthenticationError", diff --git a/swarms/structs/octotools.py b/swarms/structs/octotools.py index a6c38e2b..26c5678b 100644 --- a/swarms/structs/octotools.py +++ b/swarms/structs/octotools.py @@ -5,7 +5,6 @@ Implements the OctoTools framework using swarms. import json import logging -import os import re from dataclasses import dataclass from enum import Enum @@ -16,6 +15,7 @@ import math # Import the math module from dotenv import load_dotenv from swarms import Agent from swarms.structs.conversation import Conversation + # from exa_search import exa_search as web_search_execute @@ -59,7 +59,9 @@ class Tool: try: return self.execute_func(**kwargs) except Exception as e: - logger.error(f"Error executing tool {self.name}: {str(e)}") + logger.error( + f"Error executing tool {self.name}: {str(e)}" + ) return {"error": str(e)} @@ -92,9 +94,15 @@ class OctoToolsSwarm: """Initialize the OctoToolsSwarm system.""" self.model_name = model_name self.max_iterations = max_iterations - self.base_path = Path(base_path) if base_path else Path("./octotools_states") + self.base_path = ( + Path(base_path) + if base_path + else Path("./octotools_states") + ) self.base_path.mkdir(exist_ok=True) - self.tools = {tool.name: tool for tool in tools} # Store tools in a dictionary + self.tools = { + tool.name: tool for tool in tools + } # Store tools in a dictionary # Initialize agents self._init_agents() @@ -110,9 +118,9 @@ class OctoToolsSwarm: agent_name="OctoTools-Planner", system_prompt=self._get_planner_prompt(), model_name=self.model_name, - max_loops=3, + max_loops=3, saved_state_path=str(self.base_path / "planner.json"), - verbose=True, + verbose=True, ) # Verifier agent @@ -120,7 +128,7 @@ class OctoToolsSwarm: agent_name="OctoTools-Verifier", system_prompt=self._get_verifier_prompt(), model_name=self.model_name, - max_loops=1, + max_loops=1, saved_state_path=str(self.base_path / "verifier.json"), verbose=True, ) @@ -130,7 +138,7 @@ class OctoToolsSwarm: agent_name="OctoTools-Summarizer", system_prompt=self._get_summarizer_prompt(), model_name=self.model_name, - max_loops=1, + max_loops=1, saved_state_path=str(self.base_path / "summarizer.json"), verbose=True, ) @@ -337,22 +345,24 @@ class OctoToolsSwarm: try: return json.loads(json_str) except json.JSONDecodeError: - logger.warning(f"JSONDecodeError: Attempting to extract JSON from: {json_str}") + logger.warning( + f"JSONDecodeError: Attempting to extract JSON from: {json_str}" + ) try: # More robust JSON extraction with recursive descent def extract_json(s): stack = [] start = -1 for i, c in enumerate(s): - if c == '{': + if c == "{": if not stack: start = i stack.append(c) - elif c == '}': + elif c == "}": if stack: stack.pop() if not stack and start != -1: - return s[start:i+1] + return s[start : i + 1] return None extracted_json = extract_json(json_str) @@ -360,13 +370,23 @@ class OctoToolsSwarm: logger.info(f"Extracted JSON: {extracted_json}") return json.loads(extracted_json) else: - logger.error("Failed to extract JSON using recursive descent.") - return {"error": "Failed to parse JSON", "content": json_str} + logger.error( + "Failed to extract JSON using recursive descent." + ) + return { + "error": "Failed to parse JSON", + "content": json_str, + } except Exception as e: logger.exception(f"Error during JSON extraction: {e}") - return {"error": "Failed to parse JSON", "content": json_str} + return { + "error": "Failed to parse JSON", + "content": json_str, + } - def _execute_tool(self, tool_name: str, context: Dict[str, Any]) -> Dict[str, Any]: + def _execute_tool( + self, tool_name: str, context: Dict[str, Any] + ) -> Dict[str, Any]: """Executes a tool based on its name and provided context.""" if tool_name not in self.tools: return {"error": f"Tool '{tool_name}' not found."} @@ -376,20 +396,30 @@ class OctoToolsSwarm: # For Python Calculator tool, handle object counts from Object Detector if tool_name == "Python_Calculator_Tool": # Check for object detector results - object_detector_result = context.get("Object_Detector_Tool_result") - if object_detector_result and isinstance(object_detector_result, list): + object_detector_result = context.get( + "Object_Detector_Tool_result" + ) + if object_detector_result and isinstance( + object_detector_result, list + ): # Calculate the number of objects num_objects = len(object_detector_result) # If sub_goal doesn't already contain an expression, create one - if "sub_goal" in context and "Calculate the square root" in context["sub_goal"]: + if ( + "sub_goal" in context + and "Calculate the square root" + in context["sub_goal"] + ): context["expression"] = f"{num_objects}**0.5" elif "expression" not in context: # Default to square root if no expression is specified context["expression"] = f"{num_objects}**0.5" - + # Filter context: only pass expected inputs to the tool valid_inputs = { - k: v for k, v in context.items() if k in tool.metadata.get("input_types", {}) + k: v + for k, v in context.items() + if k in tool.metadata.get("input_types", {}) } result = tool.execute(**valid_inputs) return {"result": result} @@ -397,7 +427,9 @@ class OctoToolsSwarm: logger.exception(f"Error executing tool {tool_name}: {e}") return {"error": str(e)} - def _run_agent(self, agent: Agent, input_prompt: str) -> Dict[str, Any]: + def _run_agent( + self, agent: Agent, input_prompt: str + ) -> Dict[str, Any]: """Runs a swarms agent, handling output and JSON parsing.""" try: # Construct the full input, including the system prompt @@ -406,10 +438,12 @@ class OctoToolsSwarm: # Run the agent and capture the output agent_response = agent.run(full_input) - logger.info(f"DEBUG: Raw agent response: {agent_response}") + logger.info( + f"DEBUG: Raw agent response: {agent_response}" + ) # Extract the LLM's response (remove conversation history, etc.) - response_text = agent_response # Assuming direct return + response_text = agent_response # Assuming direct return # Try to parse the response as JSON parsed_response = self._safely_parse_json(response_text) @@ -417,10 +451,16 @@ class OctoToolsSwarm: return parsed_response except Exception as e: - logger.exception(f"Error running agent {agent.agent_name}: {e}") - return {"error": f"Agent {agent.agent_name} failed: {str(e)}"} + logger.exception( + f"Error running agent {agent.agent_name}: {e}" + ) + return { + "error": f"Agent {agent.agent_name} failed: {str(e)}" + } - def run(self, query: str, image: Optional[str] = None) -> Dict[str, Any]: + def run( + self, query: str, image: Optional[str] = None + ) -> Dict[str, Any]: """Execute the task through the multi-agent workflow.""" logger.info(f"Starting task: {query}") @@ -430,7 +470,9 @@ class OctoToolsSwarm: f"Analyze the following query and determine the necessary skills and" f" relevant tools: {query}" ) - query_analysis = self._run_agent(self.planner, planner_input) + query_analysis = self._run_agent( + self.planner, planner_input + ) if "error" in query_analysis: return { @@ -440,20 +482,27 @@ class OctoToolsSwarm: } self.memory.append( - {"step": 0, "component": "Query Analyzer", "result": query_analysis} + { + "step": 0, + "component": "Query Analyzer", + "result": query_analysis, + } ) self.conversation.add( - role=self.planner.agent_name, content=json.dumps(query_analysis) + role=self.planner.agent_name, + content=json.dumps(query_analysis), ) # Initialize context with the query and image (if provided) context = {"query": query} if image: context["image"] = image - + # Add available tools to context if "relevant_tools" in query_analysis: - context["available_tools"] = query_analysis["relevant_tools"] + context["available_tools"] = query_analysis[ + "relevant_tools" + ] else: # If no relevant tools specified, make all tools available context["available_tools"] = list(self.tools.keys()) @@ -462,48 +511,75 @@ class OctoToolsSwarm: # Step 2: Iterative Action-Execution Loop while step_count <= self.max_iterations: - logger.info(f"Starting iteration {step_count} of {self.max_iterations}") - + logger.info( + f"Starting iteration {step_count} of {self.max_iterations}" + ) + # Step 2a: Action Prediction (Low-Level Planning) action_planner_input = ( f"Current Context: {json.dumps(context)}\nAvailable Tools:" f" {', '.join(context.get('available_tools', list(self.tools.keys())))}\nPlan the" " next step." ) - action = self._run_agent(self.planner, action_planner_input) + action = self._run_agent( + self.planner, action_planner_input + ) if "error" in action: - logger.error(f"Error in action prediction: {action['error']}") + logger.error( + f"Error in action prediction: {action['error']}" + ) return { "error": f"Planner action prediction failed: {action['error']}", "trajectory": self.memory, - "conversation": self.conversation.return_history_as_string() + "conversation": self.conversation.return_history_as_string(), } self.memory.append( - {"step": step_count, "component": "Action Predictor", "result": action} + { + "step": step_count, + "component": "Action Predictor", + "result": action, + } + ) + self.conversation.add( + role=self.planner.agent_name, + content=json.dumps(action), ) - self.conversation.add(role=self.planner.agent_name, content=json.dumps(action)) # Input Validation for Action (Relaxed) - if not isinstance(action, dict) or "tool_name" not in action or "sub_goal" not in action: + if ( + not isinstance(action, dict) + or "tool_name" not in action + or "sub_goal" not in action + ): error_msg = ( "Action prediction did not return required fields (tool_name," " sub_goal) or was not a dictionary." ) logger.error(error_msg) self.memory.append( - {"step": step_count, "component": "Error", "result": error_msg} + { + "step": step_count, + "component": "Error", + "result": error_msg, + } ) break # Step 2b: Execute Tool tool_execution_context = { **context, - **action.get("context", {}), # Add any additional context - "sub_goal": action["sub_goal"], # Pass sub_goal to tool + **action.get( + "context", {} + ), # Add any additional context + "sub_goal": action[ + "sub_goal" + ], # Pass sub_goal to tool } - - tool_result = self._execute_tool(action["tool_name"], tool_execution_context) - + + tool_result = self._execute_tool( + action["tool_name"], tool_execution_context + ) + self.memory.append( { "step": step_count, @@ -514,16 +590,22 @@ class OctoToolsSwarm: # Step 2c: Context Update - Store result with a descriptive key if "result" in tool_result: - context[f"{action['tool_name']}_result"] = tool_result["result"] + context[f"{action['tool_name']}_result"] = ( + tool_result["result"] + ) if "error" in tool_result: - context[f"{action['tool_name']}_error"] = tool_result["error"] + context[f"{action['tool_name']}_error"] = ( + tool_result["error"] + ) # Step 2d: Context Verification verifier_input = ( f"Current Context: {json.dumps(context)}\nMemory:" f" {json.dumps(self.memory)}\nQuery: {query}" ) - verification = self._run_agent(self.verifier, verifier_input) + verification = self._run_agent( + self.verifier, verifier_input + ) if "error" in verification: return { "error": f"Verifier failed: {verification['error']}", @@ -538,22 +620,31 @@ class OctoToolsSwarm: "result": verification, } ) - self.conversation.add(role=self.verifier.agent_name, content=json.dumps(verification)) + self.conversation.add( + role=self.verifier.agent_name, + content=json.dumps(verification), + ) # Check for stop signal from Verifier if verification.get("stop_signal") is True: - logger.info("Received stop signal from verifier. Stopping iterations.") + logger.info( + "Received stop signal from verifier. Stopping iterations." + ) break # Safety mechanism - if we've executed the same tool multiple times same_tool_count = sum( - 1 for m in self.memory - if m.get("component") == "Action Predictor" - and m.get("result", {}).get("tool_name") == action.get("tool_name") + 1 + for m in self.memory + if m.get("component") == "Action Predictor" + and m.get("result", {}).get("tool_name") + == action.get("tool_name") ) - + if same_tool_count > 3: - logger.warning(f"Tool {action.get('tool_name')} used more than 3 times. Forcing stop.") + logger.warning( + f"Tool {action.get('tool_name')} used more than 3 times. Forcing stop." + ) break step_count += 1 @@ -561,23 +652,32 @@ class OctoToolsSwarm: # Step 3: Solution Summarization summarizer_input = f"Complete Trajectory: {json.dumps(self.memory)}\nOriginal Query: {query}" - summarization = self._run_agent(self.summarizer, summarizer_input) + summarization = self._run_agent( + self.summarizer, summarizer_input + ) if "error" in summarization: return { "error": f"Summarizer failed: {summarization['error']}", "trajectory": self.memory, - "conversation": self.conversation.return_history_as_string() + "conversation": self.conversation.return_history_as_string(), } - self.conversation.add(role=self.summarizer.agent_name, content=json.dumps(summarization)) + self.conversation.add( + role=self.summarizer.agent_name, + content=json.dumps(summarization), + ) return { - "final_answer": summarization.get("final_answer", "No answer found."), + "final_answer": summarization.get( + "final_answer", "No answer found." + ), "trajectory": self.memory, "conversation": self.conversation.return_history_as_string(), } except Exception as e: - logger.exception(f"Unexpected error in run method: {e}") # More detailed + logger.exception( + f"Unexpected error in run method: {e}" + ) # More detailed return { "error": str(e), "trajectory": self.memory, @@ -590,7 +690,9 @@ class OctoToolsSwarm: try: agent.save_state() except Exception as e: - logger.error(f"Error saving state for {agent.agent_name}: {str(e)}") + logger.error( + f"Error saving state for {agent.agent_name}: {str(e)}" + ) def load_state(self) -> None: """Load the saved state of all agents.""" @@ -598,24 +700,39 @@ class OctoToolsSwarm: try: agent.load_state() except Exception as e: - logger.error(f"Error loading state for {agent.agent_name}: {str(e)}") + logger.error( + f"Error loading state for {agent.agent_name}: {str(e)}" + ) # --- Example Usage --- # Define dummy tool functions (replace with actual implementations) -def image_captioner_execute(image: str, prompt: str = "Describe the image", **kwargs) -> str: +def image_captioner_execute( + image: str, prompt: str = "Describe the image", **kwargs +) -> str: """Dummy image captioner.""" - print(f"image_captioner_execute called with image: {image}, prompt: {prompt}") + print( + f"image_captioner_execute called with image: {image}, prompt: {prompt}" + ) return f"Caption for {image}: A descriptive caption (dummy)." # Simplified -def object_detector_execute(image: str, labels: List[str] = [], **kwargs) -> List[str]: +def object_detector_execute( + image: str, labels: List[str] = [], **kwargs +) -> List[str]: """Dummy object detector, handles missing labels gracefully.""" - print(f"object_detector_execute called with image: {image}, labels: {labels}") + print( + f"object_detector_execute called with image: {image}, labels: {labels}" + ) if not labels: - return ["object1", "object2", "object3", "object4"] # Return default objects if no labels + return [ + "object1", + "object2", + "object3", + "object4", + ] # Return default objects if no labels return [f"Detected {label}" for label in labels] # Simplified @@ -631,7 +748,9 @@ def python_calculator_execute(expression: str, **kwargs) -> str: try: # Safely evaluate only simple expressions involving numbers and basic operations if re.match(r"^[0-9+\-*/().\s]+$", expression): - result = eval(expression, {"__builtins__": {}, "math": math}) + result = eval( + expression, {"__builtins__": {}, "math": math} + ) return f"Result of {expression} is {result}" else: return "Error: Invalid expression for calculator." @@ -670,7 +789,7 @@ def get_default_tools() -> List[Tool]: name="Web_Search_Tool", description="Performs a web search.", metadata={ - "input_types": {"query": "str"}, + "input_types": {"query": "str"}, "output_type": "str", "limitations": "May not find specific or niche information.", "best_practices": "Use specific and descriptive keywords for better results.", @@ -682,44 +801,44 @@ def get_default_tools() -> List[Tool]: name="Python_Calculator_Tool", description="Evaluates a Python expression.", metadata={ - "input_types": {"expression": "str"}, + "input_types": {"expression": "str"}, "output_type": "str", "limitations": "Cannot handle complex mathematical functions or libraries.", "best_practices": "Use for basic arithmetic and simple calculations.", }, execute_func=python_calculator_execute, ) - + return [image_captioner, object_detector, web_search, calculator] # Only execute the example when this script is run directly # if __name__ == "__main__": # print("Running OctoToolsSwarm example...") - + # # Create an OctoToolsSwarm agent with default tools # tools = get_default_tools() # agent = OctoToolsSwarm(tools=tools) # # Example query # query = "What is the square root of the number of objects in this image?" - + # # Create a dummy image file for testing if it doesn't exist # image_path = "example.png" # if not os.path.exists(image_path): # with open(image_path, "w") as f: # f.write("Dummy image content") # print(f"Created dummy image file: {image_path}") - + # # Run the agent # result = agent.run(query, image=image_path) # # Display results # print("\n=== FINAL ANSWER ===") # print(result["final_answer"]) - + # print("\n=== TRAJECTORY SUMMARY ===") # for step in result["trajectory"]: # print(f"Step {step.get('step', 'N/A')}: {step.get('component', 'Unknown')}") - -# print("\nOctoToolsSwarm example completed.") \ No newline at end of file + +# print("\nOctoToolsSwarm example completed.") diff --git a/swarms/structs/swarms_api.py b/swarms/structs/swarms_api.py index 2024c9ae..c402714d 100644 --- a/swarms/structs/swarms_api.py +++ b/swarms/structs/swarms_api.py @@ -26,7 +26,6 @@ class AgentInput(BaseModel): system_prompt: Optional[str] = Field( None, description="The initial prompt or instructions given to the agent, up to 500 characters.", - max_length=500, ) model_name: Optional[str] = Field( "gpt-4o", diff --git a/swarms_api_examples/swarms_api.py b/swarms_api_examples/swarms_api.py deleted file mode 100644 index 3278ee65..00000000 --- a/swarms_api_examples/swarms_api.py +++ /dev/null @@ -1,910 +0,0 @@ - -import os -from collections import defaultdict -from datetime import datetime -from decimal import Decimal -from functools import lru_cache -from threading import Thread -from time import sleep, time -from typing import Any, Dict, List, Optional, Union - -import pytz -import supabase -from dotenv import load_dotenv -from fastapi import ( - Depends, - FastAPI, - Header, - HTTPException, - Request, - status, -) -from fastapi.middleware.cors import CORSMiddleware -from loguru import logger -from pydantic import BaseModel, Field -from swarms import Agent, SwarmRouter, SwarmType -from swarms.utils.litellm_tokenizer import count_tokens -import asyncio - -load_dotenv() - -# Define rate limit parameters -RATE_LIMIT = 100 # Max requests -TIME_WINDOW = 60 # Time window in seconds - -# In-memory store for tracking requests -request_counts = defaultdict(lambda: {"count": 0, "start_time": time()}) - -# In-memory store for scheduled jobs -scheduled_jobs: Dict[str, Dict] = {} - - -def rate_limiter(request: Request): - client_ip = request.client.host - current_time = time() - client_data = request_counts[client_ip] - - # Reset count if time window has passed - if current_time - client_data["start_time"] > TIME_WINDOW: - client_data["count"] = 0 - client_data["start_time"] = current_time - - # Increment request count - client_data["count"] += 1 - - # Check if rate limit is exceeded - if client_data["count"] > RATE_LIMIT: - raise HTTPException( - status_code=429, detail="Rate limit exceeded. Please try again later." - ) - - -class AgentSpec(BaseModel): - agent_name: Optional[str] = Field(None, description="Agent Name", max_length=100) - description: Optional[str] = Field(None, description="Description", max_length=500) - system_prompt: Optional[str] = Field( - None, description="System Prompt", max_length=500 - ) - model_name: Optional[str] = Field( - "gpt-4o", description="Model Name", max_length=500 - ) - auto_generate_prompt: Optional[bool] = Field( - False, description="Auto Generate Prompt" - ) - max_tokens: Optional[int] = Field(None, description="Max Tokens") - temperature: Optional[float] = Field(0.5, description="Temperature") - role: Optional[str] = Field("worker", description="Role") - max_loops: Optional[int] = Field(1, description="Max Loops") - - -# class ExternalAgent(BaseModel): -# base_url: str = Field(..., description="Base URL") -# parameters: Dict[str, Any] = Field(..., description="Parameters") -# headers: Dict[str, Any] = Field(..., description="Headers") - - -class ScheduleSpec(BaseModel): - scheduled_time: datetime = Field(..., description="When to execute the swarm (UTC)") - timezone: Optional[str] = Field( - "UTC", description="Timezone for the scheduled time" - ) - - -class SwarmSpec(BaseModel): - name: Optional[str] = Field(None, description="Swarm Name", max_length=100) - description: Optional[str] = Field(None, description="Description") - agents: Optional[Union[List[AgentSpec], Any]] = Field(None, description="Agents") - max_loops: Optional[int] = Field(None, description="Max Loops") - swarm_type: Optional[SwarmType] = Field(None, description="Swarm Type") - rearrange_flow: Optional[str] = Field(None, description="Flow") - task: Optional[str] = Field(None, description="Task") - img: Optional[str] = Field(None, description="Img") - return_history: Optional[bool] = Field(True, description="Return History") - rules: Optional[str] = Field(None, description="Rules") - schedule: Optional[ScheduleSpec] = Field(None, description="Scheduling information") - - -class ScheduledJob(Thread): - def __init__( - self, job_id: str, scheduled_time: datetime, swarm: SwarmSpec, api_key: str - ): - super().__init__() - self.job_id = job_id - self.scheduled_time = scheduled_time - self.swarm = swarm - self.api_key = api_key - self.daemon = True # Allow the thread to be terminated when main program exits - self.cancelled = False - - def run(self): - while not self.cancelled: - now = datetime.now(pytz.UTC) - if now >= self.scheduled_time: - try: - # Execute the swarm - asyncio.run(run_swarm_completion(self.swarm, self.api_key)) - except Exception as e: - logger.error( - f"Error executing scheduled swarm {self.job_id}: {str(e)}" - ) - finally: - # Remove the job from scheduled_jobs after execution - scheduled_jobs.pop(self.job_id, None) - break - sleep(1) # Check every second - - -def get_supabase_client(): - supabase_url = os.getenv("SUPABASE_URL") - supabase_key = os.getenv("SUPABASE_KEY") - return supabase.create_client(supabase_url, supabase_key) - - -@lru_cache(maxsize=1000) -def check_api_key(api_key: str) -> bool: - supabase_client = get_supabase_client() - response = ( - supabase_client.table("swarms_cloud_api_keys") - .select("*") - .eq("key", api_key) - .execute() - ) - return bool(response.data) - - -# class ExternalAgent: -# def __init__(self, base_url: str, parameters: Dict[str, Any], headers: Dict[str, Any]): -# self.base_url = base_url -# self.parameters = parameters -# self.headers = headers - -# def run(self, task: str) -> Dict[str, Any]: -# response = requests.post(self.base_url, json=self.parameters, headers=self.headers) -# return response.json() - - -@lru_cache(maxsize=1000) -def get_user_id_from_api_key(api_key: str) -> str: - """ - Maps an API key to its associated user ID. - - Args: - api_key (str): The API key to look up - - Returns: - str: The user ID associated with the API key - - Raises: - ValueError: If the API key is invalid or not found - """ - supabase_client = get_supabase_client() - response = ( - supabase_client.table("swarms_cloud_api_keys") - .select("user_id") - .eq("key", api_key) - .execute() - ) - if not response.data: - raise ValueError("Invalid API key") - return response.data[0]["user_id"] - - -def verify_api_key(x_api_key: str = Header(...)) -> None: - """ - Dependency to verify the API key. - """ - if not check_api_key(x_api_key): - raise HTTPException(status_code=403, detail="Invalid API Key") - - -async def get_api_key_logs(api_key: str) -> List[Dict[str, Any]]: - """ - Retrieve all API request logs for a specific API key. - - Args: - api_key: The API key to query logs for - - Returns: - List[Dict[str, Any]]: List of log entries for the API key - """ - try: - supabase_client = get_supabase_client() - - # Query swarms_api_logs table for entries matching the API key - response = ( - supabase_client.table("swarms_api_logs") - .select("*") - .eq("api_key", api_key) - .execute() - ) - return response.data - - except Exception as e: - logger.error(f"Error retrieving API logs: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to retrieve API logs: {str(e)}", - ) - - -def create_swarm(swarm_spec: SwarmSpec) -> SwarmRouter: - try: - # Validate swarm_spec - if not swarm_spec.agents: - raise ValueError("Swarm specification must include at least one agent.") - - agents = [] - for agent_spec in swarm_spec.agents: - try: - # Handle both dict and AgentSpec objects - if isinstance(agent_spec, dict): - # Convert dict to AgentSpec - agent_spec = AgentSpec(**agent_spec) - - # Validate agent_spec fields - if not agent_spec.agent_name: - raise ValueError("Agent name is required.") - if not agent_spec.model_name: - raise ValueError("Model name is required.") - - # Create the agent - agent = Agent( - agent_name=agent_spec.agent_name, - description=agent_spec.description, - system_prompt=agent_spec.system_prompt, - model_name=agent_spec.model_name, - auto_generate_prompt=agent_spec.auto_generate_prompt, - max_tokens=agent_spec.max_tokens, - temperature=agent_spec.temperature, - role=agent_spec.role, - max_loops=agent_spec.max_loops, - ) - agents.append(agent) - logger.info( - "Successfully created agent: {}", - agent_spec.agent_name, - ) - except ValueError as ve: - logger.error( - "Validation error for agent {}: {}", - getattr(agent_spec, 'agent_name', 'unknown'), - str(ve), - ) - raise - except Exception as agent_error: - logger.error( - "Error creating agent {}: {}", - getattr(agent_spec, 'agent_name', 'unknown'), - str(agent_error), - ) - raise - - if not agents: - raise ValueError( - "No valid agents could be created from the swarm specification." - ) - - # Create and configure the swarm - swarm = SwarmRouter( - name=swarm_spec.name, - description=swarm_spec.description, - agents=agents, - max_loops=swarm_spec.max_loops, - swarm_type=swarm_spec.swarm_type, - output_type="dict", - return_entire_history=False, - rules=swarm_spec.rules, - rearrange_flow=swarm_spec.rearrange_flow, - ) - - # Run the swarm task - output = swarm.run(task=swarm_spec.task) - return output - except Exception as e: - logger.error("Error creating swarm: {}", str(e)) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to create swarm: {str(e)}", - ) - - -# Add this function after your get_supabase_client() function -async def log_api_request(api_key: str, data: Dict[str, Any]) -> None: - """ - Log API request data to Supabase swarms_api_logs table. - - Args: - api_key: The API key used for the request - data: Dictionary containing request data to log - """ - try: - supabase_client = get_supabase_client() - - # Create log entry - log_entry = { - "api_key": api_key, - "data": data, - } - - # Insert into swarms_api_logs table - response = supabase_client.table("swarms_api_logs").insert(log_entry).execute() - - if not response.data: - logger.error("Failed to log API request") - - except Exception as e: - logger.error(f"Error logging API request: {str(e)}") - - -async def run_swarm_completion( - swarm: SwarmSpec, x_api_key: str = None -) -> Dict[str, Any]: - """ - Run a swarm with the specified task. - """ - try: - swarm_name = swarm.name - - agents = swarm.agents - - await log_api_request(x_api_key, swarm.model_dump()) - - # Log start of swarm execution - logger.info(f"Starting swarm {swarm_name} with {len(agents)} agents") - start_time = time() - - # Create and run the swarm - logger.debug(f"Creating swarm object for {swarm_name}") - result = create_swarm(swarm) - logger.debug(f"Running swarm task: {swarm.task}") - - # Calculate execution time - execution_time = time() - start_time - logger.info( - f"Swarm {swarm_name} executed in {round(execution_time, 2)} seconds" - ) - - # Calculate costs - logger.debug(f"Calculating costs for swarm {swarm_name}") - cost_info = calculate_swarm_cost( - agents=agents, - input_text=swarm.task, - agent_outputs=result, - execution_time=execution_time, - ) - logger.info(f"Cost calculation completed for swarm {swarm_name}: {cost_info}") - - # Deduct credits based on calculated cost - logger.debug( - f"Deducting credits for swarm {swarm_name} with cost {cost_info['total_cost']}" - ) - - deduct_credits( - x_api_key, - cost_info["total_cost"], - f"swarm_execution_{swarm_name}", - ) - - # Format the response - response = { - "status": "success", - "swarm_name": swarm_name, - "description": swarm.description, - "swarm_type": swarm.swarm_type, - "task": swarm.task, - "output": result, - "metadata": { - "max_loops": swarm.max_loops, - "num_agents": len(agents), - "execution_time_seconds": round(execution_time, 2), - "completion_time": time(), - "billing_info": cost_info, - }, - } - logger.info(response) - await log_api_request(x_api_key, response) - - return response - - except HTTPException as http_exc: - logger.error("HTTPException occurred: {}", http_exc.detail) - raise - except Exception as e: - logger.error("Error running swarm {}: {}", swarm_name, str(e)) - logger.exception(e) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to run swarm: {str(e)}", - ) - - -def deduct_credits(api_key: str, amount: float, product_name: str) -> None: - """ - Deducts the specified amount of credits for the user identified by api_key, - preferring to use free_credit before using regular credit, and logs the transaction. - """ - supabase_client = get_supabase_client() - user_id = get_user_id_from_api_key(api_key) - - # 1. Retrieve the user's credit record - response = ( - supabase_client.table("swarms_cloud_users_credits") - .select("*") - .eq("user_id", user_id) - .execute() - ) - if not response.data: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User credits record not found.", - ) - - record = response.data[0] - # Use Decimal for precise arithmetic - available_credit = Decimal(record["credit"]) - free_credit = Decimal(record.get("free_credit", "0")) - deduction = Decimal(str(amount)) - - print( - f"Available credit: {available_credit}, Free credit: {free_credit}, Deduction: {deduction}" - ) - - # 2. Verify sufficient total credits are available - if (available_credit + free_credit) < deduction: - raise HTTPException( - status_code=status.HTTP_402_PAYMENT_REQUIRED, - detail="Insufficient credits.", - ) - - # 3. Log the transaction - log_response = ( - supabase_client.table("swarms_cloud_services") - .insert( - { - "user_id": user_id, - "api_key": api_key, - "charge_credit": int( - deduction - ), # Assuming credits are stored as integers - "product_name": product_name, - } - ) - .execute() - ) - if not log_response.data: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to log the credit transaction.", - ) - - # 4. Deduct credits: use free_credit first, then deduct the remainder from available_credit - if free_credit >= deduction: - free_credit -= deduction - else: - remainder = deduction - free_credit - free_credit = Decimal("0") - available_credit -= remainder - - update_response = ( - supabase_client.table("swarms_cloud_users_credits") - .update( - { - "credit": str(available_credit), - "free_credit": str(free_credit), - } - ) - .eq("user_id", user_id) - .execute() - ) - if not update_response.data: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to update credits.", - ) - - -def calculate_swarm_cost( - agents: List[Agent], - input_text: str, - execution_time: float, - agent_outputs: Union[List[Dict[str, str]], str] = None, # Update agent_outputs type -) -> Dict[str, Any]: - """ - Calculate the cost of running a swarm based on agents, tokens, and execution time. - Includes system prompts, agent memory, and scaled output costs. - - Args: - agents: List of agents used in the swarm - input_text: The input task/prompt text - execution_time: Time taken to execute in seconds - agent_outputs: List of output texts from each agent or a list of dictionaries - - Returns: - Dict containing cost breakdown and total cost - """ - # Base costs per unit (these could be moved to environment variables) - COST_PER_AGENT = 0.01 # Base cost per agent - COST_PER_1M_INPUT_TOKENS = 2.00 # Cost per 1M input tokens - COST_PER_1M_OUTPUT_TOKENS = 6.00 # Cost per 1M output tokens - - # Get current time in California timezone - california_tz = pytz.timezone("America/Los_Angeles") - current_time = datetime.now(california_tz) - is_night_time = current_time.hour >= 20 or current_time.hour < 6 # 8 PM to 6 AM - - try: - # Calculate input tokens for task - task_tokens = count_tokens(input_text) - - # Calculate total input tokens including system prompts and memory for each agent - total_input_tokens = 0 - total_output_tokens = 0 - per_agent_tokens = {} - - for i, agent in enumerate(agents): - agent_input_tokens = task_tokens # Base task tokens - - # Add system prompt tokens if present - if agent.system_prompt: - agent_input_tokens += count_tokens(agent.system_prompt) - - # Add memory tokens if available - try: - memory = agent.short_memory.return_history_as_string() - if memory: - memory_tokens = count_tokens(str(memory)) - agent_input_tokens += memory_tokens - except Exception as e: - logger.warning( - f"Could not get memory for agent {agent.agent_name}: {str(e)}" - ) - - # Calculate actual output tokens if available, otherwise estimate - if agent_outputs: - if isinstance(agent_outputs, list): - # Sum tokens for each dictionary's content - agent_output_tokens = sum( - count_tokens(message["content"]) for message in agent_outputs - ) - elif isinstance(agent_outputs, str): - agent_output_tokens = count_tokens(agent_outputs) - else: - agent_output_tokens = int( - agent_input_tokens * 2.5 - ) # Estimated output tokens - else: - agent_output_tokens = int( - agent_input_tokens * 2.5 - ) # Estimated output tokens - - # Store per-agent token counts - per_agent_tokens[agent.agent_name] = { - "input_tokens": agent_input_tokens, - "output_tokens": agent_output_tokens, - "total_tokens": agent_input_tokens + agent_output_tokens, - } - - # Add to totals - total_input_tokens += agent_input_tokens - total_output_tokens += agent_output_tokens - - # Calculate costs (convert to millions of tokens) - agent_cost = len(agents) * COST_PER_AGENT - input_token_cost = ( - (total_input_tokens / 1_000_000) * COST_PER_1M_INPUT_TOKENS * len(agents) - ) - output_token_cost = ( - (total_output_tokens / 1_000_000) * COST_PER_1M_OUTPUT_TOKENS * len(agents) - ) - - # Apply discount during California night time hours - if is_night_time: - input_token_cost *= 0.25 # 75% discount - output_token_cost *= 0.25 # 75% discount - - # Calculate total cost - total_cost = agent_cost + input_token_cost + output_token_cost - - output = { - "cost_breakdown": { - "agent_cost": round(agent_cost, 6), - "input_token_cost": round(input_token_cost, 6), - "output_token_cost": round(output_token_cost, 6), - "token_counts": { - "total_input_tokens": total_input_tokens, - "total_output_tokens": total_output_tokens, - "total_tokens": total_input_tokens + total_output_tokens, - "per_agent": per_agent_tokens, - }, - "num_agents": len(agents), - "execution_time_seconds": round(execution_time, 2), - }, - "total_cost": round(total_cost, 6), - } - - return output - - except Exception as e: - logger.error(f"Error calculating swarm cost: {str(e)}") - raise ValueError(f"Failed to calculate swarm cost: {str(e)}") - - -# --- FastAPI Application Setup --- - -app = FastAPI( - title="Swarm Agent API", - description="API for managing and executing Python agents in the cloud without Docker/Kubernetes.", - version="1.0.0", - debug=True, -) - -# Enable CORS (adjust origins as needed) -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], # In production, restrict this to specific domains - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -@app.get("/", dependencies=[Depends(rate_limiter)]) -def root(): - return { - "status": "Welcome to the Swarm API. Check out the docs at https://docs.swarms.world" - } - - -@app.get("/health", dependencies=[Depends(rate_limiter)]) -def health(): - return {"status": "ok"} - - -@app.post( - "/v1/swarm/completions", - dependencies=[ - Depends(verify_api_key), - Depends(rate_limiter), - ], -) -async def run_swarm(swarm: SwarmSpec, x_api_key=Header(...)) -> Dict[str, Any]: - """ - Run a swarm with the specified task. - """ - return await run_swarm_completion(swarm, x_api_key) - - -@app.post( - "/v1/swarm/batch/completions", - dependencies=[ - Depends(verify_api_key), - Depends(rate_limiter), - ], -) -async def run_batch_completions( - swarms: List[SwarmSpec], x_api_key=Header(...) -) -> List[Dict[str, Any]]: - """ - Run a batch of swarms with the specified tasks. - """ - results = [] - for swarm in swarms: - try: - # Call the existing run_swarm function for each swarm - result = await run_swarm_completion(swarm, x_api_key) - results.append(result) - except HTTPException as http_exc: - logger.error("HTTPException occurred: {}", http_exc.detail) - results.append( - { - "status": "error", - "swarm_name": swarm.name, - "detail": http_exc.detail, - } - ) - except Exception as e: - logger.error("Error running swarm {}: {}", swarm.name, str(e)) - logger.exception(e) - results.append( - { - "status": "error", - "swarm_name": swarm.name, - "detail": f"Failed to run swarm: {str(e)}", - } - ) - - return results - - -# Add this new endpoint -@app.get( - "/v1/swarm/logs", - dependencies=[ - Depends(verify_api_key), - Depends(rate_limiter), - ], -) -async def get_logs(x_api_key: str = Header(...)) -> Dict[str, Any]: - """ - Get all API request logs for the provided API key. - """ - try: - logs = await get_api_key_logs(x_api_key) - return {"status": "success", "count": len(logs), "logs": logs} - except Exception as e: - logger.error(f"Error in get_logs endpoint: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) - ) - - -# @app.post("/v1/swarm/cost-prediction") -# async def cost_prediction(swarm: SwarmSpec) -> Dict[str, Any]: -# """ -# Predict the cost of running a swarm. -# """ -# return {"status": "success", "cost": calculate_swarm_cost(swarm)}) - - -@app.post( - "/v1/swarm/schedule", - dependencies=[ - Depends(verify_api_key), - Depends(rate_limiter), - ], -) -async def schedule_swarm( - swarm: SwarmSpec, x_api_key: str = Header(...) -) -> Dict[str, Any]: - """ - Schedule a swarm to run at a specific time. - """ - if not swarm.schedule: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Schedule information is required", - ) - - try: - # Generate a unique job ID - job_id = f"swarm_{swarm.name}_{int(time())}" - - # Create and start the scheduled job - job = ScheduledJob( - job_id=job_id, - scheduled_time=swarm.schedule.scheduled_time, - swarm=swarm, - api_key=x_api_key, - ) - job.start() - - # Store the job information - scheduled_jobs[job_id] = { - "job": job, - "swarm_name": swarm.name, - "scheduled_time": swarm.schedule.scheduled_time, - "timezone": swarm.schedule.timezone, - } - - # Log the scheduling - await log_api_request( - x_api_key, - { - "action": "schedule_swarm", - "swarm_name": swarm.name, - "scheduled_time": swarm.schedule.scheduled_time.isoformat(), - "job_id": job_id, - }, - ) - - return { - "status": "success", - "message": "Swarm scheduled successfully", - "job_id": job_id, - "scheduled_time": swarm.schedule.scheduled_time, - "timezone": swarm.schedule.timezone, - } - - except Exception as e: - logger.error(f"Error scheduling swarm: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to schedule swarm: {str(e)}", - ) - - -@app.get( - "/v1/swarm/schedule", - dependencies=[ - Depends(verify_api_key), - Depends(rate_limiter), - ], -) -async def get_scheduled_jobs(x_api_key: str = Header(...)) -> Dict[str, Any]: - """ - Get all scheduled swarm jobs. - """ - try: - jobs_list = [] - current_time = datetime.now(pytz.UTC) - - # Clean up completed jobs - completed_jobs = [ - job_id - for job_id, job_info in scheduled_jobs.items() - if current_time >= job_info["scheduled_time"] - ] - for job_id in completed_jobs: - scheduled_jobs.pop(job_id, None) - - # Get active jobs - for job_id, job_info in scheduled_jobs.items(): - jobs_list.append( - { - "job_id": job_id, - "swarm_name": job_info["swarm_name"], - "scheduled_time": job_info["scheduled_time"].isoformat(), - "timezone": job_info["timezone"], - } - ) - - return {"status": "success", "scheduled_jobs": jobs_list} - - except Exception as e: - logger.error(f"Error retrieving scheduled jobs: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to retrieve scheduled jobs: {str(e)}", - ) - - -@app.delete( - "/v1/swarm/schedule/{job_id}", - dependencies=[ - Depends(verify_api_key), - Depends(rate_limiter), - ], -) -async def cancel_scheduled_job( - job_id: str, x_api_key: str = Header(...) -) -> Dict[str, Any]: - """ - Cancel a scheduled swarm job. - """ - try: - if job_id not in scheduled_jobs: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Scheduled job not found" - ) - - # Cancel and remove the job - job_info = scheduled_jobs[job_id] - job_info["job"].cancelled = True - scheduled_jobs.pop(job_id) - - await log_api_request( - x_api_key, {"action": "cancel_scheduled_job", "job_id": job_id} - ) - - return { - "status": "success", - "message": "Scheduled job cancelled successfully", - "job_id": job_id, - } - - except HTTPException: - raise - except Exception as e: - logger.error(f"Error cancelling scheduled job: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to cancel scheduled job: {str(e)}", - ) - - -# --- Main Entrypoint --- - -if __name__ == "__main__": - import uvicorn - - uvicorn.run(app, host="0.0.0.0", port=8080, workers=os.cpu_count()) \ No newline at end of file