add local stt & tts, add anticipation logic, remove video context accumulation

pull/314/head
Ben Xu 1 month ago
parent bd6f530be7
commit 6110e70430

@ -1,10 +1,13 @@
from livekit.rtc import VideoStream from livekit.rtc import VideoStream, VideoFrame, VideoBufferType
from livekit.agents import JobContext from livekit.agents import JobContext
from datetime import datetime from datetime import datetime
import os import os
from livekit.rtc import VideoFrame
import asyncio import asyncio
from typing import Callable, Coroutine, Any
# Interval settings
INTERVAL = 30 # seconds
# Define the path to the log file # Define the path to the log file
LOG_FILE_PATH = 'video_processor.txt' LOG_FILE_PATH = 'video_processor.txt'
@ -20,34 +23,71 @@ def log_message(message: str):
log_file.write(f"{timestamp} - {message}\n") log_file.write(f"{timestamp} - {message}\n")
class RemoteVideoProcessor: class RemoteVideoProcessor:
"""Processes video frames from a remote participant's video stream."""
def __init__(self, video_stream: VideoStream, job_ctx: JobContext): def __init__(self, video_stream: VideoStream, job_ctx: JobContext):
log_message("Initializing RemoteVideoProcessor")
self.video_stream = video_stream self.video_stream = video_stream
self.job_ctx = job_ctx self.job_ctx = job_ctx
self.current_frame = None # Store the latest VideoFrame self.current_frame = None
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
self.interval = INTERVAL
self.video_context = False
self.last_capture_time = 0
# Add callback for safety checks
self.on_instruction_check: Callable[[VideoFrame], Coroutine[Any, Any, None]] | None = None
async def process_frames(self): async def process_frames(self):
log_message("Starting to process remote video frames.") """Process incoming video frames."""
async for frame_event in self.video_stream: async for frame_event in self.video_stream:
try: try:
video_frame = frame_event.frame video_frame = frame_event.frame
timestamp = frame_event.timestamp_us timestamp = frame_event.timestamp_us
rotation = frame_event.rotation
# Store the current frame safely log_message(f"Processing frame at timestamp {timestamp/1000000:.3f}s")
log_message(f"Received frame: width={video_frame.width}, height={video_frame.height}, type={video_frame.type}") log_message(f"Frame details: size={video_frame.width}x{video_frame.height}, type={video_frame.type}")
async with self.lock: async with self.lock:
self.current_frame = video_frame self.current_frame = video_frame
if self.video_context and self._check_interrupt(timestamp):
self.last_capture_time = timestamp
# Trigger instruction check callback if registered
if self.on_instruction_check:
await self.on_instruction_check(video_frame)
except Exception as e: except Exception as e:
log_message(f"Error processing frame: {e}") log_message(f"Error processing frame: {str(e)}")
import traceback
log_message(f"Traceback: {traceback.format_exc()}")
def register_safety_check_callback(self, callback: Callable[[VideoFrame], Coroutine[Any, Any, None]]):
"""Register a callback for safety checks"""
self.on_instruction_check = callback
log_message("Registered instruction check callback")
async def get_current_frame(self) -> VideoFrame | None: async def get_current_frame(self) -> VideoFrame | None:
"""Retrieve the current VideoFrame.""" """Get the most recent video frame."""
log_message("called get current frame") log_message("Getting current frame")
async with self.lock: async with self.lock:
log_message("retrieving current frame: " + str(self.current_frame)) if self.current_frame is None:
log_message("No current frame available")
return self.current_frame return self.current_frame
def set_video_context(self, context: bool):
"""Set the video context."""
log_message(f"Setting video context to: {context}")
self.video_context = context
def get_video_context(self) -> bool:
"""Get the video context."""
return self.video_context
def _check_interrupt(self, timestamp: int) -> bool:
"""Determine if the video context should be interrupted."""
return timestamp - self.last_capture_time > self.interval * 1000000

@ -2,10 +2,11 @@ import asyncio
import numpy as np import numpy as np
import sys import sys
import os import os
import threading
from datetime import datetime from datetime import datetime
from typing import Literal, Awaitable from typing import Literal, Awaitable
from livekit.agents import JobContext, WorkerOptions, cli from livekit.agents import JobContext, WorkerOptions, cli, transcription
from livekit.agents.transcription import STTSegmentsForwarder from livekit.agents.transcription import STTSegmentsForwarder
from livekit.agents.llm import ChatContext from livekit.agents.llm import ChatContext
from livekit import rtc from livekit import rtc
@ -13,30 +14,23 @@ from livekit.agents.pipeline import VoicePipelineAgent
from livekit.plugins import deepgram, openai, silero, elevenlabs, cartesia from livekit.plugins import deepgram, openai, silero, elevenlabs, cartesia
from livekit.agents.llm.chat_context import ChatContext, ChatImage, ChatMessage from livekit.agents.llm.chat_context import ChatContext, ChatImage, ChatMessage
from livekit.agents.llm import LLMStream from livekit.agents.llm import LLMStream
from livekit.agents.stt import SpeechStream
from source.server.livekit.video_processor import RemoteVideoProcessor from source.server.livekit.video_processor import RemoteVideoProcessor
from source.server.livekit.anticipation import handle_instruction_check
from source.server.livekit.transcriptions import _forward_transcription from source.server.livekit.logger import log_message
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
# Define the path to the log file
LOG_FILE_PATH = 'worker.txt'
DEBUG = os.getenv('DEBUG', 'false').lower() == 'true'
def log_message(message: str):
"""Append a message to the log file with a timestamp."""
if not DEBUG:
return
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
with open(LOG_FILE_PATH, 'a') as log_file:
log_file.write(f"{timestamp} - {message}\n")
start_message = """Hi! You can hold the white circle below to speak to me. _room_lock = threading.Lock()
_connected_rooms = set()
Try asking what I can do."""
START_MESSAGE = "Hi! You can hold the white circle below to speak to me. Try asking what I can do."
# This function is the entrypoint for the agent. # This function is the entrypoint for the agent.
async def entrypoint(ctx: JobContext): async def entrypoint(ctx: JobContext):
@ -96,7 +90,7 @@ async def entrypoint(ctx: JobContext):
base_url = f"http://{interpreter_server_host}:{interpreter_server_port}/" base_url = f"http://{interpreter_server_host}:{interpreter_server_port}/"
# For debugging # For debugging
base_url = "http://127.0.0.1:8000/" base_url = "http://127.0.0.1:9000/"
open_interpreter = openai.LLM( open_interpreter = openai.LLM(
model="open-interpreter", base_url=base_url, api_key="x" model="open-interpreter", base_url=base_url, api_key="x"
@ -105,11 +99,18 @@ async def entrypoint(ctx: JobContext):
tts_provider = os.getenv('01_TTS', '').lower() tts_provider = os.getenv('01_TTS', '').lower()
stt_provider = os.getenv('01_STT', '').lower() stt_provider = os.getenv('01_STT', '').lower()
tts_provider = "elevenlabs"
stt_provider = "deepgram"
# Add plugins here # Add plugins here
if tts_provider == 'openai': if tts_provider == 'openai':
tts = openai.TTS() tts = openai.TTS()
elif tts_provider == 'local':
tts = openai.TTS(base_url="http://localhost:8000/v1")
print("using local tts")
elif tts_provider == 'elevenlabs': elif tts_provider == 'elevenlabs':
tts = elevenlabs.TTS() tts = elevenlabs.TTS()
print("using elevenlabs tts")
elif tts_provider == 'cartesia': elif tts_provider == 'cartesia':
tts = cartesia.TTS() tts = cartesia.TTS()
else: else:
@ -117,16 +118,20 @@ async def entrypoint(ctx: JobContext):
if stt_provider == 'deepgram': if stt_provider == 'deepgram':
stt = deepgram.STT() stt = deepgram.STT()
elif stt_provider == 'local':
stt = openai.STT(base_url="http://localhost:8001/v1")
print("using local stt")
else: else:
raise ValueError(f"Unsupported STT provider: {stt_provider}. Please set 01_STT environment variable to 'deepgram'.") raise ValueError(f"Unsupported STT provider: {stt_provider}. Please set 01_STT environment variable to 'deepgram'.")
############################################################ ############################################################
# initialize voice assistant states # initialize voice assistant states
############################################################ ############################################################
push_to_talk = True push_to_talk = False
current_message: ChatMessage = ChatMessage(role='user') current_message: ChatMessage = ChatMessage(role='user')
submitted_message: ChatMessage = ChatMessage(role='user') submitted_message: ChatMessage = ChatMessage(role='user')
video_muted = False video_muted = False
video_context = False
tasks = [] tasks = []
############################################################ ############################################################
@ -175,6 +180,7 @@ async def entrypoint(ctx: JobContext):
if remote_video_processor and not video_muted: if remote_video_processor and not video_muted:
video_frame = await remote_video_processor.get_current_frame() video_frame = await remote_video_processor.get_current_frame()
if video_frame: if video_frame:
chat_ctx.append(role="user", images=[ChatImage(image=video_frame)]) chat_ctx.append(role="user", images=[ChatImage(image=video_frame)])
else: else:
@ -202,7 +208,15 @@ async def entrypoint(ctx: JobContext):
# append image if available # append image if available
if remote_video_processor and not video_muted: if remote_video_processor and not video_muted:
video_frame = await remote_video_processor.get_current_frame() if remote_video_processor.get_video_context():
log_message("context is true")
log_message("retrieving timeline frame")
video_frame = await remote_video_processor.get_timeline_frame()
else:
log_message("context is false")
log_message("retrieving current frame")
video_frame = await remote_video_processor.get_current_frame()
if video_frame: if video_frame:
chat_ctx.append(role="user", images=[ChatImage(image=video_frame)]) chat_ctx.append(role="user", images=[ChatImage(image=video_frame)])
log_message(f"[on_message_received] appended image: {video_frame} to chat_ctx: {chat_ctx}") log_message(f"[on_message_received] appended image: {video_frame} to chat_ctx: {chat_ctx}")
@ -263,6 +277,19 @@ async def entrypoint(ctx: JobContext):
############################################################ ############################################################
# transcribe participant track # transcribe participant track
############################################################ ############################################################
async def _forward_transcription(
stt_stream: SpeechStream,
stt_forwarder: transcription.STTSegmentsForwarder,
):
"""Forward the transcription and log the transcript in the console"""
async for ev in stt_stream:
stt_forwarder.update(ev)
if ev.type == stt.SpeechEventType.INTERIM_TRANSCRIPT:
print(ev.alternatives[0].text, end="")
elif ev.type == stt.SpeechEventType.FINAL_TRANSCRIPT:
print("\n")
print(" -> ", ev.alternatives[0].text)
async def transcribe_track(participant: rtc.RemoteParticipant, track: rtc.Track): async def transcribe_track(participant: rtc.RemoteParticipant, track: rtc.Track):
audio_stream = rtc.AudioStream(track) audio_stream = rtc.AudioStream(track)
stt_forwarder = STTSegmentsForwarder( stt_forwarder = STTSegmentsForwarder(
@ -297,8 +324,18 @@ async def entrypoint(ctx: JobContext):
remote_video_stream = rtc.VideoStream(track=track, format=rtc.VideoBufferType.RGBA) remote_video_stream = rtc.VideoStream(track=track, format=rtc.VideoBufferType.RGBA)
remote_video_processor = RemoteVideoProcessor(video_stream=remote_video_stream, job_ctx=ctx) remote_video_processor = RemoteVideoProcessor(video_stream=remote_video_stream, job_ctx=ctx)
log_message("remote video processor." + str(remote_video_processor)) log_message("remote video processor." + str(remote_video_processor))
# Register safety check callback
remote_video_processor.register_safety_check_callback(
lambda frame: handle_instruction_check(assistant, frame)
)
remote_video_processor.set_video_context(video_context)
log_message(f"set video context to {video_context} from queued video context")
asyncio.create_task(remote_video_processor.process_frames()) asyncio.create_task(remote_video_processor.process_frames())
############################################################ ############################################################
# on track muted callback # on track muted callback
############################################################ ############################################################
@ -329,11 +366,12 @@ async def entrypoint(ctx: JobContext):
local_participant = ctx.room.local_participant local_participant = ctx.room.local_participant
await local_participant.publish_data(payload="{CLEAR_CHAT}", topic="chat_context") await local_participant.publish_data(payload="{CLEAR_CHAT}", topic="chat_context")
log_message("sent {CLEAR_CHAT} to chat_context for client to clear") log_message("sent {CLEAR_CHAT} to chat_context for client to clear")
await assistant.say(assistant.start_message) await assistant.say(START_MESSAGE)
@ctx.room.on("data_received") @ctx.room.on("data_received")
def on_data_received(data: rtc.DataPacket): def on_data_received(data: rtc.DataPacket):
nonlocal video_context
decoded_data = data.data.decode() decoded_data = data.data.decode()
log_message(f"received data from {data.topic}: {decoded_data}") log_message(f"received data from {data.topic}: {decoded_data}")
if data.topic == "chat_context" and decoded_data == "{CLEAR_CHAT}": if data.topic == "chat_context" and decoded_data == "{CLEAR_CHAT}":
@ -349,6 +387,22 @@ async def entrypoint(ctx: JobContext):
asyncio.create_task(_publish_clear_chat()) asyncio.create_task(_publish_clear_chat())
if data.topic == "video_context" and decoded_data == "{VIDEO_CONTEXT_ON}":
if remote_video_processor:
remote_video_processor.set_video_context(True)
log_message("set video context to True")
else:
video_context = True
log_message("no remote video processor found, queued video context to True")
if data.topic == "video_context" and decoded_data == "{VIDEO_CONTEXT_OFF}":
if remote_video_processor:
remote_video_processor.set_video_context(False)
log_message("set video context to False")
else:
video_context = False
log_message("no remote video processor found, queued video context to False")
############################################################ ############################################################
# Start the voice assistant with the LiveKit room # Start the voice assistant with the LiveKit room
@ -367,7 +421,7 @@ async def entrypoint(ctx: JobContext):
await asyncio.sleep(1) await asyncio.sleep(1)
# Greets the user with an initial message # Greets the user with an initial message
await assistant.say(start_message, allow_interruptions=True) await assistant.say(START_MESSAGE, allow_interruptions=True)
############################################################ ############################################################
# wait for the voice assistant to finish # wait for the voice assistant to finish
@ -389,12 +443,21 @@ def main(livekit_url: str):
# Workers have to be run as CLIs right now. # Workers have to be run as CLIs right now.
# So we need to simualte running "[this file] dev" # So we need to simualte running "[this file] dev"
worker_start_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
log_message(f"=== INITIALIZING NEW WORKER AT {worker_start_time} ===")
print(f"=== INITIALIZING NEW WORKER AT {worker_start_time} ===")
# Modify sys.argv to set the path to this file as the first argument # Modify sys.argv to set the path to this file as the first argument
# and 'dev' as the second argument # and 'dev' as the second argument
sys.argv = [str(__file__), 'dev'] sys.argv = [str(__file__), 'start']
# livekit_url = "ws://localhost:7880"
# Initialize the worker with the entrypoint # Initialize the worker with the entrypoint
cli.run_app( cli.run_app(
WorkerOptions(entrypoint_fnc=entrypoint, api_key="devkey", api_secret="secret", ws_url=livekit_url) WorkerOptions(
entrypoint_fnc=entrypoint,
api_key="devkey",
api_secret="secret",
ws_url=livekit_url
)
) )
Loading…
Cancel
Save