Update stable_diffusion.py

pull/205/head
pliny 1 year ago committed by GitHub
parent d59ad33a22
commit 2c10018f58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,7 @@
import os
import base64 import base64
import os
import requests import requests
import uuid
from dotenv import load_dotenv from dotenv import load_dotenv
from typing import List from typing import List
@ -8,41 +9,67 @@ load_dotenv()
class StableDiffusion: class StableDiffusion:
""" """
A class to interact with the Stable Diffusion API for image generation. A class to interact with the Stable Diffusion API for generating images from text prompts.
Attributes: Attributes:
----------- -----------
api_key : str api_key : str
The API key for accessing the Stable Diffusion API. The API key for accessing the Stable Diffusion API.
api_host : str api_host : str
The host URL of the Stable Diffusion API. The host URL for the Stable Diffusion API.
engine_id : str engine_id : str
The ID of the Stable Diffusion engine. The engine ID for the Stable Diffusion API.
headers : dict cfg_scale : int
The headers for the API request. Configuration scale for image generation.
height : int
The height of the generated image.
width : int
The width of the generated image.
samples : int
The number of samples to generate.
steps : int
The number of steps for the generation process.
output_dir : str output_dir : str
Directory where generated images will be saved. Directory where the generated images will be saved.
Methods: Methods:
-------- --------
generate_image(prompt: str, cfg_scale: int, height: int, width: int, samples: int, steps: int) -> List[str]: __init__(self, api_key: str, api_host: str, cfg_scale: int, height: int, width: int, samples: int, steps: int):
Generates images based on a text prompt and returns a list of file paths to the generated images. Initializes the StableDiffusion instance with provided parameters.
generate_image(self, task: str) -> List[str]:
Generates an image based on the provided text prompt and returns the paths of the saved images.
""" """
def __init__(self, api_key: str, api_host: str = "https://api.stability.ai"): def __init__(self, api_key: str, api_host: str = "https://api.stability.ai", cfg_scale: int = 7, height: int = 1024, width: int = 1024, samples: int = 1, steps: int = 30):
""" """
Initializes the StableDiffusion class with the provided API key and host. Initialize the StableDiffusion class with API configurations.
Parameters: Parameters:
----------- -----------
api_key : str api_key : str
The API key for accessing the Stable Diffusion API. The API key for accessing the Stable Diffusion API.
api_host : str api_host : str
The host URL of the Stable Diffusion API. Default is "https://api.stability.ai". The host URL for the Stable Diffusion API.
cfg_scale : int
Configuration scale for image generation.
height : int
The height of the generated image.
width : int
The width of the generated image.
samples : int
The number of samples to generate.
steps : int
The number of steps for the generation process.
""" """
self.api_key = api_key self.api_key = api_key
self.api_host = api_host self.api_host = api_host
self.engine_id = "stable-diffusion-v1-6" self.engine_id = "stable-diffusion-v1-6"
self.cfg_scale = cfg_scale
self.height = height
self.width = width
self.samples = samples
self.steps = steps
self.headers = { self.headers = {
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json", "Content-Type": "application/json",
@ -51,45 +78,35 @@ class StableDiffusion:
self.output_dir = "images" self.output_dir = "images"
os.makedirs(self.output_dir, exist_ok=True) os.makedirs(self.output_dir, exist_ok=True)
def generate_image(self, prompt: str, cfg_scale: int = 7, height: int = 1024, width: int = 1024, samples: int = 1, steps: int = 30) -> List[str]: def run(self, task: str) -> List[str]:
""" """
Generates images based on a text prompt. Generates an image based on a given text prompt.
Parameters: Parameters:
----------- -----------
prompt : str task : str
The text prompt based on which the image will be generated. The text prompt based on which the image will be generated.
cfg_scale : int
CFG scale parameter for image generation. Default is 7.
height : int
Height of the generated image. Default is 1024.
width : int
Width of the generated image. Default is 1024.
samples : int
Number of images to generate. Default is 1.
steps : int
Number of steps for the generation process. Default is 30.
Returns: Returns:
-------- --------
List[str]: List[str]:
A list of paths to the generated images. A list of file paths where the generated images are saved.
Raises: Raises:
------- -------
Exception: Exception:
If the API response is not 200 (OK). If the API request fails and returns a non-200 response.
""" """
response = requests.post( response = requests.post(
f"{self.api_host}/v1/generation/{self.engine_id}/text-to-image", f"{self.api_host}/v1/generation/{self.engine_id}/text-to-image",
headers=self.headers, headers=self.headers,
json={ json={
"text_prompts": [{"text": prompt}], "text_prompts": [{"text": task}],
"cfg_scale": cfg_scale, "cfg_scale": self.cfg_scale,
"height": height, "height": self.height,
"width": width, "width": self.width,
"samples": samples, "samples": self.samples,
"steps": steps, "steps": self.steps,
}, },
) )
@ -99,14 +116,10 @@ class StableDiffusion:
data = response.json() data = response.json()
image_paths = [] image_paths = []
for i, image in enumerate(data["artifacts"]): for i, image in enumerate(data["artifacts"]):
image_path = os.path.join(self.output_dir, f"v1_txt2img_{i}.png") unique_id = uuid.uuid4() # Generate a unique identifier
image_path = os.path.join(self.output_dir, f"{unique_id}_v1_txt2img_{i}.png")
with open(image_path, "wb") as f: with open(image_path, "wb") as f:
f.write(base64.b64decode(image["base64"])) f.write(base64.b64decode(image["base64"]))
image_paths.append(image_path) image_paths.append(image_path)
return image_paths return image_paths
# Usage example:
# sd = StableDiffusion("your-api-key")
# images = sd.generate_image("A scenic landscape with mountains")
# print(images)

Loading…
Cancel
Save