import os import tempfile from functools import wraps from unittest.mock import patch import numpy as np import pytest import torch from swarms.models.distill_whisperx import DistilWhisperModel, async_retry @pytest.fixture def distil_whisper_model(): return DistilWhisperModel() def create_audio_file(data: np.ndarray, sample_rate: int, file_path: str): data.tofile(file_path) return file_path def test_initialization(distil_whisper_model): assert isinstance(distil_whisper_model, DistilWhisperModel) assert isinstance(distil_whisper_model.model, torch.nn.Module) assert isinstance(distil_whisper_model.processor, torch.nn.Module) assert distil_whisper_model.device in ["cpu", "cuda:0"] def test_transcribe_audio_file(distil_whisper_model): test_data = np.random.rand(16000) # Simulated audio data (1 second) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file: audio_file_path = create_audio_file(test_data, 16000, audio_file.name) transcription = distil_whisper_model.transcribe(audio_file_path) os.remove(audio_file_path) assert isinstance(transcription, str) assert transcription.strip() != "" @pytest.mark.asyncio async def test_async_transcribe_audio_file(distil_whisper_model): test_data = np.random.rand(16000) # Simulated audio data (1 second) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file: audio_file_path = create_audio_file(test_data, 16000, audio_file.name) transcription = await distil_whisper_model.async_transcribe(audio_file_path) os.remove(audio_file_path) assert isinstance(transcription, str) assert transcription.strip() != "" def test_transcribe_audio_data(distil_whisper_model): test_data = np.random.rand(16000) # Simulated audio data (1 second) transcription = distil_whisper_model.transcribe(test_data.tobytes()) assert isinstance(transcription, str) assert transcription.strip() != "" @pytest.mark.asyncio async def test_async_transcribe_audio_data(distil_whisper_model): test_data = np.random.rand(16000) # Simulated audio data (1 second) transcription = await distil_whisper_model.async_transcribe(test_data.tobytes()) assert isinstance(transcription, str) assert transcription.strip() != "" def test_real_time_transcribe(distil_whisper_model, capsys): test_data = np.random.rand(16000 * 5) # Simulated audio data (5 seconds) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file: audio_file_path = create_audio_file(test_data, 16000, audio_file.name) distil_whisper_model.real_time_transcribe(audio_file_path, chunk_duration=1) os.remove(audio_file_path) captured = capsys.readouterr() assert "Starting real-time transcription..." in captured.out assert "Chunk" in captured.out def test_real_time_transcribe_audio_file_not_found(distil_whisper_model, capsys): audio_file_path = "non_existent_audio.wav" distil_whisper_model.real_time_transcribe(audio_file_path, chunk_duration=1) captured = capsys.readouterr() assert "The audio file was not found." in captured.out @pytest.fixture def mock_async_retry(): def _mock_async_retry(retries=3, exceptions=(Exception,), delay=1): def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): return await func(*args, **kwargs) return wrapper return decorator with patch("distil_whisper_model.async_retry", new=_mock_async_retry()): yield @pytest.mark.asyncio async def test_async_retry_decorator_success(): async def mock_async_function(): return "Success" decorated_function = async_retry()(mock_async_function) result = await decorated_function() assert result == "Success" @pytest.mark.asyncio async def test_async_retry_decorator_failure(): async def mock_async_function(): raise Exception("Error") decorated_function = async_retry()(mock_async_function) with pytest.raises(Exception, match="Error"): await decorated_function() @pytest.mark.asyncio async def test_async_retry_decorator_multiple_attempts(): async def mock_async_function(): if mock_async_function.attempts == 0: mock_async_function.attempts += 1 raise Exception("Error") else: return "Success" mock_async_function.attempts = 0 decorated_function = async_retry(max_retries=2)(mock_async_function) result = await decorated_function() assert result == "Success" def test_create_audio_file(): test_data = np.random.rand(16000) # Simulated audio data (1 second) sample_rate = 16000 with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file: audio_file_path = create_audio_file(test_data, sample_rate, audio_file.name) assert os.path.exists(audio_file_path) os.remove(audio_file_path) if __name__ == "__main__": pytest.main()