AI Image Studio Progress Review

- Added new router for content assets
- Added new service for content assets
- Added new model for content assets
- Added new utils for content assets
- Added new docs for content assets
- Added new tests for content assets
- Added new examples for content assets
- Added new guides for content assets
This commit is contained in:
ajaysi
2025-11-23 09:21:11 +05:30
parent eede21ad42
commit 77d7c0cde6
38 changed files with 5939 additions and 37 deletions

View File

@@ -0,0 +1,322 @@
"""
Content Asset Service
Service for managing and tracking all AI-generated content assets.
"""
from typing import Dict, Any, List, Optional, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_, func, desc
from datetime import datetime
from models.content_asset_models import (
ContentAsset,
AssetCollection,
AssetType,
AssetSource
)
import logging
logger = logging.getLogger(__name__)
class ContentAssetService:
"""Service for managing content assets across all modules."""
def __init__(self, db: Session):
self.db = db
def create_asset(
self,
user_id: str,
asset_type: AssetType,
source_module: AssetSource,
filename: str,
file_url: str,
file_path: Optional[str] = None,
file_size: Optional[int] = None,
mime_type: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
prompt: Optional[str] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
provider: Optional[str] = None,
model: Optional[str] = None,
cost: Optional[float] = None,
generation_time: Optional[float] = None,
) -> ContentAsset:
"""
Create a new content asset record.
Args:
user_id: Clerk user ID
asset_type: Type of asset (text, image, video, audio)
source_module: Source module that generated it
filename: Original filename
file_url: Public URL to access the asset
file_path: Server file path (optional)
file_size: File size in bytes (optional)
mime_type: MIME type (optional)
title: Asset title (optional)
description: Asset description (optional)
prompt: Generation prompt (optional)
tags: List of tags (optional)
metadata: Additional metadata (optional)
provider: AI provider used (optional)
model: Model used (optional)
cost: Generation cost (optional)
generation_time: Generation time in seconds (optional)
Returns:
Created ContentAsset instance
"""
try:
asset = ContentAsset(
user_id=user_id,
asset_type=asset_type,
source_module=source_module,
filename=filename,
file_url=file_url,
file_path=file_path,
file_size=file_size,
mime_type=mime_type,
title=title,
description=description,
prompt=prompt,
tags=tags or [],
metadata=metadata or {},
provider=provider,
model=model,
cost=cost or 0.0,
generation_time=generation_time,
)
self.db.add(asset)
self.db.commit()
self.db.refresh(asset)
logger.info(f"Created asset {asset.id} for user {user_id} from {source_module.value}")
return asset
except Exception as e:
self.db.rollback()
logger.error(f"Error creating asset: {str(e)}", exc_info=True)
raise
def get_user_assets(
self,
user_id: str,
asset_type: Optional[AssetType] = None,
source_module: Optional[AssetSource] = None,
search_query: Optional[str] = None,
tags: Optional[List[str]] = None,
favorites_only: bool = False,
limit: int = 100,
offset: int = 0,
) -> Tuple[List[ContentAsset], int]:
"""
Get assets for a user with optional filtering.
Args:
user_id: Clerk user ID
asset_type: Filter by asset type (optional)
source_module: Filter by source module (optional)
search_query: Search in title, description, prompt (optional)
tags: Filter by tags (optional)
favorites_only: Only return favorites (optional)
limit: Maximum number of results
offset: Offset for pagination
Returns:
List of ContentAsset instances
"""
try:
query = self.db.query(ContentAsset).filter(
ContentAsset.user_id == user_id
)
if asset_type:
query = query.filter(ContentAsset.asset_type == asset_type)
if source_module:
query = query.filter(ContentAsset.source_module == source_module)
if favorites_only:
query = query.filter(ContentAsset.is_favorite == True)
if search_query:
search_filter = or_(
ContentAsset.title.ilike(f"%{search_query}%"),
ContentAsset.description.ilike(f"%{search_query}%"),
ContentAsset.prompt.ilike(f"%{search_query}%"),
ContentAsset.filename.ilike(f"%{search_query}%"),
)
query = query.filter(search_filter)
if tags:
# Filter by tags (JSON array contains any of the tags)
tag_filters = [ContentAsset.tags.contains([tag]) for tag in tags]
query = query.filter(or_(*tag_filters))
# Get total count before pagination
total_count = query.count()
# Apply ordering and pagination
query = query.order_by(desc(ContentAsset.created_at))
query = query.limit(limit).offset(offset)
return query.all(), total_count
except Exception as e:
logger.error(f"Error fetching assets: {str(e)}", exc_info=True)
raise
def get_asset_by_id(self, asset_id: int, user_id: str) -> Optional[ContentAsset]:
"""Get a specific asset by ID."""
try:
return self.db.query(ContentAsset).filter(
and_(
ContentAsset.id == asset_id,
ContentAsset.user_id == user_id
)
).first()
except Exception as e:
logger.error(f"Error fetching asset {asset_id}: {str(e)}", exc_info=True)
return None
def toggle_favorite(self, asset_id: int, user_id: str) -> bool:
"""Toggle favorite status of an asset."""
try:
asset = self.get_asset_by_id(asset_id, user_id)
if not asset:
return False
asset.is_favorite = not asset.is_favorite
self.db.commit()
return asset.is_favorite
except Exception as e:
self.db.rollback()
logger.error(f"Error toggling favorite: {str(e)}", exc_info=True)
return False
def delete_asset(self, asset_id: int, user_id: str) -> bool:
"""Delete an asset."""
try:
asset = self.get_asset_by_id(asset_id, user_id)
if not asset:
return False
self.db.delete(asset)
self.db.commit()
return True
except Exception as e:
self.db.rollback()
logger.error(f"Error deleting asset: {str(e)}", exc_info=True)
return False
def update_asset(
self,
asset_id: int,
user_id: str,
title: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Optional[ContentAsset]:
"""Update asset metadata."""
try:
asset = self.get_asset_by_id(asset_id, user_id)
if not asset:
return None
if title is not None:
asset.title = title
if description is not None:
asset.description = description
if tags is not None:
asset.tags = tags
if metadata is not None:
asset.metadata = {**(asset.metadata or {}), **metadata}
asset.updated_at = datetime.utcnow()
self.db.commit()
self.db.refresh(asset)
logger.info(f"Updated asset {asset_id} for user {user_id}")
return asset
except Exception as e:
self.db.rollback()
logger.error(f"Error updating asset: {str(e)}", exc_info=True)
return None
def update_asset_usage(self, asset_id: int, user_id: str, action: str = "access"):
"""Update asset usage statistics."""
try:
asset = self.get_asset_by_id(asset_id, user_id)
if not asset:
return
if action == "download":
asset.download_count += 1
elif action == "share":
asset.share_count += 1
asset.last_accessed = datetime.utcnow()
self.db.commit()
except Exception as e:
self.db.rollback()
logger.error(f"Error updating asset usage: {str(e)}", exc_info=True)
def get_asset_statistics(self, user_id: str) -> Dict[str, Any]:
"""Get statistics about user's assets."""
try:
total = self.db.query(func.count(ContentAsset.id)).filter(
ContentAsset.user_id == user_id
).scalar() or 0
by_type = self.db.query(
ContentAsset.asset_type,
func.count(ContentAsset.id)
).filter(
ContentAsset.user_id == user_id
).group_by(ContentAsset.asset_type).all()
by_source = self.db.query(
ContentAsset.source_module,
func.count(ContentAsset.id)
).filter(
ContentAsset.user_id == user_id
).group_by(ContentAsset.source_module).all()
total_cost = self.db.query(func.sum(ContentAsset.cost)).filter(
ContentAsset.user_id == user_id
).scalar() or 0.0
favorites_count = self.db.query(func.count(ContentAsset.id)).filter(
and_(
ContentAsset.user_id == user_id,
ContentAsset.is_favorite == True
)
).scalar() or 0
return {
"total": total,
"by_type": {str(t): c for t, c in by_type},
"by_source": {str(s): c for s, c in by_source},
"total_cost": float(total_cost),
"favorites_count": favorites_count,
}
except Exception as e:
logger.error(f"Error getting asset statistics: {str(e)}", exc_info=True)
return {
"total": 0,
"by_type": {},
"by_source": {},
"total_cost": 0.0,
"favorites_count": 0,
}

View File

@@ -20,6 +20,7 @@ from models.monitoring_models import Base as MonitoringBase
from models.persona_models import Base as PersonaBase
from models.subscription_models import Base as SubscriptionBase
from models.user_business_info import Base as UserBusinessInfoBase
from models.content_asset_models import Base as ContentAssetBase
# Database configuration
DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./alwrity.db')
@@ -74,7 +75,8 @@ def init_database():
PersonaBase.metadata.create_all(bind=engine)
SubscriptionBase.metadata.create_all(bind=engine)
UserBusinessInfoBase.metadata.create_all(bind=engine)
logger.info("Database initialized successfully with all models including subscription system and business info")
ContentAssetBase.metadata.create_all(bind=engine)
logger.info("Database initialized successfully with all models including subscription system, business info, and content assets")
except SQLAlchemyError as e:
logger.error(f"Error initializing database: {str(e)}")
raise

View File

@@ -4,6 +4,8 @@ from .studio_manager import ImageStudioManager
from .create_service import CreateStudioService, CreateStudioRequest
from .edit_service import EditStudioService, EditStudioRequest
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
from .control_service import ControlStudioService, ControlStudioRequest
from .social_optimizer_service import SocialOptimizerService, SocialOptimizerRequest
from .templates import PlatformTemplates, TemplateManager
__all__ = [
@@ -14,6 +16,10 @@ __all__ = [
"EditStudioRequest",
"UpscaleStudioService",
"UpscaleStudioRequest",
"ControlStudioService",
"ControlStudioRequest",
"SocialOptimizerService",
"SocialOptimizerRequest",
"PlatformTemplates",
"TemplateManager",
]

View File

@@ -0,0 +1,277 @@
"""Control Studio service for AI-powered controlled image generation."""
from __future__ import annotations
import base64
import io
from dataclasses import dataclass
from typing import Any, Dict, Literal, Optional
from PIL import Image
from services.stability_service import StabilityAIService
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.control")
ControlOperationType = Literal[
"sketch",
"structure",
"style",
"style_transfer",
]
@dataclass
class ControlStudioRequest:
"""Normalized request payload for Control Studio operations."""
operation: ControlOperationType
prompt: str
control_image_base64: str # Sketch, structure, or style reference
style_image_base64: Optional[str] = None # For style_transfer only
negative_prompt: Optional[str] = None
control_strength: Optional[float] = None # For sketch/structure
fidelity: Optional[float] = None # For style
style_strength: Optional[float] = None # For style_transfer
composition_fidelity: Optional[float] = None # For style_transfer
change_strength: Optional[float] = None # For style_transfer
aspect_ratio: Optional[str] = None # For style
style_preset: Optional[str] = None
seed: Optional[int] = None
output_format: str = "png"
class ControlStudioService:
"""Service layer orchestrating Control Studio operations."""
SUPPORTED_OPERATIONS: Dict[ControlOperationType, Dict[str, Any]] = {
"sketch": {
"label": "Sketch to Image",
"description": "Transform sketches into refined images with precise control.",
"provider": "stability",
"fields": {
"control_image": True,
"style_image": False,
"control_strength": True,
"fidelity": False,
"style_strength": False,
"aspect_ratio": False,
},
},
"structure": {
"label": "Structure Control",
"description": "Generate images maintaining the structure of an input image.",
"provider": "stability",
"fields": {
"control_image": True,
"style_image": False,
"control_strength": True,
"fidelity": False,
"style_strength": False,
"aspect_ratio": False,
},
},
"style": {
"label": "Style Control",
"description": "Generate images using style from a reference image.",
"provider": "stability",
"fields": {
"control_image": True,
"style_image": False,
"control_strength": False,
"fidelity": True,
"style_strength": False,
"aspect_ratio": True,
},
},
"style_transfer": {
"label": "Style Transfer",
"description": "Apply visual characteristics from a style image to a target image.",
"provider": "stability",
"fields": {
"control_image": True, # init_image
"style_image": True,
"control_strength": False,
"fidelity": False,
"style_strength": True,
"aspect_ratio": False,
},
},
}
def __init__(self):
logger.info("[Control Studio] Initialized control service")
@staticmethod
def _decode_base64_image(value: Optional[str]) -> Optional[bytes]:
"""Decode a base64 (or data URL) string to bytes."""
if not value:
return None
try:
# Handle data URLs (data:image/png;base64,...)
if value.startswith("data:"):
_, b64data = value.split(",", 1)
else:
b64data = value
return base64.b64decode(b64data)
except Exception as exc:
logger.error(f"[Control Studio] Failed to decode base64 image: {exc}")
raise ValueError("Invalid base64 image payload") from exc
@staticmethod
def _image_bytes_to_metadata(image_bytes: bytes) -> Dict[str, Any]:
"""Extract width/height metadata from image bytes."""
with Image.open(io.BytesIO(image_bytes)) as img:
return {
"width": img.width,
"height": img.height,
}
@staticmethod
def _bytes_to_base64(image_bytes: bytes, output_format: str = "png") -> str:
"""Convert raw bytes to base64 data URL."""
b64 = base64.b64encode(image_bytes).decode("utf-8")
return f"data:image/{output_format};base64,{b64}"
def list_operations(self) -> Dict[str, Dict[str, Any]]:
"""Expose supported operations for UI rendering."""
return self.SUPPORTED_OPERATIONS
async def process_control(
self,
request: ControlStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Process control request and return normalized response."""
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_control_operations
from fastapi import HTTPException
db = next(get_db())
try:
pricing_service = PricingService(db)
logger.info(f"[Control Studio] 🛂 Running pre-flight validation for user {user_id}")
validate_image_control_operations(
pricing_service=pricing_service,
user_id=user_id,
num_images=1,
)
logger.info("[Control Studio] ✅ Pre-flight validation passed")
except HTTPException:
logger.error("[Control Studio] ❌ Pre-flight validation failed")
raise
finally:
db.close()
else:
logger.warning("[Control Studio] ⚠️ No user_id provided - skipping pre-flight validation")
control_image_bytes = self._decode_base64_image(request.control_image_base64)
if not control_image_bytes:
raise ValueError("Control image payload is required")
style_image_bytes = self._decode_base64_image(request.style_image_base64)
operation = request.operation
logger.info("[Control Studio] Processing operation='%s' for user=%s", operation, user_id)
if operation not in self.SUPPORTED_OPERATIONS:
raise ValueError(f"Unsupported control operation: {operation}")
stability_service = StabilityAIService()
async with stability_service:
if operation == "sketch":
result = await stability_service.control_sketch(
image=control_image_bytes,
prompt=request.prompt,
control_strength=request.control_strength or 0.7,
negative_prompt=request.negative_prompt,
seed=request.seed,
output_format=request.output_format,
style_preset=request.style_preset,
)
elif operation == "structure":
result = await stability_service.control_structure(
image=control_image_bytes,
prompt=request.prompt,
control_strength=request.control_strength or 0.7,
negative_prompt=request.negative_prompt,
seed=request.seed,
output_format=request.output_format,
style_preset=request.style_preset,
)
elif operation == "style":
result = await stability_service.control_style(
image=control_image_bytes,
prompt=request.prompt,
negative_prompt=request.negative_prompt,
aspect_ratio=request.aspect_ratio or "1:1",
fidelity=request.fidelity or 0.5,
seed=request.seed,
output_format=request.output_format,
style_preset=request.style_preset,
)
elif operation == "style_transfer":
if not style_image_bytes:
raise ValueError("Style image is required for style transfer")
result = await stability_service.control_style_transfer(
init_image=control_image_bytes,
style_image=style_image_bytes,
prompt=request.prompt or "",
negative_prompt=request.negative_prompt,
style_strength=request.style_strength or 1.0,
composition_fidelity=request.composition_fidelity or 0.9,
change_strength=request.change_strength or 0.9,
seed=request.seed,
output_format=request.output_format,
)
else:
raise ValueError(f"Unsupported control operation: {operation}")
image_bytes = self._extract_image_bytes(result)
metadata = self._image_bytes_to_metadata(image_bytes)
metadata.update(
{
"operation": operation,
"style_preset": request.style_preset,
"provider": self.SUPPORTED_OPERATIONS[operation]["provider"],
}
)
response = {
"success": True,
"operation": operation,
"provider": metadata["provider"],
"image_base64": self._bytes_to_base64(image_bytes, request.output_format),
"width": metadata["width"],
"height": metadata["height"],
"metadata": metadata,
}
logger.info("[Control Studio] ✅ Operation '%s' completed", operation)
return response
@staticmethod
def _extract_image_bytes(result: Any) -> bytes:
"""Normalize Stability responses into raw image bytes."""
if isinstance(result, bytes):
return result
if isinstance(result, dict):
artifacts = result.get("artifacts") or result.get("data") or result.get("images") or []
for artifact in artifacts:
if isinstance(artifact, dict):
if artifact.get("base64"):
return base64.b64decode(artifact["base64"])
if artifact.get("b64_json"):
return base64.b64decode(artifact["b64_json"])
raise RuntimeError("Unable to extract image bytes from provider response")

View File

@@ -110,12 +110,12 @@ class EditStudioService:
},
"search_replace": {
"label": "Search & Replace",
"description": "Locate objects via search prompt and replace them.",
"description": "Locate objects via search prompt and replace them. Optional mask for precise control.",
"provider": "stability",
"async": False,
"fields": {
"prompt": True,
"mask": False,
"mask": True, # Optional mask for precise region selection
"negative_prompt": False,
"search_prompt": True,
"select_prompt": False,
@@ -126,12 +126,12 @@ class EditStudioService:
},
"search_recolor": {
"label": "Search & Recolor",
"description": "Select elements via prompt and recolor them.",
"description": "Select elements via prompt and recolor them. Optional mask for exact region selection.",
"provider": "stability",
"async": False,
"fields": {
"prompt": True,
"mask": False,
"mask": True, # Optional mask for precise region selection
"negative_prompt": False,
"search_prompt": False,
"select_prompt": True,
@@ -158,12 +158,12 @@ class EditStudioService:
},
"general_edit": {
"label": "Prompt-based Edit",
"description": "Free-form editing powered by Hugging Face image-to-image models.",
"description": "Free-form editing powered by Hugging Face image-to-image models. Optional mask for selective editing.",
"provider": "huggingface",
"async": False,
"fields": {
"prompt": True,
"mask": False,
"mask": True, # Optional mask for selective region editing
"negative_prompt": True,
"search_prompt": False,
"select_prompt": False,
@@ -346,6 +346,7 @@ class EditStudioService:
image=image_bytes,
prompt=request.prompt,
search_prompt=request.search_prompt,
mask=mask_bytes, # Optional mask for precise region selection
output_format=request.output_format,
)
elif operation == "search_recolor":
@@ -355,6 +356,7 @@ class EditStudioService:
image=image_bytes,
prompt=request.prompt,
select_prompt=request.select_prompt,
mask=mask_bytes, # Optional mask for precise region selection
output_format=request.output_format,
)
elif operation == "relight":
@@ -403,6 +405,7 @@ class EditStudioService:
request.prompt,
options,
user_id,
mask_bytes, # Optional mask for selective editing
)
return result.image_bytes

View File

@@ -0,0 +1,502 @@
"""Social Optimizer service for platform-specific image optimization."""
from __future__ import annotations
import base64
import io
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from PIL import Image, ImageDraw, ImageFont
from .templates import Platform
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.social_optimizer")
@dataclass
class SafeZone:
"""Safe zone configuration for text overlay."""
top: float = 0.1 # Percentage from top
bottom: float = 0.1 # Percentage from bottom
left: float = 0.1 # Percentage from left
right: float = 0.1 # Percentage from right
@dataclass
class PlatformFormat:
"""Platform format specification."""
name: str
width: int
height: int
ratio: str
safe_zone: SafeZone
file_type: str = "PNG"
max_size_mb: float = 5.0
# Platform format definitions with safe zones
PLATFORM_FORMATS: Dict[Platform, List[PlatformFormat]] = {
Platform.INSTAGRAM: [
PlatformFormat(
name="Feed Post (Square)",
width=1080,
height=1080,
ratio="1:1",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Feed Post (Portrait)",
width=1080,
height=1350,
ratio="4:5",
safe_zone=SafeZone(top=0.2, bottom=0.2, left=0.1, right=0.1),
),
PlatformFormat(
name="Story",
width=1080,
height=1920,
ratio="9:16",
safe_zone=SafeZone(top=0.25, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Reel",
width=1080,
height=1920,
ratio="9:16",
safe_zone=SafeZone(top=0.25, bottom=0.15, left=0.1, right=0.1),
),
],
Platform.FACEBOOK: [
PlatformFormat(
name="Feed Post",
width=1200,
height=630,
ratio="1.91:1",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Feed Post (Square)",
width=1080,
height=1080,
ratio="1:1",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Story",
width=1080,
height=1920,
ratio="9:16",
safe_zone=SafeZone(top=0.25, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Cover Photo",
width=820,
height=312,
ratio="16:9",
safe_zone=SafeZone(top=0.2, bottom=0.1, left=0.15, right=0.15),
),
],
Platform.TWITTER: [
PlatformFormat(
name="Post",
width=1200,
height=675,
ratio="16:9",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Card",
width=1200,
height=600,
ratio="2:1",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Header",
width=1500,
height=500,
ratio="3:1",
safe_zone=SafeZone(top=0.2, bottom=0.1, left=0.15, right=0.15),
),
],
Platform.LINKEDIN: [
PlatformFormat(
name="Feed Post",
width=1200,
height=628,
ratio="1.91:1",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Feed Post (Square)",
width=1080,
height=1080,
ratio="1:1",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Article",
width=1200,
height=627,
ratio="2:1",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Company Cover",
width=1128,
height=191,
ratio="4:1",
safe_zone=SafeZone(top=0.2, bottom=0.1, left=0.15, right=0.15),
),
],
Platform.YOUTUBE: [
PlatformFormat(
name="Thumbnail",
width=1280,
height=720,
ratio="16:9",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Channel Art",
width=2560,
height=1440,
ratio="16:9",
safe_zone=SafeZone(top=0.2, bottom=0.1, left=0.15, right=0.15),
),
],
Platform.PINTEREST: [
PlatformFormat(
name="Pin",
width=1000,
height=1500,
ratio="2:3",
safe_zone=SafeZone(top=0.2, bottom=0.2, left=0.1, right=0.1),
),
PlatformFormat(
name="Story Pin",
width=1080,
height=1920,
ratio="9:16",
safe_zone=SafeZone(top=0.25, bottom=0.15, left=0.1, right=0.1),
),
],
Platform.TIKTOK: [
PlatformFormat(
name="Video Cover",
width=1080,
height=1920,
ratio="9:16",
safe_zone=SafeZone(top=0.25, bottom=0.15, left=0.1, right=0.1),
),
],
}
@dataclass
class SocialOptimizerRequest:
"""Request payload for social optimization."""
image_base64: str
platforms: List[Platform] # List of platforms to optimize for
format_names: Optional[Dict[Platform, str]] = None # Specific format per platform
show_safe_zones: bool = False # Include safe zone overlay in output
crop_mode: str = "smart" # "smart", "center", "fit"
focal_point: Optional[Dict[str, float]] = None # {"x": 0.5, "y": 0.5} for smart crop
output_format: str = "png"
options: Dict[str, Any] = field(default_factory=dict)
class SocialOptimizerService:
"""Service for optimizing images for social media platforms."""
def __init__(self):
logger.info("[Social Optimizer] Initialized service")
@staticmethod
def _decode_base64_image(value: str) -> bytes:
"""Decode a base64 (or data URL) string to bytes."""
try:
if value.startswith("data:"):
_, b64data = value.split(",", 1)
else:
b64data = value
return base64.b64decode(b64data)
except Exception as exc:
logger.error(f"[Social Optimizer] Failed to decode base64 image: {exc}")
raise ValueError("Invalid base64 image payload") from exc
@staticmethod
def _bytes_to_base64(image_bytes: bytes, output_format: str = "png") -> str:
"""Convert raw bytes to base64 data URL."""
b64 = base64.b64encode(image_bytes).decode("utf-8")
return f"data:image/{output_format};base64,{b64}"
@staticmethod
def _smart_crop(
image: Image.Image,
target_width: int,
target_height: int,
focal_point: Optional[Dict[str, float]] = None,
) -> Image.Image:
"""Smart crop image to target dimensions, preserving important content."""
img_width, img_height = image.size
target_ratio = target_width / target_height
img_ratio = img_width / img_height
# If focal point is provided, use it for cropping
if focal_point:
focal_x = int(focal_point["x"] * img_width)
focal_y = int(focal_point["y"] * img_height)
else:
# Default to center
focal_x = img_width // 2
focal_y = img_height // 2
if img_ratio > target_ratio:
# Image is wider than target - crop width
new_width = int(img_height * target_ratio)
left = max(0, min(focal_x - new_width // 2, img_width - new_width))
right = left + new_width
cropped = image.crop((left, 0, right, img_height))
else:
# Image is taller than target - crop height
new_height = int(img_width / target_ratio)
top = max(0, min(focal_y - new_height // 2, img_height - new_height))
bottom = top + new_height
cropped = image.crop((0, top, img_width, bottom))
# Resize to exact target dimensions
return cropped.resize((target_width, target_height), Image.Resampling.LANCZOS)
@staticmethod
def _fit_image(
image: Image.Image,
target_width: int,
target_height: int,
) -> Image.Image:
"""Fit image to target dimensions while maintaining aspect ratio (adds padding if needed)."""
img_width, img_height = image.size
target_ratio = target_width / target_height
img_ratio = img_width / img_height
if img_ratio > target_ratio:
# Image is wider - fit to height, pad width
new_height = target_height
new_width = int(img_width * (target_height / img_height))
resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Create new image with target size and paste centered
result = Image.new("RGB", (target_width, target_height), (255, 255, 255))
paste_x = (target_width - new_width) // 2
result.paste(resized, (paste_x, 0))
return result
else:
# Image is taller - fit to width, pad height
new_width = target_width
new_height = int(img_height * (target_width / img_width))
resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Create new image with target size and paste centered
result = Image.new("RGB", (target_width, target_height), (255, 255, 255))
paste_y = (target_height - new_height) // 2
result.paste(resized, (0, paste_y))
return result
@staticmethod
def _center_crop(
image: Image.Image,
target_width: int,
target_height: int,
) -> Image.Image:
"""Center crop image to target dimensions."""
img_width, img_height = image.size
target_ratio = target_width / target_height
img_ratio = img_width / img_height
if img_ratio > target_ratio:
# Image is wider - crop width
new_width = int(img_height * target_ratio)
left = (img_width - new_width) // 2
cropped = image.crop((left, 0, left + new_width, img_height))
else:
# Image is taller - crop height
new_height = int(img_width / target_ratio)
top = (img_height - new_height) // 2
cropped = image.crop((0, top, img_width, top + new_height))
return cropped.resize((target_width, target_height), Image.Resampling.LANCZOS)
@staticmethod
def _draw_safe_zone(
image: Image.Image,
safe_zone: SafeZone,
) -> Image.Image:
"""Draw safe zone overlay on image."""
draw = ImageDraw.Draw(image)
width, height = image.size
# Calculate safe zone boundaries
top = int(height * safe_zone.top)
bottom = int(height * (1 - safe_zone.bottom))
left = int(width * safe_zone.left)
right = int(width * (1 - safe_zone.right))
# Draw semi-transparent overlay outside safe zone
overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
overlay_draw = ImageDraw.Draw(overlay)
# Top area
overlay_draw.rectangle([(0, 0), (width, top)], fill=(0, 0, 0, 100))
# Bottom area
overlay_draw.rectangle([(0, bottom), (width, height)], fill=(0, 0, 0, 100))
# Left area
overlay_draw.rectangle([(0, top), (left, bottom)], fill=(0, 0, 0, 100))
# Right area
overlay_draw.rectangle([(right, top), (width, bottom)], fill=(0, 0, 0, 100))
# Draw safe zone border
border_color = (255, 255, 0, 200) # Yellow with transparency
overlay_draw.rectangle(
[(left, top), (right, bottom)],
outline=border_color,
width=2,
)
# Composite overlay onto image
if image.mode != "RGBA":
image = image.convert("RGBA")
image = Image.alpha_composite(image, overlay)
return image
def get_platform_formats(self, platform: Platform) -> List[Dict[str, Any]]:
"""Get available formats for a platform."""
formats = PLATFORM_FORMATS.get(platform, [])
return [
{
"name": fmt.name,
"width": fmt.width,
"height": fmt.height,
"ratio": fmt.ratio,
"safe_zone": {
"top": fmt.safe_zone.top,
"bottom": fmt.safe_zone.bottom,
"left": fmt.safe_zone.left,
"right": fmt.safe_zone.right,
},
"file_type": fmt.file_type,
"max_size_mb": fmt.max_size_mb,
}
for fmt in formats
]
def optimize_image(
self,
request: SocialOptimizerRequest,
) -> Dict[str, Any]:
"""Optimize image for specified platforms."""
logger.info(
f"[Social Optimizer] Processing optimization for {len(request.platforms)} platform(s)"
)
# Decode input image
image_bytes = self._decode_base64_image(request.image_base64)
original_image = Image.open(io.BytesIO(image_bytes))
# Convert to RGB if needed
if original_image.mode in ("RGBA", "LA", "P"):
if original_image.mode == "P":
original_image = original_image.convert("RGBA")
background = Image.new("RGB", original_image.size, (255, 255, 255))
if original_image.mode == "RGBA":
background.paste(original_image, mask=original_image.split()[-1])
else:
background.paste(original_image)
original_image = background
elif original_image.mode != "RGB":
original_image = original_image.convert("RGB")
results = []
for platform in request.platforms:
formats = PLATFORM_FORMATS.get(platform, [])
if not formats:
logger.warning(f"[Social Optimizer] No formats found for platform: {platform}")
continue
# Get format (use specified format or default to first)
format_name = None
if request.format_names and platform in request.format_names:
format_name = request.format_names[platform]
platform_format = None
for fmt in formats:
if format_name and fmt.name == format_name:
platform_format = fmt
break
if not platform_format:
platform_format = formats[0] # Default to first format
# Crop/resize image based on mode
if request.crop_mode == "smart":
optimized_image = self._smart_crop(
original_image,
platform_format.width,
platform_format.height,
request.focal_point,
)
elif request.crop_mode == "fit":
optimized_image = self._fit_image(
original_image,
platform_format.width,
platform_format.height,
)
else: # center
optimized_image = self._center_crop(
original_image,
platform_format.width,
platform_format.height,
)
# Add safe zone overlay if requested
if request.show_safe_zones:
optimized_image = self._draw_safe_zone(optimized_image, platform_format.safe_zone)
# Convert to bytes
output_buffer = io.BytesIO()
output_format = request.output_format.lower()
if output_format == "jpg" or output_format == "jpeg":
optimized_image = optimized_image.convert("RGB")
optimized_image.save(output_buffer, format="JPEG", quality=95)
else:
optimized_image.save(output_buffer, format="PNG")
output_bytes = output_buffer.getvalue()
results.append(
{
"platform": platform.value,
"format": platform_format.name,
"width": platform_format.width,
"height": platform_format.height,
"ratio": platform_format.ratio,
"image_base64": self._bytes_to_base64(output_bytes, request.output_format),
"safe_zone": {
"top": platform_format.safe_zone.top,
"bottom": platform_format.safe_zone.bottom,
"left": platform_format.safe_zone.left,
"right": platform_format.safe_zone.right,
},
}
)
logger.info(f"[Social Optimizer] ✅ Generated {len(results)} optimized images")
return {
"success": True,
"results": results,
"total_optimized": len(results),
}

View File

@@ -5,6 +5,8 @@ from typing import Optional, Dict, Any, List
from .create_service import CreateStudioService, CreateStudioRequest
from .edit_service import EditStudioService, EditStudioRequest
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
from .control_service import ControlStudioService, ControlStudioRequest
from .social_optimizer_service import SocialOptimizerService, SocialOptimizerRequest
from .templates import Platform, TemplateCategory, ImageTemplate
from utils.logger_utils import get_service_logger
@@ -20,6 +22,8 @@ class ImageStudioManager:
self.create_service = CreateStudioService()
self.edit_service = EditStudioService()
self.upscale_service = UpscaleStudioService()
self.control_service = ControlStudioService()
self.social_optimizer_service = SocialOptimizerService()
logger.info("[Image Studio Manager] Initialized successfully")
# ====================
@@ -215,6 +219,40 @@ class ImageStudioManager:
"estimated": True,
}
# ====================
# CONTROL STUDIO
# ====================
async def control_image(
self,
request: ControlStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Run Control Studio operations."""
logger.info("[Image Studio] Control request from user: %s", user_id)
return await self.control_service.process_control(request, user_id=user_id)
def get_control_operations(self) -> Dict[str, Any]:
"""Expose control operations for UI."""
return self.control_service.list_operations()
# ====================
# SOCIAL OPTIMIZER
# ====================
async def optimize_for_social(
self,
request: SocialOptimizerRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Optimize image for social media platforms."""
logger.info("[Image Studio] Social optimization request from user: %s", user_id)
return self.social_optimizer_service.optimize_image(request)
def get_social_platform_formats(self, platform: Platform) -> List[Dict[str, Any]]:
"""Get available formats for a social platform."""
return self.social_optimizer_service.get_platform_formats(platform)
# ====================
# PLATFORM SPECS
# ====================

View File

@@ -54,7 +54,8 @@ def edit_image(
input_image_bytes: bytes,
prompt: str,
options: Optional[Dict[str, Any]] = None,
user_id: Optional[str] = None
user_id: Optional[str] = None,
mask_bytes: Optional[bytes] = None,
) -> ImageGenerationResult:
"""Edit image with pre-flight validation.
@@ -63,6 +64,7 @@ def edit_image(
prompt: Natural language prompt describing desired edits (e.g., "Turn the cat into a tiger")
options: Image editing options (provider, model, etc.)
user_id: User ID for subscription checking (optional, but required for validation)
mask_bytes: Optional mask image bytes for selective editing (grayscale, white=edit, black=preserve)
Returns:
ImageGenerationResult with edited image bytes and metadata
@@ -72,6 +74,8 @@ def edit_image(
- Describe what should change and what should remain
- Examples: "Turn the cat into a tiger", "Change background to forest",
"Make it look like a watercolor painting"
Note: Mask support depends on the specific model. Some models may ignore the mask parameter.
"""
# PRE-FLIGHT VALIDATION: Validate image editing before API call
# MUST happen BEFORE any API calls - return immediately if validation fails
@@ -128,14 +132,33 @@ def edit_image(
width = input_image.width
height = input_image.height
# Convert mask bytes to PIL Image if provided
mask_image = None
if mask_bytes:
try:
mask_image = Image.open(io.BytesIO(mask_bytes)).convert("L") # Convert to grayscale
# Ensure mask dimensions match input image
if mask_image.size != input_image.size:
logger.warning(f"[Image Editing] Mask size {mask_image.size} doesn't match image size {input_image.size}, resizing mask")
mask_image = mask_image.resize(input_image.size, Image.Resampling.LANCZOS)
except Exception as e:
logger.warning(f"[Image Editing] Failed to process mask image: {e}, continuing without mask")
mask_image = None
# Use image_to_image method from Hugging Face InferenceClient
# This follows the pattern from the Hugging Face documentation
# Docs: https://huggingface.co/docs/inference-providers/en/guides/image-editor
# Note: Mask support depends on the model - some models may ignore it
call_params = params.copy()
if mask_image:
call_params["mask_image"] = mask_image
logger.info("[Image Editing] Using mask for selective editing")
edited_image: Image.Image = client.image_to_image(
image=input_image,
prompt=prompt.strip(),
model=model,
**params,
**call_params,
)
# Convert edited image back to bytes

View File

@@ -397,6 +397,7 @@ class StabilityAIService:
image: Union[UploadFile, bytes],
prompt: str,
search_prompt: str,
mask: Optional[Union[UploadFile, bytes]] = None,
**kwargs
) -> Union[bytes, Dict[str, Any]]:
"""Replace objects in image using search prompt.
@@ -405,6 +406,7 @@ class StabilityAIService:
image: Input image
prompt: Text prompt for replacement
search_prompt: What to search for
mask: Optional mask image for precise region selection
**kwargs: Additional parameters
Returns:
@@ -414,6 +416,8 @@ class StabilityAIService:
data.update({k: v for k, v in kwargs.items() if v is not None})
files = {"image": await self._prepare_image_file(image)}
if mask:
files["mask"] = await self._prepare_image_file(mask)
return await self._make_request(
method="POST",
@@ -427,6 +431,7 @@ class StabilityAIService:
image: Union[UploadFile, bytes],
prompt: str,
select_prompt: str,
mask: Optional[Union[UploadFile, bytes]] = None,
**kwargs
) -> Union[bytes, Dict[str, Any]]:
"""Recolor objects in image using select prompt.
@@ -435,6 +440,7 @@ class StabilityAIService:
image: Input image
prompt: Text prompt for recoloring
select_prompt: What to select for recoloring
mask: Optional mask image for precise region selection
**kwargs: Additional parameters
Returns:
@@ -444,6 +450,8 @@ class StabilityAIService:
data.update({k: v for k, v in kwargs.items() if v is not None})
files = {"image": await self._prepare_image_file(image)}
if mask:
files["mask"] = await self._prepare_image_file(mask)
return await self._make_request(
method="POST",

View File

@@ -415,6 +415,75 @@ def validate_image_editing_operations(
)
def validate_image_control_operations(
pricing_service: PricingService,
user_id: str,
num_images: int = 1
) -> None:
"""
Validate image control operations (sketch-to-image, structure control, style transfer) before making API calls.
Control operations use Stability AI for image generation with control inputs, so they use
the same validation as image generation operations.
Args:
pricing_service: PricingService instance
user_id: User ID for subscription checking
num_images: Number of images to generate (for multiple variations)
Returns:
None - raises HTTPException with 429 status if validation fails
"""
try:
# Control operations use Stability AI, same as image generation
operations_to_validate = [
{
'provider': APIProvider.STABILITY,
'tokens_requested': 0,
'actual_provider_name': 'stability',
'operation_type': 'image_generation' # Control ops use image generation limits
}
for _ in range(num_images)
]
logger.info(f"[Pre-flight Validator] 🚀 Validating {num_images} image control operation(s) for user {user_id}")
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
if not can_proceed:
logger.error(f"[Pre-flight Validator] Image control blocked for user {user_id}: {message}")
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', 'stability') if usage_info else 'stability'
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details
}
)
logger.info(f"[Pre-flight Validator] ✅ Image control validated for user {user_id}")
except HTTPException:
raise
except Exception as e:
logger.error(f"[Pre-flight Validator] Error validating image control: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
'error': f"Failed to validate image control: {str(e)}",
'message': f"Failed to validate image control: {str(e)}"
}
)
def validate_video_generation_operations(
pricing_service: PricingService,
user_id: str