You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/models/test_distill_whisper.py

337 lines
9.1 KiB

import os
import tempfile
from functools import wraps
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
1 year ago
from swarms.models.distilled_whisperx import (
DistilWhisperModel,
async_retry,
)
@pytest.fixture
def distil_whisper_model():
return DistilWhisperModel()
1 year ago
def create_audio_file(
data: np.ndarray, sample_rate: int, file_path: str
):
data.tofile(file_path)
return file_path
def test_initialization(distil_whisper_model):
assert isinstance(distil_whisper_model, DistilWhisperModel)
assert isinstance(distil_whisper_model.model, torch.nn.Module)
assert isinstance(distil_whisper_model.processor, torch.nn.Module)
assert distil_whisper_model.device in ["cpu", "cuda:0"]
def test_transcribe_audio_file(distil_whisper_model):
1 year ago
test_data = np.random.rand(
16000
) # Simulated audio data (1 second)
with tempfile.NamedTemporaryFile(
suffix=".wav", delete=False
) as audio_file:
audio_file_path = create_audio_file(
test_data, 16000, audio_file.name
)
transcription = distil_whisper_model.transcribe(
audio_file_path
)
os.remove(audio_file_path)
assert isinstance(transcription, str)
assert transcription.strip() != ""
@pytest.mark.asyncio
async def test_async_transcribe_audio_file(distil_whisper_model):
1 year ago
test_data = np.random.rand(
16000
) # Simulated audio data (1 second)
with tempfile.NamedTemporaryFile(
suffix=".wav", delete=False
) as audio_file:
audio_file_path = create_audio_file(
test_data, 16000, audio_file.name
)
transcription = await distil_whisper_model.async_transcribe(
audio_file_path
)
os.remove(audio_file_path)
assert isinstance(transcription, str)
assert transcription.strip() != ""
def test_transcribe_audio_data(distil_whisper_model):
1 year ago
test_data = np.random.rand(
16000
) # Simulated audio data (1 second)
transcription = distil_whisper_model.transcribe(
test_data.tobytes()
)
assert isinstance(transcription, str)
assert transcription.strip() != ""
@pytest.mark.asyncio
async def test_async_transcribe_audio_data(distil_whisper_model):
1 year ago
test_data = np.random.rand(
16000
) # Simulated audio data (1 second)
transcription = await distil_whisper_model.async_transcribe(
test_data.tobytes()
)
assert isinstance(transcription, str)
assert transcription.strip() != ""
def test_real_time_transcribe(distil_whisper_model, capsys):
1 year ago
test_data = np.random.rand(
16000 * 5
) # Simulated audio data (5 seconds)
with tempfile.NamedTemporaryFile(
suffix=".wav", delete=False
) as audio_file:
audio_file_path = create_audio_file(
test_data, 16000, audio_file.name
)
distil_whisper_model.real_time_transcribe(
audio_file_path, chunk_duration=1
)
os.remove(audio_file_path)
captured = capsys.readouterr()
assert "Starting real-time transcription..." in captured.out
assert "Chunk" in captured.out
def test_real_time_transcribe_audio_file_not_found(
distil_whisper_model, capsys
):
audio_file_path = "non_existent_audio.wav"
1 year ago
distil_whisper_model.real_time_transcribe(
audio_file_path, chunk_duration=1
)
captured = capsys.readouterr()
assert "The audio file was not found." in captured.out
@pytest.fixture
def mock_async_retry():
1 year ago
def _mock_async_retry(
retries=3, exceptions=(Exception,), delay=1
):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
return await func(*args, **kwargs)
return wrapper
return decorator
1 year ago
with patch(
"distil_whisper_model.async_retry", new=_mock_async_retry()
):
yield
@pytest.mark.asyncio
async def test_async_retry_decorator_success():
async def mock_async_function():
return "Success"
decorated_function = async_retry()(mock_async_function)
result = await decorated_function()
assert result == "Success"
@pytest.mark.asyncio
async def test_async_retry_decorator_failure():
async def mock_async_function():
raise Exception("Error")
decorated_function = async_retry()(mock_async_function)
with pytest.raises(Exception, match="Error"):
await decorated_function()
@pytest.mark.asyncio
async def test_async_retry_decorator_multiple_attempts():
async def mock_async_function():
if mock_async_function.attempts == 0:
mock_async_function.attempts += 1
raise Exception("Error")
else:
return "Success"
mock_async_function.attempts = 0
1 year ago
decorated_function = async_retry(max_retries=2)(
mock_async_function
)
result = await decorated_function()
assert result == "Success"
def test_create_audio_file():
1 year ago
test_data = np.random.rand(
16000
) # Simulated audio data (1 second)
sample_rate = 16000
1 year ago
with tempfile.NamedTemporaryFile(
suffix=".wav", delete=False
) as audio_file:
audio_file_path = create_audio_file(
test_data, sample_rate, audio_file.name
)
assert os.path.exists(audio_file_path)
os.remove(audio_file_path)
# test_distilled_whisperx.py
# Fixtures for setting up model, processor, and audio files
@pytest.fixture(scope="module")
def model_id():
return "distil-whisper/distil-large-v2"
@pytest.fixture(scope="module")
def whisper_model(model_id):
return DistilWhisperModel(model_id)
@pytest.fixture(scope="session")
def audio_file_path(tmp_path_factory):
# You would create a small temporary MP3 file here for testing
# or use a public domain MP3 file's path
return "path/to/valid_audio.mp3"
@pytest.fixture(scope="session")
def invalid_audio_file_path():
return "path/to/invalid_audio.mp3"
@pytest.fixture(scope="session")
def audio_dict():
# This should represent a valid audio dictionary as expected by the model
return {"array": torch.randn(1, 16000), "sampling_rate": 16000}
# Test initialization
def test_initialization(whisper_model):
assert whisper_model.model is not None
assert whisper_model.processor is not None
# Test successful transcription with file path
def test_transcribe_with_file_path(whisper_model, audio_file_path):
transcription = whisper_model.transcribe(audio_file_path)
assert isinstance(transcription, str)
# Test successful transcription with audio dict
def test_transcribe_with_audio_dict(whisper_model, audio_dict):
transcription = whisper_model.transcribe(audio_dict)
assert isinstance(transcription, str)
# Test for file not found error
def test_file_not_found(whisper_model, invalid_audio_file_path):
with pytest.raises(Exception):
whisper_model.transcribe(invalid_audio_file_path)
# Asynchronous tests
@pytest.mark.asyncio
1 year ago
async def test_async_transcription_success(
whisper_model, audio_file_path
):
transcription = await whisper_model.async_transcribe(
audio_file_path
)
assert isinstance(transcription, str)
@pytest.mark.asyncio
async def test_async_transcription_failure(
whisper_model, invalid_audio_file_path
):
with pytest.raises(Exception):
await whisper_model.async_transcribe(invalid_audio_file_path)
# Testing real-time transcription simulation
1 year ago
def test_real_time_transcription(
whisper_model, audio_file_path, capsys
):
whisper_model.real_time_transcribe(
audio_file_path, chunk_duration=1
)
captured = capsys.readouterr()
assert "Starting real-time transcription..." in captured.out
# Testing retry decorator for asynchronous function
@pytest.mark.asyncio
async def test_async_retry():
@async_retry(max_retries=2, exceptions=(ValueError,), delay=0)
async def failing_func():
raise ValueError("Test")
with pytest.raises(ValueError):
await failing_func()
# Mocking the actual model to avoid GPU/CPU intensive operations during test
@pytest.fixture
def mocked_model(monkeypatch):
model_mock = AsyncMock(AutoModelForSpeechSeq2Seq)
processor_mock = MagicMock(AutoProcessor)
monkeypatch.setattr(
"swarms.models.distilled_whisperx.AutoModelForSpeechSeq2Seq.from_pretrained",
model_mock,
)
monkeypatch.setattr(
"swarms.models.distilled_whisperx.AutoProcessor.from_pretrained",
processor_mock,
)
return model_mock, processor_mock
@pytest.mark.asyncio
async def test_async_transcribe_with_mocked_model(
mocked_model, audio_file_path
):
model_mock, processor_mock = mocked_model
# Set up what the mock should return when it's called
1 year ago
model_mock.return_value.generate.return_value = torch.tensor(
[[0]]
)
processor_mock.return_value.batch_decode.return_value = [
"mocked transcription"
]
model_wrapper = DistilWhisperModel()
1 year ago
transcription = await model_wrapper.async_transcribe(
audio_file_path
)
assert transcription == "mocked transcription"