add local stt & tts, add anticipation logic, remove video context accumulation

pull/314/head
Ben Xu 4 weeks ago
parent bd6f530be7
commit 6110e70430

@ -1,10 +1,13 @@
from livekit.rtc import VideoStream
from livekit.rtc import VideoStream, VideoFrame, VideoBufferType
from livekit.agents import JobContext
from datetime import datetime
import os
from livekit.rtc import VideoFrame
import asyncio
from typing import Callable, Coroutine, Any
# Interval settings
INTERVAL = 30 # seconds
# Define the path to the log file
LOG_FILE_PATH = 'video_processor.txt'
@ -20,34 +23,71 @@ def log_message(message: str):
log_file.write(f"{timestamp} - {message}\n")
class RemoteVideoProcessor:
"""Processes video frames from a remote participant's video stream."""
def __init__(self, video_stream: VideoStream, job_ctx: JobContext):
log_message("Initializing RemoteVideoProcessor")
self.video_stream = video_stream
self.job_ctx = job_ctx
self.current_frame = None # Store the latest VideoFrame
self.current_frame = None
self.lock = asyncio.Lock()
self.interval = INTERVAL
self.video_context = False
self.last_capture_time = 0
# Add callback for safety checks
self.on_instruction_check: Callable[[VideoFrame], Coroutine[Any, Any, None]] | None = None
async def process_frames(self):
log_message("Starting to process remote video frames.")
"""Process incoming video frames."""
async for frame_event in self.video_stream:
try:
video_frame = frame_event.frame
timestamp = frame_event.timestamp_us
rotation = frame_event.rotation
# Store the current frame safely
log_message(f"Received frame: width={video_frame.width}, height={video_frame.height}, type={video_frame.type}")
log_message(f"Processing frame at timestamp {timestamp/1000000:.3f}s")
log_message(f"Frame details: size={video_frame.width}x{video_frame.height}, type={video_frame.type}")
async with self.lock:
self.current_frame = video_frame
if self.video_context and self._check_interrupt(timestamp):
self.last_capture_time = timestamp
# Trigger instruction check callback if registered
if self.on_instruction_check:
await self.on_instruction_check(video_frame)
except Exception as e:
log_message(f"Error processing frame: {e}")
log_message(f"Error processing frame: {str(e)}")
import traceback
log_message(f"Traceback: {traceback.format_exc()}")
def register_safety_check_callback(self, callback: Callable[[VideoFrame], Coroutine[Any, Any, None]]):
"""Register a callback for safety checks"""
self.on_instruction_check = callback
log_message("Registered instruction check callback")
async def get_current_frame(self) -> VideoFrame | None:
"""Retrieve the current VideoFrame."""
log_message("called get current frame")
"""Get the most recent video frame."""
log_message("Getting current frame")
async with self.lock:
log_message("retrieving current frame: " + str(self.current_frame))
if self.current_frame is None:
log_message("No current frame available")
return self.current_frame
def set_video_context(self, context: bool):
"""Set the video context."""
log_message(f"Setting video context to: {context}")
self.video_context = context
def get_video_context(self) -> bool:
"""Get the video context."""
return self.video_context
def _check_interrupt(self, timestamp: int) -> bool:
"""Determine if the video context should be interrupted."""
return timestamp - self.last_capture_time > self.interval * 1000000

@ -2,10 +2,11 @@ import asyncio
import numpy as np
import sys
import os
import threading
from datetime import datetime
from typing import Literal, Awaitable
from livekit.agents import JobContext, WorkerOptions, cli
from livekit.agents import JobContext, WorkerOptions, cli, transcription
from livekit.agents.transcription import STTSegmentsForwarder
from livekit.agents.llm import ChatContext
from livekit import rtc
@ -13,30 +14,23 @@ from livekit.agents.pipeline import VoicePipelineAgent
from livekit.plugins import deepgram, openai, silero, elevenlabs, cartesia
from livekit.agents.llm.chat_context import ChatContext, ChatImage, ChatMessage
from livekit.agents.llm import LLMStream
from livekit.agents.stt import SpeechStream
from source.server.livekit.video_processor import RemoteVideoProcessor
from source.server.livekit.transcriptions import _forward_transcription
from source.server.livekit.anticipation import handle_instruction_check
from source.server.livekit.logger import log_message
from dotenv import load_dotenv
load_dotenv()
# Define the path to the log file
LOG_FILE_PATH = 'worker.txt'
DEBUG = os.getenv('DEBUG', 'false').lower() == 'true'
def log_message(message: str):
"""Append a message to the log file with a timestamp."""
if not DEBUG:
return
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
with open(LOG_FILE_PATH, 'a') as log_file:
log_file.write(f"{timestamp} - {message}\n")
start_message = """Hi! You can hold the white circle below to speak to me.
_room_lock = threading.Lock()
_connected_rooms = set()
Try asking what I can do."""
START_MESSAGE = "Hi! You can hold the white circle below to speak to me. Try asking what I can do."
# This function is the entrypoint for the agent.
async def entrypoint(ctx: JobContext):
@ -96,7 +90,7 @@ async def entrypoint(ctx: JobContext):
base_url = f"http://{interpreter_server_host}:{interpreter_server_port}/"
# For debugging
base_url = "http://127.0.0.1:8000/"
base_url = "http://127.0.0.1:9000/"
open_interpreter = openai.LLM(
model="open-interpreter", base_url=base_url, api_key="x"
@ -105,11 +99,18 @@ async def entrypoint(ctx: JobContext):
tts_provider = os.getenv('01_TTS', '').lower()
stt_provider = os.getenv('01_STT', '').lower()
tts_provider = "elevenlabs"
stt_provider = "deepgram"
# Add plugins here
if tts_provider == 'openai':
tts = openai.TTS()
elif tts_provider == 'local':
tts = openai.TTS(base_url="http://localhost:8000/v1")
print("using local tts")
elif tts_provider == 'elevenlabs':
tts = elevenlabs.TTS()
print("using elevenlabs tts")
elif tts_provider == 'cartesia':
tts = cartesia.TTS()
else:
@ -117,16 +118,20 @@ async def entrypoint(ctx: JobContext):
if stt_provider == 'deepgram':
stt = deepgram.STT()
elif stt_provider == 'local':
stt = openai.STT(base_url="http://localhost:8001/v1")
print("using local stt")
else:
raise ValueError(f"Unsupported STT provider: {stt_provider}. Please set 01_STT environment variable to 'deepgram'.")
############################################################
# initialize voice assistant states
############################################################
push_to_talk = True
push_to_talk = False
current_message: ChatMessage = ChatMessage(role='user')
submitted_message: ChatMessage = ChatMessage(role='user')
video_muted = False
video_context = False
tasks = []
############################################################
@ -175,6 +180,7 @@ async def entrypoint(ctx: JobContext):
if remote_video_processor and not video_muted:
video_frame = await remote_video_processor.get_current_frame()
if video_frame:
chat_ctx.append(role="user", images=[ChatImage(image=video_frame)])
else:
@ -202,7 +208,15 @@ async def entrypoint(ctx: JobContext):
# append image if available
if remote_video_processor and not video_muted:
if remote_video_processor.get_video_context():
log_message("context is true")
log_message("retrieving timeline frame")
video_frame = await remote_video_processor.get_timeline_frame()
else:
log_message("context is false")
log_message("retrieving current frame")
video_frame = await remote_video_processor.get_current_frame()
if video_frame:
chat_ctx.append(role="user", images=[ChatImage(image=video_frame)])
log_message(f"[on_message_received] appended image: {video_frame} to chat_ctx: {chat_ctx}")
@ -263,6 +277,19 @@ async def entrypoint(ctx: JobContext):
############################################################
# transcribe participant track
############################################################
async def _forward_transcription(
stt_stream: SpeechStream,
stt_forwarder: transcription.STTSegmentsForwarder,
):
"""Forward the transcription and log the transcript in the console"""
async for ev in stt_stream:
stt_forwarder.update(ev)
if ev.type == stt.SpeechEventType.INTERIM_TRANSCRIPT:
print(ev.alternatives[0].text, end="")
elif ev.type == stt.SpeechEventType.FINAL_TRANSCRIPT:
print("\n")
print(" -> ", ev.alternatives[0].text)
async def transcribe_track(participant: rtc.RemoteParticipant, track: rtc.Track):
audio_stream = rtc.AudioStream(track)
stt_forwarder = STTSegmentsForwarder(
@ -297,8 +324,18 @@ async def entrypoint(ctx: JobContext):
remote_video_stream = rtc.VideoStream(track=track, format=rtc.VideoBufferType.RGBA)
remote_video_processor = RemoteVideoProcessor(video_stream=remote_video_stream, job_ctx=ctx)
log_message("remote video processor." + str(remote_video_processor))
# Register safety check callback
remote_video_processor.register_safety_check_callback(
lambda frame: handle_instruction_check(assistant, frame)
)
remote_video_processor.set_video_context(video_context)
log_message(f"set video context to {video_context} from queued video context")
asyncio.create_task(remote_video_processor.process_frames())
############################################################
# on track muted callback
############################################################
@ -329,11 +366,12 @@ async def entrypoint(ctx: JobContext):
local_participant = ctx.room.local_participant
await local_participant.publish_data(payload="{CLEAR_CHAT}", topic="chat_context")
log_message("sent {CLEAR_CHAT} to chat_context for client to clear")
await assistant.say(assistant.start_message)
await assistant.say(START_MESSAGE)
@ctx.room.on("data_received")
def on_data_received(data: rtc.DataPacket):
nonlocal video_context
decoded_data = data.data.decode()
log_message(f"received data from {data.topic}: {decoded_data}")
if data.topic == "chat_context" and decoded_data == "{CLEAR_CHAT}":
@ -349,6 +387,22 @@ async def entrypoint(ctx: JobContext):
asyncio.create_task(_publish_clear_chat())
if data.topic == "video_context" and decoded_data == "{VIDEO_CONTEXT_ON}":
if remote_video_processor:
remote_video_processor.set_video_context(True)
log_message("set video context to True")
else:
video_context = True
log_message("no remote video processor found, queued video context to True")
if data.topic == "video_context" and decoded_data == "{VIDEO_CONTEXT_OFF}":
if remote_video_processor:
remote_video_processor.set_video_context(False)
log_message("set video context to False")
else:
video_context = False
log_message("no remote video processor found, queued video context to False")
############################################################
# Start the voice assistant with the LiveKit room
@ -367,7 +421,7 @@ async def entrypoint(ctx: JobContext):
await asyncio.sleep(1)
# Greets the user with an initial message
await assistant.say(start_message, allow_interruptions=True)
await assistant.say(START_MESSAGE, allow_interruptions=True)
############################################################
# wait for the voice assistant to finish
@ -389,12 +443,21 @@ def main(livekit_url: str):
# Workers have to be run as CLIs right now.
# So we need to simualte running "[this file] dev"
worker_start_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
log_message(f"=== INITIALIZING NEW WORKER AT {worker_start_time} ===")
print(f"=== INITIALIZING NEW WORKER AT {worker_start_time} ===")
# Modify sys.argv to set the path to this file as the first argument
# and 'dev' as the second argument
sys.argv = [str(__file__), 'dev']
sys.argv = [str(__file__), 'start']
# livekit_url = "ws://localhost:7880"
# Initialize the worker with the entrypoint
cli.run_app(
WorkerOptions(entrypoint_fnc=entrypoint, api_key="devkey", api_secret="secret", ws_url=livekit_url)
WorkerOptions(
entrypoint_fnc=entrypoint,
api_key="devkey",
api_secret="secret",
ws_url=livekit_url
)
)
Loading…
Cancel
Save