From 217cb27f0968bbabc7390906863729f22595a716 Mon Sep 17 00:00:00 2001 From: pliny <133052465+elder-plinius@users.noreply.github.com> Date: Sun, 26 Nov 2023 18:45:24 -0800 Subject: [PATCH 1/3] Update .env.example --- .env.example | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.env.example b/.env.example index c6c3cade..6197a6d4 100644 --- a/.env.example +++ b/.env.example @@ -5,7 +5,7 @@ AI21_API_KEY="your_api_key_here" COHERE_API_KEY="your_api_key_here" ALEPHALPHA_API_KEY="your_api_key_here" HUGGINFACEHUB_API_KEY="your_api_key_here" - +STABILITY_API_KEY="your_api_key_here" WOLFRAM_ALPHA_APPID="your_wolfram_alpha_appid_here" ZAPIER_NLA_API_KEY="your_zapier_nla_api_key_here" @@ -41,4 +41,4 @@ REDIS_PORT= PINECONE_API_KEY="" BING_COOKIE="" -PSG_CONNECTION_STRING="" \ No newline at end of file +PSG_CONNECTION_STRING="" From fb32a182ffde88edc56499f1f6cfe30b1ba77ccb Mon Sep 17 00:00:00 2001 From: pliny <133052465+elder-plinius@users.noreply.github.com> Date: Sun, 26 Nov 2023 18:46:01 -0800 Subject: [PATCH 2/3] Create stable_diffusion.py --- playground/models/stable_diffusion.py | 58 +++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 playground/models/stable_diffusion.py diff --git a/playground/models/stable_diffusion.py b/playground/models/stable_diffusion.py new file mode 100644 index 00000000..3ca69931 --- /dev/null +++ b/playground/models/stable_diffusion.py @@ -0,0 +1,58 @@ +import os +import base64 +import requests +from dotenv import load_dotenv +from typing import List + +load_dotenv() + +class StableDiffusion: + def __init__(self, api_key: str, api_host: str = "https://api.stability.ai"): + self.api_key = api_key + self.api_host = api_host + self.engine_id = "stable-diffusion-v1-6" + self.headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + "Accept": "application/json" + } + self.output_dir = "images" + os.makedirs(self.output_dir, exist_ok=True) + + def generate_image(self, prompt: str, cfg_scale: int = 7, height: int = 1024, width: int = 1024, samples: int = 1, steps: int = 30) -> List[str]: + response = requests.post( + f"{self.api_host}/v1/generation/{self.engine_id}/text-to-image", + headers=self.headers, + json={ + "text_prompts": [{"text": prompt}], + "cfg_scale": cfg_scale, + "height": height, + "width": width, + "samples": samples, + "steps": steps, + }, + ) + + if response.status_code != 200: + raise Exception(f"Non-200 response: {response.text}") + + data = response.json() + image_paths = [] + for i, image in enumerate(data["artifacts"]): + image_path = os.path.join(self.output_dir, f"v1_txt2img_{i}.png") + with open(image_path, "wb") as f: + f.write(base64.b64decode(image["base64"])) + image_paths.append(image_path) + + return image_paths + +# Example Usage +if __name__ == "__main__": + api_key = os.getenv("STABILITY_API_KEY") + if not api_key: + raise Exception("Missing Stability API key.") + + sd_api = StableDiffusion(api_key) + images = sd_api.generate_image("A lighthouse on a cliff") + for image_path in images: + print(f"Generated image saved at: {image_path}") From d59ad33a225099f12a2d67948bdde0011ff9911b Mon Sep 17 00:00:00 2001 From: pliny <133052465+elder-plinius@users.noreply.github.com> Date: Sun, 26 Nov 2023 18:55:30 -0800 Subject: [PATCH 3/3] Update stable_diffusion.py --- playground/models/stable_diffusion.py | 74 +++++++++++++++++++++++---- 1 file changed, 64 insertions(+), 10 deletions(-) diff --git a/playground/models/stable_diffusion.py b/playground/models/stable_diffusion.py index 3ca69931..3bb77c39 100644 --- a/playground/models/stable_diffusion.py +++ b/playground/models/stable_diffusion.py @@ -7,7 +7,39 @@ from typing import List load_dotenv() class StableDiffusion: + """ + A class to interact with the Stable Diffusion API for image generation. + + Attributes: + ----------- + api_key : str + The API key for accessing the Stable Diffusion API. + api_host : str + The host URL of the Stable Diffusion API. + engine_id : str + The ID of the Stable Diffusion engine. + headers : dict + The headers for the API request. + output_dir : str + Directory where generated images will be saved. + + Methods: + -------- + generate_image(prompt: str, cfg_scale: int, height: int, width: int, samples: int, steps: int) -> List[str]: + Generates images based on a text prompt and returns a list of file paths to the generated images. + """ + def __init__(self, api_key: str, api_host: str = "https://api.stability.ai"): + """ + Initializes the StableDiffusion class with the provided API key and host. + + Parameters: + ----------- + api_key : str + The API key for accessing the Stable Diffusion API. + api_host : str + The host URL of the Stable Diffusion API. Default is "https://api.stability.ai". + """ self.api_key = api_key self.api_host = api_host self.engine_id = "stable-diffusion-v1-6" @@ -20,6 +52,34 @@ class StableDiffusion: os.makedirs(self.output_dir, exist_ok=True) def generate_image(self, prompt: str, cfg_scale: int = 7, height: int = 1024, width: int = 1024, samples: int = 1, steps: int = 30) -> List[str]: + """ + Generates images based on a text prompt. + + Parameters: + ----------- + prompt : str + The text prompt based on which the image will be generated. + cfg_scale : int + CFG scale parameter for image generation. Default is 7. + height : int + Height of the generated image. Default is 1024. + width : int + Width of the generated image. Default is 1024. + samples : int + Number of images to generate. Default is 1. + steps : int + Number of steps for the generation process. Default is 30. + + Returns: + -------- + List[str]: + A list of paths to the generated images. + + Raises: + ------- + Exception: + If the API response is not 200 (OK). + """ response = requests.post( f"{self.api_host}/v1/generation/{self.engine_id}/text-to-image", headers=self.headers, @@ -46,13 +106,7 @@ class StableDiffusion: return image_paths -# Example Usage -if __name__ == "__main__": - api_key = os.getenv("STABILITY_API_KEY") - if not api_key: - raise Exception("Missing Stability API key.") - - sd_api = StableDiffusion(api_key) - images = sd_api.generate_image("A lighthouse on a cliff") - for image_path in images: - print(f"Generated image saved at: {image_path}") +# Usage example: +# sd = StableDiffusion("your-api-key") +# images = sd.generate_image("A scenic landscape with mountains") +# print(images)