|
|
@ -7,9 +7,7 @@ import pytest
|
|
|
|
import whisperx
|
|
|
|
import whisperx
|
|
|
|
from pydub import AudioSegment
|
|
|
|
from pydub import AudioSegment
|
|
|
|
from pytube import YouTube
|
|
|
|
from pytube import YouTube
|
|
|
|
from your_module import SpeechToText
|
|
|
|
from swarms.models.whisperx import WhisperX
|
|
|
|
|
|
|
|
|
|
|
|
from swarms.models.whisperx import SpeechToText
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Fixture to create a temporary directory for testing
|
|
|
|
# Fixture to create a temporary directory for testing
|
|
|
@ -22,7 +20,7 @@ def temp_dir():
|
|
|
|
# Mock subprocess.run to prevent actual installation during tests
|
|
|
|
# Mock subprocess.run to prevent actual installation during tests
|
|
|
|
@patch.object(subprocess, "run")
|
|
|
|
@patch.object(subprocess, "run")
|
|
|
|
def test_speech_to_text_install(mock_run):
|
|
|
|
def test_speech_to_text_install(mock_run):
|
|
|
|
stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
|
|
|
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
|
|
|
stt.install()
|
|
|
|
stt.install()
|
|
|
|
mock_run.assert_called_with(["pip", "install", "whisperx"])
|
|
|
|
mock_run.assert_called_with(["pip", "install", "whisperx"])
|
|
|
|
|
|
|
|
|
|
|
@ -38,7 +36,7 @@ def test_speech_to_text_download_youtube_video(mock_streams, mock_youtube, temp_
|
|
|
|
mock_youtube.return_value = mock_youtube
|
|
|
|
mock_youtube.return_value = mock_youtube
|
|
|
|
mock_youtube.streams = mock_streams
|
|
|
|
mock_youtube.streams = mock_streams
|
|
|
|
|
|
|
|
|
|
|
|
stt = SpeechToText(video_url)
|
|
|
|
stt = WhisperX(video_url)
|
|
|
|
audio_file = stt.download_youtube_video()
|
|
|
|
audio_file = stt.download_youtube_video()
|
|
|
|
|
|
|
|
|
|
|
|
assert os.path.exists(audio_file)
|
|
|
|
assert os.path.exists(audio_file)
|
|
|
@ -74,7 +72,7 @@ def test_speech_to_text_transcribe_youtube_video(
|
|
|
|
mock_diarization.return_value = None
|
|
|
|
mock_diarization.return_value = None
|
|
|
|
|
|
|
|
|
|
|
|
video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM/video"
|
|
|
|
video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM/video"
|
|
|
|
stt = SpeechToText(video_url)
|
|
|
|
stt = WhisperX(video_url)
|
|
|
|
transcription = stt.transcribe_youtube_video()
|
|
|
|
transcription = stt.transcribe_youtube_video()
|
|
|
|
|
|
|
|
|
|
|
|
assert transcription == "Hello, World!"
|
|
|
|
assert transcription == "Hello, World!"
|
|
|
@ -89,7 +87,7 @@ def test_speech_to_text_transcribe_audio_file(temp_dir):
|
|
|
|
audio_file = os.path.join(temp_dir, "test_audio.mp3")
|
|
|
|
audio_file = os.path.join(temp_dir, "test_audio.mp3")
|
|
|
|
AudioSegment.silent(duration=500).export(audio_file, format="mp3")
|
|
|
|
AudioSegment.silent(duration=500).export(audio_file, format="mp3")
|
|
|
|
|
|
|
|
|
|
|
|
stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
|
|
|
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
|
|
|
transcription = stt.transcribe(audio_file)
|
|
|
|
transcription = stt.transcribe(audio_file)
|
|
|
|
|
|
|
|
|
|
|
|
assert transcription == ""
|
|
|
|
assert transcription == ""
|
|
|
@ -105,7 +103,7 @@ def test_speech_to_text_transcribe_whisperx_failure(
|
|
|
|
mock_load_model.side_effect = Exception("Whisperx failed")
|
|
|
|
mock_load_model.side_effect = Exception("Whisperx failed")
|
|
|
|
mock_load_audio.return_value = "audio_path"
|
|
|
|
mock_load_audio.return_value = "audio_path"
|
|
|
|
|
|
|
|
|
|
|
|
stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
|
|
|
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
|
|
|
transcription = stt.transcribe("audio_path")
|
|
|
|
transcription = stt.transcribe("audio_path")
|
|
|
|
|
|
|
|
|
|
|
|
assert transcription == "Whisperx failed"
|
|
|
|
assert transcription == "Whisperx failed"
|
|
|
@ -131,7 +129,7 @@ def test_speech_to_text_transcribe_missing_segments(
|
|
|
|
# Mock diarization pipeline
|
|
|
|
# Mock diarization pipeline
|
|
|
|
mock_diarization.return_value = None
|
|
|
|
mock_diarization.return_value = None
|
|
|
|
|
|
|
|
|
|
|
|
stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
|
|
|
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
|
|
|
transcription = stt.transcribe("audio_path")
|
|
|
|
transcription = stt.transcribe("audio_path")
|
|
|
|
|
|
|
|
|
|
|
|
assert transcription == ""
|
|
|
|
assert transcription == ""
|
|
|
@ -160,7 +158,7 @@ def test_speech_to_text_transcribe_align_failure(
|
|
|
|
# Mock diarization pipeline
|
|
|
|
# Mock diarization pipeline
|
|
|
|
mock_diarization.return_value = None
|
|
|
|
mock_diarization.return_value = None
|
|
|
|
|
|
|
|
|
|
|
|
stt = SpeechToText("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
|
|
|
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
|
|
|
transcription = stt.transcribe("audio_path")
|
|
|
|
transcription = stt.transcribe("audio_path")
|
|
|
|
|
|
|
|
|
|
|
|
assert transcription == "Align failed"
|
|
|
|
assert transcription == "Align failed"
|
|
|
@ -197,7 +195,7 @@ def test_speech_to_text_transcribe_diarization_failure(
|
|
|
|
# Mock diarization pipeline to raise an exception
|
|
|
|
# Mock diarization pipeline to raise an exception
|
|
|
|
mock_diarization.side_effect = Exception("Diarization failed")
|
|
|
|
mock_diarization.side_effect = Exception("Diarization failed")
|
|
|
|
|
|
|
|
|
|
|
|
stt = SpeechToText(video_url)
|
|
|
|
stt = WhisperX(video_url)
|
|
|
|
transcription = stt.transcribe_youtube_video()
|
|
|
|
transcription = stt.transcribe_youtube_video()
|
|
|
|
|
|
|
|
|
|
|
|
assert transcription == "Diarization failed"
|
|
|
|
assert transcription == "Diarization failed"
|
|
|
|