You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
87 lines
2.2 KiB
87 lines
2.2 KiB
1 year ago
|
import pytest
|
||
|
import torch
|
||
1 year ago
|
from unittest.mock import Mock
|
||
1 year ago
|
from swarms.models.huggingface import HuggingFaceLLM
|
||
1 year ago
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def mock_torch():
|
||
|
return Mock()
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def mock_autotokenizer():
|
||
|
return Mock()
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def mock_automodelforcausallm():
|
||
|
return Mock()
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def mock_bitsandbytesconfig():
|
||
|
return Mock()
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
1 year ago
|
def hugging_face_llm(
|
||
1 year ago
|
mock_torch,
|
||
|
mock_autotokenizer,
|
||
|
mock_automodelforcausallm,
|
||
|
mock_bitsandbytesconfig,
|
||
1 year ago
|
):
|
||
1 year ago
|
HuggingFaceLLM.torch = mock_torch
|
||
|
HuggingFaceLLM.AutoTokenizer = mock_autotokenizer
|
||
|
HuggingFaceLLM.AutoModelForCausalLM = mock_automodelforcausallm
|
||
|
HuggingFaceLLM.BitsAndBytesConfig = mock_bitsandbytesconfig
|
||
|
|
||
1 year ago
|
return HuggingFaceLLM(model_id="test")
|
||
1 year ago
|
|
||
|
|
||
|
def test_init(hugging_face_llm, mock_autotokenizer, mock_automodelforcausallm):
|
||
1 year ago
|
assert hugging_face_llm.model_id == "test"
|
||
|
mock_autotokenizer.from_pretrained.assert_called_once_with("test")
|
||
|
mock_automodelforcausallm.from_pretrained.assert_called_once_with(
|
||
|
"test", quantization_config=None
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_init_with_quantize(
|
||
|
hugging_face_llm,
|
||
|
mock_autotokenizer,
|
||
|
mock_automodelforcausallm,
|
||
|
mock_bitsandbytesconfig,
|
||
|
):
|
||
1 year ago
|
quantization_config = {
|
||
1 year ago
|
"load_in_4bit": True,
|
||
|
"bnb_4bit_use_double_quant": True,
|
||
|
"bnb_4bit_quant_type": "nf4",
|
||
|
"bnb_4bit_compute_dtype": torch.bfloat16,
|
||
1 year ago
|
}
|
||
|
mock_bitsandbytesconfig.return_value = quantization_config
|
||
|
|
||
1 year ago
|
HuggingFaceLLM(model_id="test", quantize=True)
|
||
1 year ago
|
|
||
|
mock_bitsandbytesconfig.assert_called_once_with(**quantization_config)
|
||
1 year ago
|
mock_autotokenizer.from_pretrained.assert_called_once_with("test")
|
||
|
mock_automodelforcausallm.from_pretrained.assert_called_once_with(
|
||
|
"test", quantization_config=quantization_config
|
||
|
)
|
||
1 year ago
|
|
||
|
|
||
|
def test_generate_text(hugging_face_llm):
|
||
1 year ago
|
prompt_text = "test prompt"
|
||
|
expected_output = "test output"
|
||
1 year ago
|
hugging_face_llm.tokenizer.encode.return_value = torch.tensor(
|
||
|
[0]
|
||
|
) # Mock tensor
|
||
|
hugging_face_llm.model.generate.return_value = torch.tensor(
|
||
|
[0]
|
||
|
) # Mock tensor
|
||
1 year ago
|
hugging_face_llm.tokenizer.decode.return_value = expected_output
|
||
|
|
||
|
output = hugging_face_llm.generate_text(prompt_text)
|
||
|
|
||
|
assert output == expected_output
|