stash server changes

pull/279/head
Ben Xu 7 months ago
parent 2627fba481
commit 4b25239d0f

@ -2,6 +2,7 @@ from dotenv import load_dotenv
load_dotenv() # take environment variables from .env.
import requests
import subprocess
import os
import sys
@ -12,6 +13,7 @@ from pynput import keyboard
import json
import traceback
import websockets
import websockets.sync.client
import queue
from pydub import AudioSegment
from pydub.playback import play
@ -169,11 +171,11 @@ class Device:
elapsed_time = time.time() - self.playback_latency
print(f"Time from request to playback: {elapsed_time} seconds")
self.playback_latency = None
"""
if audio is not None:
mpv_process.stdin.write(audio) # type: ignore
mpv_process.stdin.flush() # type: ignore
"""
args = ["ffplay", "-autoexit", "-", "-nodisp"]
proc = subprocess.Popen(
args=args,
@ -183,9 +185,8 @@ class Device:
)
out, err = proc.communicate(input=audio)
proc.poll()
play(audio)
"""
play(audio)
# self.audiosegments.remove(audio)
# await asyncio.sleep(0.1)
except asyncio.exceptions.CancelledError:
@ -361,7 +362,7 @@ class Device:
async def websocket_communication(self, WS_URL):
print("websocket communication was called!!!!")
show_connection_log = True
"""
async def exec_ws_communication(websocket):
if CAMERA_ENABLED:
print(
@ -373,11 +374,11 @@ class Device:
asyncio.create_task(self.message_sender(websocket))
while True:
await asyncio.sleep(0.0001)
await asyncio.sleep(0)
chunk = await websocket.recv()
logger.debug(f"Got this message from the server: {type(chunk)} {chunk}")
# print((f"Got this message from the server: {type(chunk)} {chunk}"))
#logger.debug(f"Got this message from the server: {type(chunk)} {chunk}")
print((f"Got this message from the server: {type(chunk)}"))
if type(chunk) == str:
chunk = json.loads(chunk)
@ -388,7 +389,7 @@ class Device:
continue
# At this point, we have our message
# print("checkpoint reached!", message)
print("checkpoint reached!")
if isinstance(message, bytes):
# if message["type"] == "audio" and message["format"].startswith("bytes"):
@ -398,23 +399,23 @@ class Device:
audio_bytes = message
# Create an AudioSegment instance with the raw data
"""
audio = AudioSegment(
# raw audio data (bytes)
data=audio_bytes,
# signed 16-bit little-endian format
sample_width=2,
# 24,000 Hz frame rate
frame_rate=16000,
frame_rate=24000,
# mono sound
channels=1,
)
"""
# print("audio segment was created")
await self.audiosegments.put(audio_bytes)
# await self.audiosegments.put(audio)
print("audio segment was created")
#await self.audiosegments.put(audio_bytes)
await self.audiosegments.put(audio)
# Run the code if that's the client's job
if os.getenv("CODE_RUNNER") == "client":
@ -424,42 +425,65 @@ class Device:
result = interpreter.computer.run(language, code)
send_queue.put(result)
"""
if is_win10():
logger.info("Windows 10 detected")
# Workaround for Windows 10 not latching to the websocket server.
# See https://github.com/OpenInterpreter/01/issues/197
try:
ws = websockets.connect(WS_URL)
await exec_ws_communication(ws)
# await exec_ws_communication(ws)
except Exception as e:
logger.error(f"Error while attempting to connect: {e}")
else:
print("websocket url is", WS_URL)
while True:
i = 0
# while True:
# try:
# i += 1
# print("i is", i)
# # Hit the /ping endpoint
# ping_url = f"http://{self.server_url}/ping"
# response = requests.get(ping_url)
# print(response.text)
# # async with aiohttp.ClientSession() as session:
# # async with session.get(ping_url) as response:
# # print(f"Ping response: {await response.text()}")
for i in range(3):
print(i)
try:
async with websockets.connect(WS_URL) as websocket:
print("awaiting exec_ws_communication")
await exec_ws_communication(websocket)
print("happi happi happi :DDDDDDDDDDDDD")
# await exec_ws_communication(websocket)
# print("exiting exec_ws_communication")
except:
logger.info(traceback.format_exc())
if show_connection_log:
logger.info(f"Connecting to `{WS_URL}`...")
show_connection_log = False
await asyncio.sleep(2)
print("exception in websocket communication!!!!!!!!!!!!!!!!!")
traceback.print_exc()
# except:
# print("exception in websocket communication!!!!!!!!!!!!!!!!!")
# traceback.print_exc()
# if show_connection_log:
# logger.info(f"Connecting to `{WS_URL}`...")
# show_connection_log = False
# await asyncio.sleep(2)
async def start_async(self):
print("start async was called!!!!!")
# Configuration for WebSocket
WS_URL = f"ws://{self.server_url}"
WS_URL = f"ws://{self.server_url}/"
# Start the WebSocket communication
asyncio.create_task(self.websocket_communication(WS_URL))
await self.websocket_communication(WS_URL)
"""
# Start watching the kernel if it's your job to do that
if os.getenv("CODE_RUNNER") == "client":
asyncio.create_task(put_kernel_messages_into_queue(send_queue))
asyncio.create_task(self.play_audiosegments())
#asyncio.create_task(self.play_audiosegments())
# If Raspberry Pi, add the button listener, otherwise use the spacebar
if current_platform.startswith("raspberry-pi"):
@ -483,12 +507,11 @@ class Device:
else:
break
else:
# Keyboard listener for spacebar press/release
listener = keyboard.Listener(
on_press=self.on_press, on_release=self.on_release
)
listener.start()
print("listener for keyboard started!!!!!")
"""
# Keyboard listener for spacebar press/release
# listener = keyboard.Listener(on_press=self.on_press, on_release=self.on_release)
# listener.start()
# print("listener for keyboard started!!!!!")
def start(self):
print("device was started!!!!!!")

@ -38,7 +38,7 @@ class AsyncInterpreter:
self.stt = AudioToTextRecorder(
model="tiny.en", spinner=False, use_microphone=False
)
self.stt.stop() # It needs this for some reason
self.stt.stop()
# TTS
if self.interpreter.tts == "coqui":
@ -118,8 +118,6 @@ class AsyncInterpreter:
"""
self.interpreter.messages = self.active_chat_messages
# self.beeper.start()
self.stt.stop()
# message = self.stt.text()
# print("THE MESSAGE:", message)
@ -137,15 +135,9 @@ class AsyncInterpreter:
self.stt_latency = end_stt - start_stt
print("STT LATENCY", self.stt_latency)
# print(message)
end_interpreter = 0
# print(message)
def generate(message):
last_lmc_start_flag = self._last_lmc_start_flag
self.interpreter.messages = self.active_chat_messages
# print("🍀🍀🍀🍀GENERATING, using these messages: ", self.interpreter.messages)
# print("🍀 🍀 🍀 🍀 active_chat_messages: ", self.active_chat_messages)
print("message is", message)
for chunk in self.interpreter.chat(message, display=True, stream=True):
@ -209,7 +201,7 @@ class AsyncInterpreter:
text_iterator = generate(message)
self.tts.feed(text_iterator)
self.tts.play_async(on_audio_chunk=self.on_tts_chunk, muted=True)
self.tts.play_async(on_audio_chunk=self.on_tts_chunk, muted=False)
while True:
if self.tts.is_playing():
@ -236,7 +228,7 @@ class AsyncInterpreter:
break
async def _on_tts_chunk_async(self, chunk):
# print("SENDING TTS CHUNK")
print(f"Adding chunk to output queue")
await self._add_to_queue(self._output_queue, chunk)
def on_tts_chunk(self, chunk):
@ -244,4 +236,7 @@ class AsyncInterpreter:
asyncio.run(self._on_tts_chunk_async(chunk))
async def output(self):
return await self._output_queue.get()
print("entering output method")
value = await self._output_queue.get()
print("output method returning")
return value

@ -5,6 +5,7 @@ from fastapi import FastAPI, WebSocket, Header
from fastapi.responses import PlainTextResponse
from uvicorn import Config, Server
from interpreter import interpreter as base_interpreter
from starlette.websockets import WebSocketDisconnect
from .async_interpreter import AsyncInterpreter
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Dict, Any
@ -23,18 +24,11 @@ base_interpreter.llm.model = "groq/llama3-8b-8192"
base_interpreter.llm.api_key = os.environ["GROQ_API_KEY"]
base_interpreter.llm.supports_functions = False
base_interpreter.auto_run = True
base_interpreter.tts = "elevenlabs"
os.environ["STT_RUNNER"] = "server"
os.environ["TTS_RUNNER"] = "server"
# Parse command line arguments for port number
"""
parser = argparse.ArgumentParser(description="FastAPI server.")
parser.add_argument("--port", type=int, default=8000, help="Port to run on.")
args = parser.parse_args()
"""
base_interpreter.tts = "coqui"
async def main(server_host, server_port):
interpreter = AsyncInterpreter(base_interpreter)
@ -60,84 +54,111 @@ async def main(server_host, server_port):
print("🪼🪼🪼🪼🪼🪼 Messages loaded: ", interpreter.active_chat_messages)
return {"status": "success"}
print("About to set up the websocker endpoint!!!!!!!!!!!!!!!!!!!!!!!!!")
@app.websocket("/")
async def websocket_endpoint(websocket: WebSocket):
print("websocket hit")
await websocket.accept()
try:
print("websocket accepted")
async def receive_input():
async def send_output():
try:
while True:
if websocket.client_state == "DISCONNECTED":
break
output = await interpreter.output()
if isinstance(output, bytes):
print("server sending bytes output")
try:
await websocket.send_bytes(output)
print("server successfully sent bytes output")
except Exception as e:
print(f"Error: {e}")
traceback.print_exc()
return {"error": str(e)}
elif isinstance(output, dict):
print("server sending text output")
try:
await websocket.send_text(json.dumps(output))
print("server successfully sent text output")
except Exception as e:
print(f"Error: {e}")
traceback.print_exc()
return {"error": str(e)}
except asyncio.CancelledError:
print("WebSocket connection closed")
traceback.print_exc()
async def receive_input():
try:
while True:
print("server awaiting input")
data = await websocket.receive()
if isinstance(data, bytes):
await interpreter.input(data)
try:
await interpreter.input(data)
except Exception as e:
print(f"Error: {e}")
traceback.print_exc()
return {"error": str(e)}
elif "bytes" in data:
await interpreter.input(data["bytes"])
# print("RECEIVED INPUT", data)
try:
await interpreter.input(data["bytes"])
except Exception as e:
print(f"Error: {e}")
traceback.print_exc()
return {"error": str(e)}
elif "text" in data:
# print("RECEIVED INPUT", data)
await interpreter.input(data["text"])
try:
await interpreter.input(data["text"])
except Exception as e:
print(f"Error: {e}")
traceback.print_exc()
return {"error": str(e)}
except asyncio.CancelledError:
print("WebSocket connection closed")
traceback.print_exc()
async def send_output():
while True:
output = await interpreter.output()
try:
send_task = asyncio.create_task(send_output())
receive_task = asyncio.create_task(receive_input())
print("server starting to handle ws connection")
"""
done, pending = await asyncio.wait(
[send_task, receive_task],
return_when=asyncio.FIRST_COMPLETED,
)
if isinstance(output, bytes):
# print(f"Sending {len(output)} bytes of audio data.")
await websocket.send_bytes(output)
# we dont send out bytes rn, no TTS
for task in pending:
task.cancel()
elif isinstance(output, dict):
# print("sending text")
await websocket.send_text(json.dumps(output))
for task in done:
if task.exception() is not None:
raise
"""
await asyncio.gather(send_task, receive_task)
print("server finished handling ws connection")
await asyncio.gather(send_output(), receive_input())
except WebSocketDisconnect:
print("WebSocket disconnected")
except Exception as e:
print(f"WebSocket connection closed with exception: {e}")
traceback.print_exc()
finally:
if not websocket.client_state == "DISCONNECTED":
await websocket.close()
print("server closing ws connection")
await websocket.close()
print(f"Starting server on {server_host}:{server_port}")
config = Config(app, host=server_host, port=server_port, lifespan="on")
server = Server(config)
await server.serve()
class Rename(BaseModel):
input: str
@app.post("/rename-chat")
async def rename_chat(body_content: Rename, x_api_key: str = Header(None)):
print("RENAME CHAT REQUEST in PY 🌙🌙🌙🌙")
input_value = body_content.input
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=x_api_key,
)
try:
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": f"Given the following chat snippet, create a unique and descriptive title in less than 8 words. Your answer must not be related to customer service.\n\n{input_value}",
}
],
temperature=0.3,
stream=False,
)
print(response)
completion = response["choices"][0]["message"]["content"]
return {"data": {"content": completion}}
except Exception as e:
print(f"Error: {e}")
traceback.print_exc()
return {"error": str(e)}
if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main("localhost", 8000))

Loading…
Cancel
Save