You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/models/huggingface.py

59 lines
1.9 KiB

1 year ago
import pytest
import torch
from unittest.mock import Mock, patch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from swarms.models.huggingface import HuggingfaceLLM
@pytest.fixture
def huggingface_llm():
# Create an instance of HuggingfaceLLM for testing.
model_id = "gpt2-small"
return HuggingfaceLLM(model_id=model_id)
def test_initialization(huggingface_llm):
# Test the initialization of the HuggingfaceLLM class.
assert huggingface_llm.model_id == "gpt2-small"
assert huggingface_llm.device in ["cpu", "cuda"]
assert huggingface_llm.max_length == 20
assert huggingface_llm.verbose == False
assert huggingface_llm.distributed == False
assert huggingface_llm.decoding == False
assert huggingface_llm.model is None
assert huggingface_llm.tokenizer is None
def test_load_model(huggingface_llm):
# Test loading the model.
huggingface_llm.load_model()
assert isinstance(huggingface_llm.model, AutoModelForCausalLM)
assert isinstance(huggingface_llm.tokenizer, AutoTokenizer)
def test_run(huggingface_llm):
# Test the run method of HuggingfaceLLM.
prompt_text = "Once upon a time"
generated_text = huggingface_llm.run(prompt_text)
assert isinstance(generated_text, str)
assert len(generated_text) > 0
def test_call_method(huggingface_llm):
# Test the __call__ method of HuggingfaceLLM.
prompt_text = "Once upon a time"
generated_text = huggingface_llm(prompt_text)
assert isinstance(generated_text, str)
assert len(generated_text) > 0
def test_load_model_failure():
# Test loading model failure.
with patch(
"your_module.AutoModelForCausalLM.from_pretrained",
side_effect=Exception("Model load failed"),
):
with pytest.raises(Exception):
huggingface_llm = HuggingfaceLLM(model_id="gpt2-small")
huggingface_llm.load_model()