You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
205 lines
6.5 KiB
205 lines
6.5 KiB
import os
|
|
import subprocess
|
|
import tempfile
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import whisperx
|
|
from pydub import AudioSegment
|
|
from pytube import YouTube
|
|
from swarms.models.whisperx import WhisperX
|
|
|
|
|
|
# Fixture to create a temporary directory for testing
|
|
@pytest.fixture
|
|
def temp_dir():
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
yield tempdir
|
|
|
|
|
|
# Mock subprocess.run to prevent actual installation during tests
|
|
@patch.object(subprocess, "run")
|
|
def test_speech_to_text_install(mock_run):
|
|
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
|
stt.install()
|
|
mock_run.assert_called_with(["pip", "install", "whisperx"])
|
|
|
|
|
|
# Mock pytube.YouTube and pytube.Streams for download tests
|
|
@patch("pytube.YouTube")
|
|
@patch.object(YouTube, "streams")
|
|
def test_speech_to_text_download_youtube_video(mock_streams, mock_youtube, temp_dir):
|
|
# Mock YouTube and streams
|
|
video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM"
|
|
mock_stream = mock_streams().filter().first()
|
|
mock_stream.download.return_value = os.path.join(temp_dir, "video.mp4")
|
|
mock_youtube.return_value = mock_youtube
|
|
mock_youtube.streams = mock_streams
|
|
|
|
stt = WhisperX(video_url)
|
|
audio_file = stt.download_youtube_video()
|
|
|
|
assert os.path.exists(audio_file)
|
|
assert audio_file.endswith(".mp3")
|
|
|
|
|
|
# Mock whisperx.load_model and whisperx.load_audio for transcribe tests
|
|
@patch("whisperx.load_model")
|
|
@patch("whisperx.load_audio")
|
|
@patch("whisperx.load_align_model")
|
|
@patch("whisperx.align")
|
|
@patch.object(whisperx.DiarizationPipeline, "__call__")
|
|
def test_speech_to_text_transcribe_youtube_video(
|
|
mock_diarization,
|
|
mock_align,
|
|
mock_align_model,
|
|
mock_load_audio,
|
|
mock_load_model,
|
|
temp_dir,
|
|
):
|
|
# Mock whisperx functions
|
|
mock_load_model.return_value = mock_load_model
|
|
mock_load_model.transcribe.return_value = {
|
|
"language": "en",
|
|
"segments": [{"text": "Hello, World!"}],
|
|
}
|
|
|
|
mock_load_audio.return_value = "audio_path"
|
|
mock_align_model.return_value = (mock_align_model, "metadata")
|
|
mock_align.return_value = {"segments": [{"text": "Hello, World!"}]}
|
|
|
|
# Mock diarization pipeline
|
|
mock_diarization.return_value = None
|
|
|
|
video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM/video"
|
|
stt = WhisperX(video_url)
|
|
transcription = stt.transcribe_youtube_video()
|
|
|
|
assert transcription == "Hello, World!"
|
|
|
|
|
|
# More tests for different scenarios and edge cases can be added here.
|
|
|
|
|
|
# Test transcribe method with provided audio file
|
|
def test_speech_to_text_transcribe_audio_file(temp_dir):
|
|
# Create a temporary audio file
|
|
audio_file = os.path.join(temp_dir, "test_audio.mp3")
|
|
AudioSegment.silent(duration=500).export(audio_file, format="mp3")
|
|
|
|
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
|
transcription = stt.transcribe(audio_file)
|
|
|
|
assert transcription == ""
|
|
|
|
|
|
# Test transcribe method when Whisperx fails
|
|
@patch("whisperx.load_model")
|
|
@patch("whisperx.load_audio")
|
|
def test_speech_to_text_transcribe_whisperx_failure(
|
|
mock_load_audio, mock_load_model, temp_dir
|
|
):
|
|
# Mock whisperx functions to raise an exception
|
|
mock_load_model.side_effect = Exception("Whisperx failed")
|
|
mock_load_audio.return_value = "audio_path"
|
|
|
|
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
|
transcription = stt.transcribe("audio_path")
|
|
|
|
assert transcription == "Whisperx failed"
|
|
|
|
|
|
# Test transcribe method with missing 'segments' key in Whisperx output
|
|
@patch("whisperx.load_model")
|
|
@patch("whisperx.load_audio")
|
|
@patch("whisperx.load_align_model")
|
|
@patch("whisperx.align")
|
|
@patch.object(whisperx.DiarizationPipeline, "__call__")
|
|
def test_speech_to_text_transcribe_missing_segments(
|
|
mock_diarization, mock_align, mock_align_model, mock_load_audio, mock_load_model
|
|
):
|
|
# Mock whisperx functions to return incomplete output
|
|
mock_load_model.return_value = mock_load_model
|
|
mock_load_model.transcribe.return_value = {"language": "en"}
|
|
|
|
mock_load_audio.return_value = "audio_path"
|
|
mock_align_model.return_value = (mock_align_model, "metadata")
|
|
mock_align.return_value = {}
|
|
|
|
# Mock diarization pipeline
|
|
mock_diarization.return_value = None
|
|
|
|
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
|
transcription = stt.transcribe("audio_path")
|
|
|
|
assert transcription == ""
|
|
|
|
|
|
# Test transcribe method with Whisperx align failure
|
|
@patch("whisperx.load_model")
|
|
@patch("whisperx.load_audio")
|
|
@patch("whisperx.load_align_model")
|
|
@patch("whisperx.align")
|
|
@patch.object(whisperx.DiarizationPipeline, "__call__")
|
|
def test_speech_to_text_transcribe_align_failure(
|
|
mock_diarization, mock_align, mock_align_model, mock_load_audio, mock_load_model
|
|
):
|
|
# Mock whisperx functions to raise an exception during align
|
|
mock_load_model.return_value = mock_load_model
|
|
mock_load_model.transcribe.return_value = {
|
|
"language": "en",
|
|
"segments": [{"text": "Hello, World!"}],
|
|
}
|
|
|
|
mock_load_audio.return_value = "audio_path"
|
|
mock_align_model.return_value = (mock_align_model, "metadata")
|
|
mock_align.side_effect = Exception("Align failed")
|
|
|
|
# Mock diarization pipeline
|
|
mock_diarization.return_value = None
|
|
|
|
stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM")
|
|
transcription = stt.transcribe("audio_path")
|
|
|
|
assert transcription == "Align failed"
|
|
|
|
|
|
# Test transcribe_youtube_video when Whisperx diarization fails
|
|
@patch("pytube.YouTube")
|
|
@patch.object(YouTube, "streams")
|
|
@patch("whisperx.DiarizationPipeline")
|
|
@patch("whisperx.load_audio")
|
|
@patch("whisperx.load_align_model")
|
|
@patch("whisperx.align")
|
|
def test_speech_to_text_transcribe_diarization_failure(
|
|
mock_align,
|
|
mock_align_model,
|
|
mock_load_audio,
|
|
mock_diarization,
|
|
mock_streams,
|
|
mock_youtube,
|
|
temp_dir,
|
|
):
|
|
# Mock YouTube and streams
|
|
video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM"
|
|
mock_stream = mock_streams().filter().first()
|
|
mock_stream.download.return_value = os.path.join(temp_dir, "video.mp4")
|
|
mock_youtube.return_value = mock_youtube
|
|
mock_youtube.streams = mock_streams
|
|
|
|
# Mock whisperx functions
|
|
mock_load_audio.return_value = "audio_path"
|
|
mock_align_model.return_value = (mock_align_model, "metadata")
|
|
mock_align.return_value = {"segments": [{"text": "Hello, World!"}]}
|
|
|
|
# Mock diarization pipeline to raise an exception
|
|
mock_diarization.side_effect = Exception("Diarization failed")
|
|
|
|
stt = WhisperX(video_url)
|
|
transcription = stt.transcribe_youtube_video()
|
|
|
|
assert transcription == "Diarization failed"
|
|
|
|
|
|
# Add more tests for other scenarios and edge cases as needed.
|