You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/models/idefics.py

120 lines
3.6 KiB

1 year ago
import pytest
from unittest.mock import patch
import torch
from swarms.models.idefics import Idefics, IdeficsForVisionText2Text, AutoProcessor
1 year ago
@pytest.fixture
def idefics_instance():
with patch(
"torch.cuda.is_available", return_value=False
): # Assuming tests are run on CPU for simplicity
1 year ago
instance = Idefics()
return instance
1 year ago
# Basic Tests
def test_init_default(idefics_instance):
assert idefics_instance.device == "cpu"
assert idefics_instance.max_length == 100
assert not idefics_instance.chat_history
1 year ago
@pytest.mark.parametrize(
"device,expected",
[
(None, "cpu"),
1 year ago
("cuda", "cuda"),
("cpu", "cpu"),
],
1 year ago
)
def test_init_device(device, expected):
with patch(
"torch.cuda.is_available", return_value=True if expected == "cuda" else False
):
1 year ago
instance = Idefics(device=device)
assert instance.device == expected
1 year ago
# Test `run` method
def test_run(idefics_instance):
prompts = [["User: Test"]]
with patch.object(idefics_instance, "processor") as mock_processor, patch.object(
idefics_instance, "model"
) as mock_model:
1 year ago
mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])}
mock_model.generate.return_value = torch.tensor([1, 2, 3])
mock_processor.batch_decode.return_value = ["Test"]
1 year ago
result = idefics_instance.run(prompts)
1 year ago
assert result == ["Test"]
1 year ago
# Test `__call__` method (using the same logic as run for simplicity)
def test_call(idefics_instance):
prompts = [["User: Test"]]
with patch.object(idefics_instance, "processor") as mock_processor, patch.object(
idefics_instance, "model"
) as mock_model:
1 year ago
mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])}
mock_model.generate.return_value = torch.tensor([1, 2, 3])
mock_processor.batch_decode.return_value = ["Test"]
1 year ago
result = idefics_instance(prompts)
1 year ago
assert result == ["Test"]
1 year ago
# Test `chat` method
def test_chat(idefics_instance):
user_input = "User: Hello"
response = "Model: Hi there!"
with patch.object(idefics_instance, "run", return_value=[response]):
result = idefics_instance.chat(user_input)
1 year ago
assert result == response
assert idefics_instance.chat_history == [user_input, response]
1 year ago
# Test `set_checkpoint` method
def test_set_checkpoint(idefics_instance):
new_checkpoint = "new_checkpoint"
with patch.object(
IdeficsForVisionText2Text, "from_pretrained"
) as mock_from_pretrained, patch.object(AutoProcessor, "from_pretrained"):
1 year ago
idefics_instance.set_checkpoint(new_checkpoint)
1 year ago
mock_from_pretrained.assert_called_with(new_checkpoint, torch_dtype=torch.bfloat16)
1 year ago
# Test `set_device` method
def test_set_device(idefics_instance):
new_device = "cuda"
with patch.object(idefics_instance.model, "to"):
idefics_instance.set_device(new_device)
1 year ago
assert idefics_instance.device == new_device
1 year ago
# Test `set_max_length` method
def test_set_max_length(idefics_instance):
new_length = 150
idefics_instance.set_max_length(new_length)
assert idefics_instance.max_length == new_length
1 year ago
# Test `clear_chat_history` method
def test_clear_chat_history(idefics_instance):
idefics_instance.chat_history = ["User: Test", "Model: Response"]
idefics_instance.clear_chat_history()
assert not idefics_instance.chat_history
1 year ago
# Exception Tests
def test_run_with_empty_prompts(idefics_instance):
with pytest.raises(
Exception
): # Replace Exception with the actual exception that may arise for an empty prompt.
1 year ago
idefics_instance.run([])