Implemented `profiles`

pull/284/head
killian 7 months ago
parent 632af7f7ba
commit fda23e95b2

@ -1,14 +1,3 @@
# import from the profiles directory the interpreter to be served
# add other profiles to the directory to define other interpreter instances and import them here
# {.profiles.fast: optimizes for STT/TTS latency with the fastest models }
# {.profiles.local: uses local models and local STT/TTS }
# {.profiles.default: uses default interpreter settings with optimized TTS latency }
# from .profiles.fast import interpreter as base_interpreter
# from .profiles.local import interpreter as base_interpreter
from .profiles.default import interpreter as base_interpreter
import asyncio import asyncio
import traceback import traceback
import json import json
@ -19,6 +8,7 @@ from .async_interpreter import AsyncInterpreter
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from typing import List, Dict, Any from typing import List, Dict, Any
import os import os
import importlib.util
os.environ["STT_RUNNER"] = "server" os.environ["STT_RUNNER"] = "server"
os.environ["TTS_RUNNER"] = "server" os.environ["TTS_RUNNER"] = "server"
@ -50,14 +40,6 @@ async def websocket_endpoint(
): ):
await websocket.accept() await websocket.accept()
# interpreter.tts set in the profiles directory!!!!
interpreter = AsyncInterpreter(base_interpreter, debug)
# Send the tts_service value to the client
await websocket.send_text(
json.dumps({"type": "config", "tts_service": interpreter.interpreter.tts})
)
try: try:
async def receive_input(): async def receive_input():
@ -98,9 +80,21 @@ async def websocket_endpoint(
await websocket.close() await websocket.close()
async def main(server_host, server_port, debug): async def main(server_host, server_port, profile, debug):
app.state.debug = debug app.state.debug = debug
# Load the profile module from the provided path
spec = importlib.util.spec_from_file_location("profile", profile)
profile_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(profile_module)
# Get the interpreter from the profile
interpreter = profile_module.interpreter
# Make it async
interpreter = AsyncInterpreter(interpreter, debug)
print(f"Starting server on {server_host}:{server_port}") print(f"Starting server on {server_host}:{server_port}")
config = Config(app, host=server_host, port=server_port, lifespan="on") config = Config(app, host=server_host, port=server_port, lifespan="on")
server = Server(config) server = Server(config)

@ -6,7 +6,7 @@ from ..utils.print_markdown import print_markdown
def create_tunnel( def create_tunnel(
tunnel_method="ngrok", server_host="localhost", server_port=10001, qr=False tunnel_method="ngrok", server_host="localhost", server_port=10001, qr=False, domain=None
): ):
print_markdown("Exposing server to the internet...") print_markdown("Exposing server to the internet...")
@ -99,8 +99,13 @@ def create_tunnel(
# If ngrok is installed, start it on the specified port # If ngrok is installed, start it on the specified port
# process = subprocess.Popen(f'ngrok http {server_port} --log=stdout', shell=True, stdout=subprocess.PIPE) # process = subprocess.Popen(f'ngrok http {server_port} --log=stdout', shell=True, stdout=subprocess.PIPE)
if domain:
domain = f"--domain={domain}"
else:
domain = ""
process = subprocess.Popen( process = subprocess.Popen(
f"ngrok http {server_port} --scheme http,https --log=stdout", f"ngrok http {server_port} --scheme http,https {domain} --log=stdout",
shell=True, shell=True,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
) )

@ -6,6 +6,7 @@ import os
import importlib import importlib
from source.server.tunnel import create_tunnel from source.server.tunnel import create_tunnel
from source.server.async_server import main from source.server.async_server import main
import subprocess
import signal import signal
@ -41,11 +42,25 @@ def run(
qr: bool = typer.Option( qr: bool = typer.Option(
False, "--qr", help="Display QR code to scan to connect to the server" False, "--qr", help="Display QR code to scan to connect to the server"
), ),
domain: str = typer.Option(
None, "--domain", help="Connect ngrok to a custom domain"
),
profiles: bool = typer.Option(
False,
"--profiles",
help="Opens the folder where this script is contained",
),
profile: str = typer.Option(
"default.py", # default
"--profile",
help="Specify the path to the profile, or the name of the file if it's in the `profiles` directory (run `--profiles` to open the profiles directory)",
),
debug: bool = typer.Option( debug: bool = typer.Option(
False, False,
"--debug", "--debug",
help="Print latency measurements and save microphone recordings locally for manual playback.", help="Print latency measurements and save microphone recordings locally for manual playback.",
), ),
): ):
_run( _run(
server=server, server=server,
@ -58,6 +73,9 @@ def run(
client_type=client_type, client_type=client_type,
qr=qr, qr=qr,
debug=debug, debug=debug,
domain=domain,
profiles=profiles,
profile=profile,
) )
@ -72,8 +90,33 @@ def _run(
client_type: str = "auto", client_type: str = "auto",
qr: bool = False, qr: bool = False,
debug: bool = False, debug: bool = False,
domain = None,
profiles = None,
profile = None,
): ):
profiles_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "source", "server", "profiles")
if profiles:
if platform.system() == "Windows":
subprocess.Popen(['explorer', profiles_dir])
elif platform.system() == "Darwin":
subprocess.Popen(['open', profiles_dir])
elif platform.system() == "Linux":
subprocess.Popen(['xdg-open', profiles_dir])
else:
subprocess.Popen(['open', profiles_dir])
exit(0)
if profile:
if not os.path.isfile(profile):
profile = os.path.join(profiles_dir, profile)
if not os.path.isfile(profile):
profile += ".py"
if not os.path.isfile(profile):
print(f"Invalid profile path: {profile}")
exit(1)
system_type = platform.system() system_type = platform.system()
if system_type == "Windows": if system_type == "Windows":
server_host = "localhost" server_host = "localhost"
@ -91,7 +134,6 @@ def _run(
signal.signal(signal.SIGINT, handle_exit) signal.signal(signal.SIGINT, handle_exit)
if server: if server:
# print(f"Starting server with mobile = {mobile}")
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
server_thread = threading.Thread( server_thread = threading.Thread(
@ -100,6 +142,7 @@ def _run(
main( main(
server_host, server_host,
server_port, server_port,
profile,
debug, debug,
), ),
), ),
@ -108,7 +151,7 @@ def _run(
if expose: if expose:
tunnel_thread = threading.Thread( tunnel_thread = threading.Thread(
target=create_tunnel, args=[tunnel_service, server_host, server_port, qr] target=create_tunnel, args=[tunnel_service, server_host, server_port, qr, domain]
) )
tunnel_thread.start() tunnel_thread.start()

Loading…
Cancel
Save