Fix formatting

pull/388/head
Wyatt Stanke 11 months ago
parent 1a5aed51cd
commit 72a1d90ee1
No known key found for this signature in database
GPG Key ID: CE6BA5FFF135536D

@ -96,7 +96,9 @@ class SSD1B:
byte_array = byte_stream.getvalue()
return byte_array
@backoff.on_exception(backoff.expo, Exception, max_time=max_time_seconds)
@backoff.on_exception(
backoff.expo, Exception, max_time=max_time_seconds
)
def __call__(self, task: str, neg_prompt: str):
"""
Text to image conversion using the SSD1B API
@ -124,7 +126,9 @@ class SSD1B:
if task in self.cache:
return self.cache[task]
try:
img = self.pipe(prompt=task, neg_prompt=neg_prompt).images[0]
img = self.pipe(
prompt=task, neg_prompt=neg_prompt
).images[0]
# Generate a unique filename for the image
img_name = f"{uuid.uuid4()}.{self.image_format}"
@ -140,7 +144,10 @@ class SSD1B:
# Handling exceptions and printing the errors details
print(
colored(
(f"Error running SSD1B: {error} try optimizing" " your api key and or try again"),
(
f"Error running SSD1B: {error} try optimizing"
" your api key and or try again"
),
"red",
)
)
@ -148,7 +155,9 @@ class SSD1B:
def _generate_image_name(self, task: str):
"""Generate a sanitized file name based on the task"""
sanitized_task = "".join(char for char in task if char.isalnum() or char in " _ -").rstrip()
sanitized_task = "".join(
char for char in task if char.isalnum() or char in " _ -"
).rstrip()
return f"{sanitized_task}.{self.image_format}"
def _download_image(self, img: Image, filename: str):
@ -183,7 +192,9 @@ class SSD1B:
)
)
def process_batch_concurrently(self, tasks: List[str], max_workers: int = 5):
def process_batch_concurrently(
self, tasks: List[str], max_workers: int = 5
):
"""
Process a batch of tasks concurrently
@ -204,10 +215,16 @@ class SSD1B:
>>> print(results)
"""
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_task = {executor.submit(self, task): task for task in tasks}
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers
) as executor:
future_to_task = {
executor.submit(self, task): task for task in tasks
}
results = []
for future in concurrent.futures.as_completed(future_to_task):
for future in concurrent.futures.as_completed(
future_to_task
):
task = future_to_task[future]
try:
img = future.result()
@ -217,13 +234,20 @@ class SSD1B:
except Exception as error:
print(
colored(
(f"Error running SSD1B: {error} try" " optimizing your api key and or try" " again"),
(
f"Error running SSD1B: {error} try"
" optimizing your api key and or try"
" again"
),
"red",
)
)
print(
colored(
("Error running SSD1B:" f" {error.http_status}"),
(
"Error running SSD1B:"
f" {error.http_status}"
),
"red",
)
)
@ -247,7 +271,9 @@ class SSD1B:
"""Str method for the SSD1B class"""
return f"SSD1B(image_url={self.image_url})"
@backoff.on_exception(backoff.expo, Exception, max_tries=max_retries)
@backoff.on_exception(
backoff.expo, Exception, max_tries=max_retries
)
def rate_limited_call(self, task: str):
"""Rate limited call to the SSD1B API"""
return self.__call__(task)

Loading…
Cancel
Save