Base code

This commit is contained in:
Kunthawat Greethong
2026-01-08 22:39:53 +07:00
parent 697115c61a
commit c35fa52117
2169 changed files with 626670 additions and 0 deletions

View File

@@ -0,0 +1,34 @@
"""Image Studio service package for centralized image operations."""
from .studio_manager import ImageStudioManager
from .create_service import CreateStudioService, CreateStudioRequest
from .edit_service import EditStudioService, EditStudioRequest
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
from .control_service import ControlStudioService, ControlStudioRequest
from .social_optimizer_service import SocialOptimizerService, SocialOptimizerRequest
from .transform_service import (
TransformStudioService,
TransformImageToVideoRequest,
TalkingAvatarRequest,
)
from .templates import PlatformTemplates, TemplateManager
__all__ = [
"ImageStudioManager",
"CreateStudioService",
"CreateStudioRequest",
"EditStudioService",
"EditStudioRequest",
"UpscaleStudioService",
"UpscaleStudioRequest",
"ControlStudioService",
"ControlStudioRequest",
"SocialOptimizerService",
"SocialOptimizerRequest",
"TransformStudioService",
"TransformImageToVideoRequest",
"TalkingAvatarRequest",
"PlatformTemplates",
"TemplateManager",
]

View File

@@ -0,0 +1,277 @@
"""Control Studio service for AI-powered controlled image generation."""
from __future__ import annotations
import base64
import io
from dataclasses import dataclass
from typing import Any, Dict, Literal, Optional
from PIL import Image
from services.stability_service import StabilityAIService
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.control")
ControlOperationType = Literal[
"sketch",
"structure",
"style",
"style_transfer",
]
@dataclass
class ControlStudioRequest:
"""Normalized request payload for Control Studio operations."""
operation: ControlOperationType
prompt: str
control_image_base64: str # Sketch, structure, or style reference
style_image_base64: Optional[str] = None # For style_transfer only
negative_prompt: Optional[str] = None
control_strength: Optional[float] = None # For sketch/structure
fidelity: Optional[float] = None # For style
style_strength: Optional[float] = None # For style_transfer
composition_fidelity: Optional[float] = None # For style_transfer
change_strength: Optional[float] = None # For style_transfer
aspect_ratio: Optional[str] = None # For style
style_preset: Optional[str] = None
seed: Optional[int] = None
output_format: str = "png"
class ControlStudioService:
"""Service layer orchestrating Control Studio operations."""
SUPPORTED_OPERATIONS: Dict[ControlOperationType, Dict[str, Any]] = {
"sketch": {
"label": "Sketch to Image",
"description": "Transform sketches into refined images with precise control.",
"provider": "stability",
"fields": {
"control_image": True,
"style_image": False,
"control_strength": True,
"fidelity": False,
"style_strength": False,
"aspect_ratio": False,
},
},
"structure": {
"label": "Structure Control",
"description": "Generate images maintaining the structure of an input image.",
"provider": "stability",
"fields": {
"control_image": True,
"style_image": False,
"control_strength": True,
"fidelity": False,
"style_strength": False,
"aspect_ratio": False,
},
},
"style": {
"label": "Style Control",
"description": "Generate images using style from a reference image.",
"provider": "stability",
"fields": {
"control_image": True,
"style_image": False,
"control_strength": False,
"fidelity": True,
"style_strength": False,
"aspect_ratio": True,
},
},
"style_transfer": {
"label": "Style Transfer",
"description": "Apply visual characteristics from a style image to a target image.",
"provider": "stability",
"fields": {
"control_image": True, # init_image
"style_image": True,
"control_strength": False,
"fidelity": False,
"style_strength": True,
"aspect_ratio": False,
},
},
}
def __init__(self):
logger.info("[Control Studio] Initialized control service")
@staticmethod
def _decode_base64_image(value: Optional[str]) -> Optional[bytes]:
"""Decode a base64 (or data URL) string to bytes."""
if not value:
return None
try:
# Handle data URLs (data:image/png;base64,...)
if value.startswith("data:"):
_, b64data = value.split(",", 1)
else:
b64data = value
return base64.b64decode(b64data)
except Exception as exc:
logger.error(f"[Control Studio] Failed to decode base64 image: {exc}")
raise ValueError("Invalid base64 image payload") from exc
@staticmethod
def _image_bytes_to_metadata(image_bytes: bytes) -> Dict[str, Any]:
"""Extract width/height metadata from image bytes."""
with Image.open(io.BytesIO(image_bytes)) as img:
return {
"width": img.width,
"height": img.height,
}
@staticmethod
def _bytes_to_base64(image_bytes: bytes, output_format: str = "png") -> str:
"""Convert raw bytes to base64 data URL."""
b64 = base64.b64encode(image_bytes).decode("utf-8")
return f"data:image/{output_format};base64,{b64}"
def list_operations(self) -> Dict[str, Dict[str, Any]]:
"""Expose supported operations for UI rendering."""
return self.SUPPORTED_OPERATIONS
async def process_control(
self,
request: ControlStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Process control request and return normalized response."""
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_control_operations
from fastapi import HTTPException
db = next(get_db())
try:
pricing_service = PricingService(db)
logger.info(f"[Control Studio] 🛂 Running pre-flight validation for user {user_id}")
validate_image_control_operations(
pricing_service=pricing_service,
user_id=user_id,
num_images=1,
)
logger.info("[Control Studio] ✅ Pre-flight validation passed")
except HTTPException:
logger.error("[Control Studio] ❌ Pre-flight validation failed")
raise
finally:
db.close()
else:
logger.warning("[Control Studio] ⚠️ No user_id provided - skipping pre-flight validation")
control_image_bytes = self._decode_base64_image(request.control_image_base64)
if not control_image_bytes:
raise ValueError("Control image payload is required")
style_image_bytes = self._decode_base64_image(request.style_image_base64)
operation = request.operation
logger.info("[Control Studio] Processing operation='%s' for user=%s", operation, user_id)
if operation not in self.SUPPORTED_OPERATIONS:
raise ValueError(f"Unsupported control operation: {operation}")
stability_service = StabilityAIService()
async with stability_service:
if operation == "sketch":
result = await stability_service.control_sketch(
image=control_image_bytes,
prompt=request.prompt,
control_strength=request.control_strength or 0.7,
negative_prompt=request.negative_prompt,
seed=request.seed,
output_format=request.output_format,
style_preset=request.style_preset,
)
elif operation == "structure":
result = await stability_service.control_structure(
image=control_image_bytes,
prompt=request.prompt,
control_strength=request.control_strength or 0.7,
negative_prompt=request.negative_prompt,
seed=request.seed,
output_format=request.output_format,
style_preset=request.style_preset,
)
elif operation == "style":
result = await stability_service.control_style(
image=control_image_bytes,
prompt=request.prompt,
negative_prompt=request.negative_prompt,
aspect_ratio=request.aspect_ratio or "1:1",
fidelity=request.fidelity or 0.5,
seed=request.seed,
output_format=request.output_format,
style_preset=request.style_preset,
)
elif operation == "style_transfer":
if not style_image_bytes:
raise ValueError("Style image is required for style transfer")
result = await stability_service.control_style_transfer(
init_image=control_image_bytes,
style_image=style_image_bytes,
prompt=request.prompt or "",
negative_prompt=request.negative_prompt,
style_strength=request.style_strength or 1.0,
composition_fidelity=request.composition_fidelity or 0.9,
change_strength=request.change_strength or 0.9,
seed=request.seed,
output_format=request.output_format,
)
else:
raise ValueError(f"Unsupported control operation: {operation}")
image_bytes = self._extract_image_bytes(result)
metadata = self._image_bytes_to_metadata(image_bytes)
metadata.update(
{
"operation": operation,
"style_preset": request.style_preset,
"provider": self.SUPPORTED_OPERATIONS[operation]["provider"],
}
)
response = {
"success": True,
"operation": operation,
"provider": metadata["provider"],
"image_base64": self._bytes_to_base64(image_bytes, request.output_format),
"width": metadata["width"],
"height": metadata["height"],
"metadata": metadata,
}
logger.info("[Control Studio] ✅ Operation '%s' completed", operation)
return response
@staticmethod
def _extract_image_bytes(result: Any) -> bytes:
"""Normalize Stability responses into raw image bytes."""
if isinstance(result, bytes):
return result
if isinstance(result, dict):
artifacts = result.get("artifacts") or result.get("data") or result.get("images") or []
for artifact in artifacts:
if isinstance(artifact, dict):
if artifact.get("base64"):
return base64.b64decode(artifact["base64"])
if artifact.get("b64_json"):
return base64.b64decode(artifact["b64_json"])
raise RuntimeError("Unable to extract image bytes from provider response")

View File

@@ -0,0 +1,458 @@
"""Create Studio service for AI-powered image generation."""
import os
from typing import Optional, Dict, Any, List, Literal
from dataclasses import dataclass
from services.llm_providers.image_generation import (
ImageGenerationOptions,
ImageGenerationResult,
HuggingFaceImageProvider,
GeminiImageProvider,
StabilityImageProvider,
WaveSpeedImageProvider,
)
from .templates import TemplateManager, ImageTemplate, Platform, TemplateCategory
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.create")
@dataclass
class CreateStudioRequest:
"""Request for image generation in Create Studio."""
prompt: str
template_id: Optional[str] = None
provider: Optional[str] = None # "auto", "stability", "wavespeed", "huggingface", "gemini"
model: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
aspect_ratio: Optional[str] = None # e.g., "1:1", "16:9"
style_preset: Optional[str] = None
quality: Literal["draft", "standard", "premium"] = "standard"
negative_prompt: Optional[str] = None
guidance_scale: Optional[float] = None
steps: Optional[int] = None
seed: Optional[int] = None
num_variations: int = 1
enhance_prompt: bool = True
use_persona: bool = False
persona_id: Optional[str] = None
class CreateStudioService:
"""Service for Create Studio image generation operations."""
# Provider-to-model mapping for smart recommendations
PROVIDER_MODELS = {
"stability": {
"ultra": "stability-ultra", # Best quality, 8 credits
"core": "stability-core", # Fast & affordable, 3 credits
"sd3": "sd3.5-large", # SD3.5 model
},
"wavespeed": {
"ideogram-v3-turbo": "ideogram-v3-turbo", # Photorealistic, text rendering
"qwen-image": "qwen-image", # Fast generation
},
"huggingface": {
"flux": "black-forest-labs/FLUX.1-Krea-dev",
},
"gemini": {
"imagen": "imagen-3.0-generate-001",
}
}
# Quality-to-provider mapping
QUALITY_PROVIDERS = {
"draft": ["huggingface", "wavespeed:qwen-image"], # Fast, low cost
"standard": ["stability:core", "wavespeed:ideogram-v3-turbo"], # Balanced
"premium": ["wavespeed:ideogram-v3-turbo", "stability:ultra"], # Best quality
}
def __init__(self):
"""Initialize Create Studio service."""
self.template_manager = TemplateManager()
logger.info("[Create Studio] Initialized with template manager")
def _get_provider_instance(self, provider_name: str, api_key: Optional[str] = None):
"""Get provider instance by name.
Args:
provider_name: Name of the provider
api_key: Optional API key (uses env vars if not provided)
Returns:
Provider instance
Raises:
ValueError: If provider is not supported
"""
if provider_name == "stability":
return StabilityImageProvider(api_key=api_key or os.getenv("STABILITY_API_KEY"))
elif provider_name == "wavespeed":
return WaveSpeedImageProvider(api_key=api_key or os.getenv("WAVESPEED_API_KEY"))
elif provider_name == "huggingface":
return HuggingFaceImageProvider(api_token=api_key or os.getenv("HF_API_KEY"))
elif provider_name == "gemini":
return GeminiImageProvider(api_key=api_key or os.getenv("GEMINI_API_KEY"))
else:
raise ValueError(f"Unsupported provider: {provider_name}")
def _select_provider_and_model(
self,
request: CreateStudioRequest,
template: Optional[ImageTemplate] = None
) -> tuple[str, Optional[str]]:
"""Smart provider and model selection.
Args:
request: Create studio request
template: Optional template with recommendations
Returns:
Tuple of (provider_name, model_name)
"""
# Explicit provider selection
if request.provider and request.provider != "auto":
provider = request.provider
model = request.model
logger.info("[Provider Selection] User specified: %s (model: %s)", provider, model)
return provider, model
# Template recommendation
if template and template.recommended_provider:
provider = template.recommended_provider
logger.info("[Provider Selection] Template recommends: %s", provider)
# Map provider to specific model if not specified
if not request.model:
if provider == "ideogram":
return "wavespeed", "ideogram-v3-turbo"
elif provider == "qwen":
return "wavespeed", "qwen-image"
elif provider == "stability":
# Choose based on quality
if request.quality == "premium":
return "stability", "stability-ultra"
elif request.quality == "draft":
return "stability", "stability-core"
else:
return "stability", "stability-core"
return provider, request.model
# Quality-based selection
quality_options = self.QUALITY_PROVIDERS.get(request.quality, self.QUALITY_PROVIDERS["standard"])
selected = quality_options[0] # Pick first option
if ":" in selected:
provider, model = selected.split(":", 1)
else:
provider = selected
model = None
logger.info("[Provider Selection] Quality-based (%s): %s (model: %s)",
request.quality, provider, model)
return provider, model
def _enhance_prompt(self, prompt: str, style_preset: Optional[str] = None) -> str:
"""Enhance prompt with style and quality descriptors.
Args:
prompt: Original prompt
style_preset: Style preset to apply
Returns:
Enhanced prompt
"""
enhanced = prompt
# Add style-specific enhancements
style_enhancements = {
"photographic": ", professional photography, high quality, detailed, sharp focus, natural lighting",
"digital-art": ", digital art, vibrant colors, detailed, high quality, artstation trending",
"cinematic": ", cinematic lighting, dramatic, film grain, high quality, professional",
"3d-model": ", 3D render, octane render, unreal engine, high quality, detailed",
"anime": ", anime style, vibrant colors, detailed, high quality",
"line-art": ", clean line art, detailed linework, high contrast, professional",
}
if style_preset and style_preset in style_enhancements:
enhanced += style_enhancements[style_preset]
logger.info("[Prompt Enhancement] Original: %s", prompt[:100])
logger.info("[Prompt Enhancement] Enhanced: %s", enhanced[:100])
return enhanced
def _apply_template(self, request: CreateStudioRequest, template: ImageTemplate) -> CreateStudioRequest:
"""Apply template settings to request.
Args:
request: Original request
template: Template to apply
Returns:
Modified request
"""
# Apply template dimensions if not specified
if not request.width and not request.height:
request.width = template.aspect_ratio.width
request.height = template.aspect_ratio.height
# Apply template style if not specified
if not request.style_preset:
request.style_preset = template.style_preset
# Apply template quality if not specified
if request.quality == "standard":
request.quality = template.quality
logger.info("[Template Applied] %s -> %dx%d, style=%s, quality=%s",
template.name, request.width, request.height,
request.style_preset, request.quality)
return request
def _calculate_dimensions(
self,
width: Optional[int],
height: Optional[int],
aspect_ratio: Optional[str]
) -> tuple[int, int]:
"""Calculate image dimensions from width/height or aspect ratio.
Args:
width: Explicit width
height: Explicit height
aspect_ratio: Aspect ratio string (e.g., "16:9")
Returns:
Tuple of (width, height)
"""
# Both dimensions specified
if width and height:
return width, height
# Aspect ratio specified
if aspect_ratio:
try:
w_ratio, h_ratio = map(int, aspect_ratio.split(":"))
# Use width if specified
if width:
height = int(width * h_ratio / w_ratio)
return width, height
# Use height if specified
if height:
width = int(height * w_ratio / h_ratio)
return width, height
# Default size based on aspect ratio
# Use 1080p as base
if w_ratio >= h_ratio:
# Landscape or square
width = 1920
height = int(1920 * h_ratio / w_ratio)
else:
# Portrait
height = 1920
width = int(1920 * w_ratio / h_ratio)
return width, height
except ValueError:
logger.warning("[Dimensions] Invalid aspect ratio: %s", aspect_ratio)
# Default dimensions
return 1024, 1024
async def generate(
self,
request: CreateStudioRequest,
user_id: Optional[str] = None
) -> Dict[str, Any]:
"""Generate image(s) using Create Studio.
Args:
request: Create studio request
user_id: User ID for validation and tracking
Returns:
Dictionary with generation results
Raises:
ValueError: If request is invalid
RuntimeError: If generation fails
"""
logger.info("[Create Studio] Starting generation: prompt=%s, template=%s",
request.prompt[:100], request.template_id)
# Pre-flight validation: Check subscription and usage limits
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_generation_operations
from fastapi import HTTPException
db = next(get_db())
try:
pricing_service = PricingService(db)
logger.info(f"[Create Studio] 🛂 Running pre-flight validation for user {user_id}")
validate_image_generation_operations(
pricing_service=pricing_service,
user_id=user_id,
num_images=request.num_variations
)
logger.info(f"[Create Studio] ✅ Pre-flight validation passed - proceeding with generation")
except HTTPException as http_ex:
logger.error(f"[Create Studio] ❌ Pre-flight validation failed - blocking generation")
raise
finally:
db.close()
else:
logger.warning("[Create Studio] ⚠️ No user_id provided - skipping pre-flight validation")
# Load template if specified
template = None
if request.template_id:
template = self.template_manager.get_by_id(request.template_id)
if not template:
raise ValueError(f"Template not found: {request.template_id}")
# Apply template settings
request = self._apply_template(request, template)
# Calculate dimensions
width, height = self._calculate_dimensions(
request.width, request.height, request.aspect_ratio
)
# Enhance prompt if requested
prompt = request.prompt
if request.enhance_prompt:
prompt = self._enhance_prompt(prompt, request.style_preset)
# Select provider and model
provider_name, model = self._select_provider_and_model(request, template)
# Get provider instance
try:
provider = self._get_provider_instance(provider_name)
except Exception as e:
logger.error("[Create Studio] ❌ Failed to initialize provider %s: %s",
provider_name, str(e))
raise RuntimeError(f"Provider initialization failed: {str(e)}")
# Generate images
results = []
for i in range(request.num_variations):
logger.info("[Create Studio] Generating variation %d/%d",
i + 1, request.num_variations)
try:
# Prepare options
options = ImageGenerationOptions(
prompt=prompt,
negative_prompt=request.negative_prompt,
width=width,
height=height,
guidance_scale=request.guidance_scale,
steps=request.steps,
seed=request.seed + i if request.seed else None,
model=model,
extra={"style_preset": request.style_preset} if request.style_preset else {}
)
# Generate image
result: ImageGenerationResult = provider.generate(options)
results.append({
"image_bytes": result.image_bytes,
"width": result.width,
"height": result.height,
"provider": result.provider,
"model": result.model,
"seed": result.seed,
"metadata": result.metadata,
"variation": i + 1,
})
logger.info("[Create Studio] ✅ Variation %d generated successfully", i + 1)
except Exception as e:
logger.error("[Create Studio] ❌ Failed to generate variation %d: %s",
i + 1, str(e), exc_info=True)
results.append({
"error": str(e),
"variation": i + 1,
})
# Return results
return {
"success": True,
"request": {
"prompt": request.prompt,
"enhanced_prompt": prompt if request.enhance_prompt else None,
"template_id": request.template_id,
"template_name": template.name if template else None,
"provider": provider_name,
"model": model,
"dimensions": f"{width}x{height}",
"quality": request.quality,
"num_variations": request.num_variations,
},
"results": results,
"total_generated": sum(1 for r in results if "image_bytes" in r),
"total_failed": sum(1 for r in results if "error" in r),
}
def get_templates(
self,
platform: Optional[Platform] = None,
category: Optional[TemplateCategory] = None
) -> List[ImageTemplate]:
"""Get available templates.
Args:
platform: Filter by platform
category: Filter by category
Returns:
List of templates
"""
if platform:
return self.template_manager.get_by_platform(platform)
elif category:
return self.template_manager.get_by_category(category)
else:
return self.template_manager.get_all_templates()
def search_templates(self, query: str) -> List[ImageTemplate]:
"""Search templates by query.
Args:
query: Search query
Returns:
List of matching templates
"""
return self.template_manager.search(query)
def recommend_templates(
self,
use_case: str,
platform: Optional[Platform] = None
) -> List[ImageTemplate]:
"""Recommend templates based on use case.
Args:
use_case: Description of use case
platform: Optional platform filter
Returns:
List of recommended templates
"""
return self.template_manager.recommend_for_use_case(use_case, platform)

View File

@@ -0,0 +1,461 @@
"""Edit Studio service for AI-powered image editing and transformations."""
from __future__ import annotations
import asyncio
import base64
import io
from dataclasses import dataclass, field
from typing import Any, Dict, Literal, Optional
from PIL import Image
from services.llm_providers.main_image_editing import edit_image as huggingface_edit_image
from services.stability_service import StabilityAIService
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.edit")
EditOperationType = Literal[
"remove_background",
"inpaint",
"outpaint",
"search_replace",
"search_recolor",
"relight",
"general_edit",
]
@dataclass
class EditStudioRequest:
"""Normalized request payload for Edit Studio operations."""
image_base64: str
operation: EditOperationType
prompt: Optional[str] = None
negative_prompt: Optional[str] = None
mask_base64: Optional[str] = None
search_prompt: Optional[str] = None
select_prompt: Optional[str] = None
background_image_base64: Optional[str] = None
lighting_image_base64: Optional[str] = None
expand_left: Optional[int] = None
expand_right: Optional[int] = None
expand_up: Optional[int] = None
expand_down: Optional[int] = None
provider: Optional[str] = None
model: Optional[str] = None
style_preset: Optional[str] = None
guidance_scale: Optional[float] = None
steps: Optional[int] = None
seed: Optional[int] = None
output_format: str = "png"
options: Dict[str, Any] = field(default_factory=dict)
class EditStudioService:
"""Service layer orchestrating Edit Studio operations."""
SUPPORTED_OPERATIONS: Dict[EditOperationType, Dict[str, Any]] = {
"remove_background": {
"label": "Remove Background",
"description": "Isolate the main subject and remove the background.",
"provider": "stability",
"async": False,
"fields": {
"prompt": False,
"mask": False,
"negative_prompt": False,
"search_prompt": False,
"select_prompt": False,
"background": False,
"lighting": False,
"expansion": False,
},
},
"inpaint": {
"label": "Inpaint & Fix",
"description": "Edit specific regions using prompts and optional masks.",
"provider": "stability",
"async": False,
"fields": {
"prompt": True,
"mask": True,
"negative_prompt": True,
"search_prompt": False,
"select_prompt": False,
"background": False,
"lighting": False,
"expansion": False,
},
},
"outpaint": {
"label": "Outpaint",
"description": "Extend the canvas in any direction with smart fill.",
"provider": "stability",
"async": False,
"fields": {
"prompt": False,
"mask": False,
"negative_prompt": True,
"search_prompt": False,
"select_prompt": False,
"background": False,
"lighting": False,
"expansion": True,
},
},
"search_replace": {
"label": "Search & Replace",
"description": "Locate objects via search prompt and replace them. Optional mask for precise control.",
"provider": "stability",
"async": False,
"fields": {
"prompt": True,
"mask": True, # Optional mask for precise region selection
"negative_prompt": False,
"search_prompt": True,
"select_prompt": False,
"background": False,
"lighting": False,
"expansion": False,
},
},
"search_recolor": {
"label": "Search & Recolor",
"description": "Select elements via prompt and recolor them. Optional mask for exact region selection.",
"provider": "stability",
"async": False,
"fields": {
"prompt": True,
"mask": True, # Optional mask for precise region selection
"negative_prompt": False,
"search_prompt": False,
"select_prompt": True,
"background": False,
"lighting": False,
"expansion": False,
},
},
"relight": {
"label": "Replace Background & Relight",
"description": "Swap backgrounds and relight using reference images.",
"provider": "stability",
"async": True,
"fields": {
"prompt": False,
"mask": False,
"negative_prompt": False,
"search_prompt": False,
"select_prompt": False,
"background": True,
"lighting": True,
"expansion": False,
},
},
"general_edit": {
"label": "Prompt-based Edit",
"description": "Free-form editing powered by Hugging Face image-to-image models. Optional mask for selective editing.",
"provider": "huggingface",
"async": False,
"fields": {
"prompt": True,
"mask": True, # Optional mask for selective region editing
"negative_prompt": True,
"search_prompt": False,
"select_prompt": False,
"background": False,
"lighting": False,
"expansion": False,
},
},
}
def __init__(self):
logger.info("[Edit Studio] Initialized edit service")
@staticmethod
def _decode_base64_image(value: Optional[str]) -> Optional[bytes]:
"""Decode a base64 (or data URL) string to bytes."""
if not value:
return None
try:
# Handle data URLs (data:image/png;base64,...)
if value.startswith("data:"):
_, b64data = value.split(",", 1)
else:
b64data = value
return base64.b64decode(b64data)
except Exception as exc:
logger.error(f"[Edit Studio] Failed to decode base64 image: {exc}")
raise ValueError("Invalid base64 image payload") from exc
@staticmethod
def _image_bytes_to_metadata(image_bytes: bytes) -> Dict[str, Any]:
"""Extract width/height metadata from image bytes."""
with Image.open(io.BytesIO(image_bytes)) as img:
return {
"width": img.width,
"height": img.height,
}
@staticmethod
def _bytes_to_base64(image_bytes: bytes, output_format: str = "png") -> str:
"""Convert raw bytes to base64 data URL."""
b64 = base64.b64encode(image_bytes).decode("utf-8")
return f"data:image/{output_format};base64,{b64}"
def list_operations(self) -> Dict[str, Dict[str, Any]]:
"""Expose supported operations for UI rendering."""
return self.SUPPORTED_OPERATIONS
async def process_edit(
self,
request: EditStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Process edit request and return normalized response."""
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_editing_operations
from fastapi import HTTPException
db = next(get_db())
try:
pricing_service = PricingService(db)
logger.info(f"[Edit Studio] 🛂 Running pre-flight validation for user {user_id}")
validate_image_editing_operations(
pricing_service=pricing_service,
user_id=user_id,
)
logger.info("[Edit Studio] ✅ Pre-flight validation passed")
except HTTPException:
logger.error("[Edit Studio] ❌ Pre-flight validation failed")
raise
finally:
db.close()
else:
logger.warning("[Edit Studio] ⚠️ No user_id provided - skipping pre-flight validation")
image_bytes = self._decode_base64_image(request.image_base64)
if not image_bytes:
raise ValueError("Primary image payload is required")
mask_bytes = self._decode_base64_image(request.mask_base64)
background_bytes = self._decode_base64_image(request.background_image_base64)
lighting_bytes = self._decode_base64_image(request.lighting_image_base64)
operation = request.operation
logger.info("[Edit Studio] Processing operation='%s' for user=%s", operation, user_id)
if operation not in self.SUPPORTED_OPERATIONS:
raise ValueError(f"Unsupported edit operation: {operation}")
if operation in {"remove_background", "inpaint", "outpaint", "search_replace", "search_recolor", "relight"}:
image_bytes = await self._handle_stability_edit(
operation=operation,
request=request,
image_bytes=image_bytes,
mask_bytes=mask_bytes,
background_bytes=background_bytes,
lighting_bytes=lighting_bytes,
)
else:
image_bytes = await self._handle_general_edit(
request=request,
image_bytes=image_bytes,
mask_bytes=mask_bytes,
user_id=user_id,
)
metadata = self._image_bytes_to_metadata(image_bytes)
metadata.update(
{
"operation": operation,
"style_preset": request.style_preset,
"provider": self.SUPPORTED_OPERATIONS[operation]["provider"],
}
)
response = {
"success": True,
"operation": operation,
"provider": metadata["provider"],
"image_base64": self._bytes_to_base64(image_bytes, request.output_format),
"width": metadata["width"],
"height": metadata["height"],
"metadata": metadata,
}
logger.info("[Edit Studio] ✅ Operation '%s' completed", operation)
return response
async def _handle_stability_edit(
self,
operation: EditOperationType,
request: EditStudioRequest,
image_bytes: bytes,
mask_bytes: Optional[bytes],
background_bytes: Optional[bytes],
lighting_bytes: Optional[bytes],
) -> bytes:
"""Execute Stability AI edit workflows."""
stability_service = StabilityAIService()
async with stability_service:
if operation == "remove_background":
result = await stability_service.remove_background(
image=image_bytes,
output_format=request.output_format,
)
elif operation == "inpaint":
if not request.prompt:
raise ValueError("Prompt is required for inpainting")
result = await stability_service.inpaint(
image=image_bytes,
prompt=request.prompt,
mask=mask_bytes,
negative_prompt=request.negative_prompt,
output_format=request.output_format,
style_preset=request.style_preset,
grow_mask=request.options.get("grow_mask", 5),
)
elif operation == "outpaint":
result = await stability_service.outpaint(
image=image_bytes,
prompt=request.prompt,
negative_prompt=request.negative_prompt,
output_format=request.output_format,
left=request.expand_left or 0,
right=request.expand_right or 0,
up=request.expand_up or 0,
down=request.expand_down or 0,
style_preset=request.style_preset,
)
elif operation == "search_replace":
if not (request.prompt and request.search_prompt):
raise ValueError("Both prompt and search_prompt are required for search & replace")
result = await stability_service.search_and_replace(
image=image_bytes,
prompt=request.prompt,
search_prompt=request.search_prompt,
mask=mask_bytes, # Optional mask for precise region selection
output_format=request.output_format,
)
elif operation == "search_recolor":
if not (request.prompt and request.select_prompt):
raise ValueError("Both prompt and select_prompt are required for search & recolor")
result = await stability_service.search_and_recolor(
image=image_bytes,
prompt=request.prompt,
select_prompt=request.select_prompt,
mask=mask_bytes, # Optional mask for precise region selection
output_format=request.output_format,
)
elif operation == "relight":
if not background_bytes and not lighting_bytes:
raise ValueError("At least one reference (background or lighting) is required for relight")
result = await stability_service.replace_background_and_relight(
subject_image=image_bytes,
background_reference=background_bytes,
light_reference=lighting_bytes,
output_format=request.output_format,
)
if isinstance(result, dict) and result.get("id"):
result = await self._poll_stability_result(
stability_service,
generation_id=result["id"],
output_format=request.output_format,
)
else:
raise ValueError(f"Unsupported Stability operation: {operation}")
return self._extract_image_bytes(result)
async def _handle_general_edit(
self,
request: EditStudioRequest,
image_bytes: bytes,
mask_bytes: Optional[bytes],
user_id: Optional[str],
) -> bytes:
"""Execute Hugging Face powered general editing (synchronous API)."""
if not request.prompt:
raise ValueError("Prompt is required for general edits")
options = {
"provider": request.provider or "huggingface",
"model": request.model,
"guidance_scale": request.guidance_scale,
"steps": request.steps,
"seed": request.seed,
}
# huggingface edit is synchronous - run in thread
result = await asyncio.to_thread(
huggingface_edit_image,
image_bytes,
request.prompt,
options,
user_id,
mask_bytes, # Optional mask for selective editing
)
return result.image_bytes
@staticmethod
def _extract_image_bytes(result: Any) -> bytes:
"""Normalize Stability responses into raw image bytes."""
if isinstance(result, bytes):
return result
if isinstance(result, dict):
artifacts = result.get("artifacts") or result.get("data") or result.get("images") or []
for artifact in artifacts:
if isinstance(artifact, dict):
if artifact.get("base64"):
return base64.b64decode(artifact["base64"])
if artifact.get("b64_json"):
return base64.b64decode(artifact["b64_json"])
raise RuntimeError("Unable to extract image bytes from provider response")
async def _poll_stability_result(
self,
stability_service: StabilityAIService,
generation_id: str,
output_format: str,
timeout_seconds: int = 240,
interval_seconds: float = 2.0,
) -> bytes:
"""Poll Stability async endpoint until result is ready."""
elapsed = 0.0
while elapsed < timeout_seconds:
result = await stability_service.get_generation_result(
generation_id=generation_id,
accept_type="*/*",
)
if isinstance(result, bytes):
return result
if isinstance(result, dict):
state = (result.get("state") or result.get("status") or "").lower()
if state in {"succeeded", "success", "ready", "completed"}:
return self._extract_image_bytes(result)
if state in {"failed", "error"}:
raise RuntimeError(f"Stability generation failed: {result}")
await asyncio.sleep(interval_seconds)
elapsed += interval_seconds
raise RuntimeError("Timed out waiting for Stability generation result")

View File

@@ -0,0 +1,155 @@
"""InfiniteTalk adapter for Transform Studio."""
import asyncio
from typing import Any, Dict, Optional
from fastapi import HTTPException
from loguru import logger
from services.wavespeed.infinitetalk import animate_scene_with_voiceover
from services.wavespeed.client import WaveSpeedClient
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.infinitetalk")
class InfiniteTalkService:
"""Adapter for InfiniteTalk in Transform Studio context."""
def __init__(self, client: Optional[WaveSpeedClient] = None):
"""Initialize InfiniteTalk service adapter."""
self.client = client or WaveSpeedClient()
logger.info("[InfiniteTalk Adapter] Service initialized")
def calculate_cost(self, resolution: str, duration: float) -> float:
"""Calculate cost for InfiniteTalk video.
Args:
resolution: Output resolution (480p or 720p)
duration: Video duration in seconds
Returns:
Cost in USD
"""
# InfiniteTalk pricing: $0.03/s (480p) or $0.06/s (720p)
# Minimum charge: 5 seconds
cost_per_second = 0.03 if resolution == "480p" else 0.06
actual_duration = max(5.0, duration) # Minimum 5 seconds
return cost_per_second * actual_duration
async def create_talking_avatar(
self,
image_base64: str,
audio_base64: str,
resolution: str = "720p",
prompt: Optional[str] = None,
mask_image_base64: Optional[str] = None,
seed: Optional[int] = None,
user_id: str = "transform_studio",
) -> Dict[str, Any]:
"""Create talking avatar video using InfiniteTalk.
Args:
image_base64: Person image in base64 or data URI
audio_base64: Audio file in base64 or data URI
resolution: Output resolution (480p or 720p)
prompt: Optional prompt for expression/style
mask_image_base64: Optional mask for animatable regions
seed: Optional random seed
user_id: User ID for tracking
Returns:
Dictionary with video bytes, metadata, and cost
"""
# Validate resolution
if resolution not in ["480p", "720p"]:
raise HTTPException(
status_code=400,
detail="Resolution must be '480p' or '720p' for InfiniteTalk"
)
# Decode image
import base64
try:
if image_base64.startswith("data:"):
if "," not in image_base64:
raise ValueError("Invalid data URI format: missing comma separator")
header, encoded = image_base64.split(",", 1)
mime_parts = header.split(":")[1].split(";")[0] if ":" in header else "image/png"
image_mime = mime_parts.strip() or "image/png"
image_bytes = base64.b64decode(encoded)
else:
image_bytes = base64.b64decode(image_base64)
image_mime = "image/png"
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Failed to decode image: {str(e)}"
)
# Decode audio
try:
if audio_base64.startswith("data:"):
if "," not in audio_base64:
raise ValueError("Invalid data URI format: missing comma separator")
header, encoded = audio_base64.split(",", 1)
mime_parts = header.split(":")[1].split(";")[0] if ":" in header else "audio/mpeg"
audio_mime = mime_parts.strip() or "audio/mpeg"
audio_bytes = base64.b64decode(encoded)
else:
audio_bytes = base64.b64decode(audio_base64)
audio_mime = "audio/mpeg"
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Failed to decode audio: {str(e)}"
)
# Call existing InfiniteTalk function (run in thread since it's synchronous)
# Note: We pass empty dicts for scene_data and story_context since
# Transform Studio doesn't have story context
try:
result = await asyncio.to_thread(
animate_scene_with_voiceover,
image_bytes=image_bytes,
audio_bytes=audio_bytes,
scene_data={}, # Empty for Transform Studio
story_context={}, # Empty for Transform Studio
user_id=user_id,
resolution=resolution,
prompt_override=prompt,
image_mime=image_mime,
audio_mime=audio_mime,
client=self.client,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"[InfiniteTalk Adapter] Error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"InfiniteTalk generation failed: {str(e)}"
)
# Calculate actual cost based on duration
actual_cost = self.calculate_cost(resolution, result.get("duration", 5.0))
# Update result with actual cost and additional metadata
result["cost"] = actual_cost
result["resolution"] = resolution
# Get video dimensions from resolution
resolution_dims = {
"480p": (854, 480),
"720p": (1280, 720),
}
width, height = resolution_dims.get(resolution, (1280, 720))
result["width"] = width
result["height"] = height
logger.info(
f"[InfiniteTalk Adapter] ✅ Generated talking avatar: "
f"resolution={resolution}, duration={result.get('duration', 5.0)}s, cost=${actual_cost:.2f}"
)
return result

View File

@@ -0,0 +1,502 @@
"""Social Optimizer service for platform-specific image optimization."""
from __future__ import annotations
import base64
import io
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from PIL import Image, ImageDraw, ImageFont
from .templates import Platform
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.social_optimizer")
@dataclass
class SafeZone:
"""Safe zone configuration for text overlay."""
top: float = 0.1 # Percentage from top
bottom: float = 0.1 # Percentage from bottom
left: float = 0.1 # Percentage from left
right: float = 0.1 # Percentage from right
@dataclass
class PlatformFormat:
"""Platform format specification."""
name: str
width: int
height: int
ratio: str
safe_zone: SafeZone
file_type: str = "PNG"
max_size_mb: float = 5.0
# Platform format definitions with safe zones
PLATFORM_FORMATS: Dict[Platform, List[PlatformFormat]] = {
Platform.INSTAGRAM: [
PlatformFormat(
name="Feed Post (Square)",
width=1080,
height=1080,
ratio="1:1",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Feed Post (Portrait)",
width=1080,
height=1350,
ratio="4:5",
safe_zone=SafeZone(top=0.2, bottom=0.2, left=0.1, right=0.1),
),
PlatformFormat(
name="Story",
width=1080,
height=1920,
ratio="9:16",
safe_zone=SafeZone(top=0.25, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Reel",
width=1080,
height=1920,
ratio="9:16",
safe_zone=SafeZone(top=0.25, bottom=0.15, left=0.1, right=0.1),
),
],
Platform.FACEBOOK: [
PlatformFormat(
name="Feed Post",
width=1200,
height=630,
ratio="1.91:1",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Feed Post (Square)",
width=1080,
height=1080,
ratio="1:1",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Story",
width=1080,
height=1920,
ratio="9:16",
safe_zone=SafeZone(top=0.25, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Cover Photo",
width=820,
height=312,
ratio="16:9",
safe_zone=SafeZone(top=0.2, bottom=0.1, left=0.15, right=0.15),
),
],
Platform.TWITTER: [
PlatformFormat(
name="Post",
width=1200,
height=675,
ratio="16:9",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Card",
width=1200,
height=600,
ratio="2:1",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Header",
width=1500,
height=500,
ratio="3:1",
safe_zone=SafeZone(top=0.2, bottom=0.1, left=0.15, right=0.15),
),
],
Platform.LINKEDIN: [
PlatformFormat(
name="Feed Post",
width=1200,
height=628,
ratio="1.91:1",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Feed Post (Square)",
width=1080,
height=1080,
ratio="1:1",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Article",
width=1200,
height=627,
ratio="2:1",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Company Cover",
width=1128,
height=191,
ratio="4:1",
safe_zone=SafeZone(top=0.2, bottom=0.1, left=0.15, right=0.15),
),
],
Platform.YOUTUBE: [
PlatformFormat(
name="Thumbnail",
width=1280,
height=720,
ratio="16:9",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Channel Art",
width=2560,
height=1440,
ratio="16:9",
safe_zone=SafeZone(top=0.2, bottom=0.1, left=0.15, right=0.15),
),
],
Platform.PINTEREST: [
PlatformFormat(
name="Pin",
width=1000,
height=1500,
ratio="2:3",
safe_zone=SafeZone(top=0.2, bottom=0.2, left=0.1, right=0.1),
),
PlatformFormat(
name="Story Pin",
width=1080,
height=1920,
ratio="9:16",
safe_zone=SafeZone(top=0.25, bottom=0.15, left=0.1, right=0.1),
),
],
Platform.TIKTOK: [
PlatformFormat(
name="Video Cover",
width=1080,
height=1920,
ratio="9:16",
safe_zone=SafeZone(top=0.25, bottom=0.15, left=0.1, right=0.1),
),
],
}
@dataclass
class SocialOptimizerRequest:
"""Request payload for social optimization."""
image_base64: str
platforms: List[Platform] # List of platforms to optimize for
format_names: Optional[Dict[Platform, str]] = None # Specific format per platform
show_safe_zones: bool = False # Include safe zone overlay in output
crop_mode: str = "smart" # "smart", "center", "fit"
focal_point: Optional[Dict[str, float]] = None # {"x": 0.5, "y": 0.5} for smart crop
output_format: str = "png"
options: Dict[str, Any] = field(default_factory=dict)
class SocialOptimizerService:
"""Service for optimizing images for social media platforms."""
def __init__(self):
logger.info("[Social Optimizer] Initialized service")
@staticmethod
def _decode_base64_image(value: str) -> bytes:
"""Decode a base64 (or data URL) string to bytes."""
try:
if value.startswith("data:"):
_, b64data = value.split(",", 1)
else:
b64data = value
return base64.b64decode(b64data)
except Exception as exc:
logger.error(f"[Social Optimizer] Failed to decode base64 image: {exc}")
raise ValueError("Invalid base64 image payload") from exc
@staticmethod
def _bytes_to_base64(image_bytes: bytes, output_format: str = "png") -> str:
"""Convert raw bytes to base64 data URL."""
b64 = base64.b64encode(image_bytes).decode("utf-8")
return f"data:image/{output_format};base64,{b64}"
@staticmethod
def _smart_crop(
image: Image.Image,
target_width: int,
target_height: int,
focal_point: Optional[Dict[str, float]] = None,
) -> Image.Image:
"""Smart crop image to target dimensions, preserving important content."""
img_width, img_height = image.size
target_ratio = target_width / target_height
img_ratio = img_width / img_height
# If focal point is provided, use it for cropping
if focal_point:
focal_x = int(focal_point["x"] * img_width)
focal_y = int(focal_point["y"] * img_height)
else:
# Default to center
focal_x = img_width // 2
focal_y = img_height // 2
if img_ratio > target_ratio:
# Image is wider than target - crop width
new_width = int(img_height * target_ratio)
left = max(0, min(focal_x - new_width // 2, img_width - new_width))
right = left + new_width
cropped = image.crop((left, 0, right, img_height))
else:
# Image is taller than target - crop height
new_height = int(img_width / target_ratio)
top = max(0, min(focal_y - new_height // 2, img_height - new_height))
bottom = top + new_height
cropped = image.crop((0, top, img_width, bottom))
# Resize to exact target dimensions
return cropped.resize((target_width, target_height), Image.Resampling.LANCZOS)
@staticmethod
def _fit_image(
image: Image.Image,
target_width: int,
target_height: int,
) -> Image.Image:
"""Fit image to target dimensions while maintaining aspect ratio (adds padding if needed)."""
img_width, img_height = image.size
target_ratio = target_width / target_height
img_ratio = img_width / img_height
if img_ratio > target_ratio:
# Image is wider - fit to height, pad width
new_height = target_height
new_width = int(img_width * (target_height / img_height))
resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Create new image with target size and paste centered
result = Image.new("RGB", (target_width, target_height), (255, 255, 255))
paste_x = (target_width - new_width) // 2
result.paste(resized, (paste_x, 0))
return result
else:
# Image is taller - fit to width, pad height
new_width = target_width
new_height = int(img_height * (target_width / img_width))
resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Create new image with target size and paste centered
result = Image.new("RGB", (target_width, target_height), (255, 255, 255))
paste_y = (target_height - new_height) // 2
result.paste(resized, (0, paste_y))
return result
@staticmethod
def _center_crop(
image: Image.Image,
target_width: int,
target_height: int,
) -> Image.Image:
"""Center crop image to target dimensions."""
img_width, img_height = image.size
target_ratio = target_width / target_height
img_ratio = img_width / img_height
if img_ratio > target_ratio:
# Image is wider - crop width
new_width = int(img_height * target_ratio)
left = (img_width - new_width) // 2
cropped = image.crop((left, 0, left + new_width, img_height))
else:
# Image is taller - crop height
new_height = int(img_width / target_ratio)
top = (img_height - new_height) // 2
cropped = image.crop((0, top, img_width, top + new_height))
return cropped.resize((target_width, target_height), Image.Resampling.LANCZOS)
@staticmethod
def _draw_safe_zone(
image: Image.Image,
safe_zone: SafeZone,
) -> Image.Image:
"""Draw safe zone overlay on image."""
draw = ImageDraw.Draw(image)
width, height = image.size
# Calculate safe zone boundaries
top = int(height * safe_zone.top)
bottom = int(height * (1 - safe_zone.bottom))
left = int(width * safe_zone.left)
right = int(width * (1 - safe_zone.right))
# Draw semi-transparent overlay outside safe zone
overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
overlay_draw = ImageDraw.Draw(overlay)
# Top area
overlay_draw.rectangle([(0, 0), (width, top)], fill=(0, 0, 0, 100))
# Bottom area
overlay_draw.rectangle([(0, bottom), (width, height)], fill=(0, 0, 0, 100))
# Left area
overlay_draw.rectangle([(0, top), (left, bottom)], fill=(0, 0, 0, 100))
# Right area
overlay_draw.rectangle([(right, top), (width, bottom)], fill=(0, 0, 0, 100))
# Draw safe zone border
border_color = (255, 255, 0, 200) # Yellow with transparency
overlay_draw.rectangle(
[(left, top), (right, bottom)],
outline=border_color,
width=2,
)
# Composite overlay onto image
if image.mode != "RGBA":
image = image.convert("RGBA")
image = Image.alpha_composite(image, overlay)
return image
def get_platform_formats(self, platform: Platform) -> List[Dict[str, Any]]:
"""Get available formats for a platform."""
formats = PLATFORM_FORMATS.get(platform, [])
return [
{
"name": fmt.name,
"width": fmt.width,
"height": fmt.height,
"ratio": fmt.ratio,
"safe_zone": {
"top": fmt.safe_zone.top,
"bottom": fmt.safe_zone.bottom,
"left": fmt.safe_zone.left,
"right": fmt.safe_zone.right,
},
"file_type": fmt.file_type,
"max_size_mb": fmt.max_size_mb,
}
for fmt in formats
]
def optimize_image(
self,
request: SocialOptimizerRequest,
) -> Dict[str, Any]:
"""Optimize image for specified platforms."""
logger.info(
f"[Social Optimizer] Processing optimization for {len(request.platforms)} platform(s)"
)
# Decode input image
image_bytes = self._decode_base64_image(request.image_base64)
original_image = Image.open(io.BytesIO(image_bytes))
# Convert to RGB if needed
if original_image.mode in ("RGBA", "LA", "P"):
if original_image.mode == "P":
original_image = original_image.convert("RGBA")
background = Image.new("RGB", original_image.size, (255, 255, 255))
if original_image.mode == "RGBA":
background.paste(original_image, mask=original_image.split()[-1])
else:
background.paste(original_image)
original_image = background
elif original_image.mode != "RGB":
original_image = original_image.convert("RGB")
results = []
for platform in request.platforms:
formats = PLATFORM_FORMATS.get(platform, [])
if not formats:
logger.warning(f"[Social Optimizer] No formats found for platform: {platform}")
continue
# Get format (use specified format or default to first)
format_name = None
if request.format_names and platform in request.format_names:
format_name = request.format_names[platform]
platform_format = None
for fmt in formats:
if format_name and fmt.name == format_name:
platform_format = fmt
break
if not platform_format:
platform_format = formats[0] # Default to first format
# Crop/resize image based on mode
if request.crop_mode == "smart":
optimized_image = self._smart_crop(
original_image,
platform_format.width,
platform_format.height,
request.focal_point,
)
elif request.crop_mode == "fit":
optimized_image = self._fit_image(
original_image,
platform_format.width,
platform_format.height,
)
else: # center
optimized_image = self._center_crop(
original_image,
platform_format.width,
platform_format.height,
)
# Add safe zone overlay if requested
if request.show_safe_zones:
optimized_image = self._draw_safe_zone(optimized_image, platform_format.safe_zone)
# Convert to bytes
output_buffer = io.BytesIO()
output_format = request.output_format.lower()
if output_format == "jpg" or output_format == "jpeg":
optimized_image = optimized_image.convert("RGB")
optimized_image.save(output_buffer, format="JPEG", quality=95)
else:
optimized_image.save(output_buffer, format="PNG")
output_bytes = output_buffer.getvalue()
results.append(
{
"platform": platform.value,
"format": platform_format.name,
"width": platform_format.width,
"height": platform_format.height,
"ratio": platform_format.ratio,
"image_base64": self._bytes_to_base64(output_bytes, request.output_format),
"safe_zone": {
"top": platform_format.safe_zone.top,
"bottom": platform_format.safe_zone.bottom,
"left": platform_format.safe_zone.left,
"right": platform_format.safe_zone.right,
},
}
)
logger.info(f"[Social Optimizer] ✅ Generated {len(results)} optimized images")
return {
"success": True,
"results": results,
"total_optimized": len(results),
}

View File

@@ -0,0 +1,379 @@
"""Image Studio Manager - Main orchestration service for all image operations."""
from typing import Optional, Dict, Any, List
from .create_service import CreateStudioService, CreateStudioRequest
from .edit_service import EditStudioService, EditStudioRequest
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
from .control_service import ControlStudioService, ControlStudioRequest
from .social_optimizer_service import SocialOptimizerService, SocialOptimizerRequest
from .transform_service import (
TransformStudioService,
TransformImageToVideoRequest,
TalkingAvatarRequest,
)
from .templates import Platform, TemplateCategory, ImageTemplate
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.manager")
class ImageStudioManager:
"""Main manager for Image Studio operations."""
def __init__(self):
"""Initialize Image Studio Manager."""
self.create_service = CreateStudioService()
self.edit_service = EditStudioService()
self.upscale_service = UpscaleStudioService()
self.control_service = ControlStudioService()
self.social_optimizer_service = SocialOptimizerService()
self.transform_service = TransformStudioService()
logger.info("[Image Studio Manager] Initialized successfully")
# ====================
# CREATE STUDIO
# ====================
async def create_image(
self,
request: CreateStudioRequest,
user_id: Optional[str] = None
) -> Dict[str, Any]:
"""Create/generate image using Create Studio.
Args:
request: Create studio request
user_id: User ID for validation
Returns:
Dictionary with generation results
"""
logger.info("[Image Studio] Create image request from user: %s", user_id)
return await self.create_service.generate(request, user_id=user_id)
# ====================
# EDIT STUDIO
# ====================
async def edit_image(
self,
request: EditStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Run Edit Studio operations."""
logger.info("[Image Studio] Edit image request from user: %s", user_id)
return await self.edit_service.process_edit(request, user_id=user_id)
def get_edit_operations(self) -> Dict[str, Any]:
"""Expose edit operations for UI."""
return self.edit_service.list_operations()
# ====================
# UPSCALE STUDIO
# ====================
async def upscale_image(
self,
request: UpscaleStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Run Upscale Studio operations."""
logger.info("[Image Studio] Upscale request from user: %s", user_id)
return await self.upscale_service.process_upscale(request, user_id=user_id)
def get_templates(
self,
platform: Optional[Platform] = None,
category: Optional[TemplateCategory] = None
) -> List[ImageTemplate]:
"""Get available templates.
Args:
platform: Filter by platform
category: Filter by category
Returns:
List of templates
"""
return self.create_service.get_templates(platform=platform, category=category)
def search_templates(self, query: str) -> List[ImageTemplate]:
"""Search templates by query.
Args:
query: Search query
Returns:
List of matching templates
"""
return self.create_service.search_templates(query)
def recommend_templates(
self,
use_case: str,
platform: Optional[Platform] = None
) -> List[ImageTemplate]:
"""Recommend templates based on use case.
Args:
use_case: Use case description
platform: Optional platform filter
Returns:
List of recommended templates
"""
return self.create_service.recommend_templates(use_case, platform)
def get_providers(self) -> Dict[str, Any]:
"""Get available image providers and their capabilities.
Returns:
Dictionary of providers with capabilities
"""
return {
"stability": {
"name": "Stability AI",
"models": ["ultra", "core", "sd3.5-large"],
"capabilities": ["text-to-image", "editing", "upscaling", "control", "3d"],
"max_resolution": (2048, 2048),
"cost_range": "3-8 credits per image",
},
"wavespeed": {
"name": "WaveSpeed AI",
"models": ["ideogram-v3-turbo", "qwen-image"],
"capabilities": ["text-to-image", "photorealistic", "fast-generation"],
"max_resolution": (1024, 1024),
"cost_range": "$0.05-$0.10 per image",
},
"huggingface": {
"name": "HuggingFace",
"models": ["FLUX.1-Krea-dev", "RunwayML"],
"capabilities": ["text-to-image", "image-to-image"],
"max_resolution": (1024, 1024),
"cost_range": "Free tier available",
},
"gemini": {
"name": "Google Gemini",
"models": ["imagen-3.0"],
"capabilities": ["text-to-image", "conversational-editing"],
"max_resolution": (1024, 1024),
"cost_range": "Free tier available",
}
}
# ====================
# COST ESTIMATION
# ====================
def estimate_cost(
self,
provider: str,
model: Optional[str],
operation: str,
num_images: int = 1,
resolution: Optional[tuple[int, int]] = None
) -> Dict[str, Any]:
"""Estimate cost for image operations.
Args:
provider: Provider name
model: Model name
operation: Operation type (generate, edit, upscale, etc.)
num_images: Number of images
resolution: Image resolution (width, height)
Returns:
Cost estimation details
"""
# Base costs (adjust based on actual pricing)
base_costs = {
"stability": {
"ultra": 0.08, # 8 credits
"core": 0.03, # 3 credits
"sd3": 0.065, # 6.5 credits
},
"wavespeed": {
"ideogram-v3-turbo": 0.10,
"qwen-image": 0.05,
},
"huggingface": {
"default": 0.0, # Free tier
},
"gemini": {
"default": 0.0, # Free tier
}
}
# Get base cost
provider_costs = base_costs.get(provider, {})
cost_per_image = provider_costs.get(model, provider_costs.get("default", 0.0))
# Calculate total
total_cost = cost_per_image * num_images
return {
"provider": provider,
"model": model,
"operation": operation,
"num_images": num_images,
"resolution": f"{resolution[0]}x{resolution[1]}" if resolution else "default",
"cost_per_image": cost_per_image,
"total_cost": total_cost,
"currency": "USD",
"estimated": True,
}
# ====================
# CONTROL STUDIO
# ====================
async def control_image(
self,
request: ControlStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Run Control Studio operations."""
logger.info("[Image Studio] Control request from user: %s", user_id)
return await self.control_service.process_control(request, user_id=user_id)
def get_control_operations(self) -> Dict[str, Any]:
"""Expose control operations for UI."""
return self.control_service.list_operations()
# ====================
# SOCIAL OPTIMIZER
# ====================
async def optimize_for_social(
self,
request: SocialOptimizerRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Optimize image for social media platforms."""
logger.info("[Image Studio] Social optimization request from user: %s", user_id)
return self.social_optimizer_service.optimize_image(request)
def get_social_platform_formats(self, platform: Platform) -> List[Dict[str, Any]]:
"""Get available formats for a social platform."""
return self.social_optimizer_service.get_platform_formats(platform)
# ====================
# PLATFORM SPECS
# ====================
def get_platform_specs(self, platform: Platform) -> Dict[str, Any]:
"""Get platform specifications and requirements.
Args:
platform: Platform to get specs for
Returns:
Platform specifications
"""
specs = {
Platform.INSTAGRAM: {
"name": "Instagram",
"formats": [
{"name": "Feed Post (Square)", "ratio": "1:1", "size": "1080x1080"},
{"name": "Feed Post (Portrait)", "ratio": "4:5", "size": "1080x1350"},
{"name": "Story", "ratio": "9:16", "size": "1080x1920"},
{"name": "Reel", "ratio": "9:16", "size": "1080x1920"},
],
"file_types": ["JPG", "PNG"],
"max_file_size": "30MB",
},
Platform.FACEBOOK: {
"name": "Facebook",
"formats": [
{"name": "Feed Post", "ratio": "1.91:1", "size": "1200x630"},
{"name": "Feed Post (Square)", "ratio": "1:1", "size": "1080x1080"},
{"name": "Story", "ratio": "9:16", "size": "1080x1920"},
{"name": "Cover Photo", "ratio": "16:9", "size": "820x312"},
],
"file_types": ["JPG", "PNG"],
"max_file_size": "30MB",
},
Platform.TWITTER: {
"name": "Twitter/X",
"formats": [
{"name": "Post", "ratio": "16:9", "size": "1200x675"},
{"name": "Card", "ratio": "2:1", "size": "1200x600"},
{"name": "Header", "ratio": "3:1", "size": "1500x500"},
],
"file_types": ["JPG", "PNG", "GIF"],
"max_file_size": "5MB",
},
Platform.LINKEDIN: {
"name": "LinkedIn",
"formats": [
{"name": "Feed Post", "ratio": "1.91:1", "size": "1200x628"},
{"name": "Feed Post (Square)", "ratio": "1:1", "size": "1080x1080"},
{"name": "Article", "ratio": "2:1", "size": "1200x627"},
{"name": "Company Cover", "ratio": "4:1", "size": "1128x191"},
],
"file_types": ["JPG", "PNG"],
"max_file_size": "8MB",
},
Platform.YOUTUBE: {
"name": "YouTube",
"formats": [
{"name": "Thumbnail", "ratio": "16:9", "size": "1280x720"},
{"name": "Channel Art", "ratio": "16:9", "size": "2560x1440"},
],
"file_types": ["JPG", "PNG"],
"max_file_size": "2MB",
},
Platform.PINTEREST: {
"name": "Pinterest",
"formats": [
{"name": "Pin", "ratio": "2:3", "size": "1000x1500"},
{"name": "Story Pin", "ratio": "9:16", "size": "1080x1920"},
],
"file_types": ["JPG", "PNG"],
"max_file_size": "20MB",
},
Platform.TIKTOK: {
"name": "TikTok",
"formats": [
{"name": "Video Cover", "ratio": "9:16", "size": "1080x1920"},
],
"file_types": ["JPG", "PNG"],
"max_file_size": "10MB",
},
}
return specs.get(platform, {})
# ====================
# TRANSFORM STUDIO
# ====================
async def transform_image_to_video(
self,
request: TransformImageToVideoRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Transform image to video using WAN 2.5."""
logger.info("[Image Studio] Transform image-to-video request from user: %s", user_id)
return await self.transform_service.transform_image_to_video(request, user_id=user_id or "anonymous")
async def create_talking_avatar(
self,
request: TalkingAvatarRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Create talking avatar using InfiniteTalk."""
logger.info("[Image Studio] Talking avatar request from user: %s", user_id)
return await self.transform_service.create_talking_avatar(request, user_id=user_id or "anonymous")
def estimate_transform_cost(
self,
operation: str,
resolution: str,
duration: Optional[int] = None,
) -> Dict[str, Any]:
"""Estimate cost for transform operation."""
return self.transform_service.estimate_cost(operation, resolution, duration)

View File

@@ -0,0 +1,555 @@
"""Template system for Image Studio with platform-specific presets."""
from dataclasses import dataclass
from typing import Dict, List, Optional, Literal
from enum import Enum
class Platform(str, Enum):
"""Supported social media platforms."""
INSTAGRAM = "instagram"
FACEBOOK = "facebook"
TWITTER = "twitter"
LINKEDIN = "linkedin"
YOUTUBE = "youtube"
PINTEREST = "pinterest"
TIKTOK = "tiktok"
BLOG = "blog"
EMAIL = "email"
WEBSITE = "website"
class TemplateCategory(str, Enum):
"""Template categories."""
SOCIAL_MEDIA = "social_media"
BLOG_CONTENT = "blog_content"
AD_CREATIVE = "ad_creative"
PRODUCT = "product"
BRAND_ASSETS = "brand_assets"
EMAIL_MARKETING = "email_marketing"
@dataclass
class AspectRatio:
"""Aspect ratio configuration."""
ratio: str # e.g., "1:1", "16:9"
width: int
height: int
label: str # e.g., "Square", "Widescreen"
@dataclass
class ImageTemplate:
"""Image generation template."""
id: str
name: str
category: TemplateCategory
platform: Optional[Platform]
aspect_ratio: AspectRatio
description: str
recommended_provider: str
style_preset: str
quality: Literal["draft", "standard", "premium"]
prompt_template: Optional[str] = None
negative_prompt_template: Optional[str] = None
use_cases: List[str] = None
class PlatformTemplates:
"""Platform-specific template definitions."""
# Aspect Ratios
SQUARE_1_1 = AspectRatio("1:1", 1080, 1080, "Square")
PORTRAIT_4_5 = AspectRatio("4:5", 1080, 1350, "Portrait")
STORY_9_16 = AspectRatio("9:16", 1080, 1920, "Story/Reel")
LANDSCAPE_16_9 = AspectRatio("16:9", 1920, 1080, "Landscape")
WIDE_21_9 = AspectRatio("21:9", 2560, 1080, "Ultra Wide")
TWITTER_2_1 = AspectRatio("2:1", 1200, 600, "Twitter Card")
TWITTER_3_1 = AspectRatio("3:1", 1500, 500, "Twitter Header")
FACEBOOK_1_91_1 = AspectRatio("1.91:1", 1200, 630, "Facebook Feed")
LINKEDIN_1_91_1 = AspectRatio("1.91:1", 1200, 628, "LinkedIn Feed")
LINKEDIN_2_1 = AspectRatio("2:1", 1200, 627, "LinkedIn Article")
LINKEDIN_4_1 = AspectRatio("4:1", 1128, 191, "LinkedIn Cover")
PINTEREST_2_3 = AspectRatio("2:3", 1000, 1500, "Pinterest Pin")
YOUTUBE_16_9 = AspectRatio("16:9", 1280, 720, "YouTube Thumbnail")
FACEBOOK_COVER_16_9 = AspectRatio("16:9", 820, 312, "Facebook Cover")
@classmethod
def get_platform_templates(cls) -> Dict[Platform, List[ImageTemplate]]:
"""Get all platform-specific templates."""
return {
Platform.INSTAGRAM: cls._instagram_templates(),
Platform.FACEBOOK: cls._facebook_templates(),
Platform.TWITTER: cls._twitter_templates(),
Platform.LINKEDIN: cls._linkedin_templates(),
Platform.YOUTUBE: cls._youtube_templates(),
Platform.PINTEREST: cls._pinterest_templates(),
Platform.TIKTOK: cls._tiktok_templates(),
Platform.BLOG: cls._blog_templates(),
Platform.EMAIL: cls._email_templates(),
Platform.WEBSITE: cls._website_templates(),
}
@classmethod
def _instagram_templates(cls) -> List[ImageTemplate]:
"""Instagram templates."""
return [
ImageTemplate(
id="instagram_feed_square",
name="Instagram Feed Post (Square)",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.INSTAGRAM,
aspect_ratio=cls.SQUARE_1_1,
description="Perfect for Instagram feed posts with maximum visibility",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Product showcase", "Lifestyle posts", "Brand content"]
),
ImageTemplate(
id="instagram_feed_portrait",
name="Instagram Feed Post (Portrait)",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.INSTAGRAM,
aspect_ratio=cls.PORTRAIT_4_5,
description="Vertical format for maximum feed real estate",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Fashion", "Food", "Product photography"]
),
ImageTemplate(
id="instagram_story",
name="Instagram Story",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.INSTAGRAM,
aspect_ratio=cls.STORY_9_16,
description="Full-screen vertical stories",
recommended_provider="ideogram",
style_preset="digital-art",
quality="standard",
use_cases=["Behind-the-scenes", "Announcements", "Quick updates"]
),
ImageTemplate(
id="instagram_reel_cover",
name="Instagram Reel Cover",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.INSTAGRAM,
aspect_ratio=cls.STORY_9_16,
description="Eye-catching reel cover images",
recommended_provider="ideogram",
style_preset="cinematic",
quality="premium",
use_cases=["Video covers", "Thumbnails", "Highlights"]
),
]
@classmethod
def _facebook_templates(cls) -> List[ImageTemplate]:
"""Facebook templates."""
return [
ImageTemplate(
id="facebook_feed",
name="Facebook Feed Post",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.FACEBOOK,
aspect_ratio=cls.FACEBOOK_1_91_1,
description="Optimized for Facebook news feed",
recommended_provider="ideogram",
style_preset="photographic",
quality="standard",
use_cases=["Page posts", "Shared content", "Community posts"]
),
ImageTemplate(
id="facebook_feed_square",
name="Facebook Feed Post (Square)",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.FACEBOOK,
aspect_ratio=cls.SQUARE_1_1,
description="Square format for feed posts",
recommended_provider="ideogram",
style_preset="photographic",
quality="standard",
use_cases=["Page posts", "Product highlights"]
),
ImageTemplate(
id="facebook_story",
name="Facebook Story",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.FACEBOOK,
aspect_ratio=cls.STORY_9_16,
description="Full-screen vertical stories",
recommended_provider="ideogram",
style_preset="digital-art",
quality="standard",
use_cases=["Quick updates", "Promotions", "Events"]
),
ImageTemplate(
id="facebook_cover",
name="Facebook Cover Photo",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.FACEBOOK,
aspect_ratio=cls.FACEBOOK_COVER_16_9,
description="Wide cover photo for pages",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Page branding", "Events", "Seasonal updates"]
),
]
@classmethod
def _twitter_templates(cls) -> List[ImageTemplate]:
"""Twitter/X templates."""
return [
ImageTemplate(
id="twitter_post",
name="Twitter/X Post",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.TWITTER,
aspect_ratio=cls.LANDSCAPE_16_9,
description="Optimized for Twitter feed",
recommended_provider="ideogram",
style_preset="photographic",
quality="standard",
use_cases=["Tweets", "News", "Updates"]
),
ImageTemplate(
id="twitter_card",
name="Twitter Card",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.TWITTER,
aspect_ratio=cls.TWITTER_2_1,
description="Twitter card with link preview",
recommended_provider="ideogram",
style_preset="digital-art",
quality="standard",
use_cases=["Link sharing", "Articles", "Blog posts"]
),
ImageTemplate(
id="twitter_header",
name="Twitter Header",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.TWITTER,
aspect_ratio=cls.TWITTER_3_1,
description="Profile header image",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Profile branding", "Personal brand", "Business identity"]
),
]
@classmethod
def _linkedin_templates(cls) -> List[ImageTemplate]:
"""LinkedIn templates."""
return [
ImageTemplate(
id="linkedin_post",
name="LinkedIn Post",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.LINKEDIN,
aspect_ratio=cls.LINKEDIN_1_91_1,
description="Professional feed posts",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Professional content", "Industry news", "Thought leadership"]
),
ImageTemplate(
id="linkedin_post_square",
name="LinkedIn Post (Square)",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.LINKEDIN,
aspect_ratio=cls.SQUARE_1_1,
description="Square format for LinkedIn feed",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Quick tips", "Infographics", "Quotes"]
),
ImageTemplate(
id="linkedin_article",
name="LinkedIn Article Header",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.LINKEDIN,
aspect_ratio=cls.LINKEDIN_2_1,
description="Article header images",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Long-form content", "Articles", "Newsletters"]
),
ImageTemplate(
id="linkedin_cover",
name="LinkedIn Company Cover",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.LINKEDIN,
aspect_ratio=cls.LINKEDIN_4_1,
description="Company page cover photo",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Company branding", "Recruitment", "Brand identity"]
),
]
@classmethod
def _youtube_templates(cls) -> List[ImageTemplate]:
"""YouTube templates."""
return [
ImageTemplate(
id="youtube_thumbnail",
name="YouTube Thumbnail",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.YOUTUBE,
aspect_ratio=cls.YOUTUBE_16_9,
description="Eye-catching video thumbnails",
recommended_provider="ideogram",
style_preset="cinematic",
quality="premium",
use_cases=["Video thumbnails", "Channel branding", "Playlists"]
),
ImageTemplate(
id="youtube_channel_art",
name="YouTube Channel Art",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.YOUTUBE,
aspect_ratio=cls.LANDSCAPE_16_9,
description="Channel banner art",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Channel branding", "Personal brand", "Business identity"]
),
]
@classmethod
def _pinterest_templates(cls) -> List[ImageTemplate]:
"""Pinterest templates."""
return [
ImageTemplate(
id="pinterest_pin",
name="Pinterest Pin",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.PINTEREST,
aspect_ratio=cls.PINTEREST_2_3,
description="Vertical pin format",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Product pins", "DIY guides", "Recipes", "Inspiration"]
),
ImageTemplate(
id="pinterest_story",
name="Pinterest Story Pin",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.PINTEREST,
aspect_ratio=cls.STORY_9_16,
description="Full-screen story pins",
recommended_provider="ideogram",
style_preset="digital-art",
quality="standard",
use_cases=["Step-by-step guides", "Tutorials", "Quick tips"]
),
]
@classmethod
def _tiktok_templates(cls) -> List[ImageTemplate]:
"""TikTok templates."""
return [
ImageTemplate(
id="tiktok_video_cover",
name="TikTok Video Cover",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.TIKTOK,
aspect_ratio=cls.STORY_9_16,
description="Vertical video cover",
recommended_provider="ideogram",
style_preset="cinematic",
quality="premium",
use_cases=["Video covers", "Thumbnails", "Profile highlights"]
),
]
@classmethod
def _blog_templates(cls) -> List[ImageTemplate]:
"""Blog content templates."""
return [
ImageTemplate(
id="blog_header",
name="Blog Header",
category=TemplateCategory.BLOG_CONTENT,
platform=Platform.BLOG,
aspect_ratio=cls.LANDSCAPE_16_9,
description="Blog post featured image",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Featured images", "Article headers", "Post thumbnails"]
),
ImageTemplate(
id="blog_header_wide",
name="Blog Header (Wide)",
category=TemplateCategory.BLOG_CONTENT,
platform=Platform.BLOG,
aspect_ratio=cls.WIDE_21_9,
description="Ultra-wide blog header",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Hero sections", "Wide headers", "Landing pages"]
),
]
@classmethod
def _email_templates(cls) -> List[ImageTemplate]:
"""Email marketing templates."""
return [
ImageTemplate(
id="email_banner",
name="Email Banner",
category=TemplateCategory.EMAIL_MARKETING,
platform=Platform.EMAIL,
aspect_ratio=cls.LANDSCAPE_16_9,
description="Email header banner",
recommended_provider="ideogram",
style_preset="photographic",
quality="standard",
use_cases=["Email headers", "Newsletter banners", "Promotions"]
),
ImageTemplate(
id="email_product",
name="Email Product Image",
category=TemplateCategory.EMAIL_MARKETING,
platform=Platform.EMAIL,
aspect_ratio=cls.SQUARE_1_1,
description="Product showcase for emails",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Product highlights", "Promotions", "Offers"]
),
]
@classmethod
def _website_templates(cls) -> List[ImageTemplate]:
"""Website templates."""
return [
ImageTemplate(
id="website_hero",
name="Website Hero Image",
category=TemplateCategory.BRAND_ASSETS,
platform=Platform.WEBSITE,
aspect_ratio=cls.WIDE_21_9,
description="Hero section background",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Hero sections", "Landing pages", "Home page banners"]
),
ImageTemplate(
id="website_banner",
name="Website Banner",
category=TemplateCategory.BRAND_ASSETS,
platform=Platform.WEBSITE,
aspect_ratio=cls.LANDSCAPE_16_9,
description="Section banners",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Section headers", "Category pages", "Feature sections"]
),
]
class TemplateManager:
"""Manager for image templates with search and recommendation."""
def __init__(self):
"""Initialize template manager."""
self.templates = PlatformTemplates.get_platform_templates()
self._all_templates: Optional[List[ImageTemplate]] = None
def get_all_templates(self) -> List[ImageTemplate]:
"""Get all templates across all platforms."""
if self._all_templates is None:
self._all_templates = []
for platform_templates in self.templates.values():
self._all_templates.extend(platform_templates)
return self._all_templates
def get_by_platform(self, platform: Platform) -> List[ImageTemplate]:
"""Get templates for a specific platform."""
return self.templates.get(platform, [])
def get_by_category(self, category: TemplateCategory) -> List[ImageTemplate]:
"""Get templates by category."""
all_templates = self.get_all_templates()
return [t for t in all_templates if t.category == category]
def get_by_id(self, template_id: str) -> Optional[ImageTemplate]:
"""Get template by ID."""
all_templates = self.get_all_templates()
for template in all_templates:
if template.id == template_id:
return template
return None
def search(self, query: str) -> List[ImageTemplate]:
"""Search templates by query."""
query = query.lower()
all_templates = self.get_all_templates()
results = []
for template in all_templates:
# Search in name, description, and use cases
searchable = (
template.name.lower() + " " +
template.description.lower() + " " +
" ".join(template.use_cases or []).lower()
)
if query in searchable:
results.append(template)
return results
def recommend_for_use_case(self, use_case: str, platform: Optional[Platform] = None) -> List[ImageTemplate]:
"""Recommend templates based on use case and platform."""
use_case_lower = use_case.lower()
all_templates = self.get_all_templates()
# Filter by platform if specified
if platform:
all_templates = [t for t in all_templates if t.platform == platform]
# Find matching templates
matches = []
for template in all_templates:
if template.use_cases:
for case in template.use_cases:
if use_case_lower in case.lower():
matches.append(template)
break
return matches
def get_aspect_ratio_options(self) -> List[AspectRatio]:
"""Get all available aspect ratios."""
return [
PlatformTemplates.SQUARE_1_1,
PlatformTemplates.PORTRAIT_4_5,
PlatformTemplates.STORY_9_16,
PlatformTemplates.LANDSCAPE_16_9,
PlatformTemplates.WIDE_21_9,
PlatformTemplates.TWITTER_2_1,
PlatformTemplates.TWITTER_3_1,
PlatformTemplates.FACEBOOK_1_91_1,
PlatformTemplates.LINKEDIN_1_91_1,
PlatformTemplates.LINKEDIN_2_1,
PlatformTemplates.LINKEDIN_4_1,
PlatformTemplates.PINTEREST_2_3,
PlatformTemplates.YOUTUBE_16_9,
PlatformTemplates.FACEBOOK_COVER_16_9,
]

View File

@@ -0,0 +1,370 @@
"""Transform Studio service for image-to-video and talking avatar generation."""
import os
import uuid
from pathlib import Path
from typing import Any, Dict, Optional
from dataclasses import dataclass
from fastapi import HTTPException
from loguru import logger
from .wan25_service import WAN25Service
from .infinitetalk_adapter import InfiniteTalkService
from services.llm_providers.main_video_generation import ai_video_generate
from utils.logger_utils import get_service_logger
from utils.file_storage import save_file_safely, sanitize_filename
logger = get_service_logger("image_studio.transform")
@dataclass
class TransformImageToVideoRequest:
"""Request for WAN 2.5 image-to-video."""
image_base64: str
prompt: str
audio_base64: Optional[str] = None
resolution: str = "720p" # 480p, 720p, 1080p
duration: int = 5 # 5 or 10 seconds
negative_prompt: Optional[str] = None
seed: Optional[int] = None
enable_prompt_expansion: bool = True
@dataclass
class TalkingAvatarRequest:
"""Request for InfiniteTalk talking avatar."""
image_base64: str
audio_base64: str
resolution: str = "720p" # 480p or 720p
prompt: Optional[str] = None
mask_image_base64: Optional[str] = None
seed: Optional[int] = None
class TransformStudioService:
"""Service for Transform Studio operations."""
def __init__(self):
"""Initialize Transform Studio service."""
self.wan25_service = WAN25Service()
self.infinitetalk_service = InfiniteTalkService()
# Video output directory
# __file__ is: backend/services/image_studio/transform_service.py
# We need: backend/transform_videos
base_dir = Path(__file__).parent.parent.parent.parent
self.output_dir = base_dir / "transform_videos"
self.output_dir.mkdir(parents=True, exist_ok=True)
# Verify directory was created
if not self.output_dir.exists():
raise RuntimeError(f"Failed to create transform_videos directory: {self.output_dir}")
logger.info(f"[Transform Studio] Initialized with output directory: {self.output_dir}")
def _save_video_file(
self,
video_bytes: bytes,
operation_type: str,
user_id: str,
) -> Dict[str, Any]:
"""Save video file to disk.
Args:
video_bytes: Video content as bytes
operation_type: Type of operation (e.g., "image-to-video", "talking-avatar")
user_id: User ID for directory organization
Returns:
Dictionary with filename, file_path, and file_url
"""
# Create user-specific directory
user_dir = self.output_dir / user_id
user_dir.mkdir(parents=True, exist_ok=True)
# Generate filename
filename = f"{operation_type}_{uuid.uuid4().hex[:8]}.mp4"
filename = sanitize_filename(filename)
# Save file
file_path, error = save_file_safely(
content=video_bytes,
directory=user_dir,
filename=filename,
max_file_size=500 * 1024 * 1024 # 500MB max for videos
)
if error:
raise HTTPException(
status_code=500,
detail=f"Failed to save video file: {error}"
)
file_url = f"/api/image-studio/videos/{user_id}/{filename}"
return {
"filename": filename,
"file_path": str(file_path),
"file_url": file_url,
"file_size": len(video_bytes),
}
async def transform_image_to_video(
self,
request: TransformImageToVideoRequest,
user_id: str,
) -> Dict[str, Any]:
"""Transform image to video using unified video generation entry point.
Args:
request: Transform request
user_id: User ID for tracking and file organization
Returns:
Dictionary with video URL, metadata, and cost
"""
logger.info(
f"[Transform Studio] Image-to-video request from user {user_id}: "
f"resolution={request.resolution}, duration={request.duration}s"
)
# Use unified video generation entry point
# This handles pre-flight validation, generation, and usage tracking
# Returns dict with video_bytes and full metadata
result = ai_video_generate(
image_base64=request.image_base64,
prompt=request.prompt,
operation_type="image-to-video",
provider="wavespeed",
user_id=user_id,
duration=request.duration,
resolution=request.resolution,
negative_prompt=request.negative_prompt,
seed=request.seed,
audio_base64=request.audio_base64,
enable_prompt_expansion=request.enable_prompt_expansion,
model="alibaba/wan-2.5/image-to-video",
)
# Extract video bytes and metadata from result
video_bytes = result["video_bytes"]
# Save video to disk
save_result = self._save_video_file(
video_bytes=video_bytes,
operation_type="image-to-video",
user_id=user_id,
)
# Save to asset library
try:
from services.database import get_db
from utils.asset_tracker import save_asset_to_library
db = next(get_db())
try:
save_asset_to_library(
db=db,
user_id=user_id,
asset_type="video",
source_module="image_studio",
filename=save_result["filename"],
file_url=save_result["file_url"],
file_path=save_result["file_path"],
file_size=save_result["file_size"],
mime_type="video/mp4",
title=f"Transform: Image-to-Video ({request.resolution})",
description=f"Generated video using WAN 2.5: {request.prompt[:100]}",
prompt=result.get("prompt", request.prompt),
tags=["image_studio", "transform", "video", "image-to-video", request.resolution],
provider=result.get("provider", "wavespeed"),
model=result.get("model_name", "alibaba/wan-2.5/image-to-video"),
cost=result.get("cost", 0.0),
asset_metadata={
"resolution": request.resolution,
"duration": result.get("duration", float(request.duration)),
"operation": "image-to-video",
"width": result.get("width", 1280),
"height": result.get("height", 720),
}
)
logger.info(f"[Transform Studio] Video saved to asset library")
finally:
db.close()
except Exception as e:
logger.warning(f"[Transform Studio] Failed to save to asset library: {e}")
return {
"success": True,
"video_url": save_result["file_url"],
"video_base64": None, # Don't include base64 for large videos
"duration": result.get("duration", float(request.duration)),
"resolution": result.get("resolution", request.resolution),
"width": result.get("width", 1280),
"height": result.get("height", 720),
"file_size": save_result["file_size"],
"cost": result.get("cost", 0.0),
"provider": result.get("provider", "wavespeed"),
"model": result.get("model_name", "alibaba/wan-2.5/image-to-video"),
"metadata": result.get("metadata", {}),
}
async def create_talking_avatar(
self,
request: TalkingAvatarRequest,
user_id: str,
) -> Dict[str, Any]:
"""Create talking avatar using InfiniteTalk.
Args:
request: Talking avatar request
user_id: User ID for tracking and file organization
Returns:
Dictionary with video URL, metadata, and cost
"""
logger.info(
f"[Transform Studio] Talking avatar request from user {user_id}: "
f"resolution={request.resolution}"
)
# Generate video using InfiniteTalk
result = await self.infinitetalk_service.create_talking_avatar(
image_base64=request.image_base64,
audio_base64=request.audio_base64,
resolution=request.resolution,
prompt=request.prompt,
mask_image_base64=request.mask_image_base64,
seed=request.seed,
user_id=user_id,
)
# Save video to disk
save_result = self._save_video_file(
video_bytes=result["video_bytes"],
operation_type="talking-avatar",
user_id=user_id,
)
# Track usage
try:
usage_info = track_video_usage(
user_id=user_id,
provider=result["provider"],
model_name=result["model_name"],
prompt=result.get("prompt", ""),
video_bytes=result["video_bytes"],
cost_override=result["cost"],
)
logger.info(
f"[Transform Studio] Usage tracked: {usage_info.get('current_calls', 0)} / "
f"{usage_info.get('video_limit_display', '')} videos, "
f"cost=${result['cost']:.2f}"
)
except Exception as e:
logger.warning(f"[Transform Studio] Failed to track usage: {e}")
# Save to asset library
try:
from services.database import get_db
from utils.asset_tracker import save_asset_to_library
db = next(get_db())
try:
save_asset_to_library(
db=db,
user_id=user_id,
asset_type="video",
source_module="image_studio",
filename=save_result["filename"],
file_url=save_result["file_url"],
file_path=save_result["file_path"],
file_size=save_result["file_size"],
mime_type="video/mp4",
title=f"Transform: Talking Avatar ({request.resolution})",
description="Generated talking avatar video using InfiniteTalk",
prompt=result.get("prompt", ""),
tags=["image_studio", "transform", "video", "talking-avatar", request.resolution],
provider=result["provider"],
model=result["model_name"],
cost=result["cost"],
asset_metadata={
"resolution": request.resolution,
"duration": result.get("duration", 5.0),
"operation": "talking-avatar",
"width": result.get("width", 1280),
"height": result.get("height", 720),
}
)
logger.info(f"[Transform Studio] Video saved to asset library")
finally:
db.close()
except Exception as e:
logger.warning(f"[Transform Studio] Failed to save to asset library: {e}")
return {
"success": True,
"video_url": save_result["file_url"],
"video_base64": None, # Don't include base64 for large videos
"duration": result.get("duration", 5.0),
"resolution": result.get("resolution", request.resolution),
"width": result.get("width", 1280),
"height": result.get("height", 720),
"file_size": save_result["file_size"],
"cost": result["cost"],
"provider": result["provider"],
"model": result["model_name"],
"metadata": result.get("metadata", {}),
}
def estimate_cost(
self,
operation: str,
resolution: str,
duration: Optional[int] = None,
) -> Dict[str, Any]:
"""Estimate cost for transform operation.
Args:
operation: Operation type ("image-to-video" or "talking-avatar")
resolution: Output resolution
duration: Video duration in seconds (for image-to-video)
Returns:
Cost estimation details
"""
if operation == "image-to-video":
if duration is None:
duration = 5
cost = self.wan25_service.calculate_cost(resolution, duration)
return {
"estimated_cost": cost,
"breakdown": {
"base_cost": 0.0,
"per_second": self.wan25_service.calculate_cost(resolution, 1),
"duration": duration,
"total": cost,
},
"currency": "USD",
"provider": "wavespeed",
"model": "alibaba/wan-2.5/image-to-video",
}
elif operation == "talking-avatar":
# InfiniteTalk minimum is 5 seconds
estimated_duration = duration or 5.0
cost = self.infinitetalk_service.calculate_cost(resolution, estimated_duration)
return {
"estimated_cost": cost,
"breakdown": {
"base_cost": 0.0,
"per_second": self.infinitetalk_service.calculate_cost(resolution, 1.0),
"duration": estimated_duration,
"total": cost,
},
"currency": "USD",
"provider": "wavespeed",
"model": "wavespeed-ai/infinitetalk",
}
else:
raise ValueError(f"Unknown operation: {operation}")

View File

@@ -0,0 +1,154 @@
import base64
import io
from dataclasses import dataclass
from typing import Literal, Optional, Dict, Any
from fastapi import HTTPException
from PIL import Image
from services.stability_service import StabilityAIService
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.upscale")
UpscaleMode = Literal["fast", "conservative", "creative", "auto"]
@dataclass
class UpscaleStudioRequest:
image_base64: str
mode: UpscaleMode = "auto"
target_width: Optional[int] = None
target_height: Optional[int] = None
preset: Optional[str] = None # e.g., web/print/social
prompt: Optional[str] = None # used for conservative/creative modes
class UpscaleStudioService:
"""Handles image upscaling workflows."""
def __init__(self):
logger.info("[Upscale Studio] Service initialized")
async def process_upscale(
self,
request: UpscaleStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_upscale_operations
db = next(get_db())
try:
pricing_service = PricingService(db)
logger.info("[Upscale Studio] 🛂 Running pre-flight validation for user %s", user_id)
validate_image_upscale_operations(pricing_service=pricing_service, user_id=user_id)
finally:
db.close()
image_bytes = self._decode_base64(request.image_base64)
if not image_bytes:
raise ValueError("Primary image is required for upscaling")
mode = self._resolve_mode(request)
async with StabilityAIService() as stability_service:
logger.info("[Upscale Studio] Running '%s' upscale for user=%s", mode, user_id)
params = {
"target_width": request.target_width,
"target_height": request.target_height,
}
# remove None values
params = {k: v for k, v in params.items() if v is not None}
if mode == "fast":
result = await stability_service.upscale_fast(
image=image_bytes,
**params,
)
elif mode == "conservative":
prompt = request.prompt or "High fidelity upscale preserving original details"
result = await stability_service.upscale_conservative(
image=image_bytes,
prompt=prompt,
**params,
)
elif mode == "creative":
prompt = request.prompt or "Creative upscale with enhanced artistic details"
result = await stability_service.upscale_creative(
image=image_bytes,
prompt=prompt,
**params,
)
else:
raise ValueError(f"Unsupported upscale mode: {mode}")
image_bytes = self._extract_image_bytes(result)
metadata = self._image_metadata(image_bytes)
return {
"success": True,
"mode": mode,
"image_base64": self._to_base64(image_bytes),
"width": metadata["width"],
"height": metadata["height"],
"metadata": {
"preset": request.preset,
"target_width": request.target_width,
"target_height": request.target_height,
"prompt": request.prompt,
},
}
@staticmethod
def _decode_base64(value: Optional[str]) -> Optional[bytes]:
if not value:
return None
try:
if value.startswith("data:"):
_, b64data = value.split(",", 1)
else:
b64data = value
return base64.b64decode(b64data)
except Exception as exc:
logger.error("[Upscale Studio] Failed to decode base64 image: %s", exc)
raise ValueError("Invalid base64 image payload") from exc
@staticmethod
def _to_base64(image_bytes: bytes) -> str:
return f"data:image/png;base64,{base64.b64encode(image_bytes).decode('utf-8')}"
@staticmethod
def _image_metadata(image_bytes: bytes) -> Dict[str, int]:
with Image.open(io.BytesIO(image_bytes)) as img:
return {"width": img.width, "height": img.height}
@staticmethod
def _extract_image_bytes(result: Any) -> bytes:
if isinstance(result, bytes):
return result
if isinstance(result, dict):
artifacts = result.get("artifacts") or result.get("data") or result.get("images") or []
for artifact in artifacts:
if isinstance(artifact, dict):
if artifact.get("base64"):
return base64.b64decode(artifact["base64"])
if artifact.get("b64_json"):
return base64.b64decode(artifact["b64_json"])
raise HTTPException(status_code=502, detail="Unable to extract image from provider response")
@staticmethod
def _resolve_mode(request: UpscaleStudioRequest) -> UpscaleMode:
if request.mode != "auto":
return request.mode
# simple heuristic: if target >= 3000px, use conservative, else fast
if (request.target_width and request.target_width >= 3000) or (
request.target_height and request.target_height >= 3000
):
return "conservative"
return "fast"

View File

@@ -0,0 +1,297 @@
"""WAN 2.5 service for Alibaba image-to-video generation via WaveSpeed."""
import base64
import asyncio
from typing import Any, Dict, Optional, Callable
import requests
from fastapi import HTTPException
from loguru import logger
from services.wavespeed.client import WaveSpeedClient
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.wan25")
WAN25_MODEL_PATH = "alibaba/wan-2.5/image-to-video"
WAN25_MODEL_NAME = "alibaba/wan-2.5/image-to-video"
# Pricing per second (from WaveSpeed docs)
PRICING = {
"480p": 0.05, # $0.05 per second
"720p": 0.10, # $0.10 per second
"1080p": 0.15, # $0.15 per second
}
MAX_IMAGE_BYTES = 10 * 1024 * 1024 # 10MB (recommended)
MAX_AUDIO_BYTES = 15 * 1024 * 1024 # 15MB (API limit)
MIN_AUDIO_DURATION = 3 # seconds
MAX_AUDIO_DURATION = 30 # seconds
def _as_data_uri(content_bytes: bytes, mime_type: str) -> str:
"""Convert bytes to data URI."""
encoded = base64.b64encode(content_bytes).decode("utf-8")
return f"data:{mime_type};base64,{encoded}"
def _decode_base64_image(image_base64: str) -> tuple[bytes, str]:
"""Decode base64 image, handling data URIs."""
if image_base64.startswith("data:"):
# Extract mime type and base64 data
if "," not in image_base64:
raise ValueError("Invalid data URI format: missing comma separator")
header, encoded = image_base64.split(",", 1)
mime_parts = header.split(":")[1].split(";")[0] if ":" in header else "image/png"
mime_type = mime_parts.strip()
if not mime_type:
mime_type = "image/png"
image_bytes = base64.b64decode(encoded)
else:
# Assume it's raw base64
image_bytes = base64.b64decode(image_base64)
mime_type = "image/png" # Default
return image_bytes, mime_type
def _decode_base64_audio(audio_base64: str) -> tuple[bytes, str]:
"""Decode base64 audio, handling data URIs."""
if audio_base64.startswith("data:"):
if "," not in audio_base64:
raise ValueError("Invalid data URI format: missing comma separator")
header, encoded = audio_base64.split(",", 1)
mime_parts = header.split(":")[1].split(";")[0] if ":" in header else "audio/mpeg"
mime_type = mime_parts.strip()
if not mime_type:
mime_type = "audio/mpeg"
audio_bytes = base64.b64decode(encoded)
else:
audio_bytes = base64.b64decode(audio_base64)
mime_type = "audio/mpeg" # Default
return audio_bytes, mime_type
class WAN25Service:
"""Service for Alibaba WAN 2.5 image-to-video generation."""
def __init__(self, client: Optional[WaveSpeedClient] = None):
"""Initialize WAN 2.5 service."""
self.client = client or WaveSpeedClient()
logger.info("[WAN 2.5] Service initialized")
def calculate_cost(self, resolution: str, duration: int) -> float:
"""Calculate cost for video generation.
Args:
resolution: Output resolution (480p, 720p, 1080p)
duration: Video duration in seconds (5 or 10)
Returns:
Cost in USD
"""
cost_per_second = PRICING.get(resolution, PRICING["720p"])
return cost_per_second * duration
async def generate_video(
self,
image_base64: str,
prompt: str,
audio_base64: Optional[str] = None,
resolution: str = "720p",
duration: int = 5,
negative_prompt: Optional[str] = None,
seed: Optional[int] = None,
enable_prompt_expansion: bool = True,
progress_callback: Optional[Callable[[float, str], None]] = None,
) -> Dict[str, Any]:
"""Generate video using WAN 2.5.
Args:
image_base64: Image in base64 or data URI format
prompt: Text prompt describing the video
audio_base64: Optional audio file (wav/mp3, 3-30s, ≤15MB)
resolution: Output resolution (480p, 720p, 1080p)
duration: Video duration in seconds (5 or 10)
negative_prompt: Optional negative prompt
seed: Optional random seed for reproducibility
enable_prompt_expansion: Enable prompt optimizer
Returns:
Dictionary with video bytes, metadata, and cost
"""
# Validate resolution
if resolution not in PRICING:
raise HTTPException(
status_code=400,
detail=f"Invalid resolution: {resolution}. Must be one of: {list(PRICING.keys())}"
)
# Validate duration
if duration not in [5, 10]:
raise HTTPException(
status_code=400,
detail=f"Invalid duration: {duration}. Must be 5 or 10 seconds"
)
# Validate prompt
if not prompt or not prompt.strip():
raise HTTPException(
status_code=400,
detail="Prompt is required and cannot be empty"
)
# Decode image
try:
image_bytes, image_mime = _decode_base64_image(image_base64)
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Failed to decode image: {str(e)}"
)
# Validate image size
if len(image_bytes) > MAX_IMAGE_BYTES:
raise HTTPException(
status_code=400,
detail=f"Image exceeds {MAX_IMAGE_BYTES / (1024*1024):.0f}MB limit"
)
# Build payload
payload = {
"image": _as_data_uri(image_bytes, image_mime),
"prompt": prompt,
"resolution": resolution,
"duration": duration,
"enable_prompt_expansion": enable_prompt_expansion,
}
# Add optional audio
if audio_base64:
try:
audio_bytes, audio_mime = _decode_base64_audio(audio_base64)
# Validate audio size
if len(audio_bytes) > MAX_AUDIO_BYTES:
raise HTTPException(
status_code=400,
detail=f"Audio exceeds {MAX_AUDIO_BYTES / (1024*1024):.0f}MB limit"
)
# Note: Audio duration validation would require audio analysis
# For now, we rely on API to handle it (API keeps first 5s/10s if longer)
payload["audio"] = _as_data_uri(audio_bytes, audio_mime)
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Failed to decode audio: {str(e)}"
)
# Add optional parameters
if negative_prompt:
payload["negative_prompt"] = negative_prompt
if seed is not None:
payload["seed"] = seed
# Submit to WaveSpeed
logger.info(
f"[WAN 2.5] Submitting video generation request: resolution={resolution}, duration={duration}s"
)
try:
prediction_id = self.client.submit_image_to_video(
WAN25_MODEL_PATH,
payload,
timeout=60
)
except HTTPException as e:
logger.error(f"[WAN 2.5] Submission failed: {e.detail}")
raise
# Poll for completion
logger.info(f"[WAN 2.5] Polling for completion: prediction_id={prediction_id}")
try:
# WAN 2.5 typically takes 1-2 minutes
result = self.client.poll_until_complete(
prediction_id,
timeout_seconds=180, # 3 minutes max
interval_seconds=2.0,
progress_callback=progress_callback,
)
except HTTPException as e:
detail = e.detail or {}
if isinstance(detail, dict):
detail.setdefault("prediction_id", prediction_id)
detail.setdefault("resume_available", True)
raise HTTPException(status_code=e.status_code, detail=detail)
# Extract video URL
outputs = result.get("outputs") or []
if not outputs:
raise HTTPException(
status_code=502,
detail="WAN 2.5 completed but returned no outputs"
)
video_url = outputs[0]
if not isinstance(video_url, str) or not video_url.startswith("http"):
raise HTTPException(
status_code=502,
detail=f"Invalid video URL format: {video_url}"
)
# Download video (run synchronous request in thread)
logger.info(f"[WAN 2.5] Downloading video from: {video_url}")
video_response = await asyncio.to_thread(
requests.get,
video_url,
timeout=180
)
if video_response.status_code != 200:
raise HTTPException(
status_code=502,
detail={
"error": "Failed to download WAN 2.5 video",
"status_code": video_response.status_code,
"response": video_response.text[:200],
}
)
video_bytes = video_response.content
metadata = result.get("metadata") or {}
# Calculate cost
cost = self.calculate_cost(resolution, duration)
# Get video dimensions from resolution
resolution_dims = {
"480p": (854, 480),
"720p": (1280, 720),
"1080p": (1920, 1080),
}
width, height = resolution_dims.get(resolution, (1280, 720))
logger.info(
f"[WAN 2.5] ✅ Generated video: {len(video_bytes)} bytes, "
f"resolution={resolution}, duration={duration}s, cost=${cost:.2f}"
)
return {
"video_bytes": video_bytes,
"prompt": prompt,
"duration": float(duration),
"model_name": WAN25_MODEL_NAME,
"cost": cost,
"provider": "wavespeed",
"source_video_url": video_url,
"prediction_id": prediction_id,
"resolution": resolution,
"width": width,
"height": height,
"metadata": metadata,
}