From fda23e95b22a679f6b1197c3d1f78d9caece8521 Mon Sep 17 00:00:00 2001 From: killian <63927363+KillianLucas@users.noreply.github.com> Date: Wed, 10 Jul 2024 10:56:54 -0700 Subject: [PATCH] Implemented `profiles` --- software/source/server/async_server.py | 34 ++++++++----------- software/source/server/tunnel.py | 9 +++-- software/start.py | 47 ++++++++++++++++++++++++-- 3 files changed, 66 insertions(+), 24 deletions(-) diff --git a/software/source/server/async_server.py b/software/source/server/async_server.py index 8bb91a3..849f72d 100644 --- a/software/source/server/async_server.py +++ b/software/source/server/async_server.py @@ -1,14 +1,3 @@ -# import from the profiles directory the interpreter to be served - -# add other profiles to the directory to define other interpreter instances and import them here -# {.profiles.fast: optimizes for STT/TTS latency with the fastest models } -# {.profiles.local: uses local models and local STT/TTS } -# {.profiles.default: uses default interpreter settings with optimized TTS latency } - -# from .profiles.fast import interpreter as base_interpreter -# from .profiles.local import interpreter as base_interpreter -from .profiles.default import interpreter as base_interpreter - import asyncio import traceback import json @@ -19,6 +8,7 @@ from .async_interpreter import AsyncInterpreter from fastapi.middleware.cors import CORSMiddleware from typing import List, Dict, Any import os +import importlib.util os.environ["STT_RUNNER"] = "server" os.environ["TTS_RUNNER"] = "server" @@ -50,14 +40,6 @@ async def websocket_endpoint( ): await websocket.accept() - # interpreter.tts set in the profiles directory!!!! - interpreter = AsyncInterpreter(base_interpreter, debug) - - # Send the tts_service value to the client - await websocket.send_text( - json.dumps({"type": "config", "tts_service": interpreter.interpreter.tts}) - ) - try: async def receive_input(): @@ -98,9 +80,21 @@ async def websocket_endpoint( await websocket.close() -async def main(server_host, server_port, debug): +async def main(server_host, server_port, profile, debug): + app.state.debug = debug + # Load the profile module from the provided path + spec = importlib.util.spec_from_file_location("profile", profile) + profile_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(profile_module) + + # Get the interpreter from the profile + interpreter = profile_module.interpreter + + # Make it async + interpreter = AsyncInterpreter(interpreter, debug) + print(f"Starting server on {server_host}:{server_port}") config = Config(app, host=server_host, port=server_port, lifespan="on") server = Server(config) diff --git a/software/source/server/tunnel.py b/software/source/server/tunnel.py index f25a0b3..a40c0f3 100644 --- a/software/source/server/tunnel.py +++ b/software/source/server/tunnel.py @@ -6,7 +6,7 @@ from ..utils.print_markdown import print_markdown def create_tunnel( - tunnel_method="ngrok", server_host="localhost", server_port=10001, qr=False + tunnel_method="ngrok", server_host="localhost", server_port=10001, qr=False, domain=None ): print_markdown("Exposing server to the internet...") @@ -99,8 +99,13 @@ def create_tunnel( # If ngrok is installed, start it on the specified port # process = subprocess.Popen(f'ngrok http {server_port} --log=stdout', shell=True, stdout=subprocess.PIPE) + + if domain: + domain = f"--domain={domain}" + else: + domain = "" process = subprocess.Popen( - f"ngrok http {server_port} --scheme http,https --log=stdout", + f"ngrok http {server_port} --scheme http,https {domain} --log=stdout", shell=True, stdout=subprocess.PIPE, ) diff --git a/software/start.py b/software/start.py index 0808b0f..28c5675 100644 --- a/software/start.py +++ b/software/start.py @@ -6,6 +6,7 @@ import os import importlib from source.server.tunnel import create_tunnel from source.server.async_server import main +import subprocess import signal @@ -41,11 +42,25 @@ def run( qr: bool = typer.Option( False, "--qr", help="Display QR code to scan to connect to the server" ), + domain: str = typer.Option( + None, "--domain", help="Connect ngrok to a custom domain" + ), + profiles: bool = typer.Option( + False, + "--profiles", + help="Opens the folder where this script is contained", + ), + profile: str = typer.Option( + "default.py", # default + "--profile", + help="Specify the path to the profile, or the name of the file if it's in the `profiles` directory (run `--profiles` to open the profiles directory)", + ), debug: bool = typer.Option( False, "--debug", help="Print latency measurements and save microphone recordings locally for manual playback.", ), + ): _run( server=server, @@ -58,6 +73,9 @@ def run( client_type=client_type, qr=qr, debug=debug, + domain=domain, + profiles=profiles, + profile=profile, ) @@ -72,8 +90,33 @@ def _run( client_type: str = "auto", qr: bool = False, debug: bool = False, + domain = None, + profiles = None, + profile = None, ): + profiles_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "source", "server", "profiles") + + if profiles: + if platform.system() == "Windows": + subprocess.Popen(['explorer', profiles_dir]) + elif platform.system() == "Darwin": + subprocess.Popen(['open', profiles_dir]) + elif platform.system() == "Linux": + subprocess.Popen(['xdg-open', profiles_dir]) + else: + subprocess.Popen(['open', profiles_dir]) + exit(0) + + if profile: + if not os.path.isfile(profile): + profile = os.path.join(profiles_dir, profile) + if not os.path.isfile(profile): + profile += ".py" + if not os.path.isfile(profile): + print(f"Invalid profile path: {profile}") + exit(1) + system_type = platform.system() if system_type == "Windows": server_host = "localhost" @@ -91,7 +134,6 @@ def _run( signal.signal(signal.SIGINT, handle_exit) if server: - # print(f"Starting server with mobile = {mobile}") loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) server_thread = threading.Thread( @@ -100,6 +142,7 @@ def _run( main( server_host, server_port, + profile, debug, ), ), @@ -108,7 +151,7 @@ def _run( if expose: tunnel_thread = threading.Thread( - target=create_tunnel, args=[tunnel_service, server_host, server_port, qr] + target=create_tunnel, args=[tunnel_service, server_host, server_port, qr, domain] ) tunnel_thread.start()