You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/swarms/models/dalle3.py

175 lines
4.6 KiB

import openai
import logging
import os
from dataclasses import dataclass
from functools import lru_cache
from termcolor import colored
from openai import OpenAI
from dotenv import load_dotenv
from pydantic import BaseModel, validator
from PIL import Image
from io import BytesIO
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
# Configure Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class Dalle3:
"""
Dalle3 model class
Attributes:
-----------
image_url: str
The image url generated by the Dalle3 API
Methods:
--------
__call__(self, task: str) -> Dalle3:
Makes a call to the Dalle3 API and returns the image url
Example:
--------
>>> dalle3 = Dalle3()
>>> task = "A painting of a dog"
>>> image_url = dalle3(task)
>>> print(image_url)
https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png
"""
model: str = "dall-e-3"
img: str = None
size: str = "1024x1024"
max_retries: int = 3
quality: str = "standard"
n: int = 4
client = OpenAI(
api_key=api_key,
max_retries=max_retries,
)
class Config:
"""Config class for the Dalle3 model"""
arbitrary_types_allowed = True
@validator("max_retries", "time_seconds")
def must_be_positive(cls, value):
if value <= 0:
raise ValueError("Must be positive")
return value
def read_img(self, img: str):
"""Read the image using pil"""
img = Image.open(img)
return img
def set_width_height(self, img: str, width: int, height: int):
"""Set the width and height of the image"""
img = self.read_img(img)
img = img.resize((width, height))
return img
def convert_to_bytesio(self, img: str, format: str = "PNG"):
"""Convert the image to an bytes io object"""
byte_stream = BytesIO()
img.save(byte_stream, format=format)
byte_array = byte_stream.getvalue()
return byte_array
# @lru_cache(maxsize=32)
def __call__(self, task: str):
"""
Text to image conversion using the Dalle3 API
Parameters:
-----------
task: str
The task to be converted to an image
Returns:
--------
Dalle3:
An instance of the Dalle3 class with the image url generated by the Dalle3 API
Example:
--------
>>> dalle3 = Dalle3()
>>> task = "A painting of a dog"
>>> image_url = dalle3(task)
>>> print(image_url)
https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png
"""
try:
# Making a call to the the Dalle3 API
response = self.client.images.generate(
# model=self.model,
prompt=task,
# size=self.size,
# quality=self.quality,
n=self.n,
)
# Extracting the image url from the response
img = response.data[0].url
return img
except openai.OpenAIError as error:
# Handling exceptions and printing the errors details
print(
colored(
f"Error running Dalle3: {error} try optimizing your api key and or try again",
"red",
)
)
raise error
def create_variations(self, img: str):
"""
Create variations of an image using the Dalle3 API
Parameters:
-----------
img: str
The image to be used for the API request
Returns:
--------
img: str
The image url generated by the Dalle3 API
Example:
--------
>>> dalle3 = Dalle3()
>>> img = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
>>> img = dalle3.create_variations(img)
>>> print(img)
"""
try:
response = self.client.images.create_variation(
img = open(img, "rb"),
n=self.n,
size=self.size
)
img = response.data[0].url
return img
except (Exception, openai.OpenAIError) as error:
print(
colored(
f"Error running Dalle3: {error} try optimizing your api key and or try again",
"red",
)
)
print(colored(f"Error running Dalle3: {error.http_status}", "red"))
print(colored(f"Error running Dalle3: {error.error}", "red"))
raise error