parent
							
								
									217cb27f09
								
							
						
					
					
						commit
						fb32a182ff
					
				| @ -0,0 +1,58 @@ | |||||||
|  | import os | ||||||
|  | import base64 | ||||||
|  | import requests | ||||||
|  | from dotenv import load_dotenv | ||||||
|  | from typing import List | ||||||
|  | 
 | ||||||
|  | load_dotenv() | ||||||
|  | 
 | ||||||
|  | class StableDiffusion: | ||||||
|  |     def __init__(self, api_key: str, api_host: str = "https://api.stability.ai"): | ||||||
|  |         self.api_key = api_key | ||||||
|  |         self.api_host = api_host | ||||||
|  |         self.engine_id = "stable-diffusion-v1-6" | ||||||
|  |         self.headers = { | ||||||
|  |             "Authorization": f"Bearer {self.api_key}", | ||||||
|  |             "Content-Type": "application/json", | ||||||
|  |             "Accept": "application/json" | ||||||
|  |         } | ||||||
|  |         self.output_dir = "images" | ||||||
|  |         os.makedirs(self.output_dir, exist_ok=True) | ||||||
|  | 
 | ||||||
|  |     def generate_image(self, prompt: str, cfg_scale: int = 7, height: int = 1024, width: int = 1024, samples: int = 1, steps: int = 30) -> List[str]: | ||||||
|  |         response = requests.post( | ||||||
|  |             f"{self.api_host}/v1/generation/{self.engine_id}/text-to-image", | ||||||
|  |             headers=self.headers, | ||||||
|  |             json={ | ||||||
|  |                 "text_prompts": [{"text": prompt}], | ||||||
|  |                 "cfg_scale": cfg_scale, | ||||||
|  |                 "height": height, | ||||||
|  |                 "width": width, | ||||||
|  |                 "samples": samples, | ||||||
|  |                 "steps": steps, | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         if response.status_code != 200: | ||||||
|  |             raise Exception(f"Non-200 response: {response.text}") | ||||||
|  | 
 | ||||||
|  |         data = response.json() | ||||||
|  |         image_paths = [] | ||||||
|  |         for i, image in enumerate(data["artifacts"]): | ||||||
|  |             image_path = os.path.join(self.output_dir, f"v1_txt2img_{i}.png") | ||||||
|  |             with open(image_path, "wb") as f: | ||||||
|  |                 f.write(base64.b64decode(image["base64"])) | ||||||
|  |             image_paths.append(image_path) | ||||||
|  | 
 | ||||||
|  |         return image_paths | ||||||
|  | 
 | ||||||
|  | # Example Usage | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     api_key = os.getenv("STABILITY_API_KEY") | ||||||
|  |     if not api_key: | ||||||
|  |         raise Exception("Missing Stability API key.") | ||||||
|  | 
 | ||||||
|  |     sd_api = StableDiffusion(api_key) | ||||||
|  |     images = sd_api.generate_image("A lighthouse on a cliff") | ||||||
|  |     for image_path in images: | ||||||
|  |         print(f"Generated image saved at: {image_path}") | ||||||
					Loading…
					
					
				
		Reference in new issue