You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
175 lines
4.6 KiB
175 lines
4.6 KiB
import openai
|
|
import logging
|
|
import os
|
|
from dataclasses import dataclass
|
|
from functools import lru_cache
|
|
from termcolor import colored
|
|
from openai import OpenAI
|
|
from dotenv import load_dotenv
|
|
from pydantic import BaseModel, validator
|
|
from PIL import Image
|
|
from io import BytesIO
|
|
|
|
|
|
load_dotenv()
|
|
|
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
|
|
# Configure Logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class Dalle3:
|
|
"""
|
|
Dalle3 model class
|
|
|
|
Attributes:
|
|
-----------
|
|
image_url: str
|
|
The image url generated by the Dalle3 API
|
|
|
|
Methods:
|
|
--------
|
|
__call__(self, task: str) -> Dalle3:
|
|
Makes a call to the Dalle3 API and returns the image url
|
|
|
|
Example:
|
|
--------
|
|
>>> dalle3 = Dalle3()
|
|
>>> task = "A painting of a dog"
|
|
>>> image_url = dalle3(task)
|
|
>>> print(image_url)
|
|
https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png
|
|
|
|
"""
|
|
|
|
model: str = "dall-e-3"
|
|
img: str = None
|
|
size: str = "1024x1024"
|
|
max_retries: int = 3
|
|
quality: str = "standard"
|
|
n: int = 4
|
|
client = OpenAI(
|
|
api_key=api_key,
|
|
max_retries=max_retries,
|
|
)
|
|
|
|
class Config:
|
|
"""Config class for the Dalle3 model"""
|
|
|
|
arbitrary_types_allowed = True
|
|
|
|
@validator("max_retries", "time_seconds")
|
|
def must_be_positive(cls, value):
|
|
if value <= 0:
|
|
raise ValueError("Must be positive")
|
|
return value
|
|
|
|
def read_img(self, img: str):
|
|
"""Read the image using pil"""
|
|
img = Image.open(img)
|
|
return img
|
|
|
|
def set_width_height(self, img: str, width: int, height: int):
|
|
"""Set the width and height of the image"""
|
|
img = self.read_img(img)
|
|
img = img.resize((width, height))
|
|
return img
|
|
|
|
def convert_to_bytesio(self, img: str, format: str = "PNG"):
|
|
"""Convert the image to an bytes io object"""
|
|
byte_stream = BytesIO()
|
|
img.save(byte_stream, format=format)
|
|
byte_array = byte_stream.getvalue()
|
|
return byte_array
|
|
|
|
# @lru_cache(maxsize=32)
|
|
def __call__(self, task: str):
|
|
"""
|
|
Text to image conversion using the Dalle3 API
|
|
|
|
Parameters:
|
|
-----------
|
|
task: str
|
|
The task to be converted to an image
|
|
|
|
Returns:
|
|
--------
|
|
Dalle3:
|
|
An instance of the Dalle3 class with the image url generated by the Dalle3 API
|
|
|
|
Example:
|
|
--------
|
|
>>> dalle3 = Dalle3()
|
|
>>> task = "A painting of a dog"
|
|
>>> image_url = dalle3(task)
|
|
>>> print(image_url)
|
|
https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png
|
|
"""
|
|
try:
|
|
# Making a call to the the Dalle3 API
|
|
response = self.client.images.generate(
|
|
# model=self.model,
|
|
prompt=task,
|
|
# size=self.size,
|
|
# quality=self.quality,
|
|
n=self.n,
|
|
)
|
|
# Extracting the image url from the response
|
|
img = response.data[0].url
|
|
return img
|
|
except openai.OpenAIError as error:
|
|
# Handling exceptions and printing the errors details
|
|
print(
|
|
colored(
|
|
f"Error running Dalle3: {error} try optimizing your api key and or try again",
|
|
"red",
|
|
)
|
|
)
|
|
raise error
|
|
|
|
def create_variations(self, img: str):
|
|
"""
|
|
Create variations of an image using the Dalle3 API
|
|
|
|
Parameters:
|
|
-----------
|
|
img: str
|
|
The image to be used for the API request
|
|
|
|
Returns:
|
|
--------
|
|
img: str
|
|
The image url generated by the Dalle3 API
|
|
|
|
Example:
|
|
--------
|
|
>>> dalle3 = Dalle3()
|
|
>>> img = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
|
|
>>> img = dalle3.create_variations(img)
|
|
>>> print(img)
|
|
|
|
|
|
"""
|
|
try:
|
|
|
|
response = self.client.images.create_variation(
|
|
img = open(img, "rb"),
|
|
n=self.n,
|
|
size=self.size
|
|
)
|
|
img = response.data[0].url
|
|
|
|
return img
|
|
except (Exception, openai.OpenAIError) as error:
|
|
print(
|
|
colored(
|
|
f"Error running Dalle3: {error} try optimizing your api key and or try again",
|
|
"red",
|
|
)
|
|
)
|
|
print(colored(f"Error running Dalle3: {error.http_status}", "red"))
|
|
print(colored(f"Error running Dalle3: {error.error}", "red"))
|
|
raise error |