From b474cf9a5c3f7c10326f75349ad9baee4550280f Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 11 Nov 2023 17:41:02 -0500 Subject: [PATCH] SpeechToText -> Whixperx --- pyproject.toml | 2 +- swarms/models/whisperx.py | 2 +- tests/models/whisperx.py | 20 +++++++++----------- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 145594ae..973ea8ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarms" -version = "2.1.8" +version = "2.1.9" description = "Swarms - Pytorch" license = "MIT" authors = ["Kye Gomez "] diff --git a/swarms/models/whisperx.py b/swarms/models/whisperx.py index 102ae7d7..e980cf0a 100644 --- a/swarms/models/whisperx.py +++ b/swarms/models/whisperx.py @@ -21,7 +21,7 @@ class WhisperX: """ # Example usage video_url = "url" - speech_to_text = SpeechToText(video_url) + speech_to_text = WhisperX(video_url) transcription = speech_to_text.transcribe_youtube_video() print(transcription) diff --git a/tests/models/whisperx.py b/tests/models/whisperx.py index af2fe219..bcbd02e9 100644 --- a/tests/models/whisperx.py +++ b/tests/models/whisperx.py @@ -7,9 +7,7 @@ import pytest import whisperx from pydub import AudioSegment from pytube import YouTube -from your_module import SpeechToText - -from swarms.models.whisperx import SpeechToText +from swarms.models.whisperx import WhisperX # Fixture to create a temporary directory for testing @@ -22,7 +20,7 @@ def temp_dir(): # Mock subprocess.run to prevent actual installation during tests @patch.object(subprocess, "run") def test_speech_to_text_install(mock_run): - stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM") + stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") stt.install() mock_run.assert_called_with(["pip", "install", "whisperx"]) @@ -38,7 +36,7 @@ def test_speech_to_text_download_youtube_video(mock_streams, mock_youtube, temp_ mock_youtube.return_value = mock_youtube mock_youtube.streams = mock_streams - stt = SpeechToText(video_url) + stt = WhisperX(video_url) audio_file = stt.download_youtube_video() assert os.path.exists(audio_file) @@ -74,7 +72,7 @@ def test_speech_to_text_transcribe_youtube_video( mock_diarization.return_value = None video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM/video" - stt = SpeechToText(video_url) + stt = WhisperX(video_url) transcription = stt.transcribe_youtube_video() assert transcription == "Hello, World!" @@ -89,7 +87,7 @@ def test_speech_to_text_transcribe_audio_file(temp_dir): audio_file = os.path.join(temp_dir, "test_audio.mp3") AudioSegment.silent(duration=500).export(audio_file, format="mp3") - stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM") + stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") transcription = stt.transcribe(audio_file) assert transcription == "" @@ -105,7 +103,7 @@ def test_speech_to_text_transcribe_whisperx_failure( mock_load_model.side_effect = Exception("Whisperx failed") mock_load_audio.return_value = "audio_path" - stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM") + stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") transcription = stt.transcribe("audio_path") assert transcription == "Whisperx failed" @@ -131,7 +129,7 @@ def test_speech_to_text_transcribe_missing_segments( # Mock diarization pipeline mock_diarization.return_value = None - stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM") + stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") transcription = stt.transcribe("audio_path") assert transcription == "" @@ -160,7 +158,7 @@ def test_speech_to_text_transcribe_align_failure( # Mock diarization pipeline mock_diarization.return_value = None - stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM") + stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") transcription = stt.transcribe("audio_path") assert transcription == "Align failed" @@ -197,7 +195,7 @@ def test_speech_to_text_transcribe_diarization_failure( # Mock diarization pipeline to raise an exception mock_diarization.side_effect = Exception("Diarization failed") - stt = SpeechToText(video_url) + stt = WhisperX(video_url) transcription = stt.transcribe_youtube_video() assert transcription == "Diarization failed"