Adding Whisper model validation.

pull/241/head
Robert Brisita 9 months ago
parent 1324789123
commit 66ff6b1016

@ -5,11 +5,11 @@ Defines a function which takes a path to an audio file and turns it into text.
from datetime import datetime from datetime import datetime
import os import os
import contextlib import contextlib
import platform
import tempfile import tempfile
import shutil import shutil
import ffmpeg import ffmpeg
import subprocess import subprocess
import urllib.request import urllib.request
@ -56,21 +56,92 @@ def install(service_dir):
print("Whisper Rust executable already exists. Skipping build.") print("Whisper Rust executable already exists. Skipping build.")
WHISPER_MODEL_PATH = os.path.join(service_dir, "model") WHISPER_MODEL_PATH = os.path.join(service_dir, "model")
WHISPER_MODEL_NAME = os.getenv("WHISPER_MODEL_NAME", "ggml-tiny.en.bin") WHISPER_MODEL_NAME = os.getenv("WHISPER_MODEL_NAME", "ggml-tiny.en.bin")
WHISPER_MODEL_URL = os.getenv( while not valid_model(WHISPER_MODEL_PATH, WHISPER_MODEL_NAME):
"WHISPER_MODEL_URL", print(f"Downloading Whisper model '{WHISPER_MODEL_NAME}'.")
"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/", WHISPER_MODEL_URL = os.getenv(
) "WHISPER_MODEL_URL",
"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/",
if not os.path.isfile(os.path.join(WHISPER_MODEL_PATH, WHISPER_MODEL_NAME)): )
os.makedirs(WHISPER_MODEL_PATH, exist_ok=True) os.makedirs(WHISPER_MODEL_PATH, exist_ok=True)
urllib.request.urlretrieve( urllib.request.urlretrieve(
f"{WHISPER_MODEL_URL}{WHISPER_MODEL_NAME}", f"{WHISPER_MODEL_URL}{WHISPER_MODEL_NAME}",
os.path.join(WHISPER_MODEL_PATH, WHISPER_MODEL_NAME), os.path.join(WHISPER_MODEL_PATH, WHISPER_MODEL_NAME),
) )
else: else:
print("Whisper model already exists. Skipping download.") print(f"Whisper model '{WHISPER_MODEL_NAME}' installed.")
def valid_model(model_path: str, model_file: str) -> bool:
# Try to validate model through cryptographic hash comparison
model_file_path = os.path.join(model_path, model_file)
if not os.path.isfile(model_file_path):
return False
# Download details file and get hash
details_file = f"https://huggingface.co/ggerganov/whisper.cpp/raw/main/{model_file}"
try:
with urllib.request.urlopen(details_file) as response:
body_bytes = response.read()
except:
print("Internet connection not detected. Skipping validation.")
return True
lines = body_bytes.splitlines()
colon_index = lines[1].find(b':')
details_hash = lines[1][colon_index + 1:].decode()
# Generate model hash using native commands
model_hash = None
system = platform.system()
if system == 'Darwin':
shasum_path = shutil.which('shasum')
model_hash = subprocess.check_output(
f"{shasum_path} -a 256 {model_file_path} | cut -d' ' -f1",
text=True,
shell=True
)
elif system == 'Linux':
sha256sum_path = shutil.which('sha256sum')
model_hash = subprocess.check_output(
f"{sha256sum_path} {model_file_path} | cut -d' ' -f1",
text=True,
shell=True
)
elif system == 'Windows':
comspec = os.getenv("COMSPEC")
if comspec.endswith('cmd.exe'): # Most likely
certutil_path = shutil.which('certutil')
first_op = f"{certutil_path} -hashfile {model_file_path} sha256"
second_op = 'findstr /v "SHA256 CertUtil"' # Prints only lines that do not contain a match.
model_hash = subprocess.check_output(f"{first_op} | {second_op}", text=True, shell=True)
else:
first_op = f"Get-FileHash -LiteralPath {model_file_path} -Algorithm SHA256"
subsequent_ops = "Select-Object Hash | Format-Table -HideTableHeaders | Out-String"
model_hash = subprocess.check_output([
'pwsh',
'-Command',
f"({first_op} | {subsequent_ops}).trim().toLower()"
],
text=True
)
else:
print(f"System '{system}' not supported. Skipping validation.")
return True
if details_hash == model_hash.strip():
print(f"Whisper model '{model_file}' file is valid.")
else:
msg = f'''
The model '{model_file}' did not validate. STT may not function correctly.
The model path is '{model_path}'.
Manually download and verify the model's hash to get better functionality.
Continuing.
'''
print(msg)
return True
def convert_mime_type_to_format(mime_type: str) -> str: def convert_mime_type_to_format(mime_type: str) -> str:

Loading…
Cancel
Save