Adding Whisper model validation.

pull/241/head
Robert Brisita 9 months ago
parent 1324789123
commit 66ff6b1016

@ -5,11 +5,11 @@ Defines a function which takes a path to an audio file and turns it into text.
from datetime import datetime
import os
import contextlib
import platform
import tempfile
import shutil
import ffmpeg
import subprocess
import urllib.request
@ -56,21 +56,92 @@ def install(service_dir):
print("Whisper Rust executable already exists. Skipping build.")
WHISPER_MODEL_PATH = os.path.join(service_dir, "model")
WHISPER_MODEL_NAME = os.getenv("WHISPER_MODEL_NAME", "ggml-tiny.en.bin")
while not valid_model(WHISPER_MODEL_PATH, WHISPER_MODEL_NAME):
print(f"Downloading Whisper model '{WHISPER_MODEL_NAME}'.")
WHISPER_MODEL_URL = os.getenv(
"WHISPER_MODEL_URL",
"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/",
)
if not os.path.isfile(os.path.join(WHISPER_MODEL_PATH, WHISPER_MODEL_NAME)):
os.makedirs(WHISPER_MODEL_PATH, exist_ok=True)
urllib.request.urlretrieve(
f"{WHISPER_MODEL_URL}{WHISPER_MODEL_NAME}",
os.path.join(WHISPER_MODEL_PATH, WHISPER_MODEL_NAME),
)
else:
print("Whisper model already exists. Skipping download.")
print(f"Whisper model '{WHISPER_MODEL_NAME}' installed.")
def valid_model(model_path: str, model_file: str) -> bool:
# Try to validate model through cryptographic hash comparison
model_file_path = os.path.join(model_path, model_file)
if not os.path.isfile(model_file_path):
return False
# Download details file and get hash
details_file = f"https://huggingface.co/ggerganov/whisper.cpp/raw/main/{model_file}"
try:
with urllib.request.urlopen(details_file) as response:
body_bytes = response.read()
except:
print("Internet connection not detected. Skipping validation.")
return True
lines = body_bytes.splitlines()
colon_index = lines[1].find(b':')
details_hash = lines[1][colon_index + 1:].decode()
# Generate model hash using native commands
model_hash = None
system = platform.system()
if system == 'Darwin':
shasum_path = shutil.which('shasum')
model_hash = subprocess.check_output(
f"{shasum_path} -a 256 {model_file_path} | cut -d' ' -f1",
text=True,
shell=True
)
elif system == 'Linux':
sha256sum_path = shutil.which('sha256sum')
model_hash = subprocess.check_output(
f"{sha256sum_path} {model_file_path} | cut -d' ' -f1",
text=True,
shell=True
)
elif system == 'Windows':
comspec = os.getenv("COMSPEC")
if comspec.endswith('cmd.exe'): # Most likely
certutil_path = shutil.which('certutil')
first_op = f"{certutil_path} -hashfile {model_file_path} sha256"
second_op = 'findstr /v "SHA256 CertUtil"' # Prints only lines that do not contain a match.
model_hash = subprocess.check_output(f"{first_op} | {second_op}", text=True, shell=True)
else:
first_op = f"Get-FileHash -LiteralPath {model_file_path} -Algorithm SHA256"
subsequent_ops = "Select-Object Hash | Format-Table -HideTableHeaders | Out-String"
model_hash = subprocess.check_output([
'pwsh',
'-Command',
f"({first_op} | {subsequent_ops}).trim().toLower()"
],
text=True
)
else:
print(f"System '{system}' not supported. Skipping validation.")
return True
if details_hash == model_hash.strip():
print(f"Whisper model '{model_file}' file is valid.")
else:
msg = f'''
The model '{model_file}' did not validate. STT may not function correctly.
The model path is '{model_path}'.
Manually download and verify the model's hash to get better functionality.
Continuing.
'''
print(msg)
return True
def convert_mime_type_to_format(mime_type: str) -> str:

Loading…
Cancel
Save