AI Image Studio Phase 1

This commit is contained in:
ajaysi
2025-11-20 09:06:00 +05:30
parent e96525347b
commit eede21ad42
58 changed files with 12951 additions and 8 deletions

View File

@@ -52,6 +52,7 @@ from routers.linkedin import router as linkedin_router
from api.linkedin_image_generation import router as linkedin_image_router
from api.brainstorm import router as brainstorm_router
from api.images import router as images_router
from routers.image_studio import router as image_studio_router
# Import hallucination detector router
from api.hallucination_detector import router as hallucination_detector_router
@@ -296,6 +297,7 @@ async def batch_analyze_urls_endpoint(urls: list[str]):
from routers.platform_analytics import router as platform_analytics_router
app.include_router(platform_analytics_router)
app.include_router(images_router)
app.include_router(image_studio_router)
# Include research configuration router
app.include_router(research_config_router, prefix="/api/research", tags=["research"])

View File

@@ -0,0 +1,593 @@
"""API endpoints for Image Studio operations."""
import base64
from typing import Optional, List, Dict, Any, Literal
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from services.image_studio import (
ImageStudioManager,
CreateStudioRequest,
EditStudioRequest,
)
from services.image_studio.upscale_service import UpscaleStudioRequest
from services.image_studio.templates import Platform, TemplateCategory
from middleware.auth_middleware import get_current_user
from utils.logger_utils import get_service_logger
logger = get_service_logger("api.image_studio")
router = APIRouter(prefix="/api/image-studio", tags=["image-studio"])
# ====================
# REQUEST MODELS
# ====================
class CreateImageRequest(BaseModel):
"""Request model for image generation."""
prompt: str = Field(..., description="Image generation prompt")
template_id: Optional[str] = Field(None, description="Template ID to use")
provider: Optional[str] = Field("auto", description="Provider: auto, stability, wavespeed, huggingface, gemini")
model: Optional[str] = Field(None, description="Specific model to use")
width: Optional[int] = Field(None, description="Image width in pixels")
height: Optional[int] = Field(None, description="Image height in pixels")
aspect_ratio: Optional[str] = Field(None, description="Aspect ratio (e.g., '1:1', '16:9')")
style_preset: Optional[str] = Field(None, description="Style preset")
quality: str = Field("standard", description="Quality: draft, standard, premium")
negative_prompt: Optional[str] = Field(None, description="Negative prompt")
guidance_scale: Optional[float] = Field(None, description="Guidance scale")
steps: Optional[int] = Field(None, description="Number of inference steps")
seed: Optional[int] = Field(None, description="Random seed")
num_variations: int = Field(1, ge=1, le=10, description="Number of variations (1-10)")
enhance_prompt: bool = Field(True, description="Enhance prompt with AI")
use_persona: bool = Field(False, description="Use persona for brand consistency")
persona_id: Optional[str] = Field(None, description="Persona ID")
class CostEstimationRequest(BaseModel):
"""Request model for cost estimation."""
provider: str = Field(..., description="Provider name")
model: Optional[str] = Field(None, description="Model name")
operation: str = Field("generate", description="Operation type")
num_images: int = Field(1, ge=1, description="Number of images")
width: Optional[int] = Field(None, description="Image width")
height: Optional[int] = Field(None, description="Image height")
class EditImageRequest(BaseModel):
"""Request payload for Edit Studio."""
image_base64: str = Field(..., description="Primary image payload (base64 or data URL)")
operation: Literal[
"remove_background",
"inpaint",
"outpaint",
"search_replace",
"search_recolor",
"general_edit",
] = Field(..., description="Edit operation to perform")
prompt: Optional[str] = Field(None, description="Primary prompt/instruction")
negative_prompt: Optional[str] = Field(None, description="Negative prompt for providers that support it")
mask_base64: Optional[str] = Field(None, description="Optional mask image in base64")
search_prompt: Optional[str] = Field(None, description="Search prompt for replace operations")
select_prompt: Optional[str] = Field(None, description="Select prompt for recolor operations")
background_image_base64: Optional[str] = Field(None, description="Reference background image")
lighting_image_base64: Optional[str] = Field(None, description="Reference lighting image")
expand_left: Optional[int] = Field(0, description="Outpaint expansion in pixels (left)")
expand_right: Optional[int] = Field(0, description="Outpaint expansion in pixels (right)")
expand_up: Optional[int] = Field(0, description="Outpaint expansion in pixels (up)")
expand_down: Optional[int] = Field(0, description="Outpaint expansion in pixels (down)")
provider: Optional[str] = Field(None, description="Explicit provider override")
model: Optional[str] = Field(None, description="Explicit model override")
style_preset: Optional[str] = Field(None, description="Style preset for Stability helpers")
guidance_scale: Optional[float] = Field(None, description="Guidance scale for general edits")
steps: Optional[int] = Field(None, description="Inference steps")
seed: Optional[int] = Field(None, description="Random seed for reproducibility")
output_format: str = Field("png", description="Output format for edited image")
options: Optional[Dict[str, Any]] = Field(
None,
description="Advanced provider-specific options (e.g., grow_mask)",
)
class EditImageResponse(BaseModel):
success: bool
operation: str
provider: str
image_base64: str
width: int
height: int
metadata: Dict[str, Any]
class EditOperationsResponse(BaseModel):
operations: Dict[str, Dict[str, Any]]
class UpscaleImageRequest(BaseModel):
image_base64: str
mode: Literal["fast", "conservative", "creative", "auto"] = "auto"
target_width: Optional[int] = Field(None, description="Target width in pixels")
target_height: Optional[int] = Field(None, description="Target height in pixels")
preset: Optional[str] = Field(None, description="Named preset (web, print, social)")
prompt: Optional[str] = Field(None, description="Prompt for conservative/creative modes")
class UpscaleImageResponse(BaseModel):
success: bool
mode: str
image_base64: str
width: int
height: int
metadata: Dict[str, Any]
# ====================
# DEPENDENCY
# ====================
def get_studio_manager() -> ImageStudioManager:
"""Get Image Studio Manager instance."""
return ImageStudioManager()
def _require_user_id(current_user: Dict[str, Any], operation: str) -> str:
"""Ensure user_id is available for protected operations."""
user_id = current_user.get("sub") or current_user.get("user_id")
if not user_id:
logger.error(
"[Image Studio] ❌ Missing user_id for %s operation - blocking request",
operation,
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authenticated user required for image operations.",
)
return user_id
# ====================
# CREATE STUDIO ENDPOINTS
# ====================
@router.post("/create", summary="Generate Image")
async def create_image(
request: CreateImageRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager)
):
"""Generate image(s) using Create Studio.
This endpoint supports:
- Multiple AI providers (Stability AI, WaveSpeed, HuggingFace, Gemini)
- Template-based generation
- Custom dimensions and aspect ratios
- Style presets and quality levels
- Multiple variations
- Prompt enhancement
Returns:
Dictionary with generation results including image data
"""
try:
user_id = _require_user_id(current_user, "image generation")
logger.info(f"[Create Image] Request from user {user_id}: {request.prompt[:100]}")
# Convert request to CreateStudioRequest
studio_request = CreateStudioRequest(
prompt=request.prompt,
template_id=request.template_id,
provider=request.provider,
model=request.model,
width=request.width,
height=request.height,
aspect_ratio=request.aspect_ratio,
style_preset=request.style_preset,
quality=request.quality,
negative_prompt=request.negative_prompt,
guidance_scale=request.guidance_scale,
steps=request.steps,
seed=request.seed,
num_variations=request.num_variations,
enhance_prompt=request.enhance_prompt,
use_persona=request.use_persona,
persona_id=request.persona_id,
)
# Generate images
result = await studio_manager.create_image(studio_request, user_id=user_id)
# Convert image bytes to base64 for JSON response
for idx, img_result in enumerate(result["results"]):
if "image_bytes" in img_result:
img_result["image_base64"] = base64.b64encode(img_result["image_bytes"]).decode("utf-8")
# Remove bytes from response
del img_result["image_bytes"]
logger.info(f"[Create Image] ✅ Success: {result['total_generated']} images generated")
return result
except ValueError as e:
logger.error(f"[Create Image] ❌ Validation error: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
except RuntimeError as e:
logger.error(f"[Create Image] ❌ Generation error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}")
except Exception as e:
logger.error(f"[Create Image] ❌ Unexpected error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
# ====================
# TEMPLATE ENDPOINTS
# ====================
@router.get("/templates", summary="Get Templates")
async def get_templates(
platform: Optional[Platform] = None,
category: Optional[TemplateCategory] = None,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager)
):
"""Get available image templates.
Templates provide pre-configured settings for common use cases:
- Platform-specific dimensions and formats
- Recommended providers and models
- Style presets and quality settings
Args:
platform: Filter by platform (instagram, facebook, twitter, etc.)
category: Filter by category (social_media, blog_content, ad_creative, etc.)
Returns:
List of templates
"""
try:
templates = studio_manager.get_templates(platform=platform, category=category)
# Convert to dict for JSON response
templates_dict = [
{
"id": t.id,
"name": t.name,
"category": t.category.value,
"platform": t.platform.value if t.platform else None,
"aspect_ratio": {
"ratio": t.aspect_ratio.ratio,
"width": t.aspect_ratio.width,
"height": t.aspect_ratio.height,
"label": t.aspect_ratio.label,
},
"description": t.description,
"recommended_provider": t.recommended_provider,
"style_preset": t.style_preset,
"quality": t.quality,
"use_cases": t.use_cases or [],
}
for t in templates
]
return {"templates": templates_dict, "total": len(templates_dict)}
except Exception as e:
logger.error(f"[Get Templates] ❌ Error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/templates/search", summary="Search Templates")
async def search_templates(
query: str,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager)
):
"""Search templates by query.
Searches in template names, descriptions, and use cases.
Args:
query: Search query
Returns:
List of matching templates
"""
try:
templates = studio_manager.search_templates(query)
templates_dict = [
{
"id": t.id,
"name": t.name,
"category": t.category.value,
"platform": t.platform.value if t.platform else None,
"aspect_ratio": {
"ratio": t.aspect_ratio.ratio,
"width": t.aspect_ratio.width,
"height": t.aspect_ratio.height,
"label": t.aspect_ratio.label,
},
"description": t.description,
"recommended_provider": t.recommended_provider,
"style_preset": t.style_preset,
"quality": t.quality,
"use_cases": t.use_cases or [],
}
for t in templates
]
return {"templates": templates_dict, "total": len(templates_dict), "query": query}
except Exception as e:
logger.error(f"[Search Templates] ❌ Error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/templates/recommend", summary="Recommend Templates")
async def recommend_templates(
use_case: str,
platform: Optional[Platform] = None,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager)
):
"""Recommend templates based on use case.
Args:
use_case: Description of use case (e.g., "product showcase", "blog header")
platform: Optional platform filter
Returns:
List of recommended templates
"""
try:
templates = studio_manager.recommend_templates(use_case, platform=platform)
templates_dict = [
{
"id": t.id,
"name": t.name,
"category": t.category.value,
"platform": t.platform.value if t.platform else None,
"aspect_ratio": {
"ratio": t.aspect_ratio.ratio,
"width": t.aspect_ratio.width,
"height": t.aspect_ratio.height,
"label": t.aspect_ratio.label,
},
"description": t.description,
"recommended_provider": t.recommended_provider,
"style_preset": t.style_preset,
"quality": t.quality,
"use_cases": t.use_cases or [],
}
for t in templates
]
return {"templates": templates_dict, "total": len(templates_dict), "use_case": use_case}
except Exception as e:
logger.error(f"[Recommend Templates] ❌ Error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# ====================
# PROVIDER ENDPOINTS
# ====================
@router.get("/providers", summary="Get Providers")
async def get_providers(
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager)
):
"""Get available AI providers and their capabilities.
Returns information about:
- Available models
- Capabilities
- Maximum resolution
- Cost estimates
Returns:
Dictionary of providers
"""
try:
providers = studio_manager.get_providers()
return {"providers": providers}
except Exception as e:
logger.error(f"[Get Providers] ❌ Error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# ====================
# COST ESTIMATION ENDPOINTS
# ====================
@router.post("/estimate-cost", summary="Estimate Cost")
async def estimate_cost(
request: CostEstimationRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager)
):
"""Estimate cost for image generation operations.
Provides cost estimates before generation to help users make informed decisions.
Args:
request: Cost estimation request
Returns:
Cost estimation details
"""
try:
resolution = None
if request.width and request.height:
resolution = (request.width, request.height)
estimate = studio_manager.estimate_cost(
provider=request.provider,
model=request.model,
operation=request.operation,
num_images=request.num_images,
resolution=resolution
)
return estimate
except Exception as e:
logger.error(f"[Estimate Cost] ❌ Error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# ====================
# EDIT STUDIO ENDPOINTS
# ====================
@router.post("/edit/process", response_model=EditImageResponse, summary="Process Edit Studio request")
async def process_edit_image(
request: EditImageRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Perform Edit Studio operations such as remove background, inpaint, or recolor."""
try:
user_id = _require_user_id(current_user, "image editing")
logger.info(f"[Edit Image] Request from user {user_id}: operation={request.operation}")
edit_request = EditStudioRequest(
image_base64=request.image_base64,
operation=request.operation,
prompt=request.prompt,
negative_prompt=request.negative_prompt,
mask_base64=request.mask_base64,
search_prompt=request.search_prompt,
select_prompt=request.select_prompt,
background_image_base64=request.background_image_base64,
lighting_image_base64=request.lighting_image_base64,
expand_left=request.expand_left,
expand_right=request.expand_right,
expand_up=request.expand_up,
expand_down=request.expand_down,
provider=request.provider,
model=request.model,
style_preset=request.style_preset,
guidance_scale=request.guidance_scale,
steps=request.steps,
seed=request.seed,
output_format=request.output_format,
options=request.options or {},
)
result = await studio_manager.edit_image(edit_request, user_id=user_id)
return EditImageResponse(**result)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Edit Image] ❌ Error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Image editing failed: {e}")
@router.get("/edit/operations", response_model=EditOperationsResponse, summary="List Edit Studio operations")
async def get_edit_operations(
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Return metadata for supported Edit Studio operations."""
try:
operations = studio_manager.get_edit_operations()
return EditOperationsResponse(operations=operations)
except Exception as e:
logger.error(f"[Edit Operations] ❌ Error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to load edit operations")
# ====================
# UPSCALE STUDIO ENDPOINTS
# ====================
@router.post("/upscale", response_model=UpscaleImageResponse, summary="Upscale Image")
async def upscale_image(
request: UpscaleImageRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Upscale an image using Stability AI pipelines."""
try:
user_id = _require_user_id(current_user, "image upscaling")
upscale_request = UpscaleStudioRequest(
image_base64=request.image_base64,
mode=request.mode,
target_width=request.target_width,
target_height=request.target_height,
preset=request.preset,
prompt=request.prompt,
)
result = await studio_manager.upscale_image(upscale_request, user_id=user_id)
return UpscaleImageResponse(**result)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Upscale Image] ❌ Error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Image upscaling failed: {e}")
# ====================
# PLATFORM SPECS ENDPOINTS
# ====================
@router.get("/platform-specs/{platform}", summary="Get Platform Specifications")
async def get_platform_specs(
platform: Platform,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager)
):
"""Get specifications and requirements for a specific platform.
Returns:
- Supported formats and dimensions
- File type requirements
- Maximum file size
- Best practices
Args:
platform: Platform name
Returns:
Platform specifications
"""
try:
specs = studio_manager.get_platform_specs(platform)
if not specs:
raise HTTPException(status_code=404, detail=f"Specifications not found for platform: {platform}")
return specs
except HTTPException:
raise
except Exception as e:
logger.error(f"[Get Platform Specs] ❌ Error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# ====================
# HEALTH CHECK
# ====================
@router.get("/health", summary="Health Check")
async def health_check():
"""Health check endpoint for Image Studio.
Returns:
Health status
"""
return {
"status": "healthy",
"service": "image_studio",
"version": "1.0.0",
"modules": {
"create_studio": "available",
"templates": "available",
"providers": "available",
}
}

View File

@@ -0,0 +1,20 @@
"""Image Studio service package for centralized image operations."""
from .studio_manager import ImageStudioManager
from .create_service import CreateStudioService, CreateStudioRequest
from .edit_service import EditStudioService, EditStudioRequest
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
from .templates import PlatformTemplates, TemplateManager
__all__ = [
"ImageStudioManager",
"CreateStudioService",
"CreateStudioRequest",
"EditStudioService",
"EditStudioRequest",
"UpscaleStudioService",
"UpscaleStudioRequest",
"PlatformTemplates",
"TemplateManager",
]

View File

@@ -0,0 +1,458 @@
"""Create Studio service for AI-powered image generation."""
import os
from typing import Optional, Dict, Any, List, Literal
from dataclasses import dataclass
from services.llm_providers.image_generation import (
ImageGenerationOptions,
ImageGenerationResult,
HuggingFaceImageProvider,
GeminiImageProvider,
StabilityImageProvider,
WaveSpeedImageProvider,
)
from .templates import TemplateManager, ImageTemplate, Platform, TemplateCategory
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.create")
@dataclass
class CreateStudioRequest:
"""Request for image generation in Create Studio."""
prompt: str
template_id: Optional[str] = None
provider: Optional[str] = None # "auto", "stability", "wavespeed", "huggingface", "gemini"
model: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
aspect_ratio: Optional[str] = None # e.g., "1:1", "16:9"
style_preset: Optional[str] = None
quality: Literal["draft", "standard", "premium"] = "standard"
negative_prompt: Optional[str] = None
guidance_scale: Optional[float] = None
steps: Optional[int] = None
seed: Optional[int] = None
num_variations: int = 1
enhance_prompt: bool = True
use_persona: bool = False
persona_id: Optional[str] = None
class CreateStudioService:
"""Service for Create Studio image generation operations."""
# Provider-to-model mapping for smart recommendations
PROVIDER_MODELS = {
"stability": {
"ultra": "stability-ultra", # Best quality, 8 credits
"core": "stability-core", # Fast & affordable, 3 credits
"sd3": "sd3.5-large", # SD3.5 model
},
"wavespeed": {
"ideogram-v3-turbo": "ideogram-v3-turbo", # Photorealistic, text rendering
"qwen-image": "qwen-image", # Fast generation
},
"huggingface": {
"flux": "black-forest-labs/FLUX.1-Krea-dev",
},
"gemini": {
"imagen": "imagen-3.0-generate-001",
}
}
# Quality-to-provider mapping
QUALITY_PROVIDERS = {
"draft": ["huggingface", "wavespeed:qwen-image"], # Fast, low cost
"standard": ["stability:core", "wavespeed:ideogram-v3-turbo"], # Balanced
"premium": ["wavespeed:ideogram-v3-turbo", "stability:ultra"], # Best quality
}
def __init__(self):
"""Initialize Create Studio service."""
self.template_manager = TemplateManager()
logger.info("[Create Studio] Initialized with template manager")
def _get_provider_instance(self, provider_name: str, api_key: Optional[str] = None):
"""Get provider instance by name.
Args:
provider_name: Name of the provider
api_key: Optional API key (uses env vars if not provided)
Returns:
Provider instance
Raises:
ValueError: If provider is not supported
"""
if provider_name == "stability":
return StabilityImageProvider(api_key=api_key or os.getenv("STABILITY_API_KEY"))
elif provider_name == "wavespeed":
return WaveSpeedImageProvider(api_key=api_key or os.getenv("WAVESPEED_API_KEY"))
elif provider_name == "huggingface":
return HuggingFaceImageProvider(api_token=api_key or os.getenv("HF_API_KEY"))
elif provider_name == "gemini":
return GeminiImageProvider(api_key=api_key or os.getenv("GEMINI_API_KEY"))
else:
raise ValueError(f"Unsupported provider: {provider_name}")
def _select_provider_and_model(
self,
request: CreateStudioRequest,
template: Optional[ImageTemplate] = None
) -> tuple[str, Optional[str]]:
"""Smart provider and model selection.
Args:
request: Create studio request
template: Optional template with recommendations
Returns:
Tuple of (provider_name, model_name)
"""
# Explicit provider selection
if request.provider and request.provider != "auto":
provider = request.provider
model = request.model
logger.info("[Provider Selection] User specified: %s (model: %s)", provider, model)
return provider, model
# Template recommendation
if template and template.recommended_provider:
provider = template.recommended_provider
logger.info("[Provider Selection] Template recommends: %s", provider)
# Map provider to specific model if not specified
if not request.model:
if provider == "ideogram":
return "wavespeed", "ideogram-v3-turbo"
elif provider == "qwen":
return "wavespeed", "qwen-image"
elif provider == "stability":
# Choose based on quality
if request.quality == "premium":
return "stability", "stability-ultra"
elif request.quality == "draft":
return "stability", "stability-core"
else:
return "stability", "stability-core"
return provider, request.model
# Quality-based selection
quality_options = self.QUALITY_PROVIDERS.get(request.quality, self.QUALITY_PROVIDERS["standard"])
selected = quality_options[0] # Pick first option
if ":" in selected:
provider, model = selected.split(":", 1)
else:
provider = selected
model = None
logger.info("[Provider Selection] Quality-based (%s): %s (model: %s)",
request.quality, provider, model)
return provider, model
def _enhance_prompt(self, prompt: str, style_preset: Optional[str] = None) -> str:
"""Enhance prompt with style and quality descriptors.
Args:
prompt: Original prompt
style_preset: Style preset to apply
Returns:
Enhanced prompt
"""
enhanced = prompt
# Add style-specific enhancements
style_enhancements = {
"photographic": ", professional photography, high quality, detailed, sharp focus, natural lighting",
"digital-art": ", digital art, vibrant colors, detailed, high quality, artstation trending",
"cinematic": ", cinematic lighting, dramatic, film grain, high quality, professional",
"3d-model": ", 3D render, octane render, unreal engine, high quality, detailed",
"anime": ", anime style, vibrant colors, detailed, high quality",
"line-art": ", clean line art, detailed linework, high contrast, professional",
}
if style_preset and style_preset in style_enhancements:
enhanced += style_enhancements[style_preset]
logger.info("[Prompt Enhancement] Original: %s", prompt[:100])
logger.info("[Prompt Enhancement] Enhanced: %s", enhanced[:100])
return enhanced
def _apply_template(self, request: CreateStudioRequest, template: ImageTemplate) -> CreateStudioRequest:
"""Apply template settings to request.
Args:
request: Original request
template: Template to apply
Returns:
Modified request
"""
# Apply template dimensions if not specified
if not request.width and not request.height:
request.width = template.aspect_ratio.width
request.height = template.aspect_ratio.height
# Apply template style if not specified
if not request.style_preset:
request.style_preset = template.style_preset
# Apply template quality if not specified
if request.quality == "standard":
request.quality = template.quality
logger.info("[Template Applied] %s -> %dx%d, style=%s, quality=%s",
template.name, request.width, request.height,
request.style_preset, request.quality)
return request
def _calculate_dimensions(
self,
width: Optional[int],
height: Optional[int],
aspect_ratio: Optional[str]
) -> tuple[int, int]:
"""Calculate image dimensions from width/height or aspect ratio.
Args:
width: Explicit width
height: Explicit height
aspect_ratio: Aspect ratio string (e.g., "16:9")
Returns:
Tuple of (width, height)
"""
# Both dimensions specified
if width and height:
return width, height
# Aspect ratio specified
if aspect_ratio:
try:
w_ratio, h_ratio = map(int, aspect_ratio.split(":"))
# Use width if specified
if width:
height = int(width * h_ratio / w_ratio)
return width, height
# Use height if specified
if height:
width = int(height * w_ratio / h_ratio)
return width, height
# Default size based on aspect ratio
# Use 1080p as base
if w_ratio >= h_ratio:
# Landscape or square
width = 1920
height = int(1920 * h_ratio / w_ratio)
else:
# Portrait
height = 1920
width = int(1920 * w_ratio / h_ratio)
return width, height
except ValueError:
logger.warning("[Dimensions] Invalid aspect ratio: %s", aspect_ratio)
# Default dimensions
return 1024, 1024
async def generate(
self,
request: CreateStudioRequest,
user_id: Optional[str] = None
) -> Dict[str, Any]:
"""Generate image(s) using Create Studio.
Args:
request: Create studio request
user_id: User ID for validation and tracking
Returns:
Dictionary with generation results
Raises:
ValueError: If request is invalid
RuntimeError: If generation fails
"""
logger.info("[Create Studio] Starting generation: prompt=%s, template=%s",
request.prompt[:100], request.template_id)
# Pre-flight validation: Check subscription and usage limits
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_generation_operations
from fastapi import HTTPException
db = next(get_db())
try:
pricing_service = PricingService(db)
logger.info(f"[Create Studio] 🛂 Running pre-flight validation for user {user_id}")
validate_image_generation_operations(
pricing_service=pricing_service,
user_id=user_id,
num_images=request.num_variations
)
logger.info(f"[Create Studio] ✅ Pre-flight validation passed - proceeding with generation")
except HTTPException as http_ex:
logger.error(f"[Create Studio] ❌ Pre-flight validation failed - blocking generation")
raise
finally:
db.close()
else:
logger.warning("[Create Studio] ⚠️ No user_id provided - skipping pre-flight validation")
# Load template if specified
template = None
if request.template_id:
template = self.template_manager.get_by_id(request.template_id)
if not template:
raise ValueError(f"Template not found: {request.template_id}")
# Apply template settings
request = self._apply_template(request, template)
# Calculate dimensions
width, height = self._calculate_dimensions(
request.width, request.height, request.aspect_ratio
)
# Enhance prompt if requested
prompt = request.prompt
if request.enhance_prompt:
prompt = self._enhance_prompt(prompt, request.style_preset)
# Select provider and model
provider_name, model = self._select_provider_and_model(request, template)
# Get provider instance
try:
provider = self._get_provider_instance(provider_name)
except Exception as e:
logger.error("[Create Studio] ❌ Failed to initialize provider %s: %s",
provider_name, str(e))
raise RuntimeError(f"Provider initialization failed: {str(e)}")
# Generate images
results = []
for i in range(request.num_variations):
logger.info("[Create Studio] Generating variation %d/%d",
i + 1, request.num_variations)
try:
# Prepare options
options = ImageGenerationOptions(
prompt=prompt,
negative_prompt=request.negative_prompt,
width=width,
height=height,
guidance_scale=request.guidance_scale,
steps=request.steps,
seed=request.seed + i if request.seed else None,
model=model,
extra={"style_preset": request.style_preset} if request.style_preset else {}
)
# Generate image
result: ImageGenerationResult = provider.generate(options)
results.append({
"image_bytes": result.image_bytes,
"width": result.width,
"height": result.height,
"provider": result.provider,
"model": result.model,
"seed": result.seed,
"metadata": result.metadata,
"variation": i + 1,
})
logger.info("[Create Studio] ✅ Variation %d generated successfully", i + 1)
except Exception as e:
logger.error("[Create Studio] ❌ Failed to generate variation %d: %s",
i + 1, str(e), exc_info=True)
results.append({
"error": str(e),
"variation": i + 1,
})
# Return results
return {
"success": True,
"request": {
"prompt": request.prompt,
"enhanced_prompt": prompt if request.enhance_prompt else None,
"template_id": request.template_id,
"template_name": template.name if template else None,
"provider": provider_name,
"model": model,
"dimensions": f"{width}x{height}",
"quality": request.quality,
"num_variations": request.num_variations,
},
"results": results,
"total_generated": sum(1 for r in results if "image_bytes" in r),
"total_failed": sum(1 for r in results if "error" in r),
}
def get_templates(
self,
platform: Optional[Platform] = None,
category: Optional[TemplateCategory] = None
) -> List[ImageTemplate]:
"""Get available templates.
Args:
platform: Filter by platform
category: Filter by category
Returns:
List of templates
"""
if platform:
return self.template_manager.get_by_platform(platform)
elif category:
return self.template_manager.get_by_category(category)
else:
return self.template_manager.get_all_templates()
def search_templates(self, query: str) -> List[ImageTemplate]:
"""Search templates by query.
Args:
query: Search query
Returns:
List of matching templates
"""
return self.template_manager.search(query)
def recommend_templates(
self,
use_case: str,
platform: Optional[Platform] = None
) -> List[ImageTemplate]:
"""Recommend templates based on use case.
Args:
use_case: Description of use case
platform: Optional platform filter
Returns:
List of recommended templates
"""
return self.template_manager.recommend_for_use_case(use_case, platform)

View File

@@ -0,0 +1,458 @@
"""Edit Studio service for AI-powered image editing and transformations."""
from __future__ import annotations
import asyncio
import base64
import io
from dataclasses import dataclass, field
from typing import Any, Dict, Literal, Optional
from PIL import Image
from services.llm_providers.main_image_editing import edit_image as huggingface_edit_image
from services.stability_service import StabilityAIService
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.edit")
EditOperationType = Literal[
"remove_background",
"inpaint",
"outpaint",
"search_replace",
"search_recolor",
"relight",
"general_edit",
]
@dataclass
class EditStudioRequest:
"""Normalized request payload for Edit Studio operations."""
image_base64: str
operation: EditOperationType
prompt: Optional[str] = None
negative_prompt: Optional[str] = None
mask_base64: Optional[str] = None
search_prompt: Optional[str] = None
select_prompt: Optional[str] = None
background_image_base64: Optional[str] = None
lighting_image_base64: Optional[str] = None
expand_left: Optional[int] = None
expand_right: Optional[int] = None
expand_up: Optional[int] = None
expand_down: Optional[int] = None
provider: Optional[str] = None
model: Optional[str] = None
style_preset: Optional[str] = None
guidance_scale: Optional[float] = None
steps: Optional[int] = None
seed: Optional[int] = None
output_format: str = "png"
options: Dict[str, Any] = field(default_factory=dict)
class EditStudioService:
"""Service layer orchestrating Edit Studio operations."""
SUPPORTED_OPERATIONS: Dict[EditOperationType, Dict[str, Any]] = {
"remove_background": {
"label": "Remove Background",
"description": "Isolate the main subject and remove the background.",
"provider": "stability",
"async": False,
"fields": {
"prompt": False,
"mask": False,
"negative_prompt": False,
"search_prompt": False,
"select_prompt": False,
"background": False,
"lighting": False,
"expansion": False,
},
},
"inpaint": {
"label": "Inpaint & Fix",
"description": "Edit specific regions using prompts and optional masks.",
"provider": "stability",
"async": False,
"fields": {
"prompt": True,
"mask": True,
"negative_prompt": True,
"search_prompt": False,
"select_prompt": False,
"background": False,
"lighting": False,
"expansion": False,
},
},
"outpaint": {
"label": "Outpaint",
"description": "Extend the canvas in any direction with smart fill.",
"provider": "stability",
"async": False,
"fields": {
"prompt": False,
"mask": False,
"negative_prompt": True,
"search_prompt": False,
"select_prompt": False,
"background": False,
"lighting": False,
"expansion": True,
},
},
"search_replace": {
"label": "Search & Replace",
"description": "Locate objects via search prompt and replace them.",
"provider": "stability",
"async": False,
"fields": {
"prompt": True,
"mask": False,
"negative_prompt": False,
"search_prompt": True,
"select_prompt": False,
"background": False,
"lighting": False,
"expansion": False,
},
},
"search_recolor": {
"label": "Search & Recolor",
"description": "Select elements via prompt and recolor them.",
"provider": "stability",
"async": False,
"fields": {
"prompt": True,
"mask": False,
"negative_prompt": False,
"search_prompt": False,
"select_prompt": True,
"background": False,
"lighting": False,
"expansion": False,
},
},
"relight": {
"label": "Replace Background & Relight",
"description": "Swap backgrounds and relight using reference images.",
"provider": "stability",
"async": True,
"fields": {
"prompt": False,
"mask": False,
"negative_prompt": False,
"search_prompt": False,
"select_prompt": False,
"background": True,
"lighting": True,
"expansion": False,
},
},
"general_edit": {
"label": "Prompt-based Edit",
"description": "Free-form editing powered by Hugging Face image-to-image models.",
"provider": "huggingface",
"async": False,
"fields": {
"prompt": True,
"mask": False,
"negative_prompt": True,
"search_prompt": False,
"select_prompt": False,
"background": False,
"lighting": False,
"expansion": False,
},
},
}
def __init__(self):
logger.info("[Edit Studio] Initialized edit service")
@staticmethod
def _decode_base64_image(value: Optional[str]) -> Optional[bytes]:
"""Decode a base64 (or data URL) string to bytes."""
if not value:
return None
try:
# Handle data URLs (data:image/png;base64,...)
if value.startswith("data:"):
_, b64data = value.split(",", 1)
else:
b64data = value
return base64.b64decode(b64data)
except Exception as exc:
logger.error(f"[Edit Studio] Failed to decode base64 image: {exc}")
raise ValueError("Invalid base64 image payload") from exc
@staticmethod
def _image_bytes_to_metadata(image_bytes: bytes) -> Dict[str, Any]:
"""Extract width/height metadata from image bytes."""
with Image.open(io.BytesIO(image_bytes)) as img:
return {
"width": img.width,
"height": img.height,
}
@staticmethod
def _bytes_to_base64(image_bytes: bytes, output_format: str = "png") -> str:
"""Convert raw bytes to base64 data URL."""
b64 = base64.b64encode(image_bytes).decode("utf-8")
return f"data:image/{output_format};base64,{b64}"
def list_operations(self) -> Dict[str, Dict[str, Any]]:
"""Expose supported operations for UI rendering."""
return self.SUPPORTED_OPERATIONS
async def process_edit(
self,
request: EditStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Process edit request and return normalized response."""
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_editing_operations
from fastapi import HTTPException
db = next(get_db())
try:
pricing_service = PricingService(db)
logger.info(f"[Edit Studio] 🛂 Running pre-flight validation for user {user_id}")
validate_image_editing_operations(
pricing_service=pricing_service,
user_id=user_id,
)
logger.info("[Edit Studio] ✅ Pre-flight validation passed")
except HTTPException:
logger.error("[Edit Studio] ❌ Pre-flight validation failed")
raise
finally:
db.close()
else:
logger.warning("[Edit Studio] ⚠️ No user_id provided - skipping pre-flight validation")
image_bytes = self._decode_base64_image(request.image_base64)
if not image_bytes:
raise ValueError("Primary image payload is required")
mask_bytes = self._decode_base64_image(request.mask_base64)
background_bytes = self._decode_base64_image(request.background_image_base64)
lighting_bytes = self._decode_base64_image(request.lighting_image_base64)
operation = request.operation
logger.info("[Edit Studio] Processing operation='%s' for user=%s", operation, user_id)
if operation not in self.SUPPORTED_OPERATIONS:
raise ValueError(f"Unsupported edit operation: {operation}")
if operation in {"remove_background", "inpaint", "outpaint", "search_replace", "search_recolor", "relight"}:
image_bytes = await self._handle_stability_edit(
operation=operation,
request=request,
image_bytes=image_bytes,
mask_bytes=mask_bytes,
background_bytes=background_bytes,
lighting_bytes=lighting_bytes,
)
else:
image_bytes = await self._handle_general_edit(
request=request,
image_bytes=image_bytes,
mask_bytes=mask_bytes,
user_id=user_id,
)
metadata = self._image_bytes_to_metadata(image_bytes)
metadata.update(
{
"operation": operation,
"style_preset": request.style_preset,
"provider": self.SUPPORTED_OPERATIONS[operation]["provider"],
}
)
response = {
"success": True,
"operation": operation,
"provider": metadata["provider"],
"image_base64": self._bytes_to_base64(image_bytes, request.output_format),
"width": metadata["width"],
"height": metadata["height"],
"metadata": metadata,
}
logger.info("[Edit Studio] ✅ Operation '%s' completed", operation)
return response
async def _handle_stability_edit(
self,
operation: EditOperationType,
request: EditStudioRequest,
image_bytes: bytes,
mask_bytes: Optional[bytes],
background_bytes: Optional[bytes],
lighting_bytes: Optional[bytes],
) -> bytes:
"""Execute Stability AI edit workflows."""
stability_service = StabilityAIService()
async with stability_service:
if operation == "remove_background":
result = await stability_service.remove_background(
image=image_bytes,
output_format=request.output_format,
)
elif operation == "inpaint":
if not request.prompt:
raise ValueError("Prompt is required for inpainting")
result = await stability_service.inpaint(
image=image_bytes,
prompt=request.prompt,
mask=mask_bytes,
negative_prompt=request.negative_prompt,
output_format=request.output_format,
style_preset=request.style_preset,
grow_mask=request.options.get("grow_mask", 5),
)
elif operation == "outpaint":
result = await stability_service.outpaint(
image=image_bytes,
prompt=request.prompt,
negative_prompt=request.negative_prompt,
output_format=request.output_format,
left=request.expand_left or 0,
right=request.expand_right or 0,
up=request.expand_up or 0,
down=request.expand_down or 0,
style_preset=request.style_preset,
)
elif operation == "search_replace":
if not (request.prompt and request.search_prompt):
raise ValueError("Both prompt and search_prompt are required for search & replace")
result = await stability_service.search_and_replace(
image=image_bytes,
prompt=request.prompt,
search_prompt=request.search_prompt,
output_format=request.output_format,
)
elif operation == "search_recolor":
if not (request.prompt and request.select_prompt):
raise ValueError("Both prompt and select_prompt are required for search & recolor")
result = await stability_service.search_and_recolor(
image=image_bytes,
prompt=request.prompt,
select_prompt=request.select_prompt,
output_format=request.output_format,
)
elif operation == "relight":
if not background_bytes and not lighting_bytes:
raise ValueError("At least one reference (background or lighting) is required for relight")
result = await stability_service.replace_background_and_relight(
subject_image=image_bytes,
background_reference=background_bytes,
light_reference=lighting_bytes,
output_format=request.output_format,
)
if isinstance(result, dict) and result.get("id"):
result = await self._poll_stability_result(
stability_service,
generation_id=result["id"],
output_format=request.output_format,
)
else:
raise ValueError(f"Unsupported Stability operation: {operation}")
return self._extract_image_bytes(result)
async def _handle_general_edit(
self,
request: EditStudioRequest,
image_bytes: bytes,
mask_bytes: Optional[bytes],
user_id: Optional[str],
) -> bytes:
"""Execute Hugging Face powered general editing (synchronous API)."""
if not request.prompt:
raise ValueError("Prompt is required for general edits")
options = {
"provider": request.provider or "huggingface",
"model": request.model,
"guidance_scale": request.guidance_scale,
"steps": request.steps,
"seed": request.seed,
}
# huggingface edit is synchronous - run in thread
result = await asyncio.to_thread(
huggingface_edit_image,
image_bytes,
request.prompt,
options,
user_id,
)
return result.image_bytes
@staticmethod
def _extract_image_bytes(result: Any) -> bytes:
"""Normalize Stability responses into raw image bytes."""
if isinstance(result, bytes):
return result
if isinstance(result, dict):
artifacts = result.get("artifacts") or result.get("data") or result.get("images") or []
for artifact in artifacts:
if isinstance(artifact, dict):
if artifact.get("base64"):
return base64.b64decode(artifact["base64"])
if artifact.get("b64_json"):
return base64.b64decode(artifact["b64_json"])
raise RuntimeError("Unable to extract image bytes from provider response")
async def _poll_stability_result(
self,
stability_service: StabilityAIService,
generation_id: str,
output_format: str,
timeout_seconds: int = 240,
interval_seconds: float = 2.0,
) -> bytes:
"""Poll Stability async endpoint until result is ready."""
elapsed = 0.0
while elapsed < timeout_seconds:
result = await stability_service.get_generation_result(
generation_id=generation_id,
accept_type="*/*",
)
if isinstance(result, bytes):
return result
if isinstance(result, dict):
state = (result.get("state") or result.get("status") or "").lower()
if state in {"succeeded", "success", "ready", "completed"}:
return self._extract_image_bytes(result)
if state in {"failed", "error"}:
raise RuntimeError(f"Stability generation failed: {result}")
await asyncio.sleep(interval_seconds)
elapsed += interval_seconds
raise RuntimeError("Timed out waiting for Stability generation result")

View File

@@ -0,0 +1,304 @@
"""Image Studio Manager - Main orchestration service for all image operations."""
from typing import Optional, Dict, Any, List
from .create_service import CreateStudioService, CreateStudioRequest
from .edit_service import EditStudioService, EditStudioRequest
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
from .templates import Platform, TemplateCategory, ImageTemplate
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.manager")
class ImageStudioManager:
"""Main manager for Image Studio operations."""
def __init__(self):
"""Initialize Image Studio Manager."""
self.create_service = CreateStudioService()
self.edit_service = EditStudioService()
self.upscale_service = UpscaleStudioService()
logger.info("[Image Studio Manager] Initialized successfully")
# ====================
# CREATE STUDIO
# ====================
async def create_image(
self,
request: CreateStudioRequest,
user_id: Optional[str] = None
) -> Dict[str, Any]:
"""Create/generate image using Create Studio.
Args:
request: Create studio request
user_id: User ID for validation
Returns:
Dictionary with generation results
"""
logger.info("[Image Studio] Create image request from user: %s", user_id)
return await self.create_service.generate(request, user_id=user_id)
# ====================
# EDIT STUDIO
# ====================
async def edit_image(
self,
request: EditStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Run Edit Studio operations."""
logger.info("[Image Studio] Edit image request from user: %s", user_id)
return await self.edit_service.process_edit(request, user_id=user_id)
def get_edit_operations(self) -> Dict[str, Any]:
"""Expose edit operations for UI."""
return self.edit_service.list_operations()
# ====================
# UPSCALE STUDIO
# ====================
async def upscale_image(
self,
request: UpscaleStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Run Upscale Studio operations."""
logger.info("[Image Studio] Upscale request from user: %s", user_id)
return await self.upscale_service.process_upscale(request, user_id=user_id)
def get_templates(
self,
platform: Optional[Platform] = None,
category: Optional[TemplateCategory] = None
) -> List[ImageTemplate]:
"""Get available templates.
Args:
platform: Filter by platform
category: Filter by category
Returns:
List of templates
"""
return self.create_service.get_templates(platform=platform, category=category)
def search_templates(self, query: str) -> List[ImageTemplate]:
"""Search templates by query.
Args:
query: Search query
Returns:
List of matching templates
"""
return self.create_service.search_templates(query)
def recommend_templates(
self,
use_case: str,
platform: Optional[Platform] = None
) -> List[ImageTemplate]:
"""Recommend templates based on use case.
Args:
use_case: Use case description
platform: Optional platform filter
Returns:
List of recommended templates
"""
return self.create_service.recommend_templates(use_case, platform)
def get_providers(self) -> Dict[str, Any]:
"""Get available image providers and their capabilities.
Returns:
Dictionary of providers with capabilities
"""
return {
"stability": {
"name": "Stability AI",
"models": ["ultra", "core", "sd3.5-large"],
"capabilities": ["text-to-image", "editing", "upscaling", "control", "3d"],
"max_resolution": (2048, 2048),
"cost_range": "3-8 credits per image",
},
"wavespeed": {
"name": "WaveSpeed AI",
"models": ["ideogram-v3-turbo", "qwen-image"],
"capabilities": ["text-to-image", "photorealistic", "fast-generation"],
"max_resolution": (1024, 1024),
"cost_range": "$0.05-$0.10 per image",
},
"huggingface": {
"name": "HuggingFace",
"models": ["FLUX.1-Krea-dev", "RunwayML"],
"capabilities": ["text-to-image", "image-to-image"],
"max_resolution": (1024, 1024),
"cost_range": "Free tier available",
},
"gemini": {
"name": "Google Gemini",
"models": ["imagen-3.0"],
"capabilities": ["text-to-image", "conversational-editing"],
"max_resolution": (1024, 1024),
"cost_range": "Free tier available",
}
}
# ====================
# COST ESTIMATION
# ====================
def estimate_cost(
self,
provider: str,
model: Optional[str],
operation: str,
num_images: int = 1,
resolution: Optional[tuple[int, int]] = None
) -> Dict[str, Any]:
"""Estimate cost for image operations.
Args:
provider: Provider name
model: Model name
operation: Operation type (generate, edit, upscale, etc.)
num_images: Number of images
resolution: Image resolution (width, height)
Returns:
Cost estimation details
"""
# Base costs (adjust based on actual pricing)
base_costs = {
"stability": {
"ultra": 0.08, # 8 credits
"core": 0.03, # 3 credits
"sd3": 0.065, # 6.5 credits
},
"wavespeed": {
"ideogram-v3-turbo": 0.10,
"qwen-image": 0.05,
},
"huggingface": {
"default": 0.0, # Free tier
},
"gemini": {
"default": 0.0, # Free tier
}
}
# Get base cost
provider_costs = base_costs.get(provider, {})
cost_per_image = provider_costs.get(model, provider_costs.get("default", 0.0))
# Calculate total
total_cost = cost_per_image * num_images
return {
"provider": provider,
"model": model,
"operation": operation,
"num_images": num_images,
"resolution": f"{resolution[0]}x{resolution[1]}" if resolution else "default",
"cost_per_image": cost_per_image,
"total_cost": total_cost,
"currency": "USD",
"estimated": True,
}
# ====================
# PLATFORM SPECS
# ====================
def get_platform_specs(self, platform: Platform) -> Dict[str, Any]:
"""Get platform specifications and requirements.
Args:
platform: Platform to get specs for
Returns:
Platform specifications
"""
specs = {
Platform.INSTAGRAM: {
"name": "Instagram",
"formats": [
{"name": "Feed Post (Square)", "ratio": "1:1", "size": "1080x1080"},
{"name": "Feed Post (Portrait)", "ratio": "4:5", "size": "1080x1350"},
{"name": "Story", "ratio": "9:16", "size": "1080x1920"},
{"name": "Reel", "ratio": "9:16", "size": "1080x1920"},
],
"file_types": ["JPG", "PNG"],
"max_file_size": "30MB",
},
Platform.FACEBOOK: {
"name": "Facebook",
"formats": [
{"name": "Feed Post", "ratio": "1.91:1", "size": "1200x630"},
{"name": "Feed Post (Square)", "ratio": "1:1", "size": "1080x1080"},
{"name": "Story", "ratio": "9:16", "size": "1080x1920"},
{"name": "Cover Photo", "ratio": "16:9", "size": "820x312"},
],
"file_types": ["JPG", "PNG"],
"max_file_size": "30MB",
},
Platform.TWITTER: {
"name": "Twitter/X",
"formats": [
{"name": "Post", "ratio": "16:9", "size": "1200x675"},
{"name": "Card", "ratio": "2:1", "size": "1200x600"},
{"name": "Header", "ratio": "3:1", "size": "1500x500"},
],
"file_types": ["JPG", "PNG", "GIF"],
"max_file_size": "5MB",
},
Platform.LINKEDIN: {
"name": "LinkedIn",
"formats": [
{"name": "Feed Post", "ratio": "1.91:1", "size": "1200x628"},
{"name": "Feed Post (Square)", "ratio": "1:1", "size": "1080x1080"},
{"name": "Article", "ratio": "2:1", "size": "1200x627"},
{"name": "Company Cover", "ratio": "4:1", "size": "1128x191"},
],
"file_types": ["JPG", "PNG"],
"max_file_size": "8MB",
},
Platform.YOUTUBE: {
"name": "YouTube",
"formats": [
{"name": "Thumbnail", "ratio": "16:9", "size": "1280x720"},
{"name": "Channel Art", "ratio": "16:9", "size": "2560x1440"},
],
"file_types": ["JPG", "PNG"],
"max_file_size": "2MB",
},
Platform.PINTEREST: {
"name": "Pinterest",
"formats": [
{"name": "Pin", "ratio": "2:3", "size": "1000x1500"},
{"name": "Story Pin", "ratio": "9:16", "size": "1080x1920"},
],
"file_types": ["JPG", "PNG"],
"max_file_size": "20MB",
},
Platform.TIKTOK: {
"name": "TikTok",
"formats": [
{"name": "Video Cover", "ratio": "9:16", "size": "1080x1920"},
],
"file_types": ["JPG", "PNG"],
"max_file_size": "10MB",
},
}
return specs.get(platform, {})

View File

@@ -0,0 +1,555 @@
"""Template system for Image Studio with platform-specific presets."""
from dataclasses import dataclass
from typing import Dict, List, Optional, Literal
from enum import Enum
class Platform(str, Enum):
"""Supported social media platforms."""
INSTAGRAM = "instagram"
FACEBOOK = "facebook"
TWITTER = "twitter"
LINKEDIN = "linkedin"
YOUTUBE = "youtube"
PINTEREST = "pinterest"
TIKTOK = "tiktok"
BLOG = "blog"
EMAIL = "email"
WEBSITE = "website"
class TemplateCategory(str, Enum):
"""Template categories."""
SOCIAL_MEDIA = "social_media"
BLOG_CONTENT = "blog_content"
AD_CREATIVE = "ad_creative"
PRODUCT = "product"
BRAND_ASSETS = "brand_assets"
EMAIL_MARKETING = "email_marketing"
@dataclass
class AspectRatio:
"""Aspect ratio configuration."""
ratio: str # e.g., "1:1", "16:9"
width: int
height: int
label: str # e.g., "Square", "Widescreen"
@dataclass
class ImageTemplate:
"""Image generation template."""
id: str
name: str
category: TemplateCategory
platform: Optional[Platform]
aspect_ratio: AspectRatio
description: str
recommended_provider: str
style_preset: str
quality: Literal["draft", "standard", "premium"]
prompt_template: Optional[str] = None
negative_prompt_template: Optional[str] = None
use_cases: List[str] = None
class PlatformTemplates:
"""Platform-specific template definitions."""
# Aspect Ratios
SQUARE_1_1 = AspectRatio("1:1", 1080, 1080, "Square")
PORTRAIT_4_5 = AspectRatio("4:5", 1080, 1350, "Portrait")
STORY_9_16 = AspectRatio("9:16", 1080, 1920, "Story/Reel")
LANDSCAPE_16_9 = AspectRatio("16:9", 1920, 1080, "Landscape")
WIDE_21_9 = AspectRatio("21:9", 2560, 1080, "Ultra Wide")
TWITTER_2_1 = AspectRatio("2:1", 1200, 600, "Twitter Card")
TWITTER_3_1 = AspectRatio("3:1", 1500, 500, "Twitter Header")
FACEBOOK_1_91_1 = AspectRatio("1.91:1", 1200, 630, "Facebook Feed")
LINKEDIN_1_91_1 = AspectRatio("1.91:1", 1200, 628, "LinkedIn Feed")
LINKEDIN_2_1 = AspectRatio("2:1", 1200, 627, "LinkedIn Article")
LINKEDIN_4_1 = AspectRatio("4:1", 1128, 191, "LinkedIn Cover")
PINTEREST_2_3 = AspectRatio("2:3", 1000, 1500, "Pinterest Pin")
YOUTUBE_16_9 = AspectRatio("16:9", 1280, 720, "YouTube Thumbnail")
FACEBOOK_COVER_16_9 = AspectRatio("16:9", 820, 312, "Facebook Cover")
@classmethod
def get_platform_templates(cls) -> Dict[Platform, List[ImageTemplate]]:
"""Get all platform-specific templates."""
return {
Platform.INSTAGRAM: cls._instagram_templates(),
Platform.FACEBOOK: cls._facebook_templates(),
Platform.TWITTER: cls._twitter_templates(),
Platform.LINKEDIN: cls._linkedin_templates(),
Platform.YOUTUBE: cls._youtube_templates(),
Platform.PINTEREST: cls._pinterest_templates(),
Platform.TIKTOK: cls._tiktok_templates(),
Platform.BLOG: cls._blog_templates(),
Platform.EMAIL: cls._email_templates(),
Platform.WEBSITE: cls._website_templates(),
}
@classmethod
def _instagram_templates(cls) -> List[ImageTemplate]:
"""Instagram templates."""
return [
ImageTemplate(
id="instagram_feed_square",
name="Instagram Feed Post (Square)",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.INSTAGRAM,
aspect_ratio=cls.SQUARE_1_1,
description="Perfect for Instagram feed posts with maximum visibility",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Product showcase", "Lifestyle posts", "Brand content"]
),
ImageTemplate(
id="instagram_feed_portrait",
name="Instagram Feed Post (Portrait)",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.INSTAGRAM,
aspect_ratio=cls.PORTRAIT_4_5,
description="Vertical format for maximum feed real estate",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Fashion", "Food", "Product photography"]
),
ImageTemplate(
id="instagram_story",
name="Instagram Story",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.INSTAGRAM,
aspect_ratio=cls.STORY_9_16,
description="Full-screen vertical stories",
recommended_provider="ideogram",
style_preset="digital-art",
quality="standard",
use_cases=["Behind-the-scenes", "Announcements", "Quick updates"]
),
ImageTemplate(
id="instagram_reel_cover",
name="Instagram Reel Cover",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.INSTAGRAM,
aspect_ratio=cls.STORY_9_16,
description="Eye-catching reel cover images",
recommended_provider="ideogram",
style_preset="cinematic",
quality="premium",
use_cases=["Video covers", "Thumbnails", "Highlights"]
),
]
@classmethod
def _facebook_templates(cls) -> List[ImageTemplate]:
"""Facebook templates."""
return [
ImageTemplate(
id="facebook_feed",
name="Facebook Feed Post",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.FACEBOOK,
aspect_ratio=cls.FACEBOOK_1_91_1,
description="Optimized for Facebook news feed",
recommended_provider="ideogram",
style_preset="photographic",
quality="standard",
use_cases=["Page posts", "Shared content", "Community posts"]
),
ImageTemplate(
id="facebook_feed_square",
name="Facebook Feed Post (Square)",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.FACEBOOK,
aspect_ratio=cls.SQUARE_1_1,
description="Square format for feed posts",
recommended_provider="ideogram",
style_preset="photographic",
quality="standard",
use_cases=["Page posts", "Product highlights"]
),
ImageTemplate(
id="facebook_story",
name="Facebook Story",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.FACEBOOK,
aspect_ratio=cls.STORY_9_16,
description="Full-screen vertical stories",
recommended_provider="ideogram",
style_preset="digital-art",
quality="standard",
use_cases=["Quick updates", "Promotions", "Events"]
),
ImageTemplate(
id="facebook_cover",
name="Facebook Cover Photo",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.FACEBOOK,
aspect_ratio=cls.FACEBOOK_COVER_16_9,
description="Wide cover photo for pages",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Page branding", "Events", "Seasonal updates"]
),
]
@classmethod
def _twitter_templates(cls) -> List[ImageTemplate]:
"""Twitter/X templates."""
return [
ImageTemplate(
id="twitter_post",
name="Twitter/X Post",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.TWITTER,
aspect_ratio=cls.LANDSCAPE_16_9,
description="Optimized for Twitter feed",
recommended_provider="ideogram",
style_preset="photographic",
quality="standard",
use_cases=["Tweets", "News", "Updates"]
),
ImageTemplate(
id="twitter_card",
name="Twitter Card",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.TWITTER,
aspect_ratio=cls.TWITTER_2_1,
description="Twitter card with link preview",
recommended_provider="ideogram",
style_preset="digital-art",
quality="standard",
use_cases=["Link sharing", "Articles", "Blog posts"]
),
ImageTemplate(
id="twitter_header",
name="Twitter Header",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.TWITTER,
aspect_ratio=cls.TWITTER_3_1,
description="Profile header image",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Profile branding", "Personal brand", "Business identity"]
),
]
@classmethod
def _linkedin_templates(cls) -> List[ImageTemplate]:
"""LinkedIn templates."""
return [
ImageTemplate(
id="linkedin_post",
name="LinkedIn Post",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.LINKEDIN,
aspect_ratio=cls.LINKEDIN_1_91_1,
description="Professional feed posts",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Professional content", "Industry news", "Thought leadership"]
),
ImageTemplate(
id="linkedin_post_square",
name="LinkedIn Post (Square)",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.LINKEDIN,
aspect_ratio=cls.SQUARE_1_1,
description="Square format for LinkedIn feed",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Quick tips", "Infographics", "Quotes"]
),
ImageTemplate(
id="linkedin_article",
name="LinkedIn Article Header",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.LINKEDIN,
aspect_ratio=cls.LINKEDIN_2_1,
description="Article header images",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Long-form content", "Articles", "Newsletters"]
),
ImageTemplate(
id="linkedin_cover",
name="LinkedIn Company Cover",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.LINKEDIN,
aspect_ratio=cls.LINKEDIN_4_1,
description="Company page cover photo",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Company branding", "Recruitment", "Brand identity"]
),
]
@classmethod
def _youtube_templates(cls) -> List[ImageTemplate]:
"""YouTube templates."""
return [
ImageTemplate(
id="youtube_thumbnail",
name="YouTube Thumbnail",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.YOUTUBE,
aspect_ratio=cls.YOUTUBE_16_9,
description="Eye-catching video thumbnails",
recommended_provider="ideogram",
style_preset="cinematic",
quality="premium",
use_cases=["Video thumbnails", "Channel branding", "Playlists"]
),
ImageTemplate(
id="youtube_channel_art",
name="YouTube Channel Art",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.YOUTUBE,
aspect_ratio=cls.LANDSCAPE_16_9,
description="Channel banner art",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Channel branding", "Personal brand", "Business identity"]
),
]
@classmethod
def _pinterest_templates(cls) -> List[ImageTemplate]:
"""Pinterest templates."""
return [
ImageTemplate(
id="pinterest_pin",
name="Pinterest Pin",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.PINTEREST,
aspect_ratio=cls.PINTEREST_2_3,
description="Vertical pin format",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Product pins", "DIY guides", "Recipes", "Inspiration"]
),
ImageTemplate(
id="pinterest_story",
name="Pinterest Story Pin",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.PINTEREST,
aspect_ratio=cls.STORY_9_16,
description="Full-screen story pins",
recommended_provider="ideogram",
style_preset="digital-art",
quality="standard",
use_cases=["Step-by-step guides", "Tutorials", "Quick tips"]
),
]
@classmethod
def _tiktok_templates(cls) -> List[ImageTemplate]:
"""TikTok templates."""
return [
ImageTemplate(
id="tiktok_video_cover",
name="TikTok Video Cover",
category=TemplateCategory.SOCIAL_MEDIA,
platform=Platform.TIKTOK,
aspect_ratio=cls.STORY_9_16,
description="Vertical video cover",
recommended_provider="ideogram",
style_preset="cinematic",
quality="premium",
use_cases=["Video covers", "Thumbnails", "Profile highlights"]
),
]
@classmethod
def _blog_templates(cls) -> List[ImageTemplate]:
"""Blog content templates."""
return [
ImageTemplate(
id="blog_header",
name="Blog Header",
category=TemplateCategory.BLOG_CONTENT,
platform=Platform.BLOG,
aspect_ratio=cls.LANDSCAPE_16_9,
description="Blog post featured image",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Featured images", "Article headers", "Post thumbnails"]
),
ImageTemplate(
id="blog_header_wide",
name="Blog Header (Wide)",
category=TemplateCategory.BLOG_CONTENT,
platform=Platform.BLOG,
aspect_ratio=cls.WIDE_21_9,
description="Ultra-wide blog header",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Hero sections", "Wide headers", "Landing pages"]
),
]
@classmethod
def _email_templates(cls) -> List[ImageTemplate]:
"""Email marketing templates."""
return [
ImageTemplate(
id="email_banner",
name="Email Banner",
category=TemplateCategory.EMAIL_MARKETING,
platform=Platform.EMAIL,
aspect_ratio=cls.LANDSCAPE_16_9,
description="Email header banner",
recommended_provider="ideogram",
style_preset="photographic",
quality="standard",
use_cases=["Email headers", "Newsletter banners", "Promotions"]
),
ImageTemplate(
id="email_product",
name="Email Product Image",
category=TemplateCategory.EMAIL_MARKETING,
platform=Platform.EMAIL,
aspect_ratio=cls.SQUARE_1_1,
description="Product showcase for emails",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Product highlights", "Promotions", "Offers"]
),
]
@classmethod
def _website_templates(cls) -> List[ImageTemplate]:
"""Website templates."""
return [
ImageTemplate(
id="website_hero",
name="Website Hero Image",
category=TemplateCategory.BRAND_ASSETS,
platform=Platform.WEBSITE,
aspect_ratio=cls.WIDE_21_9,
description="Hero section background",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Hero sections", "Landing pages", "Home page banners"]
),
ImageTemplate(
id="website_banner",
name="Website Banner",
category=TemplateCategory.BRAND_ASSETS,
platform=Platform.WEBSITE,
aspect_ratio=cls.LANDSCAPE_16_9,
description="Section banners",
recommended_provider="ideogram",
style_preset="photographic",
quality="premium",
use_cases=["Section headers", "Category pages", "Feature sections"]
),
]
class TemplateManager:
"""Manager for image templates with search and recommendation."""
def __init__(self):
"""Initialize template manager."""
self.templates = PlatformTemplates.get_platform_templates()
self._all_templates: Optional[List[ImageTemplate]] = None
def get_all_templates(self) -> List[ImageTemplate]:
"""Get all templates across all platforms."""
if self._all_templates is None:
self._all_templates = []
for platform_templates in self.templates.values():
self._all_templates.extend(platform_templates)
return self._all_templates
def get_by_platform(self, platform: Platform) -> List[ImageTemplate]:
"""Get templates for a specific platform."""
return self.templates.get(platform, [])
def get_by_category(self, category: TemplateCategory) -> List[ImageTemplate]:
"""Get templates by category."""
all_templates = self.get_all_templates()
return [t for t in all_templates if t.category == category]
def get_by_id(self, template_id: str) -> Optional[ImageTemplate]:
"""Get template by ID."""
all_templates = self.get_all_templates()
for template in all_templates:
if template.id == template_id:
return template
return None
def search(self, query: str) -> List[ImageTemplate]:
"""Search templates by query."""
query = query.lower()
all_templates = self.get_all_templates()
results = []
for template in all_templates:
# Search in name, description, and use cases
searchable = (
template.name.lower() + " " +
template.description.lower() + " " +
" ".join(template.use_cases or []).lower()
)
if query in searchable:
results.append(template)
return results
def recommend_for_use_case(self, use_case: str, platform: Optional[Platform] = None) -> List[ImageTemplate]:
"""Recommend templates based on use case and platform."""
use_case_lower = use_case.lower()
all_templates = self.get_all_templates()
# Filter by platform if specified
if platform:
all_templates = [t for t in all_templates if t.platform == platform]
# Find matching templates
matches = []
for template in all_templates:
if template.use_cases:
for case in template.use_cases:
if use_case_lower in case.lower():
matches.append(template)
break
return matches
def get_aspect_ratio_options(self) -> List[AspectRatio]:
"""Get all available aspect ratios."""
return [
PlatformTemplates.SQUARE_1_1,
PlatformTemplates.PORTRAIT_4_5,
PlatformTemplates.STORY_9_16,
PlatformTemplates.LANDSCAPE_16_9,
PlatformTemplates.WIDE_21_9,
PlatformTemplates.TWITTER_2_1,
PlatformTemplates.TWITTER_3_1,
PlatformTemplates.FACEBOOK_1_91_1,
PlatformTemplates.LINKEDIN_1_91_1,
PlatformTemplates.LINKEDIN_2_1,
PlatformTemplates.LINKEDIN_4_1,
PlatformTemplates.PINTEREST_2_3,
PlatformTemplates.YOUTUBE_16_9,
PlatformTemplates.FACEBOOK_COVER_16_9,
]

View File

@@ -0,0 +1,154 @@
import base64
import io
from dataclasses import dataclass
from typing import Literal, Optional, Dict, Any
from fastapi import HTTPException
from PIL import Image
from services.stability_service import StabilityAIService
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.upscale")
UpscaleMode = Literal["fast", "conservative", "creative", "auto"]
@dataclass
class UpscaleStudioRequest:
image_base64: str
mode: UpscaleMode = "auto"
target_width: Optional[int] = None
target_height: Optional[int] = None
preset: Optional[str] = None # e.g., web/print/social
prompt: Optional[str] = None # used for conservative/creative modes
class UpscaleStudioService:
"""Handles image upscaling workflows."""
def __init__(self):
logger.info("[Upscale Studio] Service initialized")
async def process_upscale(
self,
request: UpscaleStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_upscale_operations
db = next(get_db())
try:
pricing_service = PricingService(db)
logger.info("[Upscale Studio] 🛂 Running pre-flight validation for user %s", user_id)
validate_image_upscale_operations(pricing_service=pricing_service, user_id=user_id)
finally:
db.close()
image_bytes = self._decode_base64(request.image_base64)
if not image_bytes:
raise ValueError("Primary image is required for upscaling")
mode = self._resolve_mode(request)
async with StabilityAIService() as stability_service:
logger.info("[Upscale Studio] Running '%s' upscale for user=%s", mode, user_id)
params = {
"target_width": request.target_width,
"target_height": request.target_height,
}
# remove None values
params = {k: v for k, v in params.items() if v is not None}
if mode == "fast":
result = await stability_service.upscale_fast(
image=image_bytes,
**params,
)
elif mode == "conservative":
prompt = request.prompt or "High fidelity upscale preserving original details"
result = await stability_service.upscale_conservative(
image=image_bytes,
prompt=prompt,
**params,
)
elif mode == "creative":
prompt = request.prompt or "Creative upscale with enhanced artistic details"
result = await stability_service.upscale_creative(
image=image_bytes,
prompt=prompt,
**params,
)
else:
raise ValueError(f"Unsupported upscale mode: {mode}")
image_bytes = self._extract_image_bytes(result)
metadata = self._image_metadata(image_bytes)
return {
"success": True,
"mode": mode,
"image_base64": self._to_base64(image_bytes),
"width": metadata["width"],
"height": metadata["height"],
"metadata": {
"preset": request.preset,
"target_width": request.target_width,
"target_height": request.target_height,
"prompt": request.prompt,
},
}
@staticmethod
def _decode_base64(value: Optional[str]) -> Optional[bytes]:
if not value:
return None
try:
if value.startswith("data:"):
_, b64data = value.split(",", 1)
else:
b64data = value
return base64.b64decode(b64data)
except Exception as exc:
logger.error("[Upscale Studio] Failed to decode base64 image: %s", exc)
raise ValueError("Invalid base64 image payload") from exc
@staticmethod
def _to_base64(image_bytes: bytes) -> str:
return f"data:image/png;base64,{base64.b64encode(image_bytes).decode('utf-8')}"
@staticmethod
def _image_metadata(image_bytes: bytes) -> Dict[str, int]:
with Image.open(io.BytesIO(image_bytes)) as img:
return {"width": img.width, "height": img.height}
@staticmethod
def _extract_image_bytes(result: Any) -> bytes:
if isinstance(result, bytes):
return result
if isinstance(result, dict):
artifacts = result.get("artifacts") or result.get("data") or result.get("images") or []
for artifact in artifacts:
if isinstance(artifact, dict):
if artifact.get("base64"):
return base64.b64decode(artifact["base64"])
if artifact.get("b64_json"):
return base64.b64decode(artifact["b64_json"])
raise HTTPException(status_code=502, detail="Unable to extract image from provider response")
@staticmethod
def _resolve_mode(request: UpscaleStudioRequest) -> UpscaleMode:
if request.mode != "auto":
return request.mode
# simple heuristic: if target >= 3000px, use conservative, else fast
if (request.target_width and request.target_width >= 3000) or (
request.target_height and request.target_height >= 3000
):
return "conservative"
return "fast"

View File

@@ -2,6 +2,7 @@ from .base import ImageGenerationOptions, ImageGenerationResult, ImageGeneration
from .hf_provider import HuggingFaceImageProvider
from .gemini_provider import GeminiImageProvider
from .stability_provider import StabilityImageProvider
from .wavespeed_provider import WaveSpeedImageProvider
__all__ = [
"ImageGenerationOptions",
@@ -10,6 +11,7 @@ __all__ = [
"HuggingFaceImageProvider",
"GeminiImageProvider",
"StabilityImageProvider",
"WaveSpeedImageProvider",
]

View File

@@ -0,0 +1,243 @@
"""WaveSpeed AI image generation provider (Ideogram V3 Turbo & Qwen Image)."""
import io
import os
from typing import Optional
from PIL import Image
from .base import ImageGenerationProvider, ImageGenerationOptions, ImageGenerationResult
from services.wavespeed.client import WaveSpeedClient
from utils.logger_utils import get_service_logger
logger = get_service_logger("wavespeed.image_provider")
class WaveSpeedImageProvider(ImageGenerationProvider):
"""WaveSpeed AI image generation provider supporting Ideogram V3 and Qwen."""
SUPPORTED_MODELS = {
"ideogram-v3-turbo": {
"name": "Ideogram V3 Turbo",
"description": "Photorealistic generation with superior text rendering",
"cost_per_image": 0.10, # Estimated, adjust based on actual pricing
"max_resolution": (1024, 1024),
"default_steps": 20,
},
"qwen-image": {
"name": "Qwen Image",
"description": "Fast, high-quality text-to-image generation",
"cost_per_image": 0.05, # Estimated, adjust based on actual pricing
"max_resolution": (1024, 1024),
"default_steps": 15,
}
}
def __init__(self, api_key: Optional[str] = None):
"""Initialize WaveSpeed image provider.
Args:
api_key: WaveSpeed API key (falls back to env var if not provided)
"""
self.api_key = api_key or os.getenv("WAVESPEED_API_KEY")
if not self.api_key:
raise ValueError("WaveSpeed API key not found. Set WAVESPEED_API_KEY environment variable.")
self.client = WaveSpeedClient(api_key=self.api_key)
logger.info("[WaveSpeed Image Provider] Initialized with available models: %s",
list(self.SUPPORTED_MODELS.keys()))
def _validate_options(self, options: ImageGenerationOptions) -> None:
"""Validate generation options.
Args:
options: Image generation options
Raises:
ValueError: If options are invalid
"""
model = options.model or "ideogram-v3-turbo"
if model not in self.SUPPORTED_MODELS:
raise ValueError(
f"Unsupported model: {model}. "
f"Supported models: {list(self.SUPPORTED_MODELS.keys())}"
)
model_info = self.SUPPORTED_MODELS[model]
max_width, max_height = model_info["max_resolution"]
if options.width > max_width or options.height > max_height:
raise ValueError(
f"Resolution {options.width}x{options.height} exceeds maximum "
f"{max_width}x{max_height} for model {model}"
)
if not options.prompt or len(options.prompt.strip()) == 0:
raise ValueError("Prompt cannot be empty")
def _generate_ideogram_v3(self, options: ImageGenerationOptions) -> bytes:
"""Generate image using Ideogram V3 Turbo.
Args:
options: Image generation options
Returns:
Image bytes
"""
logger.info("[Ideogram V3] Starting image generation: %s", options.prompt[:100])
try:
# Prepare parameters for WaveSpeed Ideogram V3 API
# Note: Adjust these based on actual WaveSpeed API documentation
params = {
"model": "ideogram-v3-turbo",
"prompt": options.prompt,
"width": options.width,
"height": options.height,
"num_inference_steps": options.steps or self.SUPPORTED_MODELS["ideogram-v3-turbo"]["default_steps"],
}
# Add optional parameters
if options.negative_prompt:
params["negative_prompt"] = options.negative_prompt
if options.guidance_scale:
params["guidance_scale"] = options.guidance_scale
if options.seed:
params["seed"] = options.seed
# Call WaveSpeed API (using generic image generation method)
# This will need to be adjusted based on actual WaveSpeed client implementation
result = self.client.generate_image(**params)
# Extract image bytes from result
# Adjust based on actual WaveSpeed API response format
if isinstance(result, bytes):
image_bytes = result
elif isinstance(result, dict) and "image" in result:
image_bytes = result["image"]
else:
raise ValueError(f"Unexpected response format from WaveSpeed API: {type(result)}")
logger.info("[Ideogram V3] ✅ Successfully generated image: %d bytes", len(image_bytes))
return image_bytes
except Exception as e:
logger.error("[Ideogram V3] ❌ Error generating image: %s", str(e), exc_info=True)
raise RuntimeError(f"Ideogram V3 generation failed: {str(e)}")
def _generate_qwen_image(self, options: ImageGenerationOptions) -> bytes:
"""Generate image using Qwen Image.
Args:
options: Image generation options
Returns:
Image bytes
"""
logger.info("[Qwen Image] Starting image generation: %s", options.prompt[:100])
try:
# Prepare parameters for WaveSpeed Qwen Image API
params = {
"model": "qwen-image",
"prompt": options.prompt,
"width": options.width,
"height": options.height,
"num_inference_steps": options.steps or self.SUPPORTED_MODELS["qwen-image"]["default_steps"],
}
# Add optional parameters
if options.negative_prompt:
params["negative_prompt"] = options.negative_prompt
if options.guidance_scale:
params["guidance_scale"] = options.guidance_scale
if options.seed:
params["seed"] = options.seed
# Call WaveSpeed API
result = self.client.generate_image(**params)
# Extract image bytes from result
if isinstance(result, bytes):
image_bytes = result
elif isinstance(result, dict) and "image" in result:
image_bytes = result["image"]
else:
raise ValueError(f"Unexpected response format from WaveSpeed API: {type(result)}")
logger.info("[Qwen Image] ✅ Successfully generated image: %d bytes", len(image_bytes))
return image_bytes
except Exception as e:
logger.error("[Qwen Image] ❌ Error generating image: %s", str(e), exc_info=True)
raise RuntimeError(f"Qwen Image generation failed: {str(e)}")
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
"""Generate image using WaveSpeed AI models.
Args:
options: Image generation options
Returns:
ImageGenerationResult with generated image
Raises:
ValueError: If options are invalid
RuntimeError: If generation fails
"""
# Validate options
self._validate_options(options)
# Determine model
model = options.model or "ideogram-v3-turbo"
# Generate based on model
if model == "ideogram-v3-turbo":
image_bytes = self._generate_ideogram_v3(options)
elif model == "qwen-image":
image_bytes = self._generate_qwen_image(options)
else:
raise ValueError(f"Unsupported model: {model}")
# Load image to get dimensions
image = Image.open(io.BytesIO(image_bytes))
width, height = image.size
# Calculate estimated cost
model_info = self.SUPPORTED_MODELS[model]
estimated_cost = model_info["cost_per_image"]
# Return result
return ImageGenerationResult(
image_bytes=image_bytes,
width=width,
height=height,
provider="wavespeed",
model=model,
seed=options.seed,
metadata={
"provider": "wavespeed",
"model": model,
"model_name": model_info["name"],
"prompt": options.prompt,
"negative_prompt": options.negative_prompt,
"steps": options.steps or model_info["default_steps"],
"guidance_scale": options.guidance_scale,
"estimated_cost": estimated_cost,
}
)
@classmethod
def get_available_models(cls) -> dict:
"""Get available models and their information.
Returns:
Dictionary of available models
"""
return cls.SUPPORTED_MODELS

View File

@@ -240,20 +240,23 @@ def validate_exa_research_operations(
def validate_image_generation_operations(
pricing_service: PricingService,
user_id: str
user_id: str,
num_images: int = 1
) -> None:
"""
Validate image generation operation before making API calls.
Validate image generation operation(s) before making API calls.
Args:
pricing_service: PricingService instance
user_id: User ID for subscription checking
num_images: Number of images to generate (for multiple variations)
Returns:
(can_proceed, error_message, error_details)
If can_proceed is False, raises HTTPException with 429 status
None
If validation fails, raises HTTPException with 429 status
"""
try:
# Create validation operations for each image
operations_to_validate = [
{
'provider': APIProvider.STABILITY,
@@ -261,8 +264,11 @@ def validate_image_generation_operations(
'actual_provider_name': 'stability',
'operation_type': 'image_generation'
}
for _ in range(num_images)
]
logger.info(f"[Pre-flight Validator] 🚀 Validating {num_images} image generation(s) for user {user_id}")
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
@@ -289,6 +295,54 @@ def validate_image_generation_operations(
except HTTPException:
raise
def validate_image_upscale_operations(
pricing_service: PricingService,
user_id: str,
num_images: int = 1
) -> None:
"""
Validate image upscaling before making API calls.
"""
try:
operations_to_validate = [
{
'provider': APIProvider.STABILITY,
'tokens_requested': 0,
'actual_provider_name': 'stability',
'operation_type': 'image_upscale'
}
for _ in range(num_images)
]
logger.info(f"[Pre-flight Validator] 🚀 Validating {num_images} image upscale request(s) for user {user_id}")
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
if not can_proceed:
logger.error(f"[Pre-flight Validator] Image upscale blocked for user {user_id}: {message}")
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', 'stability') if usage_info else 'stability'
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details
}
)
logger.info(f"[Pre-flight Validator] ✅ Image upscale validated for user {user_id}")
except HTTPException:
raise
except Exception as e:
logger.error(f"[Pre-flight Validator] Error validating image generation: {e}", exc_info=True)
raise HTTPException(

View File

@@ -312,6 +312,175 @@ class WaveSpeedClient:
logger.info(f"[WaveSpeed] Prompt optimized successfully (length: {len(optimized_prompt)} chars)")
return optimized_prompt
def generate_image(
self,
model: str,
prompt: str,
width: int = 1024,
height: int = 1024,
num_inference_steps: Optional[int] = None,
guidance_scale: Optional[float] = None,
negative_prompt: Optional[str] = None,
seed: Optional[int] = None,
enable_sync_mode: bool = True,
timeout: int = 120,
**kwargs
) -> bytes:
"""
Generate image using WaveSpeed AI models (Ideogram V3 or Qwen Image).
Args:
model: Model to use ("ideogram-v3-turbo" or "qwen-image")
prompt: Text prompt for image generation
width: Image width (default: 1024)
height: Image height (default: 1024)
num_inference_steps: Number of inference steps
guidance_scale: Guidance scale for generation
negative_prompt: Negative prompt (what to avoid)
seed: Random seed for reproducibility
enable_sync_mode: If True, wait for result and return it directly (default: True)
timeout: Request timeout in seconds (default: 120)
**kwargs: Additional parameters
Returns:
bytes: Generated image bytes
"""
# Map model names to WaveSpeed API paths
model_paths = {
"ideogram-v3-turbo": "ideogram-ai/ideogram-v3-turbo",
"qwen-image": "wavespeed-ai/qwen-image/text-to-image",
}
model_path = model_paths.get(model)
if not model_path:
raise ValueError(f"Unsupported image model: {model}. Supported: {list(model_paths.keys())}")
url = f"{self.BASE_URL}/{model_path}"
payload = {
"prompt": prompt,
"width": width,
"height": height,
"enable_sync_mode": enable_sync_mode,
}
# Add optional parameters
if num_inference_steps is not None:
payload["num_inference_steps"] = num_inference_steps
if guidance_scale is not None:
payload["guidance_scale"] = guidance_scale
if negative_prompt:
payload["negative_prompt"] = negative_prompt
if seed is not None:
payload["seed"] = seed
# Add any extra parameters
for key, value in kwargs.items():
if key not in payload:
payload[key] = value
logger.info(f"[WaveSpeed] Generating image via {url} (model={model}, prompt_length={len(prompt)})")
response = requests.post(url, headers=self._headers(), json=payload, timeout=timeout)
if response.status_code != 200:
logger.error(f"[WaveSpeed] Image generation failed: {response.status_code} {response.text}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed image generation failed",
"status_code": response.status_code,
"response": response.text,
},
)
response_json = response.json()
data = response_json.get("data") or response_json
# Handle sync mode - result should be directly in outputs
if enable_sync_mode:
outputs = data.get("outputs") or []
if not outputs:
logger.error(f"[WaveSpeed] No outputs in sync mode response: {response.text}")
raise HTTPException(
status_code=502,
detail="WaveSpeed image generator returned no outputs",
)
# Extract image URL from outputs
image_url = None
if isinstance(outputs, list) and len(outputs) > 0:
first_output = outputs[0]
if isinstance(first_output, str):
image_url = first_output
elif isinstance(first_output, dict):
image_url = first_output.get("url") or first_output.get("output")
if not image_url or not (image_url.startswith("http://") or image_url.startswith("https://")):
logger.error(f"[WaveSpeed] Invalid image URL in outputs: {outputs}")
raise HTTPException(
status_code=502,
detail="WaveSpeed image generator output format not recognized",
)
# Fetch image bytes from URL
logger.info(f"[WaveSpeed] Fetching image from URL: {image_url}")
image_response = requests.get(image_url, timeout=timeout)
if image_response.status_code == 200:
image_bytes = image_response.content
logger.info(f"[WaveSpeed] Image generated successfully (size: {len(image_bytes)} bytes)")
return image_bytes
else:
logger.error(f"[WaveSpeed] Failed to fetch image from URL: {image_response.status_code}")
raise HTTPException(
status_code=502,
detail="Failed to fetch generated image from WaveSpeed URL",
)
# Async mode - poll for result
prediction_id = data.get("id")
if not prediction_id:
logger.error(f"[WaveSpeed] No prediction ID in async response: {response.text}")
raise HTTPException(
status_code=502,
detail="WaveSpeed response missing prediction id for async mode",
)
# Poll for result
result = self.poll_until_complete(prediction_id, timeout_seconds=240, interval_seconds=1.0)
outputs = result.get("outputs") or []
if not outputs:
raise HTTPException(status_code=502, detail="WaveSpeed image generator returned no outputs")
# Extract image URL and fetch
image_url = None
if isinstance(outputs, list) and len(outputs) > 0:
first_output = outputs[0]
if isinstance(first_output, str):
image_url = first_output
elif isinstance(first_output, dict):
image_url = first_output.get("url") or first_output.get("output")
if not image_url or not (image_url.startswith("http://") or image_url.startswith("https://")):
raise HTTPException(
status_code=502,
detail="WaveSpeed image generator output format not recognized",
)
# Fetch image bytes
logger.info(f"[WaveSpeed] Fetching image from URL: {image_url}")
image_response = requests.get(image_url, timeout=timeout)
if image_response.status_code == 200:
image_bytes = image_response.content
logger.info(f"[WaveSpeed] Image generated successfully (size: {len(image_bytes)} bytes)")
return image_bytes
else:
logger.error(f"[WaveSpeed] Failed to fetch image from URL: {image_response.status_code}")
raise HTTPException(
status_code=502,
detail="Failed to fetch generated image from WaveSpeed URL",
)
def generate_speech(
self,
text: str,