import os
import tempfile
from functools import wraps
from unittest.mock import AsyncMock, MagicMock, patch

import numpy as np
import pytest
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor

from swarms.models.distilled_whisperx import (
    DistilWhisperModel,
    async_retry,
)


@pytest.fixture
def distil_whisper_model():
    return DistilWhisperModel()


def create_audio_file(
    data: np.ndarray, sample_rate: int, file_path: str
):
    data.tofile(file_path)
    return file_path


def test_initialization(distil_whisper_model):
    assert isinstance(distil_whisper_model, DistilWhisperModel)
    assert isinstance(distil_whisper_model.model, torch.nn.Module)
    assert isinstance(distil_whisper_model.processor, torch.nn.Module)
    assert distil_whisper_model.device in ["cpu", "cuda:0"]


def test_transcribe_audio_file(distil_whisper_model):
    test_data = np.random.rand(
        16000
    )  # Simulated audio data (1 second)
    with tempfile.NamedTemporaryFile(
        suffix=".wav", delete=False
    ) as audio_file:
        audio_file_path = create_audio_file(
            test_data, 16000, audio_file.name
        )
        transcription = distil_whisper_model.transcribe(
            audio_file_path
        )
        os.remove(audio_file_path)

    assert isinstance(transcription, str)
    assert transcription.strip() != ""


@pytest.mark.asyncio
async def test_async_transcribe_audio_file(distil_whisper_model):
    test_data = np.random.rand(
        16000
    )  # Simulated audio data (1 second)
    with tempfile.NamedTemporaryFile(
        suffix=".wav", delete=False
    ) as audio_file:
        audio_file_path = create_audio_file(
            test_data, 16000, audio_file.name
        )
        transcription = await distil_whisper_model.async_transcribe(
            audio_file_path
        )
        os.remove(audio_file_path)

    assert isinstance(transcription, str)
    assert transcription.strip() != ""


def test_transcribe_audio_data(distil_whisper_model):
    test_data = np.random.rand(
        16000
    )  # Simulated audio data (1 second)
    transcription = distil_whisper_model.transcribe(
        test_data.tobytes()
    )

    assert isinstance(transcription, str)
    assert transcription.strip() != ""


@pytest.mark.asyncio
async def test_async_transcribe_audio_data(distil_whisper_model):
    test_data = np.random.rand(
        16000
    )  # Simulated audio data (1 second)
    transcription = await distil_whisper_model.async_transcribe(
        test_data.tobytes()
    )

    assert isinstance(transcription, str)
    assert transcription.strip() != ""


def test_real_time_transcribe(distil_whisper_model, capsys):
    test_data = np.random.rand(
        16000 * 5
    )  # Simulated audio data (5 seconds)
    with tempfile.NamedTemporaryFile(
        suffix=".wav", delete=False
    ) as audio_file:
        audio_file_path = create_audio_file(
            test_data, 16000, audio_file.name
        )

        distil_whisper_model.real_time_transcribe(
            audio_file_path, chunk_duration=1
        )

        os.remove(audio_file_path)

    captured = capsys.readouterr()
    assert "Starting real-time transcription..." in captured.out
    assert "Chunk" in captured.out


def test_real_time_transcribe_audio_file_not_found(
    distil_whisper_model, capsys
):
    audio_file_path = "non_existent_audio.wav"
    distil_whisper_model.real_time_transcribe(
        audio_file_path, chunk_duration=1
    )

    captured = capsys.readouterr()
    assert "The audio file was not found." in captured.out


@pytest.fixture
def mock_async_retry():
    def _mock_async_retry(
        retries=3, exceptions=(Exception,), delay=1
    ):
        def decorator(func):
            @wraps(func)
            async def wrapper(*args, **kwargs):
                return await func(*args, **kwargs)

            return wrapper

        return decorator

    with patch(
        "distil_whisper_model.async_retry", new=_mock_async_retry()
    ):
        yield


@pytest.mark.asyncio
async def test_async_retry_decorator_success():
    async def mock_async_function():
        return "Success"

    decorated_function = async_retry()(mock_async_function)
    result = await decorated_function()
    assert result == "Success"


@pytest.mark.asyncio
async def test_async_retry_decorator_failure():
    async def mock_async_function():
        raise Exception("Error")

    decorated_function = async_retry()(mock_async_function)
    with pytest.raises(Exception, match="Error"):
        await decorated_function()


@pytest.mark.asyncio
async def test_async_retry_decorator_multiple_attempts():
    async def mock_async_function():
        if mock_async_function.attempts == 0:
            mock_async_function.attempts += 1
            raise Exception("Error")
        else:
            return "Success"

    mock_async_function.attempts = 0
    decorated_function = async_retry(max_retries=2)(
        mock_async_function
    )
    result = await decorated_function()
    assert result == "Success"


def test_create_audio_file():
    test_data = np.random.rand(
        16000
    )  # Simulated audio data (1 second)
    sample_rate = 16000
    with tempfile.NamedTemporaryFile(
        suffix=".wav", delete=False
    ) as audio_file:
        audio_file_path = create_audio_file(
            test_data, sample_rate, audio_file.name
        )

        assert os.path.exists(audio_file_path)
        os.remove(audio_file_path)


# test_distilled_whisperx.py


# Fixtures for setting up model, processor, and audio files
@pytest.fixture(scope="module")
def model_id():
    return "distil-whisper/distil-large-v2"


@pytest.fixture(scope="module")
def whisper_model(model_id):
    return DistilWhisperModel(model_id)


@pytest.fixture(scope="session")
def audio_file_path(tmp_path_factory):
    # You would create a small temporary MP3 file here for testing
    # or use a public domain MP3 file's path
    return "path/to/valid_audio.mp3"


@pytest.fixture(scope="session")
def invalid_audio_file_path():
    return "path/to/invalid_audio.mp3"


@pytest.fixture(scope="session")
def audio_dict():
    # This should represent a valid audio dictionary as expected by the model
    return {"array": torch.randn(1, 16000), "sampling_rate": 16000}


# Test initialization
def test_initialization(whisper_model):
    assert whisper_model.model is not None
    assert whisper_model.processor is not None


# Test successful transcription with file path
def test_transcribe_with_file_path(whisper_model, audio_file_path):
    transcription = whisper_model.transcribe(audio_file_path)
    assert isinstance(transcription, str)


# Test successful transcription with audio dict
def test_transcribe_with_audio_dict(whisper_model, audio_dict):
    transcription = whisper_model.transcribe(audio_dict)
    assert isinstance(transcription, str)


# Test for file not found error
def test_file_not_found(whisper_model, invalid_audio_file_path):
    with pytest.raises(Exception):
        whisper_model.transcribe(invalid_audio_file_path)


# Asynchronous tests
@pytest.mark.asyncio
async def test_async_transcription_success(
    whisper_model, audio_file_path
):
    transcription = await whisper_model.async_transcribe(
        audio_file_path
    )
    assert isinstance(transcription, str)


@pytest.mark.asyncio
async def test_async_transcription_failure(
    whisper_model, invalid_audio_file_path
):
    with pytest.raises(Exception):
        await whisper_model.async_transcribe(invalid_audio_file_path)


# Testing real-time transcription simulation
def test_real_time_transcription(
    whisper_model, audio_file_path, capsys
):
    whisper_model.real_time_transcribe(
        audio_file_path, chunk_duration=1
    )
    captured = capsys.readouterr()
    assert "Starting real-time transcription..." in captured.out


# Testing retry decorator for asynchronous function
@pytest.mark.asyncio
async def test_async_retry():
    @async_retry(max_retries=2, exceptions=(ValueError,), delay=0)
    async def failing_func():
        raise ValueError("Test")

    with pytest.raises(ValueError):
        await failing_func()


# Mocking the actual model to avoid GPU/CPU intensive operations during test
@pytest.fixture
def mocked_model(monkeypatch):
    model_mock = AsyncMock(AutoModelForSpeechSeq2Seq)
    processor_mock = MagicMock(AutoProcessor)
    monkeypatch.setattr(
        "swarms.models.distilled_whisperx.AutoModelForSpeechSeq2Seq.from_pretrained",
        model_mock,
    )
    monkeypatch.setattr(
        "swarms.models.distilled_whisperx.AutoProcessor.from_pretrained",
        processor_mock,
    )
    return model_mock, processor_mock


@pytest.mark.asyncio
async def test_async_transcribe_with_mocked_model(
    mocked_model, audio_file_path
):
    model_mock, processor_mock = mocked_model
    # Set up what the mock should return when it's called
    model_mock.return_value.generate.return_value = torch.tensor(
        [[0]]
    )
    processor_mock.return_value.batch_decode.return_value = [
        "mocked transcription"
    ]
    model_wrapper = DistilWhisperModel()
    transcription = await model_wrapper.async_transcribe(
        audio_file_path
    )
    assert transcription == "mocked transcription"