You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
432 lines
15 KiB
432 lines
15 KiB
from dotenv import load_dotenv
|
|
load_dotenv() # take environment variables from .env.
|
|
|
|
import traceback
|
|
from platformdirs import user_data_dir
|
|
import ast
|
|
import json
|
|
import queue
|
|
import os
|
|
import traceback
|
|
import datetime
|
|
from .utils.bytes_to_wav import bytes_to_wav
|
|
import re
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import PlainTextResponse
|
|
from starlette.websockets import WebSocket, WebSocketDisconnect
|
|
from pathlib import Path
|
|
import asyncio
|
|
import urllib.parse
|
|
from .utils.kernel import put_kernel_messages_into_queue
|
|
from .i import configure_interpreter
|
|
from interpreter import interpreter
|
|
from ..utils.accumulator import Accumulator
|
|
from .utils.logs import setup_logging
|
|
from .utils.logs import logger
|
|
|
|
from ..utils.print_markdown import print_markdown
|
|
|
|
os.environ["STT_RUNNER"] = "server"
|
|
os.environ["TTS_RUNNER"] = "server"
|
|
|
|
markdown = """
|
|
○
|
|
|
|
*Starting...*
|
|
"""
|
|
print("")
|
|
print_markdown(markdown)
|
|
print("")
|
|
|
|
|
|
setup_logging()
|
|
|
|
accumulator = Accumulator()
|
|
|
|
app = FastAPI()
|
|
|
|
app_dir = user_data_dir('01')
|
|
conversation_history_path = os.path.join(app_dir, 'conversations', 'user.json')
|
|
|
|
SERVER_LOCAL_PORT = int(os.getenv('SERVER_LOCAL_PORT', 10001))
|
|
|
|
|
|
# This is so we only say() full sentences
|
|
def is_full_sentence(text):
|
|
return text.endswith(('.', '!', '?'))
|
|
|
|
def split_into_sentences(text):
|
|
return re.split(r'(?<=[.!?])\s+', text)
|
|
|
|
# Queues
|
|
from_computer = queue.Queue() # Just for computer messages from the device. Sync queue because interpreter.run is synchronous
|
|
from_user = asyncio.Queue() # Just for user messages from the device.
|
|
to_device = asyncio.Queue() # For messages we send.
|
|
|
|
# Switch code executor to device if that's set
|
|
|
|
if os.getenv('CODE_RUNNER') == "device":
|
|
|
|
# (This should probably just loop through all languages and apply these changes instead)
|
|
|
|
class Python:
|
|
# This is the name that will appear to the LLM.
|
|
name = "python"
|
|
|
|
def __init__(self):
|
|
self.halt = False
|
|
|
|
def run(self, code):
|
|
"""Generator that yields a dictionary in LMC Format."""
|
|
|
|
# Prepare the data
|
|
message = {"role": "assistant", "type": "code", "format": "python", "content": code}
|
|
|
|
# Unless it was just sent to the device, send it wrapped in flags
|
|
if not (interpreter.messages and interpreter.messages[-1] == message):
|
|
to_device.put({"role": "assistant", "type": "code", "format": "python", "start": True})
|
|
to_device.put(message)
|
|
to_device.put({"role": "assistant", "type": "code", "format": "python", "end": True})
|
|
|
|
# Stream the response
|
|
logger.info("Waiting for the device to respond...")
|
|
while True:
|
|
chunk = from_computer.get()
|
|
logger.info(f"Server received from device: {chunk}")
|
|
if "end" in chunk:
|
|
break
|
|
yield chunk
|
|
|
|
def stop(self):
|
|
self.halt = True
|
|
|
|
def terminate(self):
|
|
"""Terminates the entire process."""
|
|
# dramatic!! do nothing
|
|
pass
|
|
|
|
interpreter.computer.languages = [Python]
|
|
|
|
# Configure interpreter
|
|
interpreter = configure_interpreter(interpreter)
|
|
|
|
@app.get("/ping")
|
|
async def ping():
|
|
return PlainTextResponse("pong")
|
|
|
|
@app.websocket("/")
|
|
async def websocket_endpoint(websocket: WebSocket):
|
|
await websocket.accept()
|
|
receive_task = asyncio.create_task(receive_messages(websocket))
|
|
send_task = asyncio.create_task(send_messages(websocket))
|
|
try:
|
|
await asyncio.gather(receive_task, send_task)
|
|
except Exception as e:
|
|
logger.debug(traceback.format_exc())
|
|
logger.info(f"Connection lost. Error: {e}")
|
|
|
|
|
|
@app.post("/")
|
|
async def add_computer_message(request: Request):
|
|
body = await request.json()
|
|
text = body.get("text")
|
|
if not text:
|
|
return {"error": "Missing 'text' in request body"}, 422
|
|
message = {"role": "user", "type": "message", "content": text}
|
|
await from_user.put({"role": "user", "type": "message", "start": True})
|
|
await from_user.put(message)
|
|
await from_user.put({"role": "user", "type": "message", "end": True})
|
|
|
|
|
|
async def receive_messages(websocket: WebSocket):
|
|
while True:
|
|
try:
|
|
try:
|
|
data = await websocket.receive()
|
|
except Exception as e:
|
|
print(str(e))
|
|
return
|
|
if 'text' in data:
|
|
try:
|
|
data = json.loads(data['text'])
|
|
if data["role"] == "computer":
|
|
from_computer.put(data) # To be handled by interpreter.computer.run
|
|
elif data["role"] == "user":
|
|
await from_user.put(data)
|
|
else:
|
|
raise("Unknown role:", data)
|
|
except json.JSONDecodeError:
|
|
pass # data is not JSON, leave it as is
|
|
elif 'bytes' in data:
|
|
data = data['bytes'] # binary data
|
|
await from_user.put(data)
|
|
except WebSocketDisconnect as e:
|
|
if e.code == 1000:
|
|
logger.info("Websocket connection closed normally.")
|
|
return
|
|
else:
|
|
raise
|
|
|
|
|
|
async def send_messages(websocket: WebSocket):
|
|
while True:
|
|
message = await to_device.get()
|
|
#print(f"Sending to the device: {type(message)} {str(message)[:100]}")
|
|
|
|
try:
|
|
if isinstance(message, dict):
|
|
await websocket.send_json(message)
|
|
elif isinstance(message, bytes):
|
|
await websocket.send_bytes(message)
|
|
else:
|
|
raise TypeError("Message must be a dict or bytes")
|
|
except:
|
|
# Make sure to put the message back in the queue if you failed to send it
|
|
await to_device.put(message)
|
|
raise
|
|
|
|
async def listener():
|
|
|
|
while True:
|
|
try:
|
|
while True:
|
|
if not from_user.empty():
|
|
chunk = await from_user.get()
|
|
break
|
|
elif not from_computer.empty():
|
|
chunk = from_computer.get()
|
|
break
|
|
await asyncio.sleep(1)
|
|
|
|
|
|
|
|
message = accumulator.accumulate(chunk)
|
|
if message == None:
|
|
# Will be None until we have a full message ready
|
|
continue
|
|
|
|
# print(str(message)[:1000])
|
|
|
|
# At this point, we have our message
|
|
|
|
if message["type"] == "audio" and message["format"].startswith("bytes"):
|
|
|
|
if "content" not in message or message["content"] == None: # If it was nothing / silence
|
|
continue
|
|
|
|
# Convert bytes to audio file
|
|
# Format will be bytes.wav or bytes.opus
|
|
mime_type = "audio/" + message["format"].split(".")[1]
|
|
audio_file_path = bytes_to_wav(message["content"], mime_type)
|
|
|
|
# For microphone debugging:
|
|
if False:
|
|
os.system(f"open {audio_file_path}")
|
|
import time
|
|
time.sleep(15)
|
|
|
|
text = stt(audio_file_path)
|
|
print("> ", text)
|
|
message = {"role": "user", "type": "message", "content": text}
|
|
|
|
# At this point, we have only text messages
|
|
|
|
if type(message["content"]) != str:
|
|
print("This should be a string, but it's not:", message["content"])
|
|
message["content"] = message["content"].decode()
|
|
|
|
# Custom stop message will halt us
|
|
if message["content"].lower().strip(".,! ") == "stop":
|
|
continue
|
|
|
|
# Load, append, and save conversation history
|
|
with open(conversation_history_path, 'r') as file:
|
|
messages = json.load(file)
|
|
messages.append(message)
|
|
with open(conversation_history_path, 'w') as file:
|
|
json.dump(messages, file, indent=4)
|
|
|
|
accumulated_text = ""
|
|
|
|
|
|
if any([m["type"] == "image" for m in messages]) and interpreter.llm.model.startswith("gpt-"):
|
|
interpreter.llm.model = "gpt-4-vision-preview"
|
|
interpreter.llm.supports_vision = True
|
|
|
|
for chunk in interpreter.chat(messages, stream=True, display=True):
|
|
|
|
if any([m["type"] == "image" for m in interpreter.messages]):
|
|
interpreter.llm.model = "gpt-4-vision-preview"
|
|
|
|
logger.debug("Got chunk:", chunk)
|
|
|
|
# Send it to the user
|
|
await to_device.put(chunk)
|
|
# Yield to the event loop, so you actually send it out
|
|
await asyncio.sleep(0.01)
|
|
|
|
if os.getenv('TTS_RUNNER') == "server":
|
|
# Speak full sentences out loud
|
|
if chunk["role"] == "assistant" and "content" in chunk and chunk["type"] == "message":
|
|
accumulated_text += chunk["content"]
|
|
sentences = split_into_sentences(accumulated_text)
|
|
|
|
# If we're going to speak, say we're going to stop sending text.
|
|
# This should be fixed probably, we should be able to do both in parallel, or only one.
|
|
if any(is_full_sentence(sentence) for sentence in sentences):
|
|
await to_device.put({"role": "assistant", "type": "message", "end": True})
|
|
|
|
if is_full_sentence(sentences[-1]):
|
|
for sentence in sentences:
|
|
await stream_tts_to_device(sentence)
|
|
accumulated_text = ""
|
|
else:
|
|
for sentence in sentences[:-1]:
|
|
await stream_tts_to_device(sentence)
|
|
accumulated_text = sentences[-1]
|
|
|
|
# If we're going to speak, say we're going to stop sending text.
|
|
# This should be fixed probably, we should be able to do both in parallel, or only one.
|
|
if any(is_full_sentence(sentence) for sentence in sentences):
|
|
await to_device.put({"role": "assistant", "type": "message", "start": True})
|
|
|
|
# If we have a new message, save our progress and go back to the top
|
|
if not from_user.empty():
|
|
|
|
# Check if it's just an end flag. We ignore those.
|
|
temp_message = await from_user.get()
|
|
|
|
if type(temp_message) is dict and temp_message.get("role") == "user" and temp_message.get("end"):
|
|
# Yup. False alarm.
|
|
continue
|
|
else:
|
|
# Whoops! Put that back
|
|
await from_user.put(temp_message)
|
|
|
|
with open(conversation_history_path, 'w') as file:
|
|
json.dump(interpreter.messages, file, indent=4)
|
|
|
|
# TODO: is triggering seemingly randomly
|
|
#logger.info("New user message recieved. Breaking.")
|
|
#break
|
|
|
|
# Also check if there's any new computer messages
|
|
if not from_computer.empty():
|
|
|
|
with open(conversation_history_path, 'w') as file:
|
|
json.dump(interpreter.messages, file, indent=4)
|
|
|
|
logger.info("New computer message recieved. Breaking.")
|
|
break
|
|
except:
|
|
traceback.print_exc()
|
|
|
|
async def stream_tts_to_device(sentence):
|
|
force_task_completion_responses = [
|
|
"the task is done",
|
|
"the task is impossible",
|
|
"let me know what you'd like to do next",
|
|
]
|
|
if sentence.lower().strip().strip(".!?").strip() in force_task_completion_responses:
|
|
return
|
|
|
|
for chunk in stream_tts(sentence):
|
|
await to_device.put(chunk)
|
|
|
|
def stream_tts(sentence):
|
|
|
|
audio_file = tts(sentence)
|
|
|
|
with open(audio_file, "rb") as f:
|
|
audio_bytes = f.read()
|
|
os.remove(audio_file)
|
|
|
|
file_type = "bytes.raw"
|
|
chunk_size = 1024
|
|
|
|
# Stream the audio
|
|
yield {"role": "assistant", "type": "audio", "format": file_type, "start": True}
|
|
for i in range(0, len(audio_bytes), chunk_size):
|
|
chunk = audio_bytes[i:i+chunk_size]
|
|
yield chunk
|
|
yield {"role": "assistant", "type": "audio", "format": file_type, "end": True}
|
|
|
|
from uvicorn import Config, Server
|
|
import os
|
|
import platform
|
|
from importlib import import_module
|
|
|
|
# these will be overwritten
|
|
HOST = ''
|
|
PORT = 0
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
server_url = f"{HOST}:{PORT}"
|
|
print("")
|
|
print_markdown(f"\n*Ready.*\n")
|
|
print("")
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown_event():
|
|
print_markdown("*Server is shutting down*")
|
|
|
|
async def main(server_host, server_port, llm_service, model, llm_supports_vision, llm_supports_functions, context_window, max_tokens, temperature, tts_service, stt_service):
|
|
|
|
global HOST
|
|
global PORT
|
|
PORT = server_port
|
|
HOST = server_host
|
|
|
|
# Setup services
|
|
application_directory = user_data_dir('01')
|
|
services_directory = os.path.join(application_directory, 'services')
|
|
|
|
service_dict = {'llm': llm_service, 'tts': tts_service, 'stt': stt_service}
|
|
|
|
# Create a temp file with the session number
|
|
session_file_path = os.path.join(user_data_dir('01'), '01-session.txt')
|
|
with open(session_file_path, 'w') as session_file:
|
|
session_id = int(datetime.datetime.now().timestamp() * 1000)
|
|
session_file.write(str(session_id))
|
|
|
|
for service in service_dict:
|
|
|
|
service_directory = os.path.join(services_directory, service, service_dict[service])
|
|
|
|
# This is the folder they can mess around in
|
|
config = {"service_directory": service_directory}
|
|
|
|
if service == "llm":
|
|
config.update({
|
|
"interpreter": interpreter,
|
|
"model": model,
|
|
"llm_supports_vision": llm_supports_vision,
|
|
"llm_supports_functions": llm_supports_functions,
|
|
"context_window": context_window,
|
|
"max_tokens": max_tokens,
|
|
"temperature": temperature
|
|
})
|
|
|
|
module = import_module(f'.server.services.{service}.{service_dict[service]}.{service}', package='source')
|
|
|
|
ServiceClass = getattr(module, service.capitalize())
|
|
service_instance = ServiceClass(config)
|
|
globals()[service] = getattr(service_instance, service)
|
|
|
|
interpreter.llm.completions = llm
|
|
|
|
# Start listening
|
|
asyncio.create_task(listener())
|
|
|
|
# Start watching the kernel if it's your job to do that
|
|
if True: # in the future, code can run on device. for now, just server.
|
|
asyncio.create_task(put_kernel_messages_into_queue(from_computer))
|
|
|
|
config = Config(app, host=server_host, port=int(server_port), lifespan='on')
|
|
server = Server(config)
|
|
await server.serve()
|
|
|
|
# Run the FastAPI app
|
|
if __name__ == "__main__":
|
|
asyncio.run(main()) |