You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/models/distill_whisper.py

282 lines
8.7 KiB

import os
import tempfile
from functools import wraps
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from swarms.models.distilled_whisperx import DistilWhisperModel, async_retry
@pytest.fixture
def distil_whisper_model():
return DistilWhisperModel()
def create_audio_file(data: np.ndarray, sample_rate: int, file_path: str):
data.tofile(file_path)
return file_path
def test_initialization(distil_whisper_model):
assert isinstance(distil_whisper_model, DistilWhisperModel)
assert isinstance(distil_whisper_model.model, torch.nn.Module)
assert isinstance(distil_whisper_model.processor, torch.nn.Module)
assert distil_whisper_model.device in ["cpu", "cuda:0"]
def test_transcribe_audio_file(distil_whisper_model):
test_data = np.random.rand(16000) # Simulated audio data (1 second)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file:
audio_file_path = create_audio_file(test_data, 16000, audio_file.name)
transcription = distil_whisper_model.transcribe(audio_file_path)
os.remove(audio_file_path)
assert isinstance(transcription, str)
assert transcription.strip() != ""
@pytest.mark.asyncio
async def test_async_transcribe_audio_file(distil_whisper_model):
test_data = np.random.rand(16000) # Simulated audio data (1 second)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file:
audio_file_path = create_audio_file(test_data, 16000, audio_file.name)
transcription = await distil_whisper_model.async_transcribe(
audio_file_path
)
os.remove(audio_file_path)
assert isinstance(transcription, str)
assert transcription.strip() != ""
def test_transcribe_audio_data(distil_whisper_model):
test_data = np.random.rand(16000) # Simulated audio data (1 second)
transcription = distil_whisper_model.transcribe(test_data.tobytes())
assert isinstance(transcription, str)
assert transcription.strip() != ""
@pytest.mark.asyncio
async def test_async_transcribe_audio_data(distil_whisper_model):
test_data = np.random.rand(16000) # Simulated audio data (1 second)
transcription = await distil_whisper_model.async_transcribe(
test_data.tobytes()
)
assert isinstance(transcription, str)
assert transcription.strip() != ""
def test_real_time_transcribe(distil_whisper_model, capsys):
test_data = np.random.rand(16000 * 5) # Simulated audio data (5 seconds)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file:
audio_file_path = create_audio_file(test_data, 16000, audio_file.name)
distil_whisper_model.real_time_transcribe(
audio_file_path, chunk_duration=1
)
os.remove(audio_file_path)
captured = capsys.readouterr()
assert "Starting real-time transcription..." in captured.out
assert "Chunk" in captured.out
def test_real_time_transcribe_audio_file_not_found(
distil_whisper_model, capsys
):
audio_file_path = "non_existent_audio.wav"
distil_whisper_model.real_time_transcribe(audio_file_path, chunk_duration=1)
captured = capsys.readouterr()
assert "The audio file was not found." in captured.out
@pytest.fixture
def mock_async_retry():
def _mock_async_retry(retries=3, exceptions=(Exception,), delay=1):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
return await func(*args, **kwargs)
return wrapper
return decorator
with patch("distil_whisper_model.async_retry", new=_mock_async_retry()):
yield
@pytest.mark.asyncio
async def test_async_retry_decorator_success():
async def mock_async_function():
return "Success"
decorated_function = async_retry()(mock_async_function)
result = await decorated_function()
assert result == "Success"
@pytest.mark.asyncio
async def test_async_retry_decorator_failure():
async def mock_async_function():
raise Exception("Error")
decorated_function = async_retry()(mock_async_function)
with pytest.raises(Exception, match="Error"):
await decorated_function()
@pytest.mark.asyncio
async def test_async_retry_decorator_multiple_attempts():
async def mock_async_function():
if mock_async_function.attempts == 0:
mock_async_function.attempts += 1
raise Exception("Error")
else:
return "Success"
mock_async_function.attempts = 0
decorated_function = async_retry(max_retries=2)(mock_async_function)
result = await decorated_function()
assert result == "Success"
def test_create_audio_file():
test_data = np.random.rand(16000) # Simulated audio data (1 second)
sample_rate = 16000
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file:
audio_file_path = create_audio_file(
test_data, sample_rate, audio_file.name
)
assert os.path.exists(audio_file_path)
os.remove(audio_file_path)
# test_distilled_whisperx.py
# Fixtures for setting up model, processor, and audio files
@pytest.fixture(scope="module")
def model_id():
return "distil-whisper/distil-large-v2"
@pytest.fixture(scope="module")
def whisper_model(model_id):
return DistilWhisperModel(model_id)
@pytest.fixture(scope="session")
def audio_file_path(tmp_path_factory):
# You would create a small temporary MP3 file here for testing
# or use a public domain MP3 file's path
return "path/to/valid_audio.mp3"
@pytest.fixture(scope="session")
def invalid_audio_file_path():
return "path/to/invalid_audio.mp3"
@pytest.fixture(scope="session")
def audio_dict():
# This should represent a valid audio dictionary as expected by the model
return {"array": torch.randn(1, 16000), "sampling_rate": 16000}
# Test initialization
def test_initialization(whisper_model):
assert whisper_model.model is not None
assert whisper_model.processor is not None
# Test successful transcription with file path
def test_transcribe_with_file_path(whisper_model, audio_file_path):
transcription = whisper_model.transcribe(audio_file_path)
assert isinstance(transcription, str)
# Test successful transcription with audio dict
def test_transcribe_with_audio_dict(whisper_model, audio_dict):
transcription = whisper_model.transcribe(audio_dict)
assert isinstance(transcription, str)
# Test for file not found error
def test_file_not_found(whisper_model, invalid_audio_file_path):
with pytest.raises(Exception):
whisper_model.transcribe(invalid_audio_file_path)
# Asynchronous tests
@pytest.mark.asyncio
async def test_async_transcription_success(whisper_model, audio_file_path):
transcription = await whisper_model.async_transcribe(audio_file_path)
assert isinstance(transcription, str)
@pytest.mark.asyncio
async def test_async_transcription_failure(
whisper_model, invalid_audio_file_path
):
with pytest.raises(Exception):
await whisper_model.async_transcribe(invalid_audio_file_path)
# Testing real-time transcription simulation
def test_real_time_transcription(whisper_model, audio_file_path, capsys):
whisper_model.real_time_transcribe(audio_file_path, chunk_duration=1)
captured = capsys.readouterr()
assert "Starting real-time transcription..." in captured.out
# Testing retry decorator for asynchronous function
@pytest.mark.asyncio
async def test_async_retry():
@async_retry(max_retries=2, exceptions=(ValueError,), delay=0)
async def failing_func():
raise ValueError("Test")
with pytest.raises(ValueError):
await failing_func()
# Mocking the actual model to avoid GPU/CPU intensive operations during test
@pytest.fixture
def mocked_model(monkeypatch):
model_mock = AsyncMock(AutoModelForSpeechSeq2Seq)
processor_mock = MagicMock(AutoProcessor)
monkeypatch.setattr(
"swarms.models.distilled_whisperx.AutoModelForSpeechSeq2Seq.from_pretrained",
model_mock,
)
monkeypatch.setattr(
"swarms.models.distilled_whisperx.AutoProcessor.from_pretrained",
processor_mock,
)
return model_mock, processor_mock
@pytest.mark.asyncio
async def test_async_transcribe_with_mocked_model(
mocked_model, audio_file_path
):
model_mock, processor_mock = mocked_model
# Set up what the mock should return when it's called
model_mock.return_value.generate.return_value = torch.tensor([[0]])
processor_mock.return_value.batch_decode.return_value = [
"mocked transcription"
]
model_wrapper = DistilWhisperModel()
transcription = await model_wrapper.async_transcribe(audio_file_path)
assert transcription == "mocked transcription"