You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/swarms/models/open_dalle.py

68 lines
2.3 KiB

from typing import Any, Optional
import torch
from diffusers import AutoPipelineForText2Image
from swarms.models.base_multimodal_model import BaseMultiModalModel
class OpenDalle(BaseMultiModalModel):
"""OpenDalle model class
Attributes:
model_name (str): The name or path of the model to be used. Defaults to "dataautogpt3/OpenDalleV1.1".
torch_dtype (torch.dtype): The torch data type to be used. Defaults to torch.float16.
device (str): The device to be used for computation. Defaults to "cuda".
Examples:
>>> from swarms.models.open_dalle import OpenDalle
>>> od = OpenDalle()
>>> od.run("A picture of a cat")
"""
def __init__(
self,
model_name: str = "dataautogpt3/OpenDalleV1.1",
torch_dtype: Any = torch.float16,
device: str = "cuda",
*args,
**kwargs,
):
"""
Initializes the OpenDalle model.
Args:
model_name (str, optional): The name or path of the model to be used. Defaults to "dataautogpt3/OpenDalleV1.1".
torch_dtype (torch.dtype, optional): The torch data type to be used. Defaults to torch.float16.
device (str, optional): The device to be used for computation. Defaults to "cuda".
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
self.pipeline = AutoPipelineForText2Image.from_pretrained(
model_name, torch_dtype=torch_dtype, *args, **kwargs
).to(device)
def run(self, task: Optional[str] = None, *args, **kwargs):
"""Run the OpenDalle model
Args:
task (str, optional): The task to be performed. Defaults to None.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
[type]: [description]
"""
try:
if task is None:
raise ValueError("Task cannot be None")
if not isinstance(task, str):
raise TypeError("Task must be a string")
if len(task) < 1:
raise ValueError("Task cannot be empty")
return self.pipeline(task, *args, **kwargs).images[0]
except Exception as error:
print(f"[ERROR][OpenDalle] {error}")
raise error