import pytest
import torch
from unittest.mock import Mock
from swarms.models.huggingface import HuggingFaceLLM


@pytest.fixture
def mock_torch():
    return Mock()


@pytest.fixture
def mock_autotokenizer():
    return Mock()


@pytest.fixture
def mock_automodelforcausallm():
    return Mock()


@pytest.fixture
def mock_bitsandbytesconfig():
    return Mock()


@pytest.fixture
def hugging_face_llm(
    mock_torch, mock_autotokenizer, mock_automodelforcausallm, mock_bitsandbytesconfig
):
    HuggingFaceLLM.torch = mock_torch
    HuggingFaceLLM.AutoTokenizer = mock_autotokenizer
    HuggingFaceLLM.AutoModelForCausalLM = mock_automodelforcausallm
    HuggingFaceLLM.BitsAndBytesConfig = mock_bitsandbytesconfig

    return HuggingFaceLLM(model_id="test")


def test_init(hugging_face_llm, mock_autotokenizer, mock_automodelforcausallm):
    assert hugging_face_llm.model_id == "test"
    mock_autotokenizer.from_pretrained.assert_called_once_with("test")
    mock_automodelforcausallm.from_pretrained.assert_called_once_with(
        "test", quantization_config=None
    )


def test_init_with_quantize(
    hugging_face_llm,
    mock_autotokenizer,
    mock_automodelforcausallm,
    mock_bitsandbytesconfig,
):
    quantization_config = {
        "load_in_4bit": True,
        "bnb_4bit_use_double_quant": True,
        "bnb_4bit_quant_type": "nf4",
        "bnb_4bit_compute_dtype": torch.bfloat16,
    }
    mock_bitsandbytesconfig.return_value = quantization_config

    HuggingFaceLLM(model_id="test", quantize=True)

    mock_bitsandbytesconfig.assert_called_once_with(**quantization_config)
    mock_autotokenizer.from_pretrained.assert_called_once_with("test")
    mock_automodelforcausallm.from_pretrained.assert_called_once_with(
        "test", quantization_config=quantization_config
    )


def test_generate_text(hugging_face_llm):
    prompt_text = "test prompt"
    expected_output = "test output"
    hugging_face_llm.tokenizer.encode.return_value = torch.tensor([0])  # Mock tensor
    hugging_face_llm.model.generate.return_value = torch.tensor([0])  # Mock tensor
    hugging_face_llm.tokenizer.decode.return_value = expected_output

    output = hugging_face_llm.generate_text(prompt_text)

    assert output == expected_output