AI Image Studio Phase 1

This commit is contained in:
ajaysi
2025-11-20 09:06:00 +05:30
parent e96525347b
commit eede21ad42
58 changed files with 12951 additions and 8 deletions

View File

@@ -0,0 +1,20 @@
"""Image Studio service package for centralized image operations."""
from .studio_manager import ImageStudioManager
from .create_service import CreateStudioService, CreateStudioRequest
from .edit_service import EditStudioService, EditStudioRequest
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
from .templates import PlatformTemplates, TemplateManager
__all__ = [
"ImageStudioManager",
"CreateStudioService",
"CreateStudioRequest",
"EditStudioService",
"EditStudioRequest",
"UpscaleStudioService",
"UpscaleStudioRequest",
"PlatformTemplates",
"TemplateManager",
]

View File

@@ -0,0 +1,458 @@
"""Create Studio service for AI-powered image generation."""
import os
from typing import Optional, Dict, Any, List, Literal
from dataclasses import dataclass
from services.llm_providers.image_generation import (
ImageGenerationOptions,
ImageGenerationResult,
HuggingFaceImageProvider,
GeminiImageProvider,
StabilityImageProvider,
WaveSpeedImageProvider,
)
from .templates import TemplateManager, ImageTemplate, Platform, TemplateCategory
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.create")
@dataclass
class CreateStudioRequest:
"""Request for image generation in Create Studio."""
prompt: str
template_id: Optional[str] = None
provider: Optional[str] = None # "auto", "stability", "wavespeed", "huggingface", "gemini"
model: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
aspect_ratio: Optional[str] = None # e.g., "1:1", "16:9"
style_preset: Optional[str] = None
quality: Literal["draft", "standard", "premium"] = "standard"
negative_prompt: Optional[str] = None
guidance_scale: Optional[float] = None
steps: Optional[int] = None
seed: Optional[int] = None
num_variations: int = 1
enhance_prompt: bool = True
use_persona: bool = False
persona_id: Optional[str] = None
class CreateStudioService:
"""Service for Create Studio image generation operations."""
# Provider-to-model mapping for smart recommendations
PROVIDER_MODELS = {
"stability": {
"ultra": "stability-ultra", # Best quality, 8 credits
"core": "stability-core", # Fast & affordable, 3 credits
"sd3": "sd3.5-large", # SD3.5 model
},
"wavespeed": {
"ideogram-v3-turbo": "ideogram-v3-turbo", # Photorealistic, text rendering
"qwen-image": "qwen-image", # Fast generation
},
"huggingface": {
"flux": "black-forest-labs/FLUX.1-Krea-dev",
},
"gemini": {
"imagen": "imagen-3.0-generate-001",
}
}
# Quality-to-provider mapping
QUALITY_PROVIDERS = {
"draft": ["huggingface", "wavespeed:qwen-image"], # Fast, low cost
"standard": ["stability:core", "wavespeed:ideogram-v3-turbo"], # Balanced
"premium": ["wavespeed:ideogram-v3-turbo", "stability:ultra"], # Best quality
}
def __init__(self):
"""Initialize Create Studio service."""
self.template_manager = TemplateManager()
logger.info("[Create Studio] Initialized with template manager")
def _get_provider_instance(self, provider_name: str, api_key: Optional[str] = None):
"""Get provider instance by name.
Args:
provider_name: Name of the provider
api_key: Optional API key (uses env vars if not provided)
Returns:
Provider instance
Raises:
ValueError: If provider is not supported
"""
if provider_name == "stability":
return StabilityImageProvider(api_key=api_key or os.getenv("STABILITY_API_KEY"))
elif provider_name == "wavespeed":
return WaveSpeedImageProvider(api_key=api_key or os.getenv("WAVESPEED_API_KEY"))
elif provider_name == "huggingface":
return HuggingFaceImageProvider(api_token=api_key or os.getenv("HF_API_KEY"))
elif provider_name == "gemini":
return GeminiImageProvider(api_key=api_key or os.getenv("GEMINI_API_KEY"))
else:
raise ValueError(f"Unsupported provider: {provider_name}")
def _select_provider_and_model(
self,
request: CreateStudioRequest,
template: Optional[ImageTemplate] = None
) -> tuple[str, Optional[str]]:
"""Smart provider and model selection.
Args:
request: Create studio request
template: Optional template with recommendations
Returns:
Tuple of (provider_name, model_name)
"""
# Explicit provider selection
if request.provider and request.provider != "auto":
provider = request.provider
model = request.model
logger.info("[Provider Selection] User specified: %s (model: %s)", provider, model)
return provider, model
# Template recommendation
if template and template.recommended_provider:
provider = template.recommended_provider
logger.info("[Provider Selection] Template recommends: %s", provider)
# Map provider to specific model if not specified
if not request.model:
if provider == "ideogram":
return "wavespeed", "ideogram-v3-turbo"
elif provider == "qwen":
return "wavespeed", "qwen-image"
elif provider == "stability":
# Choose based on quality
if request.quality == "premium":
return "stability", "stability-ultra"
elif request.quality == "draft":
return "stability", "stability-core"
else:
return "stability", "stability-core"
return provider, request.model
# Quality-based selection
quality_options = self.QUALITY_PROVIDERS.get(request.quality, self.QUALITY_PROVIDERS["standard"])
selected = quality_options[0] # Pick first option
if ":" in selected:
provider, model = selected.split(":", 1)
else:
provider = selected
model = None
logger.info("[Provider Selection] Quality-based (%s): %s (model: %s)",
request.quality, provider, model)
return provider, model
def _enhance_prompt(self, prompt: str, style_preset: Optional[str] = None) -> str:
"""Enhance prompt with style and quality descriptors.
Args:
prompt: Original prompt
style_preset: Style preset to apply
Returns:
Enhanced prompt
"""
enhanced = prompt
# Add style-specific enhancements
style_enhancements = {
"photographic": ", professional photography, high quality, detailed, sharp focus, natural lighting",
"digital-art": ", digital art, vibrant colors, detailed, high quality, artstation trending",
"cinematic": ", cinematic lighting, dramatic, film grain, high quality, professional",
"3d-model": ", 3D render, octane render, unreal engine, high quality, detailed",
"anime": ", anime style, vibrant colors, detailed, high quality",
"line-art": ", clean line art, detailed linework, high contrast, professional",
}
if style_preset and style_preset in style_enhancements:
enhanced += style_enhancements[style_preset]
logger.info("[Prompt Enhancement] Original: %s", prompt[:100])
logger.info("[Prompt Enhancement] Enhanced: %s", enhanced[:100])
return enhanced
def _apply_template(self, request: CreateStudioRequest, template: ImageTemplate) -> CreateStudioRequest:
"""Apply template settings to request.
Args:
request: Original request
template: Template to apply
Returns:
Modified request
"""
# Apply template dimensions if not specified
if not request.width and not request.height:
request.width = template.aspect_ratio.width
request.height = template.aspect_ratio.height
# Apply template style if not specified
if not request.style_preset:
request.style_preset = template.style_preset
# Apply template quality if not specified
if request.quality == "standard":
request.quality = template.quality
logger.info("[Template Applied] %s -> %dx%d, style=%s, quality=%s",
template.name, request.width, request.height,
request.style_preset, request.quality)
return request
def _calculate_dimensions(
self,
width: Optional[int],
height: Optional[int],
aspect_ratio: Optional[str]
) -> tuple[int, int]:
"""Calculate image dimensions from width/height or aspect ratio.
Args:
width: Explicit width
height: Explicit height
aspect_ratio: Aspect ratio string (e.g., "16:9")
Returns:
Tuple of (width, height)
"""
# Both dimensions specified
if width and height:
return width, height
# Aspect ratio specified
if aspect_ratio:
try:
w_ratio, h_ratio = map(int, aspect_ratio.split(":"))
# Use width if specified
if width:
height = int(width * h_ratio / w_ratio)
return width, height
# Use height if specified
if height:
width = int(height * w_ratio / h_ratio)
return width, height
# Default size based on aspect ratio
# Use 1080p as base
if w_ratio >= h_ratio:
# Landscape or square
width = 1920
height = int(1920 * h_ratio / w_ratio)
else:
# Portrait
height = 1920
width = int(1920 * w_ratio / h_ratio)
return width, height
except ValueError:
logger.warning("[Dimensions] Invalid aspect ratio: %s", aspect_ratio)
# Default dimensions
return 1024, 1024
async def generate(
self,
request: CreateStudioRequest,
user_id: Optional[str] = None
) -> Dict[str, Any]:
"""Generate image(s) using Create Studio.
Args:
request: Create studio request
user_id: User ID for validation and tracking
Returns:
Dictionary with generation results
Raises:
ValueError: If request is invalid
RuntimeError: If generation fails
"""
logger.info("[Create Studio] Starting generation: prompt=%s, template=%s",
request.prompt[:100], request.template_id)
# Pre-flight validation: Check subscription and usage limits
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_generation_operations
from fastapi import HTTPException
db = next(get_db())
try:
pricing_service = PricingService(db)
logger.info(f"[Create Studio] 🛂 Running pre-flight validation for user {user_id}")
validate_image_generation_operations(
pricing_service=pricing_service,
user_id=user_id,
num_images=request.num_variations
)
logger.info(f"[Create Studio] ✅ Pre-flight validation passed - proceeding with generation")
except HTTPException as http_ex:
logger.error(f"[Create Studio] ❌ Pre-flight validation failed - blocking generation")
raise
finally:
db.close()
else:
logger.warning("[Create Studio] ⚠️ No user_id provided - skipping pre-flight validation")
# Load template if specified
template = None
if request.template_id:
template = self.template_manager.get_by_id(request.template_id)
if not template:
raise ValueError(f"Template not found: {request.template_id}")
# Apply template settings
request = self._apply_template(request, template)
# Calculate dimensions
width, height = self._calculate_dimensions(
request.width, request.height, request.aspect_ratio
)
# Enhance prompt if requested
prompt = request.prompt
if request.enhance_prompt:
prompt = self._enhance_prompt(prompt, request.style_preset)
# Select provider and model
provider_name, model = self._select_provider_and_model(request, template)
# Get provider instance
try:
provider = self._get_provider_instance(provider_name)
except Exception as e:
logger.error("[Create Studio] ❌ Failed to initialize provider %s: %s",
provider_name, str(e))
raise RuntimeError(f"Provider initialization failed: {str(e)}")
# Generate images
results = []
for i in range(request.num_variations):
logger.info("[Create Studio] Generating variation %d/%d",
i + 1, request.num_variations)
try:
# Prepare options
options = ImageGenerationOptions(
prompt=prompt,
negative_prompt=request.negative_prompt,
width=width,
height=height,
guidance_scale=request.guidance_scale,
steps=request.steps,
seed=request.seed + i if request.seed else None,
model=model,
extra={"style_preset": request.style_preset} if request.style_preset else {}
)
# Generate image
result: ImageGenerationResult = provider.generate(options)
results.append({
"image_bytes": result.image_bytes,
"width": result.width,
"height": result.height,
"provider": result.provider,
"model": result.model,
"seed": result.seed,
"metadata": result.metadata,
"variation": i + 1,
})
logger.info("[Create Studio] ✅ Variation %d generated successfully", i + 1)
except Exception as e:
logger.error("[Create Studio] ❌ Failed to generate variation %d: %s",
i + 1, str(e), exc_info=True)
results.append({
"error": str(e),
"variation": i + 1,
})
# Return results
return {
"success": True,
"request": {
"prompt": request.prompt,
"enhanced_prompt": prompt if request.enhance_prompt else None,
"template_id": request.template_id,
"template_name": template.name if template else None,
"provider": provider_name,
"model": model,
"dimensions": f"{width}x{height}",
"quality": request.quality,
"num_variations": request.num_variations,
},
"results": results,
"total_generated": sum(1 for r in results if "image_bytes" in r),
"total_failed": sum(1 for r in results if "error" in r),
}
def get_templates(
self,
platform: Optional[Platform] = None,
category: Optional[TemplateCategory] = None
) -> List[ImageTemplate]:
"""Get available templates.
Args:
platform: Filter by platform
category: Filter by category
Returns:
List of templates
"""
if platform:
return self.template_manager.get_by_platform(platform)
elif category:
return self.template_manager.get_by_category(category)
else:
return self.template_manager.get_all_templates()
def search_templates(self, query: str) -> List[ImageTemplate]:
"""Search templates by query.
Args:
query: Search query
Returns:
List of matching templates
"""
return self.template_manager.search(query)
def recommend_templates(
self,
use_case: str,
platform: Optional[Platform] = None
) -> List[ImageTemplate]:
"""Recommend templates based on use case.
Args:
use_case: Description of use case
platform: Optional platform filter
Returns:
List of recommended templates
"""
return self.template_manager.recommend_for_use_case(use_case, platform)

View File

@@ -0,0 +1,458 @@
"""Edit Studio service for AI-powered image editing and transformations."""
from __future__ import annotations
import asyncio
import base64
import io
from dataclasses import dataclass, field
from typing import Any, Dict, Literal, Optional
from PIL import Image
from services.llm_providers.main_image_editing import edit_image as huggingface_edit_image
from services.stability_service import StabilityAIService
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.edit")
EditOperationType = Literal[
"remove_background",
"inpaint",
"outpaint",
"search_replace",
"search_recolor",
"relight",
"general_edit",
]
@dataclass
class EditStudioRequest:
"""Normalized request payload for Edit Studio operations."""
image_base64: str
operation: EditOperationType
prompt: Optional[str] = None
negative_prompt: Optional[str] = None
mask_base64: Optional[str] = None
search_prompt: Optional[str] = None
select_prompt: Optional[str] = None
background_image_base64: Optional[str] = None
lighting_image_base64: Optional[str] = None
expand_left: Optional[int] = None
expand_right: Optional[int] = None
expand_up: Optional[int] = None
expand_down: Optional[int] = None
provider: Optional[str] = None
model: Optional[str] = None
style_preset: Optional[str] = None
guidance_scale: Optional[float] = None
steps: Optional[int] = None
seed: Optional[int] = None
output_format: str = "png"
options: Dict[str, Any] = field(default_factory=dict)
class EditStudioService:
"""Service layer orchestrating Edit Studio operations."""
SUPPORTED_OPERATIONS: Dict[EditOperationType, Dict[str, Any]] = {
"remove_background": {
"label": "Remove Background",
"description": "Isolate the main subject and remove the background.",
"provider": "stability",
"async": False,
"fields": {
"prompt": False,
"mask": False,
"negative_prompt": False,
"search_prompt": False,
"select_prompt": False,
"background": False,
"lighting": False,
"expansion": False,
},
},
"inpaint": {
"label": "Inpaint & Fix",
"description": "Edit specific regions using prompts and optional masks.",
"provider": "stability",
"async": False,
"fields": {
"prompt": True,
"mask": True,
"negative_prompt": True,
"search_prompt": False,
"select_prompt": False,
"background": False,
"lighting": False,
"expansion": False,
},
},
"outpaint": {
"label": "Outpaint",
"description": "Extend the canvas in any direction with smart fill.",
"provider": "stability",
"async": False,
"fields": {
"prompt": False,
"mask": False,
"negative_prompt": True,
"search_prompt": False,
"select_prompt": False,
"background": False,
"lighting": False,
"expansion": True,
},
},
"search_replace": {
"label": "Search & Replace",
"description": "Locate objects via search prompt and replace them.",
"provider": "stability",
"async": False,
"fields": {
"prompt": True,
"mask": False,
"negative_prompt": False,
"search_prompt": True,
"select_prompt": False,
"background": False,
"lighting": False,
"expansion": False,
},
},
"search_recolor": {
"label": "Search & Recolor",
"description": "Select elements via prompt and recolor them.",
"provider": "stability",
"async": False,
"fields": {
"prompt": True,
"mask": False,
"negative_prompt": False,
"search_prompt": False,
"select_prompt": True,
"background": False,
"lighting": False,
"expansion": False,
},
},
"relight": {
"label": "Replace Background & Relight",
"description": "Swap backgrounds and relight using reference images.",
"provider": "stability",
"async": True,
"fields": {
"prompt": False,
"mask": False,
"negative_prompt": False,
"search_prompt": False,
"select_prompt": False,
"background": True,
"lighting": True,
"expansion": False,
},
},
"general_edit": {
"label": "Prompt-based Edit",
"description": "Free-form editing powered by Hugging Face image-to-image models.",
"provider": "huggingface",
"async": False,
"fields": {
"prompt": True,
"mask": False,
"negative_prompt": True,
"search_prompt": False,
"select_prompt": False,
"background": False,
"lighting": False,
"expansion": False,
},
},
}
def __init__(self):
logger.info("[Edit Studio] Initialized edit service")
@staticmethod
def _decode_base64_image(value: Optional[str]) -> Optional[bytes]:
"""Decode a base64 (or data URL) string to bytes."""
if not value:
return None
try:
# Handle data URLs (data:image/png;base64,...)
if value.startswith("data:"):
_, b64data = value.split(",", 1)
else:
b64data = value
return base64.b64decode(b64data)
except Exception as exc:
logger.error(f"[Edit Studio] Failed to decode base64 image: {exc}")
raise ValueError("Invalid base64 image payload") from exc
@staticmethod
def _image_bytes_to_metadata(image_bytes: bytes) -> Dict[str, Any]:
"""Extract width/height metadata from image bytes."""
with Image.open(io.BytesIO(image_bytes)) as img:
return {
"width": img.width,
"height": img.height,
}
@staticmethod
def _bytes_to_base64(image_bytes: bytes, output_format: str = "png") -> str:
"""Convert raw bytes to base64 data URL."""
b64 = base64.b64encode(image_bytes).decode("utf-8")
return f"data:image/{output_format};base64,{b64}"
def list_operations(self) -> Dict[str, Dict[str, Any]]:
"""Expose supported operations for UI rendering."""
return self.SUPPORTED_OPERATIONS
async def process_edit(
self,
request: EditStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Process edit request and return normalized response."""
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_editing_operations
from fastapi import HTTPException
db = next(get_db())
try:
pricing_service = PricingService(db)
logger.info(f"[Edit Studio] 🛂 Running pre-flight validation for user {user_id}")
validate_image_editing_operations(
pricing_service=pricing_service,
user_id=user_id,
)
logger.info("[Edit Studio] ✅ Pre-flight validation passed")
except HTTPException:
logger.error("[Edit Studio] ❌ Pre-flight validation failed")
raise
finally:
db.close()
else:
logger.warning("[Edit Studio] ⚠️ No user_id provided - skipping pre-flight validation")
image_bytes = self._decode_base64_image(request.image_base64)
if not image_bytes:
raise ValueError("Primary image payload is required")
mask_bytes = self._decode_base64_image(request.mask_base64)
background_bytes = self._decode_base64_image(request.background_image_base64)
lighting_bytes = self._decode_base64_image(request.lighting_image_base64)
operation = request.operation
logger.info("[Edit Studio] Processing operation='%s' for user=%s", operation, user_id)
if operation not in self.SUPPORTED_OPERATIONS:
raise ValueError(f"Unsupported edit operation: {operation}")
if operation in {"remove_background", "inpaint", "outpaint", "search_replace", "search_recolor", "relight"}:
image_bytes = await self._handle_stability_edit(
operation=operation,
request=request,
image_bytes=image_bytes,
mask_bytes=mask_bytes,
background_bytes=background_bytes,
lighting_bytes=lighting_bytes,
)
else:
image_bytes = await self._handle_general_edit(
request=request,
image_bytes=image_bytes,
mask_bytes=mask_bytes,
user_id=user_id,
)
metadata = self._image_bytes_to_metadata(image_bytes)
metadata.update(
{
"operation": operation,
"style_preset": request.style_preset,
"provider": self.SUPPORTED_OPERATIONS[operation]["provider"],
}
)
response = {
"success": True,
"operation": operation,
"provider": metadata["provider"],
"image_base64": self._bytes_to_base64(image_bytes, request.output_format),
"width": metadata["width"],
"height": metadata["height"],
"metadata": metadata,
}
logger.info("[Edit Studio] ✅ Operation '%s' completed", operation)
return response
async def _handle_stability_edit(
self,
operation: EditOperationType,
request: EditStudioRequest,
image_bytes: bytes,
mask_bytes: Optional[bytes],
background_bytes: Optional[bytes],
lighting_bytes: Optional[bytes],
) -> bytes:
"""Execute Stability AI edit workflows."""
stability_service = StabilityAIService()
async with stability_service:
if operation == "remove_background":
result = await stability_service.remove_background(
image=image_bytes,
output_format=request.output_format,
)
elif operation == "inpaint":
if not request.prompt:
raise ValueError("Prompt is required for inpainting")
result = await stability_service.inpaint(
image=image_bytes,
prompt=request.prompt,
mask=mask_bytes,
negative_prompt=request.negative_prompt,
output_format=request.output_format,
style_preset=request.style_preset,
grow_mask=request.options.get("grow_mask", 5),
)
elif operation == "outpaint":
result = await stability_service.outpaint(
image=image_bytes,
prompt=request.prompt,
negative_prompt=request.negative_prompt,
output_format=request.output_format,
left=request.expand_left or 0,
right=request.expand_right or 0,
up=request.expand_up or 0,
down=request.expand_down or 0,
style_preset=request.style_preset,
)
elif operation == "search_replace":
if not (request.prompt and request.search_prompt):
raise ValueError("Both prompt and search_prompt are required for search & replace")
result = await stability_service.search_and_replace(
image=image_bytes,
prompt=request.prompt,
search_prompt=request.search_prompt,
output_format=request.output_format,
)
elif operation == "search_recolor":
if not (request.prompt and request.select_prompt):
raise ValueError("Both prompt and select_prompt are required for search & recolor")
result = await stability_service.search_and_recolor(
image=image_bytes,
prompt=request.prompt,
select_prompt=request.select_prompt,
output_format=request.output_format,
)
elif operation == "relight":
if not background_bytes and not lighting_bytes:
raise ValueError("At least one reference (background or lighting) is required for relight")
result = await stability_service.replace_background_and_relight(
subject_image=image_bytes,
background_reference=background_bytes,
light_reference=lighting_bytes,
output_format=request.output_format,
)
if isinstance(result, dict) and result.get("id"):
result = await self._poll_stability_result(
stability_service,
generation_id=result["id"],
output_format=request.output_format,
)
else:
raise ValueError(f"Unsupported Stability operation: {operation}")
return self._extract_image_bytes(result)
async def _handle_general_edit(
self,
request: EditStudioRequest,
image_bytes: bytes,
mask_bytes: Optional[bytes],
user_id: Optional[str],
) -> bytes:
"""Execute Hugging Face powered general editing (synchronous API)."""
if not request.prompt:
raise ValueError("Prompt is required for general edits")
options = {
"provider": request.provider or "huggingface",
"model": request.model,
"guidance_scale": request.guidance_scale,
"steps": request.steps,
"seed": request.seed,
}
# huggingface edit is synchronous - run in thread
result = await asyncio.to_thread(
huggingface_edit_image,
image_bytes,
request.prompt,
options,
user_id,
)
return result.image_bytes
@staticmethod
def _extract_image_bytes(result: Any) -> bytes:
"""Normalize Stability responses into raw image bytes."""
if isinstance(result, bytes):
return result
if isinstance(result, dict):
artifacts = result.get("artifacts") or result.get("data") or result.get("images") or []
for artifact in artifacts:
if isinstance(artifact, dict):
if artifact.get("base64"):
return base64.b64decode(artifact["base64"])
if artifact.get("b64_json"):
return base64.b64decode(artifact["b64_json"])
raise RuntimeError("Unable to extract image bytes from provider response")
async def _poll_stability_result(
self,
stability_service: StabilityAIService,
generation_id: str,
output_format: str,
timeout_seconds: int = 240,
interval_seconds: float = 2.0,
) -> bytes:
"""Poll Stability async endpoint until result is ready."""
elapsed = 0.0
while elapsed < timeout_seconds:
result = await stability_service.get_generation_result(
generation_id=generation_id,
accept_type="*/*",
)
if isinstance(result, bytes):
return result
if isinstance(result, dict):
state = (result.get("state") or result.get("status") or "").lower()
if state in {"succeeded", "success", "ready", "completed"}:
return self._extract_image_bytes(result)
if state in {"failed", "error"}:
raise RuntimeError(f"Stability generation failed: {result}")
await asyncio.sleep(interval_seconds)
elapsed += interval_seconds
raise RuntimeError("Timed out waiting for Stability generation result")

View File

@@ -0,0 +1,304 @@
"""Image Studio Manager - Main orchestration service for all image operations."""
from typing import Optional, Dict, Any, List
from .create_service import CreateStudioService, CreateStudioRequest
from .edit_service import EditStudioService, EditStudioRequest
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
from .templates import Platform, TemplateCategory, ImageTemplate
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.manager")
class ImageStudioManager:
"""Main manager for Image Studio operations."""
def __init__(self):
"""Initialize Image Studio Manager."""
self.create_service = CreateStudioService()
self.edit_service = EditStudioService()
self.upscale_service = UpscaleStudioService()
logger.info("[Image Studio Manager] Initialized successfully")
# ====================
# CREATE STUDIO
# ====================
async def create_image(
self,
request: CreateStudioRequest,
user_id: Optional[str] = None
) -> Dict[str, Any]:
"""Create/generate image using Create Studio.
Args:
request: Create studio request
user_id: User ID for validation
Returns:
Dictionary with generation results
"""
logger.info("[Image Studio] Create image request from user: %s", user_id)
return await self.create_service.generate(request, user_id=user_id)
# ====================
# EDIT STUDIO
# ====================
async def edit_image(
self,
request: EditStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Run Edit Studio operations."""
logger.info("[Image Studio] Edit image request from user: %s", user_id)
return await self.edit_service.process_edit(request, user_id=user_id)
def get_edit_operations(self) -> Dict[str, Any]:
"""Expose edit operations for UI."""
return self.edit_service.list_operations()
# ====================
# UPSCALE STUDIO
# ====================
async def upscale_image(
self,
request: UpscaleStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Run Upscale Studio operations."""
logger.info("[Image Studio] Upscale request from user: %s", user_id)
return await self.upscale_service.process_upscale(request, user_id=user_id)
def get_templates(
self,
platform: Optional[Platform] = None,
category: Optional[TemplateCategory] = None
) -> List[ImageTemplate]:
"""Get available templates.
Args:
platform: Filter by platform
category: Filter by category
Returns:
List of templates
"""
return self.create_service.get_templates(platform=platform, category=category)
def search_templates(self, query: str) -> List[ImageTemplate]:
"""Search templates by query.
Args:
query: Search query
Returns:
List of matching templates
"""
return self.create_service.search_templates(query)
def recommend_templates(
self,
use_case: str,
platform: Optional[Platform] = None
) -> List[ImageTemplate]:
"""Recommend templates based on use case.
Args:
use_case: Use case description
platform: Optional platform filter
Returns:
List of recommended templates
"""
return self.create_service.recommend_templates(use_case, platform)
def get_providers(self) -> Dict[str, Any]:
"""Get available image providers and their capabilities.
Returns:
Dictionary of providers with capabilities
"""
return {
"stability": {
"name": "Stability AI",
"models": ["ultra", "core", "sd3.5-large"],
"capabilities": ["text-to-image", "editing", "upscaling", "control", "3d"],
"max_resolution": (2048, 2048),
"cost_range": "3-8 credits per image",
},
"wavespeed": {
"name": "WaveSpeed AI",
"models": ["ideogram-v3-turbo", "qwen-image"],
"capabilities": ["text-to-image", "photorealistic", "fast-generation"],
"max_resolution": (1024, 1024),
"cost_range": "$0.05-$0.10 per image",
},
"huggingface": {
"name": "HuggingFace",
"models": ["FLUX.1-Krea-dev", "RunwayML"],
"capabilities": ["text-to-image", "image-to-image"],
"max_resolution": (1024, 1024),
"cost_range": "Free tier available",
},
"gemini": {
"name": "Google Gemini",
"models": ["imagen-3.0"],
"capabilities": ["text-to-image", "conversational-editing"],
"max_resolution": (1024, 1024),
"cost_range": "Free tier available",
}
}
# ====================
# COST ESTIMATION
# ====================
def estimate_cost(
self,
provider: str,
model: Optional[str],
operation: str,
num_images: int = 1,
resolution: Optional[tuple[int, int]] = None
) -> Dict[str, Any]:
"""Estimate cost for image operations.
Args:
provider: Provider name
model: Model name
operation: Operation type (generate, edit, upscale, etc.)
num_images: Number of images
resolution: Image resolution (width, height)
Returns:
Cost estimation details
"""
# Base costs (adjust based on actual pricing)
base_costs = {
"stability": {
"ultra": 0.08, # 8 credits
"core": 0.03, # 3 credits
"sd3": 0.065, # 6.5 credits
},
"wavespeed": {
"ideogram-v3-turbo": 0.10,
"qwen-image": 0.05,
},
"huggingface": {
"default": 0.0, # Free tier
},
"gemini": {
"default": 0.0, # Free tier
}
}
# Get base cost
provider_costs = base_costs.get(provider, {})
cost_per_image = provider_costs.get(model, provider_costs.get("default", 0.0))
# Calculate total
total_cost = cost_per_image * num_images
return {
"provider": provider,
"model": model,
"operation": operation,
"num_images": num_images,
"resolution": f"{resolution[0]}x{resolution[1]}" if resolution else "default",
"cost_per_image": cost_per_image,
"total_cost": total_cost,
"currency": "USD",
"estimated": True,
}
# ====================
# PLATFORM SPECS
# ====================
def get_platform_specs(self, platform: Platform) -> Dict[str, Any]:
"""Get platform specifications and requirements.
Args:
platform: Platform to get specs for
Returns:
Platform specifications
"""
specs = {
Platform.INSTAGRAM: {
"name": "Instagram",
"formats": [
{"name": "Feed Post (Square)", "ratio": "1:1", "size": "1080x1080"},
{"name": "Feed Post (Portrait)", "ratio": "4:5", "size": "1080x1350"},
{"name": "Story", "ratio": "9:16", "size": "1080x1920"},
{"name": "Reel", "ratio": "9:16", "size": "1080x1920"},
],
"file_types": ["JPG", "PNG"],
"max_file_size": "30MB",
},
Platform.FACEBOOK: {
"name": "Facebook",
"formats": [
{"name": "Feed Post", "ratio": "1.91:1", "size": "1200x630"},
{"name": "Feed Post (Square)", "ratio": "1:1", "size": "1080x1080"},
{"name": "Story", "ratio": "9:16", "size": "1080x1920"},
{"name": "Cover Photo", "ratio": "16:9", "size": "820x312"},
],
"file_types": ["JPG", "PNG"],
"max_file_size": "30MB",
},
Platform.TWITTER: {
"name": "Twitter/X",
"formats": [
{"name": "Post", "ratio": "16:9", "size": "1200x675"},
{"name": "Card", "ratio": "2:1", "size": "1200x600"},
{"name": "Header", "ratio": "3:1", "size": "1500x500"},
],
"file_types": ["JPG", "PNG", "GIF"],
"max_file_size": "5MB",
},
Platform.LINKEDIN: {
"name": "LinkedIn",
"formats": [
{"name": "Feed Post", "ratio": "1.91:1", "size": "1200x628"},
{"name": "Feed Post (Square)", "ratio": "1:1", "size": "1080x1080"},
{"name": "Article", "ratio": "2:1", "size": "1200x627"},
{"name": "Company Cover", "ratio": "4:1", "size": "1128x191"},
],
"file_types": ["JPG", "PNG"],
"max_file_size": "8MB",
},
Platform.YOUTUBE: {
"name": "YouTube",
"formats": [
{"name": "Thumbnail", "ratio": "16:9", "size": "1280x720"},
{"name": "Channel Art", "ratio": "16:9", "size": "2560x1440"},
],
"file_types": ["JPG", "PNG"],
"max_file_size": "2MB",
},
Platform.PINTEREST: {
"name": "Pinterest",
"formats": [
{"name": "Pin", "ratio": "2:3", "size": "1000x1500"},
{"name": "Story Pin", "ratio": "9:16", "size": "1080x1920"},
],
"file_types": ["JPG", "PNG"],
"max_file_size": "20MB",
},
Platform.TIKTOK: {
"name": "TikTok",
"formats": [
{"name": "Video Cover", "ratio": "9:16", "size": "1080x1920"},
],
"file_types": ["JPG", "PNG"],
"max_file_size": "10MB",
},
}
return specs.get(platform, {})

View File

@@ -0,0 +1,555 @@
"""Template system for Image Studio with platform-specific presets."""
from dataclasses import dataclass
from typing import Dict, List, Optional, Literal
from enum import Enum
class Platform(str, Enum):
"""Supported social media platforms."""
INSTAGRAM = "instagram"
FACEBOOK = "facebook"
TWITTER = "twitter"
LINKEDIN = "linkedin"
YOUTUBE = "youtube"
PINTEREST = "pinterest"
TIKTOK = "tiktok"
BLOG = "blog"
EMAIL = "email"
WEBSITE = "website"
class TemplateCategory(str, Enum):
"""Template categories."""
SOCIAL_MEDIA = "social_media"
BLOG_CONTENT = "blog_content"
AD_CREATIVE = "ad_creative"
PRODUCT = "product"
BRAND_ASSETS = "brand_assets"
EMAIL_MARKETING = "email_marketing"
@dataclass
class AspectRatio:
"""Aspect ratio configuration."""
ratio: str # e.g., "1:1", "16:9"
width: int
height: int
label: str # e.g., "Square", "Widescreen"
@dataclass
class ImageTemplate:
"""Image generation template."""
id: str
name: str
category: TemplateCategory
platform: Optional[Platform]
aspect_ratio: AspectRatio
description: str
recommended_provider: str
style_preset: str
quality: Literal["draft", "standard", "premium"]
prompt_template: Optional[str] = None
negative_prompt_template: Optional[str] = None
use_cases: List[str] = None
class PlatformTemplates:
"""Platform-specific template definitions."""
# Aspect Ratios
SQUARE_1_1 = AspectRatio("1:1", 1080, 1080, "Square")
PORTRAIT_4_5 = AspectRatio("4:5", 1080, 1350, "Portrait")
STORY_9_16 = AspectRatio("9:16", 1080, 1920, "Story/Reel")
LANDSCAPE_16_9 = AspectRatio("16:9", 1920, 1080, "Landscape")
WIDE_21_9 = AspectRatio("21:9", 2560, 1080, "Ultra Wide")
TWITTER_2_1 = AspectRatio("2:1", 1200, 600, "Twitter Card")
TWITTER_3_1 = AspectRatio("3:1", 1500, 500, "Twitter Header")
FACEBOOK_1_91_1 = AspectRatio("1.91:1", 1200, 630, "Facebook Feed")
LINKEDIN_1_91_1 = AspectRatio("1.91:1", 1200, 628, "LinkedIn Feed")
LINKEDIN_2_1 = AspectRatio("2:1", 1200, 627, "LinkedIn Article")
LINKEDIN_4_1 = AspectRatio("4:1", 1128, 191, "LinkedIn Cover")
PINTEREST_2_3 = AspectRatio("2:3", 1000, 1500, "Pinterest Pin")
YOUTUBE_16_9 = AspectRatio("16:9", 1280, 720, "YouTube Thumbnail")
FACEBOOK_COVER_16_9 = AspectRatio("16:9", 820, 312, "Facebook Cover")
@classmethod
def get_platform_templates(cls) -> Dict[Platform, List[ImageTemplate]]:
"""Get all platform-specific templates."""
return {
Platform.INSTAGRAM: cls._instagram_templates(),
Platform.FACEBOOK: cls._facebook_templates(),
Platform.TWITTER: cls._twitter_templates(),
Platform.LINKEDIN: cls._linkedin_templates(),
Platform.YOUTUBE: cls._youtube_templates(),
Platform.PINTEREST: cls._pinterest_templates(),
Platform.TIKTOK: cls._tiktok_templates(),
Platform.BLOG: cls._blog_templates(),
Platform.EMAIL: cls._email_templates(),
Platform.WEBSITE: cls._website_templates(),
}
@classmethod
def _instagram_templates(cls) -> List[ImageTemplate]:
"""Instagram templates."""
return [
ImageTemplate(
id="instagram_feed_square",
name="Instagram Feed Post (Square)",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.INSTAGRAM,
aspect_ratio=cls.SQUARE_1_1,
description="Perfect for Instagram feed posts with maximum visibility",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Product showcase", "Lifestyle posts", "Brand content"]
),
ImageTemplate(
id="instagram_feed_portrait",
name="Instagram Feed Post (Portrait)",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.INSTAGRAM,
aspect_ratio=cls.PORTRAIT_4_5,
description="Vertical format for maximum feed real estate",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Fashion", "Food", "Product photography"]
),
ImageTemplate(
id="instagram_story",
name="Instagram Story",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.INSTAGRAM,
aspect_ratio=cls.STORY_9_16,
description="Full-screen vertical stories",
recommended_provider="ideogram",
style_preset="digital-art",
quality="standard",
use_cases=["Behind-the-scenes", "Announcements", "Quick updates"]
),
ImageTemplate(
id="instagram_reel_cover",
name="Instagram Reel Cover",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.INSTAGRAM,
aspect_ratio=cls.STORY_9_16,
description="Eye-catching reel cover images",
recommended_provider="ideogram",
style_preset="cinematic",
quality="premium",
use_cases=["Video covers", "Thumbnails", "Highlights"]
),
]
@classmethod
def _facebook_templates(cls) -> List[ImageTemplate]:
"""Facebook templates."""
return [
ImageTemplate(
id="facebook_feed",
name="Facebook Feed Post",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.FACEBOOK,
aspect_ratio=cls.FACEBOOK_1_91_1,
description="Optimized for Facebook news feed",
recommended_provider="ideogram",
style_preset="photographic",
quality="standard",
use_cases=["Page posts", "Shared content", "Community posts"]
),
ImageTemplate(
id="facebook_feed_square",
name="Facebook Feed Post (Square)",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.FACEBOOK,
aspect_ratio=cls.SQUARE_1_1,
description="Square format for feed posts",
recommended_provider="ideogram",
style_preset="photographic",
quality="standard",
use_cases=["Page posts", "Product highlights"]
),
ImageTemplate(
id="facebook_story",
name="Facebook Story",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.FACEBOOK,
aspect_ratio=cls.STORY_9_16,
description="Full-screen vertical stories",
recommended_provider="ideogram",
style_preset="digital-art",
quality="standard",
use_cases=["Quick updates", "Promotions", "Events"]
),
ImageTemplate(
id="facebook_cover",
name="Facebook Cover Photo",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.FACEBOOK,
aspect_ratio=cls.FACEBOOK_COVER_16_9,
description="Wide cover photo for pages",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Page branding", "Events", "Seasonal updates"]
),
]
@classmethod
def _twitter_templates(cls) -> List[ImageTemplate]:
"""Twitter/X templates."""
return [
ImageTemplate(
id="twitter_post",
name="Twitter/X Post",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.TWITTER,
aspect_ratio=cls.LANDSCAPE_16_9,
description="Optimized for Twitter feed",
recommended_provider="ideogram",
style_preset="photographic",
quality="standard",
use_cases=["Tweets", "News", "Updates"]
),
ImageTemplate(
id="twitter_card",
name="Twitter Card",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.TWITTER,
aspect_ratio=cls.TWITTER_2_1,
description="Twitter card with link preview",
recommended_provider="ideogram",
style_preset="digital-art",
quality="standard",
use_cases=["Link sharing", "Articles", "Blog posts"]
),
ImageTemplate(
id="twitter_header",
name="Twitter Header",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.TWITTER,
aspect_ratio=cls.TWITTER_3_1,
description="Profile header image",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Profile branding", "Personal brand", "Business identity"]
),
]
@classmethod
def _linkedin_templates(cls) -> List[ImageTemplate]:
"""LinkedIn templates."""
return [
ImageTemplate(
id="linkedin_post",
name="LinkedIn Post",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.LINKEDIN,
aspect_ratio=cls.LINKEDIN_1_91_1,
description="Professional feed posts",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Professional content", "Industry news", "Thought leadership"]
),
ImageTemplate(
id="linkedin_post_square",
name="LinkedIn Post (Square)",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.LINKEDIN,
aspect_ratio=cls.SQUARE_1_1,
description="Square format for LinkedIn feed",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Quick tips", "Infographics", "Quotes"]
),
ImageTemplate(
id="linkedin_article",
name="LinkedIn Article Header",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.LINKEDIN,
aspect_ratio=cls.LINKEDIN_2_1,
description="Article header images",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Long-form content", "Articles", "Newsletters"]
),
ImageTemplate(
id="linkedin_cover",
name="LinkedIn Company Cover",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.LINKEDIN,
aspect_ratio=cls.LINKEDIN_4_1,
description="Company page cover photo",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Company branding", "Recruitment", "Brand identity"]
),
]
@classmethod
def _youtube_templates(cls) -> List[ImageTemplate]:
"""YouTube templates."""
return [
ImageTemplate(
id="youtube_thumbnail",
name="YouTube Thumbnail",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.YOUTUBE,
aspect_ratio=cls.YOUTUBE_16_9,
description="Eye-catching video thumbnails",
recommended_provider="ideogram",
style_preset="cinematic",
quality="premium",
use_cases=["Video thumbnails", "Channel branding", "Playlists"]
),
ImageTemplate(
id="youtube_channel_art",
name="YouTube Channel Art",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.YOUTUBE,
aspect_ratio=cls.LANDSCAPE_16_9,
description="Channel banner art",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Channel branding", "Personal brand", "Business identity"]
),
]
@classmethod
def _pinterest_templates(cls) -> List[ImageTemplate]:
"""Pinterest templates."""
return [
ImageTemplate(
id="pinterest_pin",
name="Pinterest Pin",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.PINTEREST,
aspect_ratio=cls.PINTEREST_2_3,
description="Vertical pin format",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Product pins", "DIY guides", "Recipes", "Inspiration"]
),
ImageTemplate(
id="pinterest_story",
name="Pinterest Story Pin",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.PINTEREST,
aspect_ratio=cls.STORY_9_16,
description="Full-screen story pins",
recommended_provider="ideogram",
style_preset="digital-art",
quality="standard",
use_cases=["Step-by-step guides", "Tutorials", "Quick tips"]
),
]
@classmethod
def _tiktok_templates(cls) -> List[ImageTemplate]:
"""TikTok templates."""
return [
ImageTemplate(
id="tiktok_video_cover",
name="TikTok Video Cover",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.TIKTOK,
aspect_ratio=cls.STORY_9_16,
description="Vertical video cover",
recommended_provider="ideogram",
style_preset="cinematic",
quality="premium",
use_cases=["Video covers", "Thumbnails", "Profile highlights"]
),
]
@classmethod
def _blog_templates(cls) -> List[ImageTemplate]:
"""Blog content templates."""
return [
ImageTemplate(
id="blog_header",
name="Blog Header",
category=TemplateCategory.BLOG_CONTENT,
platform=Platform.BLOG,
aspect_ratio=cls.LANDSCAPE_16_9,
description="Blog post featured image",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Featured images", "Article headers", "Post thumbnails"]
),
ImageTemplate(
id="blog_header_wide",
name="Blog Header (Wide)",
category=TemplateCategory.BLOG_CONTENT,
platform=Platform.BLOG,
aspect_ratio=cls.WIDE_21_9,
description="Ultra-wide blog header",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Hero sections", "Wide headers", "Landing pages"]
),
]
@classmethod
def _email_templates(cls) -> List[ImageTemplate]:
"""Email marketing templates."""
return [
ImageTemplate(
id="email_banner",
name="Email Banner",
category=TemplateCategory.EMAIL_MARKETING,
platform=Platform.EMAIL,
aspect_ratio=cls.LANDSCAPE_16_9,
description="Email header banner",
recommended_provider="ideogram",
style_preset="photographic",
quality="standard",
use_cases=["Email headers", "Newsletter banners", "Promotions"]
),
ImageTemplate(
id="email_product",
name="Email Product Image",
category=TemplateCategory.EMAIL_MARKETING,
platform=Platform.EMAIL,
aspect_ratio=cls.SQUARE_1_1,
description="Product showcase for emails",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Product highlights", "Promotions", "Offers"]
),
]
@classmethod
def _website_templates(cls) -> List[ImageTemplate]:
"""Website templates."""
return [
ImageTemplate(
id="website_hero",
name="Website Hero Image",
category=TemplateCategory.BRAND_ASSETS,
platform=Platform.WEBSITE,
aspect_ratio=cls.WIDE_21_9,
description="Hero section background",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Hero sections", "Landing pages", "Home page banners"]
),
ImageTemplate(
id="website_banner",
name="Website Banner",
category=TemplateCategory.BRAND_ASSETS,
platform=Platform.WEBSITE,
aspect_ratio=cls.LANDSCAPE_16_9,
description="Section banners",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Section headers", "Category pages", "Feature sections"]
),
]
class TemplateManager:
"""Manager for image templates with search and recommendation."""
def __init__(self):
"""Initialize template manager."""
self.templates = PlatformTemplates.get_platform_templates()
self._all_templates: Optional[List[ImageTemplate]] = None
def get_all_templates(self) -> List[ImageTemplate]:
"""Get all templates across all platforms."""
if self._all_templates is None:
self._all_templates = []
for platform_templates in self.templates.values():
self._all_templates.extend(platform_templates)
return self._all_templates
def get_by_platform(self, platform: Platform) -> List[ImageTemplate]:
"""Get templates for a specific platform."""
return self.templates.get(platform, [])
def get_by_category(self, category: TemplateCategory) -> List[ImageTemplate]:
"""Get templates by category."""
all_templates = self.get_all_templates()
return [t for t in all_templates if t.category == category]
def get_by_id(self, template_id: str) -> Optional[ImageTemplate]:
"""Get template by ID."""
all_templates = self.get_all_templates()
for template in all_templates:
if template.id == template_id:
return template
return None
def search(self, query: str) -> List[ImageTemplate]:
"""Search templates by query."""
query = query.lower()
all_templates = self.get_all_templates()
results = []
for template in all_templates:
# Search in name, description, and use cases
searchable = (
template.name.lower() + " " +
template.description.lower() + " " +
" ".join(template.use_cases or []).lower()
)
if query in searchable:
results.append(template)
return results
def recommend_for_use_case(self, use_case: str, platform: Optional[Platform] = None) -> List[ImageTemplate]:
"""Recommend templates based on use case and platform."""
use_case_lower = use_case.lower()
all_templates = self.get_all_templates()
# Filter by platform if specified
if platform:
all_templates = [t for t in all_templates if t.platform == platform]
# Find matching templates
matches = []
for template in all_templates:
if template.use_cases:
for case in template.use_cases:
if use_case_lower in case.lower():
matches.append(template)
break
return matches
def get_aspect_ratio_options(self) -> List[AspectRatio]:
"""Get all available aspect ratios."""
return [
PlatformTemplates.SQUARE_1_1,
PlatformTemplates.PORTRAIT_4_5,
PlatformTemplates.STORY_9_16,
PlatformTemplates.LANDSCAPE_16_9,
PlatformTemplates.WIDE_21_9,
PlatformTemplates.TWITTER_2_1,
PlatformTemplates.TWITTER_3_1,
PlatformTemplates.FACEBOOK_1_91_1,
PlatformTemplates.LINKEDIN_1_91_1,
PlatformTemplates.LINKEDIN_2_1,
PlatformTemplates.LINKEDIN_4_1,
PlatformTemplates.PINTEREST_2_3,
PlatformTemplates.YOUTUBE_16_9,
PlatformTemplates.FACEBOOK_COVER_16_9,
]

View File

@@ -0,0 +1,154 @@
import base64
import io
from dataclasses import dataclass
from typing import Literal, Optional, Dict, Any
from fastapi import HTTPException
from PIL import Image
from services.stability_service import StabilityAIService
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.upscale")
UpscaleMode = Literal["fast", "conservative", "creative", "auto"]
@dataclass
class UpscaleStudioRequest:
image_base64: str
mode: UpscaleMode = "auto"
target_width: Optional[int] = None
target_height: Optional[int] = None
preset: Optional[str] = None # e.g., web/print/social
prompt: Optional[str] = None # used for conservative/creative modes
class UpscaleStudioService:
"""Handles image upscaling workflows."""
def __init__(self):
logger.info("[Upscale Studio] Service initialized")
async def process_upscale(
self,
request: UpscaleStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_upscale_operations
db = next(get_db())
try:
pricing_service = PricingService(db)
logger.info("[Upscale Studio] 🛂 Running pre-flight validation for user %s", user_id)
validate_image_upscale_operations(pricing_service=pricing_service, user_id=user_id)
finally:
db.close()
image_bytes = self._decode_base64(request.image_base64)
if not image_bytes:
raise ValueError("Primary image is required for upscaling")
mode = self._resolve_mode(request)
async with StabilityAIService() as stability_service:
logger.info("[Upscale Studio] Running '%s' upscale for user=%s", mode, user_id)
params = {
"target_width": request.target_width,
"target_height": request.target_height,
}
# remove None values
params = {k: v for k, v in params.items() if v is not None}
if mode == "fast":
result = await stability_service.upscale_fast(
image=image_bytes,
**params,
)
elif mode == "conservative":
prompt = request.prompt or "High fidelity upscale preserving original details"
result = await stability_service.upscale_conservative(
image=image_bytes,
prompt=prompt,
**params,
)
elif mode == "creative":
prompt = request.prompt or "Creative upscale with enhanced artistic details"
result = await stability_service.upscale_creative(
image=image_bytes,
prompt=prompt,
**params,
)
else:
raise ValueError(f"Unsupported upscale mode: {mode}")
image_bytes = self._extract_image_bytes(result)
metadata = self._image_metadata(image_bytes)
return {
"success": True,
"mode": mode,
"image_base64": self._to_base64(image_bytes),
"width": metadata["width"],
"height": metadata["height"],
"metadata": {
"preset": request.preset,
"target_width": request.target_width,
"target_height": request.target_height,
"prompt": request.prompt,
},
}
@staticmethod
def _decode_base64(value: Optional[str]) -> Optional[bytes]:
if not value:
return None
try:
if value.startswith("data:"):
_, b64data = value.split(",", 1)
else:
b64data = value
return base64.b64decode(b64data)
except Exception as exc:
logger.error("[Upscale Studio] Failed to decode base64 image: %s", exc)
raise ValueError("Invalid base64 image payload") from exc
@staticmethod
def _to_base64(image_bytes: bytes) -> str:
return f"data:image/png;base64,{base64.b64encode(image_bytes).decode('utf-8')}"
@staticmethod
def _image_metadata(image_bytes: bytes) -> Dict[str, int]:
with Image.open(io.BytesIO(image_bytes)) as img:
return {"width": img.width, "height": img.height}
@staticmethod
def _extract_image_bytes(result: Any) -> bytes:
if isinstance(result, bytes):
return result
if isinstance(result, dict):
artifacts = result.get("artifacts") or result.get("data") or result.get("images") or []
for artifact in artifacts:
if isinstance(artifact, dict):
if artifact.get("base64"):
return base64.b64decode(artifact["base64"])
if artifact.get("b64_json"):
return base64.b64decode(artifact["b64_json"])
raise HTTPException(status_code=502, detail="Unable to extract image from provider response")
@staticmethod
def _resolve_mode(request: UpscaleStudioRequest) -> UpscaleMode:
if request.mode != "auto":
return request.mode
# simple heuristic: if target >= 3000px, use conservative, else fast
if (request.target_width and request.target_width >= 3000) or (
request.target_height and request.target_height >= 3000
):
return "conservative"
return "fast"