You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
01/software/source/server/async_server.py

125 lines
4.8 KiB

import importlib
import traceback
import json
import os
from RealtimeTTS import TextToAudioStream, CoquiEngine, OpenAIEngine, ElevenlabsEngine
from RealtimeSTT import AudioToTextRecorder
import types
import time
import wave
import asyncio
from fastapi.responses import PlainTextResponse
def start_server(server_host, server_port, profile, debug):
7 months ago
# Load the profile module from the provided path
spec = importlib.util.spec_from_file_location("profile", profile)
profile_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(profile_module)
# Get the interpreter from the profile
interpreter = profile_module.interpreter
# STT
interpreter.stt = AudioToTextRecorder(
model="tiny.en", spinner=False, use_microphone=False
)
interpreter.stt.stop() # It needs this for some reason
# TTS
if not hasattr(interpreter, 'tts'):
print("Setting TTS provider to default: openai")
interpreter.tts = "openai"
if interpreter.tts == "coqui":
engine = CoquiEngine()
elif interpreter.tts == "openai":
engine = OpenAIEngine(voice="onyx")
elif interpreter.tts == "elevenlabs":
engine = ElevenlabsEngine(api_key=os.environ["ELEVEN_LABS_API_KEY"])
engine.set_voice("Michael")
else:
raise ValueError(f"Unsupported TTS engine: {interpreter.interpreter.tts}")
interpreter.tts = TextToAudioStream(engine)
# Misc Settings
interpreter.verbose = debug
interpreter.server.host = server_host
interpreter.server.port = server_port
interpreter.audio_chunks = []
old_input = interpreter.input
old_output = interpreter.output
async def new_input(self, chunk):
await asyncio.sleep(0)
if isinstance(chunk, bytes):
self.stt.feed_audio(chunk)
self.audio_chunks.append(chunk)
elif isinstance(chunk, dict):
if "start" in chunk:
self.stt.start()
self.audio_chunks = []
await old_input({"role": "user", "type": "message", "start": True})
if "end" in chunk:
self.stt.stop()
content = self.stt.text()
print("User: ", content)
if False:
audio_bytes = bytearray(b"".join(self.audio_chunks))
with wave.open('audio.wav', 'wb') as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2) # Assuming 16-bit audio
wav_file.setframerate(16000) # Assuming 16kHz sample rate
wav_file.writeframes(audio_bytes)
print(os.path.abspath('audio.wav'))
await old_input({"role": "user", "type": "message", "content": content})
await old_input({"role": "user", "type": "message", "end": True})
async def new_output(self):
while True:
output = await old_output()
# if output == {"role": "assistant", "type": "message", "start": True}:
# return {"role": "assistant", "type": "audio", "format": "bytes.wav", "start": True}
if isinstance(output, bytes):
return output
await asyncio.sleep(0)
delimiters = ".?!;,\n…)]}"
if output["type"] == "message" and len(output.get("content", "")) > 0:
self.tts.feed(output.get("content"))
if not self.tts.is_playing() and any([c in delimiters for c in output.get("content")]): # Start playing once the first delimiter is encountered.
self.tts.play_async(on_audio_chunk=self.on_tts_chunk, muted=True, sentence_fragment_delimiters=delimiters)
return {"role": "assistant", "type": "audio", "format": "bytes.wav", "start": True}
if output == {"role": "assistant", "type": "message", "end": True}:
if not self.tts.is_playing(): # We put this here in case it never outputs a delimiter and never triggers play_async^
self.tts.play_async(on_audio_chunk=self.on_tts_chunk, muted=True, sentence_fragment_delimiters=delimiters)
return {"role": "assistant", "type": "audio", "format": "bytes.wav", "start": True}
return {"role": "assistant", "type": "audio", "format": "bytes.wav", "end": True}
def on_tts_chunk(self, chunk):
self.output_queue.sync_q.put(chunk)
# Wrap in voice interface
interpreter.input = types.MethodType(new_input, interpreter)
interpreter.output = types.MethodType(new_output, interpreter)
interpreter.on_tts_chunk = types.MethodType(on_tts_chunk, interpreter)
@interpreter.server.app.get("/ping")
async def ping():
return PlainTextResponse("pong")
# Start server
interpreter.server.run()