import os import tempfile from functools import wraps from unittest.mock import AsyncMock, MagicMock, patch import numpy as np import pytest import torch from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor from swarms.models.distilled_whisperx import ( DistilWhisperModel, async_retry, ) @pytest.fixture def distil_whisper_model(): return DistilWhisperModel() def create_audio_file( data: np.ndarray, sample_rate: int, file_path: str ): data.tofile(file_path) return file_path def test_initialization(distil_whisper_model): assert isinstance(distil_whisper_model, DistilWhisperModel) assert isinstance(distil_whisper_model.model, torch.nn.Module) assert isinstance(distil_whisper_model.processor, torch.nn.Module) assert distil_whisper_model.device in ["cpu", "cuda:0"] def test_transcribe_audio_file(distil_whisper_model): test_data = np.random.rand( 16000 ) # Simulated audio data (1 second) with tempfile.NamedTemporaryFile( suffix=".wav", delete=False ) as audio_file: audio_file_path = create_audio_file( test_data, 16000, audio_file.name ) transcription = distil_whisper_model.transcribe( audio_file_path ) os.remove(audio_file_path) assert isinstance(transcription, str) assert transcription.strip() != "" @pytest.mark.asyncio async def test_async_transcribe_audio_file(distil_whisper_model): test_data = np.random.rand( 16000 ) # Simulated audio data (1 second) with tempfile.NamedTemporaryFile( suffix=".wav", delete=False ) as audio_file: audio_file_path = create_audio_file( test_data, 16000, audio_file.name ) transcription = await distil_whisper_model.async_transcribe( audio_file_path ) os.remove(audio_file_path) assert isinstance(transcription, str) assert transcription.strip() != "" def test_transcribe_audio_data(distil_whisper_model): test_data = np.random.rand( 16000 ) # Simulated audio data (1 second) transcription = distil_whisper_model.transcribe( test_data.tobytes() ) assert isinstance(transcription, str) assert transcription.strip() != "" @pytest.mark.asyncio async def test_async_transcribe_audio_data(distil_whisper_model): test_data = np.random.rand( 16000 ) # Simulated audio data (1 second) transcription = await distil_whisper_model.async_transcribe( test_data.tobytes() ) assert isinstance(transcription, str) assert transcription.strip() != "" def test_real_time_transcribe(distil_whisper_model, capsys): test_data = np.random.rand( 16000 * 5 ) # Simulated audio data (5 seconds) with tempfile.NamedTemporaryFile( suffix=".wav", delete=False ) as audio_file: audio_file_path = create_audio_file( test_data, 16000, audio_file.name ) distil_whisper_model.real_time_transcribe( audio_file_path, chunk_duration=1 ) os.remove(audio_file_path) captured = capsys.readouterr() assert "Starting real-time transcription..." in captured.out assert "Chunk" in captured.out def test_real_time_transcribe_audio_file_not_found( distil_whisper_model, capsys ): audio_file_path = "non_existent_audio.wav" distil_whisper_model.real_time_transcribe( audio_file_path, chunk_duration=1 ) captured = capsys.readouterr() assert "The audio file was not found." in captured.out @pytest.fixture def mock_async_retry(): def _mock_async_retry( retries=3, exceptions=(Exception,), delay=1 ): def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): return await func(*args, **kwargs) return wrapper return decorator with patch( "distil_whisper_model.async_retry", new=_mock_async_retry() ): yield @pytest.mark.asyncio async def test_async_retry_decorator_success(): async def mock_async_function(): return "Success" decorated_function = async_retry()(mock_async_function) result = await decorated_function() assert result == "Success" @pytest.mark.asyncio async def test_async_retry_decorator_failure(): async def mock_async_function(): raise Exception("Error") decorated_function = async_retry()(mock_async_function) with pytest.raises(Exception, match="Error"): await decorated_function() @pytest.mark.asyncio async def test_async_retry_decorator_multiple_attempts(): async def mock_async_function(): if mock_async_function.attempts == 0: mock_async_function.attempts += 1 raise Exception("Error") else: return "Success" mock_async_function.attempts = 0 decorated_function = async_retry(max_retries=2)( mock_async_function ) result = await decorated_function() assert result == "Success" def test_create_audio_file(): test_data = np.random.rand( 16000 ) # Simulated audio data (1 second) sample_rate = 16000 with tempfile.NamedTemporaryFile( suffix=".wav", delete=False ) as audio_file: audio_file_path = create_audio_file( test_data, sample_rate, audio_file.name ) assert os.path.exists(audio_file_path) os.remove(audio_file_path) # test_distilled_whisperx.py # Fixtures for setting up model, processor, and audio files @pytest.fixture(scope="module") def model_id(): return "distil-whisper/distil-large-v2" @pytest.fixture(scope="module") def whisper_model(model_id): return DistilWhisperModel(model_id) @pytest.fixture(scope="session") def audio_file_path(tmp_path_factory): # You would create a small temporary MP3 file here for testing # or use a public domain MP3 file's path return "path/to/valid_audio.mp3" @pytest.fixture(scope="session") def invalid_audio_file_path(): return "path/to/invalid_audio.mp3" @pytest.fixture(scope="session") def audio_dict(): # This should represent a valid audio dictionary as expected by the model return {"array": torch.randn(1, 16000), "sampling_rate": 16000} # Test initialization def test_initialization(whisper_model): assert whisper_model.model is not None assert whisper_model.processor is not None # Test successful transcription with file path def test_transcribe_with_file_path(whisper_model, audio_file_path): transcription = whisper_model.transcribe(audio_file_path) assert isinstance(transcription, str) # Test successful transcription with audio dict def test_transcribe_with_audio_dict(whisper_model, audio_dict): transcription = whisper_model.transcribe(audio_dict) assert isinstance(transcription, str) # Test for file not found error def test_file_not_found(whisper_model, invalid_audio_file_path): with pytest.raises(Exception): whisper_model.transcribe(invalid_audio_file_path) # Asynchronous tests @pytest.mark.asyncio async def test_async_transcription_success( whisper_model, audio_file_path ): transcription = await whisper_model.async_transcribe( audio_file_path ) assert isinstance(transcription, str) @pytest.mark.asyncio async def test_async_transcription_failure( whisper_model, invalid_audio_file_path ): with pytest.raises(Exception): await whisper_model.async_transcribe(invalid_audio_file_path) # Testing real-time transcription simulation def test_real_time_transcription( whisper_model, audio_file_path, capsys ): whisper_model.real_time_transcribe( audio_file_path, chunk_duration=1 ) captured = capsys.readouterr() assert "Starting real-time transcription..." in captured.out # Testing retry decorator for asynchronous function @pytest.mark.asyncio async def test_async_retry(): @async_retry(max_retries=2, exceptions=(ValueError,), delay=0) async def failing_func(): raise ValueError("Test") with pytest.raises(ValueError): await failing_func() # Mocking the actual model to avoid GPU/CPU intensive operations during test @pytest.fixture def mocked_model(monkeypatch): model_mock = AsyncMock(AutoModelForSpeechSeq2Seq) processor_mock = MagicMock(AutoProcessor) monkeypatch.setattr( "swarms.models.distilled_whisperx.AutoModelForSpeechSeq2Seq.from_pretrained", model_mock, ) monkeypatch.setattr( "swarms.models.distilled_whisperx.AutoProcessor.from_pretrained", processor_mock, ) return model_mock, processor_mock @pytest.mark.asyncio async def test_async_transcribe_with_mocked_model( mocked_model, audio_file_path ): model_mock, processor_mock = mocked_model # Set up what the mock should return when it's called model_mock.return_value.generate.return_value = torch.tensor( [[0]] ) processor_mock.return_value.batch_decode.return_value = [ "mocked transcription" ] model_wrapper = DistilWhisperModel() transcription = await model_wrapper.async_transcribe( audio_file_path ) assert transcription == "mocked transcription"