From 6110e7043019e8c09ccc575105acf31a93e70170 Mon Sep 17 00:00:00 2001 From: Ben Xu Date: Mon, 30 Dec 2024 15:25:55 -0500 Subject: [PATCH] add local stt & tts, add anticipation logic, remove video context accumulation --- .../source/server/livekit/video_processor.py | 72 +++++++++--- software/source/server/livekit/worker.py | 111 ++++++++++++++---- 2 files changed, 143 insertions(+), 40 deletions(-) diff --git a/software/source/server/livekit/video_processor.py b/software/source/server/livekit/video_processor.py index 4167cfa..8cc8336 100644 --- a/software/source/server/livekit/video_processor.py +++ b/software/source/server/livekit/video_processor.py @@ -1,10 +1,13 @@ -from livekit.rtc import VideoStream +from livekit.rtc import VideoStream, VideoFrame, VideoBufferType from livekit.agents import JobContext from datetime import datetime import os - -from livekit.rtc import VideoFrame import asyncio +from typing import Callable, Coroutine, Any + + +# Interval settings +INTERVAL = 30 # seconds # Define the path to the log file LOG_FILE_PATH = 'video_processor.txt' @@ -20,34 +23,71 @@ def log_message(message: str): log_file.write(f"{timestamp} - {message}\n") class RemoteVideoProcessor: - """Processes video frames from a remote participant's video stream.""" - def __init__(self, video_stream: VideoStream, job_ctx: JobContext): + log_message("Initializing RemoteVideoProcessor") self.video_stream = video_stream self.job_ctx = job_ctx - self.current_frame = None # Store the latest VideoFrame + self.current_frame = None self.lock = asyncio.Lock() - + + self.interval = INTERVAL + self.video_context = False + self.last_capture_time = 0 + + # Add callback for safety checks + self.on_instruction_check: Callable[[VideoFrame], Coroutine[Any, Any, None]] | None = None async def process_frames(self): - log_message("Starting to process remote video frames.") + """Process incoming video frames.""" async for frame_event in self.video_stream: try: video_frame = frame_event.frame timestamp = frame_event.timestamp_us - rotation = frame_event.rotation + + log_message(f"Processing frame at timestamp {timestamp/1000000:.3f}s") + log_message(f"Frame details: size={video_frame.width}x{video_frame.height}, type={video_frame.type}") - # Store the current frame safely - log_message(f"Received frame: width={video_frame.width}, height={video_frame.height}, type={video_frame.type}") async with self.lock: self.current_frame = video_frame + + if self.video_context and self._check_interrupt(timestamp): + self.last_capture_time = timestamp + # Trigger instruction check callback if registered + if self.on_instruction_check: + await self.on_instruction_check(video_frame) except Exception as e: - log_message(f"Error processing frame: {e}") + log_message(f"Error processing frame: {str(e)}") + import traceback + log_message(f"Traceback: {traceback.format_exc()}") + + + def register_safety_check_callback(self, callback: Callable[[VideoFrame], Coroutine[Any, Any, None]]): + """Register a callback for safety checks""" + self.on_instruction_check = callback + log_message("Registered instruction check callback") + async def get_current_frame(self) -> VideoFrame | None: - """Retrieve the current VideoFrame.""" - log_message("called get current frame") + """Get the most recent video frame.""" + log_message("Getting current frame") async with self.lock: - log_message("retrieving current frame: " + str(self.current_frame)) - return self.current_frame \ No newline at end of file + if self.current_frame is None: + log_message("No current frame available") + return self.current_frame + + + def set_video_context(self, context: bool): + """Set the video context.""" + log_message(f"Setting video context to: {context}") + self.video_context = context + + + def get_video_context(self) -> bool: + """Get the video context.""" + return self.video_context + + + def _check_interrupt(self, timestamp: int) -> bool: + """Determine if the video context should be interrupted.""" + return timestamp - self.last_capture_time > self.interval * 1000000 diff --git a/software/source/server/livekit/worker.py b/software/source/server/livekit/worker.py index 86b8ee3..caeb840 100644 --- a/software/source/server/livekit/worker.py +++ b/software/source/server/livekit/worker.py @@ -2,10 +2,11 @@ import asyncio import numpy as np import sys import os +import threading from datetime import datetime from typing import Literal, Awaitable -from livekit.agents import JobContext, WorkerOptions, cli +from livekit.agents import JobContext, WorkerOptions, cli, transcription from livekit.agents.transcription import STTSegmentsForwarder from livekit.agents.llm import ChatContext from livekit import rtc @@ -13,30 +14,23 @@ from livekit.agents.pipeline import VoicePipelineAgent from livekit.plugins import deepgram, openai, silero, elevenlabs, cartesia from livekit.agents.llm.chat_context import ChatContext, ChatImage, ChatMessage from livekit.agents.llm import LLMStream +from livekit.agents.stt import SpeechStream from source.server.livekit.video_processor import RemoteVideoProcessor - -from source.server.livekit.transcriptions import _forward_transcription +from source.server.livekit.anticipation import handle_instruction_check +from source.server.livekit.logger import log_message from dotenv import load_dotenv - load_dotenv() -# Define the path to the log file -LOG_FILE_PATH = 'worker.txt' -DEBUG = os.getenv('DEBUG', 'false').lower() == 'true' -def log_message(message: str): - """Append a message to the log file with a timestamp.""" - if not DEBUG: - return - timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') - with open(LOG_FILE_PATH, 'a') as log_file: - log_file.write(f"{timestamp} - {message}\n") -start_message = """Hi! You can hold the white circle below to speak to me. +_room_lock = threading.Lock() +_connected_rooms = set() -Try asking what I can do.""" + + +START_MESSAGE = "Hi! You can hold the white circle below to speak to me. Try asking what I can do." # This function is the entrypoint for the agent. async def entrypoint(ctx: JobContext): @@ -96,7 +90,7 @@ async def entrypoint(ctx: JobContext): base_url = f"http://{interpreter_server_host}:{interpreter_server_port}/" # For debugging - base_url = "http://127.0.0.1:8000/" + base_url = "http://127.0.0.1:9000/" open_interpreter = openai.LLM( model="open-interpreter", base_url=base_url, api_key="x" @@ -105,11 +99,18 @@ async def entrypoint(ctx: JobContext): tts_provider = os.getenv('01_TTS', '').lower() stt_provider = os.getenv('01_STT', '').lower() + tts_provider = "elevenlabs" + stt_provider = "deepgram" + # Add plugins here if tts_provider == 'openai': tts = openai.TTS() + elif tts_provider == 'local': + tts = openai.TTS(base_url="http://localhost:8000/v1") + print("using local tts") elif tts_provider == 'elevenlabs': tts = elevenlabs.TTS() + print("using elevenlabs tts") elif tts_provider == 'cartesia': tts = cartesia.TTS() else: @@ -117,16 +118,20 @@ async def entrypoint(ctx: JobContext): if stt_provider == 'deepgram': stt = deepgram.STT() + elif stt_provider == 'local': + stt = openai.STT(base_url="http://localhost:8001/v1") + print("using local stt") else: raise ValueError(f"Unsupported STT provider: {stt_provider}. Please set 01_STT environment variable to 'deepgram'.") ############################################################ # initialize voice assistant states ############################################################ - push_to_talk = True + push_to_talk = False current_message: ChatMessage = ChatMessage(role='user') submitted_message: ChatMessage = ChatMessage(role='user') video_muted = False + video_context = False tasks = [] ############################################################ @@ -175,6 +180,7 @@ async def entrypoint(ctx: JobContext): if remote_video_processor and not video_muted: video_frame = await remote_video_processor.get_current_frame() + if video_frame: chat_ctx.append(role="user", images=[ChatImage(image=video_frame)]) else: @@ -202,7 +208,15 @@ async def entrypoint(ctx: JobContext): # append image if available if remote_video_processor and not video_muted: - video_frame = await remote_video_processor.get_current_frame() + if remote_video_processor.get_video_context(): + log_message("context is true") + log_message("retrieving timeline frame") + video_frame = await remote_video_processor.get_timeline_frame() + else: + log_message("context is false") + log_message("retrieving current frame") + video_frame = await remote_video_processor.get_current_frame() + if video_frame: chat_ctx.append(role="user", images=[ChatImage(image=video_frame)]) log_message(f"[on_message_received] appended image: {video_frame} to chat_ctx: {chat_ctx}") @@ -263,6 +277,19 @@ async def entrypoint(ctx: JobContext): ############################################################ # transcribe participant track ############################################################ + async def _forward_transcription( + stt_stream: SpeechStream, + stt_forwarder: transcription.STTSegmentsForwarder, + ): + """Forward the transcription and log the transcript in the console""" + async for ev in stt_stream: + stt_forwarder.update(ev) + if ev.type == stt.SpeechEventType.INTERIM_TRANSCRIPT: + print(ev.alternatives[0].text, end="") + elif ev.type == stt.SpeechEventType.FINAL_TRANSCRIPT: + print("\n") + print(" -> ", ev.alternatives[0].text) + async def transcribe_track(participant: rtc.RemoteParticipant, track: rtc.Track): audio_stream = rtc.AudioStream(track) stt_forwarder = STTSegmentsForwarder( @@ -297,8 +324,18 @@ async def entrypoint(ctx: JobContext): remote_video_stream = rtc.VideoStream(track=track, format=rtc.VideoBufferType.RGBA) remote_video_processor = RemoteVideoProcessor(video_stream=remote_video_stream, job_ctx=ctx) log_message("remote video processor." + str(remote_video_processor)) + + # Register safety check callback + remote_video_processor.register_safety_check_callback( + lambda frame: handle_instruction_check(assistant, frame) + ) + + remote_video_processor.set_video_context(video_context) + log_message(f"set video context to {video_context} from queued video context") + asyncio.create_task(remote_video_processor.process_frames()) + ############################################################ # on track muted callback ############################################################ @@ -329,11 +366,12 @@ async def entrypoint(ctx: JobContext): local_participant = ctx.room.local_participant await local_participant.publish_data(payload="{CLEAR_CHAT}", topic="chat_context") log_message("sent {CLEAR_CHAT} to chat_context for client to clear") - await assistant.say(assistant.start_message) + await assistant.say(START_MESSAGE) @ctx.room.on("data_received") def on_data_received(data: rtc.DataPacket): + nonlocal video_context decoded_data = data.data.decode() log_message(f"received data from {data.topic}: {decoded_data}") if data.topic == "chat_context" and decoded_data == "{CLEAR_CHAT}": @@ -349,6 +387,22 @@ async def entrypoint(ctx: JobContext): asyncio.create_task(_publish_clear_chat()) + if data.topic == "video_context" and decoded_data == "{VIDEO_CONTEXT_ON}": + if remote_video_processor: + remote_video_processor.set_video_context(True) + log_message("set video context to True") + else: + video_context = True + log_message("no remote video processor found, queued video context to True") + + if data.topic == "video_context" and decoded_data == "{VIDEO_CONTEXT_OFF}": + if remote_video_processor: + remote_video_processor.set_video_context(False) + log_message("set video context to False") + else: + video_context = False + log_message("no remote video processor found, queued video context to False") + ############################################################ # Start the voice assistant with the LiveKit room @@ -367,7 +421,7 @@ async def entrypoint(ctx: JobContext): await asyncio.sleep(1) # Greets the user with an initial message - await assistant.say(start_message, allow_interruptions=True) + await assistant.say(START_MESSAGE, allow_interruptions=True) ############################################################ # wait for the voice assistant to finish @@ -389,12 +443,21 @@ def main(livekit_url: str): # Workers have to be run as CLIs right now. # So we need to simualte running "[this file] dev" + worker_start_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') + log_message(f"=== INITIALIZING NEW WORKER AT {worker_start_time} ===") + print(f"=== INITIALIZING NEW WORKER AT {worker_start_time} ===") + # Modify sys.argv to set the path to this file as the first argument # and 'dev' as the second argument - sys.argv = [str(__file__), 'dev'] + sys.argv = [str(__file__), 'start'] - # livekit_url = "ws://localhost:7880" # Initialize the worker with the entrypoint cli.run_app( - WorkerOptions(entrypoint_fnc=entrypoint, api_key="devkey", api_secret="secret", ws_url=livekit_url) + WorkerOptions( + entrypoint_fnc=entrypoint, + api_key="devkey", + api_secret="secret", + ws_url=livekit_url + ) + ) \ No newline at end of file