Use cuda only if avaliable

pull/388/head
Wyatt Stanke 11 months ago
parent b44ca7919d
commit 127ef4a521
No known key found for this signature in database
GPG Key ID: CE6BA5FFF135536D

@ -49,7 +49,7 @@ class SSD1B:
max_time_seconds: int = 60 max_time_seconds: int = 60
save_folder: str = "images" save_folder: str = "images"
image_format: str = "png" image_format: str = "png"
device: str = "cuda" device: str = "cuda" if torch.cuda.is_available() else "cpu"
dashboard: bool = False dashboard: bool = False
cache = TTLCache(maxsize=100, ttl=3600) cache = TTLCache(maxsize=100, ttl=3600)
pipe = StableDiffusionXLPipeline.from_pretrained( pipe = StableDiffusionXLPipeline.from_pretrained(
@ -96,9 +96,7 @@ class SSD1B:
byte_array = byte_stream.getvalue() byte_array = byte_stream.getvalue()
return byte_array return byte_array
@backoff.on_exception( @backoff.on_exception(backoff.expo, Exception, max_time=max_time_seconds)
backoff.expo, Exception, max_time=max_time_seconds
)
def __call__(self, task: str, neg_prompt: str): def __call__(self, task: str, neg_prompt: str):
""" """
Text to image conversion using the SSD1B API Text to image conversion using the SSD1B API
@ -126,9 +124,7 @@ class SSD1B:
if task in self.cache: if task in self.cache:
return self.cache[task] return self.cache[task]
try: try:
img = self.pipe( img = self.pipe(prompt=task, neg_prompt=neg_prompt).images[0]
prompt=task, neg_prompt=neg_prompt
).images[0]
# Generate a unique filename for the image # Generate a unique filename for the image
img_name = f"{uuid.uuid4()}.{self.image_format}" img_name = f"{uuid.uuid4()}.{self.image_format}"
@ -144,10 +140,7 @@ class SSD1B:
# Handling exceptions and printing the errors details # Handling exceptions and printing the errors details
print( print(
colored( colored(
( (f"Error running SSD1B: {error} try optimizing" " your api key and or try again"),
f"Error running SSD1B: {error} try optimizing"
" your api key and or try again"
),
"red", "red",
) )
) )
@ -155,9 +148,7 @@ class SSD1B:
def _generate_image_name(self, task: str): def _generate_image_name(self, task: str):
"""Generate a sanitized file name based on the task""" """Generate a sanitized file name based on the task"""
sanitized_task = "".join( sanitized_task = "".join(char for char in task if char.isalnum() or char in " _ -").rstrip()
char for char in task if char.isalnum() or char in " _ -"
).rstrip()
return f"{sanitized_task}.{self.image_format}" return f"{sanitized_task}.{self.image_format}"
def _download_image(self, img: Image, filename: str): def _download_image(self, img: Image, filename: str):
@ -192,9 +183,7 @@ class SSD1B:
) )
) )
def process_batch_concurrently( def process_batch_concurrently(self, tasks: List[str], max_workers: int = 5):
self, tasks: List[str], max_workers: int = 5
):
""" """
Process a batch of tasks concurrently Process a batch of tasks concurrently
@ -215,16 +204,10 @@ class SSD1B:
>>> print(results) >>> print(results)
""" """
with concurrent.futures.ThreadPoolExecutor( with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
max_workers=max_workers future_to_task = {executor.submit(self, task): task for task in tasks}
) as executor:
future_to_task = {
executor.submit(self, task): task for task in tasks
}
results = [] results = []
for future in concurrent.futures.as_completed( for future in concurrent.futures.as_completed(future_to_task):
future_to_task
):
task = future_to_task[future] task = future_to_task[future]
try: try:
img = future.result() img = future.result()
@ -234,20 +217,13 @@ class SSD1B:
except Exception as error: except Exception as error:
print( print(
colored( colored(
( (f"Error running SSD1B: {error} try" " optimizing your api key and or try" " again"),
f"Error running SSD1B: {error} try"
" optimizing your api key and or try"
" again"
),
"red", "red",
) )
) )
print( print(
colored( colored(
( ("Error running SSD1B:" f" {error.http_status}"),
"Error running SSD1B:"
f" {error.http_status}"
),
"red", "red",
) )
) )
@ -271,9 +247,7 @@ class SSD1B:
"""Str method for the SSD1B class""" """Str method for the SSD1B class"""
return f"SSD1B(image_url={self.image_url})" return f"SSD1B(image_url={self.image_url})"
@backoff.on_exception( @backoff.on_exception(backoff.expo, Exception, max_tries=max_retries)
backoff.expo, Exception, max_tries=max_retries
)
def rate_limited_call(self, task: str): def rate_limited_call(self, task: str):
"""Rate limited call to the SSD1B API""" """Rate limited call to the SSD1B API"""
return self.__call__(task) return self.__call__(task)

Loading…
Cancel
Save