api client update

pull/791/head
Kye Gomez 2 months ago
parent 5477635441
commit 007eb5c011

@ -73,15 +73,13 @@ from swarms.structs.swarming_architectures import (
star_swarm, star_swarm,
) )
from swarms.structs.swarms_api import ( from swarms.structs.swarms_api import (
SwarmsAPIClient, AgentInput,
SwarmRequest,
SwarmAuthenticationError,
SwarmAPIError, SwarmAPIError,
SwarmAuthenticationError,
SwarmRequest,
SwarmsAPIClient,
SwarmValidationError, SwarmValidationError,
AgentInput,
) )
from swarms.structs.talk_hier import TalkHier, AgentRole, CommunicationEvent
from swarms.structs.octotools import OctoToolsSwarm, Tool, ToolType, get_default_tools
__all__ = [ __all__ = [
"Agent", "Agent",
@ -148,13 +146,6 @@ __all__ = [
"MultiAgentRouter", "MultiAgentRouter",
"MemeAgentGenerator", "MemeAgentGenerator",
"ModelRouter", "ModelRouter",
"OctoToolsSwarm",
"Tool",
"ToolType",
"get_default_tools",
"TalkHier",
"AgentRole",
"CommunicationEvent",
"SwarmsAPIClient", "SwarmsAPIClient",
"SwarmRequest", "SwarmRequest",
"SwarmAuthenticationError", "SwarmAuthenticationError",

@ -5,7 +5,6 @@ Implements the OctoTools framework using swarms.
import json import json
import logging import logging
import os
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
@ -16,6 +15,7 @@ import math # Import the math module
from dotenv import load_dotenv from dotenv import load_dotenv
from swarms import Agent from swarms import Agent
from swarms.structs.conversation import Conversation from swarms.structs.conversation import Conversation
# from exa_search import exa_search as web_search_execute # from exa_search import exa_search as web_search_execute
@ -59,7 +59,9 @@ class Tool:
try: try:
return self.execute_func(**kwargs) return self.execute_func(**kwargs)
except Exception as e: except Exception as e:
logger.error(f"Error executing tool {self.name}: {str(e)}") logger.error(
f"Error executing tool {self.name}: {str(e)}"
)
return {"error": str(e)} return {"error": str(e)}
@ -92,9 +94,15 @@ class OctoToolsSwarm:
"""Initialize the OctoToolsSwarm system.""" """Initialize the OctoToolsSwarm system."""
self.model_name = model_name self.model_name = model_name
self.max_iterations = max_iterations self.max_iterations = max_iterations
self.base_path = Path(base_path) if base_path else Path("./octotools_states") self.base_path = (
Path(base_path)
if base_path
else Path("./octotools_states")
)
self.base_path.mkdir(exist_ok=True) self.base_path.mkdir(exist_ok=True)
self.tools = {tool.name: tool for tool in tools} # Store tools in a dictionary self.tools = {
tool.name: tool for tool in tools
} # Store tools in a dictionary
# Initialize agents # Initialize agents
self._init_agents() self._init_agents()
@ -337,22 +345,24 @@ class OctoToolsSwarm:
try: try:
return json.loads(json_str) return json.loads(json_str)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning(f"JSONDecodeError: Attempting to extract JSON from: {json_str}") logger.warning(
f"JSONDecodeError: Attempting to extract JSON from: {json_str}"
)
try: try:
# More robust JSON extraction with recursive descent # More robust JSON extraction with recursive descent
def extract_json(s): def extract_json(s):
stack = [] stack = []
start = -1 start = -1
for i, c in enumerate(s): for i, c in enumerate(s):
if c == '{': if c == "{":
if not stack: if not stack:
start = i start = i
stack.append(c) stack.append(c)
elif c == '}': elif c == "}":
if stack: if stack:
stack.pop() stack.pop()
if not stack and start != -1: if not stack and start != -1:
return s[start:i+1] return s[start : i + 1]
return None return None
extracted_json = extract_json(json_str) extracted_json = extract_json(json_str)
@ -360,13 +370,23 @@ class OctoToolsSwarm:
logger.info(f"Extracted JSON: {extracted_json}") logger.info(f"Extracted JSON: {extracted_json}")
return json.loads(extracted_json) return json.loads(extracted_json)
else: else:
logger.error("Failed to extract JSON using recursive descent.") logger.error(
return {"error": "Failed to parse JSON", "content": json_str} "Failed to extract JSON using recursive descent."
)
return {
"error": "Failed to parse JSON",
"content": json_str,
}
except Exception as e: except Exception as e:
logger.exception(f"Error during JSON extraction: {e}") logger.exception(f"Error during JSON extraction: {e}")
return {"error": "Failed to parse JSON", "content": json_str} return {
"error": "Failed to parse JSON",
"content": json_str,
}
def _execute_tool(self, tool_name: str, context: Dict[str, Any]) -> Dict[str, Any]: def _execute_tool(
self, tool_name: str, context: Dict[str, Any]
) -> Dict[str, Any]:
"""Executes a tool based on its name and provided context.""" """Executes a tool based on its name and provided context."""
if tool_name not in self.tools: if tool_name not in self.tools:
return {"error": f"Tool '{tool_name}' not found."} return {"error": f"Tool '{tool_name}' not found."}
@ -376,12 +396,20 @@ class OctoToolsSwarm:
# For Python Calculator tool, handle object counts from Object Detector # For Python Calculator tool, handle object counts from Object Detector
if tool_name == "Python_Calculator_Tool": if tool_name == "Python_Calculator_Tool":
# Check for object detector results # Check for object detector results
object_detector_result = context.get("Object_Detector_Tool_result") object_detector_result = context.get(
if object_detector_result and isinstance(object_detector_result, list): "Object_Detector_Tool_result"
)
if object_detector_result and isinstance(
object_detector_result, list
):
# Calculate the number of objects # Calculate the number of objects
num_objects = len(object_detector_result) num_objects = len(object_detector_result)
# If sub_goal doesn't already contain an expression, create one # If sub_goal doesn't already contain an expression, create one
if "sub_goal" in context and "Calculate the square root" in context["sub_goal"]: if (
"sub_goal" in context
and "Calculate the square root"
in context["sub_goal"]
):
context["expression"] = f"{num_objects}**0.5" context["expression"] = f"{num_objects}**0.5"
elif "expression" not in context: elif "expression" not in context:
# Default to square root if no expression is specified # Default to square root if no expression is specified
@ -389,7 +417,9 @@ class OctoToolsSwarm:
# Filter context: only pass expected inputs to the tool # Filter context: only pass expected inputs to the tool
valid_inputs = { valid_inputs = {
k: v for k, v in context.items() if k in tool.metadata.get("input_types", {}) k: v
for k, v in context.items()
if k in tool.metadata.get("input_types", {})
} }
result = tool.execute(**valid_inputs) result = tool.execute(**valid_inputs)
return {"result": result} return {"result": result}
@ -397,7 +427,9 @@ class OctoToolsSwarm:
logger.exception(f"Error executing tool {tool_name}: {e}") logger.exception(f"Error executing tool {tool_name}: {e}")
return {"error": str(e)} return {"error": str(e)}
def _run_agent(self, agent: Agent, input_prompt: str) -> Dict[str, Any]: def _run_agent(
self, agent: Agent, input_prompt: str
) -> Dict[str, Any]:
"""Runs a swarms agent, handling output and JSON parsing.""" """Runs a swarms agent, handling output and JSON parsing."""
try: try:
# Construct the full input, including the system prompt # Construct the full input, including the system prompt
@ -406,10 +438,12 @@ class OctoToolsSwarm:
# Run the agent and capture the output # Run the agent and capture the output
agent_response = agent.run(full_input) agent_response = agent.run(full_input)
logger.info(f"DEBUG: Raw agent response: {agent_response}") logger.info(
f"DEBUG: Raw agent response: {agent_response}"
)
# Extract the LLM's response (remove conversation history, etc.) # Extract the LLM's response (remove conversation history, etc.)
response_text = agent_response # Assuming direct return response_text = agent_response # Assuming direct return
# Try to parse the response as JSON # Try to parse the response as JSON
parsed_response = self._safely_parse_json(response_text) parsed_response = self._safely_parse_json(response_text)
@ -417,10 +451,16 @@ class OctoToolsSwarm:
return parsed_response return parsed_response
except Exception as e: except Exception as e:
logger.exception(f"Error running agent {agent.agent_name}: {e}") logger.exception(
return {"error": f"Agent {agent.agent_name} failed: {str(e)}"} f"Error running agent {agent.agent_name}: {e}"
)
return {
"error": f"Agent {agent.agent_name} failed: {str(e)}"
}
def run(self, query: str, image: Optional[str] = None) -> Dict[str, Any]: def run(
self, query: str, image: Optional[str] = None
) -> Dict[str, Any]:
"""Execute the task through the multi-agent workflow.""" """Execute the task through the multi-agent workflow."""
logger.info(f"Starting task: {query}") logger.info(f"Starting task: {query}")
@ -430,7 +470,9 @@ class OctoToolsSwarm:
f"Analyze the following query and determine the necessary skills and" f"Analyze the following query and determine the necessary skills and"
f" relevant tools: {query}" f" relevant tools: {query}"
) )
query_analysis = self._run_agent(self.planner, planner_input) query_analysis = self._run_agent(
self.planner, planner_input
)
if "error" in query_analysis: if "error" in query_analysis:
return { return {
@ -440,10 +482,15 @@ class OctoToolsSwarm:
} }
self.memory.append( self.memory.append(
{"step": 0, "component": "Query Analyzer", "result": query_analysis} {
"step": 0,
"component": "Query Analyzer",
"result": query_analysis,
}
) )
self.conversation.add( self.conversation.add(
role=self.planner.agent_name, content=json.dumps(query_analysis) role=self.planner.agent_name,
content=json.dumps(query_analysis),
) )
# Initialize context with the query and image (if provided) # Initialize context with the query and image (if provided)
@ -453,7 +500,9 @@ class OctoToolsSwarm:
# Add available tools to context # Add available tools to context
if "relevant_tools" in query_analysis: if "relevant_tools" in query_analysis:
context["available_tools"] = query_analysis["relevant_tools"] context["available_tools"] = query_analysis[
"relevant_tools"
]
else: else:
# If no relevant tools specified, make all tools available # If no relevant tools specified, make all tools available
context["available_tools"] = list(self.tools.keys()) context["available_tools"] = list(self.tools.keys())
@ -462,7 +511,9 @@ class OctoToolsSwarm:
# Step 2: Iterative Action-Execution Loop # Step 2: Iterative Action-Execution Loop
while step_count <= self.max_iterations: while step_count <= self.max_iterations:
logger.info(f"Starting iteration {step_count} of {self.max_iterations}") logger.info(
f"Starting iteration {step_count} of {self.max_iterations}"
)
# Step 2a: Action Prediction (Low-Level Planning) # Step 2a: Action Prediction (Low-Level Planning)
action_planner_input = ( action_planner_input = (
@ -470,39 +521,64 @@ class OctoToolsSwarm:
f" {', '.join(context.get('available_tools', list(self.tools.keys())))}\nPlan the" f" {', '.join(context.get('available_tools', list(self.tools.keys())))}\nPlan the"
" next step." " next step."
) )
action = self._run_agent(self.planner, action_planner_input) action = self._run_agent(
self.planner, action_planner_input
)
if "error" in action: if "error" in action:
logger.error(f"Error in action prediction: {action['error']}") logger.error(
f"Error in action prediction: {action['error']}"
)
return { return {
"error": f"Planner action prediction failed: {action['error']}", "error": f"Planner action prediction failed: {action['error']}",
"trajectory": self.memory, "trajectory": self.memory,
"conversation": self.conversation.return_history_as_string() "conversation": self.conversation.return_history_as_string(),
} }
self.memory.append( self.memory.append(
{"step": step_count, "component": "Action Predictor", "result": action} {
"step": step_count,
"component": "Action Predictor",
"result": action,
}
)
self.conversation.add(
role=self.planner.agent_name,
content=json.dumps(action),
) )
self.conversation.add(role=self.planner.agent_name, content=json.dumps(action))
# Input Validation for Action (Relaxed) # Input Validation for Action (Relaxed)
if not isinstance(action, dict) or "tool_name" not in action or "sub_goal" not in action: if (
not isinstance(action, dict)
or "tool_name" not in action
or "sub_goal" not in action
):
error_msg = ( error_msg = (
"Action prediction did not return required fields (tool_name," "Action prediction did not return required fields (tool_name,"
" sub_goal) or was not a dictionary." " sub_goal) or was not a dictionary."
) )
logger.error(error_msg) logger.error(error_msg)
self.memory.append( self.memory.append(
{"step": step_count, "component": "Error", "result": error_msg} {
"step": step_count,
"component": "Error",
"result": error_msg,
}
) )
break break
# Step 2b: Execute Tool # Step 2b: Execute Tool
tool_execution_context = { tool_execution_context = {
**context, **context,
**action.get("context", {}), # Add any additional context **action.get(
"sub_goal": action["sub_goal"], # Pass sub_goal to tool "context", {}
), # Add any additional context
"sub_goal": action[
"sub_goal"
], # Pass sub_goal to tool
} }
tool_result = self._execute_tool(action["tool_name"], tool_execution_context) tool_result = self._execute_tool(
action["tool_name"], tool_execution_context
)
self.memory.append( self.memory.append(
{ {
@ -514,16 +590,22 @@ class OctoToolsSwarm:
# Step 2c: Context Update - Store result with a descriptive key # Step 2c: Context Update - Store result with a descriptive key
if "result" in tool_result: if "result" in tool_result:
context[f"{action['tool_name']}_result"] = tool_result["result"] context[f"{action['tool_name']}_result"] = (
tool_result["result"]
)
if "error" in tool_result: if "error" in tool_result:
context[f"{action['tool_name']}_error"] = tool_result["error"] context[f"{action['tool_name']}_error"] = (
tool_result["error"]
)
# Step 2d: Context Verification # Step 2d: Context Verification
verifier_input = ( verifier_input = (
f"Current Context: {json.dumps(context)}\nMemory:" f"Current Context: {json.dumps(context)}\nMemory:"
f" {json.dumps(self.memory)}\nQuery: {query}" f" {json.dumps(self.memory)}\nQuery: {query}"
) )
verification = self._run_agent(self.verifier, verifier_input) verification = self._run_agent(
self.verifier, verifier_input
)
if "error" in verification: if "error" in verification:
return { return {
"error": f"Verifier failed: {verification['error']}", "error": f"Verifier failed: {verification['error']}",
@ -538,22 +620,31 @@ class OctoToolsSwarm:
"result": verification, "result": verification,
} }
) )
self.conversation.add(role=self.verifier.agent_name, content=json.dumps(verification)) self.conversation.add(
role=self.verifier.agent_name,
content=json.dumps(verification),
)
# Check for stop signal from Verifier # Check for stop signal from Verifier
if verification.get("stop_signal") is True: if verification.get("stop_signal") is True:
logger.info("Received stop signal from verifier. Stopping iterations.") logger.info(
"Received stop signal from verifier. Stopping iterations."
)
break break
# Safety mechanism - if we've executed the same tool multiple times # Safety mechanism - if we've executed the same tool multiple times
same_tool_count = sum( same_tool_count = sum(
1 for m in self.memory 1
for m in self.memory
if m.get("component") == "Action Predictor" if m.get("component") == "Action Predictor"
and m.get("result", {}).get("tool_name") == action.get("tool_name") and m.get("result", {}).get("tool_name")
== action.get("tool_name")
) )
if same_tool_count > 3: if same_tool_count > 3:
logger.warning(f"Tool {action.get('tool_name')} used more than 3 times. Forcing stop.") logger.warning(
f"Tool {action.get('tool_name')} used more than 3 times. Forcing stop."
)
break break
step_count += 1 step_count += 1
@ -561,23 +652,32 @@ class OctoToolsSwarm:
# Step 3: Solution Summarization # Step 3: Solution Summarization
summarizer_input = f"Complete Trajectory: {json.dumps(self.memory)}\nOriginal Query: {query}" summarizer_input = f"Complete Trajectory: {json.dumps(self.memory)}\nOriginal Query: {query}"
summarization = self._run_agent(self.summarizer, summarizer_input) summarization = self._run_agent(
self.summarizer, summarizer_input
)
if "error" in summarization: if "error" in summarization:
return { return {
"error": f"Summarizer failed: {summarization['error']}", "error": f"Summarizer failed: {summarization['error']}",
"trajectory": self.memory, "trajectory": self.memory,
"conversation": self.conversation.return_history_as_string() "conversation": self.conversation.return_history_as_string(),
} }
self.conversation.add(role=self.summarizer.agent_name, content=json.dumps(summarization)) self.conversation.add(
role=self.summarizer.agent_name,
content=json.dumps(summarization),
)
return { return {
"final_answer": summarization.get("final_answer", "No answer found."), "final_answer": summarization.get(
"final_answer", "No answer found."
),
"trajectory": self.memory, "trajectory": self.memory,
"conversation": self.conversation.return_history_as_string(), "conversation": self.conversation.return_history_as_string(),
} }
except Exception as e: except Exception as e:
logger.exception(f"Unexpected error in run method: {e}") # More detailed logger.exception(
f"Unexpected error in run method: {e}"
) # More detailed
return { return {
"error": str(e), "error": str(e),
"trajectory": self.memory, "trajectory": self.memory,
@ -590,7 +690,9 @@ class OctoToolsSwarm:
try: try:
agent.save_state() agent.save_state()
except Exception as e: except Exception as e:
logger.error(f"Error saving state for {agent.agent_name}: {str(e)}") logger.error(
f"Error saving state for {agent.agent_name}: {str(e)}"
)
def load_state(self) -> None: def load_state(self) -> None:
"""Load the saved state of all agents.""" """Load the saved state of all agents."""
@ -598,24 +700,39 @@ class OctoToolsSwarm:
try: try:
agent.load_state() agent.load_state()
except Exception as e: except Exception as e:
logger.error(f"Error loading state for {agent.agent_name}: {str(e)}") logger.error(
f"Error loading state for {agent.agent_name}: {str(e)}"
)
# --- Example Usage --- # --- Example Usage ---
# Define dummy tool functions (replace with actual implementations) # Define dummy tool functions (replace with actual implementations)
def image_captioner_execute(image: str, prompt: str = "Describe the image", **kwargs) -> str: def image_captioner_execute(
image: str, prompt: str = "Describe the image", **kwargs
) -> str:
"""Dummy image captioner.""" """Dummy image captioner."""
print(f"image_captioner_execute called with image: {image}, prompt: {prompt}") print(
f"image_captioner_execute called with image: {image}, prompt: {prompt}"
)
return f"Caption for {image}: A descriptive caption (dummy)." # Simplified return f"Caption for {image}: A descriptive caption (dummy)." # Simplified
def object_detector_execute(image: str, labels: List[str] = [], **kwargs) -> List[str]: def object_detector_execute(
image: str, labels: List[str] = [], **kwargs
) -> List[str]:
"""Dummy object detector, handles missing labels gracefully.""" """Dummy object detector, handles missing labels gracefully."""
print(f"object_detector_execute called with image: {image}, labels: {labels}") print(
f"object_detector_execute called with image: {image}, labels: {labels}"
)
if not labels: if not labels:
return ["object1", "object2", "object3", "object4"] # Return default objects if no labels return [
"object1",
"object2",
"object3",
"object4",
] # Return default objects if no labels
return [f"Detected {label}" for label in labels] # Simplified return [f"Detected {label}" for label in labels] # Simplified
@ -631,7 +748,9 @@ def python_calculator_execute(expression: str, **kwargs) -> str:
try: try:
# Safely evaluate only simple expressions involving numbers and basic operations # Safely evaluate only simple expressions involving numbers and basic operations
if re.match(r"^[0-9+\-*/().\s]+$", expression): if re.match(r"^[0-9+\-*/().\s]+$", expression):
result = eval(expression, {"__builtins__": {}, "math": math}) result = eval(
expression, {"__builtins__": {}, "math": math}
)
return f"Result of {expression} is {result}" return f"Result of {expression} is {result}"
else: else:
return "Error: Invalid expression for calculator." return "Error: Invalid expression for calculator."

@ -26,7 +26,6 @@ class AgentInput(BaseModel):
system_prompt: Optional[str] = Field( system_prompt: Optional[str] = Field(
None, None,
description="The initial prompt or instructions given to the agent, up to 500 characters.", description="The initial prompt or instructions given to the agent, up to 500 characters.",
max_length=500,
) )
model_name: Optional[str] = Field( model_name: Optional[str] = Field(
"gpt-4o", "gpt-4o",

@ -1,910 +0,0 @@
import os
from collections import defaultdict
from datetime import datetime
from decimal import Decimal
from functools import lru_cache
from threading import Thread
from time import sleep, time
from typing import Any, Dict, List, Optional, Union
import pytz
import supabase
from dotenv import load_dotenv
from fastapi import (
Depends,
FastAPI,
Header,
HTTPException,
Request,
status,
)
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from pydantic import BaseModel, Field
from swarms import Agent, SwarmRouter, SwarmType
from swarms.utils.litellm_tokenizer import count_tokens
import asyncio
load_dotenv()
# Define rate limit parameters
RATE_LIMIT = 100 # Max requests
TIME_WINDOW = 60 # Time window in seconds
# In-memory store for tracking requests
request_counts = defaultdict(lambda: {"count": 0, "start_time": time()})
# In-memory store for scheduled jobs
scheduled_jobs: Dict[str, Dict] = {}
def rate_limiter(request: Request):
client_ip = request.client.host
current_time = time()
client_data = request_counts[client_ip]
# Reset count if time window has passed
if current_time - client_data["start_time"] > TIME_WINDOW:
client_data["count"] = 0
client_data["start_time"] = current_time
# Increment request count
client_data["count"] += 1
# Check if rate limit is exceeded
if client_data["count"] > RATE_LIMIT:
raise HTTPException(
status_code=429, detail="Rate limit exceeded. Please try again later."
)
class AgentSpec(BaseModel):
agent_name: Optional[str] = Field(None, description="Agent Name", max_length=100)
description: Optional[str] = Field(None, description="Description", max_length=500)
system_prompt: Optional[str] = Field(
None, description="System Prompt", max_length=500
)
model_name: Optional[str] = Field(
"gpt-4o", description="Model Name", max_length=500
)
auto_generate_prompt: Optional[bool] = Field(
False, description="Auto Generate Prompt"
)
max_tokens: Optional[int] = Field(None, description="Max Tokens")
temperature: Optional[float] = Field(0.5, description="Temperature")
role: Optional[str] = Field("worker", description="Role")
max_loops: Optional[int] = Field(1, description="Max Loops")
# class ExternalAgent(BaseModel):
# base_url: str = Field(..., description="Base URL")
# parameters: Dict[str, Any] = Field(..., description="Parameters")
# headers: Dict[str, Any] = Field(..., description="Headers")
class ScheduleSpec(BaseModel):
scheduled_time: datetime = Field(..., description="When to execute the swarm (UTC)")
timezone: Optional[str] = Field(
"UTC", description="Timezone for the scheduled time"
)
class SwarmSpec(BaseModel):
name: Optional[str] = Field(None, description="Swarm Name", max_length=100)
description: Optional[str] = Field(None, description="Description")
agents: Optional[Union[List[AgentSpec], Any]] = Field(None, description="Agents")
max_loops: Optional[int] = Field(None, description="Max Loops")
swarm_type: Optional[SwarmType] = Field(None, description="Swarm Type")
rearrange_flow: Optional[str] = Field(None, description="Flow")
task: Optional[str] = Field(None, description="Task")
img: Optional[str] = Field(None, description="Img")
return_history: Optional[bool] = Field(True, description="Return History")
rules: Optional[str] = Field(None, description="Rules")
schedule: Optional[ScheduleSpec] = Field(None, description="Scheduling information")
class ScheduledJob(Thread):
def __init__(
self, job_id: str, scheduled_time: datetime, swarm: SwarmSpec, api_key: str
):
super().__init__()
self.job_id = job_id
self.scheduled_time = scheduled_time
self.swarm = swarm
self.api_key = api_key
self.daemon = True # Allow the thread to be terminated when main program exits
self.cancelled = False
def run(self):
while not self.cancelled:
now = datetime.now(pytz.UTC)
if now >= self.scheduled_time:
try:
# Execute the swarm
asyncio.run(run_swarm_completion(self.swarm, self.api_key))
except Exception as e:
logger.error(
f"Error executing scheduled swarm {self.job_id}: {str(e)}"
)
finally:
# Remove the job from scheduled_jobs after execution
scheduled_jobs.pop(self.job_id, None)
break
sleep(1) # Check every second
def get_supabase_client():
supabase_url = os.getenv("SUPABASE_URL")
supabase_key = os.getenv("SUPABASE_KEY")
return supabase.create_client(supabase_url, supabase_key)
@lru_cache(maxsize=1000)
def check_api_key(api_key: str) -> bool:
supabase_client = get_supabase_client()
response = (
supabase_client.table("swarms_cloud_api_keys")
.select("*")
.eq("key", api_key)
.execute()
)
return bool(response.data)
# class ExternalAgent:
# def __init__(self, base_url: str, parameters: Dict[str, Any], headers: Dict[str, Any]):
# self.base_url = base_url
# self.parameters = parameters
# self.headers = headers
# def run(self, task: str) -> Dict[str, Any]:
# response = requests.post(self.base_url, json=self.parameters, headers=self.headers)
# return response.json()
@lru_cache(maxsize=1000)
def get_user_id_from_api_key(api_key: str) -> str:
"""
Maps an API key to its associated user ID.
Args:
api_key (str): The API key to look up
Returns:
str: The user ID associated with the API key
Raises:
ValueError: If the API key is invalid or not found
"""
supabase_client = get_supabase_client()
response = (
supabase_client.table("swarms_cloud_api_keys")
.select("user_id")
.eq("key", api_key)
.execute()
)
if not response.data:
raise ValueError("Invalid API key")
return response.data[0]["user_id"]
def verify_api_key(x_api_key: str = Header(...)) -> None:
"""
Dependency to verify the API key.
"""
if not check_api_key(x_api_key):
raise HTTPException(status_code=403, detail="Invalid API Key")
async def get_api_key_logs(api_key: str) -> List[Dict[str, Any]]:
"""
Retrieve all API request logs for a specific API key.
Args:
api_key: The API key to query logs for
Returns:
List[Dict[str, Any]]: List of log entries for the API key
"""
try:
supabase_client = get_supabase_client()
# Query swarms_api_logs table for entries matching the API key
response = (
supabase_client.table("swarms_api_logs")
.select("*")
.eq("api_key", api_key)
.execute()
)
return response.data
except Exception as e:
logger.error(f"Error retrieving API logs: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve API logs: {str(e)}",
)
def create_swarm(swarm_spec: SwarmSpec) -> SwarmRouter:
try:
# Validate swarm_spec
if not swarm_spec.agents:
raise ValueError("Swarm specification must include at least one agent.")
agents = []
for agent_spec in swarm_spec.agents:
try:
# Handle both dict and AgentSpec objects
if isinstance(agent_spec, dict):
# Convert dict to AgentSpec
agent_spec = AgentSpec(**agent_spec)
# Validate agent_spec fields
if not agent_spec.agent_name:
raise ValueError("Agent name is required.")
if not agent_spec.model_name:
raise ValueError("Model name is required.")
# Create the agent
agent = Agent(
agent_name=agent_spec.agent_name,
description=agent_spec.description,
system_prompt=agent_spec.system_prompt,
model_name=agent_spec.model_name,
auto_generate_prompt=agent_spec.auto_generate_prompt,
max_tokens=agent_spec.max_tokens,
temperature=agent_spec.temperature,
role=agent_spec.role,
max_loops=agent_spec.max_loops,
)
agents.append(agent)
logger.info(
"Successfully created agent: {}",
agent_spec.agent_name,
)
except ValueError as ve:
logger.error(
"Validation error for agent {}: {}",
getattr(agent_spec, 'agent_name', 'unknown'),
str(ve),
)
raise
except Exception as agent_error:
logger.error(
"Error creating agent {}: {}",
getattr(agent_spec, 'agent_name', 'unknown'),
str(agent_error),
)
raise
if not agents:
raise ValueError(
"No valid agents could be created from the swarm specification."
)
# Create and configure the swarm
swarm = SwarmRouter(
name=swarm_spec.name,
description=swarm_spec.description,
agents=agents,
max_loops=swarm_spec.max_loops,
swarm_type=swarm_spec.swarm_type,
output_type="dict",
return_entire_history=False,
rules=swarm_spec.rules,
rearrange_flow=swarm_spec.rearrange_flow,
)
# Run the swarm task
output = swarm.run(task=swarm_spec.task)
return output
except Exception as e:
logger.error("Error creating swarm: {}", str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create swarm: {str(e)}",
)
# Add this function after your get_supabase_client() function
async def log_api_request(api_key: str, data: Dict[str, Any]) -> None:
"""
Log API request data to Supabase swarms_api_logs table.
Args:
api_key: The API key used for the request
data: Dictionary containing request data to log
"""
try:
supabase_client = get_supabase_client()
# Create log entry
log_entry = {
"api_key": api_key,
"data": data,
}
# Insert into swarms_api_logs table
response = supabase_client.table("swarms_api_logs").insert(log_entry).execute()
if not response.data:
logger.error("Failed to log API request")
except Exception as e:
logger.error(f"Error logging API request: {str(e)}")
async def run_swarm_completion(
swarm: SwarmSpec, x_api_key: str = None
) -> Dict[str, Any]:
"""
Run a swarm with the specified task.
"""
try:
swarm_name = swarm.name
agents = swarm.agents
await log_api_request(x_api_key, swarm.model_dump())
# Log start of swarm execution
logger.info(f"Starting swarm {swarm_name} with {len(agents)} agents")
start_time = time()
# Create and run the swarm
logger.debug(f"Creating swarm object for {swarm_name}")
result = create_swarm(swarm)
logger.debug(f"Running swarm task: {swarm.task}")
# Calculate execution time
execution_time = time() - start_time
logger.info(
f"Swarm {swarm_name} executed in {round(execution_time, 2)} seconds"
)
# Calculate costs
logger.debug(f"Calculating costs for swarm {swarm_name}")
cost_info = calculate_swarm_cost(
agents=agents,
input_text=swarm.task,
agent_outputs=result,
execution_time=execution_time,
)
logger.info(f"Cost calculation completed for swarm {swarm_name}: {cost_info}")
# Deduct credits based on calculated cost
logger.debug(
f"Deducting credits for swarm {swarm_name} with cost {cost_info['total_cost']}"
)
deduct_credits(
x_api_key,
cost_info["total_cost"],
f"swarm_execution_{swarm_name}",
)
# Format the response
response = {
"status": "success",
"swarm_name": swarm_name,
"description": swarm.description,
"swarm_type": swarm.swarm_type,
"task": swarm.task,
"output": result,
"metadata": {
"max_loops": swarm.max_loops,
"num_agents": len(agents),
"execution_time_seconds": round(execution_time, 2),
"completion_time": time(),
"billing_info": cost_info,
},
}
logger.info(response)
await log_api_request(x_api_key, response)
return response
except HTTPException as http_exc:
logger.error("HTTPException occurred: {}", http_exc.detail)
raise
except Exception as e:
logger.error("Error running swarm {}: {}", swarm_name, str(e))
logger.exception(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to run swarm: {str(e)}",
)
def deduct_credits(api_key: str, amount: float, product_name: str) -> None:
"""
Deducts the specified amount of credits for the user identified by api_key,
preferring to use free_credit before using regular credit, and logs the transaction.
"""
supabase_client = get_supabase_client()
user_id = get_user_id_from_api_key(api_key)
# 1. Retrieve the user's credit record
response = (
supabase_client.table("swarms_cloud_users_credits")
.select("*")
.eq("user_id", user_id)
.execute()
)
if not response.data:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User credits record not found.",
)
record = response.data[0]
# Use Decimal for precise arithmetic
available_credit = Decimal(record["credit"])
free_credit = Decimal(record.get("free_credit", "0"))
deduction = Decimal(str(amount))
print(
f"Available credit: {available_credit}, Free credit: {free_credit}, Deduction: {deduction}"
)
# 2. Verify sufficient total credits are available
if (available_credit + free_credit) < deduction:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail="Insufficient credits.",
)
# 3. Log the transaction
log_response = (
supabase_client.table("swarms_cloud_services")
.insert(
{
"user_id": user_id,
"api_key": api_key,
"charge_credit": int(
deduction
), # Assuming credits are stored as integers
"product_name": product_name,
}
)
.execute()
)
if not log_response.data:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to log the credit transaction.",
)
# 4. Deduct credits: use free_credit first, then deduct the remainder from available_credit
if free_credit >= deduction:
free_credit -= deduction
else:
remainder = deduction - free_credit
free_credit = Decimal("0")
available_credit -= remainder
update_response = (
supabase_client.table("swarms_cloud_users_credits")
.update(
{
"credit": str(available_credit),
"free_credit": str(free_credit),
}
)
.eq("user_id", user_id)
.execute()
)
if not update_response.data:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update credits.",
)
def calculate_swarm_cost(
agents: List[Agent],
input_text: str,
execution_time: float,
agent_outputs: Union[List[Dict[str, str]], str] = None, # Update agent_outputs type
) -> Dict[str, Any]:
"""
Calculate the cost of running a swarm based on agents, tokens, and execution time.
Includes system prompts, agent memory, and scaled output costs.
Args:
agents: List of agents used in the swarm
input_text: The input task/prompt text
execution_time: Time taken to execute in seconds
agent_outputs: List of output texts from each agent or a list of dictionaries
Returns:
Dict containing cost breakdown and total cost
"""
# Base costs per unit (these could be moved to environment variables)
COST_PER_AGENT = 0.01 # Base cost per agent
COST_PER_1M_INPUT_TOKENS = 2.00 # Cost per 1M input tokens
COST_PER_1M_OUTPUT_TOKENS = 6.00 # Cost per 1M output tokens
# Get current time in California timezone
california_tz = pytz.timezone("America/Los_Angeles")
current_time = datetime.now(california_tz)
is_night_time = current_time.hour >= 20 or current_time.hour < 6 # 8 PM to 6 AM
try:
# Calculate input tokens for task
task_tokens = count_tokens(input_text)
# Calculate total input tokens including system prompts and memory for each agent
total_input_tokens = 0
total_output_tokens = 0
per_agent_tokens = {}
for i, agent in enumerate(agents):
agent_input_tokens = task_tokens # Base task tokens
# Add system prompt tokens if present
if agent.system_prompt:
agent_input_tokens += count_tokens(agent.system_prompt)
# Add memory tokens if available
try:
memory = agent.short_memory.return_history_as_string()
if memory:
memory_tokens = count_tokens(str(memory))
agent_input_tokens += memory_tokens
except Exception as e:
logger.warning(
f"Could not get memory for agent {agent.agent_name}: {str(e)}"
)
# Calculate actual output tokens if available, otherwise estimate
if agent_outputs:
if isinstance(agent_outputs, list):
# Sum tokens for each dictionary's content
agent_output_tokens = sum(
count_tokens(message["content"]) for message in agent_outputs
)
elif isinstance(agent_outputs, str):
agent_output_tokens = count_tokens(agent_outputs)
else:
agent_output_tokens = int(
agent_input_tokens * 2.5
) # Estimated output tokens
else:
agent_output_tokens = int(
agent_input_tokens * 2.5
) # Estimated output tokens
# Store per-agent token counts
per_agent_tokens[agent.agent_name] = {
"input_tokens": agent_input_tokens,
"output_tokens": agent_output_tokens,
"total_tokens": agent_input_tokens + agent_output_tokens,
}
# Add to totals
total_input_tokens += agent_input_tokens
total_output_tokens += agent_output_tokens
# Calculate costs (convert to millions of tokens)
agent_cost = len(agents) * COST_PER_AGENT
input_token_cost = (
(total_input_tokens / 1_000_000) * COST_PER_1M_INPUT_TOKENS * len(agents)
)
output_token_cost = (
(total_output_tokens / 1_000_000) * COST_PER_1M_OUTPUT_TOKENS * len(agents)
)
# Apply discount during California night time hours
if is_night_time:
input_token_cost *= 0.25 # 75% discount
output_token_cost *= 0.25 # 75% discount
# Calculate total cost
total_cost = agent_cost + input_token_cost + output_token_cost
output = {
"cost_breakdown": {
"agent_cost": round(agent_cost, 6),
"input_token_cost": round(input_token_cost, 6),
"output_token_cost": round(output_token_cost, 6),
"token_counts": {
"total_input_tokens": total_input_tokens,
"total_output_tokens": total_output_tokens,
"total_tokens": total_input_tokens + total_output_tokens,
"per_agent": per_agent_tokens,
},
"num_agents": len(agents),
"execution_time_seconds": round(execution_time, 2),
},
"total_cost": round(total_cost, 6),
}
return output
except Exception as e:
logger.error(f"Error calculating swarm cost: {str(e)}")
raise ValueError(f"Failed to calculate swarm cost: {str(e)}")
# --- FastAPI Application Setup ---
app = FastAPI(
title="Swarm Agent API",
description="API for managing and executing Python agents in the cloud without Docker/Kubernetes.",
version="1.0.0",
debug=True,
)
# Enable CORS (adjust origins as needed)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, restrict this to specific domains
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/", dependencies=[Depends(rate_limiter)])
def root():
return {
"status": "Welcome to the Swarm API. Check out the docs at https://docs.swarms.world"
}
@app.get("/health", dependencies=[Depends(rate_limiter)])
def health():
return {"status": "ok"}
@app.post(
"/v1/swarm/completions",
dependencies=[
Depends(verify_api_key),
Depends(rate_limiter),
],
)
async def run_swarm(swarm: SwarmSpec, x_api_key=Header(...)) -> Dict[str, Any]:
"""
Run a swarm with the specified task.
"""
return await run_swarm_completion(swarm, x_api_key)
@app.post(
"/v1/swarm/batch/completions",
dependencies=[
Depends(verify_api_key),
Depends(rate_limiter),
],
)
async def run_batch_completions(
swarms: List[SwarmSpec], x_api_key=Header(...)
) -> List[Dict[str, Any]]:
"""
Run a batch of swarms with the specified tasks.
"""
results = []
for swarm in swarms:
try:
# Call the existing run_swarm function for each swarm
result = await run_swarm_completion(swarm, x_api_key)
results.append(result)
except HTTPException as http_exc:
logger.error("HTTPException occurred: {}", http_exc.detail)
results.append(
{
"status": "error",
"swarm_name": swarm.name,
"detail": http_exc.detail,
}
)
except Exception as e:
logger.error("Error running swarm {}: {}", swarm.name, str(e))
logger.exception(e)
results.append(
{
"status": "error",
"swarm_name": swarm.name,
"detail": f"Failed to run swarm: {str(e)}",
}
)
return results
# Add this new endpoint
@app.get(
"/v1/swarm/logs",
dependencies=[
Depends(verify_api_key),
Depends(rate_limiter),
],
)
async def get_logs(x_api_key: str = Header(...)) -> Dict[str, Any]:
"""
Get all API request logs for the provided API key.
"""
try:
logs = await get_api_key_logs(x_api_key)
return {"status": "success", "count": len(logs), "logs": logs}
except Exception as e:
logger.error(f"Error in get_logs endpoint: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
# @app.post("/v1/swarm/cost-prediction")
# async def cost_prediction(swarm: SwarmSpec) -> Dict[str, Any]:
# """
# Predict the cost of running a swarm.
# """
# return {"status": "success", "cost": calculate_swarm_cost(swarm)})
@app.post(
"/v1/swarm/schedule",
dependencies=[
Depends(verify_api_key),
Depends(rate_limiter),
],
)
async def schedule_swarm(
swarm: SwarmSpec, x_api_key: str = Header(...)
) -> Dict[str, Any]:
"""
Schedule a swarm to run at a specific time.
"""
if not swarm.schedule:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Schedule information is required",
)
try:
# Generate a unique job ID
job_id = f"swarm_{swarm.name}_{int(time())}"
# Create and start the scheduled job
job = ScheduledJob(
job_id=job_id,
scheduled_time=swarm.schedule.scheduled_time,
swarm=swarm,
api_key=x_api_key,
)
job.start()
# Store the job information
scheduled_jobs[job_id] = {
"job": job,
"swarm_name": swarm.name,
"scheduled_time": swarm.schedule.scheduled_time,
"timezone": swarm.schedule.timezone,
}
# Log the scheduling
await log_api_request(
x_api_key,
{
"action": "schedule_swarm",
"swarm_name": swarm.name,
"scheduled_time": swarm.schedule.scheduled_time.isoformat(),
"job_id": job_id,
},
)
return {
"status": "success",
"message": "Swarm scheduled successfully",
"job_id": job_id,
"scheduled_time": swarm.schedule.scheduled_time,
"timezone": swarm.schedule.timezone,
}
except Exception as e:
logger.error(f"Error scheduling swarm: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to schedule swarm: {str(e)}",
)
@app.get(
"/v1/swarm/schedule",
dependencies=[
Depends(verify_api_key),
Depends(rate_limiter),
],
)
async def get_scheduled_jobs(x_api_key: str = Header(...)) -> Dict[str, Any]:
"""
Get all scheduled swarm jobs.
"""
try:
jobs_list = []
current_time = datetime.now(pytz.UTC)
# Clean up completed jobs
completed_jobs = [
job_id
for job_id, job_info in scheduled_jobs.items()
if current_time >= job_info["scheduled_time"]
]
for job_id in completed_jobs:
scheduled_jobs.pop(job_id, None)
# Get active jobs
for job_id, job_info in scheduled_jobs.items():
jobs_list.append(
{
"job_id": job_id,
"swarm_name": job_info["swarm_name"],
"scheduled_time": job_info["scheduled_time"].isoformat(),
"timezone": job_info["timezone"],
}
)
return {"status": "success", "scheduled_jobs": jobs_list}
except Exception as e:
logger.error(f"Error retrieving scheduled jobs: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve scheduled jobs: {str(e)}",
)
@app.delete(
"/v1/swarm/schedule/{job_id}",
dependencies=[
Depends(verify_api_key),
Depends(rate_limiter),
],
)
async def cancel_scheduled_job(
job_id: str, x_api_key: str = Header(...)
) -> Dict[str, Any]:
"""
Cancel a scheduled swarm job.
"""
try:
if job_id not in scheduled_jobs:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Scheduled job not found"
)
# Cancel and remove the job
job_info = scheduled_jobs[job_id]
job_info["job"].cancelled = True
scheduled_jobs.pop(job_id)
await log_api_request(
x_api_key, {"action": "cancel_scheduled_job", "job_id": job_id}
)
return {
"status": "success",
"message": "Scheduled job cancelled successfully",
"job_id": job_id,
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error cancelling scheduled job: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to cancel scheduled job: {str(e)}",
)
# --- Main Entrypoint ---
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8080, workers=os.cpu_count())
Loading…
Cancel
Save