|
|
@ -11,14 +11,14 @@ from swarms.models.huggingface import (
|
|
|
|
# Fixture for the class instance
|
|
|
|
# Fixture for the class instance
|
|
|
|
@pytest.fixture
|
|
|
|
@pytest.fixture
|
|
|
|
def llm_instance():
|
|
|
|
def llm_instance():
|
|
|
|
model_id = "gpt2-small"
|
|
|
|
model_id = "NousResearch/Nous-Hermes-2-Vision-Alpha"
|
|
|
|
instance = HuggingfaceLLM(model_id=model_id)
|
|
|
|
instance = HuggingfaceLLM(model_id=model_id)
|
|
|
|
return instance
|
|
|
|
return instance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Test for instantiation and attributes
|
|
|
|
# Test for instantiation and attributes
|
|
|
|
def test_llm_initialization(llm_instance):
|
|
|
|
def test_llm_initialization(llm_instance):
|
|
|
|
assert llm_instance.model_id == "gpt2-small"
|
|
|
|
assert llm_instance.model_id == "NousResearch/Nous-Hermes-2-Vision-Alpha"
|
|
|
|
assert llm_instance.max_length == 500
|
|
|
|
assert llm_instance.max_length == 500
|
|
|
|
# ... add more assertions for all default attributes
|
|
|
|
# ... add more assertions for all default attributes
|
|
|
|
|
|
|
|
|
|
|
@ -75,9 +75,9 @@ def test_llm_memory_consumption(llm_instance):
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"model_id, max_length",
|
|
|
|
"model_id, max_length",
|
|
|
|
[
|
|
|
|
[
|
|
|
|
("gpt2-small", 100),
|
|
|
|
("NousResearch/Nous-Hermes-2-Vision-Alpha", 100),
|
|
|
|
("gpt2-medium", 200),
|
|
|
|
("microsoft/Orca-2-13b", 200),
|
|
|
|
("gpt2-large", None), # None to check default behavior
|
|
|
|
("berkeley-nest/Starling-LM-7B-alpha", None), # None to check default behavior
|
|
|
|
],
|
|
|
|
],
|
|
|
|
)
|
|
|
|
)
|
|
|
|
def test_llm_initialization_params(model_id, max_length):
|
|
|
|
def test_llm_initialization_params(model_id, max_length):
|
|
|
@ -99,12 +99,6 @@ def test_llm_set_invalid_device(llm_instance):
|
|
|
|
llm_instance.set_device("quantum_processor")
|
|
|
|
llm_instance.set_device("quantum_processor")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Test for model download progress bar
|
|
|
|
|
|
|
|
@patch("swarms.models.huggingface.HuggingfaceLLM._download_model")
|
|
|
|
|
|
|
|
def test_llm_model_download_progress(mock_download, llm_instance):
|
|
|
|
|
|
|
|
llm_instance.download_model_with_progress()
|
|
|
|
|
|
|
|
mock_download.assert_called_once()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Mocking external API call to test run method without network
|
|
|
|
# Mocking external API call to test run method without network
|
|
|
|
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
|
|
|
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
|
|
@ -209,7 +203,6 @@ def test_llm_force_gpu_when_unavailable(
|
|
|
|
|
|
|
|
|
|
|
|
# Test for proper cleanup after model use (releasing resources)
|
|
|
|
# Test for proper cleanup after model use (releasing resources)
|
|
|
|
@patch("swarms.models.huggingface.HuggingfaceLLM._model")
|
|
|
|
@patch("swarms.models.huggingface.HuggingfaceLLM._model")
|
|
|
|
@patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer")
|
|
|
|
|
|
|
|
def test_llm_cleanup(mock_model, mock_tokenizer, llm_instance):
|
|
|
|
def test_llm_cleanup(mock_model, mock_tokenizer, llm_instance):
|
|
|
|
llm_instance.cleanup()
|
|
|
|
llm_instance.cleanup()
|
|
|
|
# Assuming cleanup method is meant to free resources
|
|
|
|
# Assuming cleanup method is meant to free resources
|
|
|
@ -217,11 +210,6 @@ def test_llm_cleanup(mock_model, mock_tokenizer, llm_instance):
|
|
|
|
mock_tokenizer.delete.assert_called_once()
|
|
|
|
mock_tokenizer.delete.assert_called_once()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Test updating the configuration after instantiation
|
|
|
|
|
|
|
|
def test_llm_update_configuration(llm_instance):
|
|
|
|
|
|
|
|
new_config = {"temperature": 0.7}
|
|
|
|
|
|
|
|
llm_instance.update_configuration(new_config)
|
|
|
|
|
|
|
|
assert llm_instance.configuration["temperature"] == 0.7
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Test if the model is re-downloaded when changing the model_id
|
|
|
|
# Test if the model is re-downloaded when changing the model_id
|
|
|
|