You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
367 lines
11 KiB
367 lines
11 KiB
import concurrent.futures
|
|
import logging
|
|
import os
|
|
import uuid
|
|
from dataclasses import dataclass
|
|
from io import BytesIO
|
|
from typing import List
|
|
|
|
import backoff
|
|
import openai
|
|
import requests
|
|
from cachetools import TTLCache
|
|
from dotenv import load_dotenv
|
|
from openai import OpenAI
|
|
from PIL import Image
|
|
from pydantic import validator
|
|
from termcolor import colored
|
|
|
|
load_dotenv()
|
|
|
|
# Configure Logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def handle_errors(self, function):
|
|
def wrapper(*args, **kwargs):
|
|
try:
|
|
return function(*args, **kwargs)
|
|
except Exception as error:
|
|
logger.error(error)
|
|
raise
|
|
|
|
return wrapper
|
|
|
|
|
|
@dataclass
|
|
class Dalle3:
|
|
"""
|
|
Dalle3 model class
|
|
|
|
Attributes:
|
|
-----------
|
|
image_url: str
|
|
The image url generated by the Dalle3 API
|
|
|
|
Methods:
|
|
--------
|
|
__call__(self, task: str) -> Dalle3:
|
|
Makes a call to the Dalle3 API and returns the image url
|
|
|
|
Example:
|
|
--------
|
|
>>> dalle3 = Dalle3()
|
|
>>> task = "A painting of a dog"
|
|
>>> image_url = dalle3(task)
|
|
>>> print(image_url)
|
|
https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png
|
|
|
|
"""
|
|
|
|
model: str = "dall-e-3"
|
|
img: str = None
|
|
size: str = "1024x1024"
|
|
max_retries: int = 3
|
|
quality: str = "standard"
|
|
openai_api_key: str = None or os.getenv("OPENAI_API_KEY")
|
|
n: int = 1
|
|
save_path: str = "images"
|
|
max_time_seconds: int = 60
|
|
save_folder: str = "images"
|
|
image_format: str = "png"
|
|
client = OpenAI(
|
|
api_key=openai_api_key,
|
|
)
|
|
cache = TTLCache(maxsize=100, ttl=3600)
|
|
dashboard: bool = False
|
|
|
|
def __post_init__(self):
|
|
"""Post init method"""
|
|
if self.openai_api_key is None:
|
|
raise ValueError("Please provide an openai api key")
|
|
if self.img is not None:
|
|
self.img = self.convert_to_bytesio(self.img)
|
|
|
|
os.makedirs(self.save_path, exist_ok=True)
|
|
|
|
class Config:
|
|
"""Config class for the Dalle3 model"""
|
|
|
|
arbitrary_types_allowed = True
|
|
|
|
@validator("max_retries", "time_seconds")
|
|
def must_be_positive(cls, value):
|
|
if value <= 0:
|
|
raise ValueError("Must be positive")
|
|
return value
|
|
|
|
def read_img(self, img: str):
|
|
"""Read the image using pil"""
|
|
img = Image.open(img)
|
|
return img
|
|
|
|
def set_width_height(self, img: str, width: int, height: int):
|
|
"""Set the width and height of the image"""
|
|
img = self.read_img(img)
|
|
img = img.resize((width, height))
|
|
return img
|
|
|
|
def convert_to_bytesio(self, img: str, format: str = "PNG"):
|
|
"""Convert the image to an bytes io object"""
|
|
byte_stream = BytesIO()
|
|
img.save(byte_stream, format=format)
|
|
byte_array = byte_stream.getvalue()
|
|
return byte_array
|
|
|
|
@backoff.on_exception(
|
|
backoff.expo, Exception, max_time=max_time_seconds
|
|
)
|
|
def __call__(self, task: str):
|
|
"""
|
|
Text to image conversion using the Dalle3 API
|
|
|
|
Parameters:
|
|
-----------
|
|
task: str
|
|
The task to be converted to an image
|
|
|
|
Returns:
|
|
--------
|
|
Dalle3:
|
|
An instance of the Dalle3 class with the image url generated by the Dalle3 API
|
|
|
|
Example:
|
|
--------
|
|
>>> dalle3 = Dalle3()
|
|
>>> task = "A painting of a dog"
|
|
>>> image_url = dalle3(task)
|
|
>>> print(image_url)
|
|
https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png
|
|
"""
|
|
if self.dashboard:
|
|
self.print_dashboard()
|
|
if task in self.cache:
|
|
return self.cache[task]
|
|
try:
|
|
# Making a call to the the Dalle3 API
|
|
response = self.client.images.generate(
|
|
model=self.model,
|
|
prompt=task,
|
|
size=self.size,
|
|
quality=self.quality,
|
|
n=self.n,
|
|
)
|
|
# Extracting the image url from the response
|
|
img = response.data[0].url
|
|
|
|
filename = f"{self._generate_uuid()}.{self.image_format}"
|
|
|
|
# Download and save the image
|
|
self._download_image(img, filename)
|
|
|
|
img_path = os.path.join(self.save_path, filename)
|
|
self.cache[task] = img_path
|
|
|
|
return img_path
|
|
except openai.OpenAIError as error:
|
|
# Handling exceptions and printing the errors details
|
|
print(
|
|
colored(
|
|
(
|
|
f"Error running Dalle3: {error} try"
|
|
" optimizing your api key and or try again"
|
|
),
|
|
"red",
|
|
)
|
|
)
|
|
raise error
|
|
|
|
def _generate_image_name(self, task: str):
|
|
"""Generate a sanitized file name based on the task"""
|
|
sanitized_task = "".join(
|
|
char for char in task if char.isalnum() or char in " _ -"
|
|
).rstrip()
|
|
return f"{sanitized_task}.{self.image_format}"
|
|
|
|
def _download_image(self, img_url: str, filename: str):
|
|
"""
|
|
Download the image from the given URL and save it to a specified filename within self.save_path.
|
|
|
|
Args:
|
|
img_url (str): URL of the image to download.
|
|
filename (str): Filename to save the image.
|
|
"""
|
|
full_path = os.path.join(self.save_path, filename)
|
|
response = requests.get(img_url)
|
|
if response.status_code == 200:
|
|
with open(full_path, "wb") as file:
|
|
file.write(response.content)
|
|
else:
|
|
raise ValueError(
|
|
f"Failed to download image from {img_url}"
|
|
)
|
|
|
|
def create_variations(self, img: str):
|
|
"""
|
|
Create variations of an image using the Dalle3 API
|
|
|
|
Parameters:
|
|
-----------
|
|
img: str
|
|
The image to be used for the API request
|
|
|
|
Returns:
|
|
--------
|
|
img: str
|
|
The image url generated by the Dalle3 API
|
|
|
|
Example:
|
|
--------
|
|
>>> dalle3 = Dalle3()
|
|
>>> img = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
|
|
>>> img = dalle3.create_variations(img)
|
|
>>> print(img)
|
|
|
|
|
|
"""
|
|
try:
|
|
response = self.client.images.create_variation(
|
|
img=open(img, "rb"), n=self.n, size=self.size
|
|
)
|
|
img = response.data[0].url
|
|
|
|
return img
|
|
except (Exception, openai.OpenAIError) as error:
|
|
print(
|
|
colored(
|
|
(
|
|
f"Error running Dalle3: {error} try"
|
|
" optimizing your api key and or try again"
|
|
),
|
|
"red",
|
|
)
|
|
)
|
|
print(
|
|
colored(
|
|
f"Error running Dalle3: {error.http_status}",
|
|
"red",
|
|
)
|
|
)
|
|
print(
|
|
colored(f"Error running Dalle3: {error.error}", "red")
|
|
)
|
|
raise error
|
|
|
|
def print_dashboard(self):
|
|
"""Print the Dalle3 dashboard"""
|
|
print(
|
|
colored(
|
|
f"""Dalle3 Dashboard:
|
|
--------------------
|
|
|
|
Model: {self.model}
|
|
Image: {self.img}
|
|
Size: {self.size}
|
|
Max Retries: {self.max_retries}
|
|
Quality: {self.quality}
|
|
N: {self.n}
|
|
Save Path: {self.save_path}
|
|
Time Seconds: {self.time_seconds}
|
|
Save Folder: {self.save_folder}
|
|
Image Format: {self.image_format}
|
|
--------------------
|
|
|
|
|
|
""",
|
|
"green",
|
|
)
|
|
)
|
|
|
|
def process_batch_concurrently(
|
|
self, tasks: List[str], max_workers: int = 5
|
|
):
|
|
"""
|
|
|
|
Process a batch of tasks concurrently
|
|
|
|
Args:
|
|
tasks (List[str]): A list of tasks to be processed
|
|
max_workers (int): The maximum number of workers to use for the concurrent processing
|
|
|
|
Returns:
|
|
--------
|
|
results (List[str]): A list of image urls generated by the Dalle3 API
|
|
|
|
Example:
|
|
--------
|
|
>>> dalle3 = Dalle3()
|
|
>>> tasks = ["A painting of a dog", "A painting of a cat"]
|
|
>>> results = dalle3.process_batch_concurrently(tasks)
|
|
>>> print(results)
|
|
['https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png',
|
|
|
|
"""
|
|
with concurrent.futures.ThreadPoolExecutor(
|
|
max_workers=max_workers
|
|
) as executor:
|
|
future_to_task = {
|
|
executor.submit(self, task): task for task in tasks
|
|
}
|
|
results = []
|
|
for future in concurrent.futures.as_completed(
|
|
future_to_task
|
|
):
|
|
task = future_to_task[future]
|
|
try:
|
|
img = future.result()
|
|
results.append(img)
|
|
|
|
print(f"Task {task} completed: {img}")
|
|
except Exception as error:
|
|
print(
|
|
colored(
|
|
(
|
|
f"Error running Dalle3: {error} try"
|
|
" optimizing your api key and or try"
|
|
" again"
|
|
),
|
|
"red",
|
|
)
|
|
)
|
|
print(
|
|
colored(
|
|
(
|
|
"Error running Dalle3:"
|
|
f" {error.http_status}"
|
|
),
|
|
"red",
|
|
)
|
|
)
|
|
print(
|
|
colored(
|
|
f"Error running Dalle3: {error.error}",
|
|
"red",
|
|
)
|
|
)
|
|
raise error
|
|
|
|
def _generate_uuid(self):
|
|
"""Generate a uuid"""
|
|
return str(uuid.uuid4())
|
|
|
|
def __repr__(self):
|
|
"""Repr method for the Dalle3 class"""
|
|
return f"Dalle3(image_url={self.image_url})"
|
|
|
|
def __str__(self):
|
|
"""Str method for the Dalle3 class"""
|
|
return f"Dalle3(image_url={self.image_url})"
|
|
|
|
@backoff.on_exception(
|
|
backoff.expo, Exception, max_tries=max_retries
|
|
)
|
|
def rate_limited_call(self, task: str):
|
|
"""Rate limited call to the Dalle3 API"""
|
|
return self.__call__(task)
|