import pytest
import torch

from swarms.models.open_dalle import OpenDalle


def test_init():
    od = OpenDalle()
    assert isinstance(od, OpenDalle)


def test_init_custom_model():
    od = OpenDalle(model_name="custom_model")
    assert od.pipeline.model_name == "custom_model"


def test_init_custom_dtype():
    od = OpenDalle(torch_dtype=torch.float32)
    assert od.pipeline.torch_dtype == torch.float32


def test_init_custom_device():
    od = OpenDalle(device="cpu")
    assert od.pipeline.device == "cpu"


def test_run():
    od = OpenDalle()
    result = od.run("A picture of a cat")
    assert isinstance(result, torch.Tensor)


def test_run_no_task():
    od = OpenDalle()
    with pytest.raises(ValueError, match="Task cannot be None"):
        od.run(None)


def test_run_non_string_task():
    od = OpenDalle()
    with pytest.raises(TypeError, match="Task must be a string"):
        od.run(123)


def test_run_empty_task():
    od = OpenDalle()
    with pytest.raises(ValueError, match="Task cannot be empty"):
        od.run("")


def test_run_custom_args():
    od = OpenDalle()
    result = od.run("A picture of a cat", custom_arg="custom_value")
    assert isinstance(result, torch.Tensor)


def test_run_error():
    od = OpenDalle()
    with pytest.raises(Exception):
        od.run("A picture of a cat", raise_error=True)