Files
ALwrity/backend/services/image_studio/create_service.py
2025-11-20 09:06:00 +05:30

459 lines
17 KiB
Python

"""Create Studio service for AI-powered image generation."""
import os
from typing import Optional, Dict, Any, List, Literal
from dataclasses import dataclass
from services.llm_providers.image_generation import (
ImageGenerationOptions,
ImageGenerationResult,
HuggingFaceImageProvider,
GeminiImageProvider,
StabilityImageProvider,
WaveSpeedImageProvider,
)
from .templates import TemplateManager, ImageTemplate, Platform, TemplateCategory
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.create")
@dataclass
class CreateStudioRequest:
"""Request for image generation in Create Studio."""
prompt: str
template_id: Optional[str] = None
provider: Optional[str] = None # "auto", "stability", "wavespeed", "huggingface", "gemini"
model: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
aspect_ratio: Optional[str] = None # e.g., "1:1", "16:9"
style_preset: Optional[str] = None
quality: Literal["draft", "standard", "premium"] = "standard"
negative_prompt: Optional[str] = None
guidance_scale: Optional[float] = None
steps: Optional[int] = None
seed: Optional[int] = None
num_variations: int = 1
enhance_prompt: bool = True
use_persona: bool = False
persona_id: Optional[str] = None
class CreateStudioService:
"""Service for Create Studio image generation operations."""
# Provider-to-model mapping for smart recommendations
PROVIDER_MODELS = {
"stability": {
"ultra": "stability-ultra", # Best quality, 8 credits
"core": "stability-core", # Fast & affordable, 3 credits
"sd3": "sd3.5-large", # SD3.5 model
},
"wavespeed": {
"ideogram-v3-turbo": "ideogram-v3-turbo", # Photorealistic, text rendering
"qwen-image": "qwen-image", # Fast generation
},
"huggingface": {
"flux": "black-forest-labs/FLUX.1-Krea-dev",
},
"gemini": {
"imagen": "imagen-3.0-generate-001",
}
}
# Quality-to-provider mapping
QUALITY_PROVIDERS = {
"draft": ["huggingface", "wavespeed:qwen-image"], # Fast, low cost
"standard": ["stability:core", "wavespeed:ideogram-v3-turbo"], # Balanced
"premium": ["wavespeed:ideogram-v3-turbo", "stability:ultra"], # Best quality
}
def __init__(self):
"""Initialize Create Studio service."""
self.template_manager = TemplateManager()
logger.info("[Create Studio] Initialized with template manager")
def _get_provider_instance(self, provider_name: str, api_key: Optional[str] = None):
"""Get provider instance by name.
Args:
provider_name: Name of the provider
api_key: Optional API key (uses env vars if not provided)
Returns:
Provider instance
Raises:
ValueError: If provider is not supported
"""
if provider_name == "stability":
return StabilityImageProvider(api_key=api_key or os.getenv("STABILITY_API_KEY"))
elif provider_name == "wavespeed":
return WaveSpeedImageProvider(api_key=api_key or os.getenv("WAVESPEED_API_KEY"))
elif provider_name == "huggingface":
return HuggingFaceImageProvider(api_token=api_key or os.getenv("HF_API_KEY"))
elif provider_name == "gemini":
return GeminiImageProvider(api_key=api_key or os.getenv("GEMINI_API_KEY"))
else:
raise ValueError(f"Unsupported provider: {provider_name}")
def _select_provider_and_model(
self,
request: CreateStudioRequest,
template: Optional[ImageTemplate] = None
) -> tuple[str, Optional[str]]:
"""Smart provider and model selection.
Args:
request: Create studio request
template: Optional template with recommendations
Returns:
Tuple of (provider_name, model_name)
"""
# Explicit provider selection
if request.provider and request.provider != "auto":
provider = request.provider
model = request.model
logger.info("[Provider Selection] User specified: %s (model: %s)", provider, model)
return provider, model
# Template recommendation
if template and template.recommended_provider:
provider = template.recommended_provider
logger.info("[Provider Selection] Template recommends: %s", provider)
# Map provider to specific model if not specified
if not request.model:
if provider == "ideogram":
return "wavespeed", "ideogram-v3-turbo"
elif provider == "qwen":
return "wavespeed", "qwen-image"
elif provider == "stability":
# Choose based on quality
if request.quality == "premium":
return "stability", "stability-ultra"
elif request.quality == "draft":
return "stability", "stability-core"
else:
return "stability", "stability-core"
return provider, request.model
# Quality-based selection
quality_options = self.QUALITY_PROVIDERS.get(request.quality, self.QUALITY_PROVIDERS["standard"])
selected = quality_options[0] # Pick first option
if ":" in selected:
provider, model = selected.split(":", 1)
else:
provider = selected
model = None
logger.info("[Provider Selection] Quality-based (%s): %s (model: %s)",
request.quality, provider, model)
return provider, model
def _enhance_prompt(self, prompt: str, style_preset: Optional[str] = None) -> str:
"""Enhance prompt with style and quality descriptors.
Args:
prompt: Original prompt
style_preset: Style preset to apply
Returns:
Enhanced prompt
"""
enhanced = prompt
# Add style-specific enhancements
style_enhancements = {
"photographic": ", professional photography, high quality, detailed, sharp focus, natural lighting",
"digital-art": ", digital art, vibrant colors, detailed, high quality, artstation trending",
"cinematic": ", cinematic lighting, dramatic, film grain, high quality, professional",
"3d-model": ", 3D render, octane render, unreal engine, high quality, detailed",
"anime": ", anime style, vibrant colors, detailed, high quality",
"line-art": ", clean line art, detailed linework, high contrast, professional",
}
if style_preset and style_preset in style_enhancements:
enhanced += style_enhancements[style_preset]
logger.info("[Prompt Enhancement] Original: %s", prompt[:100])
logger.info("[Prompt Enhancement] Enhanced: %s", enhanced[:100])
return enhanced
def _apply_template(self, request: CreateStudioRequest, template: ImageTemplate) -> CreateStudioRequest:
"""Apply template settings to request.
Args:
request: Original request
template: Template to apply
Returns:
Modified request
"""
# Apply template dimensions if not specified
if not request.width and not request.height:
request.width = template.aspect_ratio.width
request.height = template.aspect_ratio.height
# Apply template style if not specified
if not request.style_preset:
request.style_preset = template.style_preset
# Apply template quality if not specified
if request.quality == "standard":
request.quality = template.quality
logger.info("[Template Applied] %s -> %dx%d, style=%s, quality=%s",
template.name, request.width, request.height,
request.style_preset, request.quality)
return request
def _calculate_dimensions(
self,
width: Optional[int],
height: Optional[int],
aspect_ratio: Optional[str]
) -> tuple[int, int]:
"""Calculate image dimensions from width/height or aspect ratio.
Args:
width: Explicit width
height: Explicit height
aspect_ratio: Aspect ratio string (e.g., "16:9")
Returns:
Tuple of (width, height)
"""
# Both dimensions specified
if width and height:
return width, height
# Aspect ratio specified
if aspect_ratio:
try:
w_ratio, h_ratio = map(int, aspect_ratio.split(":"))
# Use width if specified
if width:
height = int(width * h_ratio / w_ratio)
return width, height
# Use height if specified
if height:
width = int(height * w_ratio / h_ratio)
return width, height
# Default size based on aspect ratio
# Use 1080p as base
if w_ratio >= h_ratio:
# Landscape or square
width = 1920
height = int(1920 * h_ratio / w_ratio)
else:
# Portrait
height = 1920
width = int(1920 * w_ratio / h_ratio)
return width, height
except ValueError:
logger.warning("[Dimensions] Invalid aspect ratio: %s", aspect_ratio)
# Default dimensions
return 1024, 1024
async def generate(
self,
request: CreateStudioRequest,
user_id: Optional[str] = None
) -> Dict[str, Any]:
"""Generate image(s) using Create Studio.
Args:
request: Create studio request
user_id: User ID for validation and tracking
Returns:
Dictionary with generation results
Raises:
ValueError: If request is invalid
RuntimeError: If generation fails
"""
logger.info("[Create Studio] Starting generation: prompt=%s, template=%s",
request.prompt[:100], request.template_id)
# Pre-flight validation: Check subscription and usage limits
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_generation_operations
from fastapi import HTTPException
db = next(get_db())
try:
pricing_service = PricingService(db)
logger.info(f"[Create Studio] 🛂 Running pre-flight validation for user {user_id}")
validate_image_generation_operations(
pricing_service=pricing_service,
user_id=user_id,
num_images=request.num_variations
)
logger.info(f"[Create Studio] ✅ Pre-flight validation passed - proceeding with generation")
except HTTPException as http_ex:
logger.error(f"[Create Studio] ❌ Pre-flight validation failed - blocking generation")
raise
finally:
db.close()
else:
logger.warning("[Create Studio] ⚠️ No user_id provided - skipping pre-flight validation")
# Load template if specified
template = None
if request.template_id:
template = self.template_manager.get_by_id(request.template_id)
if not template:
raise ValueError(f"Template not found: {request.template_id}")
# Apply template settings
request = self._apply_template(request, template)
# Calculate dimensions
width, height = self._calculate_dimensions(
request.width, request.height, request.aspect_ratio
)
# Enhance prompt if requested
prompt = request.prompt
if request.enhance_prompt:
prompt = self._enhance_prompt(prompt, request.style_preset)
# Select provider and model
provider_name, model = self._select_provider_and_model(request, template)
# Get provider instance
try:
provider = self._get_provider_instance(provider_name)
except Exception as e:
logger.error("[Create Studio] ❌ Failed to initialize provider %s: %s",
provider_name, str(e))
raise RuntimeError(f"Provider initialization failed: {str(e)}")
# Generate images
results = []
for i in range(request.num_variations):
logger.info("[Create Studio] Generating variation %d/%d",
i + 1, request.num_variations)
try:
# Prepare options
options = ImageGenerationOptions(
prompt=prompt,
negative_prompt=request.negative_prompt,
width=width,
height=height,
guidance_scale=request.guidance_scale,
steps=request.steps,
seed=request.seed + i if request.seed else None,
model=model,
extra={"style_preset": request.style_preset} if request.style_preset else {}
)
# Generate image
result: ImageGenerationResult = provider.generate(options)
results.append({
"image_bytes": result.image_bytes,
"width": result.width,
"height": result.height,
"provider": result.provider,
"model": result.model,
"seed": result.seed,
"metadata": result.metadata,
"variation": i + 1,
})
logger.info("[Create Studio] ✅ Variation %d generated successfully", i + 1)
except Exception as e:
logger.error("[Create Studio] ❌ Failed to generate variation %d: %s",
i + 1, str(e), exc_info=True)
results.append({
"error": str(e),
"variation": i + 1,
})
# Return results
return {
"success": True,
"request": {
"prompt": request.prompt,
"enhanced_prompt": prompt if request.enhance_prompt else None,
"template_id": request.template_id,
"template_name": template.name if template else None,
"provider": provider_name,
"model": model,
"dimensions": f"{width}x{height}",
"quality": request.quality,
"num_variations": request.num_variations,
},
"results": results,
"total_generated": sum(1 for r in results if "image_bytes" in r),
"total_failed": sum(1 for r in results if "error" in r),
}
def get_templates(
self,
platform: Optional[Platform] = None,
category: Optional[TemplateCategory] = None
) -> List[ImageTemplate]:
"""Get available templates.
Args:
platform: Filter by platform
category: Filter by category
Returns:
List of templates
"""
if platform:
return self.template_manager.get_by_platform(platform)
elif category:
return self.template_manager.get_by_category(category)
else:
return self.template_manager.get_all_templates()
def search_templates(self, query: str) -> List[ImageTemplate]:
"""Search templates by query.
Args:
query: Search query
Returns:
List of matching templates
"""
return self.template_manager.search(query)
def recommend_templates(
self,
use_case: str,
platform: Optional[Platform] = None
) -> List[ImageTemplate]:
"""Recommend templates based on use case.
Args:
use_case: Description of use case
platform: Optional platform filter
Returns:
List of recommended templates
"""
return self.template_manager.recommend_for_use_case(use_case, platform)