import pytest
import torch
from transformers import AutoTokenizer
from swarms.models.yi_200k import Yi34B200k


# Create fixtures if needed
@pytest.fixture
def yi34b_model():
    return Yi34B200k()


# Test cases for the Yi34B200k class
def test_yi34b_init(yi34b_model):
    assert isinstance(yi34b_model.model, torch.nn.Module)
    assert isinstance(yi34b_model.tokenizer, AutoTokenizer)


def test_yi34b_generate_text(yi34b_model):
    prompt = "There's a place where time stands still."
    generated_text = yi34b_model(prompt)
    assert isinstance(generated_text, str)
    assert len(generated_text) > 0


@pytest.mark.parametrize("max_length", [64, 128, 256, 512])
def test_yi34b_generate_text_with_length(yi34b_model, max_length):
    prompt = "There's a place where time stands still."
    generated_text = yi34b_model(prompt, max_length=max_length)
    assert len(generated_text) <= max_length


@pytest.mark.parametrize("temperature", [0.5, 1.0, 1.5])
def test_yi34b_generate_text_with_temperature(yi34b_model, temperature):
    prompt = "There's a place where time stands still."
    generated_text = yi34b_model(prompt, temperature=temperature)
    assert isinstance(generated_text, str)


def test_yi34b_generate_text_with_invalid_prompt(yi34b_model):
    prompt = None  # Invalid prompt
    with pytest.raises(
        ValueError, match="Input prompt must be a non-empty string"
    ):
        yi34b_model(prompt)


def test_yi34b_generate_text_with_invalid_max_length(yi34b_model):
    prompt = "There's a place where time stands still."
    max_length = -1  # Invalid max_length
    with pytest.raises(
        ValueError, match="max_length must be a positive integer"
    ):
        yi34b_model(prompt, max_length=max_length)


def test_yi34b_generate_text_with_invalid_temperature(yi34b_model):
    prompt = "There's a place where time stands still."
    temperature = 2.0  # Invalid temperature
    with pytest.raises(
        ValueError, match="temperature must be between 0.01 and 1.0"
    ):
        yi34b_model(prompt, temperature=temperature)


@pytest.mark.parametrize("top_k", [20, 30, 50])
def test_yi34b_generate_text_with_top_k(yi34b_model, top_k):
    prompt = "There's a place where time stands still."
    generated_text = yi34b_model(prompt, top_k=top_k)
    assert isinstance(generated_text, str)


@pytest.mark.parametrize("top_p", [0.5, 0.7, 0.9])
def test_yi34b_generate_text_with_top_p(yi34b_model, top_p):
    prompt = "There's a place where time stands still."
    generated_text = yi34b_model(prompt, top_p=top_p)
    assert isinstance(generated_text, str)


def test_yi34b_generate_text_with_invalid_top_k(yi34b_model):
    prompt = "There's a place where time stands still."
    top_k = -1  # Invalid top_k
    with pytest.raises(
        ValueError, match="top_k must be a non-negative integer"
    ):
        yi34b_model(prompt, top_k=top_k)


def test_yi34b_generate_text_with_invalid_top_p(yi34b_model):
    prompt = "There's a place where time stands still."
    top_p = 1.5  # Invalid top_p
    with pytest.raises(ValueError, match="top_p must be between 0.0 and 1.0"):
        yi34b_model(prompt, top_p=top_p)


@pytest.mark.parametrize("repitition_penalty", [1.0, 1.2, 1.5])
def test_yi34b_generate_text_with_repitition_penalty(
    yi34b_model, repitition_penalty
):
    prompt = "There's a place where time stands still."
    generated_text = yi34b_model(prompt, repitition_penalty=repitition_penalty)
    assert isinstance(generated_text, str)


def test_yi34b_generate_text_with_invalid_repitition_penalty(yi34b_model):
    prompt = "There's a place where time stands still."
    repitition_penalty = 0.0  # Invalid repitition_penalty
    with pytest.raises(
        ValueError, match="repitition_penalty must be a positive float"
    ):
        yi34b_model(prompt, repitition_penalty=repitition_penalty)


def test_yi34b_generate_text_with_invalid_device(yi34b_model):
    prompt = "There's a place where time stands still."
    device_map = "invalid_device"  # Invalid device_map
    with pytest.raises(ValueError, match="Invalid device_map"):
        yi34b_model(prompt, device_map=device_map)