diff --git a/swarms/models/__init__.py b/swarms/models/__init__.py index 4cb61b9a..a0bec07f 100644 --- a/swarms/models/__init__.py +++ b/swarms/models/__init__.py @@ -16,6 +16,7 @@ from swarms.models.kosmos_two import Kosmos from swarms.models.vilt import Vilt from swarms.models.nougat import Nougat from swarms.models.layoutlm_document_qa import LayoutLMDocumentQA + # from swarms.models.distilled_whisperx import DistilWhisperModel # from swarms.models.fuyu import Fuyu # Not working, wait until they update diff --git a/tests/models/distilled_whisperx.py b/tests/models/distilled_whisperx.py new file mode 100644 index 00000000..bab8cd0e --- /dev/null +++ b/tests/models/distilled_whisperx.py @@ -0,0 +1,120 @@ +# test_distilled_whisperx.py + +from unittest.mock import AsyncMock, MagicMock + +import pytest +import torch +from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor + +from swarms.models.distilled_whisperx import DistilWhisperModel, async_retry + + +# Fixtures for setting up model, processor, and audio files +@pytest.fixture(scope="module") +def model_id(): + return "distil-whisper/distil-large-v2" + + +@pytest.fixture(scope="module") +def whisper_model(model_id): + return DistilWhisperModel(model_id) + + +@pytest.fixture(scope="session") +def audio_file_path(tmp_path_factory): + # You would create a small temporary MP3 file here for testing + # or use a public domain MP3 file's path + return "path/to/valid_audio.mp3" + + +@pytest.fixture(scope="session") +def invalid_audio_file_path(): + return "path/to/invalid_audio.mp3" + + +@pytest.fixture(scope="session") +def audio_dict(): + # This should represent a valid audio dictionary as expected by the model + return {"array": torch.randn(1, 16000), "sampling_rate": 16000} + + +# Test initialization +def test_initialization(whisper_model): + assert whisper_model.model is not None + assert whisper_model.processor is not None + + +# Test successful transcription with file path +def test_transcribe_with_file_path(whisper_model, audio_file_path): + transcription = whisper_model.transcribe(audio_file_path) + assert isinstance(transcription, str) + + +# Test successful transcription with audio dict +def test_transcribe_with_audio_dict(whisper_model, audio_dict): + transcription = whisper_model.transcribe(audio_dict) + assert isinstance(transcription, str) + + +# Test for file not found error +def test_file_not_found(whisper_model, invalid_audio_file_path): + with pytest.raises(Exception): + whisper_model.transcribe(invalid_audio_file_path) + + +# Asynchronous tests +@pytest.mark.asyncio +async def test_async_transcription_success(whisper_model, audio_file_path): + transcription = await whisper_model.async_transcribe(audio_file_path) + assert isinstance(transcription, str) + + +@pytest.mark.asyncio +async def test_async_transcription_failure(whisper_model, invalid_audio_file_path): + with pytest.raises(Exception): + await whisper_model.async_transcribe(invalid_audio_file_path) + + +# Testing real-time transcription simulation +def test_real_time_transcription(whisper_model, audio_file_path, capsys): + whisper_model.real_time_transcribe(audio_file_path, chunk_duration=1) + captured = capsys.readouterr() + assert "Starting real-time transcription..." in captured.out + + +# Testing retry decorator for asynchronous function +@pytest.mark.asyncio +async def test_async_retry(): + @async_retry(max_retries=2, exceptions=(ValueError,), delay=0) + async def failing_func(): + raise ValueError("Test") + + with pytest.raises(ValueError): + await failing_func() + + +# Mocking the actual model to avoid GPU/CPU intensive operations during test +@pytest.fixture +def mocked_model(monkeypatch): + model_mock = AsyncMock(AutoModelForSpeechSeq2Seq) + processor_mock = MagicMock(AutoProcessor) + monkeypatch.setattr( + "swarms.models.distilled_whisperx.AutoModelForSpeechSeq2Seq.from_pretrained", + model_mock, + ) + monkeypatch.setattr( + "swarms.models.distilled_whisperx.AutoProcessor.from_pretrained", processor_mock + ) + return model_mock, processor_mock + + +@pytest.mark.asyncio +async def test_async_transcribe_with_mocked_model(mocked_model, audio_file_path): + model_mock, processor_mock = mocked_model + # Set up what the mock should return when it's called + model_mock.return_value.generate.return_value = torch.tensor([[0]]) + processor_mock.return_value.batch_decode.return_value = ["mocked transcription"] + model_wrapper = DistilWhisperModel() + transcription = await model_wrapper.async_transcribe(audio_file_path) + assert transcription == "mocked transcription" +