Files
ALwrity/backend/services/llm_providers/image_generation/wavespeed_provider.py
2025-11-20 09:06:00 +05:30

244 lines
9.0 KiB
Python

"""WaveSpeed AI image generation provider (Ideogram V3 Turbo & Qwen Image)."""
import io
import os
from typing import Optional
from PIL import Image
from .base import ImageGenerationProvider, ImageGenerationOptions, ImageGenerationResult
from services.wavespeed.client import WaveSpeedClient
from utils.logger_utils import get_service_logger
logger = get_service_logger("wavespeed.image_provider")
class WaveSpeedImageProvider(ImageGenerationProvider):
"""WaveSpeed AI image generation provider supporting Ideogram V3 and Qwen."""
SUPPORTED_MODELS = {
"ideogram-v3-turbo": {
"name": "Ideogram V3 Turbo",
"description": "Photorealistic generation with superior text rendering",
"cost_per_image": 0.10, # Estimated, adjust based on actual pricing
"max_resolution": (1024, 1024),
"default_steps": 20,
},
"qwen-image": {
"name": "Qwen Image",
"description": "Fast, high-quality text-to-image generation",
"cost_per_image": 0.05, # Estimated, adjust based on actual pricing
"max_resolution": (1024, 1024),
"default_steps": 15,
}
}
def __init__(self, api_key: Optional[str] = None):
"""Initialize WaveSpeed image provider.
Args:
api_key: WaveSpeed API key (falls back to env var if not provided)
"""
self.api_key = api_key or os.getenv("WAVESPEED_API_KEY")
if not self.api_key:
raise ValueError("WaveSpeed API key not found. Set WAVESPEED_API_KEY environment variable.")
self.client = WaveSpeedClient(api_key=self.api_key)
logger.info("[WaveSpeed Image Provider] Initialized with available models: %s",
list(self.SUPPORTED_MODELS.keys()))
def _validate_options(self, options: ImageGenerationOptions) -> None:
"""Validate generation options.
Args:
options: Image generation options
Raises:
ValueError: If options are invalid
"""
model = options.model or "ideogram-v3-turbo"
if model not in self.SUPPORTED_MODELS:
raise ValueError(
f"Unsupported model: {model}. "
f"Supported models: {list(self.SUPPORTED_MODELS.keys())}"
)
model_info = self.SUPPORTED_MODELS[model]
max_width, max_height = model_info["max_resolution"]
if options.width > max_width or options.height > max_height:
raise ValueError(
f"Resolution {options.width}x{options.height} exceeds maximum "
f"{max_width}x{max_height} for model {model}"
)
if not options.prompt or len(options.prompt.strip()) == 0:
raise ValueError("Prompt cannot be empty")
def _generate_ideogram_v3(self, options: ImageGenerationOptions) -> bytes:
"""Generate image using Ideogram V3 Turbo.
Args:
options: Image generation options
Returns:
Image bytes
"""
logger.info("[Ideogram V3] Starting image generation: %s", options.prompt[:100])
try:
# Prepare parameters for WaveSpeed Ideogram V3 API
# Note: Adjust these based on actual WaveSpeed API documentation
params = {
"model": "ideogram-v3-turbo",
"prompt": options.prompt,
"width": options.width,
"height": options.height,
"num_inference_steps": options.steps or self.SUPPORTED_MODELS["ideogram-v3-turbo"]["default_steps"],
}
# Add optional parameters
if options.negative_prompt:
params["negative_prompt"] = options.negative_prompt
if options.guidance_scale:
params["guidance_scale"] = options.guidance_scale
if options.seed:
params["seed"] = options.seed
# Call WaveSpeed API (using generic image generation method)
# This will need to be adjusted based on actual WaveSpeed client implementation
result = self.client.generate_image(**params)
# Extract image bytes from result
# Adjust based on actual WaveSpeed API response format
if isinstance(result, bytes):
image_bytes = result
elif isinstance(result, dict) and "image" in result:
image_bytes = result["image"]
else:
raise ValueError(f"Unexpected response format from WaveSpeed API: {type(result)}")
logger.info("[Ideogram V3] ✅ Successfully generated image: %d bytes", len(image_bytes))
return image_bytes
except Exception as e:
logger.error("[Ideogram V3] ❌ Error generating image: %s", str(e), exc_info=True)
raise RuntimeError(f"Ideogram V3 generation failed: {str(e)}")
def _generate_qwen_image(self, options: ImageGenerationOptions) -> bytes:
"""Generate image using Qwen Image.
Args:
options: Image generation options
Returns:
Image bytes
"""
logger.info("[Qwen Image] Starting image generation: %s", options.prompt[:100])
try:
# Prepare parameters for WaveSpeed Qwen Image API
params = {
"model": "qwen-image",
"prompt": options.prompt,
"width": options.width,
"height": options.height,
"num_inference_steps": options.steps or self.SUPPORTED_MODELS["qwen-image"]["default_steps"],
}
# Add optional parameters
if options.negative_prompt:
params["negative_prompt"] = options.negative_prompt
if options.guidance_scale:
params["guidance_scale"] = options.guidance_scale
if options.seed:
params["seed"] = options.seed
# Call WaveSpeed API
result = self.client.generate_image(**params)
# Extract image bytes from result
if isinstance(result, bytes):
image_bytes = result
elif isinstance(result, dict) and "image" in result:
image_bytes = result["image"]
else:
raise ValueError(f"Unexpected response format from WaveSpeed API: {type(result)}")
logger.info("[Qwen Image] ✅ Successfully generated image: %d bytes", len(image_bytes))
return image_bytes
except Exception as e:
logger.error("[Qwen Image] ❌ Error generating image: %s", str(e), exc_info=True)
raise RuntimeError(f"Qwen Image generation failed: {str(e)}")
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
"""Generate image using WaveSpeed AI models.
Args:
options: Image generation options
Returns:
ImageGenerationResult with generated image
Raises:
ValueError: If options are invalid
RuntimeError: If generation fails
"""
# Validate options
self._validate_options(options)
# Determine model
model = options.model or "ideogram-v3-turbo"
# Generate based on model
if model == "ideogram-v3-turbo":
image_bytes = self._generate_ideogram_v3(options)
elif model == "qwen-image":
image_bytes = self._generate_qwen_image(options)
else:
raise ValueError(f"Unsupported model: {model}")
# Load image to get dimensions
image = Image.open(io.BytesIO(image_bytes))
width, height = image.size
# Calculate estimated cost
model_info = self.SUPPORTED_MODELS[model]
estimated_cost = model_info["cost_per_image"]
# Return result
return ImageGenerationResult(
image_bytes=image_bytes,
width=width,
height=height,
provider="wavespeed",
model=model,
seed=options.seed,
metadata={
"provider": "wavespeed",
"model": model,
"model_name": model_info["name"],
"prompt": options.prompt,
"negative_prompt": options.negative_prompt,
"steps": options.steps or model_info["default_steps"],
"guidance_scale": options.guidance_scale,
"estimated_cost": estimated_cost,
}
)
@classmethod
def get_available_models(cls) -> dict:
"""Get available models and their information.
Returns:
Dictionary of available models
"""
return cls.SUPPORTED_MODELS