From 07672f498be0aa2a3222a262fbac76e99c3c97d3 Mon Sep 17 00:00:00 2001 From: Ben Xu Date: Mon, 9 Dec 2024 13:37:19 -0800 Subject: [PATCH] add voice assistant state communication and clear chat context --- software/source/server/livekit/worker.py | 54 ++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 3 deletions(-) diff --git a/software/source/server/livekit/worker.py b/software/source/server/livekit/worker.py index 0e84e2b..86b8ee3 100644 --- a/software/source/server/livekit/worker.py +++ b/software/source/server/livekit/worker.py @@ -24,9 +24,12 @@ load_dotenv() # Define the path to the log file LOG_FILE_PATH = 'worker.txt' +DEBUG = os.getenv('DEBUG', 'false').lower() == 'true' def log_message(message: str): """Append a message to the log file with a timestamp.""" + if not DEBUG: + return timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') with open(LOG_FILE_PATH, 'a') as log_file: log_file.write(f"{timestamp} - {message}\n") @@ -101,8 +104,6 @@ async def entrypoint(ctx: JobContext): tts_provider = os.getenv('01_TTS', '').lower() stt_provider = os.getenv('01_STT', '').lower() - tts_provider='elevenlabs' - stt_provider='deepgram' # Add plugins here if tts_provider == 'openai': @@ -170,6 +171,8 @@ async def entrypoint(ctx: JobContext): else: async def process_query(): + log_message(f"[before_llm_cb] processing query in VAD with chat_ctx: {chat_ctx}") + if remote_video_processor and not video_muted: video_frame = await remote_video_processor.get_current_frame() if video_frame: @@ -185,7 +188,7 @@ async def entrypoint(ctx: JobContext): return process_query() ############################################################ - # on_message_received implementation + # on_message_received helper ############################################################ async def _on_message_received(msg: str): nonlocal push_to_talk @@ -318,9 +321,39 @@ async def entrypoint(ctx: JobContext): video_muted = False log_message(f"Track unmuted: {publication.kind}") + + ############################################################ + # on data received callback + ############################################################ + async def _publish_clear_chat(): + local_participant = ctx.room.local_participant + await local_participant.publish_data(payload="{CLEAR_CHAT}", topic="chat_context") + log_message("sent {CLEAR_CHAT} to chat_context for client to clear") + await assistant.say(assistant.start_message) + + + @ctx.room.on("data_received") + def on_data_received(data: rtc.DataPacket): + decoded_data = data.data.decode() + log_message(f"received data from {data.topic}: {decoded_data}") + if data.topic == "chat_context" and decoded_data == "{CLEAR_CHAT}": + assistant.chat_ctx.messages.clear() + assistant.chat_ctx.append( + role="system", + text=( + "Only take into context the user's image if their message is relevant or pertaining to the image. Otherwise just keep in context that the image is present but do not acknowledge or mention it in your response." + ), + ) + log_message(f"cleared chat_ctx") + log_message(f"chat_ctx is now {assistant.chat_ctx}") + + asyncio.create_task(_publish_clear_chat()) + + ############################################################ # Start the voice assistant with the LiveKit room ############################################################ + assistant = VoicePipelineAgent( vad=silero.VAD.load(), stt=stt, @@ -336,6 +369,21 @@ async def entrypoint(ctx: JobContext): # Greets the user with an initial message await assistant.say(start_message, allow_interruptions=True) + ############################################################ + # wait for the voice assistant to finish + ############################################################ + @assistant.on("agent_started_speaking") + def on_agent_started_speaking(): + asyncio.create_task(ctx.room.local_participant.publish_data(payload="{AGENT_STARTED_SPEAKING}", topic="agent_state")) + log_message("Agent started speaking") + return + + @assistant.on("agent_stopped_speaking") + def on_agent_stopped_speaking(): + asyncio.create_task(ctx.room.local_participant.publish_data(payload="{AGENT_STOPPED_SPEAKING}", topic="agent_state")) + log_message("Agent stopped speaking") + return + def main(livekit_url: str): # Workers have to be run as CLIs right now.