add mobile flag

pull/256/head
Ben Xu 8 months ago
parent 2e0ab15e5b
commit 3dea99470a

@ -182,7 +182,7 @@ const Main: React.FC<MainProps> = ({ route }) => {
try { try {
const message = JSON.parse(e.data); const message = JSON.parse(e.data);
if (message.content && typeof message.content === "string") { if (message.content && message.type === "audio") {
console.log("✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅ Audio message"); console.log("✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅ Audio message");
const buffer = message.content; const buffer = message.content;

@ -39,6 +39,8 @@ print("")
setup_logging() setup_logging()
accumulator_global = Accumulator()
app = FastAPI() app = FastAPI()
app_dir = user_data_dir("01") app_dir = user_data_dir("01")
@ -196,26 +198,11 @@ async def send_messages(websocket: WebSocket):
try: try:
if isinstance(message, dict): if isinstance(message, dict):
# print(f"Sending to the device: {type(message)} {str(message)[:100]}") print(f"Sending to the device: {type(message)} {str(message)[:100]}")
await websocket.send_json(message) await websocket.send_json(message)
elif isinstance(message, bytes): elif isinstance(message, bytes):
message = base64.b64encode(message) print(f"Sending to the device: {type(message)} {str(message)[:100]}")
# print(f"Sending to the device: {type(message)} {str(message)[:100]}")
await websocket.send_bytes(message) await websocket.send_bytes(message)
"""
str_bytes = str(message)
json_bytes = {
"role": "assistant",
"type": "audio",
"format": "message",
"content": str_bytes,
}
print(
f"Sending to the device: {type(json_bytes)} {str(json_bytes)[:100]}"
)
await websocket.send_json(json_bytes)
"""
else: else:
raise TypeError("Message must be a dict or bytes") raise TypeError("Message must be a dict or bytes")
except: except:
@ -224,10 +211,11 @@ async def send_messages(websocket: WebSocket):
raise raise
async def listener(): async def listener(mobile: bool):
while True: while True:
try: try:
accumulator = Accumulator() if mobile:
accumulator_mobile = Accumulator()
while True: while True:
if not from_user.empty(): if not from_user.empty():
@ -238,7 +226,11 @@ async def listener():
break break
await asyncio.sleep(1) await asyncio.sleep(1)
message = accumulator.accumulate(chunk) if mobile:
message = accumulator_mobile.accumulate(chunk, mobile)
else:
message = accumulator_global.accumulate(chunk, mobile)
if message == None: if message == None:
# Will be None until we have a full message ready # Will be None until we have a full message ready
continue continue
@ -305,8 +297,9 @@ async def listener():
logger.debug("Got chunk:", chunk) logger.debug("Got chunk:", chunk)
# Send it to the user # Send it to the user
# await to_device.put(chunk) await to_device.put(chunk)
# Yield to the event loop, so you actually send it out
# Yield to the event loop, so you actxually send it out
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
if os.getenv("TTS_RUNNER") == "server": if os.getenv("TTS_RUNNER") == "server":
@ -328,11 +321,11 @@ async def listener():
if is_full_sentence(sentences[-1]): if is_full_sentence(sentences[-1]):
for sentence in sentences: for sentence in sentences:
await stream_tts_to_device(sentence) await stream_tts_to_device(sentence, mobile)
accumulated_text = "" accumulated_text = ""
else: else:
for sentence in sentences[:-1]: for sentence in sentences[:-1]:
await stream_tts_to_device(sentence) await stream_tts_to_device(sentence, mobile)
accumulated_text = sentences[-1] accumulated_text = sentences[-1]
# If we're going to speak, say we're going to stop sending text. # If we're going to speak, say we're going to stop sending text.
@ -376,7 +369,7 @@ async def listener():
traceback.print_exc() traceback.print_exc()
async def stream_tts_to_device(sentence): async def stream_tts_to_device(sentence, mobile: bool):
force_task_completion_responses = [ force_task_completion_responses = [
"the task is done", "the task is done",
"the task is impossible", "the task is impossible",
@ -385,41 +378,23 @@ async def stream_tts_to_device(sentence):
if sentence.lower().strip().strip(".!?").strip() in force_task_completion_responses: if sentence.lower().strip().strip(".!?").strip() in force_task_completion_responses:
return return
for chunk in stream_tts(sentence): for chunk in stream_tts(sentence, mobile):
await to_device.put(chunk) await to_device.put(chunk)
def stream_tts(sentence): def stream_tts(sentence, mobile: bool):
audio_file = tts(sentence) audio_file = tts(sentence, mobile)
# Read the entire WAV file
with open(audio_file, "rb") as f: with open(audio_file, "rb") as f:
audio_bytes = f.read() audio_bytes = f.read()
desktop_path = os.path.join(os.path.expanduser("~"), "Desktop")
desktop_audio_file = os.path.join(
desktop_path, f"{datetime.datetime.now()}" + os.path.basename(audio_file)
)
shutil.copy(audio_file, desktop_audio_file)
print(f"Audio file saved to Desktop: {desktop_audio_file}")
# storage_client = storage.Client(project="react-native-421323")
# bucket = storage_client.bucket("01-audio")
# blob = bucket.blob(f"{datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav")
# generation_match_precondition = 0
# blob.upload_from_filename(
# audio_file, if_generation_match=generation_match_precondition
# )
# print(
# f"Audio file {audio_file} uploaded to {datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav"
# )
if mobile:
file_type = "audio/wav" file_type = "audio/wav"
# Read the entire WAV file
with open(audio_file, "rb") as f:
audio_bytes = f.read()
os.remove(audio_file) os.remove(audio_file)
# Stream the audio as a single message # stream the audio as a single sentence
yield { yield {
"role": "assistant", "role": "assistant",
"type": "audio", "type": "audio",
@ -429,6 +404,19 @@ def stream_tts(sentence):
"end": True, "end": True,
} }
else:
# stream the audio in chunk sizes
os.remove(audio_file)
file_type = "bytes.raw"
chunk_size = 1024
yield {"role": "assistant", "type": "audio", "format": file_type, "start": True}
for i in range(0, len(audio_bytes), chunk_size):
chunk = audio_bytes[i : i + chunk_size]
yield chunk
yield {"role": "assistant", "type": "audio", "format": file_type, "end": True}
from uvicorn import Config, Server from uvicorn import Config, Server
import os import os
@ -464,6 +452,7 @@ async def main(
temperature, temperature,
tts_service, tts_service,
stt_service, stt_service,
mobile,
): ):
global HOST global HOST
global PORT global PORT
@ -515,7 +504,7 @@ async def main(
interpreter.llm.completions = llm interpreter.llm.completions = llm
# Start listening # Start listening
asyncio.create_task(listener()) asyncio.create_task(listener(mobile))
# Start watching the kernel if it's your job to do that # Start watching the kernel if it's your job to do that
if True: # in the future, code can run on device. for now, just server. if True: # in the future, code can run on device. for now, just server.

@ -25,7 +25,7 @@ class Tts:
def __init__(self, config): def __init__(self, config):
pass pass
def tts(self, text): def tts(self, text, mobile):
response = client.audio.speech.create( response = client.audio.speech.create(
model="tts-1", model="tts-1",
voice=os.getenv("OPENAI_VOICE_NAME", "alloy"), voice=os.getenv("OPENAI_VOICE_NAME", "alloy"),
@ -36,9 +36,15 @@ class Tts:
response.stream_to_file(temp_file.name) response.stream_to_file(temp_file.name)
# TODO: hack to format audio correctly for device # TODO: hack to format audio correctly for device
if mobile:
outfile = tempfile.gettempdir() + "/" + "output.wav" outfile = tempfile.gettempdir() + "/" + "output.wav"
ffmpeg.input(temp_file.name).output( ffmpeg.input(temp_file.name).output(
outfile, f="wav", ar="16000", ac="1", loglevel="panic" outfile, f="wav", ar="16000", ac="1", loglevel="panic"
).run() ).run()
else:
outfile = tempfile.gettempdir() + "/" + "raw.dat"
ffmpeg.input(temp_file.name).output(
outfile, f="s16le", ar="16000", ac="1", loglevel="panic"
).run()
return outfile return outfile

@ -3,7 +3,7 @@ class Accumulator:
self.template = {"role": None, "type": None, "format": None, "content": None} self.template = {"role": None, "type": None, "format": None, "content": None}
self.message = self.template self.message = self.template
def accumulate(self, chunk): def accumulate(self, chunk, mobile):
# print(str(chunk)[:100]) # print(str(chunk)[:100])
if type(chunk) == dict: if type(chunk) == dict:
if "format" in chunk and chunk["format"] == "active_line": if "format" in chunk and chunk["format"] == "active_line":
@ -44,6 +44,10 @@ class Accumulator:
if "content" not in self.message or type(self.message["content"]) != bytes: if "content" not in self.message or type(self.message["content"]) != bytes:
self.message["content"] = b"" self.message["content"] = b""
self.message["content"] += chunk self.message["content"] += chunk
if mobile:
self.message["type"] = "audio" self.message["type"] = "audio"
self.message["format"] = "bytes.wav" self.message["format"] = "bytes.wav"
return self.message return self.message
else:
return None

@ -72,13 +72,16 @@ def run(
False, "--local", help="Use recommended local services for LLM, STT, and TTS" False, "--local", help="Use recommended local services for LLM, STT, and TTS"
), ),
qr: bool = typer.Option(False, "--qr", help="Print the QR code for the server URL"), qr: bool = typer.Option(False, "--qr", help="Print the QR code for the server URL"),
mobile: bool = typer.Option(
False, "--mobile", help="Toggle server to support mobile app"
),
): ):
_run( _run(
server=server, server=server or mobile,
server_host=server_host, server_host=server_host,
server_port=server_port, server_port=server_port,
tunnel_service=tunnel_service, tunnel_service=tunnel_service,
expose=expose, expose=expose or mobile,
client=client, client=client,
server_url=server_url, server_url=server_url,
client_type=client_type, client_type=client_type,
@ -92,7 +95,8 @@ def run(
tts_service=tts_service, tts_service=tts_service,
stt_service=stt_service, stt_service=stt_service,
local=local, local=local,
qr=qr, qr=qr or mobile,
mobile=mobile,
) )
@ -116,6 +120,7 @@ def _run(
stt_service: str = "openai", stt_service: str = "openai",
local: bool = False, local: bool = False,
qr: bool = False, qr: bool = False,
mobile: bool = False,
): ):
if local: if local:
tts_service = "piper" tts_service = "piper"
@ -136,6 +141,7 @@ def _run(
signal.signal(signal.SIGINT, handle_exit) signal.signal(signal.SIGINT, handle_exit)
if server: if server:
print(f"Starting server with mobile = {mobile}")
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
server_thread = threading.Thread( server_thread = threading.Thread(
@ -153,6 +159,7 @@ def _run(
temperature, temperature,
tts_service, tts_service,
stt_service, stt_service,
mobile,
), ),
), ),
) )

Loading…
Cancel
Save