diff --git a/swarms/models/ssd_1b.py b/swarms/models/ssd_1b.py index 3e98a08c..7b17fdac 100644 --- a/swarms/models/ssd_1b.py +++ b/swarms/models/ssd_1b.py @@ -96,7 +96,9 @@ class SSD1B: byte_array = byte_stream.getvalue() return byte_array - @backoff.on_exception(backoff.expo, Exception, max_time=max_time_seconds) + @backoff.on_exception( + backoff.expo, Exception, max_time=max_time_seconds + ) def __call__(self, task: str, neg_prompt: str): """ Text to image conversion using the SSD1B API @@ -124,7 +126,9 @@ class SSD1B: if task in self.cache: return self.cache[task] try: - img = self.pipe(prompt=task, neg_prompt=neg_prompt).images[0] + img = self.pipe( + prompt=task, neg_prompt=neg_prompt + ).images[0] # Generate a unique filename for the image img_name = f"{uuid.uuid4()}.{self.image_format}" @@ -140,7 +144,10 @@ class SSD1B: # Handling exceptions and printing the errors details print( colored( - (f"Error running SSD1B: {error} try optimizing" " your api key and or try again"), + ( + f"Error running SSD1B: {error} try optimizing" + " your api key and or try again" + ), "red", ) ) @@ -148,7 +155,9 @@ class SSD1B: def _generate_image_name(self, task: str): """Generate a sanitized file name based on the task""" - sanitized_task = "".join(char for char in task if char.isalnum() or char in " _ -").rstrip() + sanitized_task = "".join( + char for char in task if char.isalnum() or char in " _ -" + ).rstrip() return f"{sanitized_task}.{self.image_format}" def _download_image(self, img: Image, filename: str): @@ -183,7 +192,9 @@ class SSD1B: ) ) - def process_batch_concurrently(self, tasks: List[str], max_workers: int = 5): + def process_batch_concurrently( + self, tasks: List[str], max_workers: int = 5 + ): """ Process a batch of tasks concurrently @@ -204,10 +215,16 @@ class SSD1B: >>> print(results) """ - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_task = {executor.submit(self, task): task for task in tasks} + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers + ) as executor: + future_to_task = { + executor.submit(self, task): task for task in tasks + } results = [] - for future in concurrent.futures.as_completed(future_to_task): + for future in concurrent.futures.as_completed( + future_to_task + ): task = future_to_task[future] try: img = future.result() @@ -217,13 +234,20 @@ class SSD1B: except Exception as error: print( colored( - (f"Error running SSD1B: {error} try" " optimizing your api key and or try" " again"), + ( + f"Error running SSD1B: {error} try" + " optimizing your api key and or try" + " again" + ), "red", ) ) print( colored( - ("Error running SSD1B:" f" {error.http_status}"), + ( + "Error running SSD1B:" + f" {error.http_status}" + ), "red", ) ) @@ -247,7 +271,9 @@ class SSD1B: """Str method for the SSD1B class""" return f"SSD1B(image_url={self.image_url})" - @backoff.on_exception(backoff.expo, Exception, max_tries=max_retries) + @backoff.on_exception( + backoff.expo, Exception, max_tries=max_retries + ) def rate_limited_call(self, task: str): """Rate limited call to the SSD1B API""" return self.__call__(task)