You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
146 lines
5.6 KiB
146 lines
5.6 KiB
import typer
|
|
import asyncio
|
|
import platform
|
|
import concurrent.futures
|
|
import threading
|
|
import os
|
|
import importlib
|
|
from source.server.tunnel import create_tunnel
|
|
from source.server.server import main
|
|
|
|
import signal
|
|
app = typer.Typer()
|
|
|
|
@app.command()
|
|
def run(
|
|
server: bool = typer.Option(False, "--server", help="Run server"),
|
|
server_host: str = typer.Option("0.0.0.0", "--server-host", help="Specify the server host where the server will deploy"),
|
|
server_port: int = typer.Option(10001, "--server-port", help="Specify the server port where the server will deploy"),
|
|
|
|
tunnel_service: str = typer.Option("ngrok", "--tunnel-service", help="Specify the tunnel service"),
|
|
expose: bool = typer.Option(False, "--expose", help="Expose server to internet"),
|
|
|
|
client: bool = typer.Option(False, "--client", help="Run client"),
|
|
server_url: str = typer.Option(None, "--server-url", help="Specify the server URL that the client should expect. Defaults to server-host and server-port"),
|
|
client_type: str = typer.Option("auto", "--client-type", help="Specify the client type"),
|
|
|
|
llm_service: str = typer.Option("litellm", "--llm-service", help="Specify the LLM service"),
|
|
|
|
model: str = typer.Option("gpt-4", "--model", help="Specify the model"),
|
|
llm_supports_vision: bool = typer.Option(False, "--llm-supports-vision", help="Specify if the LLM service supports vision"),
|
|
llm_supports_functions: bool = typer.Option(False, "--llm-supports-functions", help="Specify if the LLM service supports functions"),
|
|
context_window: int = typer.Option(2048, "--context-window", help="Specify the context window size"),
|
|
max_tokens: int = typer.Option(4096, "--max-tokens", help="Specify the maximum number of tokens"),
|
|
temperature: float = typer.Option(0.8, "--temperature", help="Specify the temperature for generation"),
|
|
|
|
tts_service: str = typer.Option("openai", "--tts-service", help="Specify the TTS service"),
|
|
|
|
stt_service: str = typer.Option("openai", "--stt-service", help="Specify the STT service"),
|
|
|
|
local: bool = typer.Option(False, "--local", help="Use recommended local services for LLM, STT, and TTS"),
|
|
):
|
|
|
|
_run(
|
|
server=server,
|
|
server_host=server_host,
|
|
server_port=server_port,
|
|
tunnel_service=tunnel_service,
|
|
expose=expose,
|
|
client=client,
|
|
server_url=server_url,
|
|
client_type=client_type,
|
|
llm_service=llm_service,
|
|
model=model,
|
|
llm_supports_vision=llm_supports_vision,
|
|
llm_supports_functions=llm_supports_functions,
|
|
context_window=context_window,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
tts_service=tts_service,
|
|
stt_service=stt_service,
|
|
local=local
|
|
)
|
|
|
|
def _run(
|
|
server: bool = False,
|
|
server_host: str = "0.0.0.0",
|
|
server_port: int = 10001,
|
|
|
|
tunnel_service: str = "bore",
|
|
expose: bool = False,
|
|
|
|
client: bool = False,
|
|
server_url: str = None,
|
|
client_type: str = "auto",
|
|
|
|
llm_service: str = "litellm",
|
|
|
|
model: str = "gpt-4",
|
|
llm_supports_vision: bool = False,
|
|
llm_supports_functions: bool = False,
|
|
context_window: int = 2048,
|
|
max_tokens: int = 4096,
|
|
temperature: float = 0.8,
|
|
|
|
tts_service: str = "openai",
|
|
|
|
stt_service: str = "openai",
|
|
|
|
local: bool = False
|
|
):
|
|
|
|
if local:
|
|
tts_service = "piper"
|
|
# llm_service = "llamafile"
|
|
stt_service = "local-whisper"
|
|
|
|
if not server_url:
|
|
server_url = f"{server_host}:{server_port}"
|
|
|
|
if not server and not client:
|
|
server = True
|
|
client = True
|
|
|
|
def handle_exit(signum, frame):
|
|
os._exit(0)
|
|
|
|
signal.signal(signal.SIGINT, handle_exit)
|
|
|
|
if server:
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
server_thread = threading.Thread(target=loop.run_until_complete, args=(main(server_host, server_port, llm_service, model, llm_supports_vision, llm_supports_functions, context_window, max_tokens, temperature, tts_service, stt_service),))
|
|
server_thread.start()
|
|
|
|
if expose:
|
|
tunnel_thread = threading.Thread(target=create_tunnel, args=[tunnel_service, server_host, server_port])
|
|
tunnel_thread.start()
|
|
|
|
if client:
|
|
if client_type == "auto":
|
|
system_type = platform.system()
|
|
if system_type == "Darwin": # Mac OS
|
|
client_type = "mac"
|
|
elif system_type == "Linux": # Linux System
|
|
try:
|
|
with open('/proc/device-tree/model', 'r') as m:
|
|
if 'raspberry pi' in m.read().lower():
|
|
client_type = "rpi"
|
|
else:
|
|
client_type = "linux"
|
|
except FileNotFoundError:
|
|
client_type = "linux"
|
|
|
|
module = importlib.import_module(f".clients.{client_type}.device", package='source')
|
|
client_thread = threading.Thread(target=module.main, args=[server_url])
|
|
client_thread.start()
|
|
|
|
try:
|
|
if server:
|
|
server_thread.join()
|
|
if expose:
|
|
tunnel_thread.join()
|
|
if client:
|
|
client_thread.join()
|
|
except KeyboardInterrupt:
|
|
os.kill(os.getpid(), signal.SIGINT) |