You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/swarms/models/stable_diffusion.py

138 lines
4.1 KiB

import base64
import os
import requests
import uuid
from dotenv import load_dotenv
from typing import List
load_dotenv()
class StableDiffusion:
"""
A class to interact with the Stable Diffusion API for generating images from text prompts.
Attributes:
-----------
api_key : str
The API key for accessing the Stable Diffusion API.
api_host : str
The host URL for the Stable Diffusion API.
engine_id : str
The engine ID for the Stable Diffusion API.
cfg_scale : int
Configuration scale for image generation.
height : int
The height of the generated image.
width : int
The width of the generated image.
samples : int
The number of samples to generate.
steps : int
The number of steps for the generation process.
output_dir : str
Directory where the generated images will be saved.
Methods:
--------
__init__(self, api_key: str, api_host: str, cfg_scale: int, height: int, width: int, samples: int, steps: int):
Initializes the StableDiffusion instance with provided parameters.
generate_image(self, task: str) -> List[str]:
Generates an image based on the provided text prompt and returns the paths of the saved images.
"""
def __init__(
self,
api_key: str,
api_host: str = "https://api.stability.ai",
cfg_scale: int = 7,
height: int = 1024,
width: int = 1024,
samples: int = 1,
steps: int = 30,
):
"""
Initialize the StableDiffusion class with API configurations.
Parameters:
-----------
api_key : str
The API key for accessing the Stable Diffusion API.
api_host : str
The host URL for the Stable Diffusion API.
cfg_scale : int
Configuration scale for image generation.
height : int
The height of the generated image.
width : int
The width of the generated image.
samples : int
The number of samples to generate.
steps : int
The number of steps for the generation process.
"""
self.api_key = api_key
self.api_host = api_host
self.engine_id = "stable-diffusion-v1-6"
self.cfg_scale = cfg_scale
self.height = height
self.width = width
self.samples = samples
self.steps = steps
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
}
self.output_dir = "images"
os.makedirs(self.output_dir, exist_ok=True)
def run(self, task: str) -> List[str]:
"""
Generates an image based on a given text prompt.
Parameters:
-----------
task : str
The text prompt based on which the image will be generated.
Returns:
--------
List[str]:
A list of file paths where the generated images are saved.
Raises:
-------
Exception:
If the API request fails and returns a non-200 response.
"""
response = requests.post(
f"{self.api_host}/v1/generation/{self.engine_id}/text-to-image",
headers=self.headers,
json={
"text_prompts": [{"text": task}],
"cfg_scale": self.cfg_scale,
"height": self.height,
"width": self.width,
"samples": self.samples,
"steps": self.steps,
},
)
if response.status_code != 200:
raise Exception(f"Non-200 response: {response.text}")
data = response.json()
image_paths = []
for i, image in enumerate(data["artifacts"]):
unique_id = uuid.uuid4() # Generate a unique identifier
image_path = os.path.join(
self.output_dir, f"{unique_id}_v1_txt2img_{i}.png"
)
with open(image_path, "wb") as f:
f.write(base64.b64decode(image["base64"]))
image_paths.append(image_path)
return image_paths