diff --git a/swarms/models/whisperx_model.py b/swarms/models/whisperx_model.py deleted file mode 100644 index e3b76fae..00000000 --- a/swarms/models/whisperx_model.py +++ /dev/null @@ -1,138 +0,0 @@ -import os -import subprocess - -try: - import whisperx - from pydub import AudioSegment - from pytube import YouTube -except Exception as error: - print("Error importing pytube. Please install pytube manually.") - print("pip install pytube") - print("pip install pydub") - print("pip install whisperx") - print(f"Pytube error: {error}") - - -class WhisperX: - def __init__( - self, - video_url, - audio_format="mp3", - device="cuda", - batch_size=16, - compute_type="float16", - hf_api_key=None, - ): - """ - # Example usage - video_url = "url" - speech_to_text = WhisperX(video_url) - transcription = speech_to_text.transcribe_youtube_video() - print(transcription) - - """ - self.video_url = video_url - self.audio_format = audio_format - self.device = device - self.batch_size = batch_size - self.compute_type = compute_type - self.hf_api_key = hf_api_key - - def install(self): - subprocess.run(["pip", "install", "whisperx"]) - subprocess.run(["pip", "install", "pytube"]) - subprocess.run(["pip", "install", "pydub"]) - - def download_youtube_video(self): - audio_file = f"video.{self.audio_format}" - - # Download video 📥 - yt = YouTube(self.video_url) - yt_stream = yt.streams.filter(only_audio=True).first() - yt_stream.download(filename="video.mp4") - - # Convert video to audio 🎧 - video = AudioSegment.from_file("video.mp4", format="mp4") - video.export(audio_file, format=self.audio_format) - os.remove("video.mp4") - - return audio_file - - def transcribe_youtube_video(self): - audio_file = self.download_youtube_video() - - device = "cuda" - batch_size = 16 - compute_type = "float16" - - # 1. Transcribe with original Whisper (batched) 🗣️ - model = whisperx.load_model( - "large-v2", device, compute_type=compute_type - ) - audio = whisperx.load_audio(audio_file) - result = model.transcribe(audio, batch_size=batch_size) - - # 2. Align Whisper output 🔍 - model_a, metadata = whisperx.load_align_model( - language_code=result["language"], device=device - ) - result = whisperx.align( - result["segments"], - model_a, - metadata, - audio, - device, - return_char_alignments=False, - ) - - # 3. Assign speaker labels 🏷️ - diarize_model = whisperx.DiarizationPipeline( - use_auth_token=self.hf_api_key, device=device - ) - diarize_model(audio_file) - - try: - segments = result["segments"] - transcription = " ".join( - segment["text"] for segment in segments - ) - return transcription - except KeyError: - print("The key 'segments' is not found in the result.") - - def transcribe(self, audio_file): - model = whisperx.load_model( - "large-v2", self.device, self.compute_type - ) - audio = whisperx.load_audio(audio_file) - result = model.transcribe(audio, batch_size=self.batch_size) - - # 2. Align Whisper output 🔍 - model_a, metadata = whisperx.load_align_model( - language_code=result["language"], device=self.device - ) - - result = whisperx.align( - result["segments"], - model_a, - metadata, - audio, - self.device, - return_char_alignments=False, - ) - - # 3. Assign speaker labels 🏷️ - diarize_model = whisperx.DiarizationPipeline( - use_auth_token=self.hf_api_key, device=self.device - ) - - diarize_model(audio_file) - - try: - segments = result["segments"] - transcription = " ".join( - segment["text"] for segment in segments - ) - return transcription - except KeyError: - print("The key 'segments' is not found in the result.") diff --git a/tests/models/test_whisperx.py b/tests/models/test_whisperx.py deleted file mode 100644 index 4b0e4120..00000000 --- a/tests/models/test_whisperx.py +++ /dev/null @@ -1,222 +0,0 @@ -import os -import subprocess -import tempfile -from unittest.mock import patch - -import pytest -import whisperx -from pydub import AudioSegment -from pytube import YouTube -from swarms.models.whisperx_model import WhisperX - - -# Fixture to create a temporary directory for testing -@pytest.fixture -def temp_dir(): - with tempfile.TemporaryDirectory() as tempdir: - yield tempdir - - -# Mock subprocess.run to prevent actual installation during tests -@patch.object(subprocess, "run") -def test_speech_to_text_install(mock_run): - stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") - stt.install() - mock_run.assert_called_with(["pip", "install", "whisperx"]) - - -# Mock pytube.YouTube and pytube.Streams for download tests -@patch("pytube.YouTube") -@patch.object(YouTube, "streams") -def test_speech_to_text_download_youtube_video( - mock_streams, mock_youtube, temp_dir -): - # Mock YouTube and streams - video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM" - mock_stream = mock_streams().filter().first() - mock_stream.download.return_value = os.path.join( - temp_dir, "video.mp4" - ) - mock_youtube.return_value = mock_youtube - mock_youtube.streams = mock_streams - - stt = WhisperX(video_url) - audio_file = stt.download_youtube_video() - - assert os.path.exists(audio_file) - assert audio_file.endswith(".mp3") - - -# Mock whisperx.load_model and whisperx.load_audio for transcribe tests -@patch("whisperx.load_model") -@patch("whisperx.load_audio") -@patch("whisperx.load_align_model") -@patch("whisperx.align") -@patch.object(whisperx.DiarizationPipeline, "__call__") -def test_speech_to_text_transcribe_youtube_video( - mock_diarization, - mock_align, - mock_align_model, - mock_load_audio, - mock_load_model, - temp_dir, -): - # Mock whisperx functions - mock_load_model.return_value = mock_load_model - mock_load_model.transcribe.return_value = { - "language": "en", - "segments": [{"text": "Hello, World!"}], - } - - mock_load_audio.return_value = "audio_path" - mock_align_model.return_value = (mock_align_model, "metadata") - mock_align.return_value = { - "segments": [{"text": "Hello, World!"}] - } - - # Mock diarization pipeline - mock_diarization.return_value = None - - video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM/video" - stt = WhisperX(video_url) - transcription = stt.transcribe_youtube_video() - - assert transcription == "Hello, World!" - - -# More tests for different scenarios and edge cases can be added here. - - -# Test transcribe method with provided audio file -def test_speech_to_text_transcribe_audio_file(temp_dir): - # Create a temporary audio file - audio_file = os.path.join(temp_dir, "test_audio.mp3") - AudioSegment.silent(duration=500).export(audio_file, format="mp3") - - stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") - transcription = stt.transcribe(audio_file) - - assert transcription == "" - - -# Test transcribe method when Whisperx fails -@patch("whisperx.load_model") -@patch("whisperx.load_audio") -def test_speech_to_text_transcribe_whisperx_failure( - mock_load_audio, mock_load_model, temp_dir -): - # Mock whisperx functions to raise an exception - mock_load_model.side_effect = Exception("Whisperx failed") - mock_load_audio.return_value = "audio_path" - - stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") - transcription = stt.transcribe("audio_path") - - assert transcription == "Whisperx failed" - - -# Test transcribe method with missing 'segments' key in Whisperx output -@patch("whisperx.load_model") -@patch("whisperx.load_audio") -@patch("whisperx.load_align_model") -@patch("whisperx.align") -@patch.object(whisperx.DiarizationPipeline, "__call__") -def test_speech_to_text_transcribe_missing_segments( - mock_diarization, - mock_align, - mock_align_model, - mock_load_audio, - mock_load_model, -): - # Mock whisperx functions to return incomplete output - mock_load_model.return_value = mock_load_model - mock_load_model.transcribe.return_value = {"language": "en"} - - mock_load_audio.return_value = "audio_path" - mock_align_model.return_value = (mock_align_model, "metadata") - mock_align.return_value = {} - - # Mock diarization pipeline - mock_diarization.return_value = None - - stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") - transcription = stt.transcribe("audio_path") - - assert transcription == "" - - -# Test transcribe method with Whisperx align failure -@patch("whisperx.load_model") -@patch("whisperx.load_audio") -@patch("whisperx.load_align_model") -@patch("whisperx.align") -@patch.object(whisperx.DiarizationPipeline, "__call__") -def test_speech_to_text_transcribe_align_failure( - mock_diarization, - mock_align, - mock_align_model, - mock_load_audio, - mock_load_model, -): - # Mock whisperx functions to raise an exception during align - mock_load_model.return_value = mock_load_model - mock_load_model.transcribe.return_value = { - "language": "en", - "segments": [{"text": "Hello, World!"}], - } - - mock_load_audio.return_value = "audio_path" - mock_align_model.return_value = (mock_align_model, "metadata") - mock_align.side_effect = Exception("Align failed") - - # Mock diarization pipeline - mock_diarization.return_value = None - - stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") - transcription = stt.transcribe("audio_path") - - assert transcription == "Align failed" - - -# Test transcribe_youtube_video when Whisperx diarization fails -@patch("pytube.YouTube") -@patch.object(YouTube, "streams") -@patch("whisperx.DiarizationPipeline") -@patch("whisperx.load_audio") -@patch("whisperx.load_align_model") -@patch("whisperx.align") -def test_speech_to_text_transcribe_diarization_failure( - mock_align, - mock_align_model, - mock_load_audio, - mock_diarization, - mock_streams, - mock_youtube, - temp_dir, -): - # Mock YouTube and streams - video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM" - mock_stream = mock_streams().filter().first() - mock_stream.download.return_value = os.path.join( - temp_dir, "video.mp4" - ) - mock_youtube.return_value = mock_youtube - mock_youtube.streams = mock_streams - - # Mock whisperx functions - mock_load_audio.return_value = "audio_path" - mock_align_model.return_value = (mock_align_model, "metadata") - mock_align.return_value = { - "segments": [{"text": "Hello, World!"}] - } - - # Mock diarization pipeline to raise an exception - mock_diarization.side_effect = Exception("Diarization failed") - - stt = WhisperX(video_url) - transcription = stt.transcribe_youtube_video() - - assert transcription == "Diarization failed" - - -# Add more tests for other scenarios and edge cases as needed.