import logging import os from dataclasses import dataclass from io import BytesIO import openai from dotenv import load_dotenv from openai import OpenAI from PIL import Image from pydantic import validator from termcolor import colored load_dotenv() api_key = os.getenv("OPENAI_API_KEY") # Configure Logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @dataclass class Dalle3: """ Dalle3 model class Attributes: ----------- image_url: str The image url generated by the Dalle3 API Methods: -------- __call__(self, task: str) -> Dalle3: Makes a call to the Dalle3 API and returns the image url Example: -------- >>> dalle3 = Dalle3() >>> task = "A painting of a dog" >>> image_url = dalle3(task) >>> print(image_url) https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png """ model: str = "dall-e-3" img: str = None size: str = "1024x1024" max_retries: int = 3 quality: str = "standard" n: int = 4 client = OpenAI( api_key=api_key, max_retries=max_retries, ) class Config: """Config class for the Dalle3 model""" arbitrary_types_allowed = True @validator("max_retries", "time_seconds") def must_be_positive(cls, value): if value <= 0: raise ValueError("Must be positive") return value def read_img(self, img: str): """Read the image using pil""" img = Image.open(img) return img def set_width_height(self, img: str, width: int, height: int): """Set the width and height of the image""" img = self.read_img(img) img = img.resize((width, height)) return img def convert_to_bytesio(self, img: str, format: str = "PNG"): """Convert the image to an bytes io object""" byte_stream = BytesIO() img.save(byte_stream, format=format) byte_array = byte_stream.getvalue() return byte_array # @lru_cache(maxsize=32) def __call__(self, task: str): """ Text to image conversion using the Dalle3 API Parameters: ----------- task: str The task to be converted to an image Returns: -------- Dalle3: An instance of the Dalle3 class with the image url generated by the Dalle3 API Example: -------- >>> dalle3 = Dalle3() >>> task = "A painting of a dog" >>> image_url = dalle3(task) >>> print(image_url) https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png """ try: # Making a call to the the Dalle3 API response = self.client.images.generate( model=self.model, prompt=task, size=self.size, quality=self.quality, n=self.n, ) # Extracting the image url from the response img = response.data[0].url return img except openai.OpenAIError as error: # Handling exceptions and printing the errors details print( colored( f"Error running Dalle3: {error} try optimizing your api key and or try again", "red", ) ) raise error def create_variations(self, img: str): """ Create variations of an image using the Dalle3 API Parameters: ----------- img: str The image to be used for the API request Returns: -------- img: str The image url generated by the Dalle3 API Example: -------- >>> dalle3 = Dalle3() >>> img = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" >>> img = dalle3.create_variations(img) >>> print(img) """ try: response = self.client.images.create_variation( img = open(img, "rb"), n=self.n, size=self.size ) img = response.data[0].url return img except (Exception, openai.OpenAIError) as error: print( colored( f"Error running Dalle3: {error} try optimizing your api key and or try again", "red", ) ) print(colored(f"Error running Dalle3: {error.http_status}", "red")) print(colored(f"Error running Dalle3: {error.error}", "red")) raise error