You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/swarm_cloud_api/api.py

227 lines
6.1 KiB

"""
SkyServe API: Production-grade FastAPI server for SimpleSkyServe.
This module provides a REST API interface for managing SkyPilot services with
proper error handling, validation, and production configurations.
"""
import multiprocessing
import os
from typing import List, Optional
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from loguru import logger
from pydantic import BaseModel, Field
from pydantic.v1 import validator
from swarm_cloud_code import ServiceConfig, SimpleSkyServe, UpdateMode
# Calculate optimal number of workers
CPU_COUNT = multiprocessing.cpu_count()
WORKERS = CPU_COUNT * 2
# Configure logging
logger.add(
"logs/skyserve-api.log",
rotation="500 MB",
retention="10 days",
level="INFO",
)
# Initialize FastAPI app
app = FastAPI(
title="SkyServe API",
description="REST API for managing SkyPilot services",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc",
)
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=os.getenv("ALLOWED_ORIGINS", "*").split(","),
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Pydantic models for request/response validation
class ServiceConfigRequest(BaseModel):
"""Request model for service configuration."""
code: str = Field(
..., description="Python code to run as a service"
)
requirements: Optional[List[str]] = Field(
default=None, description="List of pip packages"
)
envs: Optional[dict] = Field(
default=None, description="Environment variables"
)
name: Optional[str] = Field(
default=None, description="Service name"
)
num_cpus: int = Field(
default=2, ge=1, description="Number of CPUs"
)
memory: int = Field(default=4, ge=1, description="Memory in GB")
use_spot: bool = Field(
default=False, description="Use spot instances"
)
num_nodes: int = Field(
default=1, ge=1, description="Number of nodes"
)
@validator("name")
def validate_name(cls, v):
if v and not v.isalnum():
raise ValueError("Service name must be alphanumeric")
return v
class DeploymentResponse(BaseModel):
"""Response model for deployment information."""
service_name: str
endpoint: str
class ServiceStatusResponse(BaseModel):
"""Response model for service status."""
name: str
status: str
versions: List[int]
replicas: int
resources: str
uptime: int
endpoint: Optional[str]
@app.post(
"/services/",
response_model=DeploymentResponse,
status_code=status.HTTP_201_CREATED,
tags=["services"],
)
async def create_service(config: ServiceConfigRequest):
"""Deploy a new service."""
try:
service_config = ServiceConfig(
code=config.code,
requirements=config.requirements,
envs=config.envs,
name=config.name,
num_cpus=config.num_cpus,
memory=config.memory,
use_spot=config.use_spot,
num_nodes=config.num_nodes,
)
name, endpoint = SimpleSkyServe.deploy(service_config)
return {"service_name": name, "endpoint": endpoint}
except Exception as e:
logger.error(f"Failed to create service: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)
@app.get(
"/services/",
response_model=List[ServiceStatusResponse],
tags=["services"],
)
async def list_services(name: Optional[str] = None):
"""Get status of all services or a specific service."""
try:
deployments = SimpleSkyServe.get_deployments(name)
return deployments
except Exception as e:
logger.error(f"Failed to list services: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)
@app.put(
"/services/{service_name}",
status_code=status.HTTP_200_OK,
tags=["services"],
)
async def update_service(
service_name: str,
config: ServiceConfigRequest,
mode: UpdateMode = UpdateMode.GRADUAL,
):
"""Update an existing service."""
try:
service_config = ServiceConfig(
code=config.code,
requirements=config.requirements,
envs=config.envs,
name=config.name,
num_cpus=config.num_cpus,
memory=config.memory,
use_spot=config.use_spot,
num_nodes=config.num_nodes,
)
SimpleSkyServe.update(service_name, service_config, mode)
return {
"message": f"Service {service_name} updated successfully"
}
except Exception as e:
logger.error(f"Failed to update service: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)
@app.delete(
"/services/{service_name}",
status_code=status.HTTP_204_NO_CONTENT,
tags=["services"],
)
async def delete_service(service_name: str, purge: bool = False):
"""Delete a service."""
try:
SimpleSkyServe.delete(service_name, purge)
except Exception as e:
logger.error(f"Failed to delete service: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)
@app.exception_handler(Exception)
async def general_exception_handler(request, exc):
"""Global exception handler."""
logger.error(f"Unhandled exception: {str(exc)}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": "Internal server error"},
)
# Entry point for uvicorn
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"api:app",
host="0.0.0.0",
port=8000,
workers=WORKERS,
log_level="info",
reload=False, # Disable in production
proxy_headers=True,
forwarded_allow_ips="*",
access_log=True,
)