import os import subprocess import tempfile from unittest.mock import patch import pytest import whisperx from pydub import AudioSegment from pytube import YouTube from swarms.models.whisperx import WhisperX # Fixture to create a temporary directory for testing @pytest.fixture def temp_dir(): with tempfile.TemporaryDirectory() as tempdir: yield tempdir # Mock subprocess.run to prevent actual installation during tests @patch.object(subprocess, "run") def test_speech_to_text_install(mock_run): stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") stt.install() mock_run.assert_called_with(["pip", "install", "whisperx"]) # Mock pytube.YouTube and pytube.Streams for download tests @patch("pytube.YouTube") @patch.object(YouTube, "streams") def test_speech_to_text_download_youtube_video(mock_streams, mock_youtube, temp_dir): # Mock YouTube and streams video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM" mock_stream = mock_streams().filter().first() mock_stream.download.return_value = os.path.join(temp_dir, "video.mp4") mock_youtube.return_value = mock_youtube mock_youtube.streams = mock_streams stt = WhisperX(video_url) audio_file = stt.download_youtube_video() assert os.path.exists(audio_file) assert audio_file.endswith(".mp3") # Mock whisperx.load_model and whisperx.load_audio for transcribe tests @patch("whisperx.load_model") @patch("whisperx.load_audio") @patch("whisperx.load_align_model") @patch("whisperx.align") @patch.object(whisperx.DiarizationPipeline, "__call__") def test_speech_to_text_transcribe_youtube_video( mock_diarization, mock_align, mock_align_model, mock_load_audio, mock_load_model, temp_dir, ): # Mock whisperx functions mock_load_model.return_value = mock_load_model mock_load_model.transcribe.return_value = { "language": "en", "segments": [{"text": "Hello, World!"}], } mock_load_audio.return_value = "audio_path" mock_align_model.return_value = (mock_align_model, "metadata") mock_align.return_value = {"segments": [{"text": "Hello, World!"}]} # Mock diarization pipeline mock_diarization.return_value = None video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM/video" stt = WhisperX(video_url) transcription = stt.transcribe_youtube_video() assert transcription == "Hello, World!" # More tests for different scenarios and edge cases can be added here. # Test transcribe method with provided audio file def test_speech_to_text_transcribe_audio_file(temp_dir): # Create a temporary audio file audio_file = os.path.join(temp_dir, "test_audio.mp3") AudioSegment.silent(duration=500).export(audio_file, format="mp3") stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") transcription = stt.transcribe(audio_file) assert transcription == "" # Test transcribe method when Whisperx fails @patch("whisperx.load_model") @patch("whisperx.load_audio") def test_speech_to_text_transcribe_whisperx_failure( mock_load_audio, mock_load_model, temp_dir ): # Mock whisperx functions to raise an exception mock_load_model.side_effect = Exception("Whisperx failed") mock_load_audio.return_value = "audio_path" stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") transcription = stt.transcribe("audio_path") assert transcription == "Whisperx failed" # Test transcribe method with missing 'segments' key in Whisperx output @patch("whisperx.load_model") @patch("whisperx.load_audio") @patch("whisperx.load_align_model") @patch("whisperx.align") @patch.object(whisperx.DiarizationPipeline, "__call__") def test_speech_to_text_transcribe_missing_segments( mock_diarization, mock_align, mock_align_model, mock_load_audio, mock_load_model ): # Mock whisperx functions to return incomplete output mock_load_model.return_value = mock_load_model mock_load_model.transcribe.return_value = {"language": "en"} mock_load_audio.return_value = "audio_path" mock_align_model.return_value = (mock_align_model, "metadata") mock_align.return_value = {} # Mock diarization pipeline mock_diarization.return_value = None stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") transcription = stt.transcribe("audio_path") assert transcription == "" # Test transcribe method with Whisperx align failure @patch("whisperx.load_model") @patch("whisperx.load_audio") @patch("whisperx.load_align_model") @patch("whisperx.align") @patch.object(whisperx.DiarizationPipeline, "__call__") def test_speech_to_text_transcribe_align_failure( mock_diarization, mock_align, mock_align_model, mock_load_audio, mock_load_model ): # Mock whisperx functions to raise an exception during align mock_load_model.return_value = mock_load_model mock_load_model.transcribe.return_value = { "language": "en", "segments": [{"text": "Hello, World!"}], } mock_load_audio.return_value = "audio_path" mock_align_model.return_value = (mock_align_model, "metadata") mock_align.side_effect = Exception("Align failed") # Mock diarization pipeline mock_diarization.return_value = None stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") transcription = stt.transcribe("audio_path") assert transcription == "Align failed" # Test transcribe_youtube_video when Whisperx diarization fails @patch("pytube.YouTube") @patch.object(YouTube, "streams") @patch("whisperx.DiarizationPipeline") @patch("whisperx.load_audio") @patch("whisperx.load_align_model") @patch("whisperx.align") def test_speech_to_text_transcribe_diarization_failure( mock_align, mock_align_model, mock_load_audio, mock_diarization, mock_streams, mock_youtube, temp_dir, ): # Mock YouTube and streams video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM" mock_stream = mock_streams().filter().first() mock_stream.download.return_value = os.path.join(temp_dir, "video.mp4") mock_youtube.return_value = mock_youtube mock_youtube.streams = mock_streams # Mock whisperx functions mock_load_audio.return_value = "audio_path" mock_align_model.return_value = (mock_align_model, "metadata") mock_align.return_value = {"segments": [{"text": "Hello, World!"}]} # Mock diarization pipeline to raise an exception mock_diarization.side_effect = Exception("Diarization failed") stt = WhisperX(video_url) transcription = stt.transcribe_youtube_video() assert transcription == "Diarization failed" # Add more tests for other scenarios and edge cases as needed.