diff --git a/Dockerfile.api b/Dockerfile.api new file mode 100644 index 00000000..4dd794ca --- /dev/null +++ b/Dockerfile.api @@ -0,0 +1,74 @@ +# Build stage +FROM python:3.11-slim as builder + +# Set working directory +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + build-essential \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements from api folder +COPY api/requirements.txt . +RUN pip install --no-cache-dir wheel && \ + pip wheel --no-cache-dir --no-deps --wheel-dir /app/wheels -r requirements.txt + +# Final stage +FROM python:3.11-slim + +# Set environment variables +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PATH="/app/venv/bin:$PATH" \ + PYTHONPATH=/app \ + PORT=8080 + +# Create app user +RUN useradd -m -s /bin/bash app && \ + mkdir -p /app/logs && \ + chown -R app:app /app + +# Set working directory +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Copy wheels from builder +COPY --from=builder /app/wheels /app/wheels + +# Create and activate virtual environment +RUN python -m venv /app/venv && \ + /app/venv/bin/pip install --no-cache-dir /app/wheels/* + +# Copy application code +COPY --chown=app:app ./api ./api + +# Switch to app user +USER app + +# Create directories for logs +RUN mkdir -p /app/logs + +# Required environment variables +ENV SUPABASE_URL="" \ + SUPABASE_SERVICE_KEY="" \ + ENVIRONMENT="production" \ + LOG_LEVEL="info" \ + WORKERS=4 \ + MAX_REQUESTS_PER_MINUTE=60 \ + API_KEY_LENGTH=32 + +# Expose port +EXPOSE $PORT + +# Health check +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:$PORT/health || exit 1 + +# Start command +CMD ["sh", "-c", "uvicorn api.api:app --host 0.0.0.0 --port $PORT --workers $WORKERS --log-level $LOG_LEVEL"] \ No newline at end of file diff --git a/api/main.py b/api/main.py index b210a29c..75367f9b 100644 --- a/api/main.py +++ b/api/main.py @@ -1,16 +1,15 @@ import asyncio import os -import secrets import signal import traceback from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta from enum import Enum +from functools import lru_cache from pathlib import Path from typing import Any, AsyncGenerator, Dict, List, Optional from uuid import UUID, uuid4 -from fastapi.concurrency import asynccontextmanager import uvicorn from dotenv import load_dotenv from fastapi import ( @@ -20,12 +19,14 @@ from fastapi import ( Header, HTTPException, Query, - Request, status, ) +from fastapi.concurrency import asynccontextmanager from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse from loguru import logger from pydantic import BaseModel, Field +from supabase import Client, create_client from swarms.structs.agent import Agent @@ -33,6 +34,90 @@ from swarms.structs.agent import Agent load_dotenv() +class APIKey(BaseModel): + """Model matching Supabase api_keys table""" + id: UUID + created_at: datetime + name: str + user_id: UUID + key: str + limit_credit_dollar: Optional[float] = None + is_deleted: bool = False + +class User(BaseModel): + id: UUID + name: str + is_active: bool = True + is_admin: bool = False + +@lru_cache() +def get_supabase() -> Client: + """Get cached Supabase client""" + supabase_url = os.getenv("SUPABASE_URL") + supabase_key = os.getenv("SUPABASE_SERVICE_KEY") + if not supabase_url or not supabase_key: + raise ValueError("Supabase configuration is missing") + return create_client(supabase_url, supabase_key) + +async def get_current_user( + api_key: str = Header(..., description="API key for authentication"), +) -> User: + """Validate API key against Supabase and return current user.""" + if not api_key or not api_key.startswith('sk-'): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key format", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + try: + supabase = get_supabase() + + # Query the api_keys table + response = supabase.table('api_keys').select( + 'id, name, user_id, key, limit_credit_dollar, is_deleted' + ).eq('key', api_key).single().execute() + + if not response.data: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + key_data = response.data + + # Check if key is deleted + if key_data['is_deleted']: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="API key has been deleted", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + # Check credit limit if applicable + if key_data['limit_credit_dollar'] is not None and key_data['limit_credit_dollar'] <= 0: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="API key credit limit exceeded" + ) + + # Create user object + return User( + id=key_data['user_id'], + name=key_data['name'], + is_active=not key_data['is_deleted'] + ) + + except Exception as e: + logger.error(f"Error validating API key: {str(e)}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="API key validation failed", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + class UvicornServer(uvicorn.Server): """Customized uvicorn server with graceful shutdown support""" @@ -59,14 +144,6 @@ class AgentStatus(str, Enum): API_KEY_LENGTH = 32 # Length of generated API keys -class APIKey(BaseModel): - key: str - name: str - created_at: datetime - last_used: datetime - is_active: bool = True - - class APIKeyCreate(BaseModel): name: str # A friendly name for the API key @@ -214,11 +291,7 @@ class AgentStore: def __init__(self): self.agents: Dict[UUID, Agent] = {} self.agent_metadata: Dict[UUID, Dict[str, Any]] = {} - self.users: Dict[UUID, User] = {} # user_id -> User - self.api_keys: Dict[str, UUID] = {} # api_key -> user_id - self.user_agents: Dict[UUID, List[UUID]] = ( - {} - ) # user_id -> [agent_ids] + self.user_agents: Dict[UUID, List[UUID]] = {} # user_id -> [agent_ids] self.executor = ThreadPoolExecutor(max_workers=4) self._ensure_directories() @@ -227,31 +300,6 @@ class AgentStore: Path("logs").mkdir(exist_ok=True) Path("states").mkdir(exist_ok=True) - def create_api_key(self, user_id: UUID, key_name: str) -> APIKey: - """Create a new API key for a user.""" - if user_id not in self.users: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found", - ) - - # Generate a secure random API key - api_key = secrets.token_urlsafe(API_KEY_LENGTH) - - # Create the API key object - key_object = APIKey( - key=api_key, - name=key_name, - created_at=datetime.utcnow(), - last_used=datetime.utcnow(), - ) - - # Store the API key - self.users[user_id].api_keys[api_key] = key_object - self.api_keys[api_key] = user_id - - return key_object - async def verify_agent_access( self, agent_id: UUID, user_id: UUID ) -> bool: @@ -656,94 +704,11 @@ class SwarmsAPI: self._setup_routes() def _setup_routes(self): - """Set up API routes.""" - - # In your API code - - # Modify the create_user endpoint - @self.app.post("/v1/users", response_model=Dict[str, Any]) - async def create_user(request: Request): - """Create a new user and initial API key.""" - try: - body = await request.json() - username = body.get("username") - if not username or len(username) < 3: - raise HTTPException( - status_code=400, detail="Invalid username" - ) - - user_id = uuid4() - user = User(id=user_id, username=username) - self.store.users[user_id] = user - - # Always create initial API key - initial_key = self.store.create_api_key( - user_id, "Initial Key" - ) - if not initial_key: - raise HTTPException( - status_code=500, - detail="Failed to create initial API key", - ) - - return { - "user_id": user_id, - "api_key": initial_key.key, - } - except Exception as e: - logger.error(f"Error creating user: {str(e)}") - raise HTTPException(status_code=400, detail=str(e)) - - @self.app.get( - "/v1/users/{user_id}/api-keys", - response_model=List[APIKey], - ) - async def list_api_keys( - user_id: UUID, - current_user: User = Depends(get_current_user), - ): - """List all API keys for a user.""" - if ( - current_user.id != user_id - and not current_user.is_admin - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Not authorized to view API keys for this user", - ) - - return list(self.store.users[user_id].api_keys.values()) - - @self.app.delete("/v1/users/{user_id}/api-keys/{key}") - async def revoke_api_key( - user_id: UUID, - key: str, - current_user: User = Depends(get_current_user), - ): - """Revoke an API key.""" - if ( - current_user.id != user_id - and not current_user.is_admin - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Not authorized to revoke API keys for this user", - ) - - if key in self.store.users[user_id].api_keys: - self.store.users[user_id].api_keys[ - key - ].is_active = False - del self.store.api_keys[key] - return {"status": "API key revoked"} - - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="API key not found", - ) + """Set up API routes with Supabase authentication.""" @self.app.get( - "/v1/users/me/agents", response_model=List[AgentSummary] + "/v1/users/me/agents", + response_model=List[AgentSummary] ) async def list_user_agents( current_user: User = Depends(get_current_user), @@ -751,24 +716,20 @@ class SwarmsAPI: status: Optional[AgentStatus] = None, ): """List all agents owned by the current user.""" - user_agents = self.store.user_agents.get( - current_user.id, [] - ) + user_agents = self.store.user_agents.get(current_user.id, []) return [ agent - for agent in await self.store.list_agents( - tags, status - ) + for agent in await self.store.list_agents(tags, status) if agent.agent_id in user_agents ] - # Modify existing routes to use API key authentication @self.app.post("/v1/agent", response_model=Dict[str, UUID]) async def create_agent( config: AgentConfig, current_user: User = Depends(get_current_user), ): """Create a new agent with the specified configuration.""" + logger.info(f"User {current_user.id} creating new agent") agent_id = await self.store.create_agent( config, current_user.id ) @@ -776,51 +737,115 @@ class SwarmsAPI: @self.app.get("/v1/agents", response_model=List[AgentSummary]) async def list_agents( + current_user: User = Depends(get_current_user), tags: Optional[List[str]] = Query(None), status: Optional[AgentStatus] = None, ): """List all agents, optionally filtered by tags and status.""" - return await self.store.list_agents(tags, status) + agents = await self.store.list_agents(tags, status) + # Filter agents based on user access + return [ + agent for agent in agents + if await self.store.verify_agent_access(agent.agent_id, current_user.id) + ] @self.app.patch( - "/v1/agent/{agent_id}", response_model=Dict[str, str] + "/v1/agent/{agent_id}", + response_model=Dict[str, str] ) - async def update_agent(agent_id: UUID, update: AgentUpdate): + async def update_agent( + agent_id: UUID, + update: AgentUpdate, + current_user: User = Depends(get_current_user) + ): """Update an existing agent's configuration.""" + if not await self.store.verify_agent_access(agent_id, current_user.id): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to update this agent" + ) + await self.store.update_agent(agent_id, update) return {"status": "updated"} @self.app.get( "/v1/agent/{agent_id}/metrics", - response_model=AgentMetrics, + response_model=AgentMetrics ) - async def get_agent_metrics(agent_id: UUID): + async def get_agent_metrics( + agent_id: UUID, + current_user: User = Depends(get_current_user) + ): """Get performance metrics for a specific agent.""" + if not await self.store.verify_agent_access(agent_id, current_user.id): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to view this agent's metrics" + ) + return await self.store.get_agent_metrics(agent_id) @self.app.post( "/v1/agent/{agent_id}/clone", - response_model=Dict[str, UUID], + response_model=Dict[str, UUID] ) - async def clone_agent(agent_id: UUID, new_name: str): + async def clone_agent( + agent_id: UUID, + new_name: str, + current_user: User = Depends(get_current_user) + ): """Clone an existing agent with a new name.""" + if not await self.store.verify_agent_access(agent_id, current_user.id): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to clone this agent" + ) + new_id = await self.store.clone_agent(agent_id, new_name) + # Add the cloned agent to user's agents + if current_user.id not in self.store.user_agents: + self.store.user_agents[current_user.id] = [] + self.store.user_agents[current_user.id].append(new_id) + return {"agent_id": new_id} @self.app.delete("/v1/agent/{agent_id}") - async def delete_agent(agent_id: UUID): + async def delete_agent( + agent_id: UUID, + current_user: User = Depends(get_current_user) + ): """Delete an agent.""" + if not await self.store.verify_agent_access(agent_id, current_user.id): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to delete this agent" + ) + await self.store.delete_agent(agent_id) + # Remove from user's agents list + if current_user.id in self.store.user_agents: + self.store.user_agents[current_user.id] = [ + aid for aid in self.store.user_agents[current_user.id] + if aid != agent_id + ] return {"status": "deleted"} @self.app.post( - "/v1/agent/completions", response_model=CompletionResponse + "/v1/agent/completions", + response_model=CompletionResponse ) async def create_completion( request: CompletionRequest, background_tasks: BackgroundTasks, + current_user: User = Depends(get_current_user) ): """Process a completion request with the specified agent.""" + if not await self.store.verify_agent_access(request.agent_id, current_user.id): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to use this agent" + ) + try: agent = await self.store.get_agent(request.agent_id) @@ -830,12 +855,13 @@ class SwarmsAPI: request.prompt, request.agent_id, request.max_tokens, - 0.5, + request.temperature_override ) # Schedule background cleanup background_tasks.add_task( - self._cleanup_old_metrics, request.agent_id + self._cleanup_old_metrics, + request.agent_id ) return response @@ -844,25 +870,54 @@ class SwarmsAPI: logger.error(f"Error processing completion: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error processing completion: {str(e)}", + detail=f"Error processing completion: {str(e)}" ) @self.app.get("/v1/agent/{agent_id}/status") - async def get_agent_status(agent_id: UUID): + async def get_agent_status( + agent_id: UUID, + current_user: User = Depends(get_current_user) + ): """Get the current status of an agent.""" + if not await self.store.verify_agent_access(agent_id, current_user.id): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to view this agent's status" + ) + metadata = self.store.agent_metadata.get(agent_id) if not metadata: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Agent {agent_id} not found", + detail=f"Agent {agent_id} not found" ) + return { "agent_id": agent_id, "status": metadata["status"], "last_used": metadata["last_used"], "total_completions": metadata["total_completions"], - "error_count": metadata["error_count"], + "error_count": metadata["error_count"] } + + @self.app.get("/health") + async def health_check(): + """Health check endpoint - no auth required.""" + try: + # Test Supabase connection + supabase = get_supabase() + supabase.table('api_keys').select('count', count='exact').execute() + return {"status": "healthy", "database": "connected"} + except Exception as e: + logger.error(f"Health check failed: {str(e)}") + return JSONResponse( + status_code=503, + content={ + "status": "unhealthy", + "database": "disconnected", + "error": str(e) + } + ) async def _cleanup_old_metrics(self, agent_id: UUID): """Clean up old metrics data to prevent memory bloat.""" @@ -888,7 +943,7 @@ class SwarmsAPI: class APIServer: def __init__( - self, app: FastAPI, host: str = "0.0.0.0", port: int = 8000 + self, app: FastAPI, host: str = "0.0.0.0", port: int = 8080 ): self.app = app self.host = host diff --git a/api/requirements.txt b/api/requirements.txt index 1c93bff9..20dc6a40 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -8,4 +8,6 @@ opentelemetry-api opentelemetry-sdk opentelemetry-instrumentation-fastapi opentelemetry-instrumentation-requests -opentelemetry-exporter-otlp-proto-grpc \ No newline at end of file +opentelemetry-exporter-otlp-proto-grpc +swarms +supabase \ No newline at end of file diff --git a/voice.py b/voice.py new file mode 100644 index 00000000..e09b406d --- /dev/null +++ b/voice.py @@ -0,0 +1,353 @@ + +from __future__ import annotations + +import asyncio +import base64 +import io +import threading +from os import getenv +from typing import Any, Awaitable, Callable, cast + +import numpy as np + +try: + import pyaudio +except ImportError: + import subprocess + subprocess.check_call(["pip", "install", "pyaudio"]) + import pyaudio +try: + import sounddevice as sd +except ImportError: + import subprocess + subprocess.check_call(["pip", "install", "sounddevice"]) + import sounddevice as sd +from loguru import logger +from openai import AsyncOpenAI +from openai.resources.beta.realtime.realtime import ( + AsyncRealtimeConnection, +) +from openai.types.beta.realtime.session import Session + +try: + from pydub import AudioSegment +except ImportError: + import subprocess + subprocess.check_call(["pip", "install", "pydub"]) + from pydub import AudioSegment + +from dotenv import load_dotenv + +load_dotenv() + + +CHUNK_LENGTH_S = 0.05 # 100ms +SAMPLE_RATE = 24000 +FORMAT = pyaudio.paInt16 +CHANNELS = 1 + +# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false + + +def audio_to_pcm16_base64(audio_bytes: bytes) -> bytes: + # load the audio file from the byte stream + audio = AudioSegment.from_file(io.BytesIO(audio_bytes)) + print(f"Loaded audio: {audio.frame_rate=} {audio.channels=} {audio.sample_width=} {audio.frame_width=}") + # resample to 24kHz mono pcm16 + pcm_audio = audio.set_frame_rate(SAMPLE_RATE).set_channels(CHANNELS).set_sample_width(2).raw_data + return pcm_audio + + +class AudioPlayerAsync: + def __init__(self): + self.queue = [] + self.lock = threading.Lock() + self.stream = sd.OutputStream( + callback=self.callback, + samplerate=SAMPLE_RATE, + channels=CHANNELS, + dtype=np.int16, + blocksize=int(CHUNK_LENGTH_S * SAMPLE_RATE), + ) + self.playing = False + self._frame_count = 0 + + def callback(self, outdata, frames, time, status): # noqa + with self.lock: + data = np.empty(0, dtype=np.int16) + + # get next item from queue if there is still space in the buffer + while len(data) < frames and len(self.queue) > 0: + item = self.queue.pop(0) + frames_needed = frames - len(data) + data = np.concatenate((data, item[:frames_needed])) + if len(item) > frames_needed: + self.queue.insert(0, item[frames_needed:]) + + self._frame_count += len(data) + + # fill the rest of the frames with zeros if there is no more data + if len(data) < frames: + data = np.concatenate((data, np.zeros(frames - len(data), dtype=np.int16))) + + outdata[:] = data.reshape(-1, 1) + + def reset_frame_count(self): + self._frame_count = 0 + + def get_frame_count(self): + return self._frame_count + + def add_data(self, data: bytes): + with self.lock: + # bytes is pcm16 single channel audio data, convert to numpy array + np_data = np.frombuffer(data, dtype=np.int16) + self.queue.append(np_data) + if not self.playing: + self.start() + + def start(self): + self.playing = True + self.stream.start() + + def stop(self): + self.playing = False + self.stream.stop() + with self.lock: + self.queue = [] + + def terminate(self): + self.stream.close() + + +async def send_audio_worker_sounddevice( + connection: AsyncRealtimeConnection, + should_send: Callable[[], bool] | None = None, + start_send: Callable[[], Awaitable[None]] | None = None, +): + sent_audio = False + + device_info = sd.query_devices() + print(device_info) + + read_size = int(SAMPLE_RATE * 0.02) + + stream = sd.InputStream( + channels=CHANNELS, + samplerate=SAMPLE_RATE, + dtype="int16", + ) + stream.start() + + try: + while True: + if stream.read_available < read_size: + await asyncio.sleep(0) + continue + + data, _ = stream.read(read_size) + + if should_send() if should_send else True: + if not sent_audio and start_send: + await start_send() + await connection.send( + {"type": "input_audio_buffer.append", "audio": base64.b64encode(data).decode("utf-8")} + ) + sent_audio = True + + elif sent_audio: + print("Done, triggering inference") + await connection.send({"type": "input_audio_buffer.commit"}) + await connection.send({"type": "response.create", "response": {}}) + sent_audio = False + + await asyncio.sleep(0) + + except KeyboardInterrupt: + pass + finally: + stream.stop() + stream.close() + +class RealtimeApp: + """ + A console-based application to handle real-time audio recording and streaming, + connecting to OpenAI's GPT-4 Realtime API. + + Features: + - Streams microphone input to the GPT-4 Realtime API. + - Logs transcription results. + - Sends text prompts to the GPT-4 Realtime API. + """ + + def __init__(self, system_prompt: str = None) -> None: + self.connection: AsyncRealtimeConnection | None = None + self.session: Session | None = None + self.client = AsyncOpenAI(api_key=getenv("OPENAI_API_KEY")) + self.audio_player = AudioPlayerAsync() + self.last_audio_item_id: str | None = None + self.should_send_audio = asyncio.Event() + self.connected = asyncio.Event() + self.system_prompt = system_prompt + + async def initialize_text_prompt(self, text: str) -> None: + """Initialize and send a text prompt to the OpenAI Realtime API.""" + try: + async with self.client.beta.realtime.connect(model="gpt-4o-realtime-preview-2024-10-01") as conn: + self.connection = conn + await conn.session.update(session={"modalities": ["text"]}) + + await conn.conversation.item.create( + item={ + "type": "message", + "role": "system", + "content": [{"type": "input_text", "text": text}], + } + ) + await conn.response.create() + + async for event in conn: + if event.type == "response.text.delta": + print(event.delta, flush=True, end="") + + elif event.type == "response.text.done": + print() + + elif event.type == "response.done": + break + except Exception as e: + logger.exception(f"Error initializing text prompt: {e}") + + async def handle_realtime_connection(self) -> None: + """Handle the connection to the OpenAI Realtime API.""" + try: + async with self.client.beta.realtime.connect(model="gpt-4o-realtime-preview-2024-10-01") as conn: + self.connection = conn + self.connected.set() + logger.info("Connected to OpenAI Realtime API.") + + await conn.session.update(session={"turn_detection": {"type": "server_vad"}}) + + acc_items: dict[str, Any] = {} + + async for event in conn: + if event.type == "session.created": + self.session = event.session + assert event.session.id is not None + logger.info(f"Session created with ID: {event.session.id}") + continue + + if event.type == "session.updated": + self.session = event.session + logger.info("Session updated.") + continue + + if event.type == "response.audio.delta": + if event.item_id != self.last_audio_item_id: + self.audio_player.reset_frame_count() + self.last_audio_item_id = event.item_id + + bytes_data = base64.b64decode(event.delta) + self.audio_player.add_data(bytes_data) + continue + + if event.type == "response.audio_transcript.delta": + try: + text = acc_items[event.item_id] + except KeyError: + acc_items[event.item_id] = event.delta + else: + acc_items[event.item_id] = text + event.delta + + logger.debug(f"Transcription updated: {acc_items[event.item_id]}") + continue + + if event.type == "response.text.delta": + print(event.delta, flush=True, end="") + continue + + if event.type == "response.text.done": + print() + continue + + if event.type == "response.done": + break + except Exception as e: + logger.exception(f"Error in realtime connection handler: {e}") + + async def _get_connection(self) -> AsyncRealtimeConnection: + """Wait for and return the realtime connection.""" + await self.connected.wait() + assert self.connection is not None + return self.connection + + async def send_text_prompt(self, text: str) -> None: + """Send a text prompt to the OpenAI Realtime API.""" + try: + connection = await self._get_connection() + if not self.session: + logger.error("Session is not initialized. Cannot send prompt.") + return + + logger.info(f"Sending prompt to the model: {text}") + await connection.conversation.item.create( + item={ + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": text}], + } + ) + await connection.response.create() + except Exception as e: + logger.exception(f"Error sending text prompt: {e}") + + async def send_mic_audio(self) -> None: + """Stream microphone audio to the OpenAI Realtime API.""" + import sounddevice as sd # type: ignore + + sent_audio = False + + try: + read_size = int(SAMPLE_RATE * 0.02) + stream = sd.InputStream( + channels=CHANNELS, samplerate=SAMPLE_RATE, dtype="int16" + ) + stream.start() + + while True: + if stream.read_available < read_size: + await asyncio.sleep(0) + continue + + await self.should_send_audio.wait() + + data, _ = stream.read(read_size) + + connection = await self._get_connection() + if not sent_audio: + asyncio.create_task(connection.send({"type": "response.cancel"})) + sent_audio = True + + await connection.input_audio_buffer.append(audio=base64.b64encode(cast(Any, data)).decode("utf-8")) + await asyncio.sleep(0) + except Exception as e: + logger.exception(f"Error in microphone audio streaming: {e}") + finally: + stream.stop() + stream.close() + + async def run(self) -> None: + """Start the application tasks.""" + logger.info("Starting application tasks.") + + await asyncio.gather( + # self.initialize_text_prompt(self.system_prompt), + self.handle_realtime_connection(), + self.send_mic_audio() + ) + +if __name__ == "__main__": + logger.add("realtime_app.log", rotation="10 MB", retention="10 days", level="DEBUG") + logger.info("Starting RealtimeApp.") + app = RealtimeApp() + asyncio.run(app.run())