You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
68 lines
2.3 KiB
68 lines
2.3 KiB
from typing import Any, Optional
|
|
|
|
import torch
|
|
from diffusers import AutoPipelineForText2Image
|
|
|
|
from swarms.models.base_multimodal_model import BaseMultiModalModel
|
|
|
|
|
|
class OpenDalle(BaseMultiModalModel):
|
|
"""OpenDalle model class
|
|
|
|
Attributes:
|
|
model_name (str): The name or path of the model to be used. Defaults to "dataautogpt3/OpenDalleV1.1".
|
|
torch_dtype (torch.dtype): The torch data type to be used. Defaults to torch.float16.
|
|
device (str): The device to be used for computation. Defaults to "cuda".
|
|
|
|
Examples:
|
|
>>> from swarms.models.open_dalle import OpenDalle
|
|
>>> od = OpenDalle()
|
|
>>> od.run("A picture of a cat")
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str = "dataautogpt3/OpenDalleV1.1",
|
|
torch_dtype: Any = torch.float16,
|
|
device: str = "cuda",
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Initializes the OpenDalle model.
|
|
|
|
Args:
|
|
model_name (str, optional): The name or path of the model to be used. Defaults to "dataautogpt3/OpenDalleV1.1".
|
|
torch_dtype (torch.dtype, optional): The torch data type to be used. Defaults to torch.float16.
|
|
device (str, optional): The device to be used for computation. Defaults to "cuda".
|
|
*args: Variable length argument list.
|
|
**kwargs: Arbitrary keyword arguments.
|
|
"""
|
|
self.pipeline = AutoPipelineForText2Image.from_pretrained(
|
|
model_name, torch_dtype=torch_dtype, *args, **kwargs
|
|
).to(device)
|
|
|
|
def run(self, task: Optional[str] = None, *args, **kwargs):
|
|
"""Run the OpenDalle model
|
|
|
|
Args:
|
|
task (str, optional): The task to be performed. Defaults to None.
|
|
*args: Variable length argument list.
|
|
**kwargs: Arbitrary keyword arguments.
|
|
|
|
Returns:
|
|
[type]: [description]
|
|
"""
|
|
try:
|
|
if task is None:
|
|
raise ValueError("Task cannot be None")
|
|
if not isinstance(task, str):
|
|
raise TypeError("Task must be a string")
|
|
if len(task) < 1:
|
|
raise ValueError("Task cannot be empty")
|
|
return self.pipeline(task, *args, **kwargs).images[0]
|
|
except Exception as error:
|
|
print(f"[ERROR][OpenDalle] {error}")
|
|
raise error
|