parent
7cd6f25353
commit
d33588becc
@ -1,138 +0,0 @@
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
import whisperx
|
||||
from pydub import AudioSegment
|
||||
from pytube import YouTube
|
||||
except Exception as error:
|
||||
print("Error importing pytube. Please install pytube manually.")
|
||||
print("pip install pytube")
|
||||
print("pip install pydub")
|
||||
print("pip install whisperx")
|
||||
print(f"Pytube error: {error}")
|
||||
|
||||
|
||||
class WhisperX:
|
||||
def __init__(
|
||||
self,
|
||||
video_url,
|
||||
audio_format="mp3",
|
||||
device="cuda",
|
||||
batch_size=16,
|
||||
compute_type="float16",
|
||||
hf_api_key=None,
|
||||
):
|
||||
"""
|
||||
# Example usage
|
||||
video_url = "url"
|
||||
speech_to_text = WhisperX(video_url)
|
||||
transcription = speech_to_text.transcribe_youtube_video()
|
||||
print(transcription)
|
||||
|
||||
"""
|
||||
self.video_url = video_url
|
||||
self.audio_format = audio_format
|
||||
self.device = device
|
||||
self.batch_size = batch_size
|
||||
self.compute_type = compute_type
|
||||
self.hf_api_key = hf_api_key
|
||||
|
||||
def install(self):
|
||||
subprocess.run(["pip", "install", "whisperx"])
|
||||
subprocess.run(["pip", "install", "pytube"])
|
||||
subprocess.run(["pip", "install", "pydub"])
|
||||
|
||||
def download_youtube_video(self):
|
||||
audio_file = f"video.{self.audio_format}"
|
||||
|
||||
# Download video 📥
|
||||
yt = YouTube(self.video_url)
|
||||
yt_stream = yt.streams.filter(only_audio=True).first()
|
||||
yt_stream.download(filename="video.mp4")
|
||||
|
||||
# Convert video to audio 🎧
|
||||
video = AudioSegment.from_file("video.mp4", format="mp4")
|
||||
video.export(audio_file, format=self.audio_format)
|
||||
os.remove("video.mp4")
|
||||
|
||||
return audio_file
|
||||
|
||||
def transcribe_youtube_video(self):
|
||||
audio_file = self.download_youtube_video()
|
||||
|
||||
device = "cuda"
|
||||
batch_size = 16
|
||||
compute_type = "float16"
|
||||
|
||||
# 1. Transcribe with original Whisper (batched) 🗣️
|
||||
model = whisperx.load_model(
|
||||
"large-v2", device, compute_type=compute_type
|
||||
)
|
||||
audio = whisperx.load_audio(audio_file)
|
||||
result = model.transcribe(audio, batch_size=batch_size)
|
||||
|
||||
# 2. Align Whisper output 🔍
|
||||
model_a, metadata = whisperx.load_align_model(
|
||||
language_code=result["language"], device=device
|
||||
)
|
||||
result = whisperx.align(
|
||||
result["segments"],
|
||||
model_a,
|
||||
metadata,
|
||||
audio,
|
||||
device,
|
||||
return_char_alignments=False,
|
||||
)
|
||||
|
||||
# 3. Assign speaker labels 🏷️
|
||||
diarize_model = whisperx.DiarizationPipeline(
|
||||
use_auth_token=self.hf_api_key, device=device
|
||||
)
|
||||
diarize_model(audio_file)
|
||||
|
||||
try:
|
||||
segments = result["segments"]
|
||||
transcription = " ".join(
|
||||
segment["text"] for segment in segments
|
||||
)
|
||||
return transcription
|
||||
except KeyError:
|
||||
print("The key 'segments' is not found in the result.")
|
||||
|
||||
def transcribe(self, audio_file):
|
||||
model = whisperx.load_model(
|
||||
"large-v2", self.device, self.compute_type
|
||||
)
|
||||
audio = whisperx.load_audio(audio_file)
|
||||
result = model.transcribe(audio, batch_size=self.batch_size)
|
||||
|
||||
# 2. Align Whisper output 🔍
|
||||
model_a, metadata = whisperx.load_align_model(
|
||||
language_code=result["language"], device=self.device
|
||||
)
|
||||
|
||||
result = whisperx.align(
|
||||
result["segments"],
|
||||
model_a,
|
||||
metadata,
|
||||
audio,
|
||||
self.device,
|
||||
return_char_alignments=False,
|
||||
)
|
||||
|
||||
# 3. Assign speaker labels 🏷️
|
||||
diarize_model = whisperx.DiarizationPipeline(
|
||||
use_auth_token=self.hf_api_key, device=self.device
|
||||
)
|
||||
|
||||
diarize_model(audio_file)
|
||||
|
||||
try:
|
||||
segments = result["segments"]
|
||||
transcription = " ".join(
|
||||
segment["text"] for segment in segments
|
||||
)
|
||||
return transcription
|
||||
except KeyError:
|
||||
print("The key 'segments' is not found in the result.")
|
@ -1,222 +0,0 @@
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import whisperx
|
||||
from pydub import AudioSegment
|
||||
from pytube import YouTube
|
||||
from swarms.models.whisperx_model import WhisperX
|
||||
|
||||
|
||||
# Fixture to create a temporary directory for testing
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
yield tempdir
|
||||
|
||||
|
||||
# Mock subprocess.run to prevent actual installation during tests
|
||||
@patch.object(subprocess, "run")
|
||||
def test_speech_to_text_install(mock_run):
|
||||
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
||||
stt.install()
|
||||
mock_run.assert_called_with(["pip", "install", "whisperx"])
|
||||
|
||||
|
||||
# Mock pytube.YouTube and pytube.Streams for download tests
|
||||
@patch("pytube.YouTube")
|
||||
@patch.object(YouTube, "streams")
|
||||
def test_speech_to_text_download_youtube_video(
|
||||
mock_streams, mock_youtube, temp_dir
|
||||
):
|
||||
# Mock YouTube and streams
|
||||
video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM"
|
||||
mock_stream = mock_streams().filter().first()
|
||||
mock_stream.download.return_value = os.path.join(
|
||||
temp_dir, "video.mp4"
|
||||
)
|
||||
mock_youtube.return_value = mock_youtube
|
||||
mock_youtube.streams = mock_streams
|
||||
|
||||
stt = WhisperX(video_url)
|
||||
audio_file = stt.download_youtube_video()
|
||||
|
||||
assert os.path.exists(audio_file)
|
||||
assert audio_file.endswith(".mp3")
|
||||
|
||||
|
||||
# Mock whisperx.load_model and whisperx.load_audio for transcribe tests
|
||||
@patch("whisperx.load_model")
|
||||
@patch("whisperx.load_audio")
|
||||
@patch("whisperx.load_align_model")
|
||||
@patch("whisperx.align")
|
||||
@patch.object(whisperx.DiarizationPipeline, "__call__")
|
||||
def test_speech_to_text_transcribe_youtube_video(
|
||||
mock_diarization,
|
||||
mock_align,
|
||||
mock_align_model,
|
||||
mock_load_audio,
|
||||
mock_load_model,
|
||||
temp_dir,
|
||||
):
|
||||
# Mock whisperx functions
|
||||
mock_load_model.return_value = mock_load_model
|
||||
mock_load_model.transcribe.return_value = {
|
||||
"language": "en",
|
||||
"segments": [{"text": "Hello, World!"}],
|
||||
}
|
||||
|
||||
mock_load_audio.return_value = "audio_path"
|
||||
mock_align_model.return_value = (mock_align_model, "metadata")
|
||||
mock_align.return_value = {
|
||||
"segments": [{"text": "Hello, World!"}]
|
||||
}
|
||||
|
||||
# Mock diarization pipeline
|
||||
mock_diarization.return_value = None
|
||||
|
||||
video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM/video"
|
||||
stt = WhisperX(video_url)
|
||||
transcription = stt.transcribe_youtube_video()
|
||||
|
||||
assert transcription == "Hello, World!"
|
||||
|
||||
|
||||
# More tests for different scenarios and edge cases can be added here.
|
||||
|
||||
|
||||
# Test transcribe method with provided audio file
|
||||
def test_speech_to_text_transcribe_audio_file(temp_dir):
|
||||
# Create a temporary audio file
|
||||
audio_file = os.path.join(temp_dir, "test_audio.mp3")
|
||||
AudioSegment.silent(duration=500).export(audio_file, format="mp3")
|
||||
|
||||
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
||||
transcription = stt.transcribe(audio_file)
|
||||
|
||||
assert transcription == ""
|
||||
|
||||
|
||||
# Test transcribe method when Whisperx fails
|
||||
@patch("whisperx.load_model")
|
||||
@patch("whisperx.load_audio")
|
||||
def test_speech_to_text_transcribe_whisperx_failure(
|
||||
mock_load_audio, mock_load_model, temp_dir
|
||||
):
|
||||
# Mock whisperx functions to raise an exception
|
||||
mock_load_model.side_effect = Exception("Whisperx failed")
|
||||
mock_load_audio.return_value = "audio_path"
|
||||
|
||||
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
||||
transcription = stt.transcribe("audio_path")
|
||||
|
||||
assert transcription == "Whisperx failed"
|
||||
|
||||
|
||||
# Test transcribe method with missing 'segments' key in Whisperx output
|
||||
@patch("whisperx.load_model")
|
||||
@patch("whisperx.load_audio")
|
||||
@patch("whisperx.load_align_model")
|
||||
@patch("whisperx.align")
|
||||
@patch.object(whisperx.DiarizationPipeline, "__call__")
|
||||
def test_speech_to_text_transcribe_missing_segments(
|
||||
mock_diarization,
|
||||
mock_align,
|
||||
mock_align_model,
|
||||
mock_load_audio,
|
||||
mock_load_model,
|
||||
):
|
||||
# Mock whisperx functions to return incomplete output
|
||||
mock_load_model.return_value = mock_load_model
|
||||
mock_load_model.transcribe.return_value = {"language": "en"}
|
||||
|
||||
mock_load_audio.return_value = "audio_path"
|
||||
mock_align_model.return_value = (mock_align_model, "metadata")
|
||||
mock_align.return_value = {}
|
||||
|
||||
# Mock diarization pipeline
|
||||
mock_diarization.return_value = None
|
||||
|
||||
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
||||
transcription = stt.transcribe("audio_path")
|
||||
|
||||
assert transcription == ""
|
||||
|
||||
|
||||
# Test transcribe method with Whisperx align failure
|
||||
@patch("whisperx.load_model")
|
||||
@patch("whisperx.load_audio")
|
||||
@patch("whisperx.load_align_model")
|
||||
@patch("whisperx.align")
|
||||
@patch.object(whisperx.DiarizationPipeline, "__call__")
|
||||
def test_speech_to_text_transcribe_align_failure(
|
||||
mock_diarization,
|
||||
mock_align,
|
||||
mock_align_model,
|
||||
mock_load_audio,
|
||||
mock_load_model,
|
||||
):
|
||||
# Mock whisperx functions to raise an exception during align
|
||||
mock_load_model.return_value = mock_load_model
|
||||
mock_load_model.transcribe.return_value = {
|
||||
"language": "en",
|
||||
"segments": [{"text": "Hello, World!"}],
|
||||
}
|
||||
|
||||
mock_load_audio.return_value = "audio_path"
|
||||
mock_align_model.return_value = (mock_align_model, "metadata")
|
||||
mock_align.side_effect = Exception("Align failed")
|
||||
|
||||
# Mock diarization pipeline
|
||||
mock_diarization.return_value = None
|
||||
|
||||
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
||||
transcription = stt.transcribe("audio_path")
|
||||
|
||||
assert transcription == "Align failed"
|
||||
|
||||
|
||||
# Test transcribe_youtube_video when Whisperx diarization fails
|
||||
@patch("pytube.YouTube")
|
||||
@patch.object(YouTube, "streams")
|
||||
@patch("whisperx.DiarizationPipeline")
|
||||
@patch("whisperx.load_audio")
|
||||
@patch("whisperx.load_align_model")
|
||||
@patch("whisperx.align")
|
||||
def test_speech_to_text_transcribe_diarization_failure(
|
||||
mock_align,
|
||||
mock_align_model,
|
||||
mock_load_audio,
|
||||
mock_diarization,
|
||||
mock_streams,
|
||||
mock_youtube,
|
||||
temp_dir,
|
||||
):
|
||||
# Mock YouTube and streams
|
||||
video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM"
|
||||
mock_stream = mock_streams().filter().first()
|
||||
mock_stream.download.return_value = os.path.join(
|
||||
temp_dir, "video.mp4"
|
||||
)
|
||||
mock_youtube.return_value = mock_youtube
|
||||
mock_youtube.streams = mock_streams
|
||||
|
||||
# Mock whisperx functions
|
||||
mock_load_audio.return_value = "audio_path"
|
||||
mock_align_model.return_value = (mock_align_model, "metadata")
|
||||
mock_align.return_value = {
|
||||
"segments": [{"text": "Hello, World!"}]
|
||||
}
|
||||
|
||||
# Mock diarization pipeline to raise an exception
|
||||
mock_diarization.side_effect = Exception("Diarization failed")
|
||||
|
||||
stt = WhisperX(video_url)
|
||||
transcription = stt.transcribe_youtube_video()
|
||||
|
||||
assert transcription == "Diarization failed"
|
||||
|
||||
|
||||
# Add more tests for other scenarios and edge cases as needed.
|
Loading…
Reference in new issue