|
|
|
@ -7,7 +7,39 @@ from typing import List
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
|
|
|
|
class StableDiffusion:
|
|
|
|
|
"""
|
|
|
|
|
A class to interact with the Stable Diffusion API for image generation.
|
|
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
|
-----------
|
|
|
|
|
api_key : str
|
|
|
|
|
The API key for accessing the Stable Diffusion API.
|
|
|
|
|
api_host : str
|
|
|
|
|
The host URL of the Stable Diffusion API.
|
|
|
|
|
engine_id : str
|
|
|
|
|
The ID of the Stable Diffusion engine.
|
|
|
|
|
headers : dict
|
|
|
|
|
The headers for the API request.
|
|
|
|
|
output_dir : str
|
|
|
|
|
Directory where generated images will be saved.
|
|
|
|
|
|
|
|
|
|
Methods:
|
|
|
|
|
--------
|
|
|
|
|
generate_image(prompt: str, cfg_scale: int, height: int, width: int, samples: int, steps: int) -> List[str]:
|
|
|
|
|
Generates images based on a text prompt and returns a list of file paths to the generated images.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, api_key: str, api_host: str = "https://api.stability.ai"):
|
|
|
|
|
"""
|
|
|
|
|
Initializes the StableDiffusion class with the provided API key and host.
|
|
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
|
-----------
|
|
|
|
|
api_key : str
|
|
|
|
|
The API key for accessing the Stable Diffusion API.
|
|
|
|
|
api_host : str
|
|
|
|
|
The host URL of the Stable Diffusion API. Default is "https://api.stability.ai".
|
|
|
|
|
"""
|
|
|
|
|
self.api_key = api_key
|
|
|
|
|
self.api_host = api_host
|
|
|
|
|
self.engine_id = "stable-diffusion-v1-6"
|
|
|
|
@ -20,6 +52,34 @@ class StableDiffusion:
|
|
|
|
|
os.makedirs(self.output_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
def generate_image(self, prompt: str, cfg_scale: int = 7, height: int = 1024, width: int = 1024, samples: int = 1, steps: int = 30) -> List[str]:
|
|
|
|
|
"""
|
|
|
|
|
Generates images based on a text prompt.
|
|
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
|
-----------
|
|
|
|
|
prompt : str
|
|
|
|
|
The text prompt based on which the image will be generated.
|
|
|
|
|
cfg_scale : int
|
|
|
|
|
CFG scale parameter for image generation. Default is 7.
|
|
|
|
|
height : int
|
|
|
|
|
Height of the generated image. Default is 1024.
|
|
|
|
|
width : int
|
|
|
|
|
Width of the generated image. Default is 1024.
|
|
|
|
|
samples : int
|
|
|
|
|
Number of images to generate. Default is 1.
|
|
|
|
|
steps : int
|
|
|
|
|
Number of steps for the generation process. Default is 30.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
--------
|
|
|
|
|
List[str]:
|
|
|
|
|
A list of paths to the generated images.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
-------
|
|
|
|
|
Exception:
|
|
|
|
|
If the API response is not 200 (OK).
|
|
|
|
|
"""
|
|
|
|
|
response = requests.post(
|
|
|
|
|
f"{self.api_host}/v1/generation/{self.engine_id}/text-to-image",
|
|
|
|
|
headers=self.headers,
|
|
|
|
@ -46,13 +106,7 @@ class StableDiffusion:
|
|
|
|
|
|
|
|
|
|
return image_paths
|
|
|
|
|
|
|
|
|
|
# Example Usage
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
api_key = os.getenv("STABILITY_API_KEY")
|
|
|
|
|
if not api_key:
|
|
|
|
|
raise Exception("Missing Stability API key.")
|
|
|
|
|
|
|
|
|
|
sd_api = StableDiffusion(api_key)
|
|
|
|
|
images = sd_api.generate_image("A lighthouse on a cliff")
|
|
|
|
|
for image_path in images:
|
|
|
|
|
print(f"Generated image saved at: {image_path}")
|
|
|
|
|
# Usage example:
|
|
|
|
|
# sd = StableDiffusion("your-api-key")
|
|
|
|
|
# images = sd.generate_image("A scenic landscape with mountains")
|
|
|
|
|
# print(images)
|
|
|
|
|