AI Image Studio Progress Review
- Added new router for content assets - Added new service for content assets - Added new model for content assets - Added new utils for content assets - Added new docs for content assets - Added new tests for content assets - Added new examples for content assets - Added new guides for content assets
This commit is contained in:
2
backend/api/content_assets/__init__.py
Normal file
2
backend/api/content_assets/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Content Assets API Module
|
||||
|
||||
258
backend/api/content_assets/router.py
Normal file
258
backend/api/content_assets/router.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
Content Assets API Router
|
||||
API endpoints for managing unified content assets across all modules.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
|
||||
from services.database import get_db
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from services.content_asset_service import ContentAssetService
|
||||
from models.content_asset_models import AssetType, AssetSource
|
||||
|
||||
router = APIRouter(prefix="/api/content-assets", tags=["Content Assets"])
|
||||
|
||||
|
||||
class AssetResponse(BaseModel):
|
||||
"""Response model for asset data."""
|
||||
id: int
|
||||
user_id: str
|
||||
asset_type: str
|
||||
source_module: str
|
||||
filename: str
|
||||
file_url: str
|
||||
file_path: Optional[str] = None
|
||||
file_size: Optional[int] = None
|
||||
mime_type: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
prompt: Optional[str] = None
|
||||
tags: List[str] = []
|
||||
metadata: Dict[str, Any] = {}
|
||||
provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
cost: float = 0.0
|
||||
generation_time: Optional[float] = None
|
||||
is_favorite: bool = False
|
||||
download_count: int = 0
|
||||
share_count: int = 0
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class AssetListResponse(BaseModel):
|
||||
"""Response model for asset list."""
|
||||
assets: List[AssetResponse]
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
|
||||
|
||||
@router.get("/", response_model=AssetListResponse)
|
||||
async def get_assets(
|
||||
asset_type: Optional[str] = Query(None, description="Filter by asset type"),
|
||||
source_module: Optional[str] = Query(None, description="Filter by source module"),
|
||||
search: Optional[str] = Query(None, description="Search query"),
|
||||
tags: Optional[str] = Query(None, description="Comma-separated tags"),
|
||||
favorites_only: bool = Query(False, description="Only favorites"),
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Get user's content assets with optional filtering."""
|
||||
try:
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
service = ContentAssetService(db)
|
||||
|
||||
# Parse filters
|
||||
asset_type_enum = None
|
||||
if asset_type:
|
||||
try:
|
||||
asset_type_enum = AssetType(asset_type.lower())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid asset type: {asset_type}")
|
||||
|
||||
source_module_enum = None
|
||||
if source_module:
|
||||
try:
|
||||
source_module_enum = AssetSource(source_module.lower())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid source module: {source_module}")
|
||||
|
||||
tags_list = None
|
||||
if tags:
|
||||
tags_list = [tag.strip() for tag in tags.split(",")]
|
||||
|
||||
assets, total = service.get_user_assets(
|
||||
user_id=user_id,
|
||||
asset_type=asset_type_enum,
|
||||
source_module=source_module_enum,
|
||||
search_query=search,
|
||||
tags=tags_list,
|
||||
favorites_only=favorites_only,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return AssetListResponse(
|
||||
assets=[AssetResponse.model_validate(asset) for asset in assets],
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error fetching assets: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/{asset_id}/favorite", response_model=Dict[str, Any])
|
||||
async def toggle_favorite(
|
||||
asset_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Toggle favorite status of an asset."""
|
||||
try:
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
service = ContentAssetService(db)
|
||||
is_favorite = service.toggle_favorite(asset_id, user_id)
|
||||
|
||||
return {"asset_id": asset_id, "is_favorite": is_favorite}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error toggling favorite: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/{asset_id}", response_model=Dict[str, Any])
|
||||
async def delete_asset(
|
||||
asset_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Delete an asset."""
|
||||
try:
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
service = ContentAssetService(db)
|
||||
success = service.delete_asset(asset_id, user_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Asset not found")
|
||||
|
||||
return {"asset_id": asset_id, "deleted": True}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error deleting asset: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/{asset_id}/usage", response_model=Dict[str, Any])
|
||||
async def track_usage(
|
||||
asset_id: int,
|
||||
action: str = Query(..., description="Action: download, share, or access"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Track asset usage (download, share, access)."""
|
||||
try:
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
if action not in ["download", "share", "access"]:
|
||||
raise HTTPException(status_code=400, detail="Invalid action")
|
||||
|
||||
service = ContentAssetService(db)
|
||||
service.update_asset_usage(asset_id, user_id, action)
|
||||
|
||||
return {"asset_id": asset_id, "action": action, "tracked": True}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error tracking usage: {str(e)}")
|
||||
|
||||
|
||||
class AssetUpdateRequest(BaseModel):
|
||||
"""Request model for updating asset metadata."""
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
|
||||
@router.put("/{asset_id}", response_model=AssetResponse)
|
||||
async def update_asset(
|
||||
asset_id: int,
|
||||
update_data: AssetUpdateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Update asset metadata."""
|
||||
try:
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
service = ContentAssetService(db)
|
||||
|
||||
asset = service.update_asset(
|
||||
asset_id=asset_id,
|
||||
user_id=user_id,
|
||||
title=update_data.title,
|
||||
description=update_data.description,
|
||||
tags=update_data.tags,
|
||||
)
|
||||
|
||||
if not asset:
|
||||
raise HTTPException(status_code=404, detail="Asset not found")
|
||||
|
||||
return AssetResponse.model_validate(asset)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error updating asset: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/statistics", response_model=Dict[str, Any])
|
||||
async def get_statistics(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Get asset statistics for the current user."""
|
||||
try:
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
service = ContentAssetService(db)
|
||||
stats = service.get_asset_statistics(user_id)
|
||||
|
||||
return stats
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error fetching statistics: {str(e)}")
|
||||
|
||||
@@ -299,6 +299,10 @@ app.include_router(platform_analytics_router)
|
||||
app.include_router(images_router)
|
||||
app.include_router(image_studio_router)
|
||||
|
||||
# Include content assets router
|
||||
from api.content_assets.router import router as content_assets_router
|
||||
app.include_router(content_assets_router)
|
||||
|
||||
# Include research configuration router
|
||||
app.include_router(research_config_router, prefix="/api/research", tags=["research"])
|
||||
|
||||
|
||||
145
backend/models/content_asset_models.py
Normal file
145
backend/models/content_asset_models.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
Content Asset Models
|
||||
Unified database models for tracking all AI-generated content assets across all modules.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, Text, ForeignKey, Enum, Index, func
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
import enum
|
||||
|
||||
# Use the same Base as subscription models for consistency
|
||||
from models.subscription_models import Base
|
||||
|
||||
|
||||
class AssetType(enum.Enum):
|
||||
"""Types of content assets."""
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
|
||||
|
||||
class AssetSource(enum.Enum):
|
||||
"""Source module/tool that generated the asset - covers ALL ALwrity tools."""
|
||||
# Image Studio modules
|
||||
IMAGE_STUDIO_CREATE = "image_studio_create"
|
||||
IMAGE_STUDIO_EDIT = "image_studio_edit"
|
||||
IMAGE_STUDIO_UPSCALE = "image_studio_upscale"
|
||||
IMAGE_STUDIO_TRANSFORM = "image_studio_transform"
|
||||
IMAGE_STUDIO_CONTROL = "image_studio_control"
|
||||
IMAGE_STUDIO_SOCIAL = "image_studio_social"
|
||||
IMAGE_STUDIO_BATCH = "image_studio_batch"
|
||||
|
||||
# Content Writers
|
||||
STORY_WRITER = "story_writer"
|
||||
BLOG_WRITER = "blog_writer"
|
||||
LINKEDIN_WRITER = "linkedin_writer"
|
||||
FACEBOOK_WRITER = "facebook_writer"
|
||||
|
||||
# Content Planning
|
||||
CONTENT_PLANNING = "content_planning"
|
||||
CONTENT_STRATEGY = "content_strategy"
|
||||
|
||||
# SEO Tools
|
||||
SEO_DASHBOARD = "seo_dashboard"
|
||||
SEO_TOOLS = "seo_tools"
|
||||
|
||||
# Research
|
||||
RESEARCH = "research"
|
||||
|
||||
# Scheduler
|
||||
SCHEDULER = "scheduler"
|
||||
|
||||
# Main Generation (legacy/fallback)
|
||||
MAIN_TEXT_GENERATION = "main_text_generation"
|
||||
MAIN_IMAGE_GENERATION = "main_image_generation"
|
||||
MAIN_VIDEO_GENERATION = "main_video_generation"
|
||||
MAIN_AUDIO_GENERATION = "main_audio_generation"
|
||||
|
||||
|
||||
class ContentAsset(Base):
|
||||
"""
|
||||
Unified model for tracking all AI-generated content assets.
|
||||
Similar to subscription tracking, this provides a centralized way to manage all content.
|
||||
"""
|
||||
|
||||
__tablename__ = "content_assets"
|
||||
|
||||
# Primary fields
|
||||
id = Column(Integer, primary_key=True)
|
||||
user_id = Column(String(255), nullable=False, index=True) # Clerk user ID
|
||||
|
||||
# Asset identification
|
||||
asset_type = Column(Enum(AssetType), nullable=False, index=True)
|
||||
source_module = Column(Enum(AssetSource), nullable=False, index=True)
|
||||
|
||||
# File information
|
||||
filename = Column(String(500), nullable=False)
|
||||
file_path = Column(String(1000), nullable=True) # Server file path
|
||||
file_url = Column(String(1000), nullable=False) # Public URL
|
||||
file_size = Column(Integer, nullable=True) # Size in bytes
|
||||
mime_type = Column(String(100), nullable=True) # MIME type
|
||||
|
||||
# Asset metadata
|
||||
title = Column(String(500), nullable=True)
|
||||
description = Column(Text, nullable=True)
|
||||
prompt = Column(Text, nullable=True) # Original prompt used for generation
|
||||
tags = Column(JSON, nullable=True) # Array of tags for search/filtering
|
||||
metadata = Column(JSON, nullable=True) # Additional module-specific metadata
|
||||
|
||||
# Generation details
|
||||
provider = Column(String(100), nullable=True, index=True) # AI provider used (e.g., "stability", "gemini")
|
||||
model = Column(String(200), nullable=True, index=True) # Model used (full model path/name)
|
||||
cost = Column(Float, nullable=True, default=0.0) # Generation cost in USD
|
||||
generation_time = Column(Float, nullable=True) # Time taken in seconds
|
||||
|
||||
# Status tracking
|
||||
status = Column(String(50), default='completed', index=True) # completed, processing, failed, pending
|
||||
error_message = Column(Text, nullable=True) # Error details if failed
|
||||
|
||||
# Organization
|
||||
is_favorite = Column(Boolean, default=False, index=True)
|
||||
collection_id = Column(Integer, ForeignKey('asset_collections.id'), nullable=True)
|
||||
|
||||
# Usage tracking
|
||||
download_count = Column(Integer, default=0)
|
||||
share_count = Column(Integer, default=0)
|
||||
last_accessed = Column(DateTime, nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
collection = relationship("AssetCollection", back_populates="assets", cascade="all, delete-orphan")
|
||||
|
||||
# Composite indexes for common query patterns
|
||||
__table_args__ = (
|
||||
Index('idx_user_type_source', 'user_id', 'asset_type', 'source_module'),
|
||||
Index('idx_user_favorite_created', 'user_id', 'is_favorite', 'created_at'),
|
||||
Index('idx_user_tags', 'user_id', 'tags'),
|
||||
)
|
||||
|
||||
|
||||
class AssetCollection(Base):
|
||||
"""
|
||||
Collections/albums for organizing assets.
|
||||
"""
|
||||
|
||||
__tablename__ = "asset_collections"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
user_id = Column(String(255), nullable=False, index=True)
|
||||
name = Column(String(255), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
is_public = Column(Boolean, default=False)
|
||||
cover_asset_id = Column(Integer, ForeignKey('content_assets.id'), nullable=True)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
assets = relationship("ContentAsset", back_populates="collection")
|
||||
|
||||
@@ -9,6 +9,8 @@ from services.image_studio import (
|
||||
ImageStudioManager,
|
||||
CreateStudioRequest,
|
||||
EditStudioRequest,
|
||||
ControlStudioRequest,
|
||||
SocialOptimizerRequest,
|
||||
)
|
||||
from services.image_studio.upscale_service import UpscaleStudioRequest
|
||||
from services.image_studio.templates import Platform, TemplateCategory
|
||||
@@ -531,6 +533,197 @@ async def upscale_image(
|
||||
raise HTTPException(status_code=500, detail=f"Image upscaling failed: {e}")
|
||||
|
||||
|
||||
# ====================
|
||||
# CONTROL STUDIO ENDPOINTS
|
||||
# ====================
|
||||
|
||||
class ControlImageRequest(BaseModel):
|
||||
"""Request payload for Control Studio."""
|
||||
|
||||
control_image_base64: str = Field(..., description="Control image (sketch/structure/style) in base64")
|
||||
operation: Literal["sketch", "structure", "style", "style_transfer"] = Field(..., description="Control operation")
|
||||
prompt: str = Field(..., description="Text prompt for generation")
|
||||
style_image_base64: Optional[str] = Field(None, description="Style reference image (for style_transfer only)")
|
||||
negative_prompt: Optional[str] = Field(None, description="Negative prompt")
|
||||
control_strength: Optional[float] = Field(None, ge=0.0, le=1.0, description="Control strength (sketch/structure)")
|
||||
fidelity: Optional[float] = Field(None, ge=0.0, le=1.0, description="Style fidelity (style operation)")
|
||||
style_strength: Optional[float] = Field(None, ge=0.0, le=1.0, description="Style strength (style_transfer)")
|
||||
composition_fidelity: Optional[float] = Field(None, ge=0.0, le=1.0, description="Composition fidelity (style_transfer)")
|
||||
change_strength: Optional[float] = Field(None, ge=0.0, le=1.0, description="Change strength (style_transfer)")
|
||||
aspect_ratio: Optional[str] = Field(None, description="Aspect ratio (style operation)")
|
||||
style_preset: Optional[str] = Field(None, description="Style preset")
|
||||
seed: Optional[int] = Field(None, description="Random seed")
|
||||
output_format: str = Field("png", description="Output format")
|
||||
|
||||
|
||||
class ControlImageResponse(BaseModel):
|
||||
success: bool
|
||||
operation: str
|
||||
provider: str
|
||||
image_base64: str
|
||||
width: int
|
||||
height: int
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class ControlOperationsResponse(BaseModel):
|
||||
operations: Dict[str, Dict[str, Any]]
|
||||
|
||||
|
||||
@router.post("/control/process", response_model=ControlImageResponse, summary="Process Control Studio request")
|
||||
async def process_control_image(
|
||||
request: ControlImageRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Perform Control Studio operations such as sketch-to-image, structure control, style control, and style transfer."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "image control")
|
||||
logger.info(f"[Control Image] Request from user {user_id}: operation={request.operation}")
|
||||
|
||||
control_request = ControlStudioRequest(
|
||||
operation=request.operation,
|
||||
prompt=request.prompt,
|
||||
control_image_base64=request.control_image_base64,
|
||||
style_image_base64=request.style_image_base64,
|
||||
negative_prompt=request.negative_prompt,
|
||||
control_strength=request.control_strength,
|
||||
fidelity=request.fidelity,
|
||||
style_strength=request.style_strength,
|
||||
composition_fidelity=request.composition_fidelity,
|
||||
change_strength=request.change_strength,
|
||||
aspect_ratio=request.aspect_ratio,
|
||||
style_preset=request.style_preset,
|
||||
seed=request.seed,
|
||||
output_format=request.output_format,
|
||||
)
|
||||
|
||||
result = await studio_manager.control_image(control_request, user_id=user_id)
|
||||
return ControlImageResponse(**result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Control Image] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Image control failed: {e}")
|
||||
|
||||
|
||||
@router.get("/control/operations", response_model=ControlOperationsResponse, summary="List Control Studio operations")
|
||||
async def get_control_operations(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Return metadata for supported Control Studio operations."""
|
||||
try:
|
||||
operations = studio_manager.get_control_operations()
|
||||
return ControlOperationsResponse(operations=operations)
|
||||
except Exception as e:
|
||||
logger.error(f"[Control Operations] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to load control operations")
|
||||
|
||||
|
||||
# ====================
|
||||
# SOCIAL OPTIMIZER ENDPOINTS
|
||||
# ====================
|
||||
|
||||
class SocialOptimizeRequest(BaseModel):
|
||||
"""Request payload for Social Optimizer."""
|
||||
image_base64: str = Field(..., description="Source image in base64 or data URL")
|
||||
platforms: List[str] = Field(..., description="List of platforms to optimize for")
|
||||
format_names: Optional[Dict[str, str]] = Field(None, description="Specific format per platform")
|
||||
show_safe_zones: bool = Field(False, description="Include safe zone overlay in output")
|
||||
crop_mode: str = Field("smart", description="Crop mode: smart, center, or fit")
|
||||
focal_point: Optional[Dict[str, float]] = Field(None, description="Focal point for smart crop (x, y as 0-1)")
|
||||
output_format: str = Field("png", description="Output format (png or jpg)")
|
||||
|
||||
|
||||
class SocialOptimizeResponse(BaseModel):
|
||||
success: bool
|
||||
results: List[Dict[str, Any]]
|
||||
total_optimized: int
|
||||
|
||||
|
||||
class PlatformFormatsResponse(BaseModel):
|
||||
formats: List[Dict[str, Any]]
|
||||
|
||||
|
||||
@router.post("/social/optimize", response_model=SocialOptimizeResponse, summary="Optimize image for social platforms")
|
||||
async def optimize_for_social(
|
||||
request: SocialOptimizeRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Optimize an image for multiple social media platforms with smart cropping and safe zones."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "social optimization")
|
||||
logger.info(f"[Social Optimizer] Request from user {user_id}: platforms={request.platforms}")
|
||||
|
||||
# Convert platform strings to Platform enum
|
||||
from services.image_studio.templates import Platform
|
||||
platforms = []
|
||||
for platform_str in request.platforms:
|
||||
try:
|
||||
platforms.append(Platform(platform_str.lower()))
|
||||
except ValueError:
|
||||
logger.warning(f"[Social Optimizer] Invalid platform: {platform_str}")
|
||||
continue
|
||||
|
||||
if not platforms:
|
||||
raise HTTPException(status_code=400, detail="No valid platforms provided")
|
||||
|
||||
# Convert format_names dict keys to Platform enum
|
||||
format_names = None
|
||||
if request.format_names:
|
||||
format_names = {}
|
||||
for platform_str, format_name in request.format_names.items():
|
||||
try:
|
||||
platform = Platform(platform_str.lower())
|
||||
format_names[platform] = format_name
|
||||
except ValueError:
|
||||
logger.warning(f"[Social Optimizer] Invalid platform in format_names: {platform_str}")
|
||||
|
||||
social_request = SocialOptimizerRequest(
|
||||
image_base64=request.image_base64,
|
||||
platforms=platforms,
|
||||
format_names=format_names,
|
||||
show_safe_zones=request.show_safe_zones,
|
||||
crop_mode=request.crop_mode,
|
||||
focal_point=request.focal_point,
|
||||
output_format=request.output_format,
|
||||
options={},
|
||||
)
|
||||
|
||||
result = await studio_manager.optimize_for_social(social_request, user_id=user_id)
|
||||
return SocialOptimizeResponse(**result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Social Optimizer] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Social optimization failed: {e}")
|
||||
|
||||
|
||||
@router.get("/social/platforms/{platform}/formats", response_model=PlatformFormatsResponse, summary="Get platform formats")
|
||||
async def get_platform_formats(
|
||||
platform: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get available formats for a social media platform."""
|
||||
try:
|
||||
from services.image_studio.templates import Platform
|
||||
try:
|
||||
platform_enum = Platform(platform.lower())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid platform: {platform}")
|
||||
|
||||
formats = studio_manager.get_social_platform_formats(platform_enum)
|
||||
return PlatformFormatsResponse(formats=formats)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Platform Formats] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to load platform formats: {e}")
|
||||
|
||||
|
||||
# ====================
|
||||
# PLATFORM SPECS ENDPOINTS
|
||||
# ====================
|
||||
|
||||
322
backend/services/content_asset_service.py
Normal file
322
backend/services/content_asset_service.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
Content Asset Service
|
||||
Service for managing and tracking all AI-generated content assets.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_, func, desc
|
||||
from datetime import datetime
|
||||
from models.content_asset_models import (
|
||||
ContentAsset,
|
||||
AssetCollection,
|
||||
AssetType,
|
||||
AssetSource
|
||||
)
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContentAssetService:
|
||||
"""Service for managing content assets across all modules."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def create_asset(
|
||||
self,
|
||||
user_id: str,
|
||||
asset_type: AssetType,
|
||||
source_module: AssetSource,
|
||||
filename: str,
|
||||
file_url: str,
|
||||
file_path: Optional[str] = None,
|
||||
file_size: Optional[int] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
prompt: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
cost: Optional[float] = None,
|
||||
generation_time: Optional[float] = None,
|
||||
) -> ContentAsset:
|
||||
"""
|
||||
Create a new content asset record.
|
||||
|
||||
Args:
|
||||
user_id: Clerk user ID
|
||||
asset_type: Type of asset (text, image, video, audio)
|
||||
source_module: Source module that generated it
|
||||
filename: Original filename
|
||||
file_url: Public URL to access the asset
|
||||
file_path: Server file path (optional)
|
||||
file_size: File size in bytes (optional)
|
||||
mime_type: MIME type (optional)
|
||||
title: Asset title (optional)
|
||||
description: Asset description (optional)
|
||||
prompt: Generation prompt (optional)
|
||||
tags: List of tags (optional)
|
||||
metadata: Additional metadata (optional)
|
||||
provider: AI provider used (optional)
|
||||
model: Model used (optional)
|
||||
cost: Generation cost (optional)
|
||||
generation_time: Generation time in seconds (optional)
|
||||
|
||||
Returns:
|
||||
Created ContentAsset instance
|
||||
"""
|
||||
try:
|
||||
asset = ContentAsset(
|
||||
user_id=user_id,
|
||||
asset_type=asset_type,
|
||||
source_module=source_module,
|
||||
filename=filename,
|
||||
file_url=file_url,
|
||||
file_path=file_path,
|
||||
file_size=file_size,
|
||||
mime_type=mime_type,
|
||||
title=title,
|
||||
description=description,
|
||||
prompt=prompt,
|
||||
tags=tags or [],
|
||||
metadata=metadata or {},
|
||||
provider=provider,
|
||||
model=model,
|
||||
cost=cost or 0.0,
|
||||
generation_time=generation_time,
|
||||
)
|
||||
|
||||
self.db.add(asset)
|
||||
self.db.commit()
|
||||
self.db.refresh(asset)
|
||||
|
||||
logger.info(f"Created asset {asset.id} for user {user_id} from {source_module.value}")
|
||||
return asset
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error creating asset: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_user_assets(
|
||||
self,
|
||||
user_id: str,
|
||||
asset_type: Optional[AssetType] = None,
|
||||
source_module: Optional[AssetSource] = None,
|
||||
search_query: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
favorites_only: bool = False,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> Tuple[List[ContentAsset], int]:
|
||||
"""
|
||||
Get assets for a user with optional filtering.
|
||||
|
||||
Args:
|
||||
user_id: Clerk user ID
|
||||
asset_type: Filter by asset type (optional)
|
||||
source_module: Filter by source module (optional)
|
||||
search_query: Search in title, description, prompt (optional)
|
||||
tags: Filter by tags (optional)
|
||||
favorites_only: Only return favorites (optional)
|
||||
limit: Maximum number of results
|
||||
offset: Offset for pagination
|
||||
|
||||
Returns:
|
||||
List of ContentAsset instances
|
||||
"""
|
||||
try:
|
||||
query = self.db.query(ContentAsset).filter(
|
||||
ContentAsset.user_id == user_id
|
||||
)
|
||||
|
||||
if asset_type:
|
||||
query = query.filter(ContentAsset.asset_type == asset_type)
|
||||
|
||||
if source_module:
|
||||
query = query.filter(ContentAsset.source_module == source_module)
|
||||
|
||||
if favorites_only:
|
||||
query = query.filter(ContentAsset.is_favorite == True)
|
||||
|
||||
if search_query:
|
||||
search_filter = or_(
|
||||
ContentAsset.title.ilike(f"%{search_query}%"),
|
||||
ContentAsset.description.ilike(f"%{search_query}%"),
|
||||
ContentAsset.prompt.ilike(f"%{search_query}%"),
|
||||
ContentAsset.filename.ilike(f"%{search_query}%"),
|
||||
)
|
||||
query = query.filter(search_filter)
|
||||
|
||||
if tags:
|
||||
# Filter by tags (JSON array contains any of the tags)
|
||||
tag_filters = [ContentAsset.tags.contains([tag]) for tag in tags]
|
||||
query = query.filter(or_(*tag_filters))
|
||||
|
||||
# Get total count before pagination
|
||||
total_count = query.count()
|
||||
|
||||
# Apply ordering and pagination
|
||||
query = query.order_by(desc(ContentAsset.created_at))
|
||||
query = query.limit(limit).offset(offset)
|
||||
|
||||
return query.all(), total_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching assets: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_asset_by_id(self, asset_id: int, user_id: str) -> Optional[ContentAsset]:
|
||||
"""Get a specific asset by ID."""
|
||||
try:
|
||||
return self.db.query(ContentAsset).filter(
|
||||
and_(
|
||||
ContentAsset.id == asset_id,
|
||||
ContentAsset.user_id == user_id
|
||||
)
|
||||
).first()
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching asset {asset_id}: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
def toggle_favorite(self, asset_id: int, user_id: str) -> bool:
|
||||
"""Toggle favorite status of an asset."""
|
||||
try:
|
||||
asset = self.get_asset_by_id(asset_id, user_id)
|
||||
if not asset:
|
||||
return False
|
||||
|
||||
asset.is_favorite = not asset.is_favorite
|
||||
self.db.commit()
|
||||
return asset.is_favorite
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error toggling favorite: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def delete_asset(self, asset_id: int, user_id: str) -> bool:
|
||||
"""Delete an asset."""
|
||||
try:
|
||||
asset = self.get_asset_by_id(asset_id, user_id)
|
||||
if not asset:
|
||||
return False
|
||||
|
||||
self.db.delete(asset)
|
||||
self.db.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error deleting asset: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def update_asset(
|
||||
self,
|
||||
asset_id: int,
|
||||
user_id: str,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[ContentAsset]:
|
||||
"""Update asset metadata."""
|
||||
try:
|
||||
asset = self.get_asset_by_id(asset_id, user_id)
|
||||
if not asset:
|
||||
return None
|
||||
|
||||
if title is not None:
|
||||
asset.title = title
|
||||
if description is not None:
|
||||
asset.description = description
|
||||
if tags is not None:
|
||||
asset.tags = tags
|
||||
if metadata is not None:
|
||||
asset.metadata = {**(asset.metadata or {}), **metadata}
|
||||
|
||||
asset.updated_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
self.db.refresh(asset)
|
||||
|
||||
logger.info(f"Updated asset {asset_id} for user {user_id}")
|
||||
return asset
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error updating asset: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
def update_asset_usage(self, asset_id: int, user_id: str, action: str = "access"):
|
||||
"""Update asset usage statistics."""
|
||||
try:
|
||||
asset = self.get_asset_by_id(asset_id, user_id)
|
||||
if not asset:
|
||||
return
|
||||
|
||||
if action == "download":
|
||||
asset.download_count += 1
|
||||
elif action == "share":
|
||||
asset.share_count += 1
|
||||
|
||||
asset.last_accessed = datetime.utcnow()
|
||||
self.db.commit()
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error updating asset usage: {str(e)}", exc_info=True)
|
||||
|
||||
def get_asset_statistics(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get statistics about user's assets."""
|
||||
try:
|
||||
total = self.db.query(func.count(ContentAsset.id)).filter(
|
||||
ContentAsset.user_id == user_id
|
||||
).scalar() or 0
|
||||
|
||||
by_type = self.db.query(
|
||||
ContentAsset.asset_type,
|
||||
func.count(ContentAsset.id)
|
||||
).filter(
|
||||
ContentAsset.user_id == user_id
|
||||
).group_by(ContentAsset.asset_type).all()
|
||||
|
||||
by_source = self.db.query(
|
||||
ContentAsset.source_module,
|
||||
func.count(ContentAsset.id)
|
||||
).filter(
|
||||
ContentAsset.user_id == user_id
|
||||
).group_by(ContentAsset.source_module).all()
|
||||
|
||||
total_cost = self.db.query(func.sum(ContentAsset.cost)).filter(
|
||||
ContentAsset.user_id == user_id
|
||||
).scalar() or 0.0
|
||||
|
||||
favorites_count = self.db.query(func.count(ContentAsset.id)).filter(
|
||||
and_(
|
||||
ContentAsset.user_id == user_id,
|
||||
ContentAsset.is_favorite == True
|
||||
)
|
||||
).scalar() or 0
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"by_type": {str(t): c for t, c in by_type},
|
||||
"by_source": {str(s): c for s, c in by_source},
|
||||
"total_cost": float(total_cost),
|
||||
"favorites_count": favorites_count,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting asset statistics: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"total": 0,
|
||||
"by_type": {},
|
||||
"by_source": {},
|
||||
"total_cost": 0.0,
|
||||
"favorites_count": 0,
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ from models.monitoring_models import Base as MonitoringBase
|
||||
from models.persona_models import Base as PersonaBase
|
||||
from models.subscription_models import Base as SubscriptionBase
|
||||
from models.user_business_info import Base as UserBusinessInfoBase
|
||||
from models.content_asset_models import Base as ContentAssetBase
|
||||
|
||||
# Database configuration
|
||||
DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./alwrity.db')
|
||||
@@ -74,7 +75,8 @@ def init_database():
|
||||
PersonaBase.metadata.create_all(bind=engine)
|
||||
SubscriptionBase.metadata.create_all(bind=engine)
|
||||
UserBusinessInfoBase.metadata.create_all(bind=engine)
|
||||
logger.info("Database initialized successfully with all models including subscription system and business info")
|
||||
ContentAssetBase.metadata.create_all(bind=engine)
|
||||
logger.info("Database initialized successfully with all models including subscription system, business info, and content assets")
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error initializing database: {str(e)}")
|
||||
raise
|
||||
|
||||
@@ -4,6 +4,8 @@ from .studio_manager import ImageStudioManager
|
||||
from .create_service import CreateStudioService, CreateStudioRequest
|
||||
from .edit_service import EditStudioService, EditStudioRequest
|
||||
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
|
||||
from .control_service import ControlStudioService, ControlStudioRequest
|
||||
from .social_optimizer_service import SocialOptimizerService, SocialOptimizerRequest
|
||||
from .templates import PlatformTemplates, TemplateManager
|
||||
|
||||
__all__ = [
|
||||
@@ -14,6 +16,10 @@ __all__ = [
|
||||
"EditStudioRequest",
|
||||
"UpscaleStudioService",
|
||||
"UpscaleStudioRequest",
|
||||
"ControlStudioService",
|
||||
"ControlStudioRequest",
|
||||
"SocialOptimizerService",
|
||||
"SocialOptimizerRequest",
|
||||
"PlatformTemplates",
|
||||
"TemplateManager",
|
||||
]
|
||||
|
||||
277
backend/services/image_studio/control_service.py
Normal file
277
backend/services/image_studio/control_service.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""Control Studio service for AI-powered controlled image generation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from services.stability_service import StabilityAIService
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("image_studio.control")
|
||||
|
||||
|
||||
ControlOperationType = Literal[
|
||||
"sketch",
|
||||
"structure",
|
||||
"style",
|
||||
"style_transfer",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlStudioRequest:
|
||||
"""Normalized request payload for Control Studio operations."""
|
||||
|
||||
operation: ControlOperationType
|
||||
prompt: str
|
||||
control_image_base64: str # Sketch, structure, or style reference
|
||||
style_image_base64: Optional[str] = None # For style_transfer only
|
||||
negative_prompt: Optional[str] = None
|
||||
control_strength: Optional[float] = None # For sketch/structure
|
||||
fidelity: Optional[float] = None # For style
|
||||
style_strength: Optional[float] = None # For style_transfer
|
||||
composition_fidelity: Optional[float] = None # For style_transfer
|
||||
change_strength: Optional[float] = None # For style_transfer
|
||||
aspect_ratio: Optional[str] = None # For style
|
||||
style_preset: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
output_format: str = "png"
|
||||
|
||||
|
||||
class ControlStudioService:
|
||||
"""Service layer orchestrating Control Studio operations."""
|
||||
|
||||
SUPPORTED_OPERATIONS: Dict[ControlOperationType, Dict[str, Any]] = {
|
||||
"sketch": {
|
||||
"label": "Sketch to Image",
|
||||
"description": "Transform sketches into refined images with precise control.",
|
||||
"provider": "stability",
|
||||
"fields": {
|
||||
"control_image": True,
|
||||
"style_image": False,
|
||||
"control_strength": True,
|
||||
"fidelity": False,
|
||||
"style_strength": False,
|
||||
"aspect_ratio": False,
|
||||
},
|
||||
},
|
||||
"structure": {
|
||||
"label": "Structure Control",
|
||||
"description": "Generate images maintaining the structure of an input image.",
|
||||
"provider": "stability",
|
||||
"fields": {
|
||||
"control_image": True,
|
||||
"style_image": False,
|
||||
"control_strength": True,
|
||||
"fidelity": False,
|
||||
"style_strength": False,
|
||||
"aspect_ratio": False,
|
||||
},
|
||||
},
|
||||
"style": {
|
||||
"label": "Style Control",
|
||||
"description": "Generate images using style from a reference image.",
|
||||
"provider": "stability",
|
||||
"fields": {
|
||||
"control_image": True,
|
||||
"style_image": False,
|
||||
"control_strength": False,
|
||||
"fidelity": True,
|
||||
"style_strength": False,
|
||||
"aspect_ratio": True,
|
||||
},
|
||||
},
|
||||
"style_transfer": {
|
||||
"label": "Style Transfer",
|
||||
"description": "Apply visual characteristics from a style image to a target image.",
|
||||
"provider": "stability",
|
||||
"fields": {
|
||||
"control_image": True, # init_image
|
||||
"style_image": True,
|
||||
"control_strength": False,
|
||||
"fidelity": False,
|
||||
"style_strength": True,
|
||||
"aspect_ratio": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
logger.info("[Control Studio] Initialized control service")
|
||||
|
||||
@staticmethod
|
||||
def _decode_base64_image(value: Optional[str]) -> Optional[bytes]:
|
||||
"""Decode a base64 (or data URL) string to bytes."""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Handle data URLs (data:image/png;base64,...)
|
||||
if value.startswith("data:"):
|
||||
_, b64data = value.split(",", 1)
|
||||
else:
|
||||
b64data = value
|
||||
|
||||
return base64.b64decode(b64data)
|
||||
except Exception as exc:
|
||||
logger.error(f"[Control Studio] Failed to decode base64 image: {exc}")
|
||||
raise ValueError("Invalid base64 image payload") from exc
|
||||
|
||||
@staticmethod
|
||||
def _image_bytes_to_metadata(image_bytes: bytes) -> Dict[str, Any]:
|
||||
"""Extract width/height metadata from image bytes."""
|
||||
with Image.open(io.BytesIO(image_bytes)) as img:
|
||||
return {
|
||||
"width": img.width,
|
||||
"height": img.height,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _bytes_to_base64(image_bytes: bytes, output_format: str = "png") -> str:
|
||||
"""Convert raw bytes to base64 data URL."""
|
||||
b64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
return f"data:image/{output_format};base64,{b64}"
|
||||
|
||||
def list_operations(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Expose supported operations for UI rendering."""
|
||||
return self.SUPPORTED_OPERATIONS
|
||||
|
||||
async def process_control(
|
||||
self,
|
||||
request: ControlStudioRequest,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Process control request and return normalized response."""
|
||||
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_control_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
logger.info(f"[Control Studio] 🛂 Running pre-flight validation for user {user_id}")
|
||||
validate_image_control_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
num_images=1,
|
||||
)
|
||||
logger.info("[Control Studio] ✅ Pre-flight validation passed")
|
||||
except HTTPException:
|
||||
logger.error("[Control Studio] ❌ Pre-flight validation failed")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
logger.warning("[Control Studio] ⚠️ No user_id provided - skipping pre-flight validation")
|
||||
|
||||
control_image_bytes = self._decode_base64_image(request.control_image_base64)
|
||||
if not control_image_bytes:
|
||||
raise ValueError("Control image payload is required")
|
||||
|
||||
style_image_bytes = self._decode_base64_image(request.style_image_base64)
|
||||
|
||||
operation = request.operation
|
||||
logger.info("[Control Studio] Processing operation='%s' for user=%s", operation, user_id)
|
||||
|
||||
if operation not in self.SUPPORTED_OPERATIONS:
|
||||
raise ValueError(f"Unsupported control operation: {operation}")
|
||||
|
||||
stability_service = StabilityAIService()
|
||||
async with stability_service:
|
||||
if operation == "sketch":
|
||||
result = await stability_service.control_sketch(
|
||||
image=control_image_bytes,
|
||||
prompt=request.prompt,
|
||||
control_strength=request.control_strength or 0.7,
|
||||
negative_prompt=request.negative_prompt,
|
||||
seed=request.seed,
|
||||
output_format=request.output_format,
|
||||
style_preset=request.style_preset,
|
||||
)
|
||||
elif operation == "structure":
|
||||
result = await stability_service.control_structure(
|
||||
image=control_image_bytes,
|
||||
prompt=request.prompt,
|
||||
control_strength=request.control_strength or 0.7,
|
||||
negative_prompt=request.negative_prompt,
|
||||
seed=request.seed,
|
||||
output_format=request.output_format,
|
||||
style_preset=request.style_preset,
|
||||
)
|
||||
elif operation == "style":
|
||||
result = await stability_service.control_style(
|
||||
image=control_image_bytes,
|
||||
prompt=request.prompt,
|
||||
negative_prompt=request.negative_prompt,
|
||||
aspect_ratio=request.aspect_ratio or "1:1",
|
||||
fidelity=request.fidelity or 0.5,
|
||||
seed=request.seed,
|
||||
output_format=request.output_format,
|
||||
style_preset=request.style_preset,
|
||||
)
|
||||
elif operation == "style_transfer":
|
||||
if not style_image_bytes:
|
||||
raise ValueError("Style image is required for style transfer")
|
||||
result = await stability_service.control_style_transfer(
|
||||
init_image=control_image_bytes,
|
||||
style_image=style_image_bytes,
|
||||
prompt=request.prompt or "",
|
||||
negative_prompt=request.negative_prompt,
|
||||
style_strength=request.style_strength or 1.0,
|
||||
composition_fidelity=request.composition_fidelity or 0.9,
|
||||
change_strength=request.change_strength or 0.9,
|
||||
seed=request.seed,
|
||||
output_format=request.output_format,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported control operation: {operation}")
|
||||
|
||||
image_bytes = self._extract_image_bytes(result)
|
||||
metadata = self._image_bytes_to_metadata(image_bytes)
|
||||
metadata.update(
|
||||
{
|
||||
"operation": operation,
|
||||
"style_preset": request.style_preset,
|
||||
"provider": self.SUPPORTED_OPERATIONS[operation]["provider"],
|
||||
}
|
||||
)
|
||||
|
||||
response = {
|
||||
"success": True,
|
||||
"operation": operation,
|
||||
"provider": metadata["provider"],
|
||||
"image_base64": self._bytes_to_base64(image_bytes, request.output_format),
|
||||
"width": metadata["width"],
|
||||
"height": metadata["height"],
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
logger.info("[Control Studio] ✅ Operation '%s' completed", operation)
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def _extract_image_bytes(result: Any) -> bytes:
|
||||
"""Normalize Stability responses into raw image bytes."""
|
||||
if isinstance(result, bytes):
|
||||
return result
|
||||
|
||||
if isinstance(result, dict):
|
||||
artifacts = result.get("artifacts") or result.get("data") or result.get("images") or []
|
||||
for artifact in artifacts:
|
||||
if isinstance(artifact, dict):
|
||||
if artifact.get("base64"):
|
||||
return base64.b64decode(artifact["base64"])
|
||||
if artifact.get("b64_json"):
|
||||
return base64.b64decode(artifact["b64_json"])
|
||||
|
||||
raise RuntimeError("Unable to extract image bytes from provider response")
|
||||
|
||||
@@ -110,12 +110,12 @@ class EditStudioService:
|
||||
},
|
||||
"search_replace": {
|
||||
"label": "Search & Replace",
|
||||
"description": "Locate objects via search prompt and replace them.",
|
||||
"description": "Locate objects via search prompt and replace them. Optional mask for precise control.",
|
||||
"provider": "stability",
|
||||
"async": False,
|
||||
"fields": {
|
||||
"prompt": True,
|
||||
"mask": False,
|
||||
"mask": True, # Optional mask for precise region selection
|
||||
"negative_prompt": False,
|
||||
"search_prompt": True,
|
||||
"select_prompt": False,
|
||||
@@ -126,12 +126,12 @@ class EditStudioService:
|
||||
},
|
||||
"search_recolor": {
|
||||
"label": "Search & Recolor",
|
||||
"description": "Select elements via prompt and recolor them.",
|
||||
"description": "Select elements via prompt and recolor them. Optional mask for exact region selection.",
|
||||
"provider": "stability",
|
||||
"async": False,
|
||||
"fields": {
|
||||
"prompt": True,
|
||||
"mask": False,
|
||||
"mask": True, # Optional mask for precise region selection
|
||||
"negative_prompt": False,
|
||||
"search_prompt": False,
|
||||
"select_prompt": True,
|
||||
@@ -158,12 +158,12 @@ class EditStudioService:
|
||||
},
|
||||
"general_edit": {
|
||||
"label": "Prompt-based Edit",
|
||||
"description": "Free-form editing powered by Hugging Face image-to-image models.",
|
||||
"description": "Free-form editing powered by Hugging Face image-to-image models. Optional mask for selective editing.",
|
||||
"provider": "huggingface",
|
||||
"async": False,
|
||||
"fields": {
|
||||
"prompt": True,
|
||||
"mask": False,
|
||||
"mask": True, # Optional mask for selective region editing
|
||||
"negative_prompt": True,
|
||||
"search_prompt": False,
|
||||
"select_prompt": False,
|
||||
@@ -346,6 +346,7 @@ class EditStudioService:
|
||||
image=image_bytes,
|
||||
prompt=request.prompt,
|
||||
search_prompt=request.search_prompt,
|
||||
mask=mask_bytes, # Optional mask for precise region selection
|
||||
output_format=request.output_format,
|
||||
)
|
||||
elif operation == "search_recolor":
|
||||
@@ -355,6 +356,7 @@ class EditStudioService:
|
||||
image=image_bytes,
|
||||
prompt=request.prompt,
|
||||
select_prompt=request.select_prompt,
|
||||
mask=mask_bytes, # Optional mask for precise region selection
|
||||
output_format=request.output_format,
|
||||
)
|
||||
elif operation == "relight":
|
||||
@@ -403,6 +405,7 @@ class EditStudioService:
|
||||
request.prompt,
|
||||
options,
|
||||
user_id,
|
||||
mask_bytes, # Optional mask for selective editing
|
||||
)
|
||||
|
||||
return result.image_bytes
|
||||
|
||||
502
backend/services/image_studio/social_optimizer_service.py
Normal file
502
backend/services/image_studio/social_optimizer_service.py
Normal file
@@ -0,0 +1,502 @@
|
||||
"""Social Optimizer service for platform-specific image optimization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from .templates import Platform
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("image_studio.social_optimizer")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SafeZone:
|
||||
"""Safe zone configuration for text overlay."""
|
||||
top: float = 0.1 # Percentage from top
|
||||
bottom: float = 0.1 # Percentage from bottom
|
||||
left: float = 0.1 # Percentage from left
|
||||
right: float = 0.1 # Percentage from right
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlatformFormat:
|
||||
"""Platform format specification."""
|
||||
name: str
|
||||
width: int
|
||||
height: int
|
||||
ratio: str
|
||||
safe_zone: SafeZone
|
||||
file_type: str = "PNG"
|
||||
max_size_mb: float = 5.0
|
||||
|
||||
|
||||
# Platform format definitions with safe zones
|
||||
PLATFORM_FORMATS: Dict[Platform, List[PlatformFormat]] = {
|
||||
Platform.INSTAGRAM: [
|
||||
PlatformFormat(
|
||||
name="Feed Post (Square)",
|
||||
width=1080,
|
||||
height=1080,
|
||||
ratio="1:1",
|
||||
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
|
||||
),
|
||||
PlatformFormat(
|
||||
name="Feed Post (Portrait)",
|
||||
width=1080,
|
||||
height=1350,
|
||||
ratio="4:5",
|
||||
safe_zone=SafeZone(top=0.2, bottom=0.2, left=0.1, right=0.1),
|
||||
),
|
||||
PlatformFormat(
|
||||
name="Story",
|
||||
width=1080,
|
||||
height=1920,
|
||||
ratio="9:16",
|
||||
safe_zone=SafeZone(top=0.25, bottom=0.15, left=0.1, right=0.1),
|
||||
),
|
||||
PlatformFormat(
|
||||
name="Reel",
|
||||
width=1080,
|
||||
height=1920,
|
||||
ratio="9:16",
|
||||
safe_zone=SafeZone(top=0.25, bottom=0.15, left=0.1, right=0.1),
|
||||
),
|
||||
],
|
||||
Platform.FACEBOOK: [
|
||||
PlatformFormat(
|
||||
name="Feed Post",
|
||||
width=1200,
|
||||
height=630,
|
||||
ratio="1.91:1",
|
||||
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
|
||||
),
|
||||
PlatformFormat(
|
||||
name="Feed Post (Square)",
|
||||
width=1080,
|
||||
height=1080,
|
||||
ratio="1:1",
|
||||
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
|
||||
),
|
||||
PlatformFormat(
|
||||
name="Story",
|
||||
width=1080,
|
||||
height=1920,
|
||||
ratio="9:16",
|
||||
safe_zone=SafeZone(top=0.25, bottom=0.15, left=0.1, right=0.1),
|
||||
),
|
||||
PlatformFormat(
|
||||
name="Cover Photo",
|
||||
width=820,
|
||||
height=312,
|
||||
ratio="16:9",
|
||||
safe_zone=SafeZone(top=0.2, bottom=0.1, left=0.15, right=0.15),
|
||||
),
|
||||
],
|
||||
Platform.TWITTER: [
|
||||
PlatformFormat(
|
||||
name="Post",
|
||||
width=1200,
|
||||
height=675,
|
||||
ratio="16:9",
|
||||
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
|
||||
),
|
||||
PlatformFormat(
|
||||
name="Card",
|
||||
width=1200,
|
||||
height=600,
|
||||
ratio="2:1",
|
||||
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
|
||||
),
|
||||
PlatformFormat(
|
||||
name="Header",
|
||||
width=1500,
|
||||
height=500,
|
||||
ratio="3:1",
|
||||
safe_zone=SafeZone(top=0.2, bottom=0.1, left=0.15, right=0.15),
|
||||
),
|
||||
],
|
||||
Platform.LINKEDIN: [
|
||||
PlatformFormat(
|
||||
name="Feed Post",
|
||||
width=1200,
|
||||
height=628,
|
||||
ratio="1.91:1",
|
||||
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
|
||||
),
|
||||
PlatformFormat(
|
||||
name="Feed Post (Square)",
|
||||
width=1080,
|
||||
height=1080,
|
||||
ratio="1:1",
|
||||
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
|
||||
),
|
||||
PlatformFormat(
|
||||
name="Article",
|
||||
width=1200,
|
||||
height=627,
|
||||
ratio="2:1",
|
||||
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
|
||||
),
|
||||
PlatformFormat(
|
||||
name="Company Cover",
|
||||
width=1128,
|
||||
height=191,
|
||||
ratio="4:1",
|
||||
safe_zone=SafeZone(top=0.2, bottom=0.1, left=0.15, right=0.15),
|
||||
),
|
||||
],
|
||||
Platform.YOUTUBE: [
|
||||
PlatformFormat(
|
||||
name="Thumbnail",
|
||||
width=1280,
|
||||
height=720,
|
||||
ratio="16:9",
|
||||
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
|
||||
),
|
||||
PlatformFormat(
|
||||
name="Channel Art",
|
||||
width=2560,
|
||||
height=1440,
|
||||
ratio="16:9",
|
||||
safe_zone=SafeZone(top=0.2, bottom=0.1, left=0.15, right=0.15),
|
||||
),
|
||||
],
|
||||
Platform.PINTEREST: [
|
||||
PlatformFormat(
|
||||
name="Pin",
|
||||
width=1000,
|
||||
height=1500,
|
||||
ratio="2:3",
|
||||
safe_zone=SafeZone(top=0.2, bottom=0.2, left=0.1, right=0.1),
|
||||
),
|
||||
PlatformFormat(
|
||||
name="Story Pin",
|
||||
width=1080,
|
||||
height=1920,
|
||||
ratio="9:16",
|
||||
safe_zone=SafeZone(top=0.25, bottom=0.15, left=0.1, right=0.1),
|
||||
),
|
||||
],
|
||||
Platform.TIKTOK: [
|
||||
PlatformFormat(
|
||||
name="Video Cover",
|
||||
width=1080,
|
||||
height=1920,
|
||||
ratio="9:16",
|
||||
safe_zone=SafeZone(top=0.25, bottom=0.15, left=0.1, right=0.1),
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SocialOptimizerRequest:
|
||||
"""Request payload for social optimization."""
|
||||
|
||||
image_base64: str
|
||||
platforms: List[Platform] # List of platforms to optimize for
|
||||
format_names: Optional[Dict[Platform, str]] = None # Specific format per platform
|
||||
show_safe_zones: bool = False # Include safe zone overlay in output
|
||||
crop_mode: str = "smart" # "smart", "center", "fit"
|
||||
focal_point: Optional[Dict[str, float]] = None # {"x": 0.5, "y": 0.5} for smart crop
|
||||
output_format: str = "png"
|
||||
options: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class SocialOptimizerService:
|
||||
"""Service for optimizing images for social media platforms."""
|
||||
|
||||
def __init__(self):
|
||||
logger.info("[Social Optimizer] Initialized service")
|
||||
|
||||
@staticmethod
|
||||
def _decode_base64_image(value: str) -> bytes:
|
||||
"""Decode a base64 (or data URL) string to bytes."""
|
||||
try:
|
||||
if value.startswith("data:"):
|
||||
_, b64data = value.split(",", 1)
|
||||
else:
|
||||
b64data = value
|
||||
|
||||
return base64.b64decode(b64data)
|
||||
except Exception as exc:
|
||||
logger.error(f"[Social Optimizer] Failed to decode base64 image: {exc}")
|
||||
raise ValueError("Invalid base64 image payload") from exc
|
||||
|
||||
@staticmethod
|
||||
def _bytes_to_base64(image_bytes: bytes, output_format: str = "png") -> str:
|
||||
"""Convert raw bytes to base64 data URL."""
|
||||
b64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
return f"data:image/{output_format};base64,{b64}"
|
||||
|
||||
@staticmethod
|
||||
def _smart_crop(
|
||||
image: Image.Image,
|
||||
target_width: int,
|
||||
target_height: int,
|
||||
focal_point: Optional[Dict[str, float]] = None,
|
||||
) -> Image.Image:
|
||||
"""Smart crop image to target dimensions, preserving important content."""
|
||||
img_width, img_height = image.size
|
||||
target_ratio = target_width / target_height
|
||||
img_ratio = img_width / img_height
|
||||
|
||||
# If focal point is provided, use it for cropping
|
||||
if focal_point:
|
||||
focal_x = int(focal_point["x"] * img_width)
|
||||
focal_y = int(focal_point["y"] * img_height)
|
||||
else:
|
||||
# Default to center
|
||||
focal_x = img_width // 2
|
||||
focal_y = img_height // 2
|
||||
|
||||
if img_ratio > target_ratio:
|
||||
# Image is wider than target - crop width
|
||||
new_width = int(img_height * target_ratio)
|
||||
left = max(0, min(focal_x - new_width // 2, img_width - new_width))
|
||||
right = left + new_width
|
||||
cropped = image.crop((left, 0, right, img_height))
|
||||
else:
|
||||
# Image is taller than target - crop height
|
||||
new_height = int(img_width / target_ratio)
|
||||
top = max(0, min(focal_y - new_height // 2, img_height - new_height))
|
||||
bottom = top + new_height
|
||||
cropped = image.crop((0, top, img_width, bottom))
|
||||
|
||||
# Resize to exact target dimensions
|
||||
return cropped.resize((target_width, target_height), Image.Resampling.LANCZOS)
|
||||
|
||||
@staticmethod
|
||||
def _fit_image(
|
||||
image: Image.Image,
|
||||
target_width: int,
|
||||
target_height: int,
|
||||
) -> Image.Image:
|
||||
"""Fit image to target dimensions while maintaining aspect ratio (adds padding if needed)."""
|
||||
img_width, img_height = image.size
|
||||
target_ratio = target_width / target_height
|
||||
img_ratio = img_width / img_height
|
||||
|
||||
if img_ratio > target_ratio:
|
||||
# Image is wider - fit to height, pad width
|
||||
new_height = target_height
|
||||
new_width = int(img_width * (target_height / img_height))
|
||||
resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
# Create new image with target size and paste centered
|
||||
result = Image.new("RGB", (target_width, target_height), (255, 255, 255))
|
||||
paste_x = (target_width - new_width) // 2
|
||||
result.paste(resized, (paste_x, 0))
|
||||
return result
|
||||
else:
|
||||
# Image is taller - fit to width, pad height
|
||||
new_width = target_width
|
||||
new_height = int(img_height * (target_width / img_width))
|
||||
resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
# Create new image with target size and paste centered
|
||||
result = Image.new("RGB", (target_width, target_height), (255, 255, 255))
|
||||
paste_y = (target_height - new_height) // 2
|
||||
result.paste(resized, (0, paste_y))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _center_crop(
|
||||
image: Image.Image,
|
||||
target_width: int,
|
||||
target_height: int,
|
||||
) -> Image.Image:
|
||||
"""Center crop image to target dimensions."""
|
||||
img_width, img_height = image.size
|
||||
target_ratio = target_width / target_height
|
||||
img_ratio = img_width / img_height
|
||||
|
||||
if img_ratio > target_ratio:
|
||||
# Image is wider - crop width
|
||||
new_width = int(img_height * target_ratio)
|
||||
left = (img_width - new_width) // 2
|
||||
cropped = image.crop((left, 0, left + new_width, img_height))
|
||||
else:
|
||||
# Image is taller - crop height
|
||||
new_height = int(img_width / target_ratio)
|
||||
top = (img_height - new_height) // 2
|
||||
cropped = image.crop((0, top, img_width, top + new_height))
|
||||
|
||||
return cropped.resize((target_width, target_height), Image.Resampling.LANCZOS)
|
||||
|
||||
@staticmethod
|
||||
def _draw_safe_zone(
|
||||
image: Image.Image,
|
||||
safe_zone: SafeZone,
|
||||
) -> Image.Image:
|
||||
"""Draw safe zone overlay on image."""
|
||||
draw = ImageDraw.Draw(image)
|
||||
width, height = image.size
|
||||
|
||||
# Calculate safe zone boundaries
|
||||
top = int(height * safe_zone.top)
|
||||
bottom = int(height * (1 - safe_zone.bottom))
|
||||
left = int(width * safe_zone.left)
|
||||
right = int(width * (1 - safe_zone.right))
|
||||
|
||||
# Draw semi-transparent overlay outside safe zone
|
||||
overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
|
||||
overlay_draw = ImageDraw.Draw(overlay)
|
||||
|
||||
# Top area
|
||||
overlay_draw.rectangle([(0, 0), (width, top)], fill=(0, 0, 0, 100))
|
||||
# Bottom area
|
||||
overlay_draw.rectangle([(0, bottom), (width, height)], fill=(0, 0, 0, 100))
|
||||
# Left area
|
||||
overlay_draw.rectangle([(0, top), (left, bottom)], fill=(0, 0, 0, 100))
|
||||
# Right area
|
||||
overlay_draw.rectangle([(right, top), (width, bottom)], fill=(0, 0, 0, 100))
|
||||
|
||||
# Draw safe zone border
|
||||
border_color = (255, 255, 0, 200) # Yellow with transparency
|
||||
overlay_draw.rectangle(
|
||||
[(left, top), (right, bottom)],
|
||||
outline=border_color,
|
||||
width=2,
|
||||
)
|
||||
|
||||
# Composite overlay onto image
|
||||
if image.mode != "RGBA":
|
||||
image = image.convert("RGBA")
|
||||
image = Image.alpha_composite(image, overlay)
|
||||
|
||||
return image
|
||||
|
||||
def get_platform_formats(self, platform: Platform) -> List[Dict[str, Any]]:
|
||||
"""Get available formats for a platform."""
|
||||
formats = PLATFORM_FORMATS.get(platform, [])
|
||||
return [
|
||||
{
|
||||
"name": fmt.name,
|
||||
"width": fmt.width,
|
||||
"height": fmt.height,
|
||||
"ratio": fmt.ratio,
|
||||
"safe_zone": {
|
||||
"top": fmt.safe_zone.top,
|
||||
"bottom": fmt.safe_zone.bottom,
|
||||
"left": fmt.safe_zone.left,
|
||||
"right": fmt.safe_zone.right,
|
||||
},
|
||||
"file_type": fmt.file_type,
|
||||
"max_size_mb": fmt.max_size_mb,
|
||||
}
|
||||
for fmt in formats
|
||||
]
|
||||
|
||||
def optimize_image(
|
||||
self,
|
||||
request: SocialOptimizerRequest,
|
||||
) -> Dict[str, Any]:
|
||||
"""Optimize image for specified platforms."""
|
||||
logger.info(
|
||||
f"[Social Optimizer] Processing optimization for {len(request.platforms)} platform(s)"
|
||||
)
|
||||
|
||||
# Decode input image
|
||||
image_bytes = self._decode_base64_image(request.image_base64)
|
||||
original_image = Image.open(io.BytesIO(image_bytes))
|
||||
|
||||
# Convert to RGB if needed
|
||||
if original_image.mode in ("RGBA", "LA", "P"):
|
||||
if original_image.mode == "P":
|
||||
original_image = original_image.convert("RGBA")
|
||||
background = Image.new("RGB", original_image.size, (255, 255, 255))
|
||||
if original_image.mode == "RGBA":
|
||||
background.paste(original_image, mask=original_image.split()[-1])
|
||||
else:
|
||||
background.paste(original_image)
|
||||
original_image = background
|
||||
elif original_image.mode != "RGB":
|
||||
original_image = original_image.convert("RGB")
|
||||
|
||||
results = []
|
||||
|
||||
for platform in request.platforms:
|
||||
formats = PLATFORM_FORMATS.get(platform, [])
|
||||
if not formats:
|
||||
logger.warning(f"[Social Optimizer] No formats found for platform: {platform}")
|
||||
continue
|
||||
|
||||
# Get format (use specified format or default to first)
|
||||
format_name = None
|
||||
if request.format_names and platform in request.format_names:
|
||||
format_name = request.format_names[platform]
|
||||
|
||||
platform_format = None
|
||||
for fmt in formats:
|
||||
if format_name and fmt.name == format_name:
|
||||
platform_format = fmt
|
||||
break
|
||||
if not platform_format:
|
||||
platform_format = formats[0] # Default to first format
|
||||
|
||||
# Crop/resize image based on mode
|
||||
if request.crop_mode == "smart":
|
||||
optimized_image = self._smart_crop(
|
||||
original_image,
|
||||
platform_format.width,
|
||||
platform_format.height,
|
||||
request.focal_point,
|
||||
)
|
||||
elif request.crop_mode == "fit":
|
||||
optimized_image = self._fit_image(
|
||||
original_image,
|
||||
platform_format.width,
|
||||
platform_format.height,
|
||||
)
|
||||
else: # center
|
||||
optimized_image = self._center_crop(
|
||||
original_image,
|
||||
platform_format.width,
|
||||
platform_format.height,
|
||||
)
|
||||
|
||||
# Add safe zone overlay if requested
|
||||
if request.show_safe_zones:
|
||||
optimized_image = self._draw_safe_zone(optimized_image, platform_format.safe_zone)
|
||||
|
||||
# Convert to bytes
|
||||
output_buffer = io.BytesIO()
|
||||
output_format = request.output_format.lower()
|
||||
if output_format == "jpg" or output_format == "jpeg":
|
||||
optimized_image = optimized_image.convert("RGB")
|
||||
optimized_image.save(output_buffer, format="JPEG", quality=95)
|
||||
else:
|
||||
optimized_image.save(output_buffer, format="PNG")
|
||||
output_bytes = output_buffer.getvalue()
|
||||
|
||||
results.append(
|
||||
{
|
||||
"platform": platform.value,
|
||||
"format": platform_format.name,
|
||||
"width": platform_format.width,
|
||||
"height": platform_format.height,
|
||||
"ratio": platform_format.ratio,
|
||||
"image_base64": self._bytes_to_base64(output_bytes, request.output_format),
|
||||
"safe_zone": {
|
||||
"top": platform_format.safe_zone.top,
|
||||
"bottom": platform_format.safe_zone.bottom,
|
||||
"left": platform_format.safe_zone.left,
|
||||
"right": platform_format.safe_zone.right,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Social Optimizer] ✅ Generated {len(results)} optimized images")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"results": results,
|
||||
"total_optimized": len(results),
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ from typing import Optional, Dict, Any, List
|
||||
from .create_service import CreateStudioService, CreateStudioRequest
|
||||
from .edit_service import EditStudioService, EditStudioRequest
|
||||
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
|
||||
from .control_service import ControlStudioService, ControlStudioRequest
|
||||
from .social_optimizer_service import SocialOptimizerService, SocialOptimizerRequest
|
||||
from .templates import Platform, TemplateCategory, ImageTemplate
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
@@ -20,6 +22,8 @@ class ImageStudioManager:
|
||||
self.create_service = CreateStudioService()
|
||||
self.edit_service = EditStudioService()
|
||||
self.upscale_service = UpscaleStudioService()
|
||||
self.control_service = ControlStudioService()
|
||||
self.social_optimizer_service = SocialOptimizerService()
|
||||
logger.info("[Image Studio Manager] Initialized successfully")
|
||||
|
||||
# ====================
|
||||
@@ -215,6 +219,40 @@ class ImageStudioManager:
|
||||
"estimated": True,
|
||||
}
|
||||
|
||||
# ====================
|
||||
# CONTROL STUDIO
|
||||
# ====================
|
||||
|
||||
async def control_image(
|
||||
self,
|
||||
request: ControlStudioRequest,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run Control Studio operations."""
|
||||
logger.info("[Image Studio] Control request from user: %s", user_id)
|
||||
return await self.control_service.process_control(request, user_id=user_id)
|
||||
|
||||
def get_control_operations(self) -> Dict[str, Any]:
|
||||
"""Expose control operations for UI."""
|
||||
return self.control_service.list_operations()
|
||||
|
||||
# ====================
|
||||
# SOCIAL OPTIMIZER
|
||||
# ====================
|
||||
|
||||
async def optimize_for_social(
|
||||
self,
|
||||
request: SocialOptimizerRequest,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Optimize image for social media platforms."""
|
||||
logger.info("[Image Studio] Social optimization request from user: %s", user_id)
|
||||
return self.social_optimizer_service.optimize_image(request)
|
||||
|
||||
def get_social_platform_formats(self, platform: Platform) -> List[Dict[str, Any]]:
|
||||
"""Get available formats for a social platform."""
|
||||
return self.social_optimizer_service.get_platform_formats(platform)
|
||||
|
||||
# ====================
|
||||
# PLATFORM SPECS
|
||||
# ====================
|
||||
|
||||
@@ -54,7 +54,8 @@ def edit_image(
|
||||
input_image_bytes: bytes,
|
||||
prompt: str,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None
|
||||
user_id: Optional[str] = None,
|
||||
mask_bytes: Optional[bytes] = None,
|
||||
) -> ImageGenerationResult:
|
||||
"""Edit image with pre-flight validation.
|
||||
|
||||
@@ -63,6 +64,7 @@ def edit_image(
|
||||
prompt: Natural language prompt describing desired edits (e.g., "Turn the cat into a tiger")
|
||||
options: Image editing options (provider, model, etc.)
|
||||
user_id: User ID for subscription checking (optional, but required for validation)
|
||||
mask_bytes: Optional mask image bytes for selective editing (grayscale, white=edit, black=preserve)
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with edited image bytes and metadata
|
||||
@@ -72,6 +74,8 @@ def edit_image(
|
||||
- Describe what should change and what should remain
|
||||
- Examples: "Turn the cat into a tiger", "Change background to forest",
|
||||
"Make it look like a watercolor painting"
|
||||
|
||||
Note: Mask support depends on the specific model. Some models may ignore the mask parameter.
|
||||
"""
|
||||
# PRE-FLIGHT VALIDATION: Validate image editing before API call
|
||||
# MUST happen BEFORE any API calls - return immediately if validation fails
|
||||
@@ -128,14 +132,33 @@ def edit_image(
|
||||
width = input_image.width
|
||||
height = input_image.height
|
||||
|
||||
# Convert mask bytes to PIL Image if provided
|
||||
mask_image = None
|
||||
if mask_bytes:
|
||||
try:
|
||||
mask_image = Image.open(io.BytesIO(mask_bytes)).convert("L") # Convert to grayscale
|
||||
# Ensure mask dimensions match input image
|
||||
if mask_image.size != input_image.size:
|
||||
logger.warning(f"[Image Editing] Mask size {mask_image.size} doesn't match image size {input_image.size}, resizing mask")
|
||||
mask_image = mask_image.resize(input_image.size, Image.Resampling.LANCZOS)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Image Editing] Failed to process mask image: {e}, continuing without mask")
|
||||
mask_image = None
|
||||
|
||||
# Use image_to_image method from Hugging Face InferenceClient
|
||||
# This follows the pattern from the Hugging Face documentation
|
||||
# Docs: https://huggingface.co/docs/inference-providers/en/guides/image-editor
|
||||
# Note: Mask support depends on the model - some models may ignore it
|
||||
call_params = params.copy()
|
||||
if mask_image:
|
||||
call_params["mask_image"] = mask_image
|
||||
logger.info("[Image Editing] Using mask for selective editing")
|
||||
|
||||
edited_image: Image.Image = client.image_to_image(
|
||||
image=input_image,
|
||||
prompt=prompt.strip(),
|
||||
model=model,
|
||||
**params,
|
||||
**call_params,
|
||||
)
|
||||
|
||||
# Convert edited image back to bytes
|
||||
|
||||
@@ -397,6 +397,7 @@ class StabilityAIService:
|
||||
image: Union[UploadFile, bytes],
|
||||
prompt: str,
|
||||
search_prompt: str,
|
||||
mask: Optional[Union[UploadFile, bytes]] = None,
|
||||
**kwargs
|
||||
) -> Union[bytes, Dict[str, Any]]:
|
||||
"""Replace objects in image using search prompt.
|
||||
@@ -405,6 +406,7 @@ class StabilityAIService:
|
||||
image: Input image
|
||||
prompt: Text prompt for replacement
|
||||
search_prompt: What to search for
|
||||
mask: Optional mask image for precise region selection
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
@@ -414,6 +416,8 @@ class StabilityAIService:
|
||||
data.update({k: v for k, v in kwargs.items() if v is not None})
|
||||
|
||||
files = {"image": await self._prepare_image_file(image)}
|
||||
if mask:
|
||||
files["mask"] = await self._prepare_image_file(mask)
|
||||
|
||||
return await self._make_request(
|
||||
method="POST",
|
||||
@@ -427,6 +431,7 @@ class StabilityAIService:
|
||||
image: Union[UploadFile, bytes],
|
||||
prompt: str,
|
||||
select_prompt: str,
|
||||
mask: Optional[Union[UploadFile, bytes]] = None,
|
||||
**kwargs
|
||||
) -> Union[bytes, Dict[str, Any]]:
|
||||
"""Recolor objects in image using select prompt.
|
||||
@@ -435,6 +440,7 @@ class StabilityAIService:
|
||||
image: Input image
|
||||
prompt: Text prompt for recoloring
|
||||
select_prompt: What to select for recoloring
|
||||
mask: Optional mask image for precise region selection
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
@@ -444,6 +450,8 @@ class StabilityAIService:
|
||||
data.update({k: v for k, v in kwargs.items() if v is not None})
|
||||
|
||||
files = {"image": await self._prepare_image_file(image)}
|
||||
if mask:
|
||||
files["mask"] = await self._prepare_image_file(mask)
|
||||
|
||||
return await self._make_request(
|
||||
method="POST",
|
||||
|
||||
@@ -415,6 +415,75 @@ def validate_image_editing_operations(
|
||||
)
|
||||
|
||||
|
||||
def validate_image_control_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str,
|
||||
num_images: int = 1
|
||||
) -> None:
|
||||
"""
|
||||
Validate image control operations (sketch-to-image, structure control, style transfer) before making API calls.
|
||||
|
||||
Control operations use Stability AI for image generation with control inputs, so they use
|
||||
the same validation as image generation operations.
|
||||
|
||||
Args:
|
||||
pricing_service: PricingService instance
|
||||
user_id: User ID for subscription checking
|
||||
num_images: Number of images to generate (for multiple variations)
|
||||
|
||||
Returns:
|
||||
None - raises HTTPException with 429 status if validation fails
|
||||
"""
|
||||
try:
|
||||
# Control operations use Stability AI, same as image generation
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.STABILITY,
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'stability',
|
||||
'operation_type': 'image_generation' # Control ops use image generation limits
|
||||
}
|
||||
for _ in range(num_images)
|
||||
]
|
||||
|
||||
logger.info(f"[Pre-flight Validator] 🚀 Validating {num_images} image control operation(s) for user {user_id}")
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.error(f"[Pre-flight Validator] Image control blocked for user {user_id}: {message}")
|
||||
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', 'stability') if usage_info else 'stability'
|
||||
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ Image control validated for user {user_id}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating image control: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
'error': f"Failed to validate image control: {str(e)}",
|
||||
'message': f"Failed to validate image control: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def validate_video_generation_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str
|
||||
|
||||
158
backend/utils/asset_tracker.py
Normal file
158
backend/utils/asset_tracker.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
Asset Tracker Utility
|
||||
Helper utility for modules to easily save generated content to the unified asset library.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from services.content_asset_service import ContentAssetService
|
||||
from models.content_asset_models import AssetType, AssetSource
|
||||
import logging
|
||||
import re
|
||||
from urllib.parse import urlparse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum file size (100MB)
|
||||
MAX_FILE_SIZE = 100 * 1024 * 1024
|
||||
|
||||
# Allowed URL schemes
|
||||
ALLOWED_URL_SCHEMES = ['http', 'https', '/'] # Allow relative paths starting with /
|
||||
|
||||
|
||||
def validate_file_url(file_url: str) -> bool:
|
||||
"""Validate file URL format."""
|
||||
if not file_url or not isinstance(file_url, str):
|
||||
return False
|
||||
|
||||
# Allow relative paths
|
||||
if file_url.startswith('/'):
|
||||
return True
|
||||
|
||||
# Validate absolute URLs
|
||||
try:
|
||||
parsed = urlparse(file_url)
|
||||
return parsed.scheme in ALLOWED_URL_SCHEMES and parsed.netloc
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def save_asset_to_library(
|
||||
db: Session,
|
||||
user_id: str,
|
||||
asset_type: str,
|
||||
source_module: str,
|
||||
filename: str,
|
||||
file_url: str,
|
||||
file_path: Optional[str] = None,
|
||||
file_size: Optional[int] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
prompt: Optional[str] = None,
|
||||
tags: Optional[list] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
cost: Optional[float] = None,
|
||||
generation_time: Optional[float] = None,
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
Helper function to save a generated asset to the unified asset library.
|
||||
|
||||
This can be called from any module (story writer, image studio, etc.)
|
||||
to automatically track generated content.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: Clerk user ID
|
||||
asset_type: 'text', 'image', 'video', or 'audio'
|
||||
source_module: 'story_writer', 'image_studio', 'main_text_generation', etc.
|
||||
filename: Original filename
|
||||
file_url: Public URL to access the asset
|
||||
file_path: Server file path (optional)
|
||||
file_size: File size in bytes (optional)
|
||||
mime_type: MIME type (optional)
|
||||
title: Asset title (optional)
|
||||
description: Asset description (optional)
|
||||
prompt: Generation prompt (optional)
|
||||
tags: List of tags (optional)
|
||||
metadata: Additional metadata (optional)
|
||||
provider: AI provider used (optional)
|
||||
model: Model used (optional)
|
||||
cost: Generation cost (optional)
|
||||
generation_time: Generation time in seconds (optional)
|
||||
|
||||
Returns:
|
||||
Asset ID if successful, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Validate inputs
|
||||
if not user_id or not isinstance(user_id, str):
|
||||
logger.error("Invalid user_id provided")
|
||||
return None
|
||||
|
||||
if not filename or not isinstance(filename, str):
|
||||
logger.error("Invalid filename provided")
|
||||
return None
|
||||
|
||||
if not validate_file_url(file_url):
|
||||
logger.error(f"Invalid file_url format: {file_url}")
|
||||
return None
|
||||
|
||||
if file_size and file_size > MAX_FILE_SIZE:
|
||||
logger.warning(f"File size {file_size} exceeds maximum {MAX_FILE_SIZE}")
|
||||
# Don't fail, just log warning
|
||||
|
||||
# Convert string enums to enum types
|
||||
try:
|
||||
asset_type_enum = AssetType(asset_type.lower())
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid asset type: {asset_type}, defaulting to 'text'")
|
||||
asset_type_enum = AssetType.TEXT
|
||||
|
||||
try:
|
||||
source_module_enum = AssetSource(source_module.lower())
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid source module: {source_module}, defaulting to 'story_writer'")
|
||||
source_module_enum = AssetSource.STORY_WRITER
|
||||
|
||||
# Sanitize filename (remove path traversal attempts)
|
||||
filename = re.sub(r'[^\w\s\-_\.]', '', filename.split('/')[-1])
|
||||
if not filename:
|
||||
filename = f"asset_{asset_type}_{source_module}.{asset_type}"
|
||||
|
||||
# Generate title from filename if not provided
|
||||
if not title:
|
||||
title = filename.replace('_', ' ').replace('-', ' ').title()
|
||||
# Limit title length
|
||||
if len(title) > 200:
|
||||
title = title[:197] + '...'
|
||||
|
||||
service = ContentAssetService(db)
|
||||
asset = service.create_asset(
|
||||
user_id=user_id,
|
||||
asset_type=asset_type_enum,
|
||||
source_module=source_module_enum,
|
||||
filename=filename,
|
||||
file_url=file_url,
|
||||
file_path=file_path,
|
||||
file_size=file_size,
|
||||
mime_type=mime_type,
|
||||
title=title,
|
||||
description=description,
|
||||
prompt=prompt,
|
||||
tags=tags or [],
|
||||
metadata=metadata or {},
|
||||
provider=provider,
|
||||
model=model,
|
||||
cost=cost,
|
||||
generation_time=generation_time,
|
||||
)
|
||||
|
||||
logger.info(f"✅ Asset saved to library: {asset.id} ({asset_type} from {source_module})")
|
||||
return asset.id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error saving asset to library: {str(e)}", exc_info=True)
|
||||
return None
|
||||
Reference in New Issue
Block a user