dalle3 production grade ready

Former-commit-id: 41e5f17115
grit/923f7c6f-0958-480b-8748-ea6bbf1c2084
Kye 1 year ago
parent 252f5afc3e
commit 1cc295415e

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 MiB

@ -0,0 +1,24 @@
"""
User task ->> GPT4 for prompt enrichment ->> Dalle3V for image generation
->> GPT4Vision for image captioning ->> Dalle3 better image
"""
from swarms.models.dalle3 import Dalle3
import os
api_key = os.environ["OPENAI_API_KEY"]
dalle3 = Dalle3(openai_api_key=api_key, n=1)
# task = "Swarm of robots working super industrial ambience concept art"
# image_url = dalle3(task)
tasks = ["A painting of a dog", "A painting of a cat"]
results = dalle3.process_batch_concurrently(tasks)
# print(results)

@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry] [tool.poetry]
name = "swarms" name = "swarms"
version = "2.1.6" version = "2.1.7"
description = "Swarms - Pytorch" description = "Swarms - Pytorch"
license = "MIT" license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"] authors = ["Kye Gomez <kye@apac.ai>"]
@ -35,6 +35,7 @@ langchain-experimental = "*"
playwright = "*" playwright = "*"
duckduckgo-search = "*" duckduckgo-search = "*"
faiss-cpu = "*" faiss-cpu = "*"
backoff = "*"
datasets = "*" datasets = "*"
diffusers = "*" diffusers = "*"
accelerate = "*" accelerate = "*"

@ -35,6 +35,7 @@ tabulate
colored colored
griptape griptape
addict addict
backoff
ratelimit ratelimit
albumentations albumentations
basicsr basicsr

@ -2,6 +2,7 @@ import re
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from swarms.models.auto_temp import OpenAIChat from swarms.models.auto_temp import OpenAIChat
class AutoTempAgent: class AutoTempAgent:
""" """
AutoTemp is a tool for automatically selecting the best temperature setting for a given task. AutoTemp is a tool for automatically selecting the best temperature setting for a given task.
@ -31,6 +32,7 @@ class AutoTempAgent:
Generate a 10,000 word blog on mental clarity and the benefits of meditation. Generate a 10,000 word blog on mental clarity and the benefits of meditation.
""" """
def __init__( def __init__(
self, self,
temperature: float = 0.5, temperature: float = 0.5,

@ -1,9 +1,15 @@
import concurrent.futures
import logging import logging
import os import os
import uuid
from dataclasses import dataclass from dataclasses import dataclass
from io import BytesIO from io import BytesIO
from typing import List, Optional
import backoff
import openai import openai
import requests
from cachetools import TTLCache
from dotenv import load_dotenv from dotenv import load_dotenv
from openai import OpenAI from openai import OpenAI
from PIL import Image from PIL import Image
@ -19,6 +25,17 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def handle_errors(self, function):
def wrapper(*args, **kwargs):
try:
return function(*args, **kwargs)
except Exception as error:
logger.error(error)
raise
return wrapper
@dataclass @dataclass
class Dalle3: class Dalle3:
""" """
@ -49,12 +66,26 @@ class Dalle3:
size: str = "1024x1024" size: str = "1024x1024"
max_retries: int = 3 max_retries: int = 3
quality: str = "standard" quality: str = "standard"
api_key: str = None openai_api_key: str = None
n: int = 4 n: int = 1
save_path: str = "images"
max_time_seconds: int = 60
save_folder: str = "images"
image_format: str = "png"
client = OpenAI( client = OpenAI(
api_key=api_key, api_key=openai_api_key,
max_retries=max_retries,
) )
cache = TTLCache(maxsize=100, ttl=3600)
dashboard: bool = False
def __post_init__(self):
"""Post init method"""
if self.openai_api_key is None:
raise ValueError("Please provide an openai api key")
if self.img is not None:
self.img = self.convert_to_bytesio(self.img)
os.makedirs(self.save_path, exist_ok=True)
class Config: class Config:
"""Config class for the Dalle3 model""" """Config class for the Dalle3 model"""
@ -84,8 +115,8 @@ class Dalle3:
img.save(byte_stream, format=format) img.save(byte_stream, format=format)
byte_array = byte_stream.getvalue() byte_array = byte_stream.getvalue()
return byte_array return byte_array
# @lru_cache(maxsize=32) @backoff.on_exception(backoff.expo, Exception, max_time=max_time_seconds)
def __call__(self, task: str): def __call__(self, task: str):
""" """
Text to image conversion using the Dalle3 API Text to image conversion using the Dalle3 API
@ -108,6 +139,10 @@ class Dalle3:
>>> print(image_url) >>> print(image_url)
https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png
""" """
if self.dashboard:
self.print_dashboard()
if task in self.cache:
return self.cache[task]
try: try:
# Making a call to the the Dalle3 API # Making a call to the the Dalle3 API
response = self.client.images.generate( response = self.client.images.generate(
@ -119,7 +154,16 @@ class Dalle3:
) )
# Extracting the image url from the response # Extracting the image url from the response
img = response.data[0].url img = response.data[0].url
return img
filename = f"{self._generate_uuid()}.{self.image_format}"
# Download and save the image
self._download_image(img, filename)
img_path = os.path.join(self.save_path, filename)
self.cache[task] = img_path
return img_path
except openai.OpenAIError as error: except openai.OpenAIError as error:
# Handling exceptions and printing the errors details # Handling exceptions and printing the errors details
print( print(
@ -133,6 +177,29 @@ class Dalle3:
) )
raise error raise error
def _generate_image_name(self, task: str):
"""Generate a sanitized file name based on the task"""
sanitized_task = "".join(
char for char in task if char.isalnum() or char in " _ -"
).rstrip()
return f"{sanitized_task}.{self.image_format}"
def _download_image(self, img_url: str, filename: str):
"""
Download the image from the given URL and save it to a specified filename within self.save_path.
Args:
img_url (str): URL of the image to download.
filename (str): Filename to save the image.
"""
full_path = os.path.join(self.save_path, filename)
response = requests.get(img_url)
if response.status_code == 200:
with open(full_path, 'wb') as file:
file.write(response.content)
else:
raise ValueError(f"Failed to download image from {img_url}")
def create_variations(self, img: str): def create_variations(self, img: str):
""" """
Create variations of an image using the Dalle3 API Create variations of an image using the Dalle3 API
@ -176,3 +243,100 @@ class Dalle3:
print(colored(f"Error running Dalle3: {error.http_status}", "red")) print(colored(f"Error running Dalle3: {error.http_status}", "red"))
print(colored(f"Error running Dalle3: {error.error}", "red")) print(colored(f"Error running Dalle3: {error.error}", "red"))
raise error raise error
def print_dashboard(
self
):
"""Print the Dalle3 dashboard"""
print(
colored(
(
f"""Dalle3 Dashboard:
--------------------
Model: {self.model}
Image: {self.img}
Size: {self.size}
Max Retries: {self.max_retries}
Quality: {self.quality}
N: {self.n}
Save Path: {self.save_path}
Time Seconds: {self.time_seconds}
Save Folder: {self.save_folder}
Image Format: {self.image_format}
--------------------
"""
),
"green",
)
)
def process_batch_concurrently(
self,
tasks: List[str],
max_workers: int = 5
):
"""
Process a batch of tasks concurrently
Args:
tasks (List[str]): A list of tasks to be processed
max_workers (int): The maximum number of workers to use for the concurrent processing
Returns:
--------
results (List[str]): A list of image urls generated by the Dalle3 API
Example:
--------
>>> dalle3 = Dalle3()
>>> tasks = ["A painting of a dog", "A painting of a cat"]
>>> results = dalle3.process_batch_concurrently(tasks)
>>> print(results)
['https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png',
"""
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers
) as executor:
future_to_task = {executor.submit(self, task): task for task in tasks}
results = []
for future in concurrent.futures.as_completed(future_to_task):
task = future_to_task[future]
try:
img = future.result()
results.append(img)
print(f"Task {task} completed: {img}")
except Exception as error:
print(
colored(
(
f"Error running Dalle3: {error} try optimizing your api key and"
" or try again"
),
"red",
)
)
print(colored(f"Error running Dalle3: {error.http_status}", "red"))
print(colored(f"Error running Dalle3: {error.error}", "red"))
raise error
def _generate_uuid(self):
"""Generate a uuid"""
return str(uuid.uuid4())
def __repr__(self):
"""Repr method for the Dalle3 class"""
return f"Dalle3(image_url={self.image_url})"
def __str__(self):
"""Str method for the Dalle3 class"""
return f"Dalle3(image_url={self.image_url})"
@backoff.on_exception(backoff.expo, Exception, max_tries=max_retries)
def rate_limited_call(self, task: str):
"""Rate limited call to the Dalle3 API"""
return self.__call__(task)

@ -11,6 +11,7 @@ api_key = os.getenv("OPENAI_API_KEY")
load_dotenv() load_dotenv()
@pytest.fixture @pytest.fixture
def auto_temp_agent(): def auto_temp_agent():
return AutoTempAgent(api_key=api_key) return AutoTempAgent(api_key=api_key)
@ -47,7 +48,9 @@ def test_run_no_scores(auto_temp_agent):
task = "Invalid task." task = "Invalid task."
temperature_string = "0.4,0.6,0.8,1.0,1.2,1.4" temperature_string = "0.4,0.6,0.8,1.0,1.2,1.4"
with ThreadPoolExecutor(max_workers=auto_temp_agent.max_workers) as executor: with ThreadPoolExecutor(max_workers=auto_temp_agent.max_workers) as executor:
with patch.object(executor, "submit", side_effect=[None, None, None, None, None, None]): with patch.object(
executor, "submit", side_effect=[None, None, None, None, None, None]
):
result = auto_temp_agent.run(task, temperature_string) result = auto_temp_agent.run(task, temperature_string)
assert result == "No valid outputs generated." assert result == "No valid outputs generated."

Loading…
Cancel
Save