Use cuda only if avaliable

pull/388/head
Wyatt Stanke 11 months ago
parent b44ca7919d
commit 127ef4a521
No known key found for this signature in database
GPG Key ID: CE6BA5FFF135536D

@ -49,7 +49,7 @@ class SSD1B:
max_time_seconds: int = 60
save_folder: str = "images"
image_format: str = "png"
device: str = "cuda"
device: str = "cuda" if torch.cuda.is_available() else "cpu"
dashboard: bool = False
cache = TTLCache(maxsize=100, ttl=3600)
pipe = StableDiffusionXLPipeline.from_pretrained(
@ -96,9 +96,7 @@ class SSD1B:
byte_array = byte_stream.getvalue()
return byte_array
@backoff.on_exception(
backoff.expo, Exception, max_time=max_time_seconds
)
@backoff.on_exception(backoff.expo, Exception, max_time=max_time_seconds)
def __call__(self, task: str, neg_prompt: str):
"""
Text to image conversion using the SSD1B API
@ -126,9 +124,7 @@ class SSD1B:
if task in self.cache:
return self.cache[task]
try:
img = self.pipe(
prompt=task, neg_prompt=neg_prompt
).images[0]
img = self.pipe(prompt=task, neg_prompt=neg_prompt).images[0]
# Generate a unique filename for the image
img_name = f"{uuid.uuid4()}.{self.image_format}"
@ -144,10 +140,7 @@ class SSD1B:
# Handling exceptions and printing the errors details
print(
colored(
(
f"Error running SSD1B: {error} try optimizing"
" your api key and or try again"
),
(f"Error running SSD1B: {error} try optimizing" " your api key and or try again"),
"red",
)
)
@ -155,9 +148,7 @@ class SSD1B:
def _generate_image_name(self, task: str):
"""Generate a sanitized file name based on the task"""
sanitized_task = "".join(
char for char in task if char.isalnum() or char in " _ -"
).rstrip()
sanitized_task = "".join(char for char in task if char.isalnum() or char in " _ -").rstrip()
return f"{sanitized_task}.{self.image_format}"
def _download_image(self, img: Image, filename: str):
@ -192,9 +183,7 @@ class SSD1B:
)
)
def process_batch_concurrently(
self, tasks: List[str], max_workers: int = 5
):
def process_batch_concurrently(self, tasks: List[str], max_workers: int = 5):
"""
Process a batch of tasks concurrently
@ -215,16 +204,10 @@ class SSD1B:
>>> print(results)
"""
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers
) as executor:
future_to_task = {
executor.submit(self, task): task for task in tasks
}
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_task = {executor.submit(self, task): task for task in tasks}
results = []
for future in concurrent.futures.as_completed(
future_to_task
):
for future in concurrent.futures.as_completed(future_to_task):
task = future_to_task[future]
try:
img = future.result()
@ -234,20 +217,13 @@ class SSD1B:
except Exception as error:
print(
colored(
(
f"Error running SSD1B: {error} try"
" optimizing your api key and or try"
" again"
),
(f"Error running SSD1B: {error} try" " optimizing your api key and or try" " again"),
"red",
)
)
print(
colored(
(
"Error running SSD1B:"
f" {error.http_status}"
),
("Error running SSD1B:" f" {error.http_status}"),
"red",
)
)
@ -271,9 +247,7 @@ class SSD1B:
"""Str method for the SSD1B class"""
return f"SSD1B(image_url={self.image_url})"
@backoff.on_exception(
backoff.expo, Exception, max_tries=max_retries
)
@backoff.on_exception(backoff.expo, Exception, max_tries=max_retries)
def rate_limited_call(self, task: str):
"""Rate limited call to the SSD1B API"""
return self.__call__(task)

Loading…
Cancel
Save