[CLEANUP][WhisperX]

pull/334/head
Kye 1 year ago
parent 7cd6f25353
commit d33588becc

@ -1,138 +0,0 @@
import os
import subprocess
try:
import whisperx
from pydub import AudioSegment
from pytube import YouTube
except Exception as error:
print("Error importing pytube. Please install pytube manually.")
print("pip install pytube")
print("pip install pydub")
print("pip install whisperx")
print(f"Pytube error: {error}")
class WhisperX:
def __init__(
self,
video_url,
audio_format="mp3",
device="cuda",
batch_size=16,
compute_type="float16",
hf_api_key=None,
):
"""
# Example usage
video_url = "url"
speech_to_text = WhisperX(video_url)
transcription = speech_to_text.transcribe_youtube_video()
print(transcription)
"""
self.video_url = video_url
self.audio_format = audio_format
self.device = device
self.batch_size = batch_size
self.compute_type = compute_type
self.hf_api_key = hf_api_key
def install(self):
subprocess.run(["pip", "install", "whisperx"])
subprocess.run(["pip", "install", "pytube"])
subprocess.run(["pip", "install", "pydub"])
def download_youtube_video(self):
audio_file = f"video.{self.audio_format}"
# Download video 📥
yt = YouTube(self.video_url)
yt_stream = yt.streams.filter(only_audio=True).first()
yt_stream.download(filename="video.mp4")
# Convert video to audio 🎧
video = AudioSegment.from_file("video.mp4", format="mp4")
video.export(audio_file, format=self.audio_format)
os.remove("video.mp4")
return audio_file
def transcribe_youtube_video(self):
audio_file = self.download_youtube_video()
device = "cuda"
batch_size = 16
compute_type = "float16"
# 1. Transcribe with original Whisper (batched) 🗣️
model = whisperx.load_model(
"large-v2", device, compute_type=compute_type
)
audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=batch_size)
# 2. Align Whisper output 🔍
model_a, metadata = whisperx.load_align_model(
language_code=result["language"], device=device
)
result = whisperx.align(
result["segments"],
model_a,
metadata,
audio,
device,
return_char_alignments=False,
)
# 3. Assign speaker labels 🏷️
diarize_model = whisperx.DiarizationPipeline(
use_auth_token=self.hf_api_key, device=device
)
diarize_model(audio_file)
try:
segments = result["segments"]
transcription = " ".join(
segment["text"] for segment in segments
)
return transcription
except KeyError:
print("The key 'segments' is not found in the result.")
def transcribe(self, audio_file):
model = whisperx.load_model(
"large-v2", self.device, self.compute_type
)
audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=self.batch_size)
# 2. Align Whisper output 🔍
model_a, metadata = whisperx.load_align_model(
language_code=result["language"], device=self.device
)
result = whisperx.align(
result["segments"],
model_a,
metadata,
audio,
self.device,
return_char_alignments=False,
)
# 3. Assign speaker labels 🏷️
diarize_model = whisperx.DiarizationPipeline(
use_auth_token=self.hf_api_key, device=self.device
)
diarize_model(audio_file)
try:
segments = result["segments"]
transcription = " ".join(
segment["text"] for segment in segments
)
return transcription
except KeyError:
print("The key 'segments' is not found in the result.")

@ -1,222 +0,0 @@
import os
import subprocess
import tempfile
from unittest.mock import patch
import pytest
import whisperx
from pydub import AudioSegment
from pytube import YouTube
from swarms.models.whisperx_model import WhisperX
# Fixture to create a temporary directory for testing
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as tempdir:
yield tempdir
# Mock subprocess.run to prevent actual installation during tests
@patch.object(subprocess, "run")
def test_speech_to_text_install(mock_run):
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
stt.install()
mock_run.assert_called_with(["pip", "install", "whisperx"])
# Mock pytube.YouTube and pytube.Streams for download tests
@patch("pytube.YouTube")
@patch.object(YouTube, "streams")
def test_speech_to_text_download_youtube_video(
mock_streams, mock_youtube, temp_dir
):
# Mock YouTube and streams
video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM"
mock_stream = mock_streams().filter().first()
mock_stream.download.return_value = os.path.join(
temp_dir, "video.mp4"
)
mock_youtube.return_value = mock_youtube
mock_youtube.streams = mock_streams
stt = WhisperX(video_url)
audio_file = stt.download_youtube_video()
assert os.path.exists(audio_file)
assert audio_file.endswith(".mp3")
# Mock whisperx.load_model and whisperx.load_audio for transcribe tests
@patch("whisperx.load_model")
@patch("whisperx.load_audio")
@patch("whisperx.load_align_model")
@patch("whisperx.align")
@patch.object(whisperx.DiarizationPipeline, "__call__")
def test_speech_to_text_transcribe_youtube_video(
mock_diarization,
mock_align,
mock_align_model,
mock_load_audio,
mock_load_model,
temp_dir,
):
# Mock whisperx functions
mock_load_model.return_value = mock_load_model
mock_load_model.transcribe.return_value = {
"language": "en",
"segments": [{"text": "Hello, World!"}],
}
mock_load_audio.return_value = "audio_path"
mock_align_model.return_value = (mock_align_model, "metadata")
mock_align.return_value = {
"segments": [{"text": "Hello, World!"}]
}
# Mock diarization pipeline
mock_diarization.return_value = None
video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM/video"
stt = WhisperX(video_url)
transcription = stt.transcribe_youtube_video()
assert transcription == "Hello, World!"
# More tests for different scenarios and edge cases can be added here.
# Test transcribe method with provided audio file
def test_speech_to_text_transcribe_audio_file(temp_dir):
# Create a temporary audio file
audio_file = os.path.join(temp_dir, "test_audio.mp3")
AudioSegment.silent(duration=500).export(audio_file, format="mp3")
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
transcription = stt.transcribe(audio_file)
assert transcription == ""
# Test transcribe method when Whisperx fails
@patch("whisperx.load_model")
@patch("whisperx.load_audio")
def test_speech_to_text_transcribe_whisperx_failure(
mock_load_audio, mock_load_model, temp_dir
):
# Mock whisperx functions to raise an exception
mock_load_model.side_effect = Exception("Whisperx failed")
mock_load_audio.return_value = "audio_path"
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
transcription = stt.transcribe("audio_path")
assert transcription == "Whisperx failed"
# Test transcribe method with missing 'segments' key in Whisperx output
@patch("whisperx.load_model")
@patch("whisperx.load_audio")
@patch("whisperx.load_align_model")
@patch("whisperx.align")
@patch.object(whisperx.DiarizationPipeline, "__call__")
def test_speech_to_text_transcribe_missing_segments(
mock_diarization,
mock_align,
mock_align_model,
mock_load_audio,
mock_load_model,
):
# Mock whisperx functions to return incomplete output
mock_load_model.return_value = mock_load_model
mock_load_model.transcribe.return_value = {"language": "en"}
mock_load_audio.return_value = "audio_path"
mock_align_model.return_value = (mock_align_model, "metadata")
mock_align.return_value = {}
# Mock diarization pipeline
mock_diarization.return_value = None
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
transcription = stt.transcribe("audio_path")
assert transcription == ""
# Test transcribe method with Whisperx align failure
@patch("whisperx.load_model")
@patch("whisperx.load_audio")
@patch("whisperx.load_align_model")
@patch("whisperx.align")
@patch.object(whisperx.DiarizationPipeline, "__call__")
def test_speech_to_text_transcribe_align_failure(
mock_diarization,
mock_align,
mock_align_model,
mock_load_audio,
mock_load_model,
):
# Mock whisperx functions to raise an exception during align
mock_load_model.return_value = mock_load_model
mock_load_model.transcribe.return_value = {
"language": "en",
"segments": [{"text": "Hello, World!"}],
}
mock_load_audio.return_value = "audio_path"
mock_align_model.return_value = (mock_align_model, "metadata")
mock_align.side_effect = Exception("Align failed")
# Mock diarization pipeline
mock_diarization.return_value = None
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
transcription = stt.transcribe("audio_path")
assert transcription == "Align failed"
# Test transcribe_youtube_video when Whisperx diarization fails
@patch("pytube.YouTube")
@patch.object(YouTube, "streams")
@patch("whisperx.DiarizationPipeline")
@patch("whisperx.load_audio")
@patch("whisperx.load_align_model")
@patch("whisperx.align")
def test_speech_to_text_transcribe_diarization_failure(
mock_align,
mock_align_model,
mock_load_audio,
mock_diarization,
mock_streams,
mock_youtube,
temp_dir,
):
# Mock YouTube and streams
video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM"
mock_stream = mock_streams().filter().first()
mock_stream.download.return_value = os.path.join(
temp_dir, "video.mp4"
)
mock_youtube.return_value = mock_youtube
mock_youtube.streams = mock_streams
# Mock whisperx functions
mock_load_audio.return_value = "audio_path"
mock_align_model.return_value = (mock_align_model, "metadata")
mock_align.return_value = {
"segments": [{"text": "Hello, World!"}]
}
# Mock diarization pipeline to raise an exception
mock_diarization.side_effect = Exception("Diarization failed")
stt = WhisperX(video_url)
transcription = stt.transcribe_youtube_video()
assert transcription == "Diarization failed"
# Add more tests for other scenarios and edge cases as needed.
Loading…
Cancel
Save