diff --git a/images/10f498c2-e22a-4f7f-9e50-56bf1ef92629.png b/images/10f498c2-e22a-4f7f-9e50-56bf1ef92629.png new file mode 100644 index 00000000..1dece111 Binary files /dev/null and b/images/10f498c2-e22a-4f7f-9e50-56bf1ef92629.png differ diff --git a/images/1c990ee0-ed68-4375-9731-9c9c25a72fac.png b/images/1c990ee0-ed68-4375-9731-9c9c25a72fac.png new file mode 100644 index 00000000..c4b740c3 Binary files /dev/null and b/images/1c990ee0-ed68-4375-9731-9c9c25a72fac.png differ diff --git a/images/2570cc4b-fafe-4f41-8193-ea9b563156e4.png b/images/2570cc4b-fafe-4f41-8193-ea9b563156e4.png new file mode 100644 index 00000000..dfb8834f Binary files /dev/null and b/images/2570cc4b-fafe-4f41-8193-ea9b563156e4.png differ diff --git a/images/35661b4a-f230-47a1-91bf-f876935151ed.png b/images/35661b4a-f230-47a1-91bf-f876935151ed.png new file mode 100644 index 00000000..163d9a6c Binary files /dev/null and b/images/35661b4a-f230-47a1-91bf-f876935151ed.png differ diff --git a/images/4b2161eb-bc44-4ee9-b106-208408b81d42.png b/images/4b2161eb-bc44-4ee9-b106-208408b81d42.png new file mode 100644 index 00000000..5f0af5c1 Binary files /dev/null and b/images/4b2161eb-bc44-4ee9-b106-208408b81d42.png differ diff --git a/images/4e4ea9d1-e1e3-4609-a200-8d83b5912a44.png b/images/4e4ea9d1-e1e3-4609-a200-8d83b5912a44.png new file mode 100644 index 00000000..1aa8fb4b Binary files /dev/null and b/images/4e4ea9d1-e1e3-4609-a200-8d83b5912a44.png differ diff --git a/images/5081867f-bb73-4ece-b746-df2247e55da5.png b/images/5081867f-bb73-4ece-b746-df2247e55da5.png new file mode 100644 index 00000000..6967f2e2 Binary files /dev/null and b/images/5081867f-bb73-4ece-b746-df2247e55da5.png differ diff --git a/images/a3fd26f3-0ee7-49b1-9e05-60b1dde1a1a8.png b/images/a3fd26f3-0ee7-49b1-9e05-60b1dde1a1a8.png new file mode 100644 index 00000000..725208ef Binary files /dev/null and b/images/a3fd26f3-0ee7-49b1-9e05-60b1dde1a1a8.png differ diff --git a/images/af8e6856-9d24-46d5-81fc-c9b2010d5d77.png b/images/af8e6856-9d24-46d5-81fc-c9b2010d5d77.png new file mode 100644 index 00000000..8d77ea96 Binary files /dev/null and b/images/af8e6856-9d24-46d5-81fc-c9b2010d5d77.png differ diff --git a/images/f0f1d0e8-1672-4b9c-af1f-e6979f8a407c.png b/images/f0f1d0e8-1672-4b9c-af1f-e6979f8a407c.png new file mode 100644 index 00000000..14827bf8 Binary files /dev/null and b/images/f0f1d0e8-1672-4b9c-af1f-e6979f8a407c.png differ diff --git a/images/f4992864-b211-4510-9e4a-1148470dd5ec.png b/images/f4992864-b211-4510-9e4a-1148470dd5ec.png new file mode 100644 index 00000000..f28f6bc0 Binary files /dev/null and b/images/f4992864-b211-4510-9e4a-1148470dd5ec.png differ diff --git a/images/ffd2e03f-4238-4b6d-b29e-a3b41624ceae.png b/images/ffd2e03f-4238-4b6d-b29e-a3b41624ceae.png new file mode 100644 index 00000000..5b99e316 Binary files /dev/null and b/images/ffd2e03f-4238-4b6d-b29e-a3b41624ceae.png differ diff --git a/playground/models/dalle3_concurrent.py b/playground/models/dalle3_concurrent.py new file mode 100644 index 00000000..af18db3a --- /dev/null +++ b/playground/models/dalle3_concurrent.py @@ -0,0 +1,24 @@ +""" + +User task ->> GPT4 for prompt enrichment ->> Dalle3V for image generation +->> GPT4Vision for image captioning ->> Dalle3 better image + +""" +from swarms.models.dalle3 import Dalle3 +import os + +api_key = os.environ["OPENAI_API_KEY"] + +dalle3 = Dalle3(openai_api_key=api_key, n=1) + +# task = "Swarm of robots working super industrial ambience concept art" + +# image_url = dalle3(task) + +tasks = ["A painting of a dog", "A painting of a cat"] +results = dalle3.process_batch_concurrently(tasks) + +# print(results) + + + diff --git a/pyproject.toml b/pyproject.toml index 9af0ab78..8ff1df05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarms" -version = "2.1.6" +version = "2.1.7" description = "Swarms - Pytorch" license = "MIT" authors = ["Kye Gomez "] @@ -35,6 +35,7 @@ langchain-experimental = "*" playwright = "*" duckduckgo-search = "*" faiss-cpu = "*" +backoff = "*" datasets = "*" diffusers = "*" accelerate = "*" diff --git a/requirements.txt b/requirements.txt index d28e75e7..1a74a36e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,6 +35,7 @@ tabulate colored griptape addict +backoff ratelimit albumentations basicsr diff --git a/swarms/models/autotemp.py b/swarms/models/autotemp.py index d238e117..3c89ad73 100644 --- a/swarms/models/autotemp.py +++ b/swarms/models/autotemp.py @@ -2,6 +2,7 @@ import re from concurrent.futures import ThreadPoolExecutor, as_completed from swarms.models.auto_temp import OpenAIChat + class AutoTempAgent: """ AutoTemp is a tool for automatically selecting the best temperature setting for a given task. @@ -31,6 +32,7 @@ class AutoTempAgent: Generate a 10,000 word blog on mental clarity and the benefits of meditation. """ + def __init__( self, temperature: float = 0.5, diff --git a/swarms/models/dalle3.py b/swarms/models/dalle3.py index c24f262d..bb20b968 100644 --- a/swarms/models/dalle3.py +++ b/swarms/models/dalle3.py @@ -1,9 +1,15 @@ +import concurrent.futures import logging import os +import uuid from dataclasses import dataclass from io import BytesIO +from typing import List, Optional +import backoff import openai +import requests +from cachetools import TTLCache from dotenv import load_dotenv from openai import OpenAI from PIL import Image @@ -19,6 +25,17 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + +def handle_errors(self, function): + def wrapper(*args, **kwargs): + try: + return function(*args, **kwargs) + except Exception as error: + logger.error(error) + raise + return wrapper + + @dataclass class Dalle3: """ @@ -49,12 +66,26 @@ class Dalle3: size: str = "1024x1024" max_retries: int = 3 quality: str = "standard" - api_key: str = None - n: int = 4 + openai_api_key: str = None + n: int = 1 + save_path: str = "images" + max_time_seconds: int = 60 + save_folder: str = "images" + image_format: str = "png" client = OpenAI( - api_key=api_key, - max_retries=max_retries, + api_key=openai_api_key, ) + cache = TTLCache(maxsize=100, ttl=3600) + dashboard: bool = False + + def __post_init__(self): + """Post init method""" + if self.openai_api_key is None: + raise ValueError("Please provide an openai api key") + if self.img is not None: + self.img = self.convert_to_bytesio(self.img) + + os.makedirs(self.save_path, exist_ok=True) class Config: """Config class for the Dalle3 model""" @@ -84,8 +115,8 @@ class Dalle3: img.save(byte_stream, format=format) byte_array = byte_stream.getvalue() return byte_array - - # @lru_cache(maxsize=32) + + @backoff.on_exception(backoff.expo, Exception, max_time=max_time_seconds) def __call__(self, task: str): """ Text to image conversion using the Dalle3 API @@ -108,6 +139,10 @@ class Dalle3: >>> print(image_url) https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png """ + if self.dashboard: + self.print_dashboard() + if task in self.cache: + return self.cache[task] try: # Making a call to the the Dalle3 API response = self.client.images.generate( @@ -119,7 +154,16 @@ class Dalle3: ) # Extracting the image url from the response img = response.data[0].url - return img + + filename = f"{self._generate_uuid()}.{self.image_format}" + + # Download and save the image + self._download_image(img, filename) + + img_path = os.path.join(self.save_path, filename) + self.cache[task] = img_path + + return img_path except openai.OpenAIError as error: # Handling exceptions and printing the errors details print( @@ -133,6 +177,29 @@ class Dalle3: ) raise error + def _generate_image_name(self, task: str): + """Generate a sanitized file name based on the task""" + sanitized_task = "".join( + char for char in task if char.isalnum() or char in " _ -" + ).rstrip() + return f"{sanitized_task}.{self.image_format}" + + def _download_image(self, img_url: str, filename: str): + """ + Download the image from the given URL and save it to a specified filename within self.save_path. + + Args: + img_url (str): URL of the image to download. + filename (str): Filename to save the image. + """ + full_path = os.path.join(self.save_path, filename) + response = requests.get(img_url) + if response.status_code == 200: + with open(full_path, 'wb') as file: + file.write(response.content) + else: + raise ValueError(f"Failed to download image from {img_url}") + def create_variations(self, img: str): """ Create variations of an image using the Dalle3 API @@ -176,3 +243,100 @@ class Dalle3: print(colored(f"Error running Dalle3: {error.http_status}", "red")) print(colored(f"Error running Dalle3: {error.error}", "red")) raise error + + def print_dashboard( + self + ): + """Print the Dalle3 dashboard""" + print( + colored( + ( + f"""Dalle3 Dashboard: + -------------------- + + Model: {self.model} + Image: {self.img} + Size: {self.size} + Max Retries: {self.max_retries} + Quality: {self.quality} + N: {self.n} + Save Path: {self.save_path} + Time Seconds: {self.time_seconds} + Save Folder: {self.save_folder} + Image Format: {self.image_format} + -------------------- + + + """ + ), + "green", + ) + ) + + def process_batch_concurrently( + self, + tasks: List[str], + max_workers: int = 5 + ): + """ + + Process a batch of tasks concurrently + + Args: + tasks (List[str]): A list of tasks to be processed + max_workers (int): The maximum number of workers to use for the concurrent processing + + Returns: + -------- + results (List[str]): A list of image urls generated by the Dalle3 API + + Example: + -------- + >>> dalle3 = Dalle3() + >>> tasks = ["A painting of a dog", "A painting of a cat"] + >>> results = dalle3.process_batch_concurrently(tasks) + >>> print(results) + ['https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png', + + """ + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers + ) as executor: + future_to_task = {executor.submit(self, task): task for task in tasks} + results = [] + for future in concurrent.futures.as_completed(future_to_task): + task = future_to_task[future] + try: + img = future.result() + results.append(img) + + print(f"Task {task} completed: {img}") + except Exception as error: + print( + colored( + ( + f"Error running Dalle3: {error} try optimizing your api key and" + " or try again" + ), + "red", + ) + ) + print(colored(f"Error running Dalle3: {error.http_status}", "red")) + print(colored(f"Error running Dalle3: {error.error}", "red")) + raise error + def _generate_uuid(self): + """Generate a uuid""" + return str(uuid.uuid4()) + + def __repr__(self): + """Repr method for the Dalle3 class""" + return f"Dalle3(image_url={self.image_url})" + + def __str__(self): + """Str method for the Dalle3 class""" + return f"Dalle3(image_url={self.image_url})" + + @backoff.on_exception(backoff.expo, Exception, max_tries=max_retries) + def rate_limited_call(self, task: str): + """Rate limited call to the Dalle3 API""" + return self.__call__(task) \ No newline at end of file diff --git a/tests/models/auto_temp.py b/tests/models/auto_temp.py index 14468379..a3461769 100644 --- a/tests/models/auto_temp.py +++ b/tests/models/auto_temp.py @@ -11,6 +11,7 @@ api_key = os.getenv("OPENAI_API_KEY") load_dotenv() + @pytest.fixture def auto_temp_agent(): return AutoTempAgent(api_key=api_key) @@ -47,7 +48,9 @@ def test_run_no_scores(auto_temp_agent): task = "Invalid task." temperature_string = "0.4,0.6,0.8,1.0,1.2,1.4" with ThreadPoolExecutor(max_workers=auto_temp_agent.max_workers) as executor: - with patch.object(executor, "submit", side_effect=[None, None, None, None, None, None]): + with patch.object( + executor, "submit", side_effect=[None, None, None, None, None, None] + ): result = auto_temp_agent.run(task, temperature_string) assert result == "No valid outputs generated."