From 4c82587db8874ddb134a547cfaf833109d645a52 Mon Sep 17 00:00:00 2001 From: killian <63927363+KillianLucas@users.noreply.github.com> Date: Thu, 7 Mar 2024 11:29:29 -0800 Subject: [PATCH] Fixed local whisper --- .../server/services/stt/local-whisper/stt.py | 38 ++++++------------- 01OS/01OS/server/services/stt/openai/stt.py | 4 +- 01OS/start.py | 2 +- 3 files changed, 14 insertions(+), 30 deletions(-) diff --git a/01OS/01OS/server/services/stt/local-whisper/stt.py b/01OS/01OS/server/services/stt/local-whisper/stt.py index e7bf150..9514c1d 100644 --- a/01OS/01OS/server/services/stt/local-whisper/stt.py +++ b/01OS/01OS/server/services/stt/local-whisper/stt.py @@ -16,11 +16,11 @@ import subprocess class Stt: def __init__(self, config): - service_directory = config["service_directory"] - install(service_directory) + self.service_directory = config["service_directory"] + install(self.service_directory) def stt(self, audio_file_path): - return stt(audio_file_path) + return stt(self.service_directory, audio_file_path) @@ -109,14 +109,12 @@ def run_command(command): result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) return result.stdout, result.stderr -def get_transcription_file(wav_file_path: str): - local_path = os.path.join(os.path.dirname(__file__), 'model') - whisper_rust_path = os.path.join(os.path.dirname(__file__), 'whisper-rust', 'target', 'release') - model_name = os.getenv('WHISPER_MODEL_NAME') - if not model_name: - raise EnvironmentError("WHISPER_MODEL_NAME environment variable is not set.") +def get_transcription_file(service_directory, wav_file_path: str): + local_path = os.path.join(service_directory, 'model') + whisper_rust_path = os.path.join(service_directory, 'whisper-rust', 'target', 'release') + model_name = os.getenv('WHISPER_MODEL_NAME', 'ggml-tiny.en.bin') - output, error = run_command([ + output, _ = run_command([ os.path.join(whisper_rust_path, 'whisper-rust'), '--model-path', os.path.join(local_path, model_name), '--file-path', wav_file_path @@ -124,28 +122,16 @@ def get_transcription_file(wav_file_path: str): return output -def get_transcription_bytes(audio_bytes: bytearray, mime_type): - with export_audio_to_wav_ffmpeg(audio_bytes, mime_type) as wav_file_path: - return get_transcription_file(wav_file_path) -def stt_bytes(audio_bytes: bytearray, mime_type="audio/wav"): - with export_audio_to_wav_ffmpeg(audio_bytes, mime_type) as wav_file_path: - return stt_wav(wav_file_path) - -def stt_wav(wav_file_path: str): +def stt_wav(service_directory, wav_file_path: str): temp_dir = tempfile.gettempdir() output_path = os.path.join(temp_dir, f"output_stt_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav") ffmpeg.input(wav_file_path).output(output_path, acodec='pcm_s16le', ac=1, ar='16k').run() try: - transcript = get_transcription_file(output_path) + transcript = get_transcription_file(service_directory, output_path) finally: os.remove(output_path) return transcript -def stt(input_data, mime_type="audio/wav"): - if isinstance(input_data, str): - return stt_wav(input_data) - elif isinstance(input_data, bytearray): - return stt_bytes(input_data, mime_type) - else: - raise ValueError("Input data should be either a path to a wav file (str) or audio bytes (bytearray)") \ No newline at end of file +def stt(service_directory, input_data): + return stt_wav(service_directory, input_data) \ No newline at end of file diff --git a/01OS/01OS/server/services/stt/openai/stt.py b/01OS/01OS/server/services/stt/openai/stt.py index 40308cf..4823965 100644 --- a/01OS/01OS/server/services/stt/openai/stt.py +++ b/01OS/01OS/server/services/stt/openai/stt.py @@ -68,9 +68,7 @@ def run_command(command): def get_transcription_file(wav_file_path: str): local_path = os.path.join(os.path.dirname(__file__), 'local_service') whisper_rust_path = os.path.join(os.path.dirname(__file__), 'whisper-rust', 'target', 'release') - model_name = os.getenv('WHISPER_MODEL_NAME') - if not model_name: - raise EnvironmentError("WHISPER_MODEL_NAME environment variable is not set.") + model_name = os.getenv('WHISPER_MODEL_NAME', 'ggml-tiny.en.bin') output, error = run_command([ os.path.join(whisper_rust_path, 'whisper-rust'), diff --git a/01OS/start.py b/01OS/start.py index 34a0775..8cf5dfd 100644 --- a/01OS/start.py +++ b/01OS/start.py @@ -40,7 +40,7 @@ def run( if local: tts_service = "piper" - llm_service = "llamafile" + # llm_service = "llamafile" stt_service = "local-whisper" if not server_url: