parent
7cd6f25353
commit
d33588becc
@ -1,138 +0,0 @@
|
|||||||
import os
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
try:
|
|
||||||
import whisperx
|
|
||||||
from pydub import AudioSegment
|
|
||||||
from pytube import YouTube
|
|
||||||
except Exception as error:
|
|
||||||
print("Error importing pytube. Please install pytube manually.")
|
|
||||||
print("pip install pytube")
|
|
||||||
print("pip install pydub")
|
|
||||||
print("pip install whisperx")
|
|
||||||
print(f"Pytube error: {error}")
|
|
||||||
|
|
||||||
|
|
||||||
class WhisperX:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
video_url,
|
|
||||||
audio_format="mp3",
|
|
||||||
device="cuda",
|
|
||||||
batch_size=16,
|
|
||||||
compute_type="float16",
|
|
||||||
hf_api_key=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
# Example usage
|
|
||||||
video_url = "url"
|
|
||||||
speech_to_text = WhisperX(video_url)
|
|
||||||
transcription = speech_to_text.transcribe_youtube_video()
|
|
||||||
print(transcription)
|
|
||||||
|
|
||||||
"""
|
|
||||||
self.video_url = video_url
|
|
||||||
self.audio_format = audio_format
|
|
||||||
self.device = device
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.compute_type = compute_type
|
|
||||||
self.hf_api_key = hf_api_key
|
|
||||||
|
|
||||||
def install(self):
|
|
||||||
subprocess.run(["pip", "install", "whisperx"])
|
|
||||||
subprocess.run(["pip", "install", "pytube"])
|
|
||||||
subprocess.run(["pip", "install", "pydub"])
|
|
||||||
|
|
||||||
def download_youtube_video(self):
|
|
||||||
audio_file = f"video.{self.audio_format}"
|
|
||||||
|
|
||||||
# Download video 📥
|
|
||||||
yt = YouTube(self.video_url)
|
|
||||||
yt_stream = yt.streams.filter(only_audio=True).first()
|
|
||||||
yt_stream.download(filename="video.mp4")
|
|
||||||
|
|
||||||
# Convert video to audio 🎧
|
|
||||||
video = AudioSegment.from_file("video.mp4", format="mp4")
|
|
||||||
video.export(audio_file, format=self.audio_format)
|
|
||||||
os.remove("video.mp4")
|
|
||||||
|
|
||||||
return audio_file
|
|
||||||
|
|
||||||
def transcribe_youtube_video(self):
|
|
||||||
audio_file = self.download_youtube_video()
|
|
||||||
|
|
||||||
device = "cuda"
|
|
||||||
batch_size = 16
|
|
||||||
compute_type = "float16"
|
|
||||||
|
|
||||||
# 1. Transcribe with original Whisper (batched) 🗣️
|
|
||||||
model = whisperx.load_model(
|
|
||||||
"large-v2", device, compute_type=compute_type
|
|
||||||
)
|
|
||||||
audio = whisperx.load_audio(audio_file)
|
|
||||||
result = model.transcribe(audio, batch_size=batch_size)
|
|
||||||
|
|
||||||
# 2. Align Whisper output 🔍
|
|
||||||
model_a, metadata = whisperx.load_align_model(
|
|
||||||
language_code=result["language"], device=device
|
|
||||||
)
|
|
||||||
result = whisperx.align(
|
|
||||||
result["segments"],
|
|
||||||
model_a,
|
|
||||||
metadata,
|
|
||||||
audio,
|
|
||||||
device,
|
|
||||||
return_char_alignments=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Assign speaker labels 🏷️
|
|
||||||
diarize_model = whisperx.DiarizationPipeline(
|
|
||||||
use_auth_token=self.hf_api_key, device=device
|
|
||||||
)
|
|
||||||
diarize_model(audio_file)
|
|
||||||
|
|
||||||
try:
|
|
||||||
segments = result["segments"]
|
|
||||||
transcription = " ".join(
|
|
||||||
segment["text"] for segment in segments
|
|
||||||
)
|
|
||||||
return transcription
|
|
||||||
except KeyError:
|
|
||||||
print("The key 'segments' is not found in the result.")
|
|
||||||
|
|
||||||
def transcribe(self, audio_file):
|
|
||||||
model = whisperx.load_model(
|
|
||||||
"large-v2", self.device, self.compute_type
|
|
||||||
)
|
|
||||||
audio = whisperx.load_audio(audio_file)
|
|
||||||
result = model.transcribe(audio, batch_size=self.batch_size)
|
|
||||||
|
|
||||||
# 2. Align Whisper output 🔍
|
|
||||||
model_a, metadata = whisperx.load_align_model(
|
|
||||||
language_code=result["language"], device=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
result = whisperx.align(
|
|
||||||
result["segments"],
|
|
||||||
model_a,
|
|
||||||
metadata,
|
|
||||||
audio,
|
|
||||||
self.device,
|
|
||||||
return_char_alignments=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Assign speaker labels 🏷️
|
|
||||||
diarize_model = whisperx.DiarizationPipeline(
|
|
||||||
use_auth_token=self.hf_api_key, device=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
diarize_model(audio_file)
|
|
||||||
|
|
||||||
try:
|
|
||||||
segments = result["segments"]
|
|
||||||
transcription = " ".join(
|
|
||||||
segment["text"] for segment in segments
|
|
||||||
)
|
|
||||||
return transcription
|
|
||||||
except KeyError:
|
|
||||||
print("The key 'segments' is not found in the result.")
|
|
@ -1,222 +0,0 @@
|
|||||||
import os
|
|
||||||
import subprocess
|
|
||||||
import tempfile
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import whisperx
|
|
||||||
from pydub import AudioSegment
|
|
||||||
from pytube import YouTube
|
|
||||||
from swarms.models.whisperx_model import WhisperX
|
|
||||||
|
|
||||||
|
|
||||||
# Fixture to create a temporary directory for testing
|
|
||||||
@pytest.fixture
|
|
||||||
def temp_dir():
|
|
||||||
with tempfile.TemporaryDirectory() as tempdir:
|
|
||||||
yield tempdir
|
|
||||||
|
|
||||||
|
|
||||||
# Mock subprocess.run to prevent actual installation during tests
|
|
||||||
@patch.object(subprocess, "run")
|
|
||||||
def test_speech_to_text_install(mock_run):
|
|
||||||
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
|
||||||
stt.install()
|
|
||||||
mock_run.assert_called_with(["pip", "install", "whisperx"])
|
|
||||||
|
|
||||||
|
|
||||||
# Mock pytube.YouTube and pytube.Streams for download tests
|
|
||||||
@patch("pytube.YouTube")
|
|
||||||
@patch.object(YouTube, "streams")
|
|
||||||
def test_speech_to_text_download_youtube_video(
|
|
||||||
mock_streams, mock_youtube, temp_dir
|
|
||||||
):
|
|
||||||
# Mock YouTube and streams
|
|
||||||
video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM"
|
|
||||||
mock_stream = mock_streams().filter().first()
|
|
||||||
mock_stream.download.return_value = os.path.join(
|
|
||||||
temp_dir, "video.mp4"
|
|
||||||
)
|
|
||||||
mock_youtube.return_value = mock_youtube
|
|
||||||
mock_youtube.streams = mock_streams
|
|
||||||
|
|
||||||
stt = WhisperX(video_url)
|
|
||||||
audio_file = stt.download_youtube_video()
|
|
||||||
|
|
||||||
assert os.path.exists(audio_file)
|
|
||||||
assert audio_file.endswith(".mp3")
|
|
||||||
|
|
||||||
|
|
||||||
# Mock whisperx.load_model and whisperx.load_audio for transcribe tests
|
|
||||||
@patch("whisperx.load_model")
|
|
||||||
@patch("whisperx.load_audio")
|
|
||||||
@patch("whisperx.load_align_model")
|
|
||||||
@patch("whisperx.align")
|
|
||||||
@patch.object(whisperx.DiarizationPipeline, "__call__")
|
|
||||||
def test_speech_to_text_transcribe_youtube_video(
|
|
||||||
mock_diarization,
|
|
||||||
mock_align,
|
|
||||||
mock_align_model,
|
|
||||||
mock_load_audio,
|
|
||||||
mock_load_model,
|
|
||||||
temp_dir,
|
|
||||||
):
|
|
||||||
# Mock whisperx functions
|
|
||||||
mock_load_model.return_value = mock_load_model
|
|
||||||
mock_load_model.transcribe.return_value = {
|
|
||||||
"language": "en",
|
|
||||||
"segments": [{"text": "Hello, World!"}],
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_load_audio.return_value = "audio_path"
|
|
||||||
mock_align_model.return_value = (mock_align_model, "metadata")
|
|
||||||
mock_align.return_value = {
|
|
||||||
"segments": [{"text": "Hello, World!"}]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Mock diarization pipeline
|
|
||||||
mock_diarization.return_value = None
|
|
||||||
|
|
||||||
video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM/video"
|
|
||||||
stt = WhisperX(video_url)
|
|
||||||
transcription = stt.transcribe_youtube_video()
|
|
||||||
|
|
||||||
assert transcription == "Hello, World!"
|
|
||||||
|
|
||||||
|
|
||||||
# More tests for different scenarios and edge cases can be added here.
|
|
||||||
|
|
||||||
|
|
||||||
# Test transcribe method with provided audio file
|
|
||||||
def test_speech_to_text_transcribe_audio_file(temp_dir):
|
|
||||||
# Create a temporary audio file
|
|
||||||
audio_file = os.path.join(temp_dir, "test_audio.mp3")
|
|
||||||
AudioSegment.silent(duration=500).export(audio_file, format="mp3")
|
|
||||||
|
|
||||||
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
|
||||||
transcription = stt.transcribe(audio_file)
|
|
||||||
|
|
||||||
assert transcription == ""
|
|
||||||
|
|
||||||
|
|
||||||
# Test transcribe method when Whisperx fails
|
|
||||||
@patch("whisperx.load_model")
|
|
||||||
@patch("whisperx.load_audio")
|
|
||||||
def test_speech_to_text_transcribe_whisperx_failure(
|
|
||||||
mock_load_audio, mock_load_model, temp_dir
|
|
||||||
):
|
|
||||||
# Mock whisperx functions to raise an exception
|
|
||||||
mock_load_model.side_effect = Exception("Whisperx failed")
|
|
||||||
mock_load_audio.return_value = "audio_path"
|
|
||||||
|
|
||||||
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
|
||||||
transcription = stt.transcribe("audio_path")
|
|
||||||
|
|
||||||
assert transcription == "Whisperx failed"
|
|
||||||
|
|
||||||
|
|
||||||
# Test transcribe method with missing 'segments' key in Whisperx output
|
|
||||||
@patch("whisperx.load_model")
|
|
||||||
@patch("whisperx.load_audio")
|
|
||||||
@patch("whisperx.load_align_model")
|
|
||||||
@patch("whisperx.align")
|
|
||||||
@patch.object(whisperx.DiarizationPipeline, "__call__")
|
|
||||||
def test_speech_to_text_transcribe_missing_segments(
|
|
||||||
mock_diarization,
|
|
||||||
mock_align,
|
|
||||||
mock_align_model,
|
|
||||||
mock_load_audio,
|
|
||||||
mock_load_model,
|
|
||||||
):
|
|
||||||
# Mock whisperx functions to return incomplete output
|
|
||||||
mock_load_model.return_value = mock_load_model
|
|
||||||
mock_load_model.transcribe.return_value = {"language": "en"}
|
|
||||||
|
|
||||||
mock_load_audio.return_value = "audio_path"
|
|
||||||
mock_align_model.return_value = (mock_align_model, "metadata")
|
|
||||||
mock_align.return_value = {}
|
|
||||||
|
|
||||||
# Mock diarization pipeline
|
|
||||||
mock_diarization.return_value = None
|
|
||||||
|
|
||||||
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
|
||||||
transcription = stt.transcribe("audio_path")
|
|
||||||
|
|
||||||
assert transcription == ""
|
|
||||||
|
|
||||||
|
|
||||||
# Test transcribe method with Whisperx align failure
|
|
||||||
@patch("whisperx.load_model")
|
|
||||||
@patch("whisperx.load_audio")
|
|
||||||
@patch("whisperx.load_align_model")
|
|
||||||
@patch("whisperx.align")
|
|
||||||
@patch.object(whisperx.DiarizationPipeline, "__call__")
|
|
||||||
def test_speech_to_text_transcribe_align_failure(
|
|
||||||
mock_diarization,
|
|
||||||
mock_align,
|
|
||||||
mock_align_model,
|
|
||||||
mock_load_audio,
|
|
||||||
mock_load_model,
|
|
||||||
):
|
|
||||||
# Mock whisperx functions to raise an exception during align
|
|
||||||
mock_load_model.return_value = mock_load_model
|
|
||||||
mock_load_model.transcribe.return_value = {
|
|
||||||
"language": "en",
|
|
||||||
"segments": [{"text": "Hello, World!"}],
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_load_audio.return_value = "audio_path"
|
|
||||||
mock_align_model.return_value = (mock_align_model, "metadata")
|
|
||||||
mock_align.side_effect = Exception("Align failed")
|
|
||||||
|
|
||||||
# Mock diarization pipeline
|
|
||||||
mock_diarization.return_value = None
|
|
||||||
|
|
||||||
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
|
||||||
transcription = stt.transcribe("audio_path")
|
|
||||||
|
|
||||||
assert transcription == "Align failed"
|
|
||||||
|
|
||||||
|
|
||||||
# Test transcribe_youtube_video when Whisperx diarization fails
|
|
||||||
@patch("pytube.YouTube")
|
|
||||||
@patch.object(YouTube, "streams")
|
|
||||||
@patch("whisperx.DiarizationPipeline")
|
|
||||||
@patch("whisperx.load_audio")
|
|
||||||
@patch("whisperx.load_align_model")
|
|
||||||
@patch("whisperx.align")
|
|
||||||
def test_speech_to_text_transcribe_diarization_failure(
|
|
||||||
mock_align,
|
|
||||||
mock_align_model,
|
|
||||||
mock_load_audio,
|
|
||||||
mock_diarization,
|
|
||||||
mock_streams,
|
|
||||||
mock_youtube,
|
|
||||||
temp_dir,
|
|
||||||
):
|
|
||||||
# Mock YouTube and streams
|
|
||||||
video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM"
|
|
||||||
mock_stream = mock_streams().filter().first()
|
|
||||||
mock_stream.download.return_value = os.path.join(
|
|
||||||
temp_dir, "video.mp4"
|
|
||||||
)
|
|
||||||
mock_youtube.return_value = mock_youtube
|
|
||||||
mock_youtube.streams = mock_streams
|
|
||||||
|
|
||||||
# Mock whisperx functions
|
|
||||||
mock_load_audio.return_value = "audio_path"
|
|
||||||
mock_align_model.return_value = (mock_align_model, "metadata")
|
|
||||||
mock_align.return_value = {
|
|
||||||
"segments": [{"text": "Hello, World!"}]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Mock diarization pipeline to raise an exception
|
|
||||||
mock_diarization.side_effect = Exception("Diarization failed")
|
|
||||||
|
|
||||||
stt = WhisperX(video_url)
|
|
||||||
transcription = stt.transcribe_youtube_video()
|
|
||||||
|
|
||||||
assert transcription == "Diarization failed"
|
|
||||||
|
|
||||||
|
|
||||||
# Add more tests for other scenarios and edge cases as needed.
|
|
Loading…
Reference in new issue