From 127ef4a521c3cf8b88b60dd2aaa37fb338b2d64d Mon Sep 17 00:00:00 2001 From: Wyatt Stanke Date: Thu, 29 Feb 2024 15:38:21 -0500 Subject: [PATCH] Use cuda only if avaliable --- swarms/models/ssd_1b.py | 50 ++++++++++------------------------------- 1 file changed, 12 insertions(+), 38 deletions(-) diff --git a/swarms/models/ssd_1b.py b/swarms/models/ssd_1b.py index 1dc6c00a..3e98a08c 100644 --- a/swarms/models/ssd_1b.py +++ b/swarms/models/ssd_1b.py @@ -49,7 +49,7 @@ class SSD1B: max_time_seconds: int = 60 save_folder: str = "images" image_format: str = "png" - device: str = "cuda" + device: str = "cuda" if torch.cuda.is_available() else "cpu" dashboard: bool = False cache = TTLCache(maxsize=100, ttl=3600) pipe = StableDiffusionXLPipeline.from_pretrained( @@ -96,9 +96,7 @@ class SSD1B: byte_array = byte_stream.getvalue() return byte_array - @backoff.on_exception( - backoff.expo, Exception, max_time=max_time_seconds - ) + @backoff.on_exception(backoff.expo, Exception, max_time=max_time_seconds) def __call__(self, task: str, neg_prompt: str): """ Text to image conversion using the SSD1B API @@ -126,9 +124,7 @@ class SSD1B: if task in self.cache: return self.cache[task] try: - img = self.pipe( - prompt=task, neg_prompt=neg_prompt - ).images[0] + img = self.pipe(prompt=task, neg_prompt=neg_prompt).images[0] # Generate a unique filename for the image img_name = f"{uuid.uuid4()}.{self.image_format}" @@ -144,10 +140,7 @@ class SSD1B: # Handling exceptions and printing the errors details print( colored( - ( - f"Error running SSD1B: {error} try optimizing" - " your api key and or try again" - ), + (f"Error running SSD1B: {error} try optimizing" " your api key and or try again"), "red", ) ) @@ -155,9 +148,7 @@ class SSD1B: def _generate_image_name(self, task: str): """Generate a sanitized file name based on the task""" - sanitized_task = "".join( - char for char in task if char.isalnum() or char in " _ -" - ).rstrip() + sanitized_task = "".join(char for char in task if char.isalnum() or char in " _ -").rstrip() return f"{sanitized_task}.{self.image_format}" def _download_image(self, img: Image, filename: str): @@ -192,9 +183,7 @@ class SSD1B: ) ) - def process_batch_concurrently( - self, tasks: List[str], max_workers: int = 5 - ): + def process_batch_concurrently(self, tasks: List[str], max_workers: int = 5): """ Process a batch of tasks concurrently @@ -215,16 +204,10 @@ class SSD1B: >>> print(results) """ - with concurrent.futures.ThreadPoolExecutor( - max_workers=max_workers - ) as executor: - future_to_task = { - executor.submit(self, task): task for task in tasks - } + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_task = {executor.submit(self, task): task for task in tasks} results = [] - for future in concurrent.futures.as_completed( - future_to_task - ): + for future in concurrent.futures.as_completed(future_to_task): task = future_to_task[future] try: img = future.result() @@ -234,20 +217,13 @@ class SSD1B: except Exception as error: print( colored( - ( - f"Error running SSD1B: {error} try" - " optimizing your api key and or try" - " again" - ), + (f"Error running SSD1B: {error} try" " optimizing your api key and or try" " again"), "red", ) ) print( colored( - ( - "Error running SSD1B:" - f" {error.http_status}" - ), + ("Error running SSD1B:" f" {error.http_status}"), "red", ) ) @@ -271,9 +247,7 @@ class SSD1B: """Str method for the SSD1B class""" return f"SSD1B(image_url={self.image_url})" - @backoff.on_exception( - backoff.expo, Exception, max_tries=max_retries - ) + @backoff.on_exception(backoff.expo, Exception, max_tries=max_retries) def rate_limited_call(self, task: str): """Rate limited call to the SSD1B API""" return self.__call__(task)