diff --git a/playground/models/stable_diffusion.py b/playground/models/stable_diffusion.py index 3bb77c39..f45d5892 100644 --- a/playground/models/stable_diffusion.py +++ b/playground/models/stable_diffusion.py @@ -1,6 +1,7 @@ -import os import base64 +import os import requests +import uuid from dotenv import load_dotenv from typing import List @@ -8,41 +9,67 @@ load_dotenv() class StableDiffusion: """ - A class to interact with the Stable Diffusion API for image generation. + A class to interact with the Stable Diffusion API for generating images from text prompts. Attributes: ----------- api_key : str The API key for accessing the Stable Diffusion API. api_host : str - The host URL of the Stable Diffusion API. + The host URL for the Stable Diffusion API. engine_id : str - The ID of the Stable Diffusion engine. - headers : dict - The headers for the API request. + The engine ID for the Stable Diffusion API. + cfg_scale : int + Configuration scale for image generation. + height : int + The height of the generated image. + width : int + The width of the generated image. + samples : int + The number of samples to generate. + steps : int + The number of steps for the generation process. output_dir : str - Directory where generated images will be saved. + Directory where the generated images will be saved. Methods: -------- - generate_image(prompt: str, cfg_scale: int, height: int, width: int, samples: int, steps: int) -> List[str]: - Generates images based on a text prompt and returns a list of file paths to the generated images. + __init__(self, api_key: str, api_host: str, cfg_scale: int, height: int, width: int, samples: int, steps: int): + Initializes the StableDiffusion instance with provided parameters. + + generate_image(self, task: str) -> List[str]: + Generates an image based on the provided text prompt and returns the paths of the saved images. """ - def __init__(self, api_key: str, api_host: str = "https://api.stability.ai"): + def __init__(self, api_key: str, api_host: str = "https://api.stability.ai", cfg_scale: int = 7, height: int = 1024, width: int = 1024, samples: int = 1, steps: int = 30): """ - Initializes the StableDiffusion class with the provided API key and host. + Initialize the StableDiffusion class with API configurations. Parameters: ----------- api_key : str The API key for accessing the Stable Diffusion API. api_host : str - The host URL of the Stable Diffusion API. Default is "https://api.stability.ai". + The host URL for the Stable Diffusion API. + cfg_scale : int + Configuration scale for image generation. + height : int + The height of the generated image. + width : int + The width of the generated image. + samples : int + The number of samples to generate. + steps : int + The number of steps for the generation process. """ self.api_key = api_key self.api_host = api_host self.engine_id = "stable-diffusion-v1-6" + self.cfg_scale = cfg_scale + self.height = height + self.width = width + self.samples = samples + self.steps = steps self.headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", @@ -51,45 +78,35 @@ class StableDiffusion: self.output_dir = "images" os.makedirs(self.output_dir, exist_ok=True) - def generate_image(self, prompt: str, cfg_scale: int = 7, height: int = 1024, width: int = 1024, samples: int = 1, steps: int = 30) -> List[str]: + def run(self, task: str) -> List[str]: """ - Generates images based on a text prompt. + Generates an image based on a given text prompt. Parameters: ----------- - prompt : str + task : str The text prompt based on which the image will be generated. - cfg_scale : int - CFG scale parameter for image generation. Default is 7. - height : int - Height of the generated image. Default is 1024. - width : int - Width of the generated image. Default is 1024. - samples : int - Number of images to generate. Default is 1. - steps : int - Number of steps for the generation process. Default is 30. Returns: -------- List[str]: - A list of paths to the generated images. + A list of file paths where the generated images are saved. Raises: ------- Exception: - If the API response is not 200 (OK). + If the API request fails and returns a non-200 response. """ response = requests.post( f"{self.api_host}/v1/generation/{self.engine_id}/text-to-image", headers=self.headers, json={ - "text_prompts": [{"text": prompt}], - "cfg_scale": cfg_scale, - "height": height, - "width": width, - "samples": samples, - "steps": steps, + "text_prompts": [{"text": task}], + "cfg_scale": self.cfg_scale, + "height": self.height, + "width": self.width, + "samples": self.samples, + "steps": self.steps, }, ) @@ -99,14 +116,10 @@ class StableDiffusion: data = response.json() image_paths = [] for i, image in enumerate(data["artifacts"]): - image_path = os.path.join(self.output_dir, f"v1_txt2img_{i}.png") + unique_id = uuid.uuid4() # Generate a unique identifier + image_path = os.path.join(self.output_dir, f"{unique_id}_v1_txt2img_{i}.png") with open(image_path, "wb") as f: f.write(base64.b64decode(image["base64"])) image_paths.append(image_path) return image_paths - -# Usage example: -# sd = StableDiffusion("your-api-key") -# images = sd.generate_image("A scenic landscape with mountains") -# print(images)