AI Image Studio Progress Review

- Added new router for content assets
- Added new service for content assets
- Added new model for content assets
- Added new utils for content assets
- Added new docs for content assets
- Added new tests for content assets
- Added new examples for content assets
- Added new guides for content assets
This commit is contained in:
ajaysi
2025-11-23 09:21:11 +05:30
parent eede21ad42
commit 77d7c0cde6
38 changed files with 5939 additions and 37 deletions

View File

@@ -0,0 +1,2 @@
# Content Assets API Module

View File

@@ -0,0 +1,258 @@
"""
Content Assets API Router
API endpoints for managing unified content assets across all modules.
"""
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
from datetime import datetime
from services.database import get_db
from middleware.auth_middleware import get_current_user
from services.content_asset_service import ContentAssetService
from models.content_asset_models import AssetType, AssetSource
router = APIRouter(prefix="/api/content-assets", tags=["Content Assets"])
class AssetResponse(BaseModel):
"""Response model for asset data."""
id: int
user_id: str
asset_type: str
source_module: str
filename: str
file_url: str
file_path: Optional[str] = None
file_size: Optional[int] = None
mime_type: Optional[str] = None
title: Optional[str] = None
description: Optional[str] = None
prompt: Optional[str] = None
tags: List[str] = []
metadata: Dict[str, Any] = {}
provider: Optional[str] = None
model: Optional[str] = None
cost: float = 0.0
generation_time: Optional[float] = None
is_favorite: bool = False
download_count: int = 0
share_count: int = 0
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class AssetListResponse(BaseModel):
"""Response model for asset list."""
assets: List[AssetResponse]
total: int
limit: int
offset: int
@router.get("/", response_model=AssetListResponse)
async def get_assets(
asset_type: Optional[str] = Query(None, description="Filter by asset type"),
source_module: Optional[str] = Query(None, description="Filter by source module"),
search: Optional[str] = Query(None, description="Search query"),
tags: Optional[str] = Query(None, description="Comma-separated tags"),
favorites_only: bool = Query(False, description="Only favorites"),
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Get user's content assets with optional filtering."""
try:
user_id = current_user.get("user_id") or current_user.get("id")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found")
service = ContentAssetService(db)
# Parse filters
asset_type_enum = None
if asset_type:
try:
asset_type_enum = AssetType(asset_type.lower())
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid asset type: {asset_type}")
source_module_enum = None
if source_module:
try:
source_module_enum = AssetSource(source_module.lower())
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid source module: {source_module}")
tags_list = None
if tags:
tags_list = [tag.strip() for tag in tags.split(",")]
assets, total = service.get_user_assets(
user_id=user_id,
asset_type=asset_type_enum,
source_module=source_module_enum,
search_query=search,
tags=tags_list,
favorites_only=favorites_only,
limit=limit,
offset=offset,
)
return AssetListResponse(
assets=[AssetResponse.model_validate(asset) for asset in assets],
total=total,
limit=limit,
offset=offset,
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error fetching assets: {str(e)}")
@router.post("/{asset_id}/favorite", response_model=Dict[str, Any])
async def toggle_favorite(
asset_id: int,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Toggle favorite status of an asset."""
try:
user_id = current_user.get("user_id") or current_user.get("id")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found")
service = ContentAssetService(db)
is_favorite = service.toggle_favorite(asset_id, user_id)
return {"asset_id": asset_id, "is_favorite": is_favorite}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error toggling favorite: {str(e)}")
@router.delete("/{asset_id}", response_model=Dict[str, Any])
async def delete_asset(
asset_id: int,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Delete an asset."""
try:
user_id = current_user.get("user_id") or current_user.get("id")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found")
service = ContentAssetService(db)
success = service.delete_asset(asset_id, user_id)
if not success:
raise HTTPException(status_code=404, detail="Asset not found")
return {"asset_id": asset_id, "deleted": True}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error deleting asset: {str(e)}")
@router.post("/{asset_id}/usage", response_model=Dict[str, Any])
async def track_usage(
asset_id: int,
action: str = Query(..., description="Action: download, share, or access"),
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Track asset usage (download, share, access)."""
try:
user_id = current_user.get("user_id") or current_user.get("id")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found")
if action not in ["download", "share", "access"]:
raise HTTPException(status_code=400, detail="Invalid action")
service = ContentAssetService(db)
service.update_asset_usage(asset_id, user_id, action)
return {"asset_id": asset_id, "action": action, "tracked": True}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error tracking usage: {str(e)}")
class AssetUpdateRequest(BaseModel):
"""Request model for updating asset metadata."""
title: Optional[str] = None
description: Optional[str] = None
tags: Optional[List[str]] = None
@router.put("/{asset_id}", response_model=AssetResponse)
async def update_asset(
asset_id: int,
update_data: AssetUpdateRequest,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Update asset metadata."""
try:
user_id = current_user.get("user_id") or current_user.get("id")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found")
service = ContentAssetService(db)
asset = service.update_asset(
asset_id=asset_id,
user_id=user_id,
title=update_data.title,
description=update_data.description,
tags=update_data.tags,
)
if not asset:
raise HTTPException(status_code=404, detail="Asset not found")
return AssetResponse.model_validate(asset)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error updating asset: {str(e)}")
@router.get("/statistics", response_model=Dict[str, Any])
async def get_statistics(
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Get asset statistics for the current user."""
try:
user_id = current_user.get("user_id") or current_user.get("id")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found")
service = ContentAssetService(db)
stats = service.get_asset_statistics(user_id)
return stats
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error fetching statistics: {str(e)}")

View File

@@ -299,6 +299,10 @@ app.include_router(platform_analytics_router)
app.include_router(images_router)
app.include_router(image_studio_router)
# Include content assets router
from api.content_assets.router import router as content_assets_router
app.include_router(content_assets_router)
# Include research configuration router
app.include_router(research_config_router, prefix="/api/research", tags=["research"])

View File

@@ -0,0 +1,145 @@
"""
Content Asset Models
Unified database models for tracking all AI-generated content assets across all modules.
"""
from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, Text, ForeignKey, Enum, Index, func
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from datetime import datetime
import enum
# Use the same Base as subscription models for consistency
from models.subscription_models import Base
class AssetType(enum.Enum):
"""Types of content assets."""
TEXT = "text"
IMAGE = "image"
VIDEO = "video"
AUDIO = "audio"
class AssetSource(enum.Enum):
"""Source module/tool that generated the asset - covers ALL ALwrity tools."""
# Image Studio modules
IMAGE_STUDIO_CREATE = "image_studio_create"
IMAGE_STUDIO_EDIT = "image_studio_edit"
IMAGE_STUDIO_UPSCALE = "image_studio_upscale"
IMAGE_STUDIO_TRANSFORM = "image_studio_transform"
IMAGE_STUDIO_CONTROL = "image_studio_control"
IMAGE_STUDIO_SOCIAL = "image_studio_social"
IMAGE_STUDIO_BATCH = "image_studio_batch"
# Content Writers
STORY_WRITER = "story_writer"
BLOG_WRITER = "blog_writer"
LINKEDIN_WRITER = "linkedin_writer"
FACEBOOK_WRITER = "facebook_writer"
# Content Planning
CONTENT_PLANNING = "content_planning"
CONTENT_STRATEGY = "content_strategy"
# SEO Tools
SEO_DASHBOARD = "seo_dashboard"
SEO_TOOLS = "seo_tools"
# Research
RESEARCH = "research"
# Scheduler
SCHEDULER = "scheduler"
# Main Generation (legacy/fallback)
MAIN_TEXT_GENERATION = "main_text_generation"
MAIN_IMAGE_GENERATION = "main_image_generation"
MAIN_VIDEO_GENERATION = "main_video_generation"
MAIN_AUDIO_GENERATION = "main_audio_generation"
class ContentAsset(Base):
"""
Unified model for tracking all AI-generated content assets.
Similar to subscription tracking, this provides a centralized way to manage all content.
"""
__tablename__ = "content_assets"
# Primary fields
id = Column(Integer, primary_key=True)
user_id = Column(String(255), nullable=False, index=True) # Clerk user ID
# Asset identification
asset_type = Column(Enum(AssetType), nullable=False, index=True)
source_module = Column(Enum(AssetSource), nullable=False, index=True)
# File information
filename = Column(String(500), nullable=False)
file_path = Column(String(1000), nullable=True) # Server file path
file_url = Column(String(1000), nullable=False) # Public URL
file_size = Column(Integer, nullable=True) # Size in bytes
mime_type = Column(String(100), nullable=True) # MIME type
# Asset metadata
title = Column(String(500), nullable=True)
description = Column(Text, nullable=True)
prompt = Column(Text, nullable=True) # Original prompt used for generation
tags = Column(JSON, nullable=True) # Array of tags for search/filtering
metadata = Column(JSON, nullable=True) # Additional module-specific metadata
# Generation details
provider = Column(String(100), nullable=True, index=True) # AI provider used (e.g., "stability", "gemini")
model = Column(String(200), nullable=True, index=True) # Model used (full model path/name)
cost = Column(Float, nullable=True, default=0.0) # Generation cost in USD
generation_time = Column(Float, nullable=True) # Time taken in seconds
# Status tracking
status = Column(String(50), default='completed', index=True) # completed, processing, failed, pending
error_message = Column(Text, nullable=True) # Error details if failed
# Organization
is_favorite = Column(Boolean, default=False, index=True)
collection_id = Column(Integer, ForeignKey('asset_collections.id'), nullable=True)
# Usage tracking
download_count = Column(Integer, default=0)
share_count = Column(Integer, default=0)
last_accessed = Column(DateTime, nullable=True)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
collection = relationship("AssetCollection", back_populates="assets", cascade="all, delete-orphan")
# Composite indexes for common query patterns
__table_args__ = (
Index('idx_user_type_source', 'user_id', 'asset_type', 'source_module'),
Index('idx_user_favorite_created', 'user_id', 'is_favorite', 'created_at'),
Index('idx_user_tags', 'user_id', 'tags'),
)
class AssetCollection(Base):
"""
Collections/albums for organizing assets.
"""
__tablename__ = "asset_collections"
id = Column(Integer, primary_key=True)
user_id = Column(String(255), nullable=False, index=True)
name = Column(String(255), nullable=False)
description = Column(Text, nullable=True)
is_public = Column(Boolean, default=False)
cover_asset_id = Column(Integer, ForeignKey('content_assets.id'), nullable=True)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
assets = relationship("ContentAsset", back_populates="collection")

View File

@@ -9,6 +9,8 @@ from services.image_studio import (
ImageStudioManager,
CreateStudioRequest,
EditStudioRequest,
ControlStudioRequest,
SocialOptimizerRequest,
)
from services.image_studio.upscale_service import UpscaleStudioRequest
from services.image_studio.templates import Platform, TemplateCategory
@@ -531,6 +533,197 @@ async def upscale_image(
raise HTTPException(status_code=500, detail=f"Image upscaling failed: {e}")
# ====================
# CONTROL STUDIO ENDPOINTS
# ====================
class ControlImageRequest(BaseModel):
"""Request payload for Control Studio."""
control_image_base64: str = Field(..., description="Control image (sketch/structure/style) in base64")
operation: Literal["sketch", "structure", "style", "style_transfer"] = Field(..., description="Control operation")
prompt: str = Field(..., description="Text prompt for generation")
style_image_base64: Optional[str] = Field(None, description="Style reference image (for style_transfer only)")
negative_prompt: Optional[str] = Field(None, description="Negative prompt")
control_strength: Optional[float] = Field(None, ge=0.0, le=1.0, description="Control strength (sketch/structure)")
fidelity: Optional[float] = Field(None, ge=0.0, le=1.0, description="Style fidelity (style operation)")
style_strength: Optional[float] = Field(None, ge=0.0, le=1.0, description="Style strength (style_transfer)")
composition_fidelity: Optional[float] = Field(None, ge=0.0, le=1.0, description="Composition fidelity (style_transfer)")
change_strength: Optional[float] = Field(None, ge=0.0, le=1.0, description="Change strength (style_transfer)")
aspect_ratio: Optional[str] = Field(None, description="Aspect ratio (style operation)")
style_preset: Optional[str] = Field(None, description="Style preset")
seed: Optional[int] = Field(None, description="Random seed")
output_format: str = Field("png", description="Output format")
class ControlImageResponse(BaseModel):
success: bool
operation: str
provider: str
image_base64: str
width: int
height: int
metadata: Dict[str, Any]
class ControlOperationsResponse(BaseModel):
operations: Dict[str, Dict[str, Any]]
@router.post("/control/process", response_model=ControlImageResponse, summary="Process Control Studio request")
async def process_control_image(
request: ControlImageRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Perform Control Studio operations such as sketch-to-image, structure control, style control, and style transfer."""
try:
user_id = _require_user_id(current_user, "image control")
logger.info(f"[Control Image] Request from user {user_id}: operation={request.operation}")
control_request = ControlStudioRequest(
operation=request.operation,
prompt=request.prompt,
control_image_base64=request.control_image_base64,
style_image_base64=request.style_image_base64,
negative_prompt=request.negative_prompt,
control_strength=request.control_strength,
fidelity=request.fidelity,
style_strength=request.style_strength,
composition_fidelity=request.composition_fidelity,
change_strength=request.change_strength,
aspect_ratio=request.aspect_ratio,
style_preset=request.style_preset,
seed=request.seed,
output_format=request.output_format,
)
result = await studio_manager.control_image(control_request, user_id=user_id)
return ControlImageResponse(**result)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Control Image] ❌ Error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Image control failed: {e}")
@router.get("/control/operations", response_model=ControlOperationsResponse, summary="List Control Studio operations")
async def get_control_operations(
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Return metadata for supported Control Studio operations."""
try:
operations = studio_manager.get_control_operations()
return ControlOperationsResponse(operations=operations)
except Exception as e:
logger.error(f"[Control Operations] ❌ Error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to load control operations")
# ====================
# SOCIAL OPTIMIZER ENDPOINTS
# ====================
class SocialOptimizeRequest(BaseModel):
"""Request payload for Social Optimizer."""
image_base64: str = Field(..., description="Source image in base64 or data URL")
platforms: List[str] = Field(..., description="List of platforms to optimize for")
format_names: Optional[Dict[str, str]] = Field(None, description="Specific format per platform")
show_safe_zones: bool = Field(False, description="Include safe zone overlay in output")
crop_mode: str = Field("smart", description="Crop mode: smart, center, or fit")
focal_point: Optional[Dict[str, float]] = Field(None, description="Focal point for smart crop (x, y as 0-1)")
output_format: str = Field("png", description="Output format (png or jpg)")
class SocialOptimizeResponse(BaseModel):
success: bool
results: List[Dict[str, Any]]
total_optimized: int
class PlatformFormatsResponse(BaseModel):
formats: List[Dict[str, Any]]
@router.post("/social/optimize", response_model=SocialOptimizeResponse, summary="Optimize image for social platforms")
async def optimize_for_social(
request: SocialOptimizeRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Optimize an image for multiple social media platforms with smart cropping and safe zones."""
try:
user_id = _require_user_id(current_user, "social optimization")
logger.info(f"[Social Optimizer] Request from user {user_id}: platforms={request.platforms}")
# Convert platform strings to Platform enum
from services.image_studio.templates import Platform
platforms = []
for platform_str in request.platforms:
try:
platforms.append(Platform(platform_str.lower()))
except ValueError:
logger.warning(f"[Social Optimizer] Invalid platform: {platform_str}")
continue
if not platforms:
raise HTTPException(status_code=400, detail="No valid platforms provided")
# Convert format_names dict keys to Platform enum
format_names = None
if request.format_names:
format_names = {}
for platform_str, format_name in request.format_names.items():
try:
platform = Platform(platform_str.lower())
format_names[platform] = format_name
except ValueError:
logger.warning(f"[Social Optimizer] Invalid platform in format_names: {platform_str}")
social_request = SocialOptimizerRequest(
image_base64=request.image_base64,
platforms=platforms,
format_names=format_names,
show_safe_zones=request.show_safe_zones,
crop_mode=request.crop_mode,
focal_point=request.focal_point,
output_format=request.output_format,
options={},
)
result = await studio_manager.optimize_for_social(social_request, user_id=user_id)
return SocialOptimizeResponse(**result)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Social Optimizer] ❌ Error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Social optimization failed: {e}")
@router.get("/social/platforms/{platform}/formats", response_model=PlatformFormatsResponse, summary="Get platform formats")
async def get_platform_formats(
platform: str,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Get available formats for a social media platform."""
try:
from services.image_studio.templates import Platform
try:
platform_enum = Platform(platform.lower())
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid platform: {platform}")
formats = studio_manager.get_social_platform_formats(platform_enum)
return PlatformFormatsResponse(formats=formats)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Platform Formats] ❌ Error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to load platform formats: {e}")
# ====================
# PLATFORM SPECS ENDPOINTS
# ====================

View File

@@ -0,0 +1,322 @@
"""
Content Asset Service
Service for managing and tracking all AI-generated content assets.
"""
from typing import Dict, Any, List, Optional, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_, func, desc
from datetime import datetime
from models.content_asset_models import (
ContentAsset,
AssetCollection,
AssetType,
AssetSource
)
import logging
logger = logging.getLogger(__name__)
class ContentAssetService:
"""Service for managing content assets across all modules."""
def __init__(self, db: Session):
self.db = db
def create_asset(
self,
user_id: str,
asset_type: AssetType,
source_module: AssetSource,
filename: str,
file_url: str,
file_path: Optional[str] = None,
file_size: Optional[int] = None,
mime_type: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
prompt: Optional[str] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
provider: Optional[str] = None,
model: Optional[str] = None,
cost: Optional[float] = None,
generation_time: Optional[float] = None,
) -> ContentAsset:
"""
Create a new content asset record.
Args:
user_id: Clerk user ID
asset_type: Type of asset (text, image, video, audio)
source_module: Source module that generated it
filename: Original filename
file_url: Public URL to access the asset
file_path: Server file path (optional)
file_size: File size in bytes (optional)
mime_type: MIME type (optional)
title: Asset title (optional)
description: Asset description (optional)
prompt: Generation prompt (optional)
tags: List of tags (optional)
metadata: Additional metadata (optional)
provider: AI provider used (optional)
model: Model used (optional)
cost: Generation cost (optional)
generation_time: Generation time in seconds (optional)
Returns:
Created ContentAsset instance
"""
try:
asset = ContentAsset(
user_id=user_id,
asset_type=asset_type,
source_module=source_module,
filename=filename,
file_url=file_url,
file_path=file_path,
file_size=file_size,
mime_type=mime_type,
title=title,
description=description,
prompt=prompt,
tags=tags or [],
metadata=metadata or {},
provider=provider,
model=model,
cost=cost or 0.0,
generation_time=generation_time,
)
self.db.add(asset)
self.db.commit()
self.db.refresh(asset)
logger.info(f"Created asset {asset.id} for user {user_id} from {source_module.value}")
return asset
except Exception as e:
self.db.rollback()
logger.error(f"Error creating asset: {str(e)}", exc_info=True)
raise
def get_user_assets(
self,
user_id: str,
asset_type: Optional[AssetType] = None,
source_module: Optional[AssetSource] = None,
search_query: Optional[str] = None,
tags: Optional[List[str]] = None,
favorites_only: bool = False,
limit: int = 100,
offset: int = 0,
) -> Tuple[List[ContentAsset], int]:
"""
Get assets for a user with optional filtering.
Args:
user_id: Clerk user ID
asset_type: Filter by asset type (optional)
source_module: Filter by source module (optional)
search_query: Search in title, description, prompt (optional)
tags: Filter by tags (optional)
favorites_only: Only return favorites (optional)
limit: Maximum number of results
offset: Offset for pagination
Returns:
List of ContentAsset instances
"""
try:
query = self.db.query(ContentAsset).filter(
ContentAsset.user_id == user_id
)
if asset_type:
query = query.filter(ContentAsset.asset_type == asset_type)
if source_module:
query = query.filter(ContentAsset.source_module == source_module)
if favorites_only:
query = query.filter(ContentAsset.is_favorite == True)
if search_query:
search_filter = or_(
ContentAsset.title.ilike(f"%{search_query}%"),
ContentAsset.description.ilike(f"%{search_query}%"),
ContentAsset.prompt.ilike(f"%{search_query}%"),
ContentAsset.filename.ilike(f"%{search_query}%"),
)
query = query.filter(search_filter)
if tags:
# Filter by tags (JSON array contains any of the tags)
tag_filters = [ContentAsset.tags.contains([tag]) for tag in tags]
query = query.filter(or_(*tag_filters))
# Get total count before pagination
total_count = query.count()
# Apply ordering and pagination
query = query.order_by(desc(ContentAsset.created_at))
query = query.limit(limit).offset(offset)
return query.all(), total_count
except Exception as e:
logger.error(f"Error fetching assets: {str(e)}", exc_info=True)
raise
def get_asset_by_id(self, asset_id: int, user_id: str) -> Optional[ContentAsset]:
"""Get a specific asset by ID."""
try:
return self.db.query(ContentAsset).filter(
and_(
ContentAsset.id == asset_id,
ContentAsset.user_id == user_id
)
).first()
except Exception as e:
logger.error(f"Error fetching asset {asset_id}: {str(e)}", exc_info=True)
return None
def toggle_favorite(self, asset_id: int, user_id: str) -> bool:
"""Toggle favorite status of an asset."""
try:
asset = self.get_asset_by_id(asset_id, user_id)
if not asset:
return False
asset.is_favorite = not asset.is_favorite
self.db.commit()
return asset.is_favorite
except Exception as e:
self.db.rollback()
logger.error(f"Error toggling favorite: {str(e)}", exc_info=True)
return False
def delete_asset(self, asset_id: int, user_id: str) -> bool:
"""Delete an asset."""
try:
asset = self.get_asset_by_id(asset_id, user_id)
if not asset:
return False
self.db.delete(asset)
self.db.commit()
return True
except Exception as e:
self.db.rollback()
logger.error(f"Error deleting asset: {str(e)}", exc_info=True)
return False
def update_asset(
self,
asset_id: int,
user_id: str,
title: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Optional[ContentAsset]:
"""Update asset metadata."""
try:
asset = self.get_asset_by_id(asset_id, user_id)
if not asset:
return None
if title is not None:
asset.title = title
if description is not None:
asset.description = description
if tags is not None:
asset.tags = tags
if metadata is not None:
asset.metadata = {**(asset.metadata or {}), **metadata}
asset.updated_at = datetime.utcnow()
self.db.commit()
self.db.refresh(asset)
logger.info(f"Updated asset {asset_id} for user {user_id}")
return asset
except Exception as e:
self.db.rollback()
logger.error(f"Error updating asset: {str(e)}", exc_info=True)
return None
def update_asset_usage(self, asset_id: int, user_id: str, action: str = "access"):
"""Update asset usage statistics."""
try:
asset = self.get_asset_by_id(asset_id, user_id)
if not asset:
return
if action == "download":
asset.download_count += 1
elif action == "share":
asset.share_count += 1
asset.last_accessed = datetime.utcnow()
self.db.commit()
except Exception as e:
self.db.rollback()
logger.error(f"Error updating asset usage: {str(e)}", exc_info=True)
def get_asset_statistics(self, user_id: str) -> Dict[str, Any]:
"""Get statistics about user's assets."""
try:
total = self.db.query(func.count(ContentAsset.id)).filter(
ContentAsset.user_id == user_id
).scalar() or 0
by_type = self.db.query(
ContentAsset.asset_type,
func.count(ContentAsset.id)
).filter(
ContentAsset.user_id == user_id
).group_by(ContentAsset.asset_type).all()
by_source = self.db.query(
ContentAsset.source_module,
func.count(ContentAsset.id)
).filter(
ContentAsset.user_id == user_id
).group_by(ContentAsset.source_module).all()
total_cost = self.db.query(func.sum(ContentAsset.cost)).filter(
ContentAsset.user_id == user_id
).scalar() or 0.0
favorites_count = self.db.query(func.count(ContentAsset.id)).filter(
and_(
ContentAsset.user_id == user_id,
ContentAsset.is_favorite == True
)
).scalar() or 0
return {
"total": total,
"by_type": {str(t): c for t, c in by_type},
"by_source": {str(s): c for s, c in by_source},
"total_cost": float(total_cost),
"favorites_count": favorites_count,
}
except Exception as e:
logger.error(f"Error getting asset statistics: {str(e)}", exc_info=True)
return {
"total": 0,
"by_type": {},
"by_source": {},
"total_cost": 0.0,
"favorites_count": 0,
}

View File

@@ -20,6 +20,7 @@ from models.monitoring_models import Base as MonitoringBase
from models.persona_models import Base as PersonaBase
from models.subscription_models import Base as SubscriptionBase
from models.user_business_info import Base as UserBusinessInfoBase
from models.content_asset_models import Base as ContentAssetBase
# Database configuration
DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./alwrity.db')
@@ -74,7 +75,8 @@ def init_database():
PersonaBase.metadata.create_all(bind=engine)
SubscriptionBase.metadata.create_all(bind=engine)
UserBusinessInfoBase.metadata.create_all(bind=engine)
logger.info("Database initialized successfully with all models including subscription system and business info")
ContentAssetBase.metadata.create_all(bind=engine)
logger.info("Database initialized successfully with all models including subscription system, business info, and content assets")
except SQLAlchemyError as e:
logger.error(f"Error initializing database: {str(e)}")
raise

View File

@@ -4,6 +4,8 @@ from .studio_manager import ImageStudioManager
from .create_service import CreateStudioService, CreateStudioRequest
from .edit_service import EditStudioService, EditStudioRequest
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
from .control_service import ControlStudioService, ControlStudioRequest
from .social_optimizer_service import SocialOptimizerService, SocialOptimizerRequest
from .templates import PlatformTemplates, TemplateManager
__all__ = [
@@ -14,6 +16,10 @@ __all__ = [
"EditStudioRequest",
"UpscaleStudioService",
"UpscaleStudioRequest",
"ControlStudioService",
"ControlStudioRequest",
"SocialOptimizerService",
"SocialOptimizerRequest",
"PlatformTemplates",
"TemplateManager",
]

View File

@@ -0,0 +1,277 @@
"""Control Studio service for AI-powered controlled image generation."""
from __future__ import annotations
import base64
import io
from dataclasses import dataclass
from typing import Any, Dict, Literal, Optional
from PIL import Image
from services.stability_service import StabilityAIService
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.control")
ControlOperationType = Literal[
"sketch",
"structure",
"style",
"style_transfer",
]
@dataclass
class ControlStudioRequest:
"""Normalized request payload for Control Studio operations."""
operation: ControlOperationType
prompt: str
control_image_base64: str # Sketch, structure, or style reference
style_image_base64: Optional[str] = None # For style_transfer only
negative_prompt: Optional[str] = None
control_strength: Optional[float] = None # For sketch/structure
fidelity: Optional[float] = None # For style
style_strength: Optional[float] = None # For style_transfer
composition_fidelity: Optional[float] = None # For style_transfer
change_strength: Optional[float] = None # For style_transfer
aspect_ratio: Optional[str] = None # For style
style_preset: Optional[str] = None
seed: Optional[int] = None
output_format: str = "png"
class ControlStudioService:
"""Service layer orchestrating Control Studio operations."""
SUPPORTED_OPERATIONS: Dict[ControlOperationType, Dict[str, Any]] = {
"sketch": {
"label": "Sketch to Image",
"description": "Transform sketches into refined images with precise control.",
"provider": "stability",
"fields": {
"control_image": True,
"style_image": False,
"control_strength": True,
"fidelity": False,
"style_strength": False,
"aspect_ratio": False,
},
},
"structure": {
"label": "Structure Control",
"description": "Generate images maintaining the structure of an input image.",
"provider": "stability",
"fields": {
"control_image": True,
"style_image": False,
"control_strength": True,
"fidelity": False,
"style_strength": False,
"aspect_ratio": False,
},
},
"style": {
"label": "Style Control",
"description": "Generate images using style from a reference image.",
"provider": "stability",
"fields": {
"control_image": True,
"style_image": False,
"control_strength": False,
"fidelity": True,
"style_strength": False,
"aspect_ratio": True,
},
},
"style_transfer": {
"label": "Style Transfer",
"description": "Apply visual characteristics from a style image to a target image.",
"provider": "stability",
"fields": {
"control_image": True, # init_image
"style_image": True,
"control_strength": False,
"fidelity": False,
"style_strength": True,
"aspect_ratio": False,
},
},
}
def __init__(self):
logger.info("[Control Studio] Initialized control service")
@staticmethod
def _decode_base64_image(value: Optional[str]) -> Optional[bytes]:
"""Decode a base64 (or data URL) string to bytes."""
if not value:
return None
try:
# Handle data URLs (data:image/png;base64,...)
if value.startswith("data:"):
_, b64data = value.split(",", 1)
else:
b64data = value
return base64.b64decode(b64data)
except Exception as exc:
logger.error(f"[Control Studio] Failed to decode base64 image: {exc}")
raise ValueError("Invalid base64 image payload") from exc
@staticmethod
def _image_bytes_to_metadata(image_bytes: bytes) -> Dict[str, Any]:
"""Extract width/height metadata from image bytes."""
with Image.open(io.BytesIO(image_bytes)) as img:
return {
"width": img.width,
"height": img.height,
}
@staticmethod
def _bytes_to_base64(image_bytes: bytes, output_format: str = "png") -> str:
"""Convert raw bytes to base64 data URL."""
b64 = base64.b64encode(image_bytes).decode("utf-8")
return f"data:image/{output_format};base64,{b64}"
def list_operations(self) -> Dict[str, Dict[str, Any]]:
"""Expose supported operations for UI rendering."""
return self.SUPPORTED_OPERATIONS
async def process_control(
self,
request: ControlStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Process control request and return normalized response."""
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_control_operations
from fastapi import HTTPException
db = next(get_db())
try:
pricing_service = PricingService(db)
logger.info(f"[Control Studio] 🛂 Running pre-flight validation for user {user_id}")
validate_image_control_operations(
pricing_service=pricing_service,
user_id=user_id,
num_images=1,
)
logger.info("[Control Studio] ✅ Pre-flight validation passed")
except HTTPException:
logger.error("[Control Studio] ❌ Pre-flight validation failed")
raise
finally:
db.close()
else:
logger.warning("[Control Studio] ⚠️ No user_id provided - skipping pre-flight validation")
control_image_bytes = self._decode_base64_image(request.control_image_base64)
if not control_image_bytes:
raise ValueError("Control image payload is required")
style_image_bytes = self._decode_base64_image(request.style_image_base64)
operation = request.operation
logger.info("[Control Studio] Processing operation='%s' for user=%s", operation, user_id)
if operation not in self.SUPPORTED_OPERATIONS:
raise ValueError(f"Unsupported control operation: {operation}")
stability_service = StabilityAIService()
async with stability_service:
if operation == "sketch":
result = await stability_service.control_sketch(
image=control_image_bytes,
prompt=request.prompt,
control_strength=request.control_strength or 0.7,
negative_prompt=request.negative_prompt,
seed=request.seed,
output_format=request.output_format,
style_preset=request.style_preset,
)
elif operation == "structure":
result = await stability_service.control_structure(
image=control_image_bytes,
prompt=request.prompt,
control_strength=request.control_strength or 0.7,
negative_prompt=request.negative_prompt,
seed=request.seed,
output_format=request.output_format,
style_preset=request.style_preset,
)
elif operation == "style":
result = await stability_service.control_style(
image=control_image_bytes,
prompt=request.prompt,
negative_prompt=request.negative_prompt,
aspect_ratio=request.aspect_ratio or "1:1",
fidelity=request.fidelity or 0.5,
seed=request.seed,
output_format=request.output_format,
style_preset=request.style_preset,
)
elif operation == "style_transfer":
if not style_image_bytes:
raise ValueError("Style image is required for style transfer")
result = await stability_service.control_style_transfer(
init_image=control_image_bytes,
style_image=style_image_bytes,
prompt=request.prompt or "",
negative_prompt=request.negative_prompt,
style_strength=request.style_strength or 1.0,
composition_fidelity=request.composition_fidelity or 0.9,
change_strength=request.change_strength or 0.9,
seed=request.seed,
output_format=request.output_format,
)
else:
raise ValueError(f"Unsupported control operation: {operation}")
image_bytes = self._extract_image_bytes(result)
metadata = self._image_bytes_to_metadata(image_bytes)
metadata.update(
{
"operation": operation,
"style_preset": request.style_preset,
"provider": self.SUPPORTED_OPERATIONS[operation]["provider"],
}
)
response = {
"success": True,
"operation": operation,
"provider": metadata["provider"],
"image_base64": self._bytes_to_base64(image_bytes, request.output_format),
"width": metadata["width"],
"height": metadata["height"],
"metadata": metadata,
}
logger.info("[Control Studio] ✅ Operation '%s' completed", operation)
return response
@staticmethod
def _extract_image_bytes(result: Any) -> bytes:
"""Normalize Stability responses into raw image bytes."""
if isinstance(result, bytes):
return result
if isinstance(result, dict):
artifacts = result.get("artifacts") or result.get("data") or result.get("images") or []
for artifact in artifacts:
if isinstance(artifact, dict):
if artifact.get("base64"):
return base64.b64decode(artifact["base64"])
if artifact.get("b64_json"):
return base64.b64decode(artifact["b64_json"])
raise RuntimeError("Unable to extract image bytes from provider response")

View File

@@ -110,12 +110,12 @@ class EditStudioService:
},
"search_replace": {
"label": "Search & Replace",
"description": "Locate objects via search prompt and replace them.",
"description": "Locate objects via search prompt and replace them. Optional mask for precise control.",
"provider": "stability",
"async": False,
"fields": {
"prompt": True,
"mask": False,
"mask": True, # Optional mask for precise region selection
"negative_prompt": False,
"search_prompt": True,
"select_prompt": False,
@@ -126,12 +126,12 @@ class EditStudioService:
},
"search_recolor": {
"label": "Search & Recolor",
"description": "Select elements via prompt and recolor them.",
"description": "Select elements via prompt and recolor them. Optional mask for exact region selection.",
"provider": "stability",
"async": False,
"fields": {
"prompt": True,
"mask": False,
"mask": True, # Optional mask for precise region selection
"negative_prompt": False,
"search_prompt": False,
"select_prompt": True,
@@ -158,12 +158,12 @@ class EditStudioService:
},
"general_edit": {
"label": "Prompt-based Edit",
"description": "Free-form editing powered by Hugging Face image-to-image models.",
"description": "Free-form editing powered by Hugging Face image-to-image models. Optional mask for selective editing.",
"provider": "huggingface",
"async": False,
"fields": {
"prompt": True,
"mask": False,
"mask": True, # Optional mask for selective region editing
"negative_prompt": True,
"search_prompt": False,
"select_prompt": False,
@@ -346,6 +346,7 @@ class EditStudioService:
image=image_bytes,
prompt=request.prompt,
search_prompt=request.search_prompt,
mask=mask_bytes, # Optional mask for precise region selection
output_format=request.output_format,
)
elif operation == "search_recolor":
@@ -355,6 +356,7 @@ class EditStudioService:
image=image_bytes,
prompt=request.prompt,
select_prompt=request.select_prompt,
mask=mask_bytes, # Optional mask for precise region selection
output_format=request.output_format,
)
elif operation == "relight":
@@ -403,6 +405,7 @@ class EditStudioService:
request.prompt,
options,
user_id,
mask_bytes, # Optional mask for selective editing
)
return result.image_bytes

View File

@@ -0,0 +1,502 @@
"""Social Optimizer service for platform-specific image optimization."""
from __future__ import annotations
import base64
import io
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from PIL import Image, ImageDraw, ImageFont
from .templates import Platform
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.social_optimizer")
@dataclass
class SafeZone:
"""Safe zone configuration for text overlay."""
top: float = 0.1 # Percentage from top
bottom: float = 0.1 # Percentage from bottom
left: float = 0.1 # Percentage from left
right: float = 0.1 # Percentage from right
@dataclass
class PlatformFormat:
"""Platform format specification."""
name: str
width: int
height: int
ratio: str
safe_zone: SafeZone
file_type: str = "PNG"
max_size_mb: float = 5.0
# Platform format definitions with safe zones
PLATFORM_FORMATS: Dict[Platform, List[PlatformFormat]] = {
Platform.INSTAGRAM: [
PlatformFormat(
name="Feed Post (Square)",
width=1080,
height=1080,
ratio="1:1",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Feed Post (Portrait)",
width=1080,
height=1350,
ratio="4:5",
safe_zone=SafeZone(top=0.2, bottom=0.2, left=0.1, right=0.1),
),
PlatformFormat(
name="Story",
width=1080,
height=1920,
ratio="9:16",
safe_zone=SafeZone(top=0.25, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Reel",
width=1080,
height=1920,
ratio="9:16",
safe_zone=SafeZone(top=0.25, bottom=0.15, left=0.1, right=0.1),
),
],
Platform.FACEBOOK: [
PlatformFormat(
name="Feed Post",
width=1200,
height=630,
ratio="1.91:1",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Feed Post (Square)",
width=1080,
height=1080,
ratio="1:1",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Story",
width=1080,
height=1920,
ratio="9:16",
safe_zone=SafeZone(top=0.25, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Cover Photo",
width=820,
height=312,
ratio="16:9",
safe_zone=SafeZone(top=0.2, bottom=0.1, left=0.15, right=0.15),
),
],
Platform.TWITTER: [
PlatformFormat(
name="Post",
width=1200,
height=675,
ratio="16:9",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Card",
width=1200,
height=600,
ratio="2:1",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Header",
width=1500,
height=500,
ratio="3:1",
safe_zone=SafeZone(top=0.2, bottom=0.1, left=0.15, right=0.15),
),
],
Platform.LINKEDIN: [
PlatformFormat(
name="Feed Post",
width=1200,
height=628,
ratio="1.91:1",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Feed Post (Square)",
width=1080,
height=1080,
ratio="1:1",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Article",
width=1200,
height=627,
ratio="2:1",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Company Cover",
width=1128,
height=191,
ratio="4:1",
safe_zone=SafeZone(top=0.2, bottom=0.1, left=0.15, right=0.15),
),
],
Platform.YOUTUBE: [
PlatformFormat(
name="Thumbnail",
width=1280,
height=720,
ratio="16:9",
safe_zone=SafeZone(top=0.15, bottom=0.15, left=0.1, right=0.1),
),
PlatformFormat(
name="Channel Art",
width=2560,
height=1440,
ratio="16:9",
safe_zone=SafeZone(top=0.2, bottom=0.1, left=0.15, right=0.15),
),
],
Platform.PINTEREST: [
PlatformFormat(
name="Pin",
width=1000,
height=1500,
ratio="2:3",
safe_zone=SafeZone(top=0.2, bottom=0.2, left=0.1, right=0.1),
),
PlatformFormat(
name="Story Pin",
width=1080,
height=1920,
ratio="9:16",
safe_zone=SafeZone(top=0.25, bottom=0.15, left=0.1, right=0.1),
),
],
Platform.TIKTOK: [
PlatformFormat(
name="Video Cover",
width=1080,
height=1920,
ratio="9:16",
safe_zone=SafeZone(top=0.25, bottom=0.15, left=0.1, right=0.1),
),
],
}
@dataclass
class SocialOptimizerRequest:
"""Request payload for social optimization."""
image_base64: str
platforms: List[Platform] # List of platforms to optimize for
format_names: Optional[Dict[Platform, str]] = None # Specific format per platform
show_safe_zones: bool = False # Include safe zone overlay in output
crop_mode: str = "smart" # "smart", "center", "fit"
focal_point: Optional[Dict[str, float]] = None # {"x": 0.5, "y": 0.5} for smart crop
output_format: str = "png"
options: Dict[str, Any] = field(default_factory=dict)
class SocialOptimizerService:
"""Service for optimizing images for social media platforms."""
def __init__(self):
logger.info("[Social Optimizer] Initialized service")
@staticmethod
def _decode_base64_image(value: str) -> bytes:
"""Decode a base64 (or data URL) string to bytes."""
try:
if value.startswith("data:"):
_, b64data = value.split(",", 1)
else:
b64data = value
return base64.b64decode(b64data)
except Exception as exc:
logger.error(f"[Social Optimizer] Failed to decode base64 image: {exc}")
raise ValueError("Invalid base64 image payload") from exc
@staticmethod
def _bytes_to_base64(image_bytes: bytes, output_format: str = "png") -> str:
"""Convert raw bytes to base64 data URL."""
b64 = base64.b64encode(image_bytes).decode("utf-8")
return f"data:image/{output_format};base64,{b64}"
@staticmethod
def _smart_crop(
image: Image.Image,
target_width: int,
target_height: int,
focal_point: Optional[Dict[str, float]] = None,
) -> Image.Image:
"""Smart crop image to target dimensions, preserving important content."""
img_width, img_height = image.size
target_ratio = target_width / target_height
img_ratio = img_width / img_height
# If focal point is provided, use it for cropping
if focal_point:
focal_x = int(focal_point["x"] * img_width)
focal_y = int(focal_point["y"] * img_height)
else:
# Default to center
focal_x = img_width // 2
focal_y = img_height // 2
if img_ratio > target_ratio:
# Image is wider than target - crop width
new_width = int(img_height * target_ratio)
left = max(0, min(focal_x - new_width // 2, img_width - new_width))
right = left + new_width
cropped = image.crop((left, 0, right, img_height))
else:
# Image is taller than target - crop height
new_height = int(img_width / target_ratio)
top = max(0, min(focal_y - new_height // 2, img_height - new_height))
bottom = top + new_height
cropped = image.crop((0, top, img_width, bottom))
# Resize to exact target dimensions
return cropped.resize((target_width, target_height), Image.Resampling.LANCZOS)
@staticmethod
def _fit_image(
image: Image.Image,
target_width: int,
target_height: int,
) -> Image.Image:
"""Fit image to target dimensions while maintaining aspect ratio (adds padding if needed)."""
img_width, img_height = image.size
target_ratio = target_width / target_height
img_ratio = img_width / img_height
if img_ratio > target_ratio:
# Image is wider - fit to height, pad width
new_height = target_height
new_width = int(img_width * (target_height / img_height))
resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Create new image with target size and paste centered
result = Image.new("RGB", (target_width, target_height), (255, 255, 255))
paste_x = (target_width - new_width) // 2
result.paste(resized, (paste_x, 0))
return result
else:
# Image is taller - fit to width, pad height
new_width = target_width
new_height = int(img_height * (target_width / img_width))
resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Create new image with target size and paste centered
result = Image.new("RGB", (target_width, target_height), (255, 255, 255))
paste_y = (target_height - new_height) // 2
result.paste(resized, (0, paste_y))
return result
@staticmethod
def _center_crop(
image: Image.Image,
target_width: int,
target_height: int,
) -> Image.Image:
"""Center crop image to target dimensions."""
img_width, img_height = image.size
target_ratio = target_width / target_height
img_ratio = img_width / img_height
if img_ratio > target_ratio:
# Image is wider - crop width
new_width = int(img_height * target_ratio)
left = (img_width - new_width) // 2
cropped = image.crop((left, 0, left + new_width, img_height))
else:
# Image is taller - crop height
new_height = int(img_width / target_ratio)
top = (img_height - new_height) // 2
cropped = image.crop((0, top, img_width, top + new_height))
return cropped.resize((target_width, target_height), Image.Resampling.LANCZOS)
@staticmethod
def _draw_safe_zone(
image: Image.Image,
safe_zone: SafeZone,
) -> Image.Image:
"""Draw safe zone overlay on image."""
draw = ImageDraw.Draw(image)
width, height = image.size
# Calculate safe zone boundaries
top = int(height * safe_zone.top)
bottom = int(height * (1 - safe_zone.bottom))
left = int(width * safe_zone.left)
right = int(width * (1 - safe_zone.right))
# Draw semi-transparent overlay outside safe zone
overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
overlay_draw = ImageDraw.Draw(overlay)
# Top area
overlay_draw.rectangle([(0, 0), (width, top)], fill=(0, 0, 0, 100))
# Bottom area
overlay_draw.rectangle([(0, bottom), (width, height)], fill=(0, 0, 0, 100))
# Left area
overlay_draw.rectangle([(0, top), (left, bottom)], fill=(0, 0, 0, 100))
# Right area
overlay_draw.rectangle([(right, top), (width, bottom)], fill=(0, 0, 0, 100))
# Draw safe zone border
border_color = (255, 255, 0, 200) # Yellow with transparency
overlay_draw.rectangle(
[(left, top), (right, bottom)],
outline=border_color,
width=2,
)
# Composite overlay onto image
if image.mode != "RGBA":
image = image.convert("RGBA")
image = Image.alpha_composite(image, overlay)
return image
def get_platform_formats(self, platform: Platform) -> List[Dict[str, Any]]:
"""Get available formats for a platform."""
formats = PLATFORM_FORMATS.get(platform, [])
return [
{
"name": fmt.name,
"width": fmt.width,
"height": fmt.height,
"ratio": fmt.ratio,
"safe_zone": {
"top": fmt.safe_zone.top,
"bottom": fmt.safe_zone.bottom,
"left": fmt.safe_zone.left,
"right": fmt.safe_zone.right,
},
"file_type": fmt.file_type,
"max_size_mb": fmt.max_size_mb,
}
for fmt in formats
]
def optimize_image(
self,
request: SocialOptimizerRequest,
) -> Dict[str, Any]:
"""Optimize image for specified platforms."""
logger.info(
f"[Social Optimizer] Processing optimization for {len(request.platforms)} platform(s)"
)
# Decode input image
image_bytes = self._decode_base64_image(request.image_base64)
original_image = Image.open(io.BytesIO(image_bytes))
# Convert to RGB if needed
if original_image.mode in ("RGBA", "LA", "P"):
if original_image.mode == "P":
original_image = original_image.convert("RGBA")
background = Image.new("RGB", original_image.size, (255, 255, 255))
if original_image.mode == "RGBA":
background.paste(original_image, mask=original_image.split()[-1])
else:
background.paste(original_image)
original_image = background
elif original_image.mode != "RGB":
original_image = original_image.convert("RGB")
results = []
for platform in request.platforms:
formats = PLATFORM_FORMATS.get(platform, [])
if not formats:
logger.warning(f"[Social Optimizer] No formats found for platform: {platform}")
continue
# Get format (use specified format or default to first)
format_name = None
if request.format_names and platform in request.format_names:
format_name = request.format_names[platform]
platform_format = None
for fmt in formats:
if format_name and fmt.name == format_name:
platform_format = fmt
break
if not platform_format:
platform_format = formats[0] # Default to first format
# Crop/resize image based on mode
if request.crop_mode == "smart":
optimized_image = self._smart_crop(
original_image,
platform_format.width,
platform_format.height,
request.focal_point,
)
elif request.crop_mode == "fit":
optimized_image = self._fit_image(
original_image,
platform_format.width,
platform_format.height,
)
else: # center
optimized_image = self._center_crop(
original_image,
platform_format.width,
platform_format.height,
)
# Add safe zone overlay if requested
if request.show_safe_zones:
optimized_image = self._draw_safe_zone(optimized_image, platform_format.safe_zone)
# Convert to bytes
output_buffer = io.BytesIO()
output_format = request.output_format.lower()
if output_format == "jpg" or output_format == "jpeg":
optimized_image = optimized_image.convert("RGB")
optimized_image.save(output_buffer, format="JPEG", quality=95)
else:
optimized_image.save(output_buffer, format="PNG")
output_bytes = output_buffer.getvalue()
results.append(
{
"platform": platform.value,
"format": platform_format.name,
"width": platform_format.width,
"height": platform_format.height,
"ratio": platform_format.ratio,
"image_base64": self._bytes_to_base64(output_bytes, request.output_format),
"safe_zone": {
"top": platform_format.safe_zone.top,
"bottom": platform_format.safe_zone.bottom,
"left": platform_format.safe_zone.left,
"right": platform_format.safe_zone.right,
},
}
)
logger.info(f"[Social Optimizer] ✅ Generated {len(results)} optimized images")
return {
"success": True,
"results": results,
"total_optimized": len(results),
}

View File

@@ -5,6 +5,8 @@ from typing import Optional, Dict, Any, List
from .create_service import CreateStudioService, CreateStudioRequest
from .edit_service import EditStudioService, EditStudioRequest
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
from .control_service import ControlStudioService, ControlStudioRequest
from .social_optimizer_service import SocialOptimizerService, SocialOptimizerRequest
from .templates import Platform, TemplateCategory, ImageTemplate
from utils.logger_utils import get_service_logger
@@ -20,6 +22,8 @@ class ImageStudioManager:
self.create_service = CreateStudioService()
self.edit_service = EditStudioService()
self.upscale_service = UpscaleStudioService()
self.control_service = ControlStudioService()
self.social_optimizer_service = SocialOptimizerService()
logger.info("[Image Studio Manager] Initialized successfully")
# ====================
@@ -215,6 +219,40 @@ class ImageStudioManager:
"estimated": True,
}
# ====================
# CONTROL STUDIO
# ====================
async def control_image(
self,
request: ControlStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Run Control Studio operations."""
logger.info("[Image Studio] Control request from user: %s", user_id)
return await self.control_service.process_control(request, user_id=user_id)
def get_control_operations(self) -> Dict[str, Any]:
"""Expose control operations for UI."""
return self.control_service.list_operations()
# ====================
# SOCIAL OPTIMIZER
# ====================
async def optimize_for_social(
self,
request: SocialOptimizerRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Optimize image for social media platforms."""
logger.info("[Image Studio] Social optimization request from user: %s", user_id)
return self.social_optimizer_service.optimize_image(request)
def get_social_platform_formats(self, platform: Platform) -> List[Dict[str, Any]]:
"""Get available formats for a social platform."""
return self.social_optimizer_service.get_platform_formats(platform)
# ====================
# PLATFORM SPECS
# ====================

View File

@@ -54,7 +54,8 @@ def edit_image(
input_image_bytes: bytes,
prompt: str,
options: Optional[Dict[str, Any]] = None,
user_id: Optional[str] = None
user_id: Optional[str] = None,
mask_bytes: Optional[bytes] = None,
) -> ImageGenerationResult:
"""Edit image with pre-flight validation.
@@ -63,6 +64,7 @@ def edit_image(
prompt: Natural language prompt describing desired edits (e.g., "Turn the cat into a tiger")
options: Image editing options (provider, model, etc.)
user_id: User ID for subscription checking (optional, but required for validation)
mask_bytes: Optional mask image bytes for selective editing (grayscale, white=edit, black=preserve)
Returns:
ImageGenerationResult with edited image bytes and metadata
@@ -72,6 +74,8 @@ def edit_image(
- Describe what should change and what should remain
- Examples: "Turn the cat into a tiger", "Change background to forest",
"Make it look like a watercolor painting"
Note: Mask support depends on the specific model. Some models may ignore the mask parameter.
"""
# PRE-FLIGHT VALIDATION: Validate image editing before API call
# MUST happen BEFORE any API calls - return immediately if validation fails
@@ -128,14 +132,33 @@ def edit_image(
width = input_image.width
height = input_image.height
# Convert mask bytes to PIL Image if provided
mask_image = None
if mask_bytes:
try:
mask_image = Image.open(io.BytesIO(mask_bytes)).convert("L") # Convert to grayscale
# Ensure mask dimensions match input image
if mask_image.size != input_image.size:
logger.warning(f"[Image Editing] Mask size {mask_image.size} doesn't match image size {input_image.size}, resizing mask")
mask_image = mask_image.resize(input_image.size, Image.Resampling.LANCZOS)
except Exception as e:
logger.warning(f"[Image Editing] Failed to process mask image: {e}, continuing without mask")
mask_image = None
# Use image_to_image method from Hugging Face InferenceClient
# This follows the pattern from the Hugging Face documentation
# Docs: https://huggingface.co/docs/inference-providers/en/guides/image-editor
# Note: Mask support depends on the model - some models may ignore it
call_params = params.copy()
if mask_image:
call_params["mask_image"] = mask_image
logger.info("[Image Editing] Using mask for selective editing")
edited_image: Image.Image = client.image_to_image(
image=input_image,
prompt=prompt.strip(),
model=model,
**params,
**call_params,
)
# Convert edited image back to bytes

View File

@@ -397,6 +397,7 @@ class StabilityAIService:
image: Union[UploadFile, bytes],
prompt: str,
search_prompt: str,
mask: Optional[Union[UploadFile, bytes]] = None,
**kwargs
) -> Union[bytes, Dict[str, Any]]:
"""Replace objects in image using search prompt.
@@ -405,6 +406,7 @@ class StabilityAIService:
image: Input image
prompt: Text prompt for replacement
search_prompt: What to search for
mask: Optional mask image for precise region selection
**kwargs: Additional parameters
Returns:
@@ -414,6 +416,8 @@ class StabilityAIService:
data.update({k: v for k, v in kwargs.items() if v is not None})
files = {"image": await self._prepare_image_file(image)}
if mask:
files["mask"] = await self._prepare_image_file(mask)
return await self._make_request(
method="POST",
@@ -427,6 +431,7 @@ class StabilityAIService:
image: Union[UploadFile, bytes],
prompt: str,
select_prompt: str,
mask: Optional[Union[UploadFile, bytes]] = None,
**kwargs
) -> Union[bytes, Dict[str, Any]]:
"""Recolor objects in image using select prompt.
@@ -435,6 +440,7 @@ class StabilityAIService:
image: Input image
prompt: Text prompt for recoloring
select_prompt: What to select for recoloring
mask: Optional mask image for precise region selection
**kwargs: Additional parameters
Returns:
@@ -444,6 +450,8 @@ class StabilityAIService:
data.update({k: v for k, v in kwargs.items() if v is not None})
files = {"image": await self._prepare_image_file(image)}
if mask:
files["mask"] = await self._prepare_image_file(mask)
return await self._make_request(
method="POST",

View File

@@ -415,6 +415,75 @@ def validate_image_editing_operations(
)
def validate_image_control_operations(
pricing_service: PricingService,
user_id: str,
num_images: int = 1
) -> None:
"""
Validate image control operations (sketch-to-image, structure control, style transfer) before making API calls.
Control operations use Stability AI for image generation with control inputs, so they use
the same validation as image generation operations.
Args:
pricing_service: PricingService instance
user_id: User ID for subscription checking
num_images: Number of images to generate (for multiple variations)
Returns:
None - raises HTTPException with 429 status if validation fails
"""
try:
# Control operations use Stability AI, same as image generation
operations_to_validate = [
{
'provider': APIProvider.STABILITY,
'tokens_requested': 0,
'actual_provider_name': 'stability',
'operation_type': 'image_generation' # Control ops use image generation limits
}
for _ in range(num_images)
]
logger.info(f"[Pre-flight Validator] 🚀 Validating {num_images} image control operation(s) for user {user_id}")
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
if not can_proceed:
logger.error(f"[Pre-flight Validator] Image control blocked for user {user_id}: {message}")
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', 'stability') if usage_info else 'stability'
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details
}
)
logger.info(f"[Pre-flight Validator] ✅ Image control validated for user {user_id}")
except HTTPException:
raise
except Exception as e:
logger.error(f"[Pre-flight Validator] Error validating image control: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
'error': f"Failed to validate image control: {str(e)}",
'message': f"Failed to validate image control: {str(e)}"
}
)
def validate_video_generation_operations(
pricing_service: PricingService,
user_id: str

View File

@@ -0,0 +1,158 @@
"""
Asset Tracker Utility
Helper utility for modules to easily save generated content to the unified asset library.
"""
from typing import Dict, Any, Optional
from sqlalchemy.orm import Session
from services.content_asset_service import ContentAssetService
from models.content_asset_models import AssetType, AssetSource
import logging
import re
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
# Maximum file size (100MB)
MAX_FILE_SIZE = 100 * 1024 * 1024
# Allowed URL schemes
ALLOWED_URL_SCHEMES = ['http', 'https', '/'] # Allow relative paths starting with /
def validate_file_url(file_url: str) -> bool:
"""Validate file URL format."""
if not file_url or not isinstance(file_url, str):
return False
# Allow relative paths
if file_url.startswith('/'):
return True
# Validate absolute URLs
try:
parsed = urlparse(file_url)
return parsed.scheme in ALLOWED_URL_SCHEMES and parsed.netloc
except Exception:
return False
def save_asset_to_library(
db: Session,
user_id: str,
asset_type: str,
source_module: str,
filename: str,
file_url: str,
file_path: Optional[str] = None,
file_size: Optional[int] = None,
mime_type: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
prompt: Optional[str] = None,
tags: Optional[list] = None,
metadata: Optional[Dict[str, Any]] = None,
provider: Optional[str] = None,
model: Optional[str] = None,
cost: Optional[float] = None,
generation_time: Optional[float] = None,
) -> Optional[int]:
"""
Helper function to save a generated asset to the unified asset library.
This can be called from any module (story writer, image studio, etc.)
to automatically track generated content.
Args:
db: Database session
user_id: Clerk user ID
asset_type: 'text', 'image', 'video', or 'audio'
source_module: 'story_writer', 'image_studio', 'main_text_generation', etc.
filename: Original filename
file_url: Public URL to access the asset
file_path: Server file path (optional)
file_size: File size in bytes (optional)
mime_type: MIME type (optional)
title: Asset title (optional)
description: Asset description (optional)
prompt: Generation prompt (optional)
tags: List of tags (optional)
metadata: Additional metadata (optional)
provider: AI provider used (optional)
model: Model used (optional)
cost: Generation cost (optional)
generation_time: Generation time in seconds (optional)
Returns:
Asset ID if successful, None otherwise
"""
try:
# Validate inputs
if not user_id or not isinstance(user_id, str):
logger.error("Invalid user_id provided")
return None
if not filename or not isinstance(filename, str):
logger.error("Invalid filename provided")
return None
if not validate_file_url(file_url):
logger.error(f"Invalid file_url format: {file_url}")
return None
if file_size and file_size > MAX_FILE_SIZE:
logger.warning(f"File size {file_size} exceeds maximum {MAX_FILE_SIZE}")
# Don't fail, just log warning
# Convert string enums to enum types
try:
asset_type_enum = AssetType(asset_type.lower())
except ValueError:
logger.warning(f"Invalid asset type: {asset_type}, defaulting to 'text'")
asset_type_enum = AssetType.TEXT
try:
source_module_enum = AssetSource(source_module.lower())
except ValueError:
logger.warning(f"Invalid source module: {source_module}, defaulting to 'story_writer'")
source_module_enum = AssetSource.STORY_WRITER
# Sanitize filename (remove path traversal attempts)
filename = re.sub(r'[^\w\s\-_\.]', '', filename.split('/')[-1])
if not filename:
filename = f"asset_{asset_type}_{source_module}.{asset_type}"
# Generate title from filename if not provided
if not title:
title = filename.replace('_', ' ').replace('-', ' ').title()
# Limit title length
if len(title) > 200:
title = title[:197] + '...'
service = ContentAssetService(db)
asset = service.create_asset(
user_id=user_id,
asset_type=asset_type_enum,
source_module=source_module_enum,
filename=filename,
file_url=file_url,
file_path=file_path,
file_size=file_size,
mime_type=mime_type,
title=title,
description=description,
prompt=prompt,
tags=tags or [],
metadata=metadata or {},
provider=provider,
model=model,
cost=cost,
generation_time=generation_time,
)
logger.info(f"✅ Asset saved to library: {asset.id} ({asset_type} from {source_module})")
return asset.id
except Exception as e:
logger.error(f"❌ Error saving asset to library: {str(e)}", exc_info=True)
return None