You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
60 lines
1.3 KiB
60 lines
1.3 KiB
1 year ago
|
import pytest
|
||
|
import torch
|
||
|
from swarms.models.open_dalle import OpenDalle
|
||
|
|
||
|
|
||
|
def test_init():
|
||
|
od = OpenDalle()
|
||
|
assert isinstance(od, OpenDalle)
|
||
|
|
||
|
|
||
|
def test_init_custom_model():
|
||
|
od = OpenDalle(model_name="custom_model")
|
||
|
assert od.pipeline.model_name == "custom_model"
|
||
|
|
||
|
|
||
|
def test_init_custom_dtype():
|
||
|
od = OpenDalle(torch_dtype=torch.float32)
|
||
|
assert od.pipeline.torch_dtype == torch.float32
|
||
|
|
||
|
|
||
|
def test_init_custom_device():
|
||
|
od = OpenDalle(device="cpu")
|
||
|
assert od.pipeline.device == "cpu"
|
||
|
|
||
|
|
||
|
def test_run():
|
||
|
od = OpenDalle()
|
||
|
result = od.run("A picture of a cat")
|
||
|
assert isinstance(result, torch.Tensor)
|
||
|
|
||
|
|
||
|
def test_run_no_task():
|
||
|
od = OpenDalle()
|
||
|
with pytest.raises(ValueError, match="Task cannot be None"):
|
||
|
od.run(None)
|
||
|
|
||
|
|
||
|
def test_run_non_string_task():
|
||
|
od = OpenDalle()
|
||
|
with pytest.raises(TypeError, match="Task must be a string"):
|
||
|
od.run(123)
|
||
|
|
||
|
|
||
|
def test_run_empty_task():
|
||
|
od = OpenDalle()
|
||
|
with pytest.raises(ValueError, match="Task cannot be empty"):
|
||
|
od.run("")
|
||
|
|
||
|
|
||
|
def test_run_custom_args():
|
||
|
od = OpenDalle()
|
||
|
result = od.run("A picture of a cat", custom_arg="custom_value")
|
||
|
assert isinstance(result, torch.Tensor)
|
||
|
|
||
|
|
||
|
def test_run_error():
|
||
|
od = OpenDalle()
|
||
|
with pytest.raises(Exception):
|
||
|
od.run("A picture of a cat", raise_error=True)
|