You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
59 lines
1.9 KiB
59 lines
1.9 KiB
1 year ago
|
import pytest
|
||
|
import torch
|
||
|
from unittest.mock import Mock, patch
|
||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||
|
from swarms.models.huggingface import HuggingfaceLLM
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def huggingface_llm():
|
||
|
# Create an instance of HuggingfaceLLM for testing.
|
||
|
model_id = "gpt2-small"
|
||
|
return HuggingfaceLLM(model_id=model_id)
|
||
|
|
||
|
|
||
|
def test_initialization(huggingface_llm):
|
||
|
# Test the initialization of the HuggingfaceLLM class.
|
||
|
assert huggingface_llm.model_id == "gpt2-small"
|
||
|
assert huggingface_llm.device in ["cpu", "cuda"]
|
||
|
assert huggingface_llm.max_length == 20
|
||
|
assert huggingface_llm.verbose == False
|
||
|
assert huggingface_llm.distributed == False
|
||
|
assert huggingface_llm.decoding == False
|
||
|
assert huggingface_llm.model is None
|
||
|
assert huggingface_llm.tokenizer is None
|
||
|
|
||
|
|
||
|
def test_load_model(huggingface_llm):
|
||
|
# Test loading the model.
|
||
|
huggingface_llm.load_model()
|
||
|
assert isinstance(huggingface_llm.model, AutoModelForCausalLM)
|
||
|
assert isinstance(huggingface_llm.tokenizer, AutoTokenizer)
|
||
|
|
||
|
|
||
|
def test_run(huggingface_llm):
|
||
|
# Test the run method of HuggingfaceLLM.
|
||
|
prompt_text = "Once upon a time"
|
||
|
generated_text = huggingface_llm.run(prompt_text)
|
||
|
assert isinstance(generated_text, str)
|
||
|
assert len(generated_text) > 0
|
||
|
|
||
|
|
||
|
def test_call_method(huggingface_llm):
|
||
|
# Test the __call__ method of HuggingfaceLLM.
|
||
|
prompt_text = "Once upon a time"
|
||
|
generated_text = huggingface_llm(prompt_text)
|
||
|
assert isinstance(generated_text, str)
|
||
|
assert len(generated_text) > 0
|
||
|
|
||
|
|
||
|
def test_load_model_failure():
|
||
|
# Test loading model failure.
|
||
|
with patch(
|
||
|
"your_module.AutoModelForCausalLM.from_pretrained",
|
||
|
side_effect=Exception("Model load failed"),
|
||
|
):
|
||
|
with pytest.raises(Exception):
|
||
|
huggingface_llm = HuggingfaceLLM(model_id="gpt2-small")
|
||
|
huggingface_llm.load_model()
|