parent
							
								
									7cd6f25353
								
							
						
					
					
						commit
						d33588becc
					
				| @ -1,138 +0,0 @@ | |||||||
| import os |  | ||||||
| import subprocess |  | ||||||
| 
 |  | ||||||
| try: |  | ||||||
|     import whisperx |  | ||||||
|     from pydub import AudioSegment |  | ||||||
|     from pytube import YouTube |  | ||||||
| except Exception as error: |  | ||||||
|     print("Error importing pytube. Please install pytube manually.") |  | ||||||
|     print("pip install pytube") |  | ||||||
|     print("pip install pydub") |  | ||||||
|     print("pip install whisperx") |  | ||||||
|     print(f"Pytube error: {error}") |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class WhisperX: |  | ||||||
|     def __init__( |  | ||||||
|         self, |  | ||||||
|         video_url, |  | ||||||
|         audio_format="mp3", |  | ||||||
|         device="cuda", |  | ||||||
|         batch_size=16, |  | ||||||
|         compute_type="float16", |  | ||||||
|         hf_api_key=None, |  | ||||||
|     ): |  | ||||||
|         """ |  | ||||||
|         # Example usage |  | ||||||
|         video_url = "url" |  | ||||||
|         speech_to_text = WhisperX(video_url) |  | ||||||
|         transcription = speech_to_text.transcribe_youtube_video() |  | ||||||
|         print(transcription) |  | ||||||
| 
 |  | ||||||
|         """ |  | ||||||
|         self.video_url = video_url |  | ||||||
|         self.audio_format = audio_format |  | ||||||
|         self.device = device |  | ||||||
|         self.batch_size = batch_size |  | ||||||
|         self.compute_type = compute_type |  | ||||||
|         self.hf_api_key = hf_api_key |  | ||||||
| 
 |  | ||||||
|     def install(self): |  | ||||||
|         subprocess.run(["pip", "install", "whisperx"]) |  | ||||||
|         subprocess.run(["pip", "install", "pytube"]) |  | ||||||
|         subprocess.run(["pip", "install", "pydub"]) |  | ||||||
| 
 |  | ||||||
|     def download_youtube_video(self): |  | ||||||
|         audio_file = f"video.{self.audio_format}" |  | ||||||
| 
 |  | ||||||
|         # Download video 📥 |  | ||||||
|         yt = YouTube(self.video_url) |  | ||||||
|         yt_stream = yt.streams.filter(only_audio=True).first() |  | ||||||
|         yt_stream.download(filename="video.mp4") |  | ||||||
| 
 |  | ||||||
|         # Convert video to audio 🎧 |  | ||||||
|         video = AudioSegment.from_file("video.mp4", format="mp4") |  | ||||||
|         video.export(audio_file, format=self.audio_format) |  | ||||||
|         os.remove("video.mp4") |  | ||||||
| 
 |  | ||||||
|         return audio_file |  | ||||||
| 
 |  | ||||||
|     def transcribe_youtube_video(self): |  | ||||||
|         audio_file = self.download_youtube_video() |  | ||||||
| 
 |  | ||||||
|         device = "cuda" |  | ||||||
|         batch_size = 16 |  | ||||||
|         compute_type = "float16" |  | ||||||
| 
 |  | ||||||
|         # 1. Transcribe with original Whisper (batched) 🗣️ |  | ||||||
|         model = whisperx.load_model( |  | ||||||
|             "large-v2", device, compute_type=compute_type |  | ||||||
|         ) |  | ||||||
|         audio = whisperx.load_audio(audio_file) |  | ||||||
|         result = model.transcribe(audio, batch_size=batch_size) |  | ||||||
| 
 |  | ||||||
|         # 2. Align Whisper output 🔍 |  | ||||||
|         model_a, metadata = whisperx.load_align_model( |  | ||||||
|             language_code=result["language"], device=device |  | ||||||
|         ) |  | ||||||
|         result = whisperx.align( |  | ||||||
|             result["segments"], |  | ||||||
|             model_a, |  | ||||||
|             metadata, |  | ||||||
|             audio, |  | ||||||
|             device, |  | ||||||
|             return_char_alignments=False, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         # 3. Assign speaker labels 🏷️ |  | ||||||
|         diarize_model = whisperx.DiarizationPipeline( |  | ||||||
|             use_auth_token=self.hf_api_key, device=device |  | ||||||
|         ) |  | ||||||
|         diarize_model(audio_file) |  | ||||||
| 
 |  | ||||||
|         try: |  | ||||||
|             segments = result["segments"] |  | ||||||
|             transcription = " ".join( |  | ||||||
|                 segment["text"] for segment in segments |  | ||||||
|             ) |  | ||||||
|             return transcription |  | ||||||
|         except KeyError: |  | ||||||
|             print("The key 'segments' is not found in the result.") |  | ||||||
| 
 |  | ||||||
|     def transcribe(self, audio_file): |  | ||||||
|         model = whisperx.load_model( |  | ||||||
|             "large-v2", self.device, self.compute_type |  | ||||||
|         ) |  | ||||||
|         audio = whisperx.load_audio(audio_file) |  | ||||||
|         result = model.transcribe(audio, batch_size=self.batch_size) |  | ||||||
| 
 |  | ||||||
|         # 2. Align Whisper output 🔍 |  | ||||||
|         model_a, metadata = whisperx.load_align_model( |  | ||||||
|             language_code=result["language"], device=self.device |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         result = whisperx.align( |  | ||||||
|             result["segments"], |  | ||||||
|             model_a, |  | ||||||
|             metadata, |  | ||||||
|             audio, |  | ||||||
|             self.device, |  | ||||||
|             return_char_alignments=False, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         # 3. Assign speaker labels 🏷️ |  | ||||||
|         diarize_model = whisperx.DiarizationPipeline( |  | ||||||
|             use_auth_token=self.hf_api_key, device=self.device |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         diarize_model(audio_file) |  | ||||||
| 
 |  | ||||||
|         try: |  | ||||||
|             segments = result["segments"] |  | ||||||
|             transcription = " ".join( |  | ||||||
|                 segment["text"] for segment in segments |  | ||||||
|             ) |  | ||||||
|             return transcription |  | ||||||
|         except KeyError: |  | ||||||
|             print("The key 'segments' is not found in the result.") |  | ||||||
| @ -1,222 +0,0 @@ | |||||||
| import os |  | ||||||
| import subprocess |  | ||||||
| import tempfile |  | ||||||
| from unittest.mock import patch |  | ||||||
| 
 |  | ||||||
| import pytest |  | ||||||
| import whisperx |  | ||||||
| from pydub import AudioSegment |  | ||||||
| from pytube import YouTube |  | ||||||
| from swarms.models.whisperx_model import WhisperX |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| # Fixture to create a temporary directory for testing |  | ||||||
| @pytest.fixture |  | ||||||
| def temp_dir(): |  | ||||||
|     with tempfile.TemporaryDirectory() as tempdir: |  | ||||||
|         yield tempdir |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| # Mock subprocess.run to prevent actual installation during tests |  | ||||||
| @patch.object(subprocess, "run") |  | ||||||
| def test_speech_to_text_install(mock_run): |  | ||||||
|     stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") |  | ||||||
|     stt.install() |  | ||||||
|     mock_run.assert_called_with(["pip", "install", "whisperx"]) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| # Mock pytube.YouTube and pytube.Streams for download tests |  | ||||||
| @patch("pytube.YouTube") |  | ||||||
| @patch.object(YouTube, "streams") |  | ||||||
| def test_speech_to_text_download_youtube_video( |  | ||||||
|     mock_streams, mock_youtube, temp_dir |  | ||||||
| ): |  | ||||||
|     # Mock YouTube and streams |  | ||||||
|     video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM" |  | ||||||
|     mock_stream = mock_streams().filter().first() |  | ||||||
|     mock_stream.download.return_value = os.path.join( |  | ||||||
|         temp_dir, "video.mp4" |  | ||||||
|     ) |  | ||||||
|     mock_youtube.return_value = mock_youtube |  | ||||||
|     mock_youtube.streams = mock_streams |  | ||||||
| 
 |  | ||||||
|     stt = WhisperX(video_url) |  | ||||||
|     audio_file = stt.download_youtube_video() |  | ||||||
| 
 |  | ||||||
|     assert os.path.exists(audio_file) |  | ||||||
|     assert audio_file.endswith(".mp3") |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| # Mock whisperx.load_model and whisperx.load_audio for transcribe tests |  | ||||||
| @patch("whisperx.load_model") |  | ||||||
| @patch("whisperx.load_audio") |  | ||||||
| @patch("whisperx.load_align_model") |  | ||||||
| @patch("whisperx.align") |  | ||||||
| @patch.object(whisperx.DiarizationPipeline, "__call__") |  | ||||||
| def test_speech_to_text_transcribe_youtube_video( |  | ||||||
|     mock_diarization, |  | ||||||
|     mock_align, |  | ||||||
|     mock_align_model, |  | ||||||
|     mock_load_audio, |  | ||||||
|     mock_load_model, |  | ||||||
|     temp_dir, |  | ||||||
| ): |  | ||||||
|     # Mock whisperx functions |  | ||||||
|     mock_load_model.return_value = mock_load_model |  | ||||||
|     mock_load_model.transcribe.return_value = { |  | ||||||
|         "language": "en", |  | ||||||
|         "segments": [{"text": "Hello, World!"}], |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     mock_load_audio.return_value = "audio_path" |  | ||||||
|     mock_align_model.return_value = (mock_align_model, "metadata") |  | ||||||
|     mock_align.return_value = { |  | ||||||
|         "segments": [{"text": "Hello, World!"}] |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     # Mock diarization pipeline |  | ||||||
|     mock_diarization.return_value = None |  | ||||||
| 
 |  | ||||||
|     video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM/video" |  | ||||||
|     stt = WhisperX(video_url) |  | ||||||
|     transcription = stt.transcribe_youtube_video() |  | ||||||
| 
 |  | ||||||
|     assert transcription == "Hello, World!" |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| # More tests for different scenarios and edge cases can be added here. |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| # Test transcribe method with provided audio file |  | ||||||
| def test_speech_to_text_transcribe_audio_file(temp_dir): |  | ||||||
|     # Create a temporary audio file |  | ||||||
|     audio_file = os.path.join(temp_dir, "test_audio.mp3") |  | ||||||
|     AudioSegment.silent(duration=500).export(audio_file, format="mp3") |  | ||||||
| 
 |  | ||||||
|     stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") |  | ||||||
|     transcription = stt.transcribe(audio_file) |  | ||||||
| 
 |  | ||||||
|     assert transcription == "" |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| # Test transcribe method when Whisperx fails |  | ||||||
| @patch("whisperx.load_model") |  | ||||||
| @patch("whisperx.load_audio") |  | ||||||
| def test_speech_to_text_transcribe_whisperx_failure( |  | ||||||
|     mock_load_audio, mock_load_model, temp_dir |  | ||||||
| ): |  | ||||||
|     # Mock whisperx functions to raise an exception |  | ||||||
|     mock_load_model.side_effect = Exception("Whisperx failed") |  | ||||||
|     mock_load_audio.return_value = "audio_path" |  | ||||||
| 
 |  | ||||||
|     stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") |  | ||||||
|     transcription = stt.transcribe("audio_path") |  | ||||||
| 
 |  | ||||||
|     assert transcription == "Whisperx failed" |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| # Test transcribe method with missing 'segments' key in Whisperx output |  | ||||||
| @patch("whisperx.load_model") |  | ||||||
| @patch("whisperx.load_audio") |  | ||||||
| @patch("whisperx.load_align_model") |  | ||||||
| @patch("whisperx.align") |  | ||||||
| @patch.object(whisperx.DiarizationPipeline, "__call__") |  | ||||||
| def test_speech_to_text_transcribe_missing_segments( |  | ||||||
|     mock_diarization, |  | ||||||
|     mock_align, |  | ||||||
|     mock_align_model, |  | ||||||
|     mock_load_audio, |  | ||||||
|     mock_load_model, |  | ||||||
| ): |  | ||||||
|     # Mock whisperx functions to return incomplete output |  | ||||||
|     mock_load_model.return_value = mock_load_model |  | ||||||
|     mock_load_model.transcribe.return_value = {"language": "en"} |  | ||||||
| 
 |  | ||||||
|     mock_load_audio.return_value = "audio_path" |  | ||||||
|     mock_align_model.return_value = (mock_align_model, "metadata") |  | ||||||
|     mock_align.return_value = {} |  | ||||||
| 
 |  | ||||||
|     # Mock diarization pipeline |  | ||||||
|     mock_diarization.return_value = None |  | ||||||
| 
 |  | ||||||
|     stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") |  | ||||||
|     transcription = stt.transcribe("audio_path") |  | ||||||
| 
 |  | ||||||
|     assert transcription == "" |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| # Test transcribe method with Whisperx align failure |  | ||||||
| @patch("whisperx.load_model") |  | ||||||
| @patch("whisperx.load_audio") |  | ||||||
| @patch("whisperx.load_align_model") |  | ||||||
| @patch("whisperx.align") |  | ||||||
| @patch.object(whisperx.DiarizationPipeline, "__call__") |  | ||||||
| def test_speech_to_text_transcribe_align_failure( |  | ||||||
|     mock_diarization, |  | ||||||
|     mock_align, |  | ||||||
|     mock_align_model, |  | ||||||
|     mock_load_audio, |  | ||||||
|     mock_load_model, |  | ||||||
| ): |  | ||||||
|     # Mock whisperx functions to raise an exception during align |  | ||||||
|     mock_load_model.return_value = mock_load_model |  | ||||||
|     mock_load_model.transcribe.return_value = { |  | ||||||
|         "language": "en", |  | ||||||
|         "segments": [{"text": "Hello, World!"}], |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     mock_load_audio.return_value = "audio_path" |  | ||||||
|     mock_align_model.return_value = (mock_align_model, "metadata") |  | ||||||
|     mock_align.side_effect = Exception("Align failed") |  | ||||||
| 
 |  | ||||||
|     # Mock diarization pipeline |  | ||||||
|     mock_diarization.return_value = None |  | ||||||
| 
 |  | ||||||
|     stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") |  | ||||||
|     transcription = stt.transcribe("audio_path") |  | ||||||
| 
 |  | ||||||
|     assert transcription == "Align failed" |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| # Test transcribe_youtube_video when Whisperx diarization fails |  | ||||||
| @patch("pytube.YouTube") |  | ||||||
| @patch.object(YouTube, "streams") |  | ||||||
| @patch("whisperx.DiarizationPipeline") |  | ||||||
| @patch("whisperx.load_audio") |  | ||||||
| @patch("whisperx.load_align_model") |  | ||||||
| @patch("whisperx.align") |  | ||||||
| def test_speech_to_text_transcribe_diarization_failure( |  | ||||||
|     mock_align, |  | ||||||
|     mock_align_model, |  | ||||||
|     mock_load_audio, |  | ||||||
|     mock_diarization, |  | ||||||
|     mock_streams, |  | ||||||
|     mock_youtube, |  | ||||||
|     temp_dir, |  | ||||||
| ): |  | ||||||
|     # Mock YouTube and streams |  | ||||||
|     video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM" |  | ||||||
|     mock_stream = mock_streams().filter().first() |  | ||||||
|     mock_stream.download.return_value = os.path.join( |  | ||||||
|         temp_dir, "video.mp4" |  | ||||||
|     ) |  | ||||||
|     mock_youtube.return_value = mock_youtube |  | ||||||
|     mock_youtube.streams = mock_streams |  | ||||||
| 
 |  | ||||||
|     # Mock whisperx functions |  | ||||||
|     mock_load_audio.return_value = "audio_path" |  | ||||||
|     mock_align_model.return_value = (mock_align_model, "metadata") |  | ||||||
|     mock_align.return_value = { |  | ||||||
|         "segments": [{"text": "Hello, World!"}] |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     # Mock diarization pipeline to raise an exception |  | ||||||
|     mock_diarization.side_effect = Exception("Diarization failed") |  | ||||||
| 
 |  | ||||||
|     stt = WhisperX(video_url) |  | ||||||
|     transcription = stt.transcribe_youtube_video() |  | ||||||
| 
 |  | ||||||
|     assert transcription == "Diarization failed" |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| # Add more tests for other scenarios and edge cases as needed. |  | ||||||
					Loading…
					
					
				
		Reference in new issue