import pytest
import os
import torch
from swarms.models.speecht5 import SpeechT5


# Create fixtures if needed
@pytest.fixture
def speecht5_model():
    return SpeechT5()


# Test cases for the SpeechT5 class


def test_speecht5_init(speecht5_model):
    assert isinstance(
        speecht5_model.processor, SpeechT5.processor.__class__
    )
    assert isinstance(speecht5_model.model, SpeechT5.model.__class__)
    assert isinstance(
        speecht5_model.vocoder, SpeechT5.vocoder.__class__
    )
    assert isinstance(
        speecht5_model.embeddings_dataset, torch.utils.data.Dataset
    )


def test_speecht5_call(speecht5_model):
    text = "Hello, how are you?"
    speech = speecht5_model(text)
    assert isinstance(speech, torch.Tensor)


def test_speecht5_save_speech(speecht5_model):
    text = "Hello, how are you?"
    speech = speecht5_model(text)
    filename = "test_speech.wav"
    speecht5_model.save_speech(speech, filename)
    assert os.path.isfile(filename)
    os.remove(filename)


def test_speecht5_set_model(speecht5_model):
    old_model_name = speecht5_model.model_name
    new_model_name = "facebook/speecht5-tts"
    speecht5_model.set_model(new_model_name)
    assert speecht5_model.model_name == new_model_name
    assert speecht5_model.processor.model_name == new_model_name
    assert (
        speecht5_model.model.config.model_name_or_path
        == new_model_name
    )
    speecht5_model.set_model(old_model_name)  # Restore original model


def test_speecht5_set_vocoder(speecht5_model):
    old_vocoder_name = speecht5_model.vocoder_name
    new_vocoder_name = "facebook/speecht5-hifigan"
    speecht5_model.set_vocoder(new_vocoder_name)
    assert speecht5_model.vocoder_name == new_vocoder_name
    assert (
        speecht5_model.vocoder.config.model_name_or_path
        == new_vocoder_name
    )
    speecht5_model.set_vocoder(
        old_vocoder_name
    )  # Restore original vocoder


def test_speecht5_set_embeddings_dataset(speecht5_model):
    old_dataset_name = speecht5_model.dataset_name
    new_dataset_name = "Matthijs/cmu-arctic-xvectors-test"
    speecht5_model.set_embeddings_dataset(new_dataset_name)
    assert speecht5_model.dataset_name == new_dataset_name
    assert isinstance(
        speecht5_model.embeddings_dataset, torch.utils.data.Dataset
    )
    speecht5_model.set_embeddings_dataset(
        old_dataset_name
    )  # Restore original dataset


def test_speecht5_get_sampling_rate(speecht5_model):
    sampling_rate = speecht5_model.get_sampling_rate()
    assert sampling_rate == 16000


def test_speecht5_print_model_details(speecht5_model, capsys):
    speecht5_model.print_model_details()
    captured = capsys.readouterr()
    assert "Model Name: " in captured.out
    assert "Vocoder Name: " in captured.out


def test_speecht5_quick_synthesize(speecht5_model):
    text = "Hello, how are you?"
    speech = speecht5_model.quick_synthesize(text)
    assert isinstance(speech, list)
    assert isinstance(speech[0], dict)
    assert "audio" in speech[0]


def test_speecht5_change_dataset_split(speecht5_model):
    split = "test"
    speecht5_model.change_dataset_split(split)
    assert speecht5_model.embeddings_dataset.split == split


def test_speecht5_load_custom_embedding(speecht5_model):
    xvector = [0.1, 0.2, 0.3, 0.4, 0.5]
    embedding = speecht5_model.load_custom_embedding(xvector)
    assert torch.all(
        torch.eq(embedding, torch.tensor(xvector).unsqueeze(0))
    )


def test_speecht5_with_different_speakers(speecht5_model):
    text = "Hello, how are you?"
    speakers = [7306, 5324, 1234]
    for speaker_id in speakers:
        speech = speecht5_model(text, speaker_id=speaker_id)
        assert isinstance(speech, torch.Tensor)


def test_speecht5_save_speech_with_different_extensions(
    speecht5_model,
):
    text = "Hello, how are you?"
    speech = speecht5_model(text)
    extensions = [".wav", ".flac"]
    for extension in extensions:
        filename = f"test_speech{extension}"
        speecht5_model.save_speech(speech, filename)
        assert os.path.isfile(filename)
        os.remove(filename)


def test_speecht5_invalid_speaker_id(speecht5_model):
    text = "Hello, how are you?"
    invalid_speaker_id = (
        9999  # Speaker ID that does not exist in the dataset
    )
    with pytest.raises(IndexError):
        speecht5_model(text, speaker_id=invalid_speaker_id)


def test_speecht5_invalid_save_path(speecht5_model):
    text = "Hello, how are you?"
    speech = speecht5_model(text)
    invalid_path = "/invalid_directory/test_speech.wav"
    with pytest.raises(FileNotFoundError):
        speecht5_model.save_speech(speech, invalid_path)


def test_speecht5_change_vocoder_model(speecht5_model):
    text = "Hello, how are you?"
    old_vocoder_name = speecht5_model.vocoder_name
    new_vocoder_name = "facebook/speecht5-hifigan-ljspeech"
    speecht5_model.set_vocoder(new_vocoder_name)
    speech = speecht5_model(text)
    assert isinstance(speech, torch.Tensor)
    speecht5_model.set_vocoder(
        old_vocoder_name
    )  # Restore original vocoder