parent
5477635441
commit
007eb5c011
@ -1,910 +0,0 @@
|
|||||||
|
|
||||||
import os
|
|
||||||
from collections import defaultdict
|
|
||||||
from datetime import datetime
|
|
||||||
from decimal import Decimal
|
|
||||||
from functools import lru_cache
|
|
||||||
from threading import Thread
|
|
||||||
from time import sleep, time
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
import pytz
|
|
||||||
import supabase
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from fastapi import (
|
|
||||||
Depends,
|
|
||||||
FastAPI,
|
|
||||||
Header,
|
|
||||||
HTTPException,
|
|
||||||
Request,
|
|
||||||
status,
|
|
||||||
)
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from loguru import logger
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from swarms import Agent, SwarmRouter, SwarmType
|
|
||||||
from swarms.utils.litellm_tokenizer import count_tokens
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
# Define rate limit parameters
|
|
||||||
RATE_LIMIT = 100 # Max requests
|
|
||||||
TIME_WINDOW = 60 # Time window in seconds
|
|
||||||
|
|
||||||
# In-memory store for tracking requests
|
|
||||||
request_counts = defaultdict(lambda: {"count": 0, "start_time": time()})
|
|
||||||
|
|
||||||
# In-memory store for scheduled jobs
|
|
||||||
scheduled_jobs: Dict[str, Dict] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def rate_limiter(request: Request):
|
|
||||||
client_ip = request.client.host
|
|
||||||
current_time = time()
|
|
||||||
client_data = request_counts[client_ip]
|
|
||||||
|
|
||||||
# Reset count if time window has passed
|
|
||||||
if current_time - client_data["start_time"] > TIME_WINDOW:
|
|
||||||
client_data["count"] = 0
|
|
||||||
client_data["start_time"] = current_time
|
|
||||||
|
|
||||||
# Increment request count
|
|
||||||
client_data["count"] += 1
|
|
||||||
|
|
||||||
# Check if rate limit is exceeded
|
|
||||||
if client_data["count"] > RATE_LIMIT:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=429, detail="Rate limit exceeded. Please try again later."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AgentSpec(BaseModel):
|
|
||||||
agent_name: Optional[str] = Field(None, description="Agent Name", max_length=100)
|
|
||||||
description: Optional[str] = Field(None, description="Description", max_length=500)
|
|
||||||
system_prompt: Optional[str] = Field(
|
|
||||||
None, description="System Prompt", max_length=500
|
|
||||||
)
|
|
||||||
model_name: Optional[str] = Field(
|
|
||||||
"gpt-4o", description="Model Name", max_length=500
|
|
||||||
)
|
|
||||||
auto_generate_prompt: Optional[bool] = Field(
|
|
||||||
False, description="Auto Generate Prompt"
|
|
||||||
)
|
|
||||||
max_tokens: Optional[int] = Field(None, description="Max Tokens")
|
|
||||||
temperature: Optional[float] = Field(0.5, description="Temperature")
|
|
||||||
role: Optional[str] = Field("worker", description="Role")
|
|
||||||
max_loops: Optional[int] = Field(1, description="Max Loops")
|
|
||||||
|
|
||||||
|
|
||||||
# class ExternalAgent(BaseModel):
|
|
||||||
# base_url: str = Field(..., description="Base URL")
|
|
||||||
# parameters: Dict[str, Any] = Field(..., description="Parameters")
|
|
||||||
# headers: Dict[str, Any] = Field(..., description="Headers")
|
|
||||||
|
|
||||||
|
|
||||||
class ScheduleSpec(BaseModel):
|
|
||||||
scheduled_time: datetime = Field(..., description="When to execute the swarm (UTC)")
|
|
||||||
timezone: Optional[str] = Field(
|
|
||||||
"UTC", description="Timezone for the scheduled time"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SwarmSpec(BaseModel):
|
|
||||||
name: Optional[str] = Field(None, description="Swarm Name", max_length=100)
|
|
||||||
description: Optional[str] = Field(None, description="Description")
|
|
||||||
agents: Optional[Union[List[AgentSpec], Any]] = Field(None, description="Agents")
|
|
||||||
max_loops: Optional[int] = Field(None, description="Max Loops")
|
|
||||||
swarm_type: Optional[SwarmType] = Field(None, description="Swarm Type")
|
|
||||||
rearrange_flow: Optional[str] = Field(None, description="Flow")
|
|
||||||
task: Optional[str] = Field(None, description="Task")
|
|
||||||
img: Optional[str] = Field(None, description="Img")
|
|
||||||
return_history: Optional[bool] = Field(True, description="Return History")
|
|
||||||
rules: Optional[str] = Field(None, description="Rules")
|
|
||||||
schedule: Optional[ScheduleSpec] = Field(None, description="Scheduling information")
|
|
||||||
|
|
||||||
|
|
||||||
class ScheduledJob(Thread):
|
|
||||||
def __init__(
|
|
||||||
self, job_id: str, scheduled_time: datetime, swarm: SwarmSpec, api_key: str
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.job_id = job_id
|
|
||||||
self.scheduled_time = scheduled_time
|
|
||||||
self.swarm = swarm
|
|
||||||
self.api_key = api_key
|
|
||||||
self.daemon = True # Allow the thread to be terminated when main program exits
|
|
||||||
self.cancelled = False
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
while not self.cancelled:
|
|
||||||
now = datetime.now(pytz.UTC)
|
|
||||||
if now >= self.scheduled_time:
|
|
||||||
try:
|
|
||||||
# Execute the swarm
|
|
||||||
asyncio.run(run_swarm_completion(self.swarm, self.api_key))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error executing scheduled swarm {self.job_id}: {str(e)}"
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
# Remove the job from scheduled_jobs after execution
|
|
||||||
scheduled_jobs.pop(self.job_id, None)
|
|
||||||
break
|
|
||||||
sleep(1) # Check every second
|
|
||||||
|
|
||||||
|
|
||||||
def get_supabase_client():
|
|
||||||
supabase_url = os.getenv("SUPABASE_URL")
|
|
||||||
supabase_key = os.getenv("SUPABASE_KEY")
|
|
||||||
return supabase.create_client(supabase_url, supabase_key)
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1000)
|
|
||||||
def check_api_key(api_key: str) -> bool:
|
|
||||||
supabase_client = get_supabase_client()
|
|
||||||
response = (
|
|
||||||
supabase_client.table("swarms_cloud_api_keys")
|
|
||||||
.select("*")
|
|
||||||
.eq("key", api_key)
|
|
||||||
.execute()
|
|
||||||
)
|
|
||||||
return bool(response.data)
|
|
||||||
|
|
||||||
|
|
||||||
# class ExternalAgent:
|
|
||||||
# def __init__(self, base_url: str, parameters: Dict[str, Any], headers: Dict[str, Any]):
|
|
||||||
# self.base_url = base_url
|
|
||||||
# self.parameters = parameters
|
|
||||||
# self.headers = headers
|
|
||||||
|
|
||||||
# def run(self, task: str) -> Dict[str, Any]:
|
|
||||||
# response = requests.post(self.base_url, json=self.parameters, headers=self.headers)
|
|
||||||
# return response.json()
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1000)
|
|
||||||
def get_user_id_from_api_key(api_key: str) -> str:
|
|
||||||
"""
|
|
||||||
Maps an API key to its associated user ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_key (str): The API key to look up
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The user ID associated with the API key
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the API key is invalid or not found
|
|
||||||
"""
|
|
||||||
supabase_client = get_supabase_client()
|
|
||||||
response = (
|
|
||||||
supabase_client.table("swarms_cloud_api_keys")
|
|
||||||
.select("user_id")
|
|
||||||
.eq("key", api_key)
|
|
||||||
.execute()
|
|
||||||
)
|
|
||||||
if not response.data:
|
|
||||||
raise ValueError("Invalid API key")
|
|
||||||
return response.data[0]["user_id"]
|
|
||||||
|
|
||||||
|
|
||||||
def verify_api_key(x_api_key: str = Header(...)) -> None:
|
|
||||||
"""
|
|
||||||
Dependency to verify the API key.
|
|
||||||
"""
|
|
||||||
if not check_api_key(x_api_key):
|
|
||||||
raise HTTPException(status_code=403, detail="Invalid API Key")
|
|
||||||
|
|
||||||
|
|
||||||
async def get_api_key_logs(api_key: str) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Retrieve all API request logs for a specific API key.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_key: The API key to query logs for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict[str, Any]]: List of log entries for the API key
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
supabase_client = get_supabase_client()
|
|
||||||
|
|
||||||
# Query swarms_api_logs table for entries matching the API key
|
|
||||||
response = (
|
|
||||||
supabase_client.table("swarms_api_logs")
|
|
||||||
.select("*")
|
|
||||||
.eq("api_key", api_key)
|
|
||||||
.execute()
|
|
||||||
)
|
|
||||||
return response.data
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error retrieving API logs: {str(e)}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to retrieve API logs: {str(e)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_swarm(swarm_spec: SwarmSpec) -> SwarmRouter:
|
|
||||||
try:
|
|
||||||
# Validate swarm_spec
|
|
||||||
if not swarm_spec.agents:
|
|
||||||
raise ValueError("Swarm specification must include at least one agent.")
|
|
||||||
|
|
||||||
agents = []
|
|
||||||
for agent_spec in swarm_spec.agents:
|
|
||||||
try:
|
|
||||||
# Handle both dict and AgentSpec objects
|
|
||||||
if isinstance(agent_spec, dict):
|
|
||||||
# Convert dict to AgentSpec
|
|
||||||
agent_spec = AgentSpec(**agent_spec)
|
|
||||||
|
|
||||||
# Validate agent_spec fields
|
|
||||||
if not agent_spec.agent_name:
|
|
||||||
raise ValueError("Agent name is required.")
|
|
||||||
if not agent_spec.model_name:
|
|
||||||
raise ValueError("Model name is required.")
|
|
||||||
|
|
||||||
# Create the agent
|
|
||||||
agent = Agent(
|
|
||||||
agent_name=agent_spec.agent_name,
|
|
||||||
description=agent_spec.description,
|
|
||||||
system_prompt=agent_spec.system_prompt,
|
|
||||||
model_name=agent_spec.model_name,
|
|
||||||
auto_generate_prompt=agent_spec.auto_generate_prompt,
|
|
||||||
max_tokens=agent_spec.max_tokens,
|
|
||||||
temperature=agent_spec.temperature,
|
|
||||||
role=agent_spec.role,
|
|
||||||
max_loops=agent_spec.max_loops,
|
|
||||||
)
|
|
||||||
agents.append(agent)
|
|
||||||
logger.info(
|
|
||||||
"Successfully created agent: {}",
|
|
||||||
agent_spec.agent_name,
|
|
||||||
)
|
|
||||||
except ValueError as ve:
|
|
||||||
logger.error(
|
|
||||||
"Validation error for agent {}: {}",
|
|
||||||
getattr(agent_spec, 'agent_name', 'unknown'),
|
|
||||||
str(ve),
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
except Exception as agent_error:
|
|
||||||
logger.error(
|
|
||||||
"Error creating agent {}: {}",
|
|
||||||
getattr(agent_spec, 'agent_name', 'unknown'),
|
|
||||||
str(agent_error),
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
if not agents:
|
|
||||||
raise ValueError(
|
|
||||||
"No valid agents could be created from the swarm specification."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create and configure the swarm
|
|
||||||
swarm = SwarmRouter(
|
|
||||||
name=swarm_spec.name,
|
|
||||||
description=swarm_spec.description,
|
|
||||||
agents=agents,
|
|
||||||
max_loops=swarm_spec.max_loops,
|
|
||||||
swarm_type=swarm_spec.swarm_type,
|
|
||||||
output_type="dict",
|
|
||||||
return_entire_history=False,
|
|
||||||
rules=swarm_spec.rules,
|
|
||||||
rearrange_flow=swarm_spec.rearrange_flow,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run the swarm task
|
|
||||||
output = swarm.run(task=swarm_spec.task)
|
|
||||||
return output
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error creating swarm: {}", str(e))
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to create swarm: {str(e)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Add this function after your get_supabase_client() function
|
|
||||||
async def log_api_request(api_key: str, data: Dict[str, Any]) -> None:
|
|
||||||
"""
|
|
||||||
Log API request data to Supabase swarms_api_logs table.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_key: The API key used for the request
|
|
||||||
data: Dictionary containing request data to log
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
supabase_client = get_supabase_client()
|
|
||||||
|
|
||||||
# Create log entry
|
|
||||||
log_entry = {
|
|
||||||
"api_key": api_key,
|
|
||||||
"data": data,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Insert into swarms_api_logs table
|
|
||||||
response = supabase_client.table("swarms_api_logs").insert(log_entry).execute()
|
|
||||||
|
|
||||||
if not response.data:
|
|
||||||
logger.error("Failed to log API request")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error logging API request: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
async def run_swarm_completion(
|
|
||||||
swarm: SwarmSpec, x_api_key: str = None
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Run a swarm with the specified task.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
swarm_name = swarm.name
|
|
||||||
|
|
||||||
agents = swarm.agents
|
|
||||||
|
|
||||||
await log_api_request(x_api_key, swarm.model_dump())
|
|
||||||
|
|
||||||
# Log start of swarm execution
|
|
||||||
logger.info(f"Starting swarm {swarm_name} with {len(agents)} agents")
|
|
||||||
start_time = time()
|
|
||||||
|
|
||||||
# Create and run the swarm
|
|
||||||
logger.debug(f"Creating swarm object for {swarm_name}")
|
|
||||||
result = create_swarm(swarm)
|
|
||||||
logger.debug(f"Running swarm task: {swarm.task}")
|
|
||||||
|
|
||||||
# Calculate execution time
|
|
||||||
execution_time = time() - start_time
|
|
||||||
logger.info(
|
|
||||||
f"Swarm {swarm_name} executed in {round(execution_time, 2)} seconds"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate costs
|
|
||||||
logger.debug(f"Calculating costs for swarm {swarm_name}")
|
|
||||||
cost_info = calculate_swarm_cost(
|
|
||||||
agents=agents,
|
|
||||||
input_text=swarm.task,
|
|
||||||
agent_outputs=result,
|
|
||||||
execution_time=execution_time,
|
|
||||||
)
|
|
||||||
logger.info(f"Cost calculation completed for swarm {swarm_name}: {cost_info}")
|
|
||||||
|
|
||||||
# Deduct credits based on calculated cost
|
|
||||||
logger.debug(
|
|
||||||
f"Deducting credits for swarm {swarm_name} with cost {cost_info['total_cost']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
deduct_credits(
|
|
||||||
x_api_key,
|
|
||||||
cost_info["total_cost"],
|
|
||||||
f"swarm_execution_{swarm_name}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Format the response
|
|
||||||
response = {
|
|
||||||
"status": "success",
|
|
||||||
"swarm_name": swarm_name,
|
|
||||||
"description": swarm.description,
|
|
||||||
"swarm_type": swarm.swarm_type,
|
|
||||||
"task": swarm.task,
|
|
||||||
"output": result,
|
|
||||||
"metadata": {
|
|
||||||
"max_loops": swarm.max_loops,
|
|
||||||
"num_agents": len(agents),
|
|
||||||
"execution_time_seconds": round(execution_time, 2),
|
|
||||||
"completion_time": time(),
|
|
||||||
"billing_info": cost_info,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
logger.info(response)
|
|
||||||
await log_api_request(x_api_key, response)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
except HTTPException as http_exc:
|
|
||||||
logger.error("HTTPException occurred: {}", http_exc.detail)
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error running swarm {}: {}", swarm_name, str(e))
|
|
||||||
logger.exception(e)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to run swarm: {str(e)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def deduct_credits(api_key: str, amount: float, product_name: str) -> None:
|
|
||||||
"""
|
|
||||||
Deducts the specified amount of credits for the user identified by api_key,
|
|
||||||
preferring to use free_credit before using regular credit, and logs the transaction.
|
|
||||||
"""
|
|
||||||
supabase_client = get_supabase_client()
|
|
||||||
user_id = get_user_id_from_api_key(api_key)
|
|
||||||
|
|
||||||
# 1. Retrieve the user's credit record
|
|
||||||
response = (
|
|
||||||
supabase_client.table("swarms_cloud_users_credits")
|
|
||||||
.select("*")
|
|
||||||
.eq("user_id", user_id)
|
|
||||||
.execute()
|
|
||||||
)
|
|
||||||
if not response.data:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="User credits record not found.",
|
|
||||||
)
|
|
||||||
|
|
||||||
record = response.data[0]
|
|
||||||
# Use Decimal for precise arithmetic
|
|
||||||
available_credit = Decimal(record["credit"])
|
|
||||||
free_credit = Decimal(record.get("free_credit", "0"))
|
|
||||||
deduction = Decimal(str(amount))
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"Available credit: {available_credit}, Free credit: {free_credit}, Deduction: {deduction}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Verify sufficient total credits are available
|
|
||||||
if (available_credit + free_credit) < deduction:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
||||||
detail="Insufficient credits.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Log the transaction
|
|
||||||
log_response = (
|
|
||||||
supabase_client.table("swarms_cloud_services")
|
|
||||||
.insert(
|
|
||||||
{
|
|
||||||
"user_id": user_id,
|
|
||||||
"api_key": api_key,
|
|
||||||
"charge_credit": int(
|
|
||||||
deduction
|
|
||||||
), # Assuming credits are stored as integers
|
|
||||||
"product_name": product_name,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
.execute()
|
|
||||||
)
|
|
||||||
if not log_response.data:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="Failed to log the credit transaction.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. Deduct credits: use free_credit first, then deduct the remainder from available_credit
|
|
||||||
if free_credit >= deduction:
|
|
||||||
free_credit -= deduction
|
|
||||||
else:
|
|
||||||
remainder = deduction - free_credit
|
|
||||||
free_credit = Decimal("0")
|
|
||||||
available_credit -= remainder
|
|
||||||
|
|
||||||
update_response = (
|
|
||||||
supabase_client.table("swarms_cloud_users_credits")
|
|
||||||
.update(
|
|
||||||
{
|
|
||||||
"credit": str(available_credit),
|
|
||||||
"free_credit": str(free_credit),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
.eq("user_id", user_id)
|
|
||||||
.execute()
|
|
||||||
)
|
|
||||||
if not update_response.data:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="Failed to update credits.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_swarm_cost(
|
|
||||||
agents: List[Agent],
|
|
||||||
input_text: str,
|
|
||||||
execution_time: float,
|
|
||||||
agent_outputs: Union[List[Dict[str, str]], str] = None, # Update agent_outputs type
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Calculate the cost of running a swarm based on agents, tokens, and execution time.
|
|
||||||
Includes system prompts, agent memory, and scaled output costs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agents: List of agents used in the swarm
|
|
||||||
input_text: The input task/prompt text
|
|
||||||
execution_time: Time taken to execute in seconds
|
|
||||||
agent_outputs: List of output texts from each agent or a list of dictionaries
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict containing cost breakdown and total cost
|
|
||||||
"""
|
|
||||||
# Base costs per unit (these could be moved to environment variables)
|
|
||||||
COST_PER_AGENT = 0.01 # Base cost per agent
|
|
||||||
COST_PER_1M_INPUT_TOKENS = 2.00 # Cost per 1M input tokens
|
|
||||||
COST_PER_1M_OUTPUT_TOKENS = 6.00 # Cost per 1M output tokens
|
|
||||||
|
|
||||||
# Get current time in California timezone
|
|
||||||
california_tz = pytz.timezone("America/Los_Angeles")
|
|
||||||
current_time = datetime.now(california_tz)
|
|
||||||
is_night_time = current_time.hour >= 20 or current_time.hour < 6 # 8 PM to 6 AM
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Calculate input tokens for task
|
|
||||||
task_tokens = count_tokens(input_text)
|
|
||||||
|
|
||||||
# Calculate total input tokens including system prompts and memory for each agent
|
|
||||||
total_input_tokens = 0
|
|
||||||
total_output_tokens = 0
|
|
||||||
per_agent_tokens = {}
|
|
||||||
|
|
||||||
for i, agent in enumerate(agents):
|
|
||||||
agent_input_tokens = task_tokens # Base task tokens
|
|
||||||
|
|
||||||
# Add system prompt tokens if present
|
|
||||||
if agent.system_prompt:
|
|
||||||
agent_input_tokens += count_tokens(agent.system_prompt)
|
|
||||||
|
|
||||||
# Add memory tokens if available
|
|
||||||
try:
|
|
||||||
memory = agent.short_memory.return_history_as_string()
|
|
||||||
if memory:
|
|
||||||
memory_tokens = count_tokens(str(memory))
|
|
||||||
agent_input_tokens += memory_tokens
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Could not get memory for agent {agent.agent_name}: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate actual output tokens if available, otherwise estimate
|
|
||||||
if agent_outputs:
|
|
||||||
if isinstance(agent_outputs, list):
|
|
||||||
# Sum tokens for each dictionary's content
|
|
||||||
agent_output_tokens = sum(
|
|
||||||
count_tokens(message["content"]) for message in agent_outputs
|
|
||||||
)
|
|
||||||
elif isinstance(agent_outputs, str):
|
|
||||||
agent_output_tokens = count_tokens(agent_outputs)
|
|
||||||
else:
|
|
||||||
agent_output_tokens = int(
|
|
||||||
agent_input_tokens * 2.5
|
|
||||||
) # Estimated output tokens
|
|
||||||
else:
|
|
||||||
agent_output_tokens = int(
|
|
||||||
agent_input_tokens * 2.5
|
|
||||||
) # Estimated output tokens
|
|
||||||
|
|
||||||
# Store per-agent token counts
|
|
||||||
per_agent_tokens[agent.agent_name] = {
|
|
||||||
"input_tokens": agent_input_tokens,
|
|
||||||
"output_tokens": agent_output_tokens,
|
|
||||||
"total_tokens": agent_input_tokens + agent_output_tokens,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add to totals
|
|
||||||
total_input_tokens += agent_input_tokens
|
|
||||||
total_output_tokens += agent_output_tokens
|
|
||||||
|
|
||||||
# Calculate costs (convert to millions of tokens)
|
|
||||||
agent_cost = len(agents) * COST_PER_AGENT
|
|
||||||
input_token_cost = (
|
|
||||||
(total_input_tokens / 1_000_000) * COST_PER_1M_INPUT_TOKENS * len(agents)
|
|
||||||
)
|
|
||||||
output_token_cost = (
|
|
||||||
(total_output_tokens / 1_000_000) * COST_PER_1M_OUTPUT_TOKENS * len(agents)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply discount during California night time hours
|
|
||||||
if is_night_time:
|
|
||||||
input_token_cost *= 0.25 # 75% discount
|
|
||||||
output_token_cost *= 0.25 # 75% discount
|
|
||||||
|
|
||||||
# Calculate total cost
|
|
||||||
total_cost = agent_cost + input_token_cost + output_token_cost
|
|
||||||
|
|
||||||
output = {
|
|
||||||
"cost_breakdown": {
|
|
||||||
"agent_cost": round(agent_cost, 6),
|
|
||||||
"input_token_cost": round(input_token_cost, 6),
|
|
||||||
"output_token_cost": round(output_token_cost, 6),
|
|
||||||
"token_counts": {
|
|
||||||
"total_input_tokens": total_input_tokens,
|
|
||||||
"total_output_tokens": total_output_tokens,
|
|
||||||
"total_tokens": total_input_tokens + total_output_tokens,
|
|
||||||
"per_agent": per_agent_tokens,
|
|
||||||
},
|
|
||||||
"num_agents": len(agents),
|
|
||||||
"execution_time_seconds": round(execution_time, 2),
|
|
||||||
},
|
|
||||||
"total_cost": round(total_cost, 6),
|
|
||||||
}
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error calculating swarm cost: {str(e)}")
|
|
||||||
raise ValueError(f"Failed to calculate swarm cost: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
# --- FastAPI Application Setup ---
|
|
||||||
|
|
||||||
app = FastAPI(
|
|
||||||
title="Swarm Agent API",
|
|
||||||
description="API for managing and executing Python agents in the cloud without Docker/Kubernetes.",
|
|
||||||
version="1.0.0",
|
|
||||||
debug=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Enable CORS (adjust origins as needed)
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"], # In production, restrict this to specific domains
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/", dependencies=[Depends(rate_limiter)])
|
|
||||||
def root():
|
|
||||||
return {
|
|
||||||
"status": "Welcome to the Swarm API. Check out the docs at https://docs.swarms.world"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health", dependencies=[Depends(rate_limiter)])
|
|
||||||
def health():
|
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
|
|
||||||
@app.post(
|
|
||||||
"/v1/swarm/completions",
|
|
||||||
dependencies=[
|
|
||||||
Depends(verify_api_key),
|
|
||||||
Depends(rate_limiter),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
async def run_swarm(swarm: SwarmSpec, x_api_key=Header(...)) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Run a swarm with the specified task.
|
|
||||||
"""
|
|
||||||
return await run_swarm_completion(swarm, x_api_key)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post(
|
|
||||||
"/v1/swarm/batch/completions",
|
|
||||||
dependencies=[
|
|
||||||
Depends(verify_api_key),
|
|
||||||
Depends(rate_limiter),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
async def run_batch_completions(
|
|
||||||
swarms: List[SwarmSpec], x_api_key=Header(...)
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Run a batch of swarms with the specified tasks.
|
|
||||||
"""
|
|
||||||
results = []
|
|
||||||
for swarm in swarms:
|
|
||||||
try:
|
|
||||||
# Call the existing run_swarm function for each swarm
|
|
||||||
result = await run_swarm_completion(swarm, x_api_key)
|
|
||||||
results.append(result)
|
|
||||||
except HTTPException as http_exc:
|
|
||||||
logger.error("HTTPException occurred: {}", http_exc.detail)
|
|
||||||
results.append(
|
|
||||||
{
|
|
||||||
"status": "error",
|
|
||||||
"swarm_name": swarm.name,
|
|
||||||
"detail": http_exc.detail,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error running swarm {}: {}", swarm.name, str(e))
|
|
||||||
logger.exception(e)
|
|
||||||
results.append(
|
|
||||||
{
|
|
||||||
"status": "error",
|
|
||||||
"swarm_name": swarm.name,
|
|
||||||
"detail": f"Failed to run swarm: {str(e)}",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
# Add this new endpoint
|
|
||||||
@app.get(
|
|
||||||
"/v1/swarm/logs",
|
|
||||||
dependencies=[
|
|
||||||
Depends(verify_api_key),
|
|
||||||
Depends(rate_limiter),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
async def get_logs(x_api_key: str = Header(...)) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Get all API request logs for the provided API key.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logs = await get_api_key_logs(x_api_key)
|
|
||||||
return {"status": "success", "count": len(logs), "logs": logs}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in get_logs endpoint: {str(e)}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# @app.post("/v1/swarm/cost-prediction")
|
|
||||||
# async def cost_prediction(swarm: SwarmSpec) -> Dict[str, Any]:
|
|
||||||
# """
|
|
||||||
# Predict the cost of running a swarm.
|
|
||||||
# """
|
|
||||||
# return {"status": "success", "cost": calculate_swarm_cost(swarm)})
|
|
||||||
|
|
||||||
|
|
||||||
@app.post(
|
|
||||||
"/v1/swarm/schedule",
|
|
||||||
dependencies=[
|
|
||||||
Depends(verify_api_key),
|
|
||||||
Depends(rate_limiter),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
async def schedule_swarm(
|
|
||||||
swarm: SwarmSpec, x_api_key: str = Header(...)
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Schedule a swarm to run at a specific time.
|
|
||||||
"""
|
|
||||||
if not swarm.schedule:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="Schedule information is required",
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Generate a unique job ID
|
|
||||||
job_id = f"swarm_{swarm.name}_{int(time())}"
|
|
||||||
|
|
||||||
# Create and start the scheduled job
|
|
||||||
job = ScheduledJob(
|
|
||||||
job_id=job_id,
|
|
||||||
scheduled_time=swarm.schedule.scheduled_time,
|
|
||||||
swarm=swarm,
|
|
||||||
api_key=x_api_key,
|
|
||||||
)
|
|
||||||
job.start()
|
|
||||||
|
|
||||||
# Store the job information
|
|
||||||
scheduled_jobs[job_id] = {
|
|
||||||
"job": job,
|
|
||||||
"swarm_name": swarm.name,
|
|
||||||
"scheduled_time": swarm.schedule.scheduled_time,
|
|
||||||
"timezone": swarm.schedule.timezone,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Log the scheduling
|
|
||||||
await log_api_request(
|
|
||||||
x_api_key,
|
|
||||||
{
|
|
||||||
"action": "schedule_swarm",
|
|
||||||
"swarm_name": swarm.name,
|
|
||||||
"scheduled_time": swarm.schedule.scheduled_time.isoformat(),
|
|
||||||
"job_id": job_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"message": "Swarm scheduled successfully",
|
|
||||||
"job_id": job_id,
|
|
||||||
"scheduled_time": swarm.schedule.scheduled_time,
|
|
||||||
"timezone": swarm.schedule.timezone,
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error scheduling swarm: {str(e)}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to schedule swarm: {str(e)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get(
|
|
||||||
"/v1/swarm/schedule",
|
|
||||||
dependencies=[
|
|
||||||
Depends(verify_api_key),
|
|
||||||
Depends(rate_limiter),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
async def get_scheduled_jobs(x_api_key: str = Header(...)) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Get all scheduled swarm jobs.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
jobs_list = []
|
|
||||||
current_time = datetime.now(pytz.UTC)
|
|
||||||
|
|
||||||
# Clean up completed jobs
|
|
||||||
completed_jobs = [
|
|
||||||
job_id
|
|
||||||
for job_id, job_info in scheduled_jobs.items()
|
|
||||||
if current_time >= job_info["scheduled_time"]
|
|
||||||
]
|
|
||||||
for job_id in completed_jobs:
|
|
||||||
scheduled_jobs.pop(job_id, None)
|
|
||||||
|
|
||||||
# Get active jobs
|
|
||||||
for job_id, job_info in scheduled_jobs.items():
|
|
||||||
jobs_list.append(
|
|
||||||
{
|
|
||||||
"job_id": job_id,
|
|
||||||
"swarm_name": job_info["swarm_name"],
|
|
||||||
"scheduled_time": job_info["scheduled_time"].isoformat(),
|
|
||||||
"timezone": job_info["timezone"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"status": "success", "scheduled_jobs": jobs_list}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error retrieving scheduled jobs: {str(e)}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to retrieve scheduled jobs: {str(e)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.delete(
|
|
||||||
"/v1/swarm/schedule/{job_id}",
|
|
||||||
dependencies=[
|
|
||||||
Depends(verify_api_key),
|
|
||||||
Depends(rate_limiter),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
async def cancel_scheduled_job(
|
|
||||||
job_id: str, x_api_key: str = Header(...)
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Cancel a scheduled swarm job.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if job_id not in scheduled_jobs:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail="Scheduled job not found"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Cancel and remove the job
|
|
||||||
job_info = scheduled_jobs[job_id]
|
|
||||||
job_info["job"].cancelled = True
|
|
||||||
scheduled_jobs.pop(job_id)
|
|
||||||
|
|
||||||
await log_api_request(
|
|
||||||
x_api_key, {"action": "cancel_scheduled_job", "job_id": job_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"message": "Scheduled job cancelled successfully",
|
|
||||||
"job_id": job_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error cancelling scheduled job: {str(e)}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to cancel scheduled job: {str(e)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# --- Main Entrypoint ---
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8080, workers=os.cpu_count())
|
|
Loading…
Reference in new issue