Merge pull request #170 from TashaSkyUp/fix-169

Implement cross-platform compatibility in stt.py by replacing system-…
pull/173/head^2
Ty Fiero 10 months ago committed by GitHub
commit 6ef8e86173
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -12,6 +12,8 @@ import subprocess
import os import os
import subprocess import subprocess
import platform
import urllib.request
class Stt: class Stt:
@ -23,7 +25,6 @@ class Stt:
return stt(self.service_directory, audio_file_path) return stt(self.service_directory, audio_file_path)
def install(service_dir): def install(service_dir):
### INSTALL ### INSTALL
@ -42,27 +43,29 @@ def install(service_dir):
# Check if whisper-rust executable exists before attempting to build # Check if whisper-rust executable exists before attempting to build
if not os.path.isfile(os.path.join(WHISPER_RUST_PATH, "target/release/whisper-rust")): if not os.path.isfile(os.path.join(WHISPER_RUST_PATH, "target/release/whisper-rust")):
# Check if Rust is installed. Needed to build whisper executable # Check if Rust is installed. Needed to build whisper executable
rust_check = subprocess.call('command -v rustc', shell=True) rustc_path = shutil.which('rustc')
if rust_check != 0: if rustc_path is None:
print("Rust is not installed or is not in system PATH. Please install Rust before proceeding.") print("Rust is not installed or is not in system PATH. Please install Rust before proceeding.")
exit(1) exit(1)
# Build Whisper Rust executable if not found # Build Whisper Rust executable if not found
subprocess.call('cargo build --release', shell=True) subprocess.run(['cargo', 'build', '--release'], check=True)
else: else:
print("Whisper Rust executable already exists. Skipping build.") print("Whisper Rust executable already exists. Skipping build.")
WHISPER_MODEL_PATH = os.path.join(service_dir, "model") WHISPER_MODEL_PATH = os.path.join(service_dir, "model")
WHISPER_MODEL_NAME = os.getenv('WHISPER_MODEL_NAME', 'ggml-tiny.en.bin') WHISPER_MODEL_NAME = os.getenv('WHISPER_MODEL_NAME', 'ggml-tiny.en.bin')
WHISPER_MODEL_URL = os.getenv('WHISPER_MODEL_URL', 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/') WHISPER_MODEL_URL = os.getenv('WHISPER_MODEL_URL', 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/')
if not os.path.isfile(os.path.join(WHISPER_MODEL_PATH, WHISPER_MODEL_NAME)): if not os.path.isfile(os.path.join(WHISPER_MODEL_PATH, WHISPER_MODEL_NAME)):
os.makedirs(WHISPER_MODEL_PATH, exist_ok=True) os.makedirs(WHISPER_MODEL_PATH, exist_ok=True)
subprocess.call(f'curl -L "{WHISPER_MODEL_URL}{WHISPER_MODEL_NAME}" -o "{os.path.join(WHISPER_MODEL_PATH, WHISPER_MODEL_NAME)}"', shell=True) urllib.request.urlretrieve(f"{WHISPER_MODEL_URL}{WHISPER_MODEL_NAME}",
os.path.join(WHISPER_MODEL_PATH, WHISPER_MODEL_NAME))
else: else:
print("Whisper model already exists. Skipping download.") print("Whisper model already exists. Skipping download.")
def convert_mime_type_to_format(mime_type: str) -> str: def convert_mime_type_to_format(mime_type: str) -> str:
if mime_type == "audio/x-wav" or mime_type == "audio/wav": if mime_type == "audio/x-wav" or mime_type == "audio/wav":
return "wav" return "wav"
@ -73,6 +76,7 @@ def convert_mime_type_to_format(mime_type: str) -> str:
return mime_type return mime_type
@contextlib.contextmanager @contextlib.contextmanager
def export_audio_to_wav_ffmpeg(audio: bytearray, mime_type: str) -> str: def export_audio_to_wav_ffmpeg(audio: bytearray, mime_type: str) -> str:
temp_dir = tempfile.gettempdir() temp_dir = tempfile.gettempdir()
@ -105,10 +109,12 @@ def export_audio_to_wav_ffmpeg(audio: bytearray, mime_type: str) -> str:
os.remove(input_path) os.remove(input_path)
os.remove(output_path) os.remove(output_path)
def run_command(command): def run_command(command):
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
return result.stdout, result.stderr return result.stdout, result.stderr
def get_transcription_file(service_directory, wav_file_path: str): def get_transcription_file(service_directory, wav_file_path: str):
local_path = os.path.join(service_directory, 'model') local_path = os.path.join(service_directory, 'model')
whisper_rust_path = os.path.join(service_directory, 'whisper-rust', 'target', 'release') whisper_rust_path = os.path.join(service_directory, 'whisper-rust', 'target', 'release')
@ -124,14 +130,15 @@ def get_transcription_file(service_directory, wav_file_path: str):
def stt_wav(service_directory, wav_file_path: str): def stt_wav(service_directory, wav_file_path: str):
temp_dir = tempfile.gettempdir() temp_dir = tempfile.gettempdir()
output_path = os.path.join(temp_dir, f"output_stt_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav") output_path = os.path.join(temp_dir, f"output_stt_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav")
ffmpeg.input(wav_file_path).output(output_path, acodec='pcm_s16le', ac=1, ar='16k').run() ffmpeg.input(wav_file_path).output(output_path, acodec='pcm_s16le', ac=1, ar='16k').run()
try: try:
transcript = get_transcription_file(service_directory, output_path) transcript = get_transcription_file(service_directory, output_path)
finally: finally:
os.remove(output_path) os.remove(output_path)
return transcript return transcript
def stt(service_directory, input_data): def stt(service_directory, input_data):
return stt_wav(service_directory, input_data) return stt_wav(service_directory, input_data)

Loading…
Cancel
Save