API CHANGES]

pull/712/head
Kye Gomez 10 months ago
parent 33e7f69450
commit 3a1a614d7b

@ -0,0 +1,74 @@
# Build stage
FROM python:3.11-slim as builder
# Set working directory
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y \
build-essential \
curl \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements from api folder
COPY api/requirements.txt .
RUN pip install --no-cache-dir wheel && \
pip wheel --no-cache-dir --no-deps --wheel-dir /app/wheels -r requirements.txt
# Final stage
FROM python:3.11-slim
# Set environment variables
ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
PATH="/app/venv/bin:$PATH" \
PYTHONPATH=/app \
PORT=8080
# Create app user
RUN useradd -m -s /bin/bash app && \
mkdir -p /app/logs && \
chown -R app:app /app
# Set working directory
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y \
curl \
&& rm -rf /var/lib/apt/lists/*
# Copy wheels from builder
COPY --from=builder /app/wheels /app/wheels
# Create and activate virtual environment
RUN python -m venv /app/venv && \
/app/venv/bin/pip install --no-cache-dir /app/wheels/*
# Copy application code
COPY --chown=app:app ./api ./api
# Switch to app user
USER app
# Create directories for logs
RUN mkdir -p /app/logs
# Required environment variables
ENV SUPABASE_URL="" \
SUPABASE_SERVICE_KEY="" \
ENVIRONMENT="production" \
LOG_LEVEL="info" \
WORKERS=4 \
MAX_REQUESTS_PER_MINUTE=60 \
API_KEY_LENGTH=32
# Expose port
EXPOSE $PORT
# Health check
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD curl -f http://localhost:$PORT/health || exit 1
# Start command
CMD ["sh", "-c", "uvicorn api.api:app --host 0.0.0.0 --port $PORT --workers $WORKERS --log-level $LOG_LEVEL"]

@ -1,16 +1,15 @@
import asyncio import asyncio
import os import os
import secrets
import signal import signal
import traceback import traceback
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta from datetime import datetime, timedelta
from enum import Enum from enum import Enum
from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from fastapi.concurrency import asynccontextmanager
import uvicorn import uvicorn
from dotenv import load_dotenv from dotenv import load_dotenv
from fastapi import ( from fastapi import (
@ -20,12 +19,14 @@ from fastapi import (
Header, Header,
HTTPException, HTTPException,
Query, Query,
Request,
status, status,
) )
from fastapi.concurrency import asynccontextmanager
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from loguru import logger from loguru import logger
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from supabase import Client, create_client
from swarms.structs.agent import Agent from swarms.structs.agent import Agent
@ -33,6 +34,90 @@ from swarms.structs.agent import Agent
load_dotenv() load_dotenv()
class APIKey(BaseModel):
"""Model matching Supabase api_keys table"""
id: UUID
created_at: datetime
name: str
user_id: UUID
key: str
limit_credit_dollar: Optional[float] = None
is_deleted: bool = False
class User(BaseModel):
id: UUID
name: str
is_active: bool = True
is_admin: bool = False
@lru_cache()
def get_supabase() -> Client:
"""Get cached Supabase client"""
supabase_url = os.getenv("SUPABASE_URL")
supabase_key = os.getenv("SUPABASE_SERVICE_KEY")
if not supabase_url or not supabase_key:
raise ValueError("Supabase configuration is missing")
return create_client(supabase_url, supabase_key)
async def get_current_user(
api_key: str = Header(..., description="API key for authentication"),
) -> User:
"""Validate API key against Supabase and return current user."""
if not api_key or not api_key.startswith('sk-'):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key format",
headers={"WWW-Authenticate": "ApiKey"},
)
try:
supabase = get_supabase()
# Query the api_keys table
response = supabase.table('api_keys').select(
'id, name, user_id, key, limit_credit_dollar, is_deleted'
).eq('key', api_key).single().execute()
if not response.data:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "ApiKey"},
)
key_data = response.data
# Check if key is deleted
if key_data['is_deleted']:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="API key has been deleted",
headers={"WWW-Authenticate": "ApiKey"},
)
# Check credit limit if applicable
if key_data['limit_credit_dollar'] is not None and key_data['limit_credit_dollar'] <= 0:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="API key credit limit exceeded"
)
# Create user object
return User(
id=key_data['user_id'],
name=key_data['name'],
is_active=not key_data['is_deleted']
)
except Exception as e:
logger.error(f"Error validating API key: {str(e)}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="API key validation failed",
headers={"WWW-Authenticate": "ApiKey"},
)
class UvicornServer(uvicorn.Server): class UvicornServer(uvicorn.Server):
"""Customized uvicorn server with graceful shutdown support""" """Customized uvicorn server with graceful shutdown support"""
@ -59,14 +144,6 @@ class AgentStatus(str, Enum):
API_KEY_LENGTH = 32 # Length of generated API keys API_KEY_LENGTH = 32 # Length of generated API keys
class APIKey(BaseModel):
key: str
name: str
created_at: datetime
last_used: datetime
is_active: bool = True
class APIKeyCreate(BaseModel): class APIKeyCreate(BaseModel):
name: str # A friendly name for the API key name: str # A friendly name for the API key
@ -214,11 +291,7 @@ class AgentStore:
def __init__(self): def __init__(self):
self.agents: Dict[UUID, Agent] = {} self.agents: Dict[UUID, Agent] = {}
self.agent_metadata: Dict[UUID, Dict[str, Any]] = {} self.agent_metadata: Dict[UUID, Dict[str, Any]] = {}
self.users: Dict[UUID, User] = {} # user_id -> User self.user_agents: Dict[UUID, List[UUID]] = {} # user_id -> [agent_ids]
self.api_keys: Dict[str, UUID] = {} # api_key -> user_id
self.user_agents: Dict[UUID, List[UUID]] = (
{}
) # user_id -> [agent_ids]
self.executor = ThreadPoolExecutor(max_workers=4) self.executor = ThreadPoolExecutor(max_workers=4)
self._ensure_directories() self._ensure_directories()
@ -227,31 +300,6 @@ class AgentStore:
Path("logs").mkdir(exist_ok=True) Path("logs").mkdir(exist_ok=True)
Path("states").mkdir(exist_ok=True) Path("states").mkdir(exist_ok=True)
def create_api_key(self, user_id: UUID, key_name: str) -> APIKey:
"""Create a new API key for a user."""
if user_id not in self.users:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
# Generate a secure random API key
api_key = secrets.token_urlsafe(API_KEY_LENGTH)
# Create the API key object
key_object = APIKey(
key=api_key,
name=key_name,
created_at=datetime.utcnow(),
last_used=datetime.utcnow(),
)
# Store the API key
self.users[user_id].api_keys[api_key] = key_object
self.api_keys[api_key] = user_id
return key_object
async def verify_agent_access( async def verify_agent_access(
self, agent_id: UUID, user_id: UUID self, agent_id: UUID, user_id: UUID
) -> bool: ) -> bool:
@ -656,94 +704,11 @@ class SwarmsAPI:
self._setup_routes() self._setup_routes()
def _setup_routes(self): def _setup_routes(self):
"""Set up API routes.""" """Set up API routes with Supabase authentication."""
# In your API code
# Modify the create_user endpoint
@self.app.post("/v1/users", response_model=Dict[str, Any])
async def create_user(request: Request):
"""Create a new user and initial API key."""
try:
body = await request.json()
username = body.get("username")
if not username or len(username) < 3:
raise HTTPException(
status_code=400, detail="Invalid username"
)
user_id = uuid4()
user = User(id=user_id, username=username)
self.store.users[user_id] = user
# Always create initial API key
initial_key = self.store.create_api_key(
user_id, "Initial Key"
)
if not initial_key:
raise HTTPException(
status_code=500,
detail="Failed to create initial API key",
)
return {
"user_id": user_id,
"api_key": initial_key.key,
}
except Exception as e:
logger.error(f"Error creating user: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
@self.app.get( @self.app.get(
"/v1/users/{user_id}/api-keys", "/v1/users/me/agents",
response_model=List[APIKey], response_model=List[AgentSummary]
)
async def list_api_keys(
user_id: UUID,
current_user: User = Depends(get_current_user),
):
"""List all API keys for a user."""
if (
current_user.id != user_id
and not current_user.is_admin
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authorized to view API keys for this user",
)
return list(self.store.users[user_id].api_keys.values())
@self.app.delete("/v1/users/{user_id}/api-keys/{key}")
async def revoke_api_key(
user_id: UUID,
key: str,
current_user: User = Depends(get_current_user),
):
"""Revoke an API key."""
if (
current_user.id != user_id
and not current_user.is_admin
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authorized to revoke API keys for this user",
)
if key in self.store.users[user_id].api_keys:
self.store.users[user_id].api_keys[
key
].is_active = False
del self.store.api_keys[key]
return {"status": "API key revoked"}
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found",
)
@self.app.get(
"/v1/users/me/agents", response_model=List[AgentSummary]
) )
async def list_user_agents( async def list_user_agents(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
@ -751,24 +716,20 @@ class SwarmsAPI:
status: Optional[AgentStatus] = None, status: Optional[AgentStatus] = None,
): ):
"""List all agents owned by the current user.""" """List all agents owned by the current user."""
user_agents = self.store.user_agents.get( user_agents = self.store.user_agents.get(current_user.id, [])
current_user.id, []
)
return [ return [
agent agent
for agent in await self.store.list_agents( for agent in await self.store.list_agents(tags, status)
tags, status
)
if agent.agent_id in user_agents if agent.agent_id in user_agents
] ]
# Modify existing routes to use API key authentication
@self.app.post("/v1/agent", response_model=Dict[str, UUID]) @self.app.post("/v1/agent", response_model=Dict[str, UUID])
async def create_agent( async def create_agent(
config: AgentConfig, config: AgentConfig,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Create a new agent with the specified configuration.""" """Create a new agent with the specified configuration."""
logger.info(f"User {current_user.id} creating new agent")
agent_id = await self.store.create_agent( agent_id = await self.store.create_agent(
config, current_user.id config, current_user.id
) )
@ -776,51 +737,115 @@ class SwarmsAPI:
@self.app.get("/v1/agents", response_model=List[AgentSummary]) @self.app.get("/v1/agents", response_model=List[AgentSummary])
async def list_agents( async def list_agents(
current_user: User = Depends(get_current_user),
tags: Optional[List[str]] = Query(None), tags: Optional[List[str]] = Query(None),
status: Optional[AgentStatus] = None, status: Optional[AgentStatus] = None,
): ):
"""List all agents, optionally filtered by tags and status.""" """List all agents, optionally filtered by tags and status."""
return await self.store.list_agents(tags, status) agents = await self.store.list_agents(tags, status)
# Filter agents based on user access
return [
agent for agent in agents
if await self.store.verify_agent_access(agent.agent_id, current_user.id)
]
@self.app.patch( @self.app.patch(
"/v1/agent/{agent_id}", response_model=Dict[str, str] "/v1/agent/{agent_id}",
response_model=Dict[str, str]
) )
async def update_agent(agent_id: UUID, update: AgentUpdate): async def update_agent(
agent_id: UUID,
update: AgentUpdate,
current_user: User = Depends(get_current_user)
):
"""Update an existing agent's configuration.""" """Update an existing agent's configuration."""
if not await self.store.verify_agent_access(agent_id, current_user.id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authorized to update this agent"
)
await self.store.update_agent(agent_id, update) await self.store.update_agent(agent_id, update)
return {"status": "updated"} return {"status": "updated"}
@self.app.get( @self.app.get(
"/v1/agent/{agent_id}/metrics", "/v1/agent/{agent_id}/metrics",
response_model=AgentMetrics, response_model=AgentMetrics
) )
async def get_agent_metrics(agent_id: UUID): async def get_agent_metrics(
agent_id: UUID,
current_user: User = Depends(get_current_user)
):
"""Get performance metrics for a specific agent.""" """Get performance metrics for a specific agent."""
if not await self.store.verify_agent_access(agent_id, current_user.id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authorized to view this agent's metrics"
)
return await self.store.get_agent_metrics(agent_id) return await self.store.get_agent_metrics(agent_id)
@self.app.post( @self.app.post(
"/v1/agent/{agent_id}/clone", "/v1/agent/{agent_id}/clone",
response_model=Dict[str, UUID], response_model=Dict[str, UUID]
) )
async def clone_agent(agent_id: UUID, new_name: str): async def clone_agent(
agent_id: UUID,
new_name: str,
current_user: User = Depends(get_current_user)
):
"""Clone an existing agent with a new name.""" """Clone an existing agent with a new name."""
if not await self.store.verify_agent_access(agent_id, current_user.id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authorized to clone this agent"
)
new_id = await self.store.clone_agent(agent_id, new_name) new_id = await self.store.clone_agent(agent_id, new_name)
# Add the cloned agent to user's agents
if current_user.id not in self.store.user_agents:
self.store.user_agents[current_user.id] = []
self.store.user_agents[current_user.id].append(new_id)
return {"agent_id": new_id} return {"agent_id": new_id}
@self.app.delete("/v1/agent/{agent_id}") @self.app.delete("/v1/agent/{agent_id}")
async def delete_agent(agent_id: UUID): async def delete_agent(
agent_id: UUID,
current_user: User = Depends(get_current_user)
):
"""Delete an agent.""" """Delete an agent."""
if not await self.store.verify_agent_access(agent_id, current_user.id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authorized to delete this agent"
)
await self.store.delete_agent(agent_id) await self.store.delete_agent(agent_id)
# Remove from user's agents list
if current_user.id in self.store.user_agents:
self.store.user_agents[current_user.id] = [
aid for aid in self.store.user_agents[current_user.id]
if aid != agent_id
]
return {"status": "deleted"} return {"status": "deleted"}
@self.app.post( @self.app.post(
"/v1/agent/completions", response_model=CompletionResponse "/v1/agent/completions",
response_model=CompletionResponse
) )
async def create_completion( async def create_completion(
request: CompletionRequest, request: CompletionRequest,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_user)
): ):
"""Process a completion request with the specified agent.""" """Process a completion request with the specified agent."""
if not await self.store.verify_agent_access(request.agent_id, current_user.id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authorized to use this agent"
)
try: try:
agent = await self.store.get_agent(request.agent_id) agent = await self.store.get_agent(request.agent_id)
@ -830,12 +855,13 @@ class SwarmsAPI:
request.prompt, request.prompt,
request.agent_id, request.agent_id,
request.max_tokens, request.max_tokens,
0.5, request.temperature_override
) )
# Schedule background cleanup # Schedule background cleanup
background_tasks.add_task( background_tasks.add_task(
self._cleanup_old_metrics, request.agent_id self._cleanup_old_metrics,
request.agent_id
) )
return response return response
@ -844,26 +870,55 @@ class SwarmsAPI:
logger.error(f"Error processing completion: {str(e)}") logger.error(f"Error processing completion: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error processing completion: {str(e)}", detail=f"Error processing completion: {str(e)}"
) )
@self.app.get("/v1/agent/{agent_id}/status") @self.app.get("/v1/agent/{agent_id}/status")
async def get_agent_status(agent_id: UUID): async def get_agent_status(
agent_id: UUID,
current_user: User = Depends(get_current_user)
):
"""Get the current status of an agent.""" """Get the current status of an agent."""
if not await self.store.verify_agent_access(agent_id, current_user.id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authorized to view this agent's status"
)
metadata = self.store.agent_metadata.get(agent_id) metadata = self.store.agent_metadata.get(agent_id)
if not metadata: if not metadata:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=f"Agent {agent_id} not found", detail=f"Agent {agent_id} not found"
) )
return { return {
"agent_id": agent_id, "agent_id": agent_id,
"status": metadata["status"], "status": metadata["status"],
"last_used": metadata["last_used"], "last_used": metadata["last_used"],
"total_completions": metadata["total_completions"], "total_completions": metadata["total_completions"],
"error_count": metadata["error_count"], "error_count": metadata["error_count"]
} }
@self.app.get("/health")
async def health_check():
"""Health check endpoint - no auth required."""
try:
# Test Supabase connection
supabase = get_supabase()
supabase.table('api_keys').select('count', count='exact').execute()
return {"status": "healthy", "database": "connected"}
except Exception as e:
logger.error(f"Health check failed: {str(e)}")
return JSONResponse(
status_code=503,
content={
"status": "unhealthy",
"database": "disconnected",
"error": str(e)
}
)
async def _cleanup_old_metrics(self, agent_id: UUID): async def _cleanup_old_metrics(self, agent_id: UUID):
"""Clean up old metrics data to prevent memory bloat.""" """Clean up old metrics data to prevent memory bloat."""
metadata = self.store.agent_metadata.get(agent_id) metadata = self.store.agent_metadata.get(agent_id)
@ -888,7 +943,7 @@ class SwarmsAPI:
class APIServer: class APIServer:
def __init__( def __init__(
self, app: FastAPI, host: str = "0.0.0.0", port: int = 8000 self, app: FastAPI, host: str = "0.0.0.0", port: int = 8080
): ):
self.app = app self.app = app
self.host = host self.host = host

@ -9,3 +9,5 @@ opentelemetry-sdk
opentelemetry-instrumentation-fastapi opentelemetry-instrumentation-fastapi
opentelemetry-instrumentation-requests opentelemetry-instrumentation-requests
opentelemetry-exporter-otlp-proto-grpc opentelemetry-exporter-otlp-proto-grpc
swarms
supabase

@ -0,0 +1,353 @@
from __future__ import annotations
import asyncio
import base64
import io
import threading
from os import getenv
from typing import Any, Awaitable, Callable, cast
import numpy as np
try:
import pyaudio
except ImportError:
import subprocess
subprocess.check_call(["pip", "install", "pyaudio"])
import pyaudio
try:
import sounddevice as sd
except ImportError:
import subprocess
subprocess.check_call(["pip", "install", "sounddevice"])
import sounddevice as sd
from loguru import logger
from openai import AsyncOpenAI
from openai.resources.beta.realtime.realtime import (
AsyncRealtimeConnection,
)
from openai.types.beta.realtime.session import Session
try:
from pydub import AudioSegment
except ImportError:
import subprocess
subprocess.check_call(["pip", "install", "pydub"])
from pydub import AudioSegment
from dotenv import load_dotenv
load_dotenv()
CHUNK_LENGTH_S = 0.05 # 100ms
SAMPLE_RATE = 24000
FORMAT = pyaudio.paInt16
CHANNELS = 1
# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
def audio_to_pcm16_base64(audio_bytes: bytes) -> bytes:
# load the audio file from the byte stream
audio = AudioSegment.from_file(io.BytesIO(audio_bytes))
print(f"Loaded audio: {audio.frame_rate=} {audio.channels=} {audio.sample_width=} {audio.frame_width=}")
# resample to 24kHz mono pcm16
pcm_audio = audio.set_frame_rate(SAMPLE_RATE).set_channels(CHANNELS).set_sample_width(2).raw_data
return pcm_audio
class AudioPlayerAsync:
def __init__(self):
self.queue = []
self.lock = threading.Lock()
self.stream = sd.OutputStream(
callback=self.callback,
samplerate=SAMPLE_RATE,
channels=CHANNELS,
dtype=np.int16,
blocksize=int(CHUNK_LENGTH_S * SAMPLE_RATE),
)
self.playing = False
self._frame_count = 0
def callback(self, outdata, frames, time, status): # noqa
with self.lock:
data = np.empty(0, dtype=np.int16)
# get next item from queue if there is still space in the buffer
while len(data) < frames and len(self.queue) > 0:
item = self.queue.pop(0)
frames_needed = frames - len(data)
data = np.concatenate((data, item[:frames_needed]))
if len(item) > frames_needed:
self.queue.insert(0, item[frames_needed:])
self._frame_count += len(data)
# fill the rest of the frames with zeros if there is no more data
if len(data) < frames:
data = np.concatenate((data, np.zeros(frames - len(data), dtype=np.int16)))
outdata[:] = data.reshape(-1, 1)
def reset_frame_count(self):
self._frame_count = 0
def get_frame_count(self):
return self._frame_count
def add_data(self, data: bytes):
with self.lock:
# bytes is pcm16 single channel audio data, convert to numpy array
np_data = np.frombuffer(data, dtype=np.int16)
self.queue.append(np_data)
if not self.playing:
self.start()
def start(self):
self.playing = True
self.stream.start()
def stop(self):
self.playing = False
self.stream.stop()
with self.lock:
self.queue = []
def terminate(self):
self.stream.close()
async def send_audio_worker_sounddevice(
connection: AsyncRealtimeConnection,
should_send: Callable[[], bool] | None = None,
start_send: Callable[[], Awaitable[None]] | None = None,
):
sent_audio = False
device_info = sd.query_devices()
print(device_info)
read_size = int(SAMPLE_RATE * 0.02)
stream = sd.InputStream(
channels=CHANNELS,
samplerate=SAMPLE_RATE,
dtype="int16",
)
stream.start()
try:
while True:
if stream.read_available < read_size:
await asyncio.sleep(0)
continue
data, _ = stream.read(read_size)
if should_send() if should_send else True:
if not sent_audio and start_send:
await start_send()
await connection.send(
{"type": "input_audio_buffer.append", "audio": base64.b64encode(data).decode("utf-8")}
)
sent_audio = True
elif sent_audio:
print("Done, triggering inference")
await connection.send({"type": "input_audio_buffer.commit"})
await connection.send({"type": "response.create", "response": {}})
sent_audio = False
await asyncio.sleep(0)
except KeyboardInterrupt:
pass
finally:
stream.stop()
stream.close()
class RealtimeApp:
"""
A console-based application to handle real-time audio recording and streaming,
connecting to OpenAI's GPT-4 Realtime API.
Features:
- Streams microphone input to the GPT-4 Realtime API.
- Logs transcription results.
- Sends text prompts to the GPT-4 Realtime API.
"""
def __init__(self, system_prompt: str = None) -> None:
self.connection: AsyncRealtimeConnection | None = None
self.session: Session | None = None
self.client = AsyncOpenAI(api_key=getenv("OPENAI_API_KEY"))
self.audio_player = AudioPlayerAsync()
self.last_audio_item_id: str | None = None
self.should_send_audio = asyncio.Event()
self.connected = asyncio.Event()
self.system_prompt = system_prompt
async def initialize_text_prompt(self, text: str) -> None:
"""Initialize and send a text prompt to the OpenAI Realtime API."""
try:
async with self.client.beta.realtime.connect(model="gpt-4o-realtime-preview-2024-10-01") as conn:
self.connection = conn
await conn.session.update(session={"modalities": ["text"]})
await conn.conversation.item.create(
item={
"type": "message",
"role": "system",
"content": [{"type": "input_text", "text": text}],
}
)
await conn.response.create()
async for event in conn:
if event.type == "response.text.delta":
print(event.delta, flush=True, end="")
elif event.type == "response.text.done":
print()
elif event.type == "response.done":
break
except Exception as e:
logger.exception(f"Error initializing text prompt: {e}")
async def handle_realtime_connection(self) -> None:
"""Handle the connection to the OpenAI Realtime API."""
try:
async with self.client.beta.realtime.connect(model="gpt-4o-realtime-preview-2024-10-01") as conn:
self.connection = conn
self.connected.set()
logger.info("Connected to OpenAI Realtime API.")
await conn.session.update(session={"turn_detection": {"type": "server_vad"}})
acc_items: dict[str, Any] = {}
async for event in conn:
if event.type == "session.created":
self.session = event.session
assert event.session.id is not None
logger.info(f"Session created with ID: {event.session.id}")
continue
if event.type == "session.updated":
self.session = event.session
logger.info("Session updated.")
continue
if event.type == "response.audio.delta":
if event.item_id != self.last_audio_item_id:
self.audio_player.reset_frame_count()
self.last_audio_item_id = event.item_id
bytes_data = base64.b64decode(event.delta)
self.audio_player.add_data(bytes_data)
continue
if event.type == "response.audio_transcript.delta":
try:
text = acc_items[event.item_id]
except KeyError:
acc_items[event.item_id] = event.delta
else:
acc_items[event.item_id] = text + event.delta
logger.debug(f"Transcription updated: {acc_items[event.item_id]}")
continue
if event.type == "response.text.delta":
print(event.delta, flush=True, end="")
continue
if event.type == "response.text.done":
print()
continue
if event.type == "response.done":
break
except Exception as e:
logger.exception(f"Error in realtime connection handler: {e}")
async def _get_connection(self) -> AsyncRealtimeConnection:
"""Wait for and return the realtime connection."""
await self.connected.wait()
assert self.connection is not None
return self.connection
async def send_text_prompt(self, text: str) -> None:
"""Send a text prompt to the OpenAI Realtime API."""
try:
connection = await self._get_connection()
if not self.session:
logger.error("Session is not initialized. Cannot send prompt.")
return
logger.info(f"Sending prompt to the model: {text}")
await connection.conversation.item.create(
item={
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": text}],
}
)
await connection.response.create()
except Exception as e:
logger.exception(f"Error sending text prompt: {e}")
async def send_mic_audio(self) -> None:
"""Stream microphone audio to the OpenAI Realtime API."""
import sounddevice as sd # type: ignore
sent_audio = False
try:
read_size = int(SAMPLE_RATE * 0.02)
stream = sd.InputStream(
channels=CHANNELS, samplerate=SAMPLE_RATE, dtype="int16"
)
stream.start()
while True:
if stream.read_available < read_size:
await asyncio.sleep(0)
continue
await self.should_send_audio.wait()
data, _ = stream.read(read_size)
connection = await self._get_connection()
if not sent_audio:
asyncio.create_task(connection.send({"type": "response.cancel"}))
sent_audio = True
await connection.input_audio_buffer.append(audio=base64.b64encode(cast(Any, data)).decode("utf-8"))
await asyncio.sleep(0)
except Exception as e:
logger.exception(f"Error in microphone audio streaming: {e}")
finally:
stream.stop()
stream.close()
async def run(self) -> None:
"""Start the application tasks."""
logger.info("Starting application tasks.")
await asyncio.gather(
# self.initialize_text_prompt(self.system_prompt),
self.handle_realtime_connection(),
self.send_mic_audio()
)
if __name__ == "__main__":
logger.add("realtime_app.log", rotation="10 MB", retention="10 days", level="DEBUG")
logger.info("Starting RealtimeApp.")
app = RealtimeApp()
asyncio.run(app.run())
Loading…
Cancel
Save