|
|
@ -49,7 +49,7 @@ class SSD1B:
|
|
|
|
max_time_seconds: int = 60
|
|
|
|
max_time_seconds: int = 60
|
|
|
|
save_folder: str = "images"
|
|
|
|
save_folder: str = "images"
|
|
|
|
image_format: str = "png"
|
|
|
|
image_format: str = "png"
|
|
|
|
device: str = "cuda"
|
|
|
|
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
dashboard: bool = False
|
|
|
|
dashboard: bool = False
|
|
|
|
cache = TTLCache(maxsize=100, ttl=3600)
|
|
|
|
cache = TTLCache(maxsize=100, ttl=3600)
|
|
|
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
|
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
|
@ -96,9 +96,7 @@ class SSD1B:
|
|
|
|
byte_array = byte_stream.getvalue()
|
|
|
|
byte_array = byte_stream.getvalue()
|
|
|
|
return byte_array
|
|
|
|
return byte_array
|
|
|
|
|
|
|
|
|
|
|
|
@backoff.on_exception(
|
|
|
|
@backoff.on_exception(backoff.expo, Exception, max_time=max_time_seconds)
|
|
|
|
backoff.expo, Exception, max_time=max_time_seconds
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
def __call__(self, task: str, neg_prompt: str):
|
|
|
|
def __call__(self, task: str, neg_prompt: str):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Text to image conversion using the SSD1B API
|
|
|
|
Text to image conversion using the SSD1B API
|
|
|
@ -126,9 +124,7 @@ class SSD1B:
|
|
|
|
if task in self.cache:
|
|
|
|
if task in self.cache:
|
|
|
|
return self.cache[task]
|
|
|
|
return self.cache[task]
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
img = self.pipe(
|
|
|
|
img = self.pipe(prompt=task, neg_prompt=neg_prompt).images[0]
|
|
|
|
prompt=task, neg_prompt=neg_prompt
|
|
|
|
|
|
|
|
).images[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Generate a unique filename for the image
|
|
|
|
# Generate a unique filename for the image
|
|
|
|
img_name = f"{uuid.uuid4()}.{self.image_format}"
|
|
|
|
img_name = f"{uuid.uuid4()}.{self.image_format}"
|
|
|
@ -144,10 +140,7 @@ class SSD1B:
|
|
|
|
# Handling exceptions and printing the errors details
|
|
|
|
# Handling exceptions and printing the errors details
|
|
|
|
print(
|
|
|
|
print(
|
|
|
|
colored(
|
|
|
|
colored(
|
|
|
|
(
|
|
|
|
(f"Error running SSD1B: {error} try optimizing" " your api key and or try again"),
|
|
|
|
f"Error running SSD1B: {error} try optimizing"
|
|
|
|
|
|
|
|
" your api key and or try again"
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
"red",
|
|
|
|
"red",
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
@ -155,9 +148,7 @@ class SSD1B:
|
|
|
|
|
|
|
|
|
|
|
|
def _generate_image_name(self, task: str):
|
|
|
|
def _generate_image_name(self, task: str):
|
|
|
|
"""Generate a sanitized file name based on the task"""
|
|
|
|
"""Generate a sanitized file name based on the task"""
|
|
|
|
sanitized_task = "".join(
|
|
|
|
sanitized_task = "".join(char for char in task if char.isalnum() or char in " _ -").rstrip()
|
|
|
|
char for char in task if char.isalnum() or char in " _ -"
|
|
|
|
|
|
|
|
).rstrip()
|
|
|
|
|
|
|
|
return f"{sanitized_task}.{self.image_format}"
|
|
|
|
return f"{sanitized_task}.{self.image_format}"
|
|
|
|
|
|
|
|
|
|
|
|
def _download_image(self, img: Image, filename: str):
|
|
|
|
def _download_image(self, img: Image, filename: str):
|
|
|
@ -192,9 +183,7 @@ class SSD1B:
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def process_batch_concurrently(
|
|
|
|
def process_batch_concurrently(self, tasks: List[str], max_workers: int = 5):
|
|
|
|
self, tasks: List[str], max_workers: int = 5
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
Process a batch of tasks concurrently
|
|
|
|
Process a batch of tasks concurrently
|
|
|
@ -215,16 +204,10 @@ class SSD1B:
|
|
|
|
>>> print(results)
|
|
|
|
>>> print(results)
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
|
|
max_workers=max_workers
|
|
|
|
future_to_task = {executor.submit(self, task): task for task in tasks}
|
|
|
|
) as executor:
|
|
|
|
|
|
|
|
future_to_task = {
|
|
|
|
|
|
|
|
executor.submit(self, task): task for task in tasks
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
results = []
|
|
|
|
results = []
|
|
|
|
for future in concurrent.futures.as_completed(
|
|
|
|
for future in concurrent.futures.as_completed(future_to_task):
|
|
|
|
future_to_task
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
task = future_to_task[future]
|
|
|
|
task = future_to_task[future]
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
img = future.result()
|
|
|
|
img = future.result()
|
|
|
@ -234,20 +217,13 @@ class SSD1B:
|
|
|
|
except Exception as error:
|
|
|
|
except Exception as error:
|
|
|
|
print(
|
|
|
|
print(
|
|
|
|
colored(
|
|
|
|
colored(
|
|
|
|
(
|
|
|
|
(f"Error running SSD1B: {error} try" " optimizing your api key and or try" " again"),
|
|
|
|
f"Error running SSD1B: {error} try"
|
|
|
|
|
|
|
|
" optimizing your api key and or try"
|
|
|
|
|
|
|
|
" again"
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
"red",
|
|
|
|
"red",
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
print(
|
|
|
|
print(
|
|
|
|
colored(
|
|
|
|
colored(
|
|
|
|
(
|
|
|
|
("Error running SSD1B:" f" {error.http_status}"),
|
|
|
|
"Error running SSD1B:"
|
|
|
|
|
|
|
|
f" {error.http_status}"
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
"red",
|
|
|
|
"red",
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
@ -271,9 +247,7 @@ class SSD1B:
|
|
|
|
"""Str method for the SSD1B class"""
|
|
|
|
"""Str method for the SSD1B class"""
|
|
|
|
return f"SSD1B(image_url={self.image_url})"
|
|
|
|
return f"SSD1B(image_url={self.image_url})"
|
|
|
|
|
|
|
|
|
|
|
|
@backoff.on_exception(
|
|
|
|
@backoff.on_exception(backoff.expo, Exception, max_tries=max_retries)
|
|
|
|
backoff.expo, Exception, max_tries=max_retries
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
def rate_limited_call(self, task: str):
|
|
|
|
def rate_limited_call(self, task: str):
|
|
|
|
"""Rate limited call to the SSD1B API"""
|
|
|
|
"""Rate limited call to the SSD1B API"""
|
|
|
|
return self.__call__(task)
|
|
|
|
return self.__call__(task)
|
|
|
|