244 lines
9.0 KiB
Python
244 lines
9.0 KiB
Python
"""WaveSpeed AI image generation provider (Ideogram V3 Turbo & Qwen Image)."""
|
|
|
|
import io
|
|
import os
|
|
from typing import Optional
|
|
from PIL import Image
|
|
|
|
from .base import ImageGenerationProvider, ImageGenerationOptions, ImageGenerationResult
|
|
from services.wavespeed.client import WaveSpeedClient
|
|
from utils.logger_utils import get_service_logger
|
|
|
|
|
|
logger = get_service_logger("wavespeed.image_provider")
|
|
|
|
|
|
class WaveSpeedImageProvider(ImageGenerationProvider):
|
|
"""WaveSpeed AI image generation provider supporting Ideogram V3 and Qwen."""
|
|
|
|
SUPPORTED_MODELS = {
|
|
"ideogram-v3-turbo": {
|
|
"name": "Ideogram V3 Turbo",
|
|
"description": "Photorealistic generation with superior text rendering",
|
|
"cost_per_image": 0.10, # Estimated, adjust based on actual pricing
|
|
"max_resolution": (1024, 1024),
|
|
"default_steps": 20,
|
|
},
|
|
"qwen-image": {
|
|
"name": "Qwen Image",
|
|
"description": "Fast, high-quality text-to-image generation",
|
|
"cost_per_image": 0.05, # Estimated, adjust based on actual pricing
|
|
"max_resolution": (1024, 1024),
|
|
"default_steps": 15,
|
|
}
|
|
}
|
|
|
|
def __init__(self, api_key: Optional[str] = None):
|
|
"""Initialize WaveSpeed image provider.
|
|
|
|
Args:
|
|
api_key: WaveSpeed API key (falls back to env var if not provided)
|
|
"""
|
|
self.api_key = api_key or os.getenv("WAVESPEED_API_KEY")
|
|
if not self.api_key:
|
|
raise ValueError("WaveSpeed API key not found. Set WAVESPEED_API_KEY environment variable.")
|
|
|
|
self.client = WaveSpeedClient(api_key=self.api_key)
|
|
logger.info("[WaveSpeed Image Provider] Initialized with available models: %s",
|
|
list(self.SUPPORTED_MODELS.keys()))
|
|
|
|
def _validate_options(self, options: ImageGenerationOptions) -> None:
|
|
"""Validate generation options.
|
|
|
|
Args:
|
|
options: Image generation options
|
|
|
|
Raises:
|
|
ValueError: If options are invalid
|
|
"""
|
|
model = options.model or "ideogram-v3-turbo"
|
|
|
|
if model not in self.SUPPORTED_MODELS:
|
|
raise ValueError(
|
|
f"Unsupported model: {model}. "
|
|
f"Supported models: {list(self.SUPPORTED_MODELS.keys())}"
|
|
)
|
|
|
|
model_info = self.SUPPORTED_MODELS[model]
|
|
max_width, max_height = model_info["max_resolution"]
|
|
|
|
if options.width > max_width or options.height > max_height:
|
|
raise ValueError(
|
|
f"Resolution {options.width}x{options.height} exceeds maximum "
|
|
f"{max_width}x{max_height} for model {model}"
|
|
)
|
|
|
|
if not options.prompt or len(options.prompt.strip()) == 0:
|
|
raise ValueError("Prompt cannot be empty")
|
|
|
|
def _generate_ideogram_v3(self, options: ImageGenerationOptions) -> bytes:
|
|
"""Generate image using Ideogram V3 Turbo.
|
|
|
|
Args:
|
|
options: Image generation options
|
|
|
|
Returns:
|
|
Image bytes
|
|
"""
|
|
logger.info("[Ideogram V3] Starting image generation: %s", options.prompt[:100])
|
|
|
|
try:
|
|
# Prepare parameters for WaveSpeed Ideogram V3 API
|
|
# Note: Adjust these based on actual WaveSpeed API documentation
|
|
params = {
|
|
"model": "ideogram-v3-turbo",
|
|
"prompt": options.prompt,
|
|
"width": options.width,
|
|
"height": options.height,
|
|
"num_inference_steps": options.steps or self.SUPPORTED_MODELS["ideogram-v3-turbo"]["default_steps"],
|
|
}
|
|
|
|
# Add optional parameters
|
|
if options.negative_prompt:
|
|
params["negative_prompt"] = options.negative_prompt
|
|
|
|
if options.guidance_scale:
|
|
params["guidance_scale"] = options.guidance_scale
|
|
|
|
if options.seed:
|
|
params["seed"] = options.seed
|
|
|
|
# Call WaveSpeed API (using generic image generation method)
|
|
# This will need to be adjusted based on actual WaveSpeed client implementation
|
|
result = self.client.generate_image(**params)
|
|
|
|
# Extract image bytes from result
|
|
# Adjust based on actual WaveSpeed API response format
|
|
if isinstance(result, bytes):
|
|
image_bytes = result
|
|
elif isinstance(result, dict) and "image" in result:
|
|
image_bytes = result["image"]
|
|
else:
|
|
raise ValueError(f"Unexpected response format from WaveSpeed API: {type(result)}")
|
|
|
|
logger.info("[Ideogram V3] ✅ Successfully generated image: %d bytes", len(image_bytes))
|
|
return image_bytes
|
|
|
|
except Exception as e:
|
|
logger.error("[Ideogram V3] ❌ Error generating image: %s", str(e), exc_info=True)
|
|
raise RuntimeError(f"Ideogram V3 generation failed: {str(e)}")
|
|
|
|
def _generate_qwen_image(self, options: ImageGenerationOptions) -> bytes:
|
|
"""Generate image using Qwen Image.
|
|
|
|
Args:
|
|
options: Image generation options
|
|
|
|
Returns:
|
|
Image bytes
|
|
"""
|
|
logger.info("[Qwen Image] Starting image generation: %s", options.prompt[:100])
|
|
|
|
try:
|
|
# Prepare parameters for WaveSpeed Qwen Image API
|
|
params = {
|
|
"model": "qwen-image",
|
|
"prompt": options.prompt,
|
|
"width": options.width,
|
|
"height": options.height,
|
|
"num_inference_steps": options.steps or self.SUPPORTED_MODELS["qwen-image"]["default_steps"],
|
|
}
|
|
|
|
# Add optional parameters
|
|
if options.negative_prompt:
|
|
params["negative_prompt"] = options.negative_prompt
|
|
|
|
if options.guidance_scale:
|
|
params["guidance_scale"] = options.guidance_scale
|
|
|
|
if options.seed:
|
|
params["seed"] = options.seed
|
|
|
|
# Call WaveSpeed API
|
|
result = self.client.generate_image(**params)
|
|
|
|
# Extract image bytes from result
|
|
if isinstance(result, bytes):
|
|
image_bytes = result
|
|
elif isinstance(result, dict) and "image" in result:
|
|
image_bytes = result["image"]
|
|
else:
|
|
raise ValueError(f"Unexpected response format from WaveSpeed API: {type(result)}")
|
|
|
|
logger.info("[Qwen Image] ✅ Successfully generated image: %d bytes", len(image_bytes))
|
|
return image_bytes
|
|
|
|
except Exception as e:
|
|
logger.error("[Qwen Image] ❌ Error generating image: %s", str(e), exc_info=True)
|
|
raise RuntimeError(f"Qwen Image generation failed: {str(e)}")
|
|
|
|
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
|
|
"""Generate image using WaveSpeed AI models.
|
|
|
|
Args:
|
|
options: Image generation options
|
|
|
|
Returns:
|
|
ImageGenerationResult with generated image
|
|
|
|
Raises:
|
|
ValueError: If options are invalid
|
|
RuntimeError: If generation fails
|
|
"""
|
|
# Validate options
|
|
self._validate_options(options)
|
|
|
|
# Determine model
|
|
model = options.model or "ideogram-v3-turbo"
|
|
|
|
# Generate based on model
|
|
if model == "ideogram-v3-turbo":
|
|
image_bytes = self._generate_ideogram_v3(options)
|
|
elif model == "qwen-image":
|
|
image_bytes = self._generate_qwen_image(options)
|
|
else:
|
|
raise ValueError(f"Unsupported model: {model}")
|
|
|
|
# Load image to get dimensions
|
|
image = Image.open(io.BytesIO(image_bytes))
|
|
width, height = image.size
|
|
|
|
# Calculate estimated cost
|
|
model_info = self.SUPPORTED_MODELS[model]
|
|
estimated_cost = model_info["cost_per_image"]
|
|
|
|
# Return result
|
|
return ImageGenerationResult(
|
|
image_bytes=image_bytes,
|
|
width=width,
|
|
height=height,
|
|
provider="wavespeed",
|
|
model=model,
|
|
seed=options.seed,
|
|
metadata={
|
|
"provider": "wavespeed",
|
|
"model": model,
|
|
"model_name": model_info["name"],
|
|
"prompt": options.prompt,
|
|
"negative_prompt": options.negative_prompt,
|
|
"steps": options.steps or model_info["default_steps"],
|
|
"guidance_scale": options.guidance_scale,
|
|
"estimated_cost": estimated_cost,
|
|
}
|
|
)
|
|
|
|
@classmethod
|
|
def get_available_models(cls) -> dict:
|
|
"""Get available models and their information.
|
|
|
|
Returns:
|
|
Dictionary of available models
|
|
"""
|
|
return cls.SUPPORTED_MODELS
|
|
|