#!/usr/bin/env python3 import os import sys import argparse import time from pathlib import Path import requests import base64 def load_env(): env_path = Path(__file__).parent / ".env" if env_path.exists(): for line in env_path.read_text().splitlines(): line = line.strip() if line and not line.startswith("#") and "=" in line: k, v = line.split("=", 1) os.environ.setdefault(k.strip(), v.strip().strip("\"'")) load_env() API_TOKEN = os.environ.get("CHUTES_API_TOKEN") API_URL = "https://chutes-z-image-turbo.chutes.ai/generate" def generate_image( prompt, width=1024, height=1024, steps=9, seed=None, guidance_scale=0, shift=3, max_seq_len=512, ): if not API_TOKEN: print("Error: CHUTES_API_TOKEN not set in environment", file=sys.stderr) sys.exit(1) if not prompt or len(prompt) < 3: print("Error: Prompt must be at least 3 characters", file=sys.stderr) sys.exit(1) if len(prompt) > 1200: print( "Error: Prompt exceeds maximum length of 1200 characters", file=sys.stderr ) sys.exit(1) payload = { "prompt": prompt, "width": width, "height": height, "num_inference_steps": steps, "guidance_scale": guidance_scale, "shift": shift, "max_sequence_length": max_seq_len, "seed": seed, } try: headers = { "Authorization": f"Bearer {API_TOKEN}", "Content-Type": "application/json", } response = requests.post(API_URL, headers=headers, json=payload, timeout=300) response.raise_for_status() content_type = response.headers.get("Content-Type", "") if "image/" in content_type: image_bytes = response.content else: result = response.json() if isinstance(result, list) and len(result) > 0: item = result[0] image_data = item.get("data", "") if image_data.startswith("data:image"): image_bytes = base64.b64decode(image_data.split(",", 1)[1]) else: image_bytes = base64.b64decode(image_data) else: print("Error: Invalid response format", file=sys.stderr) sys.exit(1) timestamp = int(time.time()) filename = f"generated_{timestamp}.png" with open(filename, "wb") as f: f.write(image_bytes) print(f"Image saved: {filename} [{timestamp}]") except requests.exceptions.RequestException as e: print(f"Error: API request failed - {e}", file=sys.stderr) sys.exit(1) except Exception as e: print(f"Error: {e}", file=sys.stderr) sys.exit(1) def main(): parser = argparse.ArgumentParser(description="Generate images from text prompts") parser.add_argument("prompt", help="Text prompt for image generation") parser.add_argument( "--width", type=int, default=1024, help="Image width (576-2048)" ) parser.add_argument( "--height", type=int, default=1024, help="Image height (576-2048)" ) parser.add_argument("--steps", type=int, default=9, help="Inference steps (1-100)") parser.add_argument("--seed", type=int, default=None, help="Random seed") parser.add_argument( "--guidance-scale", type=float, default=0, help="Guidance scale (0-5)" ) parser.add_argument("--shift", type=float, default=3, help="Shift parameter (1-10)") parser.add_argument( "--max-seq-len", type=int, default=512, help="Max sequence length (256-2048)" ) args = parser.parse_args() if not (576 <= args.width <= 2048): print("Error: width must be between 576 and 2048", file=sys.stderr) sys.exit(1) if not (576 <= args.height <= 2048): print("Error: height must be between 576 and 2048", file=sys.stderr) sys.exit(1) if not (1 <= args.steps <= 100): print("Error: steps must be between 1 and 100", file=sys.stderr) sys.exit(1) if args.seed is not None and not (0 <= args.seed <= 4294967295): print("Error: seed must be between 0 and 4294967295", file=sys.stderr) sys.exit(1) if not (0 <= args.guidance_scale <= 5): print("Error: guidance-scale must be between 0 and 5", file=sys.stderr) sys.exit(1) if not (1 <= args.shift <= 10): print("Error: shift must be between 1 and 10", file=sys.stderr) sys.exit(1) if not (256 <= args.max_seq_len <= 2048): print("Error: max-seq-len must be between 256 and 2048", file=sys.stderr) sys.exit(1) generate_image( prompt=args.prompt, width=args.width, height=args.height, steps=args.steps, seed=args.seed, guidance_scale=args.guidance_scale, shift=args.shift, max_seq_len=args.max_seq_len, ) if __name__ == "__main__": main()