You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
69 lines
2.2 KiB
69 lines
2.2 KiB
1 year ago
|
import pytest
|
||
|
import torch
|
||
1 year ago
|
from unittest.mock import Mock
|
||
1 year ago
|
from swarms.models.huggingface import HuggingFaceLLM
|
||
1 year ago
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def mock_torch():
|
||
|
return Mock()
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def mock_autotokenizer():
|
||
|
return Mock()
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def mock_automodelforcausallm():
|
||
|
return Mock()
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def mock_bitsandbytesconfig():
|
||
|
return Mock()
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def hugging_face_llm(mock_torch, mock_autotokenizer, mock_automodelforcausallm, mock_bitsandbytesconfig):
|
||
|
HuggingFaceLLM.torch = mock_torch
|
||
|
HuggingFaceLLM.AutoTokenizer = mock_autotokenizer
|
||
|
HuggingFaceLLM.AutoModelForCausalLM = mock_automodelforcausallm
|
||
|
HuggingFaceLLM.BitsAndBytesConfig = mock_bitsandbytesconfig
|
||
|
|
||
|
return HuggingFaceLLM(model_id='test')
|
||
|
|
||
|
|
||
|
def test_init(hugging_face_llm, mock_autotokenizer, mock_automodelforcausallm):
|
||
|
assert hugging_face_llm.model_id == 'test'
|
||
|
mock_autotokenizer.from_pretrained.assert_called_once_with('test')
|
||
|
mock_automodelforcausallm.from_pretrained.assert_called_once_with('test', quantization_config=None)
|
||
|
|
||
|
|
||
|
def test_init_with_quantize(hugging_face_llm, mock_autotokenizer, mock_automodelforcausallm, mock_bitsandbytesconfig):
|
||
|
quantization_config = {
|
||
|
'load_in_4bit': True,
|
||
|
'bnb_4bit_use_double_quant': True,
|
||
|
'bnb_4bit_quant_type': "nf4",
|
||
|
'bnb_4bit_compute_dtype': torch.bfloat16
|
||
|
}
|
||
|
mock_bitsandbytesconfig.return_value = quantization_config
|
||
|
|
||
1 year ago
|
HuggingFaceLLM(model_id='test', quantize=True)
|
||
1 year ago
|
|
||
|
mock_bitsandbytesconfig.assert_called_once_with(**quantization_config)
|
||
|
mock_autotokenizer.from_pretrained.assert_called_once_with('test')
|
||
|
mock_automodelforcausallm.from_pretrained.assert_called_once_with('test', quantization_config=quantization_config)
|
||
|
|
||
|
|
||
|
def test_generate_text(hugging_face_llm):
|
||
|
prompt_text = 'test prompt'
|
||
|
expected_output = 'test output'
|
||
|
hugging_face_llm.tokenizer.encode.return_value = torch.tensor([0]) # Mock tensor
|
||
|
hugging_face_llm.model.generate.return_value = torch.tensor([0]) # Mock tensor
|
||
|
hugging_face_llm.tokenizer.decode.return_value = expected_output
|
||
|
|
||
|
output = hugging_face_llm.generate_text(prompt_text)
|
||
|
|
||
|
assert output == expected_output
|