AI Image Studio Phase 1
This commit is contained in:
@@ -52,6 +52,7 @@ from routers.linkedin import router as linkedin_router
|
||||
from api.linkedin_image_generation import router as linkedin_image_router
|
||||
from api.brainstorm import router as brainstorm_router
|
||||
from api.images import router as images_router
|
||||
from routers.image_studio import router as image_studio_router
|
||||
|
||||
# Import hallucination detector router
|
||||
from api.hallucination_detector import router as hallucination_detector_router
|
||||
@@ -296,6 +297,7 @@ async def batch_analyze_urls_endpoint(urls: list[str]):
|
||||
from routers.platform_analytics import router as platform_analytics_router
|
||||
app.include_router(platform_analytics_router)
|
||||
app.include_router(images_router)
|
||||
app.include_router(image_studio_router)
|
||||
|
||||
# Include research configuration router
|
||||
app.include_router(research_config_router, prefix="/api/research", tags=["research"])
|
||||
|
||||
593
backend/routers/image_studio.py
Normal file
593
backend/routers/image_studio.py
Normal file
@@ -0,0 +1,593 @@
|
||||
"""API endpoints for Image Studio operations."""
|
||||
|
||||
import base64
|
||||
from typing import Optional, List, Dict, Any, Literal
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from services.image_studio import (
|
||||
ImageStudioManager,
|
||||
CreateStudioRequest,
|
||||
EditStudioRequest,
|
||||
)
|
||||
from services.image_studio.upscale_service import UpscaleStudioRequest
|
||||
from services.image_studio.templates import Platform, TemplateCategory
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(prefix="/api/image-studio", tags=["image-studio"])
|
||||
|
||||
|
||||
# ====================
|
||||
# REQUEST MODELS
|
||||
# ====================
|
||||
|
||||
class CreateImageRequest(BaseModel):
|
||||
"""Request model for image generation."""
|
||||
prompt: str = Field(..., description="Image generation prompt")
|
||||
template_id: Optional[str] = Field(None, description="Template ID to use")
|
||||
provider: Optional[str] = Field("auto", description="Provider: auto, stability, wavespeed, huggingface, gemini")
|
||||
model: Optional[str] = Field(None, description="Specific model to use")
|
||||
width: Optional[int] = Field(None, description="Image width in pixels")
|
||||
height: Optional[int] = Field(None, description="Image height in pixels")
|
||||
aspect_ratio: Optional[str] = Field(None, description="Aspect ratio (e.g., '1:1', '16:9')")
|
||||
style_preset: Optional[str] = Field(None, description="Style preset")
|
||||
quality: str = Field("standard", description="Quality: draft, standard, premium")
|
||||
negative_prompt: Optional[str] = Field(None, description="Negative prompt")
|
||||
guidance_scale: Optional[float] = Field(None, description="Guidance scale")
|
||||
steps: Optional[int] = Field(None, description="Number of inference steps")
|
||||
seed: Optional[int] = Field(None, description="Random seed")
|
||||
num_variations: int = Field(1, ge=1, le=10, description="Number of variations (1-10)")
|
||||
enhance_prompt: bool = Field(True, description="Enhance prompt with AI")
|
||||
use_persona: bool = Field(False, description="Use persona for brand consistency")
|
||||
persona_id: Optional[str] = Field(None, description="Persona ID")
|
||||
|
||||
|
||||
class CostEstimationRequest(BaseModel):
|
||||
"""Request model for cost estimation."""
|
||||
provider: str = Field(..., description="Provider name")
|
||||
model: Optional[str] = Field(None, description="Model name")
|
||||
operation: str = Field("generate", description="Operation type")
|
||||
num_images: int = Field(1, ge=1, description="Number of images")
|
||||
width: Optional[int] = Field(None, description="Image width")
|
||||
height: Optional[int] = Field(None, description="Image height")
|
||||
|
||||
|
||||
class EditImageRequest(BaseModel):
|
||||
"""Request payload for Edit Studio."""
|
||||
|
||||
image_base64: str = Field(..., description="Primary image payload (base64 or data URL)")
|
||||
operation: Literal[
|
||||
"remove_background",
|
||||
"inpaint",
|
||||
"outpaint",
|
||||
"search_replace",
|
||||
"search_recolor",
|
||||
"general_edit",
|
||||
] = Field(..., description="Edit operation to perform")
|
||||
prompt: Optional[str] = Field(None, description="Primary prompt/instruction")
|
||||
negative_prompt: Optional[str] = Field(None, description="Negative prompt for providers that support it")
|
||||
mask_base64: Optional[str] = Field(None, description="Optional mask image in base64")
|
||||
search_prompt: Optional[str] = Field(None, description="Search prompt for replace operations")
|
||||
select_prompt: Optional[str] = Field(None, description="Select prompt for recolor operations")
|
||||
background_image_base64: Optional[str] = Field(None, description="Reference background image")
|
||||
lighting_image_base64: Optional[str] = Field(None, description="Reference lighting image")
|
||||
expand_left: Optional[int] = Field(0, description="Outpaint expansion in pixels (left)")
|
||||
expand_right: Optional[int] = Field(0, description="Outpaint expansion in pixels (right)")
|
||||
expand_up: Optional[int] = Field(0, description="Outpaint expansion in pixels (up)")
|
||||
expand_down: Optional[int] = Field(0, description="Outpaint expansion in pixels (down)")
|
||||
provider: Optional[str] = Field(None, description="Explicit provider override")
|
||||
model: Optional[str] = Field(None, description="Explicit model override")
|
||||
style_preset: Optional[str] = Field(None, description="Style preset for Stability helpers")
|
||||
guidance_scale: Optional[float] = Field(None, description="Guidance scale for general edits")
|
||||
steps: Optional[int] = Field(None, description="Inference steps")
|
||||
seed: Optional[int] = Field(None, description="Random seed for reproducibility")
|
||||
output_format: str = Field("png", description="Output format for edited image")
|
||||
options: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description="Advanced provider-specific options (e.g., grow_mask)",
|
||||
)
|
||||
|
||||
|
||||
class EditImageResponse(BaseModel):
|
||||
success: bool
|
||||
operation: str
|
||||
provider: str
|
||||
image_base64: str
|
||||
width: int
|
||||
height: int
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class EditOperationsResponse(BaseModel):
|
||||
operations: Dict[str, Dict[str, Any]]
|
||||
|
||||
|
||||
class UpscaleImageRequest(BaseModel):
|
||||
image_base64: str
|
||||
mode: Literal["fast", "conservative", "creative", "auto"] = "auto"
|
||||
target_width: Optional[int] = Field(None, description="Target width in pixels")
|
||||
target_height: Optional[int] = Field(None, description="Target height in pixels")
|
||||
preset: Optional[str] = Field(None, description="Named preset (web, print, social)")
|
||||
prompt: Optional[str] = Field(None, description="Prompt for conservative/creative modes")
|
||||
|
||||
|
||||
class UpscaleImageResponse(BaseModel):
|
||||
success: bool
|
||||
mode: str
|
||||
image_base64: str
|
||||
width: int
|
||||
height: int
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
# ====================
|
||||
# DEPENDENCY
|
||||
# ====================
|
||||
|
||||
def get_studio_manager() -> ImageStudioManager:
|
||||
"""Get Image Studio Manager instance."""
|
||||
return ImageStudioManager()
|
||||
|
||||
|
||||
def _require_user_id(current_user: Dict[str, Any], operation: str) -> str:
|
||||
"""Ensure user_id is available for protected operations."""
|
||||
user_id = current_user.get("sub") or current_user.get("user_id")
|
||||
if not user_id:
|
||||
logger.error(
|
||||
"[Image Studio] ❌ Missing user_id for %s operation - blocking request",
|
||||
operation,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authenticated user required for image operations.",
|
||||
)
|
||||
return user_id
|
||||
|
||||
|
||||
# ====================
|
||||
# CREATE STUDIO ENDPOINTS
|
||||
# ====================
|
||||
|
||||
@router.post("/create", summary="Generate Image")
|
||||
async def create_image(
|
||||
request: CreateImageRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Generate image(s) using Create Studio.
|
||||
|
||||
This endpoint supports:
|
||||
- Multiple AI providers (Stability AI, WaveSpeed, HuggingFace, Gemini)
|
||||
- Template-based generation
|
||||
- Custom dimensions and aspect ratios
|
||||
- Style presets and quality levels
|
||||
- Multiple variations
|
||||
- Prompt enhancement
|
||||
|
||||
Returns:
|
||||
Dictionary with generation results including image data
|
||||
"""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "image generation")
|
||||
logger.info(f"[Create Image] Request from user {user_id}: {request.prompt[:100]}")
|
||||
|
||||
# Convert request to CreateStudioRequest
|
||||
studio_request = CreateStudioRequest(
|
||||
prompt=request.prompt,
|
||||
template_id=request.template_id,
|
||||
provider=request.provider,
|
||||
model=request.model,
|
||||
width=request.width,
|
||||
height=request.height,
|
||||
aspect_ratio=request.aspect_ratio,
|
||||
style_preset=request.style_preset,
|
||||
quality=request.quality,
|
||||
negative_prompt=request.negative_prompt,
|
||||
guidance_scale=request.guidance_scale,
|
||||
steps=request.steps,
|
||||
seed=request.seed,
|
||||
num_variations=request.num_variations,
|
||||
enhance_prompt=request.enhance_prompt,
|
||||
use_persona=request.use_persona,
|
||||
persona_id=request.persona_id,
|
||||
)
|
||||
|
||||
# Generate images
|
||||
result = await studio_manager.create_image(studio_request, user_id=user_id)
|
||||
|
||||
# Convert image bytes to base64 for JSON response
|
||||
for idx, img_result in enumerate(result["results"]):
|
||||
if "image_bytes" in img_result:
|
||||
img_result["image_base64"] = base64.b64encode(img_result["image_bytes"]).decode("utf-8")
|
||||
# Remove bytes from response
|
||||
del img_result["image_bytes"]
|
||||
|
||||
logger.info(f"[Create Image] ✅ Success: {result['total_generated']} images generated")
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"[Create Image] ❌ Validation error: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
logger.error(f"[Create Image] ❌ Generation error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Create Image] ❌ Unexpected error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
|
||||
# ====================
|
||||
# TEMPLATE ENDPOINTS
|
||||
# ====================
|
||||
|
||||
@router.get("/templates", summary="Get Templates")
|
||||
async def get_templates(
|
||||
platform: Optional[Platform] = None,
|
||||
category: Optional[TemplateCategory] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Get available image templates.
|
||||
|
||||
Templates provide pre-configured settings for common use cases:
|
||||
- Platform-specific dimensions and formats
|
||||
- Recommended providers and models
|
||||
- Style presets and quality settings
|
||||
|
||||
Args:
|
||||
platform: Filter by platform (instagram, facebook, twitter, etc.)
|
||||
category: Filter by category (social_media, blog_content, ad_creative, etc.)
|
||||
|
||||
Returns:
|
||||
List of templates
|
||||
"""
|
||||
try:
|
||||
templates = studio_manager.get_templates(platform=platform, category=category)
|
||||
|
||||
# Convert to dict for JSON response
|
||||
templates_dict = [
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"category": t.category.value,
|
||||
"platform": t.platform.value if t.platform else None,
|
||||
"aspect_ratio": {
|
||||
"ratio": t.aspect_ratio.ratio,
|
||||
"width": t.aspect_ratio.width,
|
||||
"height": t.aspect_ratio.height,
|
||||
"label": t.aspect_ratio.label,
|
||||
},
|
||||
"description": t.description,
|
||||
"recommended_provider": t.recommended_provider,
|
||||
"style_preset": t.style_preset,
|
||||
"quality": t.quality,
|
||||
"use_cases": t.use_cases or [],
|
||||
}
|
||||
for t in templates
|
||||
]
|
||||
|
||||
return {"templates": templates_dict, "total": len(templates_dict)}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Get Templates] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/templates/search", summary="Search Templates")
|
||||
async def search_templates(
|
||||
query: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Search templates by query.
|
||||
|
||||
Searches in template names, descriptions, and use cases.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
List of matching templates
|
||||
"""
|
||||
try:
|
||||
templates = studio_manager.search_templates(query)
|
||||
|
||||
templates_dict = [
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"category": t.category.value,
|
||||
"platform": t.platform.value if t.platform else None,
|
||||
"aspect_ratio": {
|
||||
"ratio": t.aspect_ratio.ratio,
|
||||
"width": t.aspect_ratio.width,
|
||||
"height": t.aspect_ratio.height,
|
||||
"label": t.aspect_ratio.label,
|
||||
},
|
||||
"description": t.description,
|
||||
"recommended_provider": t.recommended_provider,
|
||||
"style_preset": t.style_preset,
|
||||
"quality": t.quality,
|
||||
"use_cases": t.use_cases or [],
|
||||
}
|
||||
for t in templates
|
||||
]
|
||||
|
||||
return {"templates": templates_dict, "total": len(templates_dict), "query": query}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Search Templates] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/templates/recommend", summary="Recommend Templates")
|
||||
async def recommend_templates(
|
||||
use_case: str,
|
||||
platform: Optional[Platform] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Recommend templates based on use case.
|
||||
|
||||
Args:
|
||||
use_case: Description of use case (e.g., "product showcase", "blog header")
|
||||
platform: Optional platform filter
|
||||
|
||||
Returns:
|
||||
List of recommended templates
|
||||
"""
|
||||
try:
|
||||
templates = studio_manager.recommend_templates(use_case, platform=platform)
|
||||
|
||||
templates_dict = [
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"category": t.category.value,
|
||||
"platform": t.platform.value if t.platform else None,
|
||||
"aspect_ratio": {
|
||||
"ratio": t.aspect_ratio.ratio,
|
||||
"width": t.aspect_ratio.width,
|
||||
"height": t.aspect_ratio.height,
|
||||
"label": t.aspect_ratio.label,
|
||||
},
|
||||
"description": t.description,
|
||||
"recommended_provider": t.recommended_provider,
|
||||
"style_preset": t.style_preset,
|
||||
"quality": t.quality,
|
||||
"use_cases": t.use_cases or [],
|
||||
}
|
||||
for t in templates
|
||||
]
|
||||
|
||||
return {"templates": templates_dict, "total": len(templates_dict), "use_case": use_case}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Recommend Templates] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ====================
|
||||
# PROVIDER ENDPOINTS
|
||||
# ====================
|
||||
|
||||
@router.get("/providers", summary="Get Providers")
|
||||
async def get_providers(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Get available AI providers and their capabilities.
|
||||
|
||||
Returns information about:
|
||||
- Available models
|
||||
- Capabilities
|
||||
- Maximum resolution
|
||||
- Cost estimates
|
||||
|
||||
Returns:
|
||||
Dictionary of providers
|
||||
"""
|
||||
try:
|
||||
providers = studio_manager.get_providers()
|
||||
return {"providers": providers}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Get Providers] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ====================
|
||||
# COST ESTIMATION ENDPOINTS
|
||||
# ====================
|
||||
|
||||
@router.post("/estimate-cost", summary="Estimate Cost")
|
||||
async def estimate_cost(
|
||||
request: CostEstimationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Estimate cost for image generation operations.
|
||||
|
||||
Provides cost estimates before generation to help users make informed decisions.
|
||||
|
||||
Args:
|
||||
request: Cost estimation request
|
||||
|
||||
Returns:
|
||||
Cost estimation details
|
||||
"""
|
||||
try:
|
||||
resolution = None
|
||||
if request.width and request.height:
|
||||
resolution = (request.width, request.height)
|
||||
|
||||
estimate = studio_manager.estimate_cost(
|
||||
provider=request.provider,
|
||||
model=request.model,
|
||||
operation=request.operation,
|
||||
num_images=request.num_images,
|
||||
resolution=resolution
|
||||
)
|
||||
|
||||
return estimate
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Estimate Cost] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ====================
|
||||
# EDIT STUDIO ENDPOINTS
|
||||
# ====================
|
||||
|
||||
@router.post("/edit/process", response_model=EditImageResponse, summary="Process Edit Studio request")
|
||||
async def process_edit_image(
|
||||
request: EditImageRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Perform Edit Studio operations such as remove background, inpaint, or recolor."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "image editing")
|
||||
logger.info(f"[Edit Image] Request from user {user_id}: operation={request.operation}")
|
||||
|
||||
edit_request = EditStudioRequest(
|
||||
image_base64=request.image_base64,
|
||||
operation=request.operation,
|
||||
prompt=request.prompt,
|
||||
negative_prompt=request.negative_prompt,
|
||||
mask_base64=request.mask_base64,
|
||||
search_prompt=request.search_prompt,
|
||||
select_prompt=request.select_prompt,
|
||||
background_image_base64=request.background_image_base64,
|
||||
lighting_image_base64=request.lighting_image_base64,
|
||||
expand_left=request.expand_left,
|
||||
expand_right=request.expand_right,
|
||||
expand_up=request.expand_up,
|
||||
expand_down=request.expand_down,
|
||||
provider=request.provider,
|
||||
model=request.model,
|
||||
style_preset=request.style_preset,
|
||||
guidance_scale=request.guidance_scale,
|
||||
steps=request.steps,
|
||||
seed=request.seed,
|
||||
output_format=request.output_format,
|
||||
options=request.options or {},
|
||||
)
|
||||
|
||||
result = await studio_manager.edit_image(edit_request, user_id=user_id)
|
||||
return EditImageResponse(**result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Edit Image] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Image editing failed: {e}")
|
||||
|
||||
|
||||
@router.get("/edit/operations", response_model=EditOperationsResponse, summary="List Edit Studio operations")
|
||||
async def get_edit_operations(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Return metadata for supported Edit Studio operations."""
|
||||
try:
|
||||
operations = studio_manager.get_edit_operations()
|
||||
return EditOperationsResponse(operations=operations)
|
||||
except Exception as e:
|
||||
logger.error(f"[Edit Operations] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to load edit operations")
|
||||
|
||||
|
||||
# ====================
|
||||
# UPSCALE STUDIO ENDPOINTS
|
||||
# ====================
|
||||
|
||||
@router.post("/upscale", response_model=UpscaleImageResponse, summary="Upscale Image")
|
||||
async def upscale_image(
|
||||
request: UpscaleImageRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Upscale an image using Stability AI pipelines."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "image upscaling")
|
||||
upscale_request = UpscaleStudioRequest(
|
||||
image_base64=request.image_base64,
|
||||
mode=request.mode,
|
||||
target_width=request.target_width,
|
||||
target_height=request.target_height,
|
||||
preset=request.preset,
|
||||
prompt=request.prompt,
|
||||
)
|
||||
result = await studio_manager.upscale_image(upscale_request, user_id=user_id)
|
||||
return UpscaleImageResponse(**result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Upscale Image] ❌ Error: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Image upscaling failed: {e}")
|
||||
|
||||
|
||||
# ====================
|
||||
# PLATFORM SPECS ENDPOINTS
|
||||
# ====================
|
||||
|
||||
@router.get("/platform-specs/{platform}", summary="Get Platform Specifications")
|
||||
async def get_platform_specs(
|
||||
platform: Platform,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Get specifications and requirements for a specific platform.
|
||||
|
||||
Returns:
|
||||
- Supported formats and dimensions
|
||||
- File type requirements
|
||||
- Maximum file size
|
||||
- Best practices
|
||||
|
||||
Args:
|
||||
platform: Platform name
|
||||
|
||||
Returns:
|
||||
Platform specifications
|
||||
"""
|
||||
try:
|
||||
specs = studio_manager.get_platform_specs(platform)
|
||||
if not specs:
|
||||
raise HTTPException(status_code=404, detail=f"Specifications not found for platform: {platform}")
|
||||
|
||||
return specs
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Get Platform Specs] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ====================
|
||||
# HEALTH CHECK
|
||||
# ====================
|
||||
|
||||
@router.get("/health", summary="Health Check")
|
||||
async def health_check():
|
||||
"""Health check endpoint for Image Studio.
|
||||
|
||||
Returns:
|
||||
Health status
|
||||
"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "image_studio",
|
||||
"version": "1.0.0",
|
||||
"modules": {
|
||||
"create_studio": "available",
|
||||
"templates": "available",
|
||||
"providers": "available",
|
||||
}
|
||||
}
|
||||
|
||||
20
backend/services/image_studio/__init__.py
Normal file
20
backend/services/image_studio/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Image Studio service package for centralized image operations."""
|
||||
|
||||
from .studio_manager import ImageStudioManager
|
||||
from .create_service import CreateStudioService, CreateStudioRequest
|
||||
from .edit_service import EditStudioService, EditStudioRequest
|
||||
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
|
||||
from .templates import PlatformTemplates, TemplateManager
|
||||
|
||||
__all__ = [
|
||||
"ImageStudioManager",
|
||||
"CreateStudioService",
|
||||
"CreateStudioRequest",
|
||||
"EditStudioService",
|
||||
"EditStudioRequest",
|
||||
"UpscaleStudioService",
|
||||
"UpscaleStudioRequest",
|
||||
"PlatformTemplates",
|
||||
"TemplateManager",
|
||||
]
|
||||
|
||||
458
backend/services/image_studio/create_service.py
Normal file
458
backend/services/image_studio/create_service.py
Normal file
@@ -0,0 +1,458 @@
|
||||
"""Create Studio service for AI-powered image generation."""
|
||||
|
||||
import os
|
||||
from typing import Optional, Dict, Any, List, Literal
|
||||
from dataclasses import dataclass
|
||||
|
||||
from services.llm_providers.image_generation import (
|
||||
ImageGenerationOptions,
|
||||
ImageGenerationResult,
|
||||
HuggingFaceImageProvider,
|
||||
GeminiImageProvider,
|
||||
StabilityImageProvider,
|
||||
WaveSpeedImageProvider,
|
||||
)
|
||||
from .templates import TemplateManager, ImageTemplate, Platform, TemplateCategory
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("image_studio.create")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateStudioRequest:
|
||||
"""Request for image generation in Create Studio."""
|
||||
prompt: str
|
||||
template_id: Optional[str] = None
|
||||
provider: Optional[str] = None # "auto", "stability", "wavespeed", "huggingface", "gemini"
|
||||
model: Optional[str] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
aspect_ratio: Optional[str] = None # e.g., "1:1", "16:9"
|
||||
style_preset: Optional[str] = None
|
||||
quality: Literal["draft", "standard", "premium"] = "standard"
|
||||
negative_prompt: Optional[str] = None
|
||||
guidance_scale: Optional[float] = None
|
||||
steps: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
num_variations: int = 1
|
||||
enhance_prompt: bool = True
|
||||
use_persona: bool = False
|
||||
persona_id: Optional[str] = None
|
||||
|
||||
|
||||
class CreateStudioService:
|
||||
"""Service for Create Studio image generation operations."""
|
||||
|
||||
# Provider-to-model mapping for smart recommendations
|
||||
PROVIDER_MODELS = {
|
||||
"stability": {
|
||||
"ultra": "stability-ultra", # Best quality, 8 credits
|
||||
"core": "stability-core", # Fast & affordable, 3 credits
|
||||
"sd3": "sd3.5-large", # SD3.5 model
|
||||
},
|
||||
"wavespeed": {
|
||||
"ideogram-v3-turbo": "ideogram-v3-turbo", # Photorealistic, text rendering
|
||||
"qwen-image": "qwen-image", # Fast generation
|
||||
},
|
||||
"huggingface": {
|
||||
"flux": "black-forest-labs/FLUX.1-Krea-dev",
|
||||
},
|
||||
"gemini": {
|
||||
"imagen": "imagen-3.0-generate-001",
|
||||
}
|
||||
}
|
||||
|
||||
# Quality-to-provider mapping
|
||||
QUALITY_PROVIDERS = {
|
||||
"draft": ["huggingface", "wavespeed:qwen-image"], # Fast, low cost
|
||||
"standard": ["stability:core", "wavespeed:ideogram-v3-turbo"], # Balanced
|
||||
"premium": ["wavespeed:ideogram-v3-turbo", "stability:ultra"], # Best quality
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Create Studio service."""
|
||||
self.template_manager = TemplateManager()
|
||||
logger.info("[Create Studio] Initialized with template manager")
|
||||
|
||||
def _get_provider_instance(self, provider_name: str, api_key: Optional[str] = None):
|
||||
"""Get provider instance by name.
|
||||
|
||||
Args:
|
||||
provider_name: Name of the provider
|
||||
api_key: Optional API key (uses env vars if not provided)
|
||||
|
||||
Returns:
|
||||
Provider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is not supported
|
||||
"""
|
||||
if provider_name == "stability":
|
||||
return StabilityImageProvider(api_key=api_key or os.getenv("STABILITY_API_KEY"))
|
||||
elif provider_name == "wavespeed":
|
||||
return WaveSpeedImageProvider(api_key=api_key or os.getenv("WAVESPEED_API_KEY"))
|
||||
elif provider_name == "huggingface":
|
||||
return HuggingFaceImageProvider(api_token=api_key or os.getenv("HF_API_KEY"))
|
||||
elif provider_name == "gemini":
|
||||
return GeminiImageProvider(api_key=api_key or os.getenv("GEMINI_API_KEY"))
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider_name}")
|
||||
|
||||
def _select_provider_and_model(
|
||||
self,
|
||||
request: CreateStudioRequest,
|
||||
template: Optional[ImageTemplate] = None
|
||||
) -> tuple[str, Optional[str]]:
|
||||
"""Smart provider and model selection.
|
||||
|
||||
Args:
|
||||
request: Create studio request
|
||||
template: Optional template with recommendations
|
||||
|
||||
Returns:
|
||||
Tuple of (provider_name, model_name)
|
||||
"""
|
||||
# Explicit provider selection
|
||||
if request.provider and request.provider != "auto":
|
||||
provider = request.provider
|
||||
model = request.model
|
||||
logger.info("[Provider Selection] User specified: %s (model: %s)", provider, model)
|
||||
return provider, model
|
||||
|
||||
# Template recommendation
|
||||
if template and template.recommended_provider:
|
||||
provider = template.recommended_provider
|
||||
logger.info("[Provider Selection] Template recommends: %s", provider)
|
||||
|
||||
# Map provider to specific model if not specified
|
||||
if not request.model:
|
||||
if provider == "ideogram":
|
||||
return "wavespeed", "ideogram-v3-turbo"
|
||||
elif provider == "qwen":
|
||||
return "wavespeed", "qwen-image"
|
||||
elif provider == "stability":
|
||||
# Choose based on quality
|
||||
if request.quality == "premium":
|
||||
return "stability", "stability-ultra"
|
||||
elif request.quality == "draft":
|
||||
return "stability", "stability-core"
|
||||
else:
|
||||
return "stability", "stability-core"
|
||||
|
||||
return provider, request.model
|
||||
|
||||
# Quality-based selection
|
||||
quality_options = self.QUALITY_PROVIDERS.get(request.quality, self.QUALITY_PROVIDERS["standard"])
|
||||
selected = quality_options[0] # Pick first option
|
||||
|
||||
if ":" in selected:
|
||||
provider, model = selected.split(":", 1)
|
||||
else:
|
||||
provider = selected
|
||||
model = None
|
||||
|
||||
logger.info("[Provider Selection] Quality-based (%s): %s (model: %s)",
|
||||
request.quality, provider, model)
|
||||
return provider, model
|
||||
|
||||
def _enhance_prompt(self, prompt: str, style_preset: Optional[str] = None) -> str:
|
||||
"""Enhance prompt with style and quality descriptors.
|
||||
|
||||
Args:
|
||||
prompt: Original prompt
|
||||
style_preset: Style preset to apply
|
||||
|
||||
Returns:
|
||||
Enhanced prompt
|
||||
"""
|
||||
enhanced = prompt
|
||||
|
||||
# Add style-specific enhancements
|
||||
style_enhancements = {
|
||||
"photographic": ", professional photography, high quality, detailed, sharp focus, natural lighting",
|
||||
"digital-art": ", digital art, vibrant colors, detailed, high quality, artstation trending",
|
||||
"cinematic": ", cinematic lighting, dramatic, film grain, high quality, professional",
|
||||
"3d-model": ", 3D render, octane render, unreal engine, high quality, detailed",
|
||||
"anime": ", anime style, vibrant colors, detailed, high quality",
|
||||
"line-art": ", clean line art, detailed linework, high contrast, professional",
|
||||
}
|
||||
|
||||
if style_preset and style_preset in style_enhancements:
|
||||
enhanced += style_enhancements[style_preset]
|
||||
|
||||
logger.info("[Prompt Enhancement] Original: %s", prompt[:100])
|
||||
logger.info("[Prompt Enhancement] Enhanced: %s", enhanced[:100])
|
||||
|
||||
return enhanced
|
||||
|
||||
def _apply_template(self, request: CreateStudioRequest, template: ImageTemplate) -> CreateStudioRequest:
|
||||
"""Apply template settings to request.
|
||||
|
||||
Args:
|
||||
request: Original request
|
||||
template: Template to apply
|
||||
|
||||
Returns:
|
||||
Modified request
|
||||
"""
|
||||
# Apply template dimensions if not specified
|
||||
if not request.width and not request.height:
|
||||
request.width = template.aspect_ratio.width
|
||||
request.height = template.aspect_ratio.height
|
||||
|
||||
# Apply template style if not specified
|
||||
if not request.style_preset:
|
||||
request.style_preset = template.style_preset
|
||||
|
||||
# Apply template quality if not specified
|
||||
if request.quality == "standard":
|
||||
request.quality = template.quality
|
||||
|
||||
logger.info("[Template Applied] %s -> %dx%d, style=%s, quality=%s",
|
||||
template.name, request.width, request.height,
|
||||
request.style_preset, request.quality)
|
||||
|
||||
return request
|
||||
|
||||
def _calculate_dimensions(
|
||||
self,
|
||||
width: Optional[int],
|
||||
height: Optional[int],
|
||||
aspect_ratio: Optional[str]
|
||||
) -> tuple[int, int]:
|
||||
"""Calculate image dimensions from width/height or aspect ratio.
|
||||
|
||||
Args:
|
||||
width: Explicit width
|
||||
height: Explicit height
|
||||
aspect_ratio: Aspect ratio string (e.g., "16:9")
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
# Both dimensions specified
|
||||
if width and height:
|
||||
return width, height
|
||||
|
||||
# Aspect ratio specified
|
||||
if aspect_ratio:
|
||||
try:
|
||||
w_ratio, h_ratio = map(int, aspect_ratio.split(":"))
|
||||
|
||||
# Use width if specified
|
||||
if width:
|
||||
height = int(width * h_ratio / w_ratio)
|
||||
return width, height
|
||||
|
||||
# Use height if specified
|
||||
if height:
|
||||
width = int(height * w_ratio / h_ratio)
|
||||
return width, height
|
||||
|
||||
# Default size based on aspect ratio
|
||||
# Use 1080p as base
|
||||
if w_ratio >= h_ratio:
|
||||
# Landscape or square
|
||||
width = 1920
|
||||
height = int(1920 * h_ratio / w_ratio)
|
||||
else:
|
||||
# Portrait
|
||||
height = 1920
|
||||
width = int(1920 * w_ratio / h_ratio)
|
||||
|
||||
return width, height
|
||||
except ValueError:
|
||||
logger.warning("[Dimensions] Invalid aspect ratio: %s", aspect_ratio)
|
||||
|
||||
# Default dimensions
|
||||
return 1024, 1024
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
request: CreateStudioRequest,
|
||||
user_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate image(s) using Create Studio.
|
||||
|
||||
Args:
|
||||
request: Create studio request
|
||||
user_id: User ID for validation and tracking
|
||||
|
||||
Returns:
|
||||
Dictionary with generation results
|
||||
|
||||
Raises:
|
||||
ValueError: If request is invalid
|
||||
RuntimeError: If generation fails
|
||||
"""
|
||||
logger.info("[Create Studio] Starting generation: prompt=%s, template=%s",
|
||||
request.prompt[:100], request.template_id)
|
||||
|
||||
# Pre-flight validation: Check subscription and usage limits
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
logger.info(f"[Create Studio] 🛂 Running pre-flight validation for user {user_id}")
|
||||
validate_image_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
num_images=request.num_variations
|
||||
)
|
||||
logger.info(f"[Create Studio] ✅ Pre-flight validation passed - proceeding with generation")
|
||||
except HTTPException as http_ex:
|
||||
logger.error(f"[Create Studio] ❌ Pre-flight validation failed - blocking generation")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
logger.warning("[Create Studio] ⚠️ No user_id provided - skipping pre-flight validation")
|
||||
|
||||
# Load template if specified
|
||||
template = None
|
||||
if request.template_id:
|
||||
template = self.template_manager.get_by_id(request.template_id)
|
||||
if not template:
|
||||
raise ValueError(f"Template not found: {request.template_id}")
|
||||
|
||||
# Apply template settings
|
||||
request = self._apply_template(request, template)
|
||||
|
||||
# Calculate dimensions
|
||||
width, height = self._calculate_dimensions(
|
||||
request.width, request.height, request.aspect_ratio
|
||||
)
|
||||
|
||||
# Enhance prompt if requested
|
||||
prompt = request.prompt
|
||||
if request.enhance_prompt:
|
||||
prompt = self._enhance_prompt(prompt, request.style_preset)
|
||||
|
||||
# Select provider and model
|
||||
provider_name, model = self._select_provider_and_model(request, template)
|
||||
|
||||
# Get provider instance
|
||||
try:
|
||||
provider = self._get_provider_instance(provider_name)
|
||||
except Exception as e:
|
||||
logger.error("[Create Studio] ❌ Failed to initialize provider %s: %s",
|
||||
provider_name, str(e))
|
||||
raise RuntimeError(f"Provider initialization failed: {str(e)}")
|
||||
|
||||
# Generate images
|
||||
results = []
|
||||
for i in range(request.num_variations):
|
||||
logger.info("[Create Studio] Generating variation %d/%d",
|
||||
i + 1, request.num_variations)
|
||||
|
||||
try:
|
||||
# Prepare options
|
||||
options = ImageGenerationOptions(
|
||||
prompt=prompt,
|
||||
negative_prompt=request.negative_prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
guidance_scale=request.guidance_scale,
|
||||
steps=request.steps,
|
||||
seed=request.seed + i if request.seed else None,
|
||||
model=model,
|
||||
extra={"style_preset": request.style_preset} if request.style_preset else {}
|
||||
)
|
||||
|
||||
# Generate image
|
||||
result: ImageGenerationResult = provider.generate(options)
|
||||
|
||||
results.append({
|
||||
"image_bytes": result.image_bytes,
|
||||
"width": result.width,
|
||||
"height": result.height,
|
||||
"provider": result.provider,
|
||||
"model": result.model,
|
||||
"seed": result.seed,
|
||||
"metadata": result.metadata,
|
||||
"variation": i + 1,
|
||||
})
|
||||
|
||||
logger.info("[Create Studio] ✅ Variation %d generated successfully", i + 1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Create Studio] ❌ Failed to generate variation %d: %s",
|
||||
i + 1, str(e), exc_info=True)
|
||||
results.append({
|
||||
"error": str(e),
|
||||
"variation": i + 1,
|
||||
})
|
||||
|
||||
# Return results
|
||||
return {
|
||||
"success": True,
|
||||
"request": {
|
||||
"prompt": request.prompt,
|
||||
"enhanced_prompt": prompt if request.enhance_prompt else None,
|
||||
"template_id": request.template_id,
|
||||
"template_name": template.name if template else None,
|
||||
"provider": provider_name,
|
||||
"model": model,
|
||||
"dimensions": f"{width}x{height}",
|
||||
"quality": request.quality,
|
||||
"num_variations": request.num_variations,
|
||||
},
|
||||
"results": results,
|
||||
"total_generated": sum(1 for r in results if "image_bytes" in r),
|
||||
"total_failed": sum(1 for r in results if "error" in r),
|
||||
}
|
||||
|
||||
def get_templates(
|
||||
self,
|
||||
platform: Optional[Platform] = None,
|
||||
category: Optional[TemplateCategory] = None
|
||||
) -> List[ImageTemplate]:
|
||||
"""Get available templates.
|
||||
|
||||
Args:
|
||||
platform: Filter by platform
|
||||
category: Filter by category
|
||||
|
||||
Returns:
|
||||
List of templates
|
||||
"""
|
||||
if platform:
|
||||
return self.template_manager.get_by_platform(platform)
|
||||
elif category:
|
||||
return self.template_manager.get_by_category(category)
|
||||
else:
|
||||
return self.template_manager.get_all_templates()
|
||||
|
||||
def search_templates(self, query: str) -> List[ImageTemplate]:
|
||||
"""Search templates by query.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
List of matching templates
|
||||
"""
|
||||
return self.template_manager.search(query)
|
||||
|
||||
def recommend_templates(
|
||||
self,
|
||||
use_case: str,
|
||||
platform: Optional[Platform] = None
|
||||
) -> List[ImageTemplate]:
|
||||
"""Recommend templates based on use case.
|
||||
|
||||
Args:
|
||||
use_case: Description of use case
|
||||
platform: Optional platform filter
|
||||
|
||||
Returns:
|
||||
List of recommended templates
|
||||
"""
|
||||
return self.template_manager.recommend_for_use_case(use_case, platform)
|
||||
|
||||
458
backend/services/image_studio/edit_service.py
Normal file
458
backend/services/image_studio/edit_service.py
Normal file
@@ -0,0 +1,458 @@
|
||||
"""Edit Studio service for AI-powered image editing and transformations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from services.llm_providers.main_image_editing import edit_image as huggingface_edit_image
|
||||
from services.stability_service import StabilityAIService
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("image_studio.edit")
|
||||
|
||||
|
||||
EditOperationType = Literal[
|
||||
"remove_background",
|
||||
"inpaint",
|
||||
"outpaint",
|
||||
"search_replace",
|
||||
"search_recolor",
|
||||
"relight",
|
||||
"general_edit",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class EditStudioRequest:
|
||||
"""Normalized request payload for Edit Studio operations."""
|
||||
|
||||
image_base64: str
|
||||
operation: EditOperationType
|
||||
prompt: Optional[str] = None
|
||||
negative_prompt: Optional[str] = None
|
||||
mask_base64: Optional[str] = None
|
||||
search_prompt: Optional[str] = None
|
||||
select_prompt: Optional[str] = None
|
||||
background_image_base64: Optional[str] = None
|
||||
lighting_image_base64: Optional[str] = None
|
||||
expand_left: Optional[int] = None
|
||||
expand_right: Optional[int] = None
|
||||
expand_up: Optional[int] = None
|
||||
expand_down: Optional[int] = None
|
||||
provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
style_preset: Optional[str] = None
|
||||
guidance_scale: Optional[float] = None
|
||||
steps: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
output_format: str = "png"
|
||||
options: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class EditStudioService:
|
||||
"""Service layer orchestrating Edit Studio operations."""
|
||||
|
||||
SUPPORTED_OPERATIONS: Dict[EditOperationType, Dict[str, Any]] = {
|
||||
"remove_background": {
|
||||
"label": "Remove Background",
|
||||
"description": "Isolate the main subject and remove the background.",
|
||||
"provider": "stability",
|
||||
"async": False,
|
||||
"fields": {
|
||||
"prompt": False,
|
||||
"mask": False,
|
||||
"negative_prompt": False,
|
||||
"search_prompt": False,
|
||||
"select_prompt": False,
|
||||
"background": False,
|
||||
"lighting": False,
|
||||
"expansion": False,
|
||||
},
|
||||
},
|
||||
"inpaint": {
|
||||
"label": "Inpaint & Fix",
|
||||
"description": "Edit specific regions using prompts and optional masks.",
|
||||
"provider": "stability",
|
||||
"async": False,
|
||||
"fields": {
|
||||
"prompt": True,
|
||||
"mask": True,
|
||||
"negative_prompt": True,
|
||||
"search_prompt": False,
|
||||
"select_prompt": False,
|
||||
"background": False,
|
||||
"lighting": False,
|
||||
"expansion": False,
|
||||
},
|
||||
},
|
||||
"outpaint": {
|
||||
"label": "Outpaint",
|
||||
"description": "Extend the canvas in any direction with smart fill.",
|
||||
"provider": "stability",
|
||||
"async": False,
|
||||
"fields": {
|
||||
"prompt": False,
|
||||
"mask": False,
|
||||
"negative_prompt": True,
|
||||
"search_prompt": False,
|
||||
"select_prompt": False,
|
||||
"background": False,
|
||||
"lighting": False,
|
||||
"expansion": True,
|
||||
},
|
||||
},
|
||||
"search_replace": {
|
||||
"label": "Search & Replace",
|
||||
"description": "Locate objects via search prompt and replace them.",
|
||||
"provider": "stability",
|
||||
"async": False,
|
||||
"fields": {
|
||||
"prompt": True,
|
||||
"mask": False,
|
||||
"negative_prompt": False,
|
||||
"search_prompt": True,
|
||||
"select_prompt": False,
|
||||
"background": False,
|
||||
"lighting": False,
|
||||
"expansion": False,
|
||||
},
|
||||
},
|
||||
"search_recolor": {
|
||||
"label": "Search & Recolor",
|
||||
"description": "Select elements via prompt and recolor them.",
|
||||
"provider": "stability",
|
||||
"async": False,
|
||||
"fields": {
|
||||
"prompt": True,
|
||||
"mask": False,
|
||||
"negative_prompt": False,
|
||||
"search_prompt": False,
|
||||
"select_prompt": True,
|
||||
"background": False,
|
||||
"lighting": False,
|
||||
"expansion": False,
|
||||
},
|
||||
},
|
||||
"relight": {
|
||||
"label": "Replace Background & Relight",
|
||||
"description": "Swap backgrounds and relight using reference images.",
|
||||
"provider": "stability",
|
||||
"async": True,
|
||||
"fields": {
|
||||
"prompt": False,
|
||||
"mask": False,
|
||||
"negative_prompt": False,
|
||||
"search_prompt": False,
|
||||
"select_prompt": False,
|
||||
"background": True,
|
||||
"lighting": True,
|
||||
"expansion": False,
|
||||
},
|
||||
},
|
||||
"general_edit": {
|
||||
"label": "Prompt-based Edit",
|
||||
"description": "Free-form editing powered by Hugging Face image-to-image models.",
|
||||
"provider": "huggingface",
|
||||
"async": False,
|
||||
"fields": {
|
||||
"prompt": True,
|
||||
"mask": False,
|
||||
"negative_prompt": True,
|
||||
"search_prompt": False,
|
||||
"select_prompt": False,
|
||||
"background": False,
|
||||
"lighting": False,
|
||||
"expansion": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
logger.info("[Edit Studio] Initialized edit service")
|
||||
|
||||
@staticmethod
|
||||
def _decode_base64_image(value: Optional[str]) -> Optional[bytes]:
|
||||
"""Decode a base64 (or data URL) string to bytes."""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Handle data URLs (data:image/png;base64,...)
|
||||
if value.startswith("data:"):
|
||||
_, b64data = value.split(",", 1)
|
||||
else:
|
||||
b64data = value
|
||||
|
||||
return base64.b64decode(b64data)
|
||||
except Exception as exc:
|
||||
logger.error(f"[Edit Studio] Failed to decode base64 image: {exc}")
|
||||
raise ValueError("Invalid base64 image payload") from exc
|
||||
|
||||
@staticmethod
|
||||
def _image_bytes_to_metadata(image_bytes: bytes) -> Dict[str, Any]:
|
||||
"""Extract width/height metadata from image bytes."""
|
||||
with Image.open(io.BytesIO(image_bytes)) as img:
|
||||
return {
|
||||
"width": img.width,
|
||||
"height": img.height,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _bytes_to_base64(image_bytes: bytes, output_format: str = "png") -> str:
|
||||
"""Convert raw bytes to base64 data URL."""
|
||||
b64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
return f"data:image/{output_format};base64,{b64}"
|
||||
|
||||
def list_operations(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Expose supported operations for UI rendering."""
|
||||
return self.SUPPORTED_OPERATIONS
|
||||
|
||||
async def process_edit(
|
||||
self,
|
||||
request: EditStudioRequest,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Process edit request and return normalized response."""
|
||||
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_editing_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
logger.info(f"[Edit Studio] 🛂 Running pre-flight validation for user {user_id}")
|
||||
validate_image_editing_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
)
|
||||
logger.info("[Edit Studio] ✅ Pre-flight validation passed")
|
||||
except HTTPException:
|
||||
logger.error("[Edit Studio] ❌ Pre-flight validation failed")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
logger.warning("[Edit Studio] ⚠️ No user_id provided - skipping pre-flight validation")
|
||||
|
||||
image_bytes = self._decode_base64_image(request.image_base64)
|
||||
if not image_bytes:
|
||||
raise ValueError("Primary image payload is required")
|
||||
|
||||
mask_bytes = self._decode_base64_image(request.mask_base64)
|
||||
background_bytes = self._decode_base64_image(request.background_image_base64)
|
||||
lighting_bytes = self._decode_base64_image(request.lighting_image_base64)
|
||||
|
||||
operation = request.operation
|
||||
logger.info("[Edit Studio] Processing operation='%s' for user=%s", operation, user_id)
|
||||
|
||||
if operation not in self.SUPPORTED_OPERATIONS:
|
||||
raise ValueError(f"Unsupported edit operation: {operation}")
|
||||
|
||||
if operation in {"remove_background", "inpaint", "outpaint", "search_replace", "search_recolor", "relight"}:
|
||||
image_bytes = await self._handle_stability_edit(
|
||||
operation=operation,
|
||||
request=request,
|
||||
image_bytes=image_bytes,
|
||||
mask_bytes=mask_bytes,
|
||||
background_bytes=background_bytes,
|
||||
lighting_bytes=lighting_bytes,
|
||||
)
|
||||
else:
|
||||
image_bytes = await self._handle_general_edit(
|
||||
request=request,
|
||||
image_bytes=image_bytes,
|
||||
mask_bytes=mask_bytes,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
metadata = self._image_bytes_to_metadata(image_bytes)
|
||||
metadata.update(
|
||||
{
|
||||
"operation": operation,
|
||||
"style_preset": request.style_preset,
|
||||
"provider": self.SUPPORTED_OPERATIONS[operation]["provider"],
|
||||
}
|
||||
)
|
||||
|
||||
response = {
|
||||
"success": True,
|
||||
"operation": operation,
|
||||
"provider": metadata["provider"],
|
||||
"image_base64": self._bytes_to_base64(image_bytes, request.output_format),
|
||||
"width": metadata["width"],
|
||||
"height": metadata["height"],
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
logger.info("[Edit Studio] ✅ Operation '%s' completed", operation)
|
||||
return response
|
||||
|
||||
async def _handle_stability_edit(
|
||||
self,
|
||||
operation: EditOperationType,
|
||||
request: EditStudioRequest,
|
||||
image_bytes: bytes,
|
||||
mask_bytes: Optional[bytes],
|
||||
background_bytes: Optional[bytes],
|
||||
lighting_bytes: Optional[bytes],
|
||||
) -> bytes:
|
||||
"""Execute Stability AI edit workflows."""
|
||||
stability_service = StabilityAIService()
|
||||
|
||||
async with stability_service:
|
||||
if operation == "remove_background":
|
||||
result = await stability_service.remove_background(
|
||||
image=image_bytes,
|
||||
output_format=request.output_format,
|
||||
)
|
||||
elif operation == "inpaint":
|
||||
if not request.prompt:
|
||||
raise ValueError("Prompt is required for inpainting")
|
||||
result = await stability_service.inpaint(
|
||||
image=image_bytes,
|
||||
prompt=request.prompt,
|
||||
mask=mask_bytes,
|
||||
negative_prompt=request.negative_prompt,
|
||||
output_format=request.output_format,
|
||||
style_preset=request.style_preset,
|
||||
grow_mask=request.options.get("grow_mask", 5),
|
||||
)
|
||||
elif operation == "outpaint":
|
||||
result = await stability_service.outpaint(
|
||||
image=image_bytes,
|
||||
prompt=request.prompt,
|
||||
negative_prompt=request.negative_prompt,
|
||||
output_format=request.output_format,
|
||||
left=request.expand_left or 0,
|
||||
right=request.expand_right or 0,
|
||||
up=request.expand_up or 0,
|
||||
down=request.expand_down or 0,
|
||||
style_preset=request.style_preset,
|
||||
)
|
||||
elif operation == "search_replace":
|
||||
if not (request.prompt and request.search_prompt):
|
||||
raise ValueError("Both prompt and search_prompt are required for search & replace")
|
||||
result = await stability_service.search_and_replace(
|
||||
image=image_bytes,
|
||||
prompt=request.prompt,
|
||||
search_prompt=request.search_prompt,
|
||||
output_format=request.output_format,
|
||||
)
|
||||
elif operation == "search_recolor":
|
||||
if not (request.prompt and request.select_prompt):
|
||||
raise ValueError("Both prompt and select_prompt are required for search & recolor")
|
||||
result = await stability_service.search_and_recolor(
|
||||
image=image_bytes,
|
||||
prompt=request.prompt,
|
||||
select_prompt=request.select_prompt,
|
||||
output_format=request.output_format,
|
||||
)
|
||||
elif operation == "relight":
|
||||
if not background_bytes and not lighting_bytes:
|
||||
raise ValueError("At least one reference (background or lighting) is required for relight")
|
||||
result = await stability_service.replace_background_and_relight(
|
||||
subject_image=image_bytes,
|
||||
background_reference=background_bytes,
|
||||
light_reference=lighting_bytes,
|
||||
output_format=request.output_format,
|
||||
)
|
||||
if isinstance(result, dict) and result.get("id"):
|
||||
result = await self._poll_stability_result(
|
||||
stability_service,
|
||||
generation_id=result["id"],
|
||||
output_format=request.output_format,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported Stability operation: {operation}")
|
||||
|
||||
return self._extract_image_bytes(result)
|
||||
|
||||
async def _handle_general_edit(
|
||||
self,
|
||||
request: EditStudioRequest,
|
||||
image_bytes: bytes,
|
||||
mask_bytes: Optional[bytes],
|
||||
user_id: Optional[str],
|
||||
) -> bytes:
|
||||
"""Execute Hugging Face powered general editing (synchronous API)."""
|
||||
if not request.prompt:
|
||||
raise ValueError("Prompt is required for general edits")
|
||||
|
||||
options = {
|
||||
"provider": request.provider or "huggingface",
|
||||
"model": request.model,
|
||||
"guidance_scale": request.guidance_scale,
|
||||
"steps": request.steps,
|
||||
"seed": request.seed,
|
||||
}
|
||||
|
||||
# huggingface edit is synchronous - run in thread
|
||||
result = await asyncio.to_thread(
|
||||
huggingface_edit_image,
|
||||
image_bytes,
|
||||
request.prompt,
|
||||
options,
|
||||
user_id,
|
||||
)
|
||||
|
||||
return result.image_bytes
|
||||
|
||||
@staticmethod
|
||||
def _extract_image_bytes(result: Any) -> bytes:
|
||||
"""Normalize Stability responses into raw image bytes."""
|
||||
if isinstance(result, bytes):
|
||||
return result
|
||||
|
||||
if isinstance(result, dict):
|
||||
artifacts = result.get("artifacts") or result.get("data") or result.get("images") or []
|
||||
for artifact in artifacts:
|
||||
if isinstance(artifact, dict):
|
||||
if artifact.get("base64"):
|
||||
return base64.b64decode(artifact["base64"])
|
||||
if artifact.get("b64_json"):
|
||||
return base64.b64decode(artifact["b64_json"])
|
||||
|
||||
raise RuntimeError("Unable to extract image bytes from provider response")
|
||||
|
||||
async def _poll_stability_result(
|
||||
self,
|
||||
stability_service: StabilityAIService,
|
||||
generation_id: str,
|
||||
output_format: str,
|
||||
timeout_seconds: int = 240,
|
||||
interval_seconds: float = 2.0,
|
||||
) -> bytes:
|
||||
"""Poll Stability async endpoint until result is ready."""
|
||||
elapsed = 0.0
|
||||
while elapsed < timeout_seconds:
|
||||
result = await stability_service.get_generation_result(
|
||||
generation_id=generation_id,
|
||||
accept_type="*/*",
|
||||
)
|
||||
|
||||
if isinstance(result, bytes):
|
||||
return result
|
||||
|
||||
if isinstance(result, dict):
|
||||
state = (result.get("state") or result.get("status") or "").lower()
|
||||
if state in {"succeeded", "success", "ready", "completed"}:
|
||||
return self._extract_image_bytes(result)
|
||||
if state in {"failed", "error"}:
|
||||
raise RuntimeError(f"Stability generation failed: {result}")
|
||||
|
||||
await asyncio.sleep(interval_seconds)
|
||||
elapsed += interval_seconds
|
||||
|
||||
raise RuntimeError("Timed out waiting for Stability generation result")
|
||||
|
||||
|
||||
304
backend/services/image_studio/studio_manager.py
Normal file
304
backend/services/image_studio/studio_manager.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""Image Studio Manager - Main orchestration service for all image operations."""
|
||||
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from .create_service import CreateStudioService, CreateStudioRequest
|
||||
from .edit_service import EditStudioService, EditStudioRequest
|
||||
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
|
||||
from .templates import Platform, TemplateCategory, ImageTemplate
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("image_studio.manager")
|
||||
|
||||
|
||||
class ImageStudioManager:
|
||||
"""Main manager for Image Studio operations."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Image Studio Manager."""
|
||||
self.create_service = CreateStudioService()
|
||||
self.edit_service = EditStudioService()
|
||||
self.upscale_service = UpscaleStudioService()
|
||||
logger.info("[Image Studio Manager] Initialized successfully")
|
||||
|
||||
# ====================
|
||||
# CREATE STUDIO
|
||||
# ====================
|
||||
|
||||
async def create_image(
|
||||
self,
|
||||
request: CreateStudioRequest,
|
||||
user_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create/generate image using Create Studio.
|
||||
|
||||
Args:
|
||||
request: Create studio request
|
||||
user_id: User ID for validation
|
||||
|
||||
Returns:
|
||||
Dictionary with generation results
|
||||
"""
|
||||
logger.info("[Image Studio] Create image request from user: %s", user_id)
|
||||
return await self.create_service.generate(request, user_id=user_id)
|
||||
|
||||
# ====================
|
||||
# EDIT STUDIO
|
||||
# ====================
|
||||
|
||||
async def edit_image(
|
||||
self,
|
||||
request: EditStudioRequest,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run Edit Studio operations."""
|
||||
logger.info("[Image Studio] Edit image request from user: %s", user_id)
|
||||
return await self.edit_service.process_edit(request, user_id=user_id)
|
||||
|
||||
def get_edit_operations(self) -> Dict[str, Any]:
|
||||
"""Expose edit operations for UI."""
|
||||
return self.edit_service.list_operations()
|
||||
|
||||
# ====================
|
||||
# UPSCALE STUDIO
|
||||
# ====================
|
||||
|
||||
async def upscale_image(
|
||||
self,
|
||||
request: UpscaleStudioRequest,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run Upscale Studio operations."""
|
||||
logger.info("[Image Studio] Upscale request from user: %s", user_id)
|
||||
return await self.upscale_service.process_upscale(request, user_id=user_id)
|
||||
|
||||
def get_templates(
|
||||
self,
|
||||
platform: Optional[Platform] = None,
|
||||
category: Optional[TemplateCategory] = None
|
||||
) -> List[ImageTemplate]:
|
||||
"""Get available templates.
|
||||
|
||||
Args:
|
||||
platform: Filter by platform
|
||||
category: Filter by category
|
||||
|
||||
Returns:
|
||||
List of templates
|
||||
"""
|
||||
return self.create_service.get_templates(platform=platform, category=category)
|
||||
|
||||
def search_templates(self, query: str) -> List[ImageTemplate]:
|
||||
"""Search templates by query.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
List of matching templates
|
||||
"""
|
||||
return self.create_service.search_templates(query)
|
||||
|
||||
def recommend_templates(
|
||||
self,
|
||||
use_case: str,
|
||||
platform: Optional[Platform] = None
|
||||
) -> List[ImageTemplate]:
|
||||
"""Recommend templates based on use case.
|
||||
|
||||
Args:
|
||||
use_case: Use case description
|
||||
platform: Optional platform filter
|
||||
|
||||
Returns:
|
||||
List of recommended templates
|
||||
"""
|
||||
return self.create_service.recommend_templates(use_case, platform)
|
||||
|
||||
def get_providers(self) -> Dict[str, Any]:
|
||||
"""Get available image providers and their capabilities.
|
||||
|
||||
Returns:
|
||||
Dictionary of providers with capabilities
|
||||
"""
|
||||
return {
|
||||
"stability": {
|
||||
"name": "Stability AI",
|
||||
"models": ["ultra", "core", "sd3.5-large"],
|
||||
"capabilities": ["text-to-image", "editing", "upscaling", "control", "3d"],
|
||||
"max_resolution": (2048, 2048),
|
||||
"cost_range": "3-8 credits per image",
|
||||
},
|
||||
"wavespeed": {
|
||||
"name": "WaveSpeed AI",
|
||||
"models": ["ideogram-v3-turbo", "qwen-image"],
|
||||
"capabilities": ["text-to-image", "photorealistic", "fast-generation"],
|
||||
"max_resolution": (1024, 1024),
|
||||
"cost_range": "$0.05-$0.10 per image",
|
||||
},
|
||||
"huggingface": {
|
||||
"name": "HuggingFace",
|
||||
"models": ["FLUX.1-Krea-dev", "RunwayML"],
|
||||
"capabilities": ["text-to-image", "image-to-image"],
|
||||
"max_resolution": (1024, 1024),
|
||||
"cost_range": "Free tier available",
|
||||
},
|
||||
"gemini": {
|
||||
"name": "Google Gemini",
|
||||
"models": ["imagen-3.0"],
|
||||
"capabilities": ["text-to-image", "conversational-editing"],
|
||||
"max_resolution": (1024, 1024),
|
||||
"cost_range": "Free tier available",
|
||||
}
|
||||
}
|
||||
|
||||
# ====================
|
||||
# COST ESTIMATION
|
||||
# ====================
|
||||
|
||||
def estimate_cost(
|
||||
self,
|
||||
provider: str,
|
||||
model: Optional[str],
|
||||
operation: str,
|
||||
num_images: int = 1,
|
||||
resolution: Optional[tuple[int, int]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Estimate cost for image operations.
|
||||
|
||||
Args:
|
||||
provider: Provider name
|
||||
model: Model name
|
||||
operation: Operation type (generate, edit, upscale, etc.)
|
||||
num_images: Number of images
|
||||
resolution: Image resolution (width, height)
|
||||
|
||||
Returns:
|
||||
Cost estimation details
|
||||
"""
|
||||
# Base costs (adjust based on actual pricing)
|
||||
base_costs = {
|
||||
"stability": {
|
||||
"ultra": 0.08, # 8 credits
|
||||
"core": 0.03, # 3 credits
|
||||
"sd3": 0.065, # 6.5 credits
|
||||
},
|
||||
"wavespeed": {
|
||||
"ideogram-v3-turbo": 0.10,
|
||||
"qwen-image": 0.05,
|
||||
},
|
||||
"huggingface": {
|
||||
"default": 0.0, # Free tier
|
||||
},
|
||||
"gemini": {
|
||||
"default": 0.0, # Free tier
|
||||
}
|
||||
}
|
||||
|
||||
# Get base cost
|
||||
provider_costs = base_costs.get(provider, {})
|
||||
cost_per_image = provider_costs.get(model, provider_costs.get("default", 0.0))
|
||||
|
||||
# Calculate total
|
||||
total_cost = cost_per_image * num_images
|
||||
|
||||
return {
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"operation": operation,
|
||||
"num_images": num_images,
|
||||
"resolution": f"{resolution[0]}x{resolution[1]}" if resolution else "default",
|
||||
"cost_per_image": cost_per_image,
|
||||
"total_cost": total_cost,
|
||||
"currency": "USD",
|
||||
"estimated": True,
|
||||
}
|
||||
|
||||
# ====================
|
||||
# PLATFORM SPECS
|
||||
# ====================
|
||||
|
||||
def get_platform_specs(self, platform: Platform) -> Dict[str, Any]:
|
||||
"""Get platform specifications and requirements.
|
||||
|
||||
Args:
|
||||
platform: Platform to get specs for
|
||||
|
||||
Returns:
|
||||
Platform specifications
|
||||
"""
|
||||
specs = {
|
||||
Platform.INSTAGRAM: {
|
||||
"name": "Instagram",
|
||||
"formats": [
|
||||
{"name": "Feed Post (Square)", "ratio": "1:1", "size": "1080x1080"},
|
||||
{"name": "Feed Post (Portrait)", "ratio": "4:5", "size": "1080x1350"},
|
||||
{"name": "Story", "ratio": "9:16", "size": "1080x1920"},
|
||||
{"name": "Reel", "ratio": "9:16", "size": "1080x1920"},
|
||||
],
|
||||
"file_types": ["JPG", "PNG"],
|
||||
"max_file_size": "30MB",
|
||||
},
|
||||
Platform.FACEBOOK: {
|
||||
"name": "Facebook",
|
||||
"formats": [
|
||||
{"name": "Feed Post", "ratio": "1.91:1", "size": "1200x630"},
|
||||
{"name": "Feed Post (Square)", "ratio": "1:1", "size": "1080x1080"},
|
||||
{"name": "Story", "ratio": "9:16", "size": "1080x1920"},
|
||||
{"name": "Cover Photo", "ratio": "16:9", "size": "820x312"},
|
||||
],
|
||||
"file_types": ["JPG", "PNG"],
|
||||
"max_file_size": "30MB",
|
||||
},
|
||||
Platform.TWITTER: {
|
||||
"name": "Twitter/X",
|
||||
"formats": [
|
||||
{"name": "Post", "ratio": "16:9", "size": "1200x675"},
|
||||
{"name": "Card", "ratio": "2:1", "size": "1200x600"},
|
||||
{"name": "Header", "ratio": "3:1", "size": "1500x500"},
|
||||
],
|
||||
"file_types": ["JPG", "PNG", "GIF"],
|
||||
"max_file_size": "5MB",
|
||||
},
|
||||
Platform.LINKEDIN: {
|
||||
"name": "LinkedIn",
|
||||
"formats": [
|
||||
{"name": "Feed Post", "ratio": "1.91:1", "size": "1200x628"},
|
||||
{"name": "Feed Post (Square)", "ratio": "1:1", "size": "1080x1080"},
|
||||
{"name": "Article", "ratio": "2:1", "size": "1200x627"},
|
||||
{"name": "Company Cover", "ratio": "4:1", "size": "1128x191"},
|
||||
],
|
||||
"file_types": ["JPG", "PNG"],
|
||||
"max_file_size": "8MB",
|
||||
},
|
||||
Platform.YOUTUBE: {
|
||||
"name": "YouTube",
|
||||
"formats": [
|
||||
{"name": "Thumbnail", "ratio": "16:9", "size": "1280x720"},
|
||||
{"name": "Channel Art", "ratio": "16:9", "size": "2560x1440"},
|
||||
],
|
||||
"file_types": ["JPG", "PNG"],
|
||||
"max_file_size": "2MB",
|
||||
},
|
||||
Platform.PINTEREST: {
|
||||
"name": "Pinterest",
|
||||
"formats": [
|
||||
{"name": "Pin", "ratio": "2:3", "size": "1000x1500"},
|
||||
{"name": "Story Pin", "ratio": "9:16", "size": "1080x1920"},
|
||||
],
|
||||
"file_types": ["JPG", "PNG"],
|
||||
"max_file_size": "20MB",
|
||||
},
|
||||
Platform.TIKTOK: {
|
||||
"name": "TikTok",
|
||||
"formats": [
|
||||
{"name": "Video Cover", "ratio": "9:16", "size": "1080x1920"},
|
||||
],
|
||||
"file_types": ["JPG", "PNG"],
|
||||
"max_file_size": "10MB",
|
||||
},
|
||||
}
|
||||
|
||||
return specs.get(platform, {})
|
||||
|
||||
555
backend/services/image_studio/templates.py
Normal file
555
backend/services/image_studio/templates.py
Normal file
@@ -0,0 +1,555 @@
|
||||
"""Template system for Image Studio with platform-specific presets."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Literal
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Platform(str, Enum):
|
||||
"""Supported social media platforms."""
|
||||
INSTAGRAM = "instagram"
|
||||
FACEBOOK = "facebook"
|
||||
TWITTER = "twitter"
|
||||
LINKEDIN = "linkedin"
|
||||
YOUTUBE = "youtube"
|
||||
PINTEREST = "pinterest"
|
||||
TIKTOK = "tiktok"
|
||||
BLOG = "blog"
|
||||
EMAIL = "email"
|
||||
WEBSITE = "website"
|
||||
|
||||
|
||||
class TemplateCategory(str, Enum):
|
||||
"""Template categories."""
|
||||
SOCIAL_MEDIA = "social_media"
|
||||
BLOG_CONTENT = "blog_content"
|
||||
AD_CREATIVE = "ad_creative"
|
||||
PRODUCT = "product"
|
||||
BRAND_ASSETS = "brand_assets"
|
||||
EMAIL_MARKETING = "email_marketing"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AspectRatio:
|
||||
"""Aspect ratio configuration."""
|
||||
ratio: str # e.g., "1:1", "16:9"
|
||||
width: int
|
||||
height: int
|
||||
label: str # e.g., "Square", "Widescreen"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageTemplate:
|
||||
"""Image generation template."""
|
||||
id: str
|
||||
name: str
|
||||
category: TemplateCategory
|
||||
platform: Optional[Platform]
|
||||
aspect_ratio: AspectRatio
|
||||
description: str
|
||||
recommended_provider: str
|
||||
style_preset: str
|
||||
quality: Literal["draft", "standard", "premium"]
|
||||
prompt_template: Optional[str] = None
|
||||
negative_prompt_template: Optional[str] = None
|
||||
use_cases: List[str] = None
|
||||
|
||||
|
||||
class PlatformTemplates:
|
||||
"""Platform-specific template definitions."""
|
||||
|
||||
# Aspect Ratios
|
||||
SQUARE_1_1 = AspectRatio("1:1", 1080, 1080, "Square")
|
||||
PORTRAIT_4_5 = AspectRatio("4:5", 1080, 1350, "Portrait")
|
||||
STORY_9_16 = AspectRatio("9:16", 1080, 1920, "Story/Reel")
|
||||
LANDSCAPE_16_9 = AspectRatio("16:9", 1920, 1080, "Landscape")
|
||||
WIDE_21_9 = AspectRatio("21:9", 2560, 1080, "Ultra Wide")
|
||||
TWITTER_2_1 = AspectRatio("2:1", 1200, 600, "Twitter Card")
|
||||
TWITTER_3_1 = AspectRatio("3:1", 1500, 500, "Twitter Header")
|
||||
FACEBOOK_1_91_1 = AspectRatio("1.91:1", 1200, 630, "Facebook Feed")
|
||||
LINKEDIN_1_91_1 = AspectRatio("1.91:1", 1200, 628, "LinkedIn Feed")
|
||||
LINKEDIN_2_1 = AspectRatio("2:1", 1200, 627, "LinkedIn Article")
|
||||
LINKEDIN_4_1 = AspectRatio("4:1", 1128, 191, "LinkedIn Cover")
|
||||
PINTEREST_2_3 = AspectRatio("2:3", 1000, 1500, "Pinterest Pin")
|
||||
YOUTUBE_16_9 = AspectRatio("16:9", 1280, 720, "YouTube Thumbnail")
|
||||
FACEBOOK_COVER_16_9 = AspectRatio("16:9", 820, 312, "Facebook Cover")
|
||||
|
||||
@classmethod
|
||||
def get_platform_templates(cls) -> Dict[Platform, List[ImageTemplate]]:
|
||||
"""Get all platform-specific templates."""
|
||||
return {
|
||||
Platform.INSTAGRAM: cls._instagram_templates(),
|
||||
Platform.FACEBOOK: cls._facebook_templates(),
|
||||
Platform.TWITTER: cls._twitter_templates(),
|
||||
Platform.LINKEDIN: cls._linkedin_templates(),
|
||||
Platform.YOUTUBE: cls._youtube_templates(),
|
||||
Platform.PINTEREST: cls._pinterest_templates(),
|
||||
Platform.TIKTOK: cls._tiktok_templates(),
|
||||
Platform.BLOG: cls._blog_templates(),
|
||||
Platform.EMAIL: cls._email_templates(),
|
||||
Platform.WEBSITE: cls._website_templates(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _instagram_templates(cls) -> List[ImageTemplate]:
|
||||
"""Instagram templates."""
|
||||
return [
|
||||
ImageTemplate(
|
||||
id="instagram_feed_square",
|
||||
name="Instagram Feed Post (Square)",
|
||||
category=TemplateCategory.SOCIAL_MEDIA,
|
||||
platform=Platform.INSTAGRAM,
|
||||
aspect_ratio=cls.SQUARE_1_1,
|
||||
description="Perfect for Instagram feed posts with maximum visibility",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="photographic",
|
||||
quality="premium",
|
||||
use_cases=["Product showcase", "Lifestyle posts", "Brand content"]
|
||||
),
|
||||
ImageTemplate(
|
||||
id="instagram_feed_portrait",
|
||||
name="Instagram Feed Post (Portrait)",
|
||||
category=TemplateCategory.SOCIAL_MEDIA,
|
||||
platform=Platform.INSTAGRAM,
|
||||
aspect_ratio=cls.PORTRAIT_4_5,
|
||||
description="Vertical format for maximum feed real estate",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="photographic",
|
||||
quality="premium",
|
||||
use_cases=["Fashion", "Food", "Product photography"]
|
||||
),
|
||||
ImageTemplate(
|
||||
id="instagram_story",
|
||||
name="Instagram Story",
|
||||
category=TemplateCategory.SOCIAL_MEDIA,
|
||||
platform=Platform.INSTAGRAM,
|
||||
aspect_ratio=cls.STORY_9_16,
|
||||
description="Full-screen vertical stories",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="digital-art",
|
||||
quality="standard",
|
||||
use_cases=["Behind-the-scenes", "Announcements", "Quick updates"]
|
||||
),
|
||||
ImageTemplate(
|
||||
id="instagram_reel_cover",
|
||||
name="Instagram Reel Cover",
|
||||
category=TemplateCategory.SOCIAL_MEDIA,
|
||||
platform=Platform.INSTAGRAM,
|
||||
aspect_ratio=cls.STORY_9_16,
|
||||
description="Eye-catching reel cover images",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="cinematic",
|
||||
quality="premium",
|
||||
use_cases=["Video covers", "Thumbnails", "Highlights"]
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _facebook_templates(cls) -> List[ImageTemplate]:
|
||||
"""Facebook templates."""
|
||||
return [
|
||||
ImageTemplate(
|
||||
id="facebook_feed",
|
||||
name="Facebook Feed Post",
|
||||
category=TemplateCategory.SOCIAL_MEDIA,
|
||||
platform=Platform.FACEBOOK,
|
||||
aspect_ratio=cls.FACEBOOK_1_91_1,
|
||||
description="Optimized for Facebook news feed",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="photographic",
|
||||
quality="standard",
|
||||
use_cases=["Page posts", "Shared content", "Community posts"]
|
||||
),
|
||||
ImageTemplate(
|
||||
id="facebook_feed_square",
|
||||
name="Facebook Feed Post (Square)",
|
||||
category=TemplateCategory.SOCIAL_MEDIA,
|
||||
platform=Platform.FACEBOOK,
|
||||
aspect_ratio=cls.SQUARE_1_1,
|
||||
description="Square format for feed posts",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="photographic",
|
||||
quality="standard",
|
||||
use_cases=["Page posts", "Product highlights"]
|
||||
),
|
||||
ImageTemplate(
|
||||
id="facebook_story",
|
||||
name="Facebook Story",
|
||||
category=TemplateCategory.SOCIAL_MEDIA,
|
||||
platform=Platform.FACEBOOK,
|
||||
aspect_ratio=cls.STORY_9_16,
|
||||
description="Full-screen vertical stories",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="digital-art",
|
||||
quality="standard",
|
||||
use_cases=["Quick updates", "Promotions", "Events"]
|
||||
),
|
||||
ImageTemplate(
|
||||
id="facebook_cover",
|
||||
name="Facebook Cover Photo",
|
||||
category=TemplateCategory.SOCIAL_MEDIA,
|
||||
platform=Platform.FACEBOOK,
|
||||
aspect_ratio=cls.FACEBOOK_COVER_16_9,
|
||||
description="Wide cover photo for pages",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="photographic",
|
||||
quality="premium",
|
||||
use_cases=["Page branding", "Events", "Seasonal updates"]
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _twitter_templates(cls) -> List[ImageTemplate]:
|
||||
"""Twitter/X templates."""
|
||||
return [
|
||||
ImageTemplate(
|
||||
id="twitter_post",
|
||||
name="Twitter/X Post",
|
||||
category=TemplateCategory.SOCIAL_MEDIA,
|
||||
platform=Platform.TWITTER,
|
||||
aspect_ratio=cls.LANDSCAPE_16_9,
|
||||
description="Optimized for Twitter feed",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="photographic",
|
||||
quality="standard",
|
||||
use_cases=["Tweets", "News", "Updates"]
|
||||
),
|
||||
ImageTemplate(
|
||||
id="twitter_card",
|
||||
name="Twitter Card",
|
||||
category=TemplateCategory.SOCIAL_MEDIA,
|
||||
platform=Platform.TWITTER,
|
||||
aspect_ratio=cls.TWITTER_2_1,
|
||||
description="Twitter card with link preview",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="digital-art",
|
||||
quality="standard",
|
||||
use_cases=["Link sharing", "Articles", "Blog posts"]
|
||||
),
|
||||
ImageTemplate(
|
||||
id="twitter_header",
|
||||
name="Twitter Header",
|
||||
category=TemplateCategory.SOCIAL_MEDIA,
|
||||
platform=Platform.TWITTER,
|
||||
aspect_ratio=cls.TWITTER_3_1,
|
||||
description="Profile header image",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="photographic",
|
||||
quality="premium",
|
||||
use_cases=["Profile branding", "Personal brand", "Business identity"]
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _linkedin_templates(cls) -> List[ImageTemplate]:
|
||||
"""LinkedIn templates."""
|
||||
return [
|
||||
ImageTemplate(
|
||||
id="linkedin_post",
|
||||
name="LinkedIn Post",
|
||||
category=TemplateCategory.SOCIAL_MEDIA,
|
||||
platform=Platform.LINKEDIN,
|
||||
aspect_ratio=cls.LINKEDIN_1_91_1,
|
||||
description="Professional feed posts",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="photographic",
|
||||
quality="premium",
|
||||
use_cases=["Professional content", "Industry news", "Thought leadership"]
|
||||
),
|
||||
ImageTemplate(
|
||||
id="linkedin_post_square",
|
||||
name="LinkedIn Post (Square)",
|
||||
category=TemplateCategory.SOCIAL_MEDIA,
|
||||
platform=Platform.LINKEDIN,
|
||||
aspect_ratio=cls.SQUARE_1_1,
|
||||
description="Square format for LinkedIn feed",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="photographic",
|
||||
quality="premium",
|
||||
use_cases=["Quick tips", "Infographics", "Quotes"]
|
||||
),
|
||||
ImageTemplate(
|
||||
id="linkedin_article",
|
||||
name="LinkedIn Article Header",
|
||||
category=TemplateCategory.SOCIAL_MEDIA,
|
||||
platform=Platform.LINKEDIN,
|
||||
aspect_ratio=cls.LINKEDIN_2_1,
|
||||
description="Article header images",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="photographic",
|
||||
quality="premium",
|
||||
use_cases=["Long-form content", "Articles", "Newsletters"]
|
||||
),
|
||||
ImageTemplate(
|
||||
id="linkedin_cover",
|
||||
name="LinkedIn Company Cover",
|
||||
category=TemplateCategory.SOCIAL_MEDIA,
|
||||
platform=Platform.LINKEDIN,
|
||||
aspect_ratio=cls.LINKEDIN_4_1,
|
||||
description="Company page cover photo",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="photographic",
|
||||
quality="premium",
|
||||
use_cases=["Company branding", "Recruitment", "Brand identity"]
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _youtube_templates(cls) -> List[ImageTemplate]:
|
||||
"""YouTube templates."""
|
||||
return [
|
||||
ImageTemplate(
|
||||
id="youtube_thumbnail",
|
||||
name="YouTube Thumbnail",
|
||||
category=TemplateCategory.SOCIAL_MEDIA,
|
||||
platform=Platform.YOUTUBE,
|
||||
aspect_ratio=cls.YOUTUBE_16_9,
|
||||
description="Eye-catching video thumbnails",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="cinematic",
|
||||
quality="premium",
|
||||
use_cases=["Video thumbnails", "Channel branding", "Playlists"]
|
||||
),
|
||||
ImageTemplate(
|
||||
id="youtube_channel_art",
|
||||
name="YouTube Channel Art",
|
||||
category=TemplateCategory.SOCIAL_MEDIA,
|
||||
platform=Platform.YOUTUBE,
|
||||
aspect_ratio=cls.LANDSCAPE_16_9,
|
||||
description="Channel banner art",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="photographic",
|
||||
quality="premium",
|
||||
use_cases=["Channel branding", "Personal brand", "Business identity"]
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _pinterest_templates(cls) -> List[ImageTemplate]:
|
||||
"""Pinterest templates."""
|
||||
return [
|
||||
ImageTemplate(
|
||||
id="pinterest_pin",
|
||||
name="Pinterest Pin",
|
||||
category=TemplateCategory.SOCIAL_MEDIA,
|
||||
platform=Platform.PINTEREST,
|
||||
aspect_ratio=cls.PINTEREST_2_3,
|
||||
description="Vertical pin format",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="photographic",
|
||||
quality="premium",
|
||||
use_cases=["Product pins", "DIY guides", "Recipes", "Inspiration"]
|
||||
),
|
||||
ImageTemplate(
|
||||
id="pinterest_story",
|
||||
name="Pinterest Story Pin",
|
||||
category=TemplateCategory.SOCIAL_MEDIA,
|
||||
platform=Platform.PINTEREST,
|
||||
aspect_ratio=cls.STORY_9_16,
|
||||
description="Full-screen story pins",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="digital-art",
|
||||
quality="standard",
|
||||
use_cases=["Step-by-step guides", "Tutorials", "Quick tips"]
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _tiktok_templates(cls) -> List[ImageTemplate]:
|
||||
"""TikTok templates."""
|
||||
return [
|
||||
ImageTemplate(
|
||||
id="tiktok_video_cover",
|
||||
name="TikTok Video Cover",
|
||||
category=TemplateCategory.SOCIAL_MEDIA,
|
||||
platform=Platform.TIKTOK,
|
||||
aspect_ratio=cls.STORY_9_16,
|
||||
description="Vertical video cover",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="cinematic",
|
||||
quality="premium",
|
||||
use_cases=["Video covers", "Thumbnails", "Profile highlights"]
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _blog_templates(cls) -> List[ImageTemplate]:
|
||||
"""Blog content templates."""
|
||||
return [
|
||||
ImageTemplate(
|
||||
id="blog_header",
|
||||
name="Blog Header",
|
||||
category=TemplateCategory.BLOG_CONTENT,
|
||||
platform=Platform.BLOG,
|
||||
aspect_ratio=cls.LANDSCAPE_16_9,
|
||||
description="Blog post featured image",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="photographic",
|
||||
quality="premium",
|
||||
use_cases=["Featured images", "Article headers", "Post thumbnails"]
|
||||
),
|
||||
ImageTemplate(
|
||||
id="blog_header_wide",
|
||||
name="Blog Header (Wide)",
|
||||
category=TemplateCategory.BLOG_CONTENT,
|
||||
platform=Platform.BLOG,
|
||||
aspect_ratio=cls.WIDE_21_9,
|
||||
description="Ultra-wide blog header",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="photographic",
|
||||
quality="premium",
|
||||
use_cases=["Hero sections", "Wide headers", "Landing pages"]
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _email_templates(cls) -> List[ImageTemplate]:
|
||||
"""Email marketing templates."""
|
||||
return [
|
||||
ImageTemplate(
|
||||
id="email_banner",
|
||||
name="Email Banner",
|
||||
category=TemplateCategory.EMAIL_MARKETING,
|
||||
platform=Platform.EMAIL,
|
||||
aspect_ratio=cls.LANDSCAPE_16_9,
|
||||
description="Email header banner",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="photographic",
|
||||
quality="standard",
|
||||
use_cases=["Email headers", "Newsletter banners", "Promotions"]
|
||||
),
|
||||
ImageTemplate(
|
||||
id="email_product",
|
||||
name="Email Product Image",
|
||||
category=TemplateCategory.EMAIL_MARKETING,
|
||||
platform=Platform.EMAIL,
|
||||
aspect_ratio=cls.SQUARE_1_1,
|
||||
description="Product showcase for emails",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="photographic",
|
||||
quality="premium",
|
||||
use_cases=["Product highlights", "Promotions", "Offers"]
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _website_templates(cls) -> List[ImageTemplate]:
|
||||
"""Website templates."""
|
||||
return [
|
||||
ImageTemplate(
|
||||
id="website_hero",
|
||||
name="Website Hero Image",
|
||||
category=TemplateCategory.BRAND_ASSETS,
|
||||
platform=Platform.WEBSITE,
|
||||
aspect_ratio=cls.WIDE_21_9,
|
||||
description="Hero section background",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="photographic",
|
||||
quality="premium",
|
||||
use_cases=["Hero sections", "Landing pages", "Home page banners"]
|
||||
),
|
||||
ImageTemplate(
|
||||
id="website_banner",
|
||||
name="Website Banner",
|
||||
category=TemplateCategory.BRAND_ASSETS,
|
||||
platform=Platform.WEBSITE,
|
||||
aspect_ratio=cls.LANDSCAPE_16_9,
|
||||
description="Section banners",
|
||||
recommended_provider="ideogram",
|
||||
style_preset="photographic",
|
||||
quality="premium",
|
||||
use_cases=["Section headers", "Category pages", "Feature sections"]
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class TemplateManager:
|
||||
"""Manager for image templates with search and recommendation."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize template manager."""
|
||||
self.templates = PlatformTemplates.get_platform_templates()
|
||||
self._all_templates: Optional[List[ImageTemplate]] = None
|
||||
|
||||
def get_all_templates(self) -> List[ImageTemplate]:
|
||||
"""Get all templates across all platforms."""
|
||||
if self._all_templates is None:
|
||||
self._all_templates = []
|
||||
for platform_templates in self.templates.values():
|
||||
self._all_templates.extend(platform_templates)
|
||||
return self._all_templates
|
||||
|
||||
def get_by_platform(self, platform: Platform) -> List[ImageTemplate]:
|
||||
"""Get templates for a specific platform."""
|
||||
return self.templates.get(platform, [])
|
||||
|
||||
def get_by_category(self, category: TemplateCategory) -> List[ImageTemplate]:
|
||||
"""Get templates by category."""
|
||||
all_templates = self.get_all_templates()
|
||||
return [t for t in all_templates if t.category == category]
|
||||
|
||||
def get_by_id(self, template_id: str) -> Optional[ImageTemplate]:
|
||||
"""Get template by ID."""
|
||||
all_templates = self.get_all_templates()
|
||||
for template in all_templates:
|
||||
if template.id == template_id:
|
||||
return template
|
||||
return None
|
||||
|
||||
def search(self, query: str) -> List[ImageTemplate]:
|
||||
"""Search templates by query."""
|
||||
query = query.lower()
|
||||
all_templates = self.get_all_templates()
|
||||
results = []
|
||||
|
||||
for template in all_templates:
|
||||
# Search in name, description, and use cases
|
||||
searchable = (
|
||||
template.name.lower() + " " +
|
||||
template.description.lower() + " " +
|
||||
" ".join(template.use_cases or []).lower()
|
||||
)
|
||||
if query in searchable:
|
||||
results.append(template)
|
||||
|
||||
return results
|
||||
|
||||
def recommend_for_use_case(self, use_case: str, platform: Optional[Platform] = None) -> List[ImageTemplate]:
|
||||
"""Recommend templates based on use case and platform."""
|
||||
use_case_lower = use_case.lower()
|
||||
all_templates = self.get_all_templates()
|
||||
|
||||
# Filter by platform if specified
|
||||
if platform:
|
||||
all_templates = [t for t in all_templates if t.platform == platform]
|
||||
|
||||
# Find matching templates
|
||||
matches = []
|
||||
for template in all_templates:
|
||||
if template.use_cases:
|
||||
for case in template.use_cases:
|
||||
if use_case_lower in case.lower():
|
||||
matches.append(template)
|
||||
break
|
||||
|
||||
return matches
|
||||
|
||||
def get_aspect_ratio_options(self) -> List[AspectRatio]:
|
||||
"""Get all available aspect ratios."""
|
||||
return [
|
||||
PlatformTemplates.SQUARE_1_1,
|
||||
PlatformTemplates.PORTRAIT_4_5,
|
||||
PlatformTemplates.STORY_9_16,
|
||||
PlatformTemplates.LANDSCAPE_16_9,
|
||||
PlatformTemplates.WIDE_21_9,
|
||||
PlatformTemplates.TWITTER_2_1,
|
||||
PlatformTemplates.TWITTER_3_1,
|
||||
PlatformTemplates.FACEBOOK_1_91_1,
|
||||
PlatformTemplates.LINKEDIN_1_91_1,
|
||||
PlatformTemplates.LINKEDIN_2_1,
|
||||
PlatformTemplates.LINKEDIN_4_1,
|
||||
PlatformTemplates.PINTEREST_2_3,
|
||||
PlatformTemplates.YOUTUBE_16_9,
|
||||
PlatformTemplates.FACEBOOK_COVER_16_9,
|
||||
]
|
||||
|
||||
154
backend/services/image_studio/upscale_service.py
Normal file
154
backend/services/image_studio/upscale_service.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import base64
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional, Dict, Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from PIL import Image
|
||||
|
||||
from services.stability_service import StabilityAIService
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("image_studio.upscale")
|
||||
|
||||
|
||||
UpscaleMode = Literal["fast", "conservative", "creative", "auto"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpscaleStudioRequest:
|
||||
image_base64: str
|
||||
mode: UpscaleMode = "auto"
|
||||
target_width: Optional[int] = None
|
||||
target_height: Optional[int] = None
|
||||
preset: Optional[str] = None # e.g., web/print/social
|
||||
prompt: Optional[str] = None # used for conservative/creative modes
|
||||
|
||||
|
||||
class UpscaleStudioService:
|
||||
"""Handles image upscaling workflows."""
|
||||
|
||||
def __init__(self):
|
||||
logger.info("[Upscale Studio] Service initialized")
|
||||
|
||||
async def process_upscale(
|
||||
self,
|
||||
request: UpscaleStudioRequest,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_upscale_operations
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
logger.info("[Upscale Studio] 🛂 Running pre-flight validation for user %s", user_id)
|
||||
validate_image_upscale_operations(pricing_service=pricing_service, user_id=user_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
image_bytes = self._decode_base64(request.image_base64)
|
||||
if not image_bytes:
|
||||
raise ValueError("Primary image is required for upscaling")
|
||||
|
||||
mode = self._resolve_mode(request)
|
||||
|
||||
async with StabilityAIService() as stability_service:
|
||||
logger.info("[Upscale Studio] Running '%s' upscale for user=%s", mode, user_id)
|
||||
|
||||
params = {
|
||||
"target_width": request.target_width,
|
||||
"target_height": request.target_height,
|
||||
}
|
||||
# remove None values
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
if mode == "fast":
|
||||
result = await stability_service.upscale_fast(
|
||||
image=image_bytes,
|
||||
**params,
|
||||
)
|
||||
elif mode == "conservative":
|
||||
prompt = request.prompt or "High fidelity upscale preserving original details"
|
||||
result = await stability_service.upscale_conservative(
|
||||
image=image_bytes,
|
||||
prompt=prompt,
|
||||
**params,
|
||||
)
|
||||
elif mode == "creative":
|
||||
prompt = request.prompt or "Creative upscale with enhanced artistic details"
|
||||
result = await stability_service.upscale_creative(
|
||||
image=image_bytes,
|
||||
prompt=prompt,
|
||||
**params,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported upscale mode: {mode}")
|
||||
|
||||
image_bytes = self._extract_image_bytes(result)
|
||||
metadata = self._image_metadata(image_bytes)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"mode": mode,
|
||||
"image_base64": self._to_base64(image_bytes),
|
||||
"width": metadata["width"],
|
||||
"height": metadata["height"],
|
||||
"metadata": {
|
||||
"preset": request.preset,
|
||||
"target_width": request.target_width,
|
||||
"target_height": request.target_height,
|
||||
"prompt": request.prompt,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _decode_base64(value: Optional[str]) -> Optional[bytes]:
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
if value.startswith("data:"):
|
||||
_, b64data = value.split(",", 1)
|
||||
else:
|
||||
b64data = value
|
||||
return base64.b64decode(b64data)
|
||||
except Exception as exc:
|
||||
logger.error("[Upscale Studio] Failed to decode base64 image: %s", exc)
|
||||
raise ValueError("Invalid base64 image payload") from exc
|
||||
|
||||
@staticmethod
|
||||
def _to_base64(image_bytes: bytes) -> str:
|
||||
return f"data:image/png;base64,{base64.b64encode(image_bytes).decode('utf-8')}"
|
||||
|
||||
@staticmethod
|
||||
def _image_metadata(image_bytes: bytes) -> Dict[str, int]:
|
||||
with Image.open(io.BytesIO(image_bytes)) as img:
|
||||
return {"width": img.width, "height": img.height}
|
||||
|
||||
@staticmethod
|
||||
def _extract_image_bytes(result: Any) -> bytes:
|
||||
if isinstance(result, bytes):
|
||||
return result
|
||||
if isinstance(result, dict):
|
||||
artifacts = result.get("artifacts") or result.get("data") or result.get("images") or []
|
||||
for artifact in artifacts:
|
||||
if isinstance(artifact, dict):
|
||||
if artifact.get("base64"):
|
||||
return base64.b64decode(artifact["base64"])
|
||||
if artifact.get("b64_json"):
|
||||
return base64.b64decode(artifact["b64_json"])
|
||||
raise HTTPException(status_code=502, detail="Unable to extract image from provider response")
|
||||
|
||||
@staticmethod
|
||||
def _resolve_mode(request: UpscaleStudioRequest) -> UpscaleMode:
|
||||
if request.mode != "auto":
|
||||
return request.mode
|
||||
# simple heuristic: if target >= 3000px, use conservative, else fast
|
||||
if (request.target_width and request.target_width >= 3000) or (
|
||||
request.target_height and request.target_height >= 3000
|
||||
):
|
||||
return "conservative"
|
||||
return "fast"
|
||||
|
||||
@@ -2,6 +2,7 @@ from .base import ImageGenerationOptions, ImageGenerationResult, ImageGeneration
|
||||
from .hf_provider import HuggingFaceImageProvider
|
||||
from .gemini_provider import GeminiImageProvider
|
||||
from .stability_provider import StabilityImageProvider
|
||||
from .wavespeed_provider import WaveSpeedImageProvider
|
||||
|
||||
__all__ = [
|
||||
"ImageGenerationOptions",
|
||||
@@ -10,6 +11,7 @@ __all__ = [
|
||||
"HuggingFaceImageProvider",
|
||||
"GeminiImageProvider",
|
||||
"StabilityImageProvider",
|
||||
"WaveSpeedImageProvider",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,243 @@
|
||||
"""WaveSpeed AI image generation provider (Ideogram V3 Turbo & Qwen Image)."""
|
||||
|
||||
import io
|
||||
import os
|
||||
from typing import Optional
|
||||
from PIL import Image
|
||||
|
||||
from .base import ImageGenerationProvider, ImageGenerationOptions, ImageGenerationResult
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("wavespeed.image_provider")
|
||||
|
||||
|
||||
class WaveSpeedImageProvider(ImageGenerationProvider):
|
||||
"""WaveSpeed AI image generation provider supporting Ideogram V3 and Qwen."""
|
||||
|
||||
SUPPORTED_MODELS = {
|
||||
"ideogram-v3-turbo": {
|
||||
"name": "Ideogram V3 Turbo",
|
||||
"description": "Photorealistic generation with superior text rendering",
|
||||
"cost_per_image": 0.10, # Estimated, adjust based on actual pricing
|
||||
"max_resolution": (1024, 1024),
|
||||
"default_steps": 20,
|
||||
},
|
||||
"qwen-image": {
|
||||
"name": "Qwen Image",
|
||||
"description": "Fast, high-quality text-to-image generation",
|
||||
"cost_per_image": 0.05, # Estimated, adjust based on actual pricing
|
||||
"max_resolution": (1024, 1024),
|
||||
"default_steps": 15,
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
"""Initialize WaveSpeed image provider.
|
||||
|
||||
Args:
|
||||
api_key: WaveSpeed API key (falls back to env var if not provided)
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("WAVESPEED_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("WaveSpeed API key not found. Set WAVESPEED_API_KEY environment variable.")
|
||||
|
||||
self.client = WaveSpeedClient(api_key=self.api_key)
|
||||
logger.info("[WaveSpeed Image Provider] Initialized with available models: %s",
|
||||
list(self.SUPPORTED_MODELS.keys()))
|
||||
|
||||
def _validate_options(self, options: ImageGenerationOptions) -> None:
|
||||
"""Validate generation options.
|
||||
|
||||
Args:
|
||||
options: Image generation options
|
||||
|
||||
Raises:
|
||||
ValueError: If options are invalid
|
||||
"""
|
||||
model = options.model or "ideogram-v3-turbo"
|
||||
|
||||
if model not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Unsupported model: {model}. "
|
||||
f"Supported models: {list(self.SUPPORTED_MODELS.keys())}"
|
||||
)
|
||||
|
||||
model_info = self.SUPPORTED_MODELS[model]
|
||||
max_width, max_height = model_info["max_resolution"]
|
||||
|
||||
if options.width > max_width or options.height > max_height:
|
||||
raise ValueError(
|
||||
f"Resolution {options.width}x{options.height} exceeds maximum "
|
||||
f"{max_width}x{max_height} for model {model}"
|
||||
)
|
||||
|
||||
if not options.prompt or len(options.prompt.strip()) == 0:
|
||||
raise ValueError("Prompt cannot be empty")
|
||||
|
||||
def _generate_ideogram_v3(self, options: ImageGenerationOptions) -> bytes:
|
||||
"""Generate image using Ideogram V3 Turbo.
|
||||
|
||||
Args:
|
||||
options: Image generation options
|
||||
|
||||
Returns:
|
||||
Image bytes
|
||||
"""
|
||||
logger.info("[Ideogram V3] Starting image generation: %s", options.prompt[:100])
|
||||
|
||||
try:
|
||||
# Prepare parameters for WaveSpeed Ideogram V3 API
|
||||
# Note: Adjust these based on actual WaveSpeed API documentation
|
||||
params = {
|
||||
"model": "ideogram-v3-turbo",
|
||||
"prompt": options.prompt,
|
||||
"width": options.width,
|
||||
"height": options.height,
|
||||
"num_inference_steps": options.steps or self.SUPPORTED_MODELS["ideogram-v3-turbo"]["default_steps"],
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if options.negative_prompt:
|
||||
params["negative_prompt"] = options.negative_prompt
|
||||
|
||||
if options.guidance_scale:
|
||||
params["guidance_scale"] = options.guidance_scale
|
||||
|
||||
if options.seed:
|
||||
params["seed"] = options.seed
|
||||
|
||||
# Call WaveSpeed API (using generic image generation method)
|
||||
# This will need to be adjusted based on actual WaveSpeed client implementation
|
||||
result = self.client.generate_image(**params)
|
||||
|
||||
# Extract image bytes from result
|
||||
# Adjust based on actual WaveSpeed API response format
|
||||
if isinstance(result, bytes):
|
||||
image_bytes = result
|
||||
elif isinstance(result, dict) and "image" in result:
|
||||
image_bytes = result["image"]
|
||||
else:
|
||||
raise ValueError(f"Unexpected response format from WaveSpeed API: {type(result)}")
|
||||
|
||||
logger.info("[Ideogram V3] ✅ Successfully generated image: %d bytes", len(image_bytes))
|
||||
return image_bytes
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Ideogram V3] ❌ Error generating image: %s", str(e), exc_info=True)
|
||||
raise RuntimeError(f"Ideogram V3 generation failed: {str(e)}")
|
||||
|
||||
def _generate_qwen_image(self, options: ImageGenerationOptions) -> bytes:
|
||||
"""Generate image using Qwen Image.
|
||||
|
||||
Args:
|
||||
options: Image generation options
|
||||
|
||||
Returns:
|
||||
Image bytes
|
||||
"""
|
||||
logger.info("[Qwen Image] Starting image generation: %s", options.prompt[:100])
|
||||
|
||||
try:
|
||||
# Prepare parameters for WaveSpeed Qwen Image API
|
||||
params = {
|
||||
"model": "qwen-image",
|
||||
"prompt": options.prompt,
|
||||
"width": options.width,
|
||||
"height": options.height,
|
||||
"num_inference_steps": options.steps or self.SUPPORTED_MODELS["qwen-image"]["default_steps"],
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if options.negative_prompt:
|
||||
params["negative_prompt"] = options.negative_prompt
|
||||
|
||||
if options.guidance_scale:
|
||||
params["guidance_scale"] = options.guidance_scale
|
||||
|
||||
if options.seed:
|
||||
params["seed"] = options.seed
|
||||
|
||||
# Call WaveSpeed API
|
||||
result = self.client.generate_image(**params)
|
||||
|
||||
# Extract image bytes from result
|
||||
if isinstance(result, bytes):
|
||||
image_bytes = result
|
||||
elif isinstance(result, dict) and "image" in result:
|
||||
image_bytes = result["image"]
|
||||
else:
|
||||
raise ValueError(f"Unexpected response format from WaveSpeed API: {type(result)}")
|
||||
|
||||
logger.info("[Qwen Image] ✅ Successfully generated image: %d bytes", len(image_bytes))
|
||||
return image_bytes
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Qwen Image] ❌ Error generating image: %s", str(e), exc_info=True)
|
||||
raise RuntimeError(f"Qwen Image generation failed: {str(e)}")
|
||||
|
||||
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
|
||||
"""Generate image using WaveSpeed AI models.
|
||||
|
||||
Args:
|
||||
options: Image generation options
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with generated image
|
||||
|
||||
Raises:
|
||||
ValueError: If options are invalid
|
||||
RuntimeError: If generation fails
|
||||
"""
|
||||
# Validate options
|
||||
self._validate_options(options)
|
||||
|
||||
# Determine model
|
||||
model = options.model or "ideogram-v3-turbo"
|
||||
|
||||
# Generate based on model
|
||||
if model == "ideogram-v3-turbo":
|
||||
image_bytes = self._generate_ideogram_v3(options)
|
||||
elif model == "qwen-image":
|
||||
image_bytes = self._generate_qwen_image(options)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model: {model}")
|
||||
|
||||
# Load image to get dimensions
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
width, height = image.size
|
||||
|
||||
# Calculate estimated cost
|
||||
model_info = self.SUPPORTED_MODELS[model]
|
||||
estimated_cost = model_info["cost_per_image"]
|
||||
|
||||
# Return result
|
||||
return ImageGenerationResult(
|
||||
image_bytes=image_bytes,
|
||||
width=width,
|
||||
height=height,
|
||||
provider="wavespeed",
|
||||
model=model,
|
||||
seed=options.seed,
|
||||
metadata={
|
||||
"provider": "wavespeed",
|
||||
"model": model,
|
||||
"model_name": model_info["name"],
|
||||
"prompt": options.prompt,
|
||||
"negative_prompt": options.negative_prompt,
|
||||
"steps": options.steps or model_info["default_steps"],
|
||||
"guidance_scale": options.guidance_scale,
|
||||
"estimated_cost": estimated_cost,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_available_models(cls) -> dict:
|
||||
"""Get available models and their information.
|
||||
|
||||
Returns:
|
||||
Dictionary of available models
|
||||
"""
|
||||
return cls.SUPPORTED_MODELS
|
||||
|
||||
@@ -240,20 +240,23 @@ def validate_exa_research_operations(
|
||||
|
||||
def validate_image_generation_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str
|
||||
user_id: str,
|
||||
num_images: int = 1
|
||||
) -> None:
|
||||
"""
|
||||
Validate image generation operation before making API calls.
|
||||
Validate image generation operation(s) before making API calls.
|
||||
|
||||
Args:
|
||||
pricing_service: PricingService instance
|
||||
user_id: User ID for subscription checking
|
||||
num_images: Number of images to generate (for multiple variations)
|
||||
|
||||
Returns:
|
||||
(can_proceed, error_message, error_details)
|
||||
If can_proceed is False, raises HTTPException with 429 status
|
||||
None
|
||||
If validation fails, raises HTTPException with 429 status
|
||||
"""
|
||||
try:
|
||||
# Create validation operations for each image
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.STABILITY,
|
||||
@@ -261,8 +264,11 @@ def validate_image_generation_operations(
|
||||
'actual_provider_name': 'stability',
|
||||
'operation_type': 'image_generation'
|
||||
}
|
||||
for _ in range(num_images)
|
||||
]
|
||||
|
||||
logger.info(f"[Pre-flight Validator] 🚀 Validating {num_images} image generation(s) for user {user_id}")
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
@@ -289,6 +295,54 @@ def validate_image_generation_operations(
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
|
||||
def validate_image_upscale_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str,
|
||||
num_images: int = 1
|
||||
) -> None:
|
||||
"""
|
||||
Validate image upscaling before making API calls.
|
||||
"""
|
||||
try:
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.STABILITY,
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'stability',
|
||||
'operation_type': 'image_upscale'
|
||||
}
|
||||
for _ in range(num_images)
|
||||
]
|
||||
|
||||
logger.info(f"[Pre-flight Validator] 🚀 Validating {num_images} image upscale request(s) for user {user_id}")
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.error(f"[Pre-flight Validator] Image upscale blocked for user {user_id}: {message}")
|
||||
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', 'stability') if usage_info else 'stability'
|
||||
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ Image upscale validated for user {user_id}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating image generation: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
|
||||
@@ -312,6 +312,175 @@ class WaveSpeedClient:
|
||||
logger.info(f"[WaveSpeed] Prompt optimized successfully (length: {len(optimized_prompt)} chars)")
|
||||
return optimized_prompt
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
width: int = 1024,
|
||||
height: int = 1024,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
guidance_scale: Optional[float] = None,
|
||||
negative_prompt: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
enable_sync_mode: bool = True,
|
||||
timeout: int = 120,
|
||||
**kwargs
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate image using WaveSpeed AI models (Ideogram V3 or Qwen Image).
|
||||
|
||||
Args:
|
||||
model: Model to use ("ideogram-v3-turbo" or "qwen-image")
|
||||
prompt: Text prompt for image generation
|
||||
width: Image width (default: 1024)
|
||||
height: Image height (default: 1024)
|
||||
num_inference_steps: Number of inference steps
|
||||
guidance_scale: Guidance scale for generation
|
||||
negative_prompt: Negative prompt (what to avoid)
|
||||
seed: Random seed for reproducibility
|
||||
enable_sync_mode: If True, wait for result and return it directly (default: True)
|
||||
timeout: Request timeout in seconds (default: 120)
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
bytes: Generated image bytes
|
||||
"""
|
||||
# Map model names to WaveSpeed API paths
|
||||
model_paths = {
|
||||
"ideogram-v3-turbo": "ideogram-ai/ideogram-v3-turbo",
|
||||
"qwen-image": "wavespeed-ai/qwen-image/text-to-image",
|
||||
}
|
||||
|
||||
model_path = model_paths.get(model)
|
||||
if not model_path:
|
||||
raise ValueError(f"Unsupported image model: {model}. Supported: {list(model_paths.keys())}")
|
||||
|
||||
url = f"{self.BASE_URL}/{model_path}"
|
||||
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"enable_sync_mode": enable_sync_mode,
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if num_inference_steps is not None:
|
||||
payload["num_inference_steps"] = num_inference_steps
|
||||
if guidance_scale is not None:
|
||||
payload["guidance_scale"] = guidance_scale
|
||||
if negative_prompt:
|
||||
payload["negative_prompt"] = negative_prompt
|
||||
if seed is not None:
|
||||
payload["seed"] = seed
|
||||
|
||||
# Add any extra parameters
|
||||
for key, value in kwargs.items():
|
||||
if key not in payload:
|
||||
payload[key] = value
|
||||
|
||||
logger.info(f"[WaveSpeed] Generating image via {url} (model={model}, prompt_length={len(prompt)})")
|
||||
response = requests.post(url, headers=self._headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Image generation failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed image generation failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
# Handle sync mode - result should be directly in outputs
|
||||
if enable_sync_mode:
|
||||
outputs = data.get("outputs") or []
|
||||
if not outputs:
|
||||
logger.error(f"[WaveSpeed] No outputs in sync mode response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed image generator returned no outputs",
|
||||
)
|
||||
|
||||
# Extract image URL from outputs
|
||||
image_url = None
|
||||
if isinstance(outputs, list) and len(outputs) > 0:
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, str):
|
||||
image_url = first_output
|
||||
elif isinstance(first_output, dict):
|
||||
image_url = first_output.get("url") or first_output.get("output")
|
||||
|
||||
if not image_url or not (image_url.startswith("http://") or image_url.startswith("https://")):
|
||||
logger.error(f"[WaveSpeed] Invalid image URL in outputs: {outputs}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed image generator output format not recognized",
|
||||
)
|
||||
|
||||
# Fetch image bytes from URL
|
||||
logger.info(f"[WaveSpeed] Fetching image from URL: {image_url}")
|
||||
image_response = requests.get(image_url, timeout=timeout)
|
||||
if image_response.status_code == 200:
|
||||
image_bytes = image_response.content
|
||||
logger.info(f"[WaveSpeed] Image generated successfully (size: {len(image_bytes)} bytes)")
|
||||
return image_bytes
|
||||
else:
|
||||
logger.error(f"[WaveSpeed] Failed to fetch image from URL: {image_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch generated image from WaveSpeed URL",
|
||||
)
|
||||
|
||||
# Async mode - poll for result
|
||||
prediction_id = data.get("id")
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in async response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed response missing prediction id for async mode",
|
||||
)
|
||||
|
||||
# Poll for result
|
||||
result = self.poll_until_complete(prediction_id, timeout_seconds=240, interval_seconds=1.0)
|
||||
outputs = result.get("outputs") or []
|
||||
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed image generator returned no outputs")
|
||||
|
||||
# Extract image URL and fetch
|
||||
image_url = None
|
||||
if isinstance(outputs, list) and len(outputs) > 0:
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, str):
|
||||
image_url = first_output
|
||||
elif isinstance(first_output, dict):
|
||||
image_url = first_output.get("url") or first_output.get("output")
|
||||
|
||||
if not image_url or not (image_url.startswith("http://") or image_url.startswith("https://")):
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed image generator output format not recognized",
|
||||
)
|
||||
|
||||
# Fetch image bytes
|
||||
logger.info(f"[WaveSpeed] Fetching image from URL: {image_url}")
|
||||
image_response = requests.get(image_url, timeout=timeout)
|
||||
if image_response.status_code == 200:
|
||||
image_bytes = image_response.content
|
||||
logger.info(f"[WaveSpeed] Image generated successfully (size: {len(image_bytes)} bytes)")
|
||||
return image_bytes
|
||||
else:
|
||||
logger.error(f"[WaveSpeed] Failed to fetch image from URL: {image_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch generated image from WaveSpeed URL",
|
||||
)
|
||||
|
||||
def generate_speech(
|
||||
self,
|
||||
text: str,
|
||||
|
||||
Reference in New Issue
Block a user