AI Image Studio, AI podcast Maker, AI product Marketing

This commit is contained in:
ajaysi
2025-11-28 14:33:52 +05:30
parent 77d7c0cde6
commit 49e2131715
122 changed files with 22311 additions and 4331 deletions

View File

@@ -10,6 +10,9 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from loguru import logger
from middleware.auth_middleware import get_current_user
from sqlalchemy.orm import Session
from services.database import get_db as get_db_dependency
from utils.text_asset_tracker import save_and_track_text_content
from models.blog_models import (
BlogResearchRequest,
@@ -41,6 +44,10 @@ router = APIRouter(prefix="/api/blog", tags=["AI Blog Writer"])
service = BlogWriterService()
recommendation_applier = BlogSEORecommendationApplier()
# Use the proper database dependency from services.database
get_db = get_db_dependency
# ---------------------------
# SEO Recommendation Endpoints
# ---------------------------
@@ -272,10 +279,41 @@ async def rebalance_outline(outline_data: Dict[str, Any], target_words: int = 15
# Content Generation Endpoints
@router.post("/section/generate", response_model=BlogSectionResponse)
async def generate_section(request: BlogSectionRequest) -> BlogSectionResponse:
async def generate_section(
request: BlogSectionRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db)
) -> BlogSectionResponse:
"""Generate content for a specific section."""
try:
return await service.generate_section(request)
response = await service.generate_section(request)
# Save and track text content (non-blocking)
if response.markdown:
try:
user_id = str(current_user.get('id', '')) if current_user else None
if user_id:
section_heading = getattr(request, 'section_heading', getattr(request, 'heading', 'Section'))
save_and_track_text_content(
db=db,
user_id=user_id,
content=response.markdown,
source_module="blog_writer",
title=f"Blog Section: {section_heading[:60]}",
description=f"Blog section content",
prompt=f"Section: {section_heading}\nKeywords: {getattr(request, 'keywords', [])}",
tags=["blog", "section", "content"],
asset_metadata={
"section_id": getattr(request, 'section_id', None),
"word_count": len(response.markdown.split()),
},
subdirectory="sections",
file_extension=".md"
)
except Exception as track_error:
logger.warning(f"Failed to track blog section asset: {track_error}")
return response
except Exception as e:
logger.error(f"Failed to generate section: {e}")
raise HTTPException(status_code=500, detail=str(e))
@@ -321,13 +359,48 @@ async def start_content_generation(
@router.get("/content/status/{task_id}")
async def content_generation_status(task_id: str) -> Dict[str, Any]:
async def content_generation_status(
task_id: str,
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
db: Session = Depends(get_db)
) -> Dict[str, Any]:
"""Poll status for content generation task."""
try:
status = await task_manager.get_task_status(task_id)
if status is None:
raise HTTPException(status_code=404, detail="Task not found")
# Track blog content when task completes (non-blocking)
if status.get('status') == 'completed' and status.get('result'):
try:
result = status.get('result', {})
if result.get('sections') and len(result.get('sections', [])) > 0:
user_id = str(current_user.get('id', '')) if current_user else None
if user_id:
# Combine all sections into full blog content
blog_content = f"# {result.get('title', 'Untitled Blog')}\n\n"
for section in result.get('sections', []):
blog_content += f"\n## {section.get('heading', 'Section')}\n\n{section.get('content', '')}\n\n"
save_and_track_text_content(
db=db,
user_id=user_id,
content=blog_content,
source_module="blog_writer",
title=f"Blog: {result.get('title', 'Untitled Blog')[:60]}",
description=f"Complete blog post with {len(result.get('sections', []))} sections",
prompt=f"Title: {result.get('title', 'Untitled')}\nSections: {len(result.get('sections', []))}",
tags=["blog", "complete", "content"],
asset_metadata={
"section_count": len(result.get('sections', [])),
"model": result.get('model'),
},
subdirectory="complete",
file_extension=".md"
)
except Exception as track_error:
logger.warning(f"Failed to track blog content asset: {track_error}")
# If task failed with subscription error, return HTTP error so frontend interceptor can catch it
if status.get('status') == 'failed' and status.get('error_status') in [429, 402]:
error_data = status.get('error_data', {}) or {}
@@ -420,10 +493,40 @@ async def analyze_flow_advanced(request: Dict[str, Any]) -> Dict[str, Any]:
@router.post("/section/optimize", response_model=BlogOptimizeResponse)
async def optimize_section(request: BlogOptimizeRequest) -> BlogOptimizeResponse:
async def optimize_section(
request: BlogOptimizeRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db)
) -> BlogOptimizeResponse:
"""Optimize a specific section for better quality and engagement."""
try:
return await service.optimize_section(request)
response = await service.optimize_section(request)
# Save and track text content (non-blocking)
if response.optimized:
try:
user_id = str(current_user.get('id', '')) if current_user else None
if user_id:
save_and_track_text_content(
db=db,
user_id=user_id,
content=response.optimized,
source_module="blog_writer",
title=f"Optimized Blog Section",
description=f"Optimized blog section content",
prompt=f"Original Content: {request.content[:200]}\nGoals: {request.goals}",
tags=["blog", "section", "optimized"],
asset_metadata={
"optimization_goals": request.goals,
"word_count": len(response.optimized.split()),
},
subdirectory="sections/optimized",
file_extension=".md"
)
except Exception as track_error:
logger.warning(f"Failed to track optimized blog section asset: {track_error}")
return response
except Exception as e:
logger.error(f"Failed to optimize section: {e}")
raise HTTPException(status_code=500, detail=str(e))
@@ -591,13 +694,49 @@ async def start_medium_generation(
@router.get("/generate/medium/status/{task_id}")
async def medium_generation_status(task_id: str):
async def medium_generation_status(
task_id: str,
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Poll status for medium blog generation task."""
try:
status = await task_manager.get_task_status(task_id)
if status is None:
raise HTTPException(status_code=404, detail="Task not found")
# Track blog content when task completes (non-blocking)
if status.get('status') == 'completed' and status.get('result'):
try:
result = status.get('result', {})
if result.get('sections') and len(result.get('sections', [])) > 0:
user_id = str(current_user.get('id', '')) if current_user else None
if user_id:
# Combine all sections into full blog content
blog_content = f"# {result.get('title', 'Untitled Blog')}\n\n"
for section in result.get('sections', []):
blog_content += f"\n## {section.get('heading', 'Section')}\n\n{section.get('content', '')}\n\n"
save_and_track_text_content(
db=db,
user_id=user_id,
content=blog_content,
source_module="blog_writer",
title=f"Medium Blog: {result.get('title', 'Untitled Blog')[:60]}",
description=f"Medium-length blog post with {len(result.get('sections', []))} sections",
prompt=f"Title: {result.get('title', 'Untitled')}\nSections: {len(result.get('sections', []))}",
tags=["blog", "medium", "complete"],
asset_metadata={
"section_count": len(result.get('sections', [])),
"model": result.get('model'),
"generation_time_ms": result.get('generation_time_ms'),
},
subdirectory="medium",
file_extension=".md"
)
except Exception as track_error:
logger.warning(f"Failed to track medium blog asset: {track_error}")
# If task failed with subscription error, return HTTP error so frontend interceptor can catch it
if status.get('status') == 'failed' and status.get('error_status') in [429, 402]:
error_data = status.get('error_data', {}) or {}
@@ -677,7 +816,8 @@ async def rewrite_status(task_id: str):
@router.post("/titles/generate-seo")
async def generate_seo_titles(
request: Dict[str, Any],
current_user: Dict[str, Any] = Depends(get_current_user)
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db)
) -> Dict[str, Any]:
"""Generate 5 SEO-optimized blog titles using research and outline data."""
try:
@@ -722,6 +862,30 @@ async def generate_seo_titles(
user_id=user_id
)
# Save and track titles (non-blocking)
if titles and len(titles) > 0:
try:
titles_content = "# SEO Blog Titles\n\n" + "\n".join([f"{i+1}. {title}" for i, title in enumerate(titles)])
save_and_track_text_content(
db=db,
user_id=user_id,
content=titles_content,
source_module="blog_writer",
title=f"SEO Blog Titles: {primary_keywords[0] if primary_keywords else 'Blog'}",
description=f"SEO-optimized blog title suggestions",
prompt=f"Primary Keywords: {primary_keywords}\nSearch Intent: {search_intent}\nWord Count: {word_count}",
tags=["blog", "titles", "seo"],
asset_metadata={
"title_count": len(titles),
"primary_keywords": primary_keywords,
"search_intent": search_intent,
},
subdirectory="titles",
file_extension=".md"
)
except Exception as track_error:
logger.warning(f"Failed to track SEO titles asset: {track_error}")
return {
"success": True,
"titles": titles
@@ -736,7 +900,8 @@ async def generate_seo_titles(
@router.post("/introductions/generate")
async def generate_introductions(
request: Dict[str, Any],
current_user: Dict[str, Any] = Depends(get_current_user)
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db)
) -> Dict[str, Any]:
"""Generate 3 varied blog introductions using research, outline, and content."""
try:
@@ -781,6 +946,33 @@ async def generate_introductions(
user_id=user_id
)
# Save and track introductions (non-blocking)
if introductions and len(introductions) > 0:
try:
intro_content = f"# Blog Introductions for: {blog_title}\n\n"
for i, intro in enumerate(introductions, 1):
intro_content += f"## Introduction {i}\n\n{intro}\n\n"
save_and_track_text_content(
db=db,
user_id=user_id,
content=intro_content,
source_module="blog_writer",
title=f"Blog Introductions: {blog_title[:60]}",
description=f"Blog introduction variations",
prompt=f"Blog Title: {blog_title}\nPrimary Keywords: {primary_keywords}\nSearch Intent: {search_intent}",
tags=["blog", "introductions"],
asset_metadata={
"introduction_count": len(introductions),
"blog_title": blog_title,
"search_intent": search_intent,
},
subdirectory="introductions",
file_extension=".md"
)
except Exception as track_error:
logger.warning(f"Failed to track blog introductions asset: {track_error}")
return {
"success": True,
"introductions": introductions

View File

@@ -21,6 +21,7 @@ from models.blog_models import (
)
from services.blog_writer.blog_service import BlogWriterService
from services.blog_writer.database_task_manager import DatabaseTaskManager
from utils.text_asset_tracker import save_and_track_text_content
class TaskManager:
@@ -281,6 +282,9 @@ class TaskManager:
self.task_storage[task_id]["status"] = "completed"
self.task_storage[task_id]["result"] = result.dict()
await self.update_progress(task_id, f"✅ Generated {len(result.sections)} sections successfully.")
# Note: Blog content tracking is handled in the status endpoint
# to ensure we have proper database session and user context
except HTTPException as http_error:
# Handle HTTPException (e.g., 429 subscription limit) - preserve error details for frontend

View File

@@ -32,7 +32,7 @@ class AssetResponse(BaseModel):
description: Optional[str] = None
prompt: Optional[str] = None
tags: List[str] = []
metadata: Dict[str, Any] = {}
asset_metadata: Dict[str, Any] = {}
provider: Optional[str] = None
model: Optional[str] = None
cost: float = 0.0

View File

@@ -1,11 +1,15 @@
"""FastAPI router for Facebook Writer endpoints."""
from fastapi import APIRouter, HTTPException, Depends
from typing import Dict, Any
from typing import Dict, Any, Optional
import logging
from sqlalchemy.orm import Session
from ..models import *
from ..services import *
from middleware.auth_middleware import get_current_user
from services.database import get_db as get_db_dependency
from utils.text_asset_tracker import save_and_track_text_content
# Configure logging
logger = logging.getLogger(__name__)
@@ -115,9 +119,17 @@ async def get_available_tools():
return {"tools": tools, "total_count": len(tools)}
# Use the proper database dependency from services.database
get_db = get_db_dependency
# Content Creation Endpoints
@router.post("/post/generate", response_model=FacebookPostResponse)
async def generate_facebook_post(request: FacebookPostRequest):
async def generate_facebook_post(
request: FacebookPostRequest,
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Generate a Facebook post with engagement optimization."""
try:
logger.info(f"Generating Facebook post for business: {request.business_type}")
@@ -126,6 +138,37 @@ async def generate_facebook_post(request: FacebookPostRequest):
if not response.success:
raise HTTPException(status_code=400, detail=response.error)
# Save and track text content (non-blocking)
if response.content:
try:
user_id = None
if current_user:
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
if user_id:
text_content = response.content
if response.analytics:
text_content += f"\n\n## Analytics\nExpected Reach: {response.analytics.expected_reach}\nExpected Engagement: {response.analytics.expected_engagement}\nBest Time to Post: {response.analytics.best_time_to_post}"
save_and_track_text_content(
db=db,
user_id=user_id,
content=text_content,
source_module="facebook_writer",
title=f"Facebook Post: {request.business_type[:60]}",
description=f"Facebook post for {request.business_type}",
prompt=f"Business Type: {request.business_type}\nTarget Audience: {request.target_audience}\nGoal: {request.post_goal.value if hasattr(request.post_goal, 'value') else request.post_goal}\nTone: {request.post_tone.value if hasattr(request.post_tone, 'value') else request.post_tone}",
tags=["facebook", "post", request.business_type.lower().replace(' ', '_')],
asset_metadata={
"post_goal": request.post_goal.value if hasattr(request.post_goal, 'value') else str(request.post_goal),
"post_tone": request.post_tone.value if hasattr(request.post_tone, 'value') else str(request.post_tone),
"media_type": request.media_type.value if hasattr(request.media_type, 'value') else str(request.media_type)
},
subdirectory="posts"
)
except Exception as track_error:
logger.warning(f"Failed to track Facebook post asset: {track_error}")
return response
except Exception as e:
@@ -134,7 +177,11 @@ async def generate_facebook_post(request: FacebookPostRequest):
@router.post("/story/generate", response_model=FacebookStoryResponse)
async def generate_facebook_story(request: FacebookStoryRequest):
async def generate_facebook_story(
request: FacebookStoryRequest,
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Generate a Facebook story with visual suggestions."""
try:
logger.info(f"Generating Facebook story for business: {request.business_type}")
@@ -143,6 +190,31 @@ async def generate_facebook_story(request: FacebookStoryRequest):
if not response.success:
raise HTTPException(status_code=400, detail=response.error)
# Save and track text content (non-blocking)
if response.content:
try:
user_id = None
if current_user:
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
if user_id:
save_and_track_text_content(
db=db,
user_id=user_id,
content=response.content,
source_module="facebook_writer",
title=f"Facebook Story: {request.business_type[:60]}",
description=f"Facebook story for {request.business_type}",
prompt=f"Business Type: {request.business_type}\nStory Type: {request.story_type.value if hasattr(request.story_type, 'value') else request.story_type}",
tags=["facebook", "story", request.business_type.lower().replace(' ', '_')],
asset_metadata={
"story_type": request.story_type.value if hasattr(request.story_type, 'value') else str(request.story_type)
},
subdirectory="stories"
)
except Exception as track_error:
logger.warning(f"Failed to track Facebook story asset: {track_error}")
return response
except Exception as e:
@@ -151,7 +223,11 @@ async def generate_facebook_story(request: FacebookStoryRequest):
@router.post("/reel/generate", response_model=FacebookReelResponse)
async def generate_facebook_reel(request: FacebookReelRequest):
async def generate_facebook_reel(
request: FacebookReelRequest,
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Generate a Facebook reel script with music suggestions."""
try:
logger.info(f"Generating Facebook reel for business: {request.business_type}")
@@ -160,6 +236,42 @@ async def generate_facebook_reel(request: FacebookReelRequest):
if not response.success:
raise HTTPException(status_code=400, detail=response.error)
# Save and track text content (non-blocking)
if response.script:
try:
user_id = None
if current_user:
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
if user_id:
text_content = f"# Facebook Reel Script\n\n## Script\n{response.script}\n"
if response.scene_breakdown:
text_content += f"\n## Scene Breakdown\n" + "\n".join([f"{i+1}. {scene}" for i, scene in enumerate(response.scene_breakdown)]) + "\n"
if response.music_suggestions:
text_content += f"\n## Music Suggestions\n" + "\n".join(response.music_suggestions) + "\n"
if response.hashtag_suggestions:
text_content += f"\n## Hashtag Suggestions\n" + " ".join([f"#{tag}" for tag in response.hashtag_suggestions]) + "\n"
save_and_track_text_content(
db=db,
user_id=user_id,
content=text_content,
source_module="facebook_writer",
title=f"Facebook Reel: {request.topic[:60]}",
description=f"Facebook reel script for {request.business_type}",
prompt=f"Business Type: {request.business_type}\nTopic: {request.topic}\nReel Type: {request.reel_type.value if hasattr(request.reel_type, 'value') else request.reel_type}\nLength: {request.reel_length.value if hasattr(request.reel_length, 'value') else request.reel_length}",
tags=["facebook", "reel", request.business_type.lower().replace(' ', '_')],
asset_metadata={
"reel_type": request.reel_type.value if hasattr(request.reel_type, 'value') else str(request.reel_type),
"reel_length": request.reel_length.value if hasattr(request.reel_length, 'value') else str(request.reel_length),
"reel_style": request.reel_style.value if hasattr(request.reel_style, 'value') else str(request.reel_style)
},
subdirectory="reels",
file_extension=".md"
)
except Exception as track_error:
logger.warning(f"Failed to track Facebook reel asset: {track_error}")
return response
except Exception as e:
@@ -168,7 +280,11 @@ async def generate_facebook_reel(request: FacebookReelRequest):
@router.post("/carousel/generate", response_model=FacebookCarouselResponse)
async def generate_facebook_carousel(request: FacebookCarouselRequest):
async def generate_facebook_carousel(
request: FacebookCarouselRequest,
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Generate a Facebook carousel post with multiple slides."""
try:
logger.info(f"Generating Facebook carousel for business: {request.business_type}")
@@ -177,6 +293,44 @@ async def generate_facebook_carousel(request: FacebookCarouselRequest):
if not response.success:
raise HTTPException(status_code=400, detail=response.error)
# Save and track text content (non-blocking)
if response.main_caption and response.slides:
try:
user_id = None
if current_user:
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
if user_id:
text_content = f"# Facebook Carousel\n\n## Main Caption\n{response.main_caption}\n\n"
text_content += "## Slides\n"
for i, slide in enumerate(response.slides, 1):
text_content += f"\n### Slide {i}: {slide.title}\n{slide.content}\n"
if slide.image_description:
text_content += f"Image Description: {slide.image_description}\n"
if response.hashtag_suggestions:
text_content += f"\n## Hashtag Suggestions\n" + " ".join([f"#{tag}" for tag in response.hashtag_suggestions]) + "\n"
save_and_track_text_content(
db=db,
user_id=user_id,
content=text_content,
source_module="facebook_writer",
title=f"Facebook Carousel: {request.topic[:60]}",
description=f"Facebook carousel for {request.business_type}",
prompt=f"Business Type: {request.business_type}\nTopic: {request.topic}\nCarousel Type: {request.carousel_type.value if hasattr(request.carousel_type, 'value') else request.carousel_type}\nSlides: {request.num_slides}",
tags=["facebook", "carousel", request.business_type.lower().replace(' ', '_')],
asset_metadata={
"carousel_type": request.carousel_type.value if hasattr(request.carousel_type, 'value') else str(request.carousel_type),
"num_slides": request.num_slides,
"has_cta": request.include_cta
},
subdirectory="carousels",
file_extension=".md"
)
except Exception as track_error:
logger.warning(f"Failed to track Facebook carousel asset: {track_error}")
return response
except Exception as e:
@@ -186,7 +340,11 @@ async def generate_facebook_carousel(request: FacebookCarouselRequest):
# Business Tools Endpoints
@router.post("/event/generate", response_model=FacebookEventResponse)
async def generate_facebook_event(request: FacebookEventRequest):
async def generate_facebook_event(
request: FacebookEventRequest,
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Generate a Facebook event description."""
try:
logger.info(f"Generating Facebook event: {request.event_name}")
@@ -195,6 +353,36 @@ async def generate_facebook_event(request: FacebookEventRequest):
if not response.success:
raise HTTPException(status_code=400, detail=response.error)
# Save and track text content (non-blocking)
if response.description:
try:
user_id = None
if current_user:
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
if user_id:
text_content = f"# Facebook Event: {request.event_name}\n\n## Description\n{response.description}\n"
if hasattr(response, 'details') and response.details:
text_content += f"\n## Details\n{response.details}\n"
save_and_track_text_content(
db=db,
user_id=user_id,
content=text_content,
source_module="facebook_writer",
title=f"Facebook Event: {request.event_name[:60]}",
description=f"Facebook event description for {request.event_name}",
prompt=f"Event Name: {request.event_name}\nEvent Type: {getattr(request, 'event_type', 'N/A')}\nDate: {getattr(request, 'event_date', 'N/A')}",
tags=["facebook", "event", request.event_name.lower().replace(' ', '_')[:20]],
asset_metadata={
"event_name": request.event_name,
"event_type": getattr(request, 'event_type', None)
},
subdirectory="events"
)
except Exception as track_error:
logger.warning(f"Failed to track Facebook event asset: {track_error}")
return response
except Exception as e:
@@ -203,7 +391,11 @@ async def generate_facebook_event(request: FacebookEventRequest):
@router.post("/group-post/generate", response_model=FacebookGroupPostResponse)
async def generate_facebook_group_post(request: FacebookGroupPostRequest):
async def generate_facebook_group_post(
request: FacebookGroupPostRequest,
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Generate a Facebook group post following community guidelines."""
try:
logger.info(f"Generating Facebook group post for: {request.group_name}")
@@ -212,6 +404,32 @@ async def generate_facebook_group_post(request: FacebookGroupPostRequest):
if not response.success:
raise HTTPException(status_code=400, detail=response.error)
# Save and track text content (non-blocking)
if response.content:
try:
user_id = None
if current_user:
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
if user_id:
save_and_track_text_content(
db=db,
user_id=user_id,
content=response.content,
source_module="facebook_writer",
title=f"Facebook Group Post: {request.group_name[:60]}",
description=f"Facebook group post for {request.group_name}",
prompt=f"Group Name: {request.group_name}\nTopic: {getattr(request, 'topic', 'N/A')}",
tags=["facebook", "group_post", request.group_name.lower().replace(' ', '_')[:20]],
asset_metadata={
"group_name": request.group_name,
"group_type": getattr(request, 'group_type', None)
},
subdirectory="group_posts"
)
except Exception as track_error:
logger.warning(f"Failed to track Facebook group post asset: {track_error}")
return response
except Exception as e:
@@ -220,7 +438,11 @@ async def generate_facebook_group_post(request: FacebookGroupPostRequest):
@router.post("/page-about/generate", response_model=FacebookPageAboutResponse)
async def generate_facebook_page_about(request: FacebookPageAboutRequest):
async def generate_facebook_page_about(
request: FacebookPageAboutRequest,
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Generate a Facebook page about section."""
try:
logger.info(f"Generating Facebook page about for: {request.business_name}")
@@ -229,6 +451,32 @@ async def generate_facebook_page_about(request: FacebookPageAboutRequest):
if not response.success:
raise HTTPException(status_code=400, detail=response.error)
# Save and track text content (non-blocking)
if response.about_section:
try:
user_id = None
if current_user:
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
if user_id:
save_and_track_text_content(
db=db,
user_id=user_id,
content=response.about_section,
source_module="facebook_writer",
title=f"Facebook Page About: {request.business_name[:60]}",
description=f"Facebook page about section for {request.business_name}",
prompt=f"Business Name: {request.business_name}\nBusiness Type: {getattr(request, 'business_type', 'N/A')}",
tags=["facebook", "page_about", request.business_name.lower().replace(' ', '_')[:20]],
asset_metadata={
"business_name": request.business_name,
"business_type": getattr(request, 'business_type', None)
},
subdirectory="page_about"
)
except Exception as track_error:
logger.warning(f"Failed to track Facebook page about asset: {track_error}")
return response
except Exception as e:
@@ -238,7 +486,11 @@ async def generate_facebook_page_about(request: FacebookPageAboutRequest):
# Marketing Tools Endpoints
@router.post("/ad-copy/generate", response_model=FacebookAdCopyResponse)
async def generate_facebook_ad_copy(request: FacebookAdCopyRequest):
async def generate_facebook_ad_copy(
request: FacebookAdCopyRequest,
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Generate Facebook ad copy with targeting suggestions."""
try:
logger.info(f"Generating Facebook ad copy for: {request.business_type}")
@@ -247,6 +499,41 @@ async def generate_facebook_ad_copy(request: FacebookAdCopyRequest):
if not response.success:
raise HTTPException(status_code=400, detail=response.error)
# Save and track text content (non-blocking)
if response.ad_copy:
try:
user_id = None
if current_user:
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
if user_id:
text_content = f"# Facebook Ad Copy\n\n## Ad Copy\n{response.ad_copy}\n"
if hasattr(response, 'headline') and response.headline:
text_content += f"\n## Headline\n{response.headline}\n"
if hasattr(response, 'description') and response.description:
text_content += f"\n## Description\n{response.description}\n"
if hasattr(response, 'targeting_suggestions') and response.targeting_suggestions:
text_content += f"\n## Targeting Suggestions\n" + "\n".join(response.targeting_suggestions) + "\n"
save_and_track_text_content(
db=db,
user_id=user_id,
content=text_content,
source_module="facebook_writer",
title=f"Facebook Ad Copy: {request.business_type[:60]}",
description=f"Facebook ad copy for {request.business_type}",
prompt=f"Business Type: {request.business_type}\nAd Objective: {getattr(request, 'ad_objective', 'N/A')}\nTarget Audience: {getattr(request, 'target_audience', 'N/A')}",
tags=["facebook", "ad_copy", request.business_type.lower().replace(' ', '_')],
asset_metadata={
"ad_objective": getattr(request, 'ad_objective', None),
"budget": getattr(request, 'budget', None)
},
subdirectory="ad_copy",
file_extension=".md"
)
except Exception as track_error:
logger.warning(f"Failed to track Facebook ad copy asset: {track_error}")
return response
except Exception as e:

View File

@@ -2,10 +2,14 @@ from __future__ import annotations
import base64
import os
import uuid
from typing import Optional, Dict, Any
from datetime import datetime
from pathlib import Path
from sqlalchemy.orm import Session
from fastapi import APIRouter, HTTPException, Depends
from fastapi import APIRouter, HTTPException, Depends, Request
from fastapi.responses import FileResponse
from pydantic import BaseModel, Field
from services.llm_providers.main_image_generation import generate_image
@@ -16,6 +20,8 @@ from middleware.auth_middleware import get_current_user
from services.database import get_db
from services.subscription import UsageTrackingService, PricingService
from models.subscription_models import APIProvider, UsageSummary
from utils.asset_tracker import save_asset_to_library
from utils.file_storage import save_file_safely, generate_unique_filename, sanitize_filename
router = APIRouter(prefix="/api/images", tags=["images"])
@@ -37,6 +43,7 @@ class ImageGenerateRequest(BaseModel):
class ImageGenerateResponse(BaseModel):
success: bool = True
image_base64: str
image_url: Optional[str] = None # URL to saved image file
width: int
height: int
provider: str
@@ -47,7 +54,8 @@ class ImageGenerateResponse(BaseModel):
@router.post("/generate", response_model=ImageGenerateResponse)
def generate(
req: ImageGenerateRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db)
) -> ImageGenerateResponse:
"""Generate image with subscription checking."""
try:
@@ -80,6 +88,78 @@ def generate(
)
image_b64 = base64.b64encode(result.image_bytes).decode("utf-8")
# Save image to disk and track in asset library
image_url = None
image_filename = None
image_path = None
try:
# Create output directory for image studio images
base_dir = Path(__file__).parent.parent
output_dir = base_dir / "image_studio_images"
# Generate safe filename from prompt
clean_prompt = sanitize_filename(req.prompt[:50], max_length=50)
image_filename = generate_unique_filename(
prefix=f"img_{clean_prompt}",
extension=".png",
include_uuid=True
)
# Save file safely
image_path, save_error = save_file_safely(
content=result.image_bytes,
directory=output_dir,
filename=image_filename,
max_file_size=50 * 1024 * 1024 # 50MB for images
)
if image_path and not save_error:
# Generate file URL (will be served via API endpoint)
image_url = f"/api/images/image-studio/images/{image_path.name}"
logger.info(f"[images.generate] Saved image to: {image_path} ({len(result.image_bytes)} bytes)")
# Save to asset library (non-blocking)
try:
asset_id = save_asset_to_library(
db=db,
user_id=user_id,
asset_type="image",
source_module="image_studio",
filename=image_path.name,
file_url=image_url,
file_path=str(image_path),
file_size=len(result.image_bytes),
mime_type="image/png",
title=req.prompt[:100] if len(req.prompt) <= 100 else req.prompt[:97] + "...",
description=f"Generated image: {req.prompt[:200]}" if len(req.prompt) > 200 else req.prompt,
prompt=req.prompt,
tags=["image_studio", "generated", result.provider] if result.provider else ["image_studio", "generated"],
provider=result.provider,
model=result.model,
asset_metadata={
"width": result.width,
"height": result.height,
"seed": result.seed,
"status": "completed",
"negative_prompt": req.negative_prompt
}
)
if asset_id:
logger.info(f"[images.generate] ✅ Asset saved to library: ID={asset_id}, filename={image_path.name}")
else:
logger.warning(f"[images.generate] Asset tracking returned None (may have failed silently)")
except Exception as asset_error:
logger.error(f"[images.generate] Failed to save asset to library: {asset_error}", exc_info=True)
# Don't fail the request if asset tracking fails
else:
logger.warning(f"[images.generate] Failed to save image to disk: {save_error}")
# Continue without failing the request - base64 is still available
except Exception as save_error:
logger.error(f"[images.generate] Unexpected error saving image: {save_error}", exc_info=True)
# Continue without failing the request
# TRACK USAGE after successful image generation
if result:
logger.info(f"[images.generate] ✅ Image generation successful, tracking usage for user {user_id}")
@@ -168,6 +248,7 @@ def generate(
return ImageGenerateResponse(
image_base64=image_b64,
image_url=image_url,
width=result.width,
height=result.height,
provider=result.provider,
@@ -226,6 +307,7 @@ class ImageEditRequest(BaseModel):
class ImageEditResponse(BaseModel):
success: bool = True
image_base64: str
image_url: Optional[str] = None # URL to saved edited image file
width: int
height: int
provider: str
@@ -358,7 +440,8 @@ def suggest_prompts(
@router.post("/edit", response_model=ImageEditResponse)
def edit(
req: ImageEditRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db)
) -> ImageEditResponse:
"""Edit image with subscription checking."""
try:
@@ -391,6 +474,78 @@ def edit(
)
edited_image_b64 = base64.b64encode(result.image_bytes).decode("utf-8")
# Save edited image to disk and track in asset library
image_url = None
image_filename = None
image_path = None
try:
# Create output directory for image studio edited images
base_dir = Path(__file__).parent.parent
output_dir = base_dir / "image_studio_images" / "edited"
# Generate safe filename from prompt
clean_prompt = sanitize_filename(req.prompt[:50], max_length=50)
image_filename = generate_unique_filename(
prefix=f"edited_{clean_prompt}",
extension=".png",
include_uuid=True
)
# Save file safely
image_path, save_error = save_file_safely(
content=result.image_bytes,
directory=output_dir,
filename=image_filename,
max_file_size=50 * 1024 * 1024 # 50MB for images
)
if image_path and not save_error:
# Generate file URL
image_url = f"/api/images/image-studio/images/edited/{image_path.name}"
logger.info(f"[images.edit] Saved edited image to: {image_path} ({len(result.image_bytes)} bytes)")
# Save to asset library (non-blocking)
try:
asset_id = save_asset_to_library(
db=db,
user_id=user_id,
asset_type="image",
source_module="image_studio",
filename=image_path.name,
file_url=image_url,
file_path=str(image_path),
file_size=len(result.image_bytes),
mime_type="image/png",
title=f"Edited: {req.prompt[:100]}" if len(req.prompt) <= 100 else f"Edited: {req.prompt[:97]}...",
description=f"Edited image with prompt: {req.prompt[:200]}" if len(req.prompt) > 200 else f"Edited image with prompt: {req.prompt}",
prompt=req.prompt,
tags=["image_studio", "edited", result.provider] if result.provider else ["image_studio", "edited"],
provider=result.provider,
model=result.model,
asset_metadata={
"width": result.width,
"height": result.height,
"seed": result.seed,
"status": "completed",
"operation": "edit"
}
)
if asset_id:
logger.info(f"[images.edit] ✅ Asset saved to library: ID={asset_id}, filename={image_path.name}")
else:
logger.warning(f"[images.edit] Asset tracking returned None (may have failed silently)")
except Exception as asset_error:
logger.error(f"[images.edit] Failed to save asset to library: {asset_error}", exc_info=True)
# Don't fail the request if asset tracking fails
else:
logger.warning(f"[images.edit] Failed to save edited image to disk: {save_error}")
# Continue without failing the request - base64 is still available
except Exception as save_error:
logger.error(f"[images.edit] Unexpected error saving edited image: {save_error}", exc_info=True)
# Continue without failing the request
# TRACK USAGE after successful image editing
if result:
logger.info(f"[images.edit] ✅ Image editing successful, tracking usage for user {user_id}")
@@ -478,6 +633,7 @@ def edit(
return ImageEditResponse(
image_base64=edited_image_b64,
image_url=image_url,
width=result.width,
height=result.height,
provider=result.provider,
@@ -494,3 +650,55 @@ def edit(
detail="Image editing service is temporarily unavailable or the connection was reset. Please try again."
)
# ---------------------------
# Image Serving Endpoints
# ---------------------------
@router.get("/image-studio/images/{image_filename:path}")
async def serve_image_studio_image(
image_filename: str,
current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Serve a generated or edited image from Image Studio."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
# Determine if it's an edited image or regular image
base_dir = Path(__file__).parent.parent
image_studio_dir = (base_dir / "image_studio_images").resolve()
if image_filename.startswith("edited/"):
# Remove "edited/" prefix and serve from edited directory
actual_filename = image_filename.replace("edited/", "", 1)
image_path = (image_studio_dir / "edited" / actual_filename).resolve()
base_subdir = (image_studio_dir / "edited").resolve()
else:
image_path = (image_studio_dir / image_filename).resolve()
base_subdir = image_studio_dir
# Security: Prevent directory traversal attacks
# Ensure the resolved path is within the intended directory
try:
image_path.relative_to(base_subdir)
except ValueError:
raise HTTPException(
status_code=403,
detail="Access denied: Invalid image path"
)
if not image_path.exists():
raise HTTPException(status_code=404, detail="Image not found")
return FileResponse(
path=str(image_path),
media_type="image/png",
filename=image_path.name
)
except HTTPException:
raise
except Exception as e:
logger.error(f"[images] Failed to serve image: {e}")
raise HTTPException(status_code=500, detail=str(e))

File diff suppressed because it is too large Load Diff

View File

@@ -4,18 +4,20 @@ Each module focuses on a related set of routes to keep the primary
`router.py` concise and easier to maintain.
"""
from . import story_setup
from . import story_content
from . import story_tasks
from . import media_generation
from . import video_generation
from . import cache_routes
from . import media_generation
from . import scene_animation
from . import story_content
from . import story_setup
from . import story_tasks
from . import video_generation
__all__ = [
"story_setup",
"story_content",
"story_tasks",
"media_generation",
"video_generation",
"cache_routes",
"media_generation",
"scene_animation",
"story_content",
"story_setup",
"story_tasks",
"video_generation",
]

View File

@@ -1,8 +1,10 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse
from loguru import logger
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
from models.story_models import (
@@ -18,8 +20,10 @@ from models.story_models import (
GenerateAIAudioResponse,
StoryScene,
)
from services.database import get_db
from services.story_writer.image_generation_service import StoryImageGenerationService
from services.story_writer.audio_generation_service import StoryAudioGenerationService
from utils.asset_tracker import save_asset_to_library
from ..utils.auth import require_authenticated_user
from ..utils.media_utils import resolve_media_file
@@ -34,6 +38,7 @@ audio_service = StoryAudioGenerationService()
async def generate_scene_images(
request: StoryImageGenerationRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db),
) -> StoryImageGenerationResponse:
"""Generate images for story scenes."""
try:
@@ -70,6 +75,37 @@ async def generate_scene_images(
for result in image_results
]
# Save assets to library
for result in image_results:
if not result.get("error") and result.get("image_url"):
try:
scene_number = result.get("scene_number", 0)
# Safely get prompt from scenes_data with bounds checking
prompt = None
if scene_number > 0 and scene_number <= len(scenes_data):
prompt = scenes_data[scene_number - 1].get("image_prompt")
save_asset_to_library(
db=db,
user_id=user_id,
asset_type="image",
source_module="story_writer",
filename=result.get("image_filename", ""),
file_url=result.get("image_url", ""),
file_path=result.get("image_path"),
file_size=result.get("file_size"),
mime_type="image/png",
title=f"Scene {scene_number}: {result.get('scene_title', 'Untitled')}",
description=f"Story scene image for scene {scene_number}",
prompt=prompt,
tags=["story_writer", "scene", f"scene_{scene_number}"],
provider=result.get("provider"),
model=result.get("model"),
asset_metadata={"scene_number": scene_number, "scene_title": result.get("scene_title"), "status": "completed"}
)
except Exception as e:
logger.warning(f"[StoryWriter] Failed to save image asset to library: {e}")
return StoryImageGenerationResponse(images=image_models, success=True)
except HTTPException:
@@ -163,6 +199,7 @@ async def serve_scene_image(
async def generate_scene_audio(
request: StoryAudioGenerationRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db),
) -> StoryAudioGenerationResponse:
"""Generate audio narration for story scenes."""
try:
@@ -185,18 +222,52 @@ async def generate_scene_audio(
audio_models: List[StoryAudioResult] = []
for result in audio_results:
audio_url = result.get("audio_url") or ""
audio_filename = result.get("audio_filename") or ""
audio_models.append(
StoryAudioResult(
scene_number=result.get("scene_number", 0),
scene_title=result.get("scene_title", "Untitled"),
audio_filename=result.get("audio_filename") or "",
audio_url=result.get("audio_url") or "",
audio_filename=audio_filename,
audio_url=audio_url,
provider=result.get("provider", "unknown"),
file_size=result.get("file_size", 0),
error=result.get("error"),
)
)
# Save assets to library
if not result.get("error") and audio_url:
try:
scene_number = result.get("scene_number", 0)
# Safely get prompt from scenes_data with bounds checking
prompt = None
if scene_number > 0 and scene_number <= len(scenes_data):
prompt = scenes_data[scene_number - 1].get("text")
save_asset_to_library(
db=db,
user_id=user_id,
asset_type="audio",
source_module="story_writer",
filename=audio_filename,
file_url=audio_url,
file_path=result.get("audio_path"),
file_size=result.get("file_size"),
mime_type="audio/mpeg",
title=f"Scene {scene_number}: {result.get('scene_title', 'Untitled')}",
description=f"Story scene audio narration for scene {scene_number}",
prompt=prompt,
tags=["story_writer", "audio", "narration", f"scene_{scene_number}"],
provider=result.get("provider"),
model=result.get("model"),
cost=result.get("cost"),
asset_metadata={"scene_number": scene_number, "scene_title": result.get("scene_title"), "status": "completed"}
)
except Exception as e:
logger.warning(f"[StoryWriter] Failed to save audio asset to library: {e}")
return StoryAudioGenerationResponse(audio_files=audio_models, success=True)
except HTTPException:
@@ -287,3 +358,59 @@ async def serve_scene_audio(
raise HTTPException(status_code=500, detail=str(exc))
class PromptOptimizeRequest(BaseModel):
text: str = Field(..., description="The prompt text to optimize")
mode: Optional[str] = Field(default="image", pattern="^(image|video)$", description="Optimization mode: 'image' or 'video'")
style: Optional[str] = Field(
default="default",
pattern="^(default|artistic|photographic|technical|anime|realistic)$",
description="Style: 'default', 'artistic', 'photographic', 'technical', 'anime', or 'realistic'"
)
image: Optional[str] = Field(None, description="Base64-encoded image for context (optional)")
class PromptOptimizeResponse(BaseModel):
optimized_prompt: str
success: bool
@router.post("/optimize-prompt", response_model=PromptOptimizeResponse)
async def optimize_prompt(
request: PromptOptimizeRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> PromptOptimizeResponse:
"""Optimize an image prompt using WaveSpeed prompt optimizer."""
try:
user_id = require_authenticated_user(current_user)
if not request.text or not request.text.strip():
raise HTTPException(status_code=400, detail="Prompt text is required")
logger.info(f"[StoryWriter] Optimizing prompt for user {user_id} (mode={request.mode}, style={request.style})")
from services.wavespeed.client import WaveSpeedClient
client = WaveSpeedClient()
optimized_prompt = client.optimize_prompt(
text=request.text.strip(),
mode=request.mode or "image",
style=request.style or "default",
image=request.image, # Optional base64 image
enable_sync_mode=True,
timeout=30
)
logger.info(f"[StoryWriter] Prompt optimized successfully for user {user_id}")
return PromptOptimizeResponse(
optimized_prompt=optimized_prompt,
success=True
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to optimize prompt: {exc}")
raise HTTPException(status_code=500, detail=str(exc))

View File

@@ -0,0 +1,484 @@
"""
Scene Animation Routes
Handles scene animation endpoints using WaveSpeed Kling and InfiniteTalk.
"""
import mimetypes
from pathlib import Path
from typing import Any, Dict, Optional
from urllib.parse import quote
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request
from loguru import logger
from sqlalchemy.orm import Session
from middleware.auth_middleware import get_current_user
from models.story_models import (
AnimateSceneRequest,
AnimateSceneResponse,
AnimateSceneVoiceoverRequest,
ResumeSceneAnimationRequest,
)
from services.database import get_db
from services.llm_providers.main_video_generation import track_video_usage
from services.story_writer.video_generation_service import StoryVideoGenerationService
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_scene_animation_operation
from services.wavespeed.infinitetalk import animate_scene_with_voiceover
from services.wavespeed.kling_animation import animate_scene_image, resume_scene_animation
from utils.asset_tracker import save_asset_to_library
from utils.logger_utils import get_service_logger
from ..task_manager import task_manager
from ..utils.auth import require_authenticated_user
from ..utils.media_utils import load_story_audio_bytes, load_story_image_bytes
router = APIRouter()
scene_logger = get_service_logger("api.story_writer.scene_animation")
AI_VIDEO_SUBDIR = Path("AI_Videos")
def _build_authenticated_media_url(request: Request, path: str) -> str:
"""Append the caller's auth token to a media URL so <video>/<img> tags can access it."""
if not path:
return path
token: Optional[str] = None
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header.replace("Bearer ", "").strip()
elif "token" in request.query_params:
token = request.query_params["token"]
if token:
separator = "&" if "?" in path else "?"
path = f"{path}{separator}token={quote(token)}"
return path
def _guess_mime_from_url(url: str, fallback: str) -> str:
"""Guess MIME type from URL."""
if not url:
return fallback
mime, _ = mimetypes.guess_type(url)
return mime or fallback
@router.post("/animate-scene-preview", response_model=AnimateSceneResponse)
async def animate_scene_preview(
request_obj: Request,
request: AnimateSceneRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> AnimateSceneResponse:
"""
Animate a single scene image using WaveSpeed Kling v2.5 Turbo Std.
"""
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get("id", ""))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
duration = request.duration or 5
if duration not in (5, 10):
raise HTTPException(status_code=400, detail="Duration must be 5 or 10 seconds.")
scene_logger.info(
"[AnimateScene] User=%s scene=%s duration=%s image_url=%s",
user_id,
request.scene_number,
duration,
request.image_url,
)
image_bytes = load_story_image_bytes(request.image_url)
if not image_bytes:
scene_logger.warning("[AnimateScene] Missing image bytes for user=%s scene=%s", user_id, request.scene_number)
raise HTTPException(status_code=404, detail="Scene image not found. Generate images first.")
db = next(get_db())
try:
pricing_service = PricingService(db)
validate_scene_animation_operation(pricing_service=pricing_service, user_id=user_id)
finally:
db.close()
animation_result = animate_scene_image(
image_bytes=image_bytes,
scene_data=request.scene_data,
story_context=request.story_context,
user_id=user_id,
duration=duration,
)
base_dir = Path(__file__).parent.parent.parent.parent
ai_video_dir = base_dir / "story_videos" / AI_VIDEO_SUBDIR
ai_video_dir.mkdir(parents=True, exist_ok=True)
video_service = StoryVideoGenerationService(output_dir=str(ai_video_dir))
save_result = video_service.save_scene_video(
video_bytes=animation_result["video_bytes"],
scene_number=request.scene_number,
user_id=user_id,
)
video_filename = save_result["video_filename"]
video_url = _build_authenticated_media_url(
request_obj, f"/api/story/videos/ai/{video_filename}"
)
usage_info = track_video_usage(
user_id=user_id,
provider=animation_result["provider"],
model_name=animation_result["model_name"],
prompt=animation_result["prompt"],
video_bytes=animation_result["video_bytes"],
cost_override=animation_result["cost"],
)
if usage_info:
scene_logger.warning(
"[AnimateScene] Video usage tracked user=%s: %s%s / %s (cost +$%.2f, total=$%.2f)",
user_id,
usage_info.get("previous_calls"),
usage_info.get("current_calls"),
usage_info.get("video_limit_display"),
usage_info.get("cost_per_video", 0.0),
usage_info.get("total_video_cost", 0.0),
)
scene_logger.info(
"[AnimateScene] ✅ Completed user=%s scene=%s duration=%s cost=$%.2f video=%s",
user_id,
request.scene_number,
animation_result["duration"],
animation_result["cost"],
video_url,
)
# Save video asset to library
db = next(get_db())
try:
save_asset_to_library(
db=db,
user_id=user_id,
asset_type="video",
source_module="story_writer",
filename=video_filename,
file_url=video_url,
file_path=str(ai_video_dir / video_filename),
file_size=len(animation_result["video_bytes"]),
mime_type="video/mp4",
title=f"Scene {request.scene_number} Animation",
description=f"Animated scene {request.scene_number} from story",
prompt=animation_result["prompt"],
tags=["story_writer", "video", "animation", f"scene_{request.scene_number}"],
provider=animation_result["provider"],
model=animation_result.get("model_name"),
cost=animation_result["cost"],
asset_metadata={"scene_number": request.scene_number, "duration": animation_result["duration"], "status": "completed"}
)
except Exception as e:
logger.warning(f"[StoryWriter] Failed to save video asset to library: {e}")
finally:
db.close()
return AnimateSceneResponse(
success=True,
scene_number=request.scene_number,
video_filename=video_filename,
video_url=video_url,
duration=animation_result["duration"],
cost=animation_result["cost"],
prompt_used=animation_result["prompt"],
provider=animation_result["provider"],
prediction_id=animation_result.get("prediction_id"),
)
@router.post("/animate-scene-resume", response_model=AnimateSceneResponse)
async def resume_scene_animation_endpoint(
request_obj: Request,
request: ResumeSceneAnimationRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> AnimateSceneResponse:
"""Resume downloading a WaveSpeed animation when the initial call timed out."""
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get("id", ""))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
scene_logger.info(
"[AnimateScene] Resume requested user=%s scene=%s prediction=%s",
user_id,
request.scene_number,
request.prediction_id,
)
animation_result = resume_scene_animation(
prediction_id=request.prediction_id,
duration=request.duration or 5,
user_id=user_id,
)
base_dir = Path(__file__).parent.parent.parent.parent
ai_video_dir = base_dir / "story_videos" / AI_VIDEO_SUBDIR
ai_video_dir.mkdir(parents=True, exist_ok=True)
video_service = StoryVideoGenerationService(output_dir=str(ai_video_dir))
save_result = video_service.save_scene_video(
video_bytes=animation_result["video_bytes"],
scene_number=request.scene_number,
user_id=user_id,
)
video_filename = save_result["video_filename"]
video_url = _build_authenticated_media_url(
request_obj, f"/api/story/videos/ai/{video_filename}"
)
usage_info = track_video_usage(
user_id=user_id,
provider=animation_result["provider"],
model_name=animation_result["model_name"],
prompt=animation_result["prompt"],
video_bytes=animation_result["video_bytes"],
cost_override=animation_result["cost"],
)
if usage_info:
scene_logger.warning(
"[AnimateScene] (Resume) Video usage tracked user=%s: %s%s / %s (cost +$%.2f, total=$%.2f)",
user_id,
usage_info.get("previous_calls"),
usage_info.get("current_calls"),
usage_info.get("video_limit_display"),
usage_info.get("cost_per_video", 0.0),
usage_info.get("total_video_cost", 0.0),
)
scene_logger.info(
"[AnimateScene] ✅ Resume completed user=%s scene=%s prediction=%s video=%s",
user_id,
request.scene_number,
request.prediction_id,
video_url,
)
return AnimateSceneResponse(
success=True,
scene_number=request.scene_number,
video_filename=video_filename,
video_url=video_url,
duration=animation_result["duration"],
cost=animation_result["cost"],
prompt_used=animation_result["prompt"],
provider=animation_result["provider"],
prediction_id=animation_result.get("prediction_id"),
)
@router.post("/animate-scene-voiceover", response_model=Dict[str, Any])
async def animate_scene_voiceover_endpoint(
request_obj: Request,
request: AnimateSceneVoiceoverRequest,
background_tasks: BackgroundTasks,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> Dict[str, Any]:
"""
Animate a scene using WaveSpeed InfiniteTalk (image + audio) asynchronously.
Returns task_id for polling since InfiniteTalk can take up to 10 minutes.
"""
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get("id", ""))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
scene_logger.info(
"[AnimateSceneVoiceover] User=%s scene=%s resolution=%s (async)",
user_id,
request.scene_number,
request.resolution or "720p",
)
image_bytes = load_story_image_bytes(request.image_url)
if not image_bytes:
raise HTTPException(status_code=404, detail="Scene image not found. Generate images first.")
audio_bytes = load_story_audio_bytes(request.audio_url)
if not audio_bytes:
raise HTTPException(status_code=404, detail="Scene audio not found. Generate audio first.")
db = next(get_db())
try:
pricing_service = PricingService(db)
validate_scene_animation_operation(pricing_service=pricing_service, user_id=user_id)
finally:
db.close()
# Extract token for authenticated URL building (if needed)
auth_token = None
auth_header = request_obj.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
auth_token = auth_header.replace("Bearer ", "").strip()
# Create async task
task_id = task_manager.create_task("scene_voiceover_animation")
background_tasks.add_task(
_execute_voiceover_animation_task,
task_id=task_id,
request=request,
user_id=user_id,
image_bytes=image_bytes,
audio_bytes=audio_bytes,
auth_token=auth_token,
)
return {
"task_id": task_id,
"status": "pending",
"message": "InfiniteTalk animation started. This may take up to 10 minutes.",
}
def _execute_voiceover_animation_task(
task_id: str,
request: AnimateSceneVoiceoverRequest,
user_id: str,
image_bytes: bytes,
audio_bytes: bytes,
auth_token: Optional[str] = None,
):
"""Background task to generate InfiniteTalk video with progress updates."""
try:
task_manager.update_task_status(
task_id, "processing", progress=5.0, message="Submitting to WaveSpeed InfiniteTalk..."
)
animation_result = animate_scene_with_voiceover(
image_bytes=image_bytes,
audio_bytes=audio_bytes,
scene_data=request.scene_data,
story_context=request.story_context,
user_id=user_id,
resolution=request.resolution or "720p",
prompt_override=request.prompt,
image_mime=_guess_mime_from_url(request.image_url, "image/png"),
audio_mime=_guess_mime_from_url(request.audio_url, "audio/mpeg"),
)
task_manager.update_task_status(
task_id, "processing", progress=80.0, message="Saving video file..."
)
base_dir = Path(__file__).parent.parent.parent.parent
ai_video_dir = base_dir / "story_videos" / AI_VIDEO_SUBDIR
ai_video_dir.mkdir(parents=True, exist_ok=True)
video_service = StoryVideoGenerationService(output_dir=str(ai_video_dir))
save_result = video_service.save_scene_video(
video_bytes=animation_result["video_bytes"],
scene_number=request.scene_number,
user_id=user_id,
)
video_filename = save_result["video_filename"]
# Build authenticated URL if token provided, otherwise return plain URL
video_url = f"/api/story/videos/ai/{video_filename}"
if auth_token:
video_url = f"{video_url}?token={quote(auth_token)}"
usage_info = track_video_usage(
user_id=user_id,
provider=animation_result["provider"],
model_name=animation_result["model_name"],
prompt=animation_result["prompt"],
video_bytes=animation_result["video_bytes"],
cost_override=animation_result["cost"],
)
if usage_info:
scene_logger.warning(
"[AnimateSceneVoiceover] Video usage tracked user=%s: %s%s / %s (cost +$%.2f, total=$%.2f)",
user_id,
usage_info.get("previous_calls"),
usage_info.get("current_calls"),
usage_info.get("video_limit_display"),
usage_info.get("cost_per_video", 0.0),
usage_info.get("total_video_cost", 0.0),
)
scene_logger.info(
"[AnimateSceneVoiceover] ✅ Completed user=%s scene=%s cost=$%.2f video=%s",
user_id,
request.scene_number,
animation_result["cost"],
video_url,
)
# Save video asset to library
db = next(get_db())
try:
save_asset_to_library(
db=db,
user_id=user_id,
asset_type="video",
source_module="story_writer",
filename=video_filename,
file_url=video_url,
file_path=str(ai_video_dir / video_filename),
file_size=len(animation_result["video_bytes"]),
mime_type="video/mp4",
title=f"Scene {request.scene_number} Animation (Voiceover)",
description=f"Animated scene {request.scene_number} with voiceover from story",
prompt=animation_result["prompt"],
tags=["story_writer", "video", "animation", "voiceover", f"scene_{request.scene_number}"],
provider=animation_result["provider"],
model=animation_result.get("model_name"),
cost=animation_result["cost"],
asset_metadata={"scene_number": request.scene_number, "duration": animation_result["duration"], "status": "completed"}
)
except Exception as e:
logger.warning(f"[StoryWriter] Failed to save video asset to library: {e}")
finally:
db.close()
result = AnimateSceneResponse(
success=True,
scene_number=request.scene_number,
video_filename=video_filename,
video_url=video_url,
duration=animation_result["duration"],
cost=animation_result["cost"],
prompt_used=animation_result["prompt"],
provider=animation_result["provider"],
prediction_id=animation_result.get("prediction_id"),
)
task_manager.update_task_status(
task_id,
"completed",
progress=100.0,
message="InfiniteTalk animation complete!",
result=result.dict(),
)
except HTTPException as exc:
error_msg = str(exc.detail) if isinstance(exc.detail, str) else exc.detail.get("error", "Animation failed") if isinstance(exc.detail, dict) else "Animation failed"
scene_logger.error(f"[AnimateSceneVoiceover] Failed: {error_msg}")
task_manager.update_task_status(
task_id,
"failed",
error=error_msg,
message=f"InfiniteTalk animation failed: {error_msg}",
)
except Exception as exc:
error_msg = str(exc)
scene_logger.error(f"[AnimateSceneVoiceover] Error: {error_msg}", exc_info=True)
task_manager.update_task_status(
task_id,
"failed",
error=error_msg,
message=f"InfiniteTalk animation error: {error_msg}",
)

View File

@@ -1,7 +1,9 @@
from typing import Any, Dict, List
from datetime import datetime
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException
from loguru import logger
from pydantic import BaseModel
from middleware.auth_middleware import get_current_user
from models.story_models import (
@@ -18,6 +20,7 @@ from ..utils.auth import require_authenticated_user
router = APIRouter()
story_service = StoryWriterService()
scene_approval_store: Dict[str, Dict[str, Any]] = {}
@router.post("/generate-start", response_model=StoryContentResponse)
@@ -193,3 +196,45 @@ async def continue_story(
raise HTTPException(status_code=500, detail=str(exc))
class SceneApprovalRequest(BaseModel):
project_id: str
scene_id: str
approved: bool = True
notes: Optional[str] = None
@router.post("/script/approve")
async def approve_script_scene(
request: SceneApprovalRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> Dict[str, Any]:
"""Persist scene approval metadata for auditing."""
try:
user_id = require_authenticated_user(current_user)
approvals = scene_approval_store.setdefault(request.project_id, {})
approvals[request.scene_id] = {
"approved": request.approved,
"notes": request.notes,
"user_id": user_id,
"timestamp": datetime.utcnow().isoformat(),
}
logger.info(
"[StoryWriter] Scene approval recorded user=%s project=%s scene=%s approved=%s",
user_id,
request.project_id,
request.scene_id,
request.approved,
)
return {
"success": True,
"project_id": request.project_id,
"scene_id": request.scene_id,
"approved": request.approved,
}
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to approve scene: {exc}")
raise HTTPException(status_code=500, detail=str(exc))

View File

@@ -509,3 +509,30 @@ async def serve_story_video(
raise HTTPException(status_code=500, detail=str(exc))
@router.get("/videos/ai/{video_filename}")
async def serve_ai_story_video(
video_filename: str,
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Serve a generated AI scene animation video."""
try:
require_authenticated_user(current_user)
base_dir = Path(__file__).parent.parent.parent.parent
ai_video_dir = (base_dir / "story_videos" / "AI_Videos").resolve()
video_service_ai = StoryVideoGenerationService(output_dir=str(ai_video_dir))
video_path = resolve_media_file(video_service_ai.output_dir, video_filename)
return FileResponse(
path=str(video_path),
media_type="video/mp4",
filename=video_filename
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to serve AI video: {exc}")
raise HTTPException(status_code=500, detail=str(exc))

View File

@@ -53,6 +53,7 @@ from api.linkedin_image_generation import router as linkedin_image_router
from api.brainstorm import router as brainstorm_router
from api.images import router as images_router
from routers.image_studio import router as image_studio_router
from routers.product_marketing import router as product_marketing_router
# Import hallucination detector router
from api.hallucination_detector import router as hallucination_detector_router
@@ -298,6 +299,7 @@ from routers.platform_analytics import router as platform_analytics_router
app.include_router(platform_analytics_router)
app.include_router(images_router)
app.include_router(image_studio_router)
app.include_router(product_marketing_router)
# Include content assets router
from api.content_assets.router import router as content_assets_router

View File

@@ -0,0 +1,264 @@
# Asset Tracking Implementation Guide
## Overview
This document describes the production-ready implementation of asset tracking across all ALwrity modules. The unified Content Asset Library automatically tracks all AI-generated content (images, videos, audio, text) for easy management and organization.
## Architecture
### Core Components
1. **Database Models** (`backend/models/content_asset_models.py`)
- `ContentAsset`: Main model for tracking assets
- `AssetCollection`: Collections/albums for organizing assets
- `AssetType`: Enum (text, image, video, audio)
- `AssetSource`: Enum (all ALwrity modules)
2. **Service Layer** (`backend/services/content_asset_service.py`)
- CRUD operations for assets
- Search, filter, pagination
- Usage tracking
3. **Utility Functions**
- `backend/utils/asset_tracker.py`: `save_asset_to_library()` helper
- `backend/utils/file_storage.py`: Robust file saving utilities
## Implementation Status
### ✅ Completed Integrations
#### 1. Story Writer (`backend/api/story_writer/router.py`)
- **Images**: Tracks all scene images with metadata
- **Audio**: Tracks all scene audio files with narration details
- **Videos**: Tracks individual scene videos and complete story videos
- **Location**: After generation in `/generate-images`, `/generate-audio`, `/generate-video`, `/generate-complete-video`
- **Metadata**: Includes prompts, scene numbers, providers, models, costs, status
#### 2. Image Studio (`backend/api/images.py`)
- **Image Generation**: Tracks all generated images
- **Image Editing**: Tracks all edited images
- **Location**: After generation in `/api/images/generate` and `/api/images/edit`
- **Features**:
- Robust file saving with validation
- Atomic file writes
- Proper error handling (non-blocking)
- File serving endpoint at `/api/images/image-studio/images/{filename}`
### 📝 Notes on Other Modules
#### Main Generation Services
- **Text Generation** (`main_text_generation.py`): Returns strings, not files. If text content needs tracking, save to `.txt` or `.md` files first.
- **Video Generation** (`main_video_generation.py`): Already integrated via Story Writer
- **Audio Generation** (`main_audio_generation.py`): Already integrated via Story Writer
#### Social Writers
- **LinkedIn Writer**: Generates text content (posts, articles). No file generation currently.
- **Facebook Writer**: Generates text content (posts, stories, reels). No file generation currently.
- **Blog Writer**: Generates blog content. May generate images in future.
**Note**: If these modules generate files in the future, follow the integration pattern below.
## Integration Pattern
### For Image Generation
```python
from utils.asset_tracker import save_asset_to_library
from utils.file_storage import save_file_safely, generate_unique_filename
from sqlalchemy.orm import Session
from pathlib import Path
# After successful image generation
try:
base_dir = Path(__file__).parent.parent
output_dir = base_dir / "module_images"
image_filename = generate_unique_filename(
prefix="img_prompt",
extension=".png",
include_uuid=True
)
# Save file safely
image_path, save_error = save_file_safely(
content=result.image_bytes,
directory=output_dir,
filename=image_filename,
max_file_size=50 * 1024 * 1024 # 50MB
)
if image_path and not save_error:
image_url = f"/api/module/images/{image_path.name}"
# Track in asset library (non-blocking)
try:
asset_id = save_asset_to_library(
db=db,
user_id=user_id,
asset_type="image",
source_module="module_name",
filename=image_path.name,
file_url=image_url,
file_path=str(image_path),
file_size=len(result.image_bytes),
mime_type="image/png",
title="Image Title",
description="Image description",
prompt=prompt,
tags=["tag1", "tag2"],
provider=result.provider,
model=result.model,
metadata={"status": "completed"}
)
logger.info(f"✅ Asset saved: ID={asset_id}")
except Exception as e:
logger.error(f"Asset tracking failed: {e}", exc_info=True)
# Don't fail the request
except Exception as e:
logger.error(f"File save failed: {e}", exc_info=True)
# Continue - base64 is still available
```
### For Video Generation
```python
# After successful video generation
try:
asset_id = save_asset_to_library(
db=db,
user_id=user_id,
asset_type="video",
source_module="module_name",
filename=video_filename,
file_url=video_url,
file_path=str(video_path),
file_size=file_size,
mime_type="video/mp4",
title="Video Title",
description="Video description",
prompt=prompt,
tags=["video", "tag"],
provider=provider,
model=model,
cost=cost,
metadata={"duration": duration, "status": "completed"}
)
except Exception as e:
logger.error(f"Asset tracking failed: {e}", exc_info=True)
```
### For Audio Generation
```python
# After successful audio generation
try:
asset_id = save_asset_to_library(
db=db,
user_id=user_id,
asset_type="audio",
source_module="module_name",
filename=audio_filename,
file_url=audio_url,
file_path=str(audio_path),
file_size=file_size,
mime_type="audio/mpeg",
title="Audio Title",
description="Audio description",
prompt=text,
tags=["audio", "tag"],
provider=provider,
model=model,
cost=cost,
metadata={"status": "completed"}
)
except Exception as e:
logger.error(f"Asset tracking failed: {e}", exc_info=True)
```
## Best Practices
### 1. Error Handling
- **Always non-blocking**: Asset tracking failures should never break the main request
- **Log errors**: Use `logger.error()` with `exc_info=True` for debugging
- **Graceful degradation**: Continue with base64/file response even if tracking fails
### 2. File Management
- **Use `save_file_safely()`**: Handles validation, atomic writes, directory creation
- **Sanitize filenames**: Use `sanitize_filename()` to prevent path traversal
- **Unique filenames**: Use `generate_unique_filename()` with UUIDs
- **File size limits**: Enforce reasonable limits (50MB for images, 100MB for videos)
### 3. Database Sessions
- **Pass session explicitly**: Use `db: Session = Depends(get_db)` in endpoints
- **Handle session lifecycle**: Let FastAPI manage session cleanup
- **Background tasks**: Get new session in background tasks
### 4. Metadata
- **Rich metadata**: Include provider, model, dimensions, cost, status
- **Searchable tags**: Use consistent tag naming (e.g., "image_studio", "generated")
- **Status tracking**: Always include `"status": "completed"` in metadata
### 5. File URLs
- **Consistent patterns**: Use `/api/{module}/images/{filename}` format
- **Serving endpoints**: Create corresponding GET endpoints to serve files
- **Authentication**: Protect file serving endpoints with `get_current_user`
## File Storage Utilities
### `save_file_safely()`
- Validates file size
- Creates directories automatically
- Atomic writes (temp file + rename)
- Returns `(file_path, error_message)` tuple
### `sanitize_filename()`
- Removes dangerous characters
- Prevents path traversal
- Limits filename length
- Handles empty filenames
### `generate_unique_filename()`
- Creates unique filenames with UUIDs
- Sanitizes prefix
- Handles extensions properly
## Testing Checklist
- [ ] Images are saved to disk correctly
- [ ] Files are accessible via serving endpoints
- [ ] Asset tracking works (check database)
- [ ] Errors don't break main requests
- [ ] File size limits are enforced
- [ ] Filenames are sanitized properly
- [ ] Metadata is complete and accurate
- [ ] Asset Library UI displays assets correctly
## Future Enhancements
1. **Text Content Tracking**: Save text content as files when needed
2. **Batch Operations**: Track multiple assets in single transaction
3. **File Cleanup**: Automatic cleanup of orphaned files
4. **Storage Backends**: Support S3, GCS for production
5. **Thumbnail Generation**: Auto-generate thumbnails for videos/images
6. **Compression**: Compress large files before storage
## Troubleshooting
### Assets not appearing in library
1. Check database: `SELECT * FROM content_assets WHERE user_id = '...'`
2. Check logs for asset tracking errors
3. Verify `save_asset_to_library()` returns asset ID
4. Check file URLs are correct
### File serving fails
1. Verify file exists on disk
2. Check serving endpoint is registered
3. Verify authentication is working
4. Check file permissions
### Performance issues
1. Use background tasks for heavy operations
2. Batch database operations
3. Consider async file I/O for large files
4. Monitor database query performance

View File

@@ -0,0 +1,143 @@
# Text Asset Tracking Implementation
## Overview
Text content tracking has been successfully implemented across LinkedIn Writer and Facebook Writer endpoints. All generated text content is automatically saved as files and tracked in the unified Content Asset Library.
## Implementation Status
### ✅ Completed Integrations
#### 1. LinkedIn Writer (`backend/routers/linkedin.py`)
- **Post Generation**: Tracks LinkedIn posts with content, hashtags, and CTAs
- **Article Generation**: Tracks LinkedIn articles with full content, sections, and SEO metadata
- **Carousel Generation**: Tracks LinkedIn carousels with all slides
- **Video Script Generation**: Tracks LinkedIn video scripts with hooks, scenes, captions
- **Comment Response Generation**: Tracks LinkedIn comment responses
**File Format**: Markdown (`.md`) for articles, carousels, video scripts, comment responses; Text (`.txt`) for posts
**Storage Location**: `backend/linkedinwriter_text/{subdirectory}/`
- `posts/` - LinkedIn posts
- `articles/` - LinkedIn articles
- `carousels/` - LinkedIn carousels
- `video_scripts/` - LinkedIn video scripts
- `comment_responses/` - LinkedIn comment responses
#### 2. Facebook Writer (`backend/api/facebook_writer/routers/facebook_router.py`)
- **Post Generation**: Tracks Facebook posts with content and analytics
- **Story Generation**: Tracks Facebook stories
**File Format**: Text (`.txt`)
**Storage Location**: `backend/facebookwriter_text/{subdirectory}/`
- `posts/` - Facebook posts
- `stories/` - Facebook stories
### 📝 Pending Integrations
#### Facebook Writer (Additional Endpoints)
- Reel Generation
- Carousel Generation
- Event Generation
- Group Post Generation
- Page About Generation
- Ad Copy Generation
- Hashtag Generation
#### Blog Writer (`backend/api/blog_writer/router.py`)
- Blog content generation endpoints
- Medium blog generation
- Blog section generation
## Architecture
### Core Components
1. **Text Asset Tracker** (`backend/utils/text_asset_tracker.py`)
- `save_and_track_text_content()`: Main function for saving and tracking text
- Handles file saving, URL generation, and asset library tracking
- Non-blocking error handling
2. **File Storage Utilities** (`backend/utils/file_storage.py`)
- `save_text_file_safely()`: Safely saves text files with validation
- `sanitize_filename()`: Prevents path traversal
- `generate_unique_filename()`: Creates unique filenames
3. **Asset Tracker** (`backend/utils/asset_tracker.py`)
- `save_asset_to_library()`: Saves asset metadata to database
## Integration Pattern
### Basic Integration
```python
from utils.text_asset_tracker import save_and_track_text_content
from sqlalchemy.orm import Session
from middleware.auth_middleware import get_current_user
@router.post("/generate-content")
async def generate_content(
request: ContentRequest,
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
db: Session = Depends(get_db)
):
# Generate content
response = await service.generate(request)
# Save and track text content (non-blocking)
if response.content:
try:
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
if user_id:
save_and_track_text_content(
db=db,
user_id=user_id,
content=response.content,
source_module="module_name",
title=f"Content Title: {request.topic[:60]}",
description=f"Content description",
prompt=f"Topic: {request.topic}",
tags=["tag1", "tag2"],
metadata={"key": "value"},
subdirectory="content_type"
)
except Exception as track_error:
logger.warning(f"Failed to track text asset: {track_error}")
return response
```
## File Serving
Text files are saved with URLs like `/api/text-assets/{module}/{subdirectory}/{filename}`. A serving endpoint should be created in `backend/app.py`:
```python
@router.get("/api/text-assets/{file_path:path}")
async def serve_text_asset(
file_path: str,
current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Serve text assets with authentication."""
# Implementation needed
pass
```
## Best Practices
1. **Non-blocking**: Text tracking failures should never break the main request
2. **Error Handling**: Use try/except around tracking calls
3. **User ID Extraction**: Support both `current_user` dependency and header-based extraction
4. **Content Formatting**: Combine related content (e.g., post + hashtags + CTA)
5. **Metadata**: Include rich metadata for search and filtering
6. **File Organization**: Use subdirectories to organize by content type
## Next Steps
1. Add text tracking to remaining Facebook Writer endpoints
2. Add text tracking to Blog Writer endpoints
3. Create text asset serving endpoint
4. Add text preview in Asset Library UI
5. Support text file downloads

View File

@@ -22,41 +22,31 @@ class AssetType(enum.Enum):
class AssetSource(enum.Enum):
"""Source module/tool that generated the asset - covers ALL ALwrity tools."""
# Image Studio modules
IMAGE_STUDIO_CREATE = "image_studio_create"
IMAGE_STUDIO_EDIT = "image_studio_edit"
IMAGE_STUDIO_UPSCALE = "image_studio_upscale"
IMAGE_STUDIO_TRANSFORM = "image_studio_transform"
IMAGE_STUDIO_CONTROL = "image_studio_control"
IMAGE_STUDIO_SOCIAL = "image_studio_social"
IMAGE_STUDIO_BATCH = "image_studio_batch"
# Content Writers
"""Source module/tool that generated the asset."""
# Core Content Generation
STORY_WRITER = "story_writer"
BLOG_WRITER = "blog_writer"
LINKEDIN_WRITER = "linkedin_writer"
FACEBOOK_WRITER = "facebook_writer"
# Content Planning
CONTENT_PLANNING = "content_planning"
CONTENT_STRATEGY = "content_strategy"
# SEO Tools
SEO_DASHBOARD = "seo_dashboard"
SEO_TOOLS = "seo_tools"
# Research
RESEARCH = "research"
# Scheduler
SCHEDULER = "scheduler"
# Main Generation (legacy/fallback)
IMAGE_STUDIO = "image_studio"
MAIN_TEXT_GENERATION = "main_text_generation"
MAIN_IMAGE_GENERATION = "main_image_generation"
MAIN_VIDEO_GENERATION = "main_video_generation"
MAIN_AUDIO_GENERATION = "main_audio_generation"
# Social Media Writers
BLOG_WRITER = "blog_writer"
LINKEDIN_WRITER = "linkedin_writer"
FACEBOOK_WRITER = "facebook_writer"
# SEO & Content Tools
SEO_TOOLS = "seo_tools"
CONTENT_PLANNING = "content_planning"
WRITING_ASSISTANT = "writing_assistant"
# Research & Strategy
RESEARCH_TOOLS = "research_tools"
CONTENT_STRATEGY = "content_strategy"
# Product Marketing Suite
PRODUCT_MARKETING = "product_marketing"
class ContentAsset(Base):
@@ -87,18 +77,14 @@ class ContentAsset(Base):
description = Column(Text, nullable=True)
prompt = Column(Text, nullable=True) # Original prompt used for generation
tags = Column(JSON, nullable=True) # Array of tags for search/filtering
metadata = Column(JSON, nullable=True) # Additional module-specific metadata
asset_metadata = Column(JSON, nullable=True) # Additional module-specific metadata (renamed from 'metadata' to avoid SQLAlchemy conflict)
# Generation details
provider = Column(String(100), nullable=True, index=True) # AI provider used (e.g., "stability", "gemini")
model = Column(String(200), nullable=True, index=True) # Model used (full model path/name)
provider = Column(String(100), nullable=True) # AI provider used (e.g., "stability", "gemini")
model = Column(String(100), nullable=True) # Model used
cost = Column(Float, nullable=True, default=0.0) # Generation cost in USD
generation_time = Column(Float, nullable=True) # Time taken in seconds
# Status tracking
status = Column(String(50), default='completed', index=True) # completed, processing, failed, pending
error_message = Column(Text, nullable=True) # Error details if failed
# Organization
is_favorite = Column(Boolean, default=False, index=True)
collection_id = Column(Integer, ForeignKey('asset_collections.id'), nullable=True)
@@ -113,7 +99,11 @@ class ContentAsset(Base):
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
collection = relationship("AssetCollection", back_populates="assets", cascade="all, delete-orphan")
collection = relationship(
"AssetCollection",
back_populates="assets",
foreign_keys=[collection_id]
)
# Composite indexes for common query patterns
__table_args__ = (
@@ -141,5 +131,15 @@ class AssetCollection(Base):
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
assets = relationship("ContentAsset", back_populates="collection")
assets = relationship(
"ContentAsset",
back_populates="collection",
foreign_keys="[ContentAsset.collection_id]",
cascade="all, delete-orphan" # Cascade delete on the "one" side (one-to-many)
)
cover_asset = relationship(
"ContentAsset",
foreign_keys=[cover_asset_id],
uselist=False
)

View File

@@ -0,0 +1,155 @@
"""
Product Asset Models
Database models for storing product-specific assets (separate from campaign assets).
These models are for the Product Marketing Suite (product asset creation).
"""
from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, Text, ForeignKey, Index
from sqlalchemy.orm import relationship
from datetime import datetime
import enum
from models.subscription_models import Base
class ProductAssetType(enum.Enum):
"""Product asset type enum."""
IMAGE = "image"
VIDEO = "video"
AUDIO = "audio"
ANIMATION = "animation"
class ProductImageStyle(enum.Enum):
"""Product image style enum."""
STUDIO = "studio"
LIFESTYLE = "lifestyle"
OUTDOOR = "outdoor"
MINIMALIST = "minimalist"
LUXURY = "luxury"
TECHNICAL = "technical"
class ProductAsset(Base):
"""
Product asset model.
Stores product-specific assets (images, videos, audio) generated for product marketing.
"""
__tablename__ = "product_assets"
# Primary fields
id = Column(Integer, primary_key=True, autoincrement=True)
product_id = Column(String(255), nullable=False, index=True) # User-defined product ID
user_id = Column(String(255), nullable=False, index=True) # Clerk user ID
# Product information
product_name = Column(String(500), nullable=False)
product_description = Column(Text, nullable=True)
# Asset details
asset_type = Column(String(50), nullable=False, index=True) # image, video, audio, animation
variant = Column(String(100), nullable=True) # color, size, angle, etc.
style = Column(String(50), nullable=True) # studio, lifestyle, minimalist, etc.
environment = Column(String(50), nullable=True) # studio, lifestyle, outdoor, etc.
# Link to ContentAsset (unified asset library)
content_asset_id = Column(Integer, ForeignKey('content_assets.id', ondelete='SET NULL'), nullable=True, index=True)
# Generation details
provider = Column(String(100), nullable=True)
model = Column(String(100), nullable=True)
cost = Column(Float, default=0.0)
generation_time = Column(Float, nullable=True)
prompt_used = Column(Text, nullable=True)
# E-commerce integration
ecommerce_exported = Column(Boolean, default=False)
exported_to = Column(JSON, nullable=True) # Array of platform names
# Status
status = Column(String(50), default="completed", nullable=False) # completed, processing, failed
# Metadata
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Additional metadata
metadata = Column(JSON, nullable=True) # Additional product-specific metadata
# Composite indexes
__table_args__ = (
Index('idx_user_product', 'user_id', 'product_id'),
Index('idx_user_type', 'user_id', 'asset_type'),
Index('idx_product_type', 'product_id', 'asset_type'),
)
class ProductStyleTemplate(Base):
"""
Brand style template for products.
Stores reusable brand style configurations for product asset generation.
"""
__tablename__ = "product_style_templates"
# Primary fields
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(String(255), nullable=False, index=True)
template_name = Column(String(255), nullable=False)
# Style configuration
color_palette = Column(JSON, nullable=True) # Array of brand colors
background_style = Column(String(50), nullable=True) # white, transparent, lifestyle, branded
lighting_preset = Column(String(50), nullable=True) # natural, studio, dramatic, soft
preferred_style = Column(String(50), nullable=True) # photorealistic, minimalist, luxury, technical
preferred_environment = Column(String(50), nullable=True) # studio, lifestyle, outdoor
# Brand integration
use_brand_colors = Column(Boolean, default=True)
use_brand_logo = Column(Boolean, default=False)
# Metadata
is_default = Column(Boolean, default=False) # Default template for user
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Composite indexes
__table_args__ = (
Index('idx_user_template', 'user_id', 'template_name'),
)
class EcommerceExport(Base):
"""
E-commerce platform export tracking.
Tracks product asset exports to e-commerce platforms.
"""
__tablename__ = "product_ecommerce_exports"
# Primary fields
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(String(255), nullable=False, index=True)
product_id = Column(String(255), nullable=False, index=True)
# Platform information
platform = Column(String(50), nullable=False) # shopify, amazon, woocommerce
platform_product_id = Column(String(255), nullable=True) # Product ID on the platform
# Export details
exported_assets = Column(JSON, nullable=False) # Array of asset IDs exported
export_status = Column(String(50), default="pending", nullable=False) # pending, completed, failed
error_message = Column(Text, nullable=True)
# Metadata
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
exported_at = Column(DateTime, nullable=True)
# Composite indexes
__table_args__ = (
Index('idx_user_platform', 'user_id', 'platform'),
Index('idx_product_platform', 'product_id', 'platform'),
)

View File

@@ -0,0 +1,162 @@
"""
Product Marketing Campaign Models
Database models for storing campaign blueprints and asset proposals.
"""
from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, Text, ForeignKey, Index, func
from sqlalchemy.orm import relationship
from datetime import datetime
import enum
from models.subscription_models import Base
class CampaignStatus(enum.Enum):
"""Campaign status enum."""
DRAFT = "draft"
GENERATING = "generating"
READY = "ready"
PUBLISHED = "published"
ARCHIVED = "archived"
class AssetNodeStatus(enum.Enum):
"""Asset node status enum."""
DRAFT = "draft"
PROPOSED = "proposed"
GENERATING = "generating"
READY = "ready"
APPROVED = "approved"
REJECTED = "rejected"
class Campaign(Base):
"""
Campaign blueprint model.
Stores campaign information, phases, and asset nodes.
"""
__tablename__ = "product_marketing_campaigns"
# Primary fields
id = Column(Integer, primary_key=True, autoincrement=True)
campaign_id = Column(String(255), unique=True, nullable=False, index=True)
user_id = Column(String(255), nullable=False, index=True) # Clerk user ID
# Campaign details
campaign_name = Column(String(500), nullable=False)
goal = Column(String(100), nullable=False) # product_launch, awareness, conversion, etc.
kpi = Column(String(500), nullable=True)
status = Column(String(50), default="draft", nullable=False, index=True)
# Campaign structure
phases = Column(JSON, nullable=True) # Array of phase objects
channels = Column(JSON, nullable=False) # Array of channel strings
asset_nodes = Column(JSON, nullable=True) # Array of asset node objects
# Product context
product_context = Column(JSON, nullable=True) # Product information
# Metadata
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
proposals = relationship("CampaignProposal", back_populates="campaign", cascade="all, delete-orphan")
generated_assets = relationship("CampaignAsset", back_populates="campaign", cascade="all, delete-orphan")
# Composite indexes
__table_args__ = (
Index('idx_user_status', 'user_id', 'status'),
Index('idx_user_created', 'user_id', 'created_at'),
)
class CampaignProposal(Base):
"""
Asset proposals for a campaign.
Stores AI-generated proposals for each asset node.
"""
__tablename__ = "product_marketing_proposals"
id = Column(Integer, primary_key=True, autoincrement=True)
campaign_id = Column(String(255), ForeignKey('product_marketing_campaigns.campaign_id', ondelete='CASCADE'), nullable=False, index=True)
user_id = Column(String(255), nullable=False, index=True)
# Asset node reference
asset_node_id = Column(String(255), nullable=False, index=True)
asset_type = Column(String(50), nullable=False) # image, text, video, audio
channel = Column(String(50), nullable=False)
# Proposal details
proposed_prompt = Column(Text, nullable=False)
recommended_template = Column(String(255), nullable=True)
recommended_provider = Column(String(100), nullable=True)
recommended_model = Column(String(100), nullable=True)
cost_estimate = Column(Float, default=0.0)
concept_summary = Column(Text, nullable=True)
# Status
status = Column(String(50), default="proposed", nullable=False) # proposed, approved, rejected, generating
approved_at = Column(DateTime, nullable=True)
# Metadata
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
campaign = relationship("Campaign", back_populates="proposals")
generated_asset = relationship("CampaignAsset", back_populates="proposal", uselist=False)
# Composite indexes
__table_args__ = (
Index('idx_campaign_node', 'campaign_id', 'asset_node_id'),
Index('idx_user_status', 'user_id', 'status'),
)
class CampaignAsset(Base):
"""
Generated assets for a campaign.
Links to ContentAsset and stores campaign-specific metadata.
"""
__tablename__ = "product_marketing_assets"
id = Column(Integer, primary_key=True, autoincrement=True)
campaign_id = Column(String(255), ForeignKey('product_marketing_campaigns.campaign_id', ondelete='CASCADE'), nullable=False, index=True)
proposal_id = Column(Integer, ForeignKey('product_marketing_proposals.id', ondelete='SET NULL'), nullable=True)
user_id = Column(String(255), nullable=False, index=True)
# Asset node reference
asset_node_id = Column(String(255), nullable=False, index=True)
# Link to ContentAsset
content_asset_id = Column(Integer, ForeignKey('content_assets.id', ondelete='SET NULL'), nullable=True)
# Generation details
provider = Column(String(100), nullable=True)
model = Column(String(100), nullable=True)
cost = Column(Float, default=0.0)
generation_time = Column(Float, nullable=True)
# Status
status = Column(String(50), default="generating", nullable=False) # generating, ready, approved, published
approved_at = Column(DateTime, nullable=True)
published_at = Column(DateTime, nullable=True)
# Metadata
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
campaign = relationship("Campaign", back_populates="generated_assets")
proposal = relationship("CampaignProposal", back_populates="generated_asset")
# Composite indexes
__table_args__ = (
Index('idx_campaign_node', 'campaign_id', 'asset_node_id'),
Index('idx_user_status', 'user_id', 'status'),
)

View File

@@ -1,8 +1,10 @@
"""API endpoints for Image Studio operations."""
import base64
from pathlib import Path
from typing import Optional, List, Dict, Any, Literal
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import FileResponse
from pydantic import BaseModel, Field
from services.image_studio import (
@@ -11,10 +13,12 @@ from services.image_studio import (
EditStudioRequest,
ControlStudioRequest,
SocialOptimizerRequest,
TransformImageToVideoRequest,
TalkingAvatarRequest,
)
from services.image_studio.upscale_service import UpscaleStudioRequest
from services.image_studio.templates import Platform, TemplateCategory
from middleware.auth_middleware import get_current_user
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
from utils.logger_utils import get_service_logger
@@ -136,7 +140,12 @@ def get_studio_manager() -> ImageStudioManager:
def _require_user_id(current_user: Dict[str, Any], operation: str) -> str:
"""Ensure user_id is available for protected operations."""
user_id = current_user.get("sub") or current_user.get("user_id")
user_id = (
current_user.get("sub")
or current_user.get("user_id")
or current_user.get("id")
or current_user.get("clerk_user_id")
)
if not user_id:
logger.error(
"[Image Studio] ❌ Missing user_id for %s operation - blocking request",
@@ -762,6 +771,244 @@ async def get_platform_specs(
raise HTTPException(status_code=500, detail=str(e))
# ====================
# TRANSFORM STUDIO ENDPOINTS
# ====================
class TransformImageToVideoRequestModel(BaseModel):
"""Request model for image-to-video transformation."""
image_base64: str = Field(..., description="Image in base64 or data URL format")
prompt: str = Field(..., description="Text prompt describing the video")
audio_base64: Optional[str] = Field(None, description="Optional audio file (wav/mp3, 3-30s, ≤15MB)")
resolution: Literal["480p", "720p", "1080p"] = Field("720p", description="Output resolution")
duration: Literal[5, 10] = Field(5, description="Video duration in seconds")
negative_prompt: Optional[str] = Field(None, description="Negative prompt")
seed: Optional[int] = Field(None, description="Random seed for reproducibility")
enable_prompt_expansion: bool = Field(True, description="Enable prompt optimizer")
class TalkingAvatarRequestModel(BaseModel):
"""Request model for talking avatar generation."""
image_base64: str = Field(..., description="Person image in base64 or data URL")
audio_base64: str = Field(..., description="Audio file in base64 or data URL (wav/mp3, max 10 minutes)")
resolution: Literal["480p", "720p"] = Field("720p", description="Output resolution")
prompt: Optional[str] = Field(None, description="Optional prompt for expression/style")
mask_image_base64: Optional[str] = Field(None, description="Optional mask for animatable regions")
seed: Optional[int] = Field(None, description="Random seed")
class TransformVideoResponse(BaseModel):
"""Response model for video generation."""
success: bool
video_url: Optional[str] = None
video_base64: Optional[str] = None
duration: float
resolution: str
width: int
height: int
file_size: int
cost: float
provider: str
model: str
metadata: Dict[str, Any]
class TransformCostEstimateRequest(BaseModel):
"""Request model for cost estimation."""
operation: Literal["image-to-video", "talking-avatar"] = Field(..., description="Operation type")
resolution: str = Field(..., description="Output resolution")
duration: Optional[int] = Field(None, description="Video duration in seconds (for image-to-video)")
class TransformCostEstimateResponse(BaseModel):
"""Response model for cost estimation."""
estimated_cost: float
breakdown: Dict[str, Any]
currency: str
provider: str
model: str
@router.post("/transform/image-to-video", response_model=TransformVideoResponse, summary="Transform Image to Video")
async def transform_image_to_video(
request: TransformImageToVideoRequestModel,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Transform an image into a video using WAN 2.5.
This endpoint generates a video from an image and text prompt, with optional audio synchronization.
Supports resolutions of 480p, 720p, and 1080p, with durations of 5 or 10 seconds.
Returns:
Video generation result with URL and metadata
"""
try:
user_id = _require_user_id(current_user, "image-to-video transformation")
logger.info(f"[Transform Studio] Image-to-video request from user {user_id}: resolution={request.resolution}, duration={request.duration}s")
# Convert request to service request
transform_request = TransformImageToVideoRequest(
image_base64=request.image_base64,
prompt=request.prompt,
audio_base64=request.audio_base64,
resolution=request.resolution,
duration=request.duration,
negative_prompt=request.negative_prompt,
seed=request.seed,
enable_prompt_expansion=request.enable_prompt_expansion,
)
# Generate video
result = await studio_manager.transform_image_to_video(transform_request, user_id=user_id)
logger.info(f"[Transform Studio] ✅ Image-to-video completed: cost=${result['cost']:.2f}")
return TransformVideoResponse(**result)
except ValueError as e:
logger.error(f"[Transform Studio] ❌ Validation error: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
except HTTPException:
raise
except Exception as e:
logger.error(f"[Transform Studio] ❌ Unexpected error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Video generation failed: {str(e)}")
@router.post("/transform/talking-avatar", response_model=TransformVideoResponse, summary="Create Talking Avatar")
async def create_talking_avatar(
request: TalkingAvatarRequestModel,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Create a talking avatar video using InfiniteTalk.
This endpoint generates a video with precise lip-sync from an image and audio file.
Supports resolutions of 480p and 720p, with videos up to 10 minutes long.
Returns:
Video generation result with URL and metadata
"""
try:
user_id = _require_user_id(current_user, "talking avatar generation")
logger.info(f"[Transform Studio] Talking avatar request from user {user_id}: resolution={request.resolution}")
# Convert request to service request
avatar_request = TalkingAvatarRequest(
image_base64=request.image_base64,
audio_base64=request.audio_base64,
resolution=request.resolution,
prompt=request.prompt,
mask_image_base64=request.mask_image_base64,
seed=request.seed,
)
# Generate video
result = await studio_manager.create_talking_avatar(avatar_request, user_id=user_id)
logger.info(f"[Transform Studio] ✅ Talking avatar completed: cost=${result['cost']:.2f}")
return TransformVideoResponse(**result)
except ValueError as e:
logger.error(f"[Transform Studio] ❌ Validation error: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
except HTTPException:
raise
except Exception as e:
logger.error(f"[Transform Studio] ❌ Unexpected error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Talking avatar generation failed: {str(e)}")
@router.post("/transform/estimate-cost", response_model=TransformCostEstimateResponse, summary="Estimate Transform Cost")
async def estimate_transform_cost(
request: TransformCostEstimateRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
studio_manager: ImageStudioManager = Depends(get_studio_manager),
):
"""Estimate cost for transform operations.
Provides cost estimates before generation to help users make informed decisions.
Returns:
Cost estimation details
"""
try:
estimate = studio_manager.estimate_transform_cost(
operation=request.operation,
resolution=request.resolution,
duration=request.duration,
)
return TransformCostEstimateResponse(**estimate)
except ValueError as e:
logger.error(f"[Transform Studio] ❌ Cost estimation error: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"[Transform Studio] ❌ Error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/videos/{user_id}/{video_filename:path}", summary="Serve Transform Studio Video")
async def serve_transform_video(
user_id: str,
video_filename: str,
current_user: Dict[str, Any] = Depends(get_current_user_with_query_token),
):
"""Serve a generated Transform Studio video file.
Args:
user_id: User ID from URL path
video_filename: Video filename
current_user: Authenticated user
Returns:
Video file response
"""
try:
# Verify user has access (must be the owner)
authenticated_user_id = _require_user_id(current_user, "video access")
if authenticated_user_id != user_id:
raise HTTPException(
status_code=403,
detail="Access denied: You can only access your own videos"
)
# Resolve video path
# __file__ is: backend/routers/image_studio.py
# We need: backend/transform_videos
base_dir = Path(__file__).parent.parent.parent
transform_videos_dir = base_dir / "transform_videos"
video_path = transform_videos_dir / user_id / video_filename
# Security: Ensure path is within transform_videos directory
# Prevent directory traversal attacks
try:
resolved_video_path = video_path.resolve()
resolved_base = transform_videos_dir.resolve()
# Check if video path is within base directory
resolved_video_path.relative_to(resolved_base)
except ValueError:
raise HTTPException(
status_code=403,
detail="Invalid video path: path traversal detected"
)
if not video_path.exists():
raise HTTPException(status_code=404, detail="Video not found")
return FileResponse(
path=str(video_path),
media_type="video/mp4",
filename=video_filename
)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Transform Studio] Failed to serve video: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# ====================
# HEALTH CHECK
# ====================

View File

@@ -8,9 +8,10 @@ proper error handling, monitoring, and documentation.
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks, Request
from fastapi.responses import JSONResponse
from typing import Dict, Any
from typing import Dict, Any, Optional
import time
from loguru import logger
from pathlib import Path
from models.linkedin_models import (
LinkedInPostRequest, LinkedInArticleRequest, LinkedInCarouselRequest,
@@ -19,11 +20,13 @@ from models.linkedin_models import (
LinkedInVideoScriptResponse, LinkedInCommentResponseResult
)
from services.linkedin_service import LinkedInService
from middleware.auth_middleware import get_current_user
from utils.text_asset_tracker import save_and_track_text_content
# Initialize the LinkedIn service instance
linkedin_service = LinkedInService()
from services.subscription.monitoring_middleware import DatabaseAPIMonitor
from services.database import get_db_session
from services.database import get_db as get_db_dependency
from sqlalchemy.orm import Session
# Initialize router
@@ -41,14 +44,8 @@ router = APIRouter(
monitor = DatabaseAPIMonitor()
def get_db():
"""Dependency to get database session."""
db = get_db_session()
try:
yield db
finally:
if db:
db.close()
# Use the proper database dependency from services.database
get_db = get_db_dependency
async def log_api_request(request: Request, db: Session, duration: float, status_code: int):
@@ -104,7 +101,8 @@ async def generate_post(
request: LinkedInPostRequest,
background_tasks: BackgroundTasks,
http_request: Request,
db: Session = Depends(get_db)
db: Session = Depends(get_db),
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
):
"""Generate a LinkedIn post based on the provided parameters."""
start_time = time.time()
@@ -119,6 +117,13 @@ async def generate_post(
if not request.industry.strip():
raise HTTPException(status_code=422, detail="Industry cannot be empty")
# Extract user_id
user_id = None
if current_user:
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
if not user_id:
user_id = http_request.headers.get("X-User-ID") or http_request.headers.get("Authorization")
# Generate post content
response = await linkedin_service.generate_linkedin_post(request)
@@ -131,6 +136,38 @@ async def generate_post(
if not response.success:
raise HTTPException(status_code=500, detail=response.error)
# Save and track text content (non-blocking)
if user_id and response.data and response.data.content:
try:
# Combine all text content
text_content = response.data.content
if response.data.call_to_action:
text_content += f"\n\nCall to Action: {response.data.call_to_action}"
if response.data.hashtags:
hashtag_text = " ".join([f"#{h.hashtag}" if isinstance(h, dict) else f"#{h.get('hashtag', '')}" for h in response.data.hashtags])
text_content += f"\n\nHashtags: {hashtag_text}"
save_and_track_text_content(
db=db,
user_id=user_id,
content=text_content,
source_module="linkedin_writer",
title=f"LinkedIn Post: {request.topic[:80]}",
description=f"LinkedIn post for {request.industry} industry",
prompt=f"Topic: {request.topic}\nIndustry: {request.industry}\nTone: {request.tone}",
tags=["linkedin", "post", request.industry.lower().replace(' ', '_')],
asset_metadata={
"post_type": request.post_type.value if hasattr(request.post_type, 'value') else str(request.post_type),
"tone": request.tone.value if hasattr(request.tone, 'value') else str(request.tone),
"character_count": response.data.character_count,
"hashtag_count": len(response.data.hashtags),
"grounding_enabled": response.data.grounding_enabled if hasattr(response.data, 'grounding_enabled') else False
},
subdirectory="posts"
)
except Exception as track_error:
logger.warning(f"Failed to track LinkedIn post asset: {track_error}")
logger.info(f"Successfully generated LinkedIn post in {duration:.2f} seconds")
return response
@@ -174,7 +211,8 @@ async def generate_article(
request: LinkedInArticleRequest,
background_tasks: BackgroundTasks,
http_request: Request,
db: Session = Depends(get_db)
db: Session = Depends(get_db),
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
):
"""Generate a LinkedIn article based on the provided parameters."""
start_time = time.time()
@@ -189,6 +227,13 @@ async def generate_article(
if not request.industry.strip():
raise HTTPException(status_code=422, detail="Industry cannot be empty")
# Extract user_id
user_id = None
if current_user:
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
if not user_id:
user_id = http_request.headers.get("X-User-ID") or http_request.headers.get("Authorization")
# Generate article content
response = await linkedin_service.generate_linkedin_article(request)
@@ -201,6 +246,44 @@ async def generate_article(
if not response.success:
raise HTTPException(status_code=500, detail=response.error)
# Save and track text content (non-blocking)
if user_id and response.data:
try:
# Combine article content
text_content = f"# {response.data.title}\n\n"
text_content += response.data.content
if response.data.sections:
text_content += "\n\n## Sections:\n"
for section in response.data.sections:
if isinstance(section, dict):
text_content += f"\n### {section.get('heading', 'Section')}\n{section.get('content', '')}\n"
if response.data.seo_metadata:
text_content += f"\n\n## SEO Metadata\n{response.data.seo_metadata}\n"
save_and_track_text_content(
db=db,
user_id=user_id,
content=text_content,
source_module="linkedin_writer",
title=f"LinkedIn Article: {response.data.title[:80] if response.data.title else request.topic[:80]}",
description=f"LinkedIn article for {request.industry} industry",
prompt=f"Topic: {request.topic}\nIndustry: {request.industry}\nTone: {request.tone}\nWord Count: {request.word_count}",
tags=["linkedin", "article", request.industry.lower().replace(' ', '_')],
asset_metadata={
"tone": request.tone.value if hasattr(request.tone, 'value') else str(request.tone),
"word_count": response.data.word_count,
"reading_time": response.data.reading_time,
"section_count": len(response.data.sections) if response.data.sections else 0,
"grounding_enabled": response.data.grounding_enabled if hasattr(response.data, 'grounding_enabled') else False
},
subdirectory="articles",
file_extension=".md"
)
except Exception as track_error:
logger.warning(f"Failed to track LinkedIn article asset: {track_error}")
logger.info(f"Successfully generated LinkedIn article in {duration:.2f} seconds")
return response
@@ -243,7 +326,8 @@ async def generate_carousel(
request: LinkedInCarouselRequest,
background_tasks: BackgroundTasks,
http_request: Request,
db: Session = Depends(get_db)
db: Session = Depends(get_db),
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
):
"""Generate a LinkedIn carousel based on the provided parameters."""
start_time = time.time()
@@ -261,6 +345,13 @@ async def generate_carousel(
if request.slide_count < 3 or request.slide_count > 15:
raise HTTPException(status_code=422, detail="Slide count must be between 3 and 15")
# Extract user_id
user_id = None
if current_user:
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
if not user_id:
user_id = http_request.headers.get("X-User-ID") or http_request.headers.get("Authorization")
# Generate carousel content
response = await linkedin_service.generate_linkedin_carousel(request)
@@ -273,6 +364,36 @@ async def generate_carousel(
if not response.success:
raise HTTPException(status_code=500, detail=response.error)
# Save and track text content (non-blocking)
if user_id and response.data:
try:
# Combine carousel content
text_content = f"# {response.data.title}\n\n"
for slide in response.data.slides:
text_content += f"\n## Slide {slide.slide_number}: {slide.title}\n{slide.content}\n"
if slide.visual_elements:
text_content += f"\nVisual Elements: {', '.join(slide.visual_elements)}\n"
save_and_track_text_content(
db=db,
user_id=user_id,
content=text_content,
source_module="linkedin_writer",
title=f"LinkedIn Carousel: {response.data.title[:80] if response.data.title else request.topic[:80]}",
description=f"LinkedIn carousel for {request.industry} industry",
prompt=f"Topic: {request.topic}\nIndustry: {request.industry}\nSlides: {getattr(request, 'number_of_slides', request.slide_count if hasattr(request, 'slide_count') else 5)}",
tags=["linkedin", "carousel", request.industry.lower().replace(' ', '_')],
asset_metadata={
"slide_count": len(response.data.slides),
"has_cover": response.data.cover_slide is not None,
"has_cta": response.data.cta_slide is not None
},
subdirectory="carousels",
file_extension=".md"
)
except Exception as track_error:
logger.warning(f"Failed to track LinkedIn carousel asset: {track_error}")
logger.info(f"Successfully generated LinkedIn carousel in {duration:.2f} seconds")
return response
@@ -315,7 +436,8 @@ async def generate_video_script(
request: LinkedInVideoScriptRequest,
background_tasks: BackgroundTasks,
http_request: Request,
db: Session = Depends(get_db)
db: Session = Depends(get_db),
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
):
"""Generate a LinkedIn video script based on the provided parameters."""
start_time = time.time()
@@ -330,9 +452,17 @@ async def generate_video_script(
if not request.industry.strip():
raise HTTPException(status_code=422, detail="Industry cannot be empty")
if request.video_length < 15 or request.video_length > 300:
video_duration = getattr(request, 'video_duration', getattr(request, 'video_length', 60))
if video_duration < 15 or video_duration > 300:
raise HTTPException(status_code=422, detail="Video length must be between 15 and 300 seconds")
# Extract user_id
user_id = None
if current_user:
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
if not user_id:
user_id = http_request.headers.get("X-User-ID") or http_request.headers.get("Authorization")
# Generate video script content
response = await linkedin_service.generate_linkedin_video_script(request)
@@ -345,6 +475,47 @@ async def generate_video_script(
if not response.success:
raise HTTPException(status_code=500, detail=response.error)
# Save and track text content (non-blocking)
if user_id and response.data:
try:
# Combine video script content
text_content = f"# Video Script: {request.topic}\n\n"
text_content += f"## Hook\n{response.data.hook}\n\n"
text_content += "## Main Content\n"
for scene in response.data.main_content:
if isinstance(scene, dict):
text_content += f"\n### Scene {scene.get('scene_number', '')}\n"
text_content += f"{scene.get('content', '')}\n"
if scene.get('duration'):
text_content += f"Duration: {scene.get('duration')}s\n"
if scene.get('visual_notes'):
text_content += f"Visual Notes: {scene.get('visual_notes')}\n"
text_content += f"\n## Conclusion\n{response.data.conclusion}\n"
if response.data.captions:
text_content += f"\n## Captions\n" + "\n".join(response.data.captions) + "\n"
if response.data.thumbnail_suggestions:
text_content += f"\n## Thumbnail Suggestions\n" + "\n".join(response.data.thumbnail_suggestions) + "\n"
save_and_track_text_content(
db=db,
user_id=user_id,
content=text_content,
source_module="linkedin_writer",
title=f"LinkedIn Video Script: {request.topic[:80]}",
description=f"LinkedIn video script for {request.industry} industry",
prompt=f"Topic: {request.topic}\nIndustry: {request.industry}\nDuration: {video_duration}s",
tags=["linkedin", "video_script", request.industry.lower().replace(' ', '_')],
asset_metadata={
"video_duration": video_duration,
"scene_count": len(response.data.main_content),
"has_captions": bool(response.data.captions)
},
subdirectory="video_scripts",
file_extension=".md"
)
except Exception as track_error:
logger.warning(f"Failed to track LinkedIn video script asset: {track_error}")
logger.info(f"Successfully generated LinkedIn video script in {duration:.2f} seconds")
return response
@@ -387,7 +558,8 @@ async def generate_comment_response(
request: LinkedInCommentResponseRequest,
background_tasks: BackgroundTasks,
http_request: Request,
db: Session = Depends(get_db)
db: Session = Depends(get_db),
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
):
"""Generate a LinkedIn comment response based on the provided parameters."""
start_time = time.time()
@@ -396,11 +568,21 @@ async def generate_comment_response(
logger.info("Received LinkedIn comment response generation request")
# Validate request
if not request.original_post.strip():
raise HTTPException(status_code=422, detail="Original post cannot be empty")
original_comment = getattr(request, 'original_comment', getattr(request, 'comment', ''))
post_context = getattr(request, 'post_context', getattr(request, 'original_post', ''))
if not request.comment.strip():
raise HTTPException(status_code=422, detail="Comment cannot be empty")
if not original_comment.strip():
raise HTTPException(status_code=422, detail="Original comment cannot be empty")
if not post_context.strip():
raise HTTPException(status_code=422, detail="Post context cannot be empty")
# Extract user_id
user_id = None
if current_user:
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
if not user_id:
user_id = http_request.headers.get("X-User-ID") or http_request.headers.get("Authorization")
# Generate comment response
response = await linkedin_service.generate_linkedin_comment_response(request)
@@ -414,6 +596,38 @@ async def generate_comment_response(
if not response.success:
raise HTTPException(status_code=500, detail=response.error)
# Save and track text content (non-blocking)
if user_id and hasattr(response, 'response') and response.response:
try:
text_content = f"# Comment Response\n\n"
text_content += f"## Original Comment\n{original_comment}\n\n"
text_content += f"## Post Context\n{post_context}\n\n"
text_content += f"## Generated Response\n{response.response}\n"
if hasattr(response, 'alternatives') and response.alternatives:
text_content += f"\n## Alternative Responses\n"
for i, alt in enumerate(response.alternatives, 1):
text_content += f"\n### Alternative {i}\n{alt}\n"
save_and_track_text_content(
db=db,
user_id=user_id,
content=text_content,
source_module="linkedin_writer",
title=f"LinkedIn Comment Response: {original_comment[:60]}",
description=f"LinkedIn comment response for {request.industry} industry",
prompt=f"Original Comment: {original_comment}\nPost Context: {post_context}\nIndustry: {request.industry}",
tags=["linkedin", "comment_response", request.industry.lower().replace(' ', '_')],
asset_metadata={
"response_length": getattr(request, 'response_length', 'medium'),
"tone": request.tone.value if hasattr(request.tone, 'value') else str(request.tone),
"has_alternatives": hasattr(response, 'alternatives') and bool(response.alternatives)
},
subdirectory="comment_responses",
file_extension=".md"
)
except Exception as track_error:
logger.warning(f"Failed to track LinkedIn comment response asset: {track_error}")
logger.info(f"Successfully generated LinkedIn comment response in {duration:.2f} seconds")
return response

View File

@@ -0,0 +1,640 @@
"""API endpoints for Product Marketing Suite."""
from typing import Optional, List, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from services.product_marketing import (
ProductMarketingOrchestrator,
BrandDNASyncService,
AssetAuditService,
ChannelPackService,
)
from services.product_marketing.campaign_storage import CampaignStorageService
from services.product_marketing.product_image_service import ProductImageService, ProductImageRequest
from middleware.auth_middleware import get_current_user
from utils.logger_utils import get_service_logger
from services.database import get_db
from sqlalchemy.orm import Session
logger = get_service_logger("api.product_marketing")
router = APIRouter(prefix="/api/product-marketing", tags=["product-marketing"])
# ====================
# REQUEST MODELS
# ====================
class CampaignCreateRequest(BaseModel):
"""Request to create a new campaign blueprint."""
campaign_name: str = Field(..., description="Campaign name")
goal: str = Field(..., description="Campaign goal (product_launch, awareness, conversion, etc.)")
kpi: Optional[str] = Field(None, description="Key performance indicator")
channels: List[str] = Field(..., description="Target channels (instagram, linkedin, tiktok, etc.)")
product_context: Optional[Dict[str, Any]] = Field(None, description="Product information")
class AssetProposalRequest(BaseModel):
"""Request to generate asset proposals."""
campaign_id: str = Field(..., description="Campaign ID")
product_context: Optional[Dict[str, Any]] = Field(None, description="Product information")
class AssetGenerateRequest(BaseModel):
"""Request to generate a specific asset."""
asset_proposal: Dict[str, Any] = Field(..., description="Asset proposal from generate_proposals")
product_context: Optional[Dict[str, Any]] = Field(None, description="Product information")
class AssetAuditRequest(BaseModel):
"""Request to audit uploaded assets."""
image_base64: str = Field(..., description="Base64 encoded image")
asset_metadata: Optional[Dict[str, Any]] = Field(None, description="Asset metadata")
# ====================
# DEPENDENCY
# ====================
def get_orchestrator() -> ProductMarketingOrchestrator:
"""Get Product Marketing Orchestrator instance."""
return ProductMarketingOrchestrator()
def get_campaign_storage() -> CampaignStorageService:
"""Get Campaign Storage Service instance."""
return CampaignStorageService()
def _require_user_id(current_user: Dict[str, Any], operation: str) -> str:
"""Ensure user_id is available for protected operations."""
user_id = current_user.get("sub") or current_user.get("user_id") or current_user.get("id")
if not user_id:
logger.error(
"[Product Marketing] ❌ Missing user_id for %s operation - blocking request",
operation,
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authenticated user required for product marketing operations.",
)
return str(user_id)
# ====================
# CAMPAIGN ENDPOINTS
# ====================
@router.post("/campaigns/validate-preflight", summary="Validate Campaign Pre-flight")
async def validate_campaign_preflight(
request: CampaignCreateRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
orchestrator: ProductMarketingOrchestrator = Depends(get_orchestrator)
):
"""Validate campaign blueprint against subscription limits before creation.
This endpoint:
- Creates a temporary blueprint to estimate costs
- Validates subscription limits
- Returns cost estimates and validation results
- Does NOT save anything to database
"""
try:
user_id = _require_user_id(current_user, "campaign pre-flight validation")
logger.info(f"[Product Marketing] Pre-flight validation for user {user_id}")
# Create temporary blueprint for validation (not saved)
campaign_data = {
"campaign_name": request.campaign_name or "Temporary Campaign",
"goal": request.goal,
"kpi": request.kpi,
"channels": request.channels,
}
blueprint = orchestrator.create_campaign_blueprint(user_id, campaign_data)
# Run pre-flight validation
validation_result = orchestrator.validate_campaign_preflight(user_id, blueprint)
logger.info(f"[Product Marketing] ✅ Pre-flight validation completed: can_proceed={validation_result.get('can_proceed')}")
return validation_result
except Exception as e:
logger.error(f"[Product Marketing] ❌ Error in pre-flight validation: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Pre-flight validation failed: {str(e)}")
@router.post("/campaigns/create-blueprint", summary="Create Campaign Blueprint")
async def create_campaign_blueprint(
request: CampaignCreateRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
orchestrator: ProductMarketingOrchestrator = Depends(get_orchestrator)
):
"""Create a campaign blueprint with personalized asset nodes.
This endpoint:
- Uses onboarding data to personalize the blueprint
- Generates campaign phases (teaser, launch, nurture)
- Creates asset nodes for each phase and channel
- Returns blueprint ready for AI proposal generation
"""
try:
user_id = _require_user_id(current_user, "campaign blueprint creation")
logger.info(f"[Product Marketing] Creating blueprint for user {user_id}: {request.campaign_name}")
campaign_data = {
"campaign_name": request.campaign_name,
"goal": request.goal,
"kpi": request.kpi,
"channels": request.channels,
}
blueprint = orchestrator.create_campaign_blueprint(user_id, campaign_data)
# Convert blueprint to dict for JSON response
blueprint_dict = {
"campaign_id": blueprint.campaign_id,
"campaign_name": blueprint.campaign_name,
"goal": blueprint.goal,
"kpi": blueprint.kpi,
"phases": blueprint.phases,
"asset_nodes": [
{
"asset_id": node.asset_id,
"asset_type": node.asset_type,
"channel": node.channel,
"status": node.status,
}
for node in blueprint.asset_nodes
],
"channels": blueprint.channels,
"status": blueprint.status,
}
# Save to database
campaign_storage = get_campaign_storage()
campaign_storage.save_campaign(user_id, blueprint_dict)
logger.info(f"[Product Marketing] ✅ Blueprint created and saved: {blueprint.campaign_id}")
return blueprint_dict
except Exception as e:
logger.error(f"[Product Marketing] ❌ Error creating blueprint: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Campaign blueprint creation failed: {str(e)}")
@router.post("/campaigns/{campaign_id}/generate-proposals", summary="Generate Asset Proposals")
async def generate_asset_proposals(
campaign_id: str,
request: AssetProposalRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
orchestrator: ProductMarketingOrchestrator = Depends(get_orchestrator)
):
"""Generate AI proposals for all assets in a campaign blueprint.
This endpoint:
- Uses specialized marketing prompts with brand DNA
- Recommends templates, providers, and settings
- Provides cost estimates
- Returns proposals ready for user approval
"""
try:
user_id = _require_user_id(current_user, "asset proposal generation")
logger.info(f"[Product Marketing] Generating proposals for campaign {campaign_id}")
# Fetch blueprint from database
campaign_storage = get_campaign_storage()
campaign = campaign_storage.get_campaign(user_id, campaign_id)
if not campaign:
raise HTTPException(status_code=404, detail="Campaign not found")
# Reconstruct blueprint from database
from services.product_marketing.orchestrator import CampaignBlueprint, CampaignAssetNode
asset_nodes = []
if campaign.asset_nodes:
for node_data in campaign.asset_nodes:
asset_nodes.append(CampaignAssetNode(
asset_id=node_data.get('asset_id'),
asset_type=node_data.get('asset_type'),
channel=node_data.get('channel'),
status=node_data.get('status', 'draft'),
))
blueprint = CampaignBlueprint(
campaign_id=campaign.campaign_id,
campaign_name=campaign.campaign_name,
goal=campaign.goal,
kpi=campaign.kpi,
channels=campaign.channels or [],
asset_nodes=asset_nodes,
)
proposals = orchestrator.generate_asset_proposals(
user_id=user_id,
blueprint=blueprint,
product_context=request.product_context,
)
# Save proposals to database
try:
campaign_storage.save_proposals(user_id, campaign_id, proposals)
logger.info(f"[Product Marketing] ✅ Saved {proposals['total_assets']} proposals to database")
except Exception as save_error:
logger.error(f"[Product Marketing] ⚠️ Failed to save proposals to database: {str(save_error)}")
# Continue even if save fails - proposals are still returned to user
# This allows the workflow to continue, but proposals won't persist
logger.info(f"[Product Marketing] ✅ Generated {proposals['total_assets']} proposals")
return proposals
except Exception as e:
logger.error(f"[Product Marketing] ❌ Error generating proposals: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Asset proposal generation failed: {str(e)}")
@router.post("/assets/generate", summary="Generate Asset")
async def generate_asset(
request: AssetGenerateRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
orchestrator: ProductMarketingOrchestrator = Depends(get_orchestrator)
):
"""Generate a single asset using Image Studio APIs.
This endpoint:
- Reuses existing Image Studio APIs
- Applies specialized marketing prompts
- Automatically tracks assets in Asset Library
- Validates subscription limits
"""
try:
user_id = _require_user_id(current_user, "asset generation")
logger.info(f"[Product Marketing] Generating asset for user {user_id}")
result = await orchestrator.generate_asset(
user_id=user_id,
asset_proposal=request.asset_proposal,
product_context=request.product_context,
)
logger.info(f"[Product Marketing] ✅ Asset generated successfully")
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"[Product Marketing] ❌ Error generating asset: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Asset generation failed: {str(e)}")
# ====================
# BRAND DNA ENDPOINTS
# ====================
@router.get("/brand-dna", summary="Get Brand DNA Tokens")
async def get_brand_dna(
current_user: Dict[str, Any] = Depends(get_current_user),
brand_dna_sync: BrandDNASyncService = Depends(lambda: BrandDNASyncService())
):
"""Get brand DNA tokens for the authenticated user.
Returns normalized brand DNA from onboarding and persona data.
"""
try:
user_id = _require_user_id(current_user, "brand DNA retrieval")
brand_tokens = brand_dna_sync.get_brand_dna_tokens(user_id)
return {"brand_dna": brand_tokens}
except Exception as e:
logger.error(f"[Product Marketing] ❌ Error getting brand DNA: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/brand-dna/channel/{channel}", summary="Get Channel-Specific Brand DNA")
async def get_channel_brand_dna(
channel: str,
current_user: Dict[str, Any] = Depends(get_current_user),
brand_dna_sync: BrandDNASyncService = Depends(lambda: BrandDNASyncService())
):
"""Get channel-specific brand DNA adaptations."""
try:
user_id = _require_user_id(current_user, "channel brand DNA retrieval")
channel_dna = brand_dna_sync.get_channel_specific_dna(user_id, channel)
return {"channel": channel, "brand_dna": channel_dna}
except Exception as e:
logger.error(f"[Product Marketing] ❌ Error getting channel DNA: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# ====================
# ASSET AUDIT ENDPOINTS
# ====================
@router.post("/assets/audit", summary="Audit Asset")
async def audit_asset(
request: AssetAuditRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
asset_audit: AssetAuditService = Depends(lambda: AssetAuditService())
):
"""Audit an uploaded asset and get enhancement recommendations."""
try:
user_id = _require_user_id(current_user, "asset audit")
audit_result = asset_audit.audit_asset(
request.image_base64,
request.asset_metadata,
)
return audit_result
except Exception as e:
logger.error(f"[Product Marketing] ❌ Error auditing asset: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# ====================
# CHANNEL PACK ENDPOINTS
# ====================
@router.get("/channels/{channel}/pack", summary="Get Channel Pack")
async def get_channel_pack(
channel: str,
asset_type: str = "social_post",
current_user: Dict[str, Any] = Depends(get_current_user),
channel_pack: ChannelPackService = Depends(lambda: ChannelPackService())
):
"""Get channel-specific pack configuration with templates and optimization tips."""
try:
pack = channel_pack.get_channel_pack(channel, asset_type)
return pack
except Exception as e:
logger.error(f"[Product Marketing] ❌ Error getting channel pack: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# ====================
# CAMPAIGN LISTING & RETRIEVAL
# ====================
@router.get("/campaigns", summary="List Campaigns")
async def list_campaigns(
status: Optional[str] = None,
current_user: Dict[str, Any] = Depends(get_current_user),
campaign_storage: CampaignStorageService = Depends(get_campaign_storage)
):
"""List all campaigns for the authenticated user."""
try:
user_id = _require_user_id(current_user, "list campaigns")
campaigns = campaign_storage.list_campaigns(user_id, status=status)
return {
"campaigns": [
{
"campaign_id": c.campaign_id,
"campaign_name": c.campaign_name,
"goal": c.goal,
"kpi": c.kpi,
"status": c.status,
"channels": c.channels,
"phases": c.phases,
"asset_nodes": c.asset_nodes,
"created_at": c.created_at.isoformat() if c.created_at else None,
"updated_at": c.updated_at.isoformat() if c.updated_at else None,
}
for c in campaigns
],
"total": len(campaigns),
}
except Exception as e:
logger.error(f"[Product Marketing] ❌ Error listing campaigns: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/campaigns/{campaign_id}", summary="Get Campaign")
async def get_campaign(
campaign_id: str,
current_user: Dict[str, Any] = Depends(get_current_user),
campaign_storage: CampaignStorageService = Depends(get_campaign_storage)
):
"""Get a specific campaign by ID."""
try:
user_id = _require_user_id(current_user, "get campaign")
campaign = campaign_storage.get_campaign(user_id, campaign_id)
if not campaign:
raise HTTPException(status_code=404, detail="Campaign not found")
return {
"campaign_id": campaign.campaign_id,
"campaign_name": campaign.campaign_name,
"goal": campaign.goal,
"kpi": campaign.kpi,
"status": campaign.status,
"channels": campaign.channels,
"phases": campaign.phases,
"asset_nodes": campaign.asset_nodes,
"product_context": campaign.product_context,
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
"updated_at": campaign.updated_at.isoformat() if campaign.updated_at else None,
}
except HTTPException:
raise
except Exception as e:
logger.error(f"[Product Marketing] ❌ Error getting campaign: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/campaigns/{campaign_id}/proposals", summary="Get Campaign Proposals")
async def get_campaign_proposals(
campaign_id: str,
current_user: Dict[str, Any] = Depends(get_current_user),
campaign_storage: CampaignStorageService = Depends(get_campaign_storage)
):
"""Get proposals for a campaign."""
try:
user_id = _require_user_id(current_user, "get proposals")
proposals = campaign_storage.get_proposals(user_id, campaign_id)
proposals_dict = {}
for proposal in proposals:
proposals_dict[proposal.asset_node_id] = {
"asset_id": proposal.asset_node_id,
"asset_type": proposal.asset_type,
"channel": proposal.channel,
"proposed_prompt": proposal.proposed_prompt,
"recommended_template": proposal.recommended_template,
"recommended_provider": proposal.recommended_provider,
"cost_estimate": proposal.cost_estimate,
"concept_summary": proposal.concept_summary,
"status": proposal.status,
}
return {
"proposals": proposals_dict,
"total_assets": len(proposals),
}
except Exception as e:
logger.error(f"[Product Marketing] ❌ Error getting proposals: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# ====================
# PRODUCT ASSET ENDPOINTS (Product Marketing Suite - Product Assets)
# ====================
class ProductPhotoshootRequest(BaseModel):
"""Request for product image photoshoot generation."""
product_name: str = Field(..., description="Product name")
product_description: str = Field(..., description="Product description")
environment: str = Field(default="studio", description="Environment: studio, lifestyle, outdoor, minimalist, luxury")
background_style: str = Field(default="white", description="Background: white, transparent, lifestyle, branded")
lighting: str = Field(default="natural", description="Lighting: natural, studio, dramatic, soft")
product_variant: Optional[str] = Field(None, description="Product variant (color, size, etc.)")
angle: Optional[str] = Field(None, description="Product angle: front, side, top, 360")
style: str = Field(default="photorealistic", description="Style: photorealistic, minimalist, luxury, technical")
resolution: str = Field(default="1024x1024", description="Resolution (e.g., 1024x1024, 1280x720)")
num_variations: int = Field(default=1, description="Number of variations to generate")
brand_colors: Optional[List[str]] = Field(None, description="Brand color palette")
additional_context: Optional[str] = Field(None, description="Additional context for generation")
def get_product_image_service() -> ProductImageService:
"""Get Product Image Service instance."""
return ProductImageService()
@router.post("/products/photoshoot", summary="Generate Product Image")
async def generate_product_image(
request: ProductPhotoshootRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
product_image_service: ProductImageService = Depends(get_product_image_service),
brand_dna_sync: BrandDNASyncService = Depends(lambda: BrandDNASyncService())
):
"""Generate professional product images using AI.
This endpoint:
- Generates product images optimized for e-commerce
- Supports multiple environments and styles
- Integrates with brand DNA for personalization
- Automatically saves to Asset Library
"""
try:
user_id = _require_user_id(current_user, "product image generation")
logger.info(f"[Product Marketing] Generating product image for '{request.product_name}'")
# Get brand DNA for personalization
brand_context = None
try:
brand_dna = brand_dna_sync.get_brand_dna_tokens(user_id)
brand_context = {
"visual_identity": brand_dna.get("visual_identity", {}),
"persona": brand_dna.get("persona", {}),
}
except Exception as brand_error:
logger.warning(f"[Product Marketing] Could not load brand DNA: {str(brand_error)}")
# Convert request to service request
service_request = ProductImageRequest(
product_name=request.product_name,
product_description=request.product_description,
environment=request.environment,
background_style=request.background_style,
lighting=request.lighting,
product_variant=request.product_variant,
angle=request.angle,
style=request.style,
resolution=request.resolution,
num_variations=request.num_variations,
brand_colors=request.brand_colors,
additional_context=request.additional_context,
)
# Generate product image
result = await product_image_service.generate_product_image(
request=service_request,
user_id=user_id,
brand_context=brand_context,
)
if not result.success:
raise HTTPException(status_code=500, detail=result.error or "Product image generation failed")
logger.info(f"[Product Marketing] ✅ Generated product image: {result.asset_id}")
# Return result (image_bytes will be served via separate endpoint)
return {
"success": True,
"product_name": result.product_name,
"image_url": result.image_url,
"asset_id": result.asset_id,
"provider": result.provider,
"model": result.model,
"cost": result.cost,
"generation_time": result.generation_time,
}
except HTTPException:
raise
except Exception as e:
logger.error(f"[Product Marketing] ❌ Error generating product image: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Product image generation failed: {str(e)}")
@router.get("/products/images/{filename}", summary="Serve Product Image")
async def serve_product_image(
filename: str,
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Serve generated product images."""
try:
from fastapi.responses import FileResponse
from pathlib import Path
_require_user_id(current_user, "serving product image")
# Locate image file
base_dir = Path(__file__).parent.parent.parent
image_path = base_dir / "product_images" / filename
if not image_path.exists():
raise HTTPException(status_code=404, detail="Image not found")
return FileResponse(
path=str(image_path),
media_type="image/png",
filename=filename
)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Product Marketing] ❌ Error serving product image: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# ====================
# HEALTH CHECK
# ====================
@router.get("/health", summary="Health Check")
async def health_check():
"""Health check endpoint for Product Marketing Suite."""
return {
"status": "healthy",
"service": "product_marketing",
"version": "1.0.0",
"modules": {
"orchestrator": "available",
"prompt_builder": "available",
"brand_dna_sync": "available",
"asset_audit": "available",
"channel_pack": "available",
"product_image_service": "available",
}
}

View File

@@ -0,0 +1,88 @@
"""
Database Migration Script for Product Asset Tables
Creates all tables needed for Product Marketing Suite (product asset creation).
These tables are separate from campaign-related tables and focus on product-specific assets.
"""
import sys
import os
from pathlib import Path
# Add the backend directory to Python path
backend_dir = Path(__file__).parent.parent
sys.path.insert(0, str(backend_dir))
from sqlalchemy import create_engine, text, inspect
from sqlalchemy.orm import sessionmaker
from loguru import logger
import traceback
# Import models - Product Asset models use SubscriptionBase
from models.subscription_models import Base as SubscriptionBase
from models.product_asset_models import ProductAsset, ProductStyleTemplate, EcommerceExport
from services.database import DATABASE_URL
def create_product_asset_tables():
"""Create all product asset tables."""
try:
# Create engine
engine = create_engine(DATABASE_URL, echo=False)
# Create all tables (product asset models share SubscriptionBase)
logger.info("Creating product asset tables for Product Marketing Suite...")
SubscriptionBase.metadata.create_all(bind=engine)
logger.info("✅ Product asset tables created successfully")
# Verify tables were created
with engine.connect() as conn:
# Check if tables exist
inspector = inspect(engine)
tables = inspector.get_table_names()
expected_tables = [
'product_assets',
'product_style_templates',
'product_ecommerce_exports'
]
created_tables = [t for t in expected_tables if t in tables]
missing_tables = [t for t in expected_tables if t not in tables]
if created_tables:
logger.info(f"✅ Created tables: {', '.join(created_tables)}")
if missing_tables:
logger.warning(f"⚠️ Missing tables: {', '.join(missing_tables)}")
else:
logger.info("🎉 All product asset tables verified!")
# Verify indexes were created
with engine.connect() as conn:
inspector = inspect(engine)
# Check ProductAsset indexes
product_asset_indexes = inspector.get_indexes('product_assets')
logger.info(f"✅ ProductAsset indexes: {len(product_asset_indexes)} indexes created")
# Check ProductStyleTemplate indexes
style_template_indexes = inspector.get_indexes('product_style_templates')
logger.info(f"✅ ProductStyleTemplate indexes: {len(style_template_indexes)} indexes created")
# Check EcommerceExport indexes
ecommerce_export_indexes = inspector.get_indexes('product_ecommerce_exports')
logger.info(f"✅ EcommerceExport indexes: {len(ecommerce_export_indexes)} indexes created")
return True
except Exception as e:
logger.error(f"❌ Error creating product asset tables: {e}")
logger.error(traceback.format_exc())
return False
if __name__ == "__main__":
success = create_product_asset_tables()
sys.exit(0 if success else 1)

View File

@@ -0,0 +1,71 @@
"""
Database Migration Script for Product Marketing Suite
Creates all tables needed for campaigns, proposals, and generated assets.
"""
import sys
import os
from pathlib import Path
# Add the backend directory to Python path
backend_dir = Path(__file__).parent.parent
sys.path.insert(0, str(backend_dir))
from sqlalchemy import create_engine, text, inspect
from sqlalchemy.orm import sessionmaker
from loguru import logger
import traceback
# Import models - Product Marketing uses SubscriptionBase
# Import the Base first, then import product marketing models to register them
from models.subscription_models import Base as SubscriptionBase
from models.product_marketing_models import Campaign, CampaignProposal, CampaignAsset
from services.database import DATABASE_URL
def create_product_marketing_tables():
"""Create all product marketing tables."""
try:
# Create engine
engine = create_engine(DATABASE_URL, echo=False)
# Create all tables (product marketing models share SubscriptionBase)
logger.info("Creating product marketing tables...")
SubscriptionBase.metadata.create_all(bind=engine)
logger.info("✅ Product marketing tables created successfully")
# Verify tables were created
with engine.connect() as conn:
# Check if tables exist
from sqlalchemy import inspect as sqlalchemy_inspect
inspector = sqlalchemy_inspect(engine)
tables = inspector.get_table_names()
expected_tables = [
'product_marketing_campaigns',
'product_marketing_proposals',
'product_marketing_assets'
]
created_tables = [t for t in expected_tables if t in tables]
missing_tables = [t for t in expected_tables if t not in tables]
if created_tables:
logger.info(f"✅ Created tables: {', '.join(created_tables)}")
if missing_tables:
logger.warning(f"⚠️ Missing tables: {', '.join(missing_tables)}")
else:
logger.info("🎉 All product marketing tables verified!")
return True
except Exception as e:
logger.error(f"❌ Error creating product marketing tables: {e}")
logger.error(traceback.format_exc())
return False
if __name__ == "__main__":
success = create_product_marketing_tables()
sys.exit(0 if success else 1)

View File

@@ -38,7 +38,7 @@ class ContentAssetService:
description: Optional[str] = None,
prompt: Optional[str] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
asset_metadata: Optional[Dict[str, Any]] = None,
provider: Optional[str] = None,
model: Optional[str] = None,
cost: Optional[float] = None,
@@ -60,7 +60,7 @@ class ContentAssetService:
description: Asset description (optional)
prompt: Generation prompt (optional)
tags: List of tags (optional)
metadata: Additional metadata (optional)
asset_metadata: Additional metadata (optional)
provider: AI provider used (optional)
model: Model used (optional)
cost: Generation cost (optional)
@@ -83,7 +83,7 @@ class ContentAssetService:
description=description,
prompt=prompt,
tags=tags or [],
metadata=metadata or {},
asset_metadata=asset_metadata or {},
provider=provider,
model=model,
cost=cost or 0.0,
@@ -222,7 +222,7 @@ class ContentAssetService:
title: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
asset_metadata: Optional[Dict[str, Any]] = None,
) -> Optional[ContentAsset]:
"""Update asset metadata."""
try:
@@ -236,8 +236,8 @@ class ContentAssetService:
asset.description = description
if tags is not None:
asset.tags = tags
if metadata is not None:
asset.metadata = {**(asset.metadata or {}), **metadata}
if asset_metadata is not None:
asset.asset_metadata = {**(asset.asset_metadata or {}), **asset_metadata}
asset.updated_at = datetime.utcnow()
self.db.commit()

View File

@@ -21,6 +21,10 @@ from models.persona_models import Base as PersonaBase
from models.subscription_models import Base as SubscriptionBase
from models.user_business_info import Base as UserBusinessInfoBase
from models.content_asset_models import Base as ContentAssetBase
# Product Marketing models use SubscriptionBase, but import to ensure models are registered
from models.product_marketing_models import Campaign, CampaignProposal, CampaignAsset
# Product Asset models (Product Marketing Suite - product assets, not campaigns)
from models.product_asset_models import ProductAsset, ProductStyleTemplate, EcommerceExport
# Database configuration
DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./alwrity.db')
@@ -73,10 +77,10 @@ def init_database():
EnhancedStrategyBase.metadata.create_all(bind=engine)
MonitoringBase.metadata.create_all(bind=engine)
PersonaBase.metadata.create_all(bind=engine)
SubscriptionBase.metadata.create_all(bind=engine)
SubscriptionBase.metadata.create_all(bind=engine) # Includes product_marketing models
UserBusinessInfoBase.metadata.create_all(bind=engine)
ContentAssetBase.metadata.create_all(bind=engine)
logger.info("Database initialized successfully with all models including subscription system, business info, and content assets")
logger.info("Database initialized successfully with all models including subscription system, product marketing, business info, and content assets")
except SQLAlchemyError as e:
logger.error(f"Error initializing database: {str(e)}")
raise

View File

@@ -6,6 +6,11 @@ from .edit_service import EditStudioService, EditStudioRequest
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
from .control_service import ControlStudioService, ControlStudioRequest
from .social_optimizer_service import SocialOptimizerService, SocialOptimizerRequest
from .transform_service import (
TransformStudioService,
TransformImageToVideoRequest,
TalkingAvatarRequest,
)
from .templates import PlatformTemplates, TemplateManager
__all__ = [
@@ -20,6 +25,9 @@ __all__ = [
"ControlStudioRequest",
"SocialOptimizerService",
"SocialOptimizerRequest",
"TransformStudioService",
"TransformImageToVideoRequest",
"TalkingAvatarRequest",
"PlatformTemplates",
"TemplateManager",
]

View File

@@ -0,0 +1,155 @@
"""InfiniteTalk adapter for Transform Studio."""
import asyncio
from typing import Any, Dict, Optional
from fastapi import HTTPException
from loguru import logger
from services.wavespeed.infinitetalk import animate_scene_with_voiceover
from services.wavespeed.client import WaveSpeedClient
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.infinitetalk")
class InfiniteTalkService:
"""Adapter for InfiniteTalk in Transform Studio context."""
def __init__(self, client: Optional[WaveSpeedClient] = None):
"""Initialize InfiniteTalk service adapter."""
self.client = client or WaveSpeedClient()
logger.info("[InfiniteTalk Adapter] Service initialized")
def calculate_cost(self, resolution: str, duration: float) -> float:
"""Calculate cost for InfiniteTalk video.
Args:
resolution: Output resolution (480p or 720p)
duration: Video duration in seconds
Returns:
Cost in USD
"""
# InfiniteTalk pricing: $0.03/s (480p) or $0.06/s (720p)
# Minimum charge: 5 seconds
cost_per_second = 0.03 if resolution == "480p" else 0.06
actual_duration = max(5.0, duration) # Minimum 5 seconds
return cost_per_second * actual_duration
async def create_talking_avatar(
self,
image_base64: str,
audio_base64: str,
resolution: str = "720p",
prompt: Optional[str] = None,
mask_image_base64: Optional[str] = None,
seed: Optional[int] = None,
user_id: str = "transform_studio",
) -> Dict[str, Any]:
"""Create talking avatar video using InfiniteTalk.
Args:
image_base64: Person image in base64 or data URI
audio_base64: Audio file in base64 or data URI
resolution: Output resolution (480p or 720p)
prompt: Optional prompt for expression/style
mask_image_base64: Optional mask for animatable regions
seed: Optional random seed
user_id: User ID for tracking
Returns:
Dictionary with video bytes, metadata, and cost
"""
# Validate resolution
if resolution not in ["480p", "720p"]:
raise HTTPException(
status_code=400,
detail="Resolution must be '480p' or '720p' for InfiniteTalk"
)
# Decode image
import base64
try:
if image_base64.startswith("data:"):
if "," not in image_base64:
raise ValueError("Invalid data URI format: missing comma separator")
header, encoded = image_base64.split(",", 1)
mime_parts = header.split(":")[1].split(";")[0] if ":" in header else "image/png"
image_mime = mime_parts.strip() or "image/png"
image_bytes = base64.b64decode(encoded)
else:
image_bytes = base64.b64decode(image_base64)
image_mime = "image/png"
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Failed to decode image: {str(e)}"
)
# Decode audio
try:
if audio_base64.startswith("data:"):
if "," not in audio_base64:
raise ValueError("Invalid data URI format: missing comma separator")
header, encoded = audio_base64.split(",", 1)
mime_parts = header.split(":")[1].split(";")[0] if ":" in header else "audio/mpeg"
audio_mime = mime_parts.strip() or "audio/mpeg"
audio_bytes = base64.b64decode(encoded)
else:
audio_bytes = base64.b64decode(audio_base64)
audio_mime = "audio/mpeg"
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Failed to decode audio: {str(e)}"
)
# Call existing InfiniteTalk function (run in thread since it's synchronous)
# Note: We pass empty dicts for scene_data and story_context since
# Transform Studio doesn't have story context
try:
result = await asyncio.to_thread(
animate_scene_with_voiceover,
image_bytes=image_bytes,
audio_bytes=audio_bytes,
scene_data={}, # Empty for Transform Studio
story_context={}, # Empty for Transform Studio
user_id=user_id,
resolution=resolution,
prompt_override=prompt,
image_mime=image_mime,
audio_mime=audio_mime,
client=self.client,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"[InfiniteTalk Adapter] Error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"InfiniteTalk generation failed: {str(e)}"
)
# Calculate actual cost based on duration
actual_cost = self.calculate_cost(resolution, result.get("duration", 5.0))
# Update result with actual cost and additional metadata
result["cost"] = actual_cost
result["resolution"] = resolution
# Get video dimensions from resolution
resolution_dims = {
"480p": (854, 480),
"720p": (1280, 720),
}
width, height = resolution_dims.get(resolution, (1280, 720))
result["width"] = width
result["height"] = height
logger.info(
f"[InfiniteTalk Adapter] ✅ Generated talking avatar: "
f"resolution={resolution}, duration={result.get('duration', 5.0)}s, cost=${actual_cost:.2f}"
)
return result

View File

@@ -7,6 +7,11 @@ from .edit_service import EditStudioService, EditStudioRequest
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
from .control_service import ControlStudioService, ControlStudioRequest
from .social_optimizer_service import SocialOptimizerService, SocialOptimizerRequest
from .transform_service import (
TransformStudioService,
TransformImageToVideoRequest,
TalkingAvatarRequest,
)
from .templates import Platform, TemplateCategory, ImageTemplate
from utils.logger_utils import get_service_logger
@@ -24,6 +29,7 @@ class ImageStudioManager:
self.upscale_service = UpscaleStudioService()
self.control_service = ControlStudioService()
self.social_optimizer_service = SocialOptimizerService()
self.transform_service = TransformStudioService()
logger.info("[Image Studio Manager] Initialized successfully")
# ====================
@@ -339,4 +345,35 @@ class ImageStudioManager:
}
return specs.get(platform, {})
# ====================
# TRANSFORM STUDIO
# ====================
async def transform_image_to_video(
self,
request: TransformImageToVideoRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Transform image to video using WAN 2.5."""
logger.info("[Image Studio] Transform image-to-video request from user: %s", user_id)
return await self.transform_service.transform_image_to_video(request, user_id=user_id or "anonymous")
async def create_talking_avatar(
self,
request: TalkingAvatarRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Create talking avatar using InfiniteTalk."""
logger.info("[Image Studio] Talking avatar request from user: %s", user_id)
return await self.transform_service.create_talking_avatar(request, user_id=user_id or "anonymous")
def estimate_transform_cost(
self,
operation: str,
resolution: str,
duration: Optional[int] = None,
) -> Dict[str, Any]:
"""Estimate cost for transform operation."""
return self.transform_service.estimate_cost(operation, resolution, duration)

View File

@@ -0,0 +1,379 @@
"""Transform Studio service for image-to-video and talking avatar generation."""
import os
import uuid
from pathlib import Path
from typing import Any, Dict, Optional
from dataclasses import dataclass
from fastapi import HTTPException
from loguru import logger
from .wan25_service import WAN25Service
from .infinitetalk_adapter import InfiniteTalkService
from services.llm_providers.main_video_generation import track_video_usage
from utils.logger_utils import get_service_logger
from utils.file_storage import save_file_safely, sanitize_filename
logger = get_service_logger("image_studio.transform")
@dataclass
class TransformImageToVideoRequest:
"""Request for WAN 2.5 image-to-video."""
image_base64: str
prompt: str
audio_base64: Optional[str] = None
resolution: str = "720p" # 480p, 720p, 1080p
duration: int = 5 # 5 or 10 seconds
negative_prompt: Optional[str] = None
seed: Optional[int] = None
enable_prompt_expansion: bool = True
@dataclass
class TalkingAvatarRequest:
"""Request for InfiniteTalk talking avatar."""
image_base64: str
audio_base64: str
resolution: str = "720p" # 480p or 720p
prompt: Optional[str] = None
mask_image_base64: Optional[str] = None
seed: Optional[int] = None
class TransformStudioService:
"""Service for Transform Studio operations."""
def __init__(self):
"""Initialize Transform Studio service."""
self.wan25_service = WAN25Service()
self.infinitetalk_service = InfiniteTalkService()
# Video output directory
# __file__ is: backend/services/image_studio/transform_service.py
# We need: backend/transform_videos
base_dir = Path(__file__).parent.parent.parent.parent
self.output_dir = base_dir / "transform_videos"
self.output_dir.mkdir(parents=True, exist_ok=True)
# Verify directory was created
if not self.output_dir.exists():
raise RuntimeError(f"Failed to create transform_videos directory: {self.output_dir}")
logger.info(f"[Transform Studio] Initialized with output directory: {self.output_dir}")
def _save_video_file(
self,
video_bytes: bytes,
operation_type: str,
user_id: str,
) -> Dict[str, Any]:
"""Save video file to disk.
Args:
video_bytes: Video content as bytes
operation_type: Type of operation (e.g., "image-to-video", "talking-avatar")
user_id: User ID for directory organization
Returns:
Dictionary with filename, file_path, and file_url
"""
# Create user-specific directory
user_dir = self.output_dir / user_id
user_dir.mkdir(parents=True, exist_ok=True)
# Generate filename
filename = f"{operation_type}_{uuid.uuid4().hex[:8]}.mp4"
filename = sanitize_filename(filename)
# Save file
file_path, error = save_file_safely(
content=video_bytes,
directory=user_dir,
filename=filename,
max_file_size=500 * 1024 * 1024 # 500MB max for videos
)
if error:
raise HTTPException(
status_code=500,
detail=f"Failed to save video file: {error}"
)
file_url = f"/api/image-studio/videos/{user_id}/{filename}"
return {
"filename": filename,
"file_path": str(file_path),
"file_url": file_url,
"file_size": len(video_bytes),
}
async def transform_image_to_video(
self,
request: TransformImageToVideoRequest,
user_id: str,
) -> Dict[str, Any]:
"""Transform image to video using WAN 2.5.
Args:
request: Transform request
user_id: User ID for tracking and file organization
Returns:
Dictionary with video URL, metadata, and cost
"""
logger.info(
f"[Transform Studio] Image-to-video request from user {user_id}: "
f"resolution={request.resolution}, duration={request.duration}s"
)
# Generate video using WAN 2.5
result = await self.wan25_service.generate_video(
image_base64=request.image_base64,
prompt=request.prompt,
audio_base64=request.audio_base64,
resolution=request.resolution,
duration=request.duration,
negative_prompt=request.negative_prompt,
seed=request.seed,
enable_prompt_expansion=request.enable_prompt_expansion,
)
# Save video to disk
save_result = self._save_video_file(
video_bytes=result["video_bytes"],
operation_type="image-to-video",
user_id=user_id,
)
# Track usage
try:
usage_info = track_video_usage(
user_id=user_id,
provider=result["provider"],
model_name=result["model_name"],
prompt=result["prompt"],
video_bytes=result["video_bytes"],
cost_override=result["cost"],
)
logger.info(
f"[Transform Studio] Usage tracked: {usage_info.get('current_calls', 0)} / "
f"{usage_info.get('video_limit_display', '')} videos, "
f"cost=${result['cost']:.2f}"
)
except Exception as e:
logger.warning(f"[Transform Studio] Failed to track usage: {e}")
# Save to asset library
try:
from services.database import get_db
from utils.asset_tracker import save_asset_to_library
db = next(get_db())
try:
save_asset_to_library(
db=db,
user_id=user_id,
asset_type="video",
source_module="image_studio",
filename=save_result["filename"],
file_url=save_result["file_url"],
file_path=save_result["file_path"],
file_size=save_result["file_size"],
mime_type="video/mp4",
title=f"Transform: Image-to-Video ({request.resolution})",
description=f"Generated video using WAN 2.5: {request.prompt[:100]}",
prompt=result["prompt"],
tags=["image_studio", "transform", "video", "image-to-video", request.resolution],
provider=result["provider"],
model=result["model_name"],
cost=result["cost"],
asset_metadata={
"resolution": request.resolution,
"duration": result["duration"],
"operation": "image-to-video",
"width": result["width"],
"height": result["height"],
}
)
logger.info(f"[Transform Studio] Video saved to asset library")
finally:
db.close()
except Exception as e:
logger.warning(f"[Transform Studio] Failed to save to asset library: {e}")
return {
"success": True,
"video_url": save_result["file_url"],
"video_base64": None, # Don't include base64 for large videos
"duration": result["duration"],
"resolution": result["resolution"],
"width": result["width"],
"height": result["height"],
"file_size": save_result["file_size"],
"cost": result["cost"],
"provider": result["provider"],
"model": result["model_name"],
"metadata": result.get("metadata", {}),
}
async def create_talking_avatar(
self,
request: TalkingAvatarRequest,
user_id: str,
) -> Dict[str, Any]:
"""Create talking avatar using InfiniteTalk.
Args:
request: Talking avatar request
user_id: User ID for tracking and file organization
Returns:
Dictionary with video URL, metadata, and cost
"""
logger.info(
f"[Transform Studio] Talking avatar request from user {user_id}: "
f"resolution={request.resolution}"
)
# Generate video using InfiniteTalk
result = await self.infinitetalk_service.create_talking_avatar(
image_base64=request.image_base64,
audio_base64=request.audio_base64,
resolution=request.resolution,
prompt=request.prompt,
mask_image_base64=request.mask_image_base64,
seed=request.seed,
user_id=user_id,
)
# Save video to disk
save_result = self._save_video_file(
video_bytes=result["video_bytes"],
operation_type="talking-avatar",
user_id=user_id,
)
# Track usage
try:
usage_info = track_video_usage(
user_id=user_id,
provider=result["provider"],
model_name=result["model_name"],
prompt=result.get("prompt", ""),
video_bytes=result["video_bytes"],
cost_override=result["cost"],
)
logger.info(
f"[Transform Studio] Usage tracked: {usage_info.get('current_calls', 0)} / "
f"{usage_info.get('video_limit_display', '')} videos, "
f"cost=${result['cost']:.2f}"
)
except Exception as e:
logger.warning(f"[Transform Studio] Failed to track usage: {e}")
# Save to asset library
try:
from services.database import get_db
from utils.asset_tracker import save_asset_to_library
db = next(get_db())
try:
save_asset_to_library(
db=db,
user_id=user_id,
asset_type="video",
source_module="image_studio",
filename=save_result["filename"],
file_url=save_result["file_url"],
file_path=save_result["file_path"],
file_size=save_result["file_size"],
mime_type="video/mp4",
title=f"Transform: Talking Avatar ({request.resolution})",
description="Generated talking avatar video using InfiniteTalk",
prompt=result.get("prompt", ""),
tags=["image_studio", "transform", "video", "talking-avatar", request.resolution],
provider=result["provider"],
model=result["model_name"],
cost=result["cost"],
asset_metadata={
"resolution": request.resolution,
"duration": result.get("duration", 5.0),
"operation": "talking-avatar",
"width": result.get("width", 1280),
"height": result.get("height", 720),
}
)
logger.info(f"[Transform Studio] Video saved to asset library")
finally:
db.close()
except Exception as e:
logger.warning(f"[Transform Studio] Failed to save to asset library: {e}")
return {
"success": True,
"video_url": save_result["file_url"],
"video_base64": None, # Don't include base64 for large videos
"duration": result.get("duration", 5.0),
"resolution": result.get("resolution", request.resolution),
"width": result.get("width", 1280),
"height": result.get("height", 720),
"file_size": save_result["file_size"],
"cost": result["cost"],
"provider": result["provider"],
"model": result["model_name"],
"metadata": result.get("metadata", {}),
}
def estimate_cost(
self,
operation: str,
resolution: str,
duration: Optional[int] = None,
) -> Dict[str, Any]:
"""Estimate cost for transform operation.
Args:
operation: Operation type ("image-to-video" or "talking-avatar")
resolution: Output resolution
duration: Video duration in seconds (for image-to-video)
Returns:
Cost estimation details
"""
if operation == "image-to-video":
if duration is None:
duration = 5
cost = self.wan25_service.calculate_cost(resolution, duration)
return {
"estimated_cost": cost,
"breakdown": {
"base_cost": 0.0,
"per_second": self.wan25_service.calculate_cost(resolution, 1),
"duration": duration,
"total": cost,
},
"currency": "USD",
"provider": "wavespeed",
"model": "alibaba/wan-2.5/image-to-video",
}
elif operation == "talking-avatar":
# InfiniteTalk minimum is 5 seconds
estimated_duration = duration or 5.0
cost = self.infinitetalk_service.calculate_cost(resolution, estimated_duration)
return {
"estimated_cost": cost,
"breakdown": {
"base_cost": 0.0,
"per_second": self.infinitetalk_service.calculate_cost(resolution, 1.0),
"duration": estimated_duration,
"total": cost,
},
"currency": "USD",
"provider": "wavespeed",
"model": "wavespeed-ai/infinitetalk",
}
else:
raise ValueError(f"Unknown operation: {operation}")

View File

@@ -0,0 +1,295 @@
"""WAN 2.5 service for Alibaba image-to-video generation via WaveSpeed."""
import base64
import asyncio
from typing import Any, Dict, Optional
import requests
from fastapi import HTTPException
from loguru import logger
from services.wavespeed.client import WaveSpeedClient
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.wan25")
WAN25_MODEL_PATH = "alibaba/wan-2.5/image-to-video"
WAN25_MODEL_NAME = "alibaba/wan-2.5/image-to-video"
# Pricing per second (from WaveSpeed docs)
PRICING = {
"480p": 0.05, # $0.05 per second
"720p": 0.10, # $0.10 per second
"1080p": 0.15, # $0.15 per second
}
MAX_IMAGE_BYTES = 10 * 1024 * 1024 # 10MB (recommended)
MAX_AUDIO_BYTES = 15 * 1024 * 1024 # 15MB (API limit)
MIN_AUDIO_DURATION = 3 # seconds
MAX_AUDIO_DURATION = 30 # seconds
def _as_data_uri(content_bytes: bytes, mime_type: str) -> str:
"""Convert bytes to data URI."""
encoded = base64.b64encode(content_bytes).decode("utf-8")
return f"data:{mime_type};base64,{encoded}"
def _decode_base64_image(image_base64: str) -> tuple[bytes, str]:
"""Decode base64 image, handling data URIs."""
if image_base64.startswith("data:"):
# Extract mime type and base64 data
if "," not in image_base64:
raise ValueError("Invalid data URI format: missing comma separator")
header, encoded = image_base64.split(",", 1)
mime_parts = header.split(":")[1].split(";")[0] if ":" in header else "image/png"
mime_type = mime_parts.strip()
if not mime_type:
mime_type = "image/png"
image_bytes = base64.b64decode(encoded)
else:
# Assume it's raw base64
image_bytes = base64.b64decode(image_base64)
mime_type = "image/png" # Default
return image_bytes, mime_type
def _decode_base64_audio(audio_base64: str) -> tuple[bytes, str]:
"""Decode base64 audio, handling data URIs."""
if audio_base64.startswith("data:"):
if "," not in audio_base64:
raise ValueError("Invalid data URI format: missing comma separator")
header, encoded = audio_base64.split(",", 1)
mime_parts = header.split(":")[1].split(";")[0] if ":" in header else "audio/mpeg"
mime_type = mime_parts.strip()
if not mime_type:
mime_type = "audio/mpeg"
audio_bytes = base64.b64decode(encoded)
else:
audio_bytes = base64.b64decode(audio_base64)
mime_type = "audio/mpeg" # Default
return audio_bytes, mime_type
class WAN25Service:
"""Service for Alibaba WAN 2.5 image-to-video generation."""
def __init__(self, client: Optional[WaveSpeedClient] = None):
"""Initialize WAN 2.5 service."""
self.client = client or WaveSpeedClient()
logger.info("[WAN 2.5] Service initialized")
def calculate_cost(self, resolution: str, duration: int) -> float:
"""Calculate cost for video generation.
Args:
resolution: Output resolution (480p, 720p, 1080p)
duration: Video duration in seconds (5 or 10)
Returns:
Cost in USD
"""
cost_per_second = PRICING.get(resolution, PRICING["720p"])
return cost_per_second * duration
async def generate_video(
self,
image_base64: str,
prompt: str,
audio_base64: Optional[str] = None,
resolution: str = "720p",
duration: int = 5,
negative_prompt: Optional[str] = None,
seed: Optional[int] = None,
enable_prompt_expansion: bool = True,
) -> Dict[str, Any]:
"""Generate video using WAN 2.5.
Args:
image_base64: Image in base64 or data URI format
prompt: Text prompt describing the video
audio_base64: Optional audio file (wav/mp3, 3-30s, ≤15MB)
resolution: Output resolution (480p, 720p, 1080p)
duration: Video duration in seconds (5 or 10)
negative_prompt: Optional negative prompt
seed: Optional random seed for reproducibility
enable_prompt_expansion: Enable prompt optimizer
Returns:
Dictionary with video bytes, metadata, and cost
"""
# Validate resolution
if resolution not in PRICING:
raise HTTPException(
status_code=400,
detail=f"Invalid resolution: {resolution}. Must be one of: {list(PRICING.keys())}"
)
# Validate duration
if duration not in [5, 10]:
raise HTTPException(
status_code=400,
detail=f"Invalid duration: {duration}. Must be 5 or 10 seconds"
)
# Validate prompt
if not prompt or not prompt.strip():
raise HTTPException(
status_code=400,
detail="Prompt is required and cannot be empty"
)
# Decode image
try:
image_bytes, image_mime = _decode_base64_image(image_base64)
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Failed to decode image: {str(e)}"
)
# Validate image size
if len(image_bytes) > MAX_IMAGE_BYTES:
raise HTTPException(
status_code=400,
detail=f"Image exceeds {MAX_IMAGE_BYTES / (1024*1024):.0f}MB limit"
)
# Build payload
payload = {
"image": _as_data_uri(image_bytes, image_mime),
"prompt": prompt,
"resolution": resolution,
"duration": duration,
"enable_prompt_expansion": enable_prompt_expansion,
}
# Add optional audio
if audio_base64:
try:
audio_bytes, audio_mime = _decode_base64_audio(audio_base64)
# Validate audio size
if len(audio_bytes) > MAX_AUDIO_BYTES:
raise HTTPException(
status_code=400,
detail=f"Audio exceeds {MAX_AUDIO_BYTES / (1024*1024):.0f}MB limit"
)
# Note: Audio duration validation would require audio analysis
# For now, we rely on API to handle it (API keeps first 5s/10s if longer)
payload["audio"] = _as_data_uri(audio_bytes, audio_mime)
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Failed to decode audio: {str(e)}"
)
# Add optional parameters
if negative_prompt:
payload["negative_prompt"] = negative_prompt
if seed is not None:
payload["seed"] = seed
# Submit to WaveSpeed
logger.info(
f"[WAN 2.5] Submitting video generation request: resolution={resolution}, duration={duration}s"
)
try:
prediction_id = self.client.submit_image_to_video(
WAN25_MODEL_PATH,
payload,
timeout=60
)
except HTTPException as e:
logger.error(f"[WAN 2.5] Submission failed: {e.detail}")
raise
# Poll for completion
logger.info(f"[WAN 2.5] Polling for completion: prediction_id={prediction_id}")
try:
# WAN 2.5 typically takes 1-2 minutes
result = self.client.poll_until_complete(
prediction_id,
timeout_seconds=180, # 3 minutes max
interval_seconds=2.0
)
except HTTPException as e:
detail = e.detail or {}
if isinstance(detail, dict):
detail.setdefault("prediction_id", prediction_id)
detail.setdefault("resume_available", True)
raise HTTPException(status_code=e.status_code, detail=detail)
# Extract video URL
outputs = result.get("outputs") or []
if not outputs:
raise HTTPException(
status_code=502,
detail="WAN 2.5 completed but returned no outputs"
)
video_url = outputs[0]
if not isinstance(video_url, str) or not video_url.startswith("http"):
raise HTTPException(
status_code=502,
detail=f"Invalid video URL format: {video_url}"
)
# Download video (run synchronous request in thread)
logger.info(f"[WAN 2.5] Downloading video from: {video_url}")
video_response = await asyncio.to_thread(
requests.get,
video_url,
timeout=180
)
if video_response.status_code != 200:
raise HTTPException(
status_code=502,
detail={
"error": "Failed to download WAN 2.5 video",
"status_code": video_response.status_code,
"response": video_response.text[:200],
}
)
video_bytes = video_response.content
metadata = result.get("metadata") or {}
# Calculate cost
cost = self.calculate_cost(resolution, duration)
# Get video dimensions from resolution
resolution_dims = {
"480p": (854, 480),
"720p": (1280, 720),
"1080p": (1920, 1080),
}
width, height = resolution_dims.get(resolution, (1280, 720))
logger.info(
f"[WAN 2.5] ✅ Generated video: {len(video_bytes)} bytes, "
f"resolution={resolution}, duration={duration}s, cost=${cost:.2f}"
)
return {
"video_bytes": video_bytes,
"prompt": prompt,
"duration": float(duration),
"model_name": WAN25_MODEL_NAME,
"cost": cost,
"provider": "wavespeed",
"source_video_url": video_url,
"prediction_id": prediction_id,
"resolution": resolution,
"width": width,
"height": height,
"metadata": metadata,
}

View File

@@ -0,0 +1,20 @@
"""Product Marketing Suite service package."""
from .orchestrator import ProductMarketingOrchestrator
from .brand_dna_sync import BrandDNASyncService
from .prompt_builder import ProductMarketingPromptBuilder
from .asset_audit import AssetAuditService
from .channel_pack import ChannelPackService
from .campaign_storage import CampaignStorageService
from .product_image_service import ProductImageService
__all__ = [
"ProductMarketingOrchestrator",
"BrandDNASyncService",
"ProductMarketingPromptBuilder",
"AssetAuditService",
"ChannelPackService",
"CampaignStorageService",
"ProductImageService",
]

View File

@@ -0,0 +1,205 @@
"""
Asset Audit Service
Analyzes uploaded assets and recommends enhancement operations.
"""
from typing import Dict, Any, List, Optional
from loguru import logger
import base64
from io import BytesIO
from PIL import Image
class AssetAuditService:
"""Service to audit assets and recommend enhancements."""
def __init__(self):
"""Initialize Asset Audit Service."""
self.logger = logger
logger.info("[Asset Audit] Service initialized")
def audit_asset(
self,
image_base64: str,
asset_metadata: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Audit an uploaded asset and recommend enhancement operations.
Args:
image_base64: Base64 encoded image
asset_metadata: Optional metadata about the asset
Returns:
Audit results with recommendations
"""
try:
# Decode image
image_bytes = self._decode_base64(image_base64)
if not image_bytes:
raise ValueError("Invalid image data")
# Analyze image
image = Image.open(BytesIO(image_bytes))
width, height = image.size
format_type = image.format or "PNG"
mode = image.mode
# Basic quality checks
quality_score = self._assess_quality(image, width, height)
# Generate recommendations
recommendations = []
# Resolution recommendations
if width < 1080 or height < 1080:
recommendations.append({
"operation": "upscale",
"priority": "high",
"reason": f"Image resolution ({width}x{height}) is below recommended 1080p for social media",
"suggested_mode": "fast" if width < 512 else "conservative",
})
# Background recommendations
if mode == "RGBA" and self._has_transparency(image):
recommendations.append({
"operation": "remove_background",
"priority": "low",
"reason": "Image already has transparency, background removal may not be needed",
})
else:
recommendations.append({
"operation": "remove_background",
"priority": "medium",
"reason": "Background removal can create versatile product images",
})
# Enhancement recommendations based on quality
if quality_score < 0.7:
recommendations.append({
"operation": "enhance",
"priority": "high",
"reason": f"Image quality score ({quality_score:.2f}) suggests enhancement needed",
"suggested_operations": ["upscale", "general_edit"],
})
# Format recommendations
if format_type not in ["PNG", "JPEG"]:
recommendations.append({
"operation": "convert",
"priority": "low",
"reason": f"Format {format_type} may not be optimal for web/social media",
"suggested_format": "PNG" if mode == "RGBA" else "JPEG",
})
audit_result = {
"asset_info": {
"width": width,
"height": height,
"format": format_type,
"mode": mode,
"quality_score": quality_score,
},
"recommendations": recommendations,
"status": "usable" if quality_score > 0.6 else "needs_enhancement",
}
logger.info(f"[Asset Audit] Audited asset: {width}x{height}, quality: {quality_score:.2f}")
return audit_result
except Exception as e:
logger.error(f"[Asset Audit] Error auditing asset: {str(e)}")
return {
"asset_info": {},
"recommendations": [],
"status": "error",
"error": str(e),
}
def _decode_base64(self, image_base64: str) -> Optional[bytes]:
"""Decode base64 image data."""
try:
if image_base64.startswith("data:"):
_, b64data = image_base64.split(",", 1)
else:
b64data = image_base64
return base64.b64decode(b64data)
except Exception as e:
logger.error(f"[Asset Audit] Error decoding base64: {str(e)}")
return None
def _has_transparency(self, image: Image.Image) -> bool:
"""Check if image has transparency."""
if image.mode in ("RGBA", "LA"):
alpha = image.split()[-1]
return any(pixel < 255 for pixel in alpha.getdata())
return False
def _assess_quality(self, image: Image.Image, width: int, height: int) -> float:
"""
Assess image quality score (0.0 to 1.0).
Simple heuristic based on resolution and format.
"""
score = 0.5 # Base score
# Resolution scoring
min_dimension = min(width, height)
if min_dimension >= 1080:
score += 0.3
elif min_dimension >= 512:
score += 0.2
elif min_dimension >= 256:
score += 0.1
# Format scoring
if image.format in ["PNG", "JPEG"]:
score += 0.1
# Mode scoring
if image.mode in ["RGB", "RGBA"]:
score += 0.1
return min(score, 1.0)
def batch_audit_assets(
self,
assets: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""
Audit multiple assets in batch.
Args:
assets: List of asset dictionaries with 'image_base64' and optional 'metadata'
Returns:
Batch audit results
"""
results = []
for asset in assets:
audit_result = self.audit_asset(
asset.get('image_base64'),
asset.get('metadata')
)
results.append({
"asset_id": asset.get('id'),
"audit": audit_result,
})
# Summary statistics
total_assets = len(results)
usable_count = sum(1 for r in results if r["audit"]["status"] == "usable")
needs_enhancement_count = sum(
1 for r in results if r["audit"]["status"] == "needs_enhancement"
)
return {
"results": results,
"summary": {
"total_assets": total_assets,
"usable": usable_count,
"needs_enhancement": needs_enhancement_count,
"error": total_assets - usable_count - needs_enhancement_count,
},
}

View File

@@ -0,0 +1,176 @@
"""
Brand DNA Sync Service
Normalizes persona data and onboarding information into reusable brand tokens.
"""
from typing import Dict, Any, Optional
from loguru import logger
from services.onboarding import OnboardingDatabaseService
from services.database import SessionLocal
class BrandDNASyncService:
"""Service to sync and normalize brand DNA from onboarding and persona data."""
def __init__(self):
"""Initialize Brand DNA Sync Service."""
self.logger = logger
logger.info("[Brand DNA Sync] Service initialized")
def get_brand_dna_tokens(self, user_id: str) -> Dict[str, Any]:
"""
Extract and normalize brand DNA tokens from onboarding and persona data.
Args:
user_id: User ID to fetch data for
Returns:
Dictionary of brand DNA tokens ready for prompt injection
"""
try:
db = SessionLocal()
try:
onboarding_db = OnboardingDatabaseService(db)
website_analysis = onboarding_db.get_website_analysis(user_id, db)
persona_data = onboarding_db.get_persona_data(user_id, db)
competitor_analyses = onboarding_db.get_competitor_analysis(user_id, db)
finally:
db.close()
brand_tokens = {
"writing_style": {},
"target_audience": {},
"visual_identity": {},
"persona": {},
"competitive_positioning": {},
}
# Extract writing style from website analysis
if website_analysis:
writing_style = website_analysis.get('writing_style') or {}
target_audience = website_analysis.get('target_audience') or {}
brand_analysis = website_analysis.get('brand_analysis') or {}
style_guidelines = website_analysis.get('style_guidelines') or {}
# Ensure writing_style is a dict before accessing
if isinstance(writing_style, dict):
brand_tokens["writing_style"] = {
"tone": writing_style.get('tone', 'professional'),
"voice": writing_style.get('voice', 'authoritative'),
"complexity": writing_style.get('complexity', 'intermediate'),
"engagement_level": writing_style.get('engagement_level', 'moderate'),
}
# Ensure target_audience is a dict before accessing
if isinstance(target_audience, dict):
brand_tokens["target_audience"] = {
"demographics": target_audience.get('demographics', []),
"industry_focus": target_audience.get('industry_focus', 'general'),
"expertise_level": target_audience.get('expertise_level', 'intermediate'),
}
# Ensure brand_analysis is a dict before accessing
if isinstance(brand_analysis, dict) and brand_analysis:
brand_tokens["visual_identity"] = {
"color_palette": brand_analysis.get('color_palette', []),
"brand_values": brand_analysis.get('brand_values', []),
"positioning": brand_analysis.get('positioning', ''),
}
# Add style_guidelines if available and visual_identity exists
if style_guidelines and isinstance(style_guidelines, dict):
if "visual_identity" not in brand_tokens:
brand_tokens["visual_identity"] = {}
brand_tokens["visual_identity"]["style_guidelines"] = style_guidelines
# Extract persona data
if persona_data:
core_persona = persona_data.get('corePersona') or {}
platform_personas = persona_data.get('platformPersonas') or {}
# Ensure core_persona is a dict before accessing
if isinstance(core_persona, dict) and core_persona:
brand_tokens["persona"] = {
"persona_name": core_persona.get('persona_name', ''),
"archetype": core_persona.get('archetype', ''),
"core_belief": core_persona.get('core_belief', ''),
"linguistic_fingerprint": core_persona.get('linguistic_fingerprint', {}),
}
# Ensure persona dict exists before setting platform_personas
if "persona" not in brand_tokens:
brand_tokens["persona"] = {}
# Only set platform_personas if it's a valid dict
if isinstance(platform_personas, dict):
brand_tokens["persona"]["platform_personas"] = platform_personas
# Extract competitive positioning
if competitor_analyses and isinstance(competitor_analyses, list) and len(competitor_analyses) > 0:
# Extract differentiation points
brand_tokens["competitive_positioning"] = {
"differentiators": [],
"unique_value_props": [],
}
for competitor in competitor_analyses[:3]: # Top 3 competitors
if not isinstance(competitor, dict):
continue
analysis_data = competitor.get('analysis_data') or {}
if isinstance(analysis_data, dict) and analysis_data:
competitive_insights = analysis_data.get('competitive_analysis') or {}
if isinstance(competitive_insights, dict) and competitive_insights:
differentiators = competitive_insights.get('differentiators', [])
if isinstance(differentiators, list) and differentiators:
brand_tokens["competitive_positioning"]["differentiators"].extend(
differentiators[:2]
)
logger.info(f"[Brand DNA Sync] Extracted brand tokens for user {user_id}")
return brand_tokens
except Exception as e:
logger.error(f"[Brand DNA Sync] Error extracting brand tokens: {str(e)}")
return {
"writing_style": {"tone": "professional", "voice": "authoritative"},
"target_audience": {"demographics": [], "expertise_level": "intermediate"},
"visual_identity": {},
"persona": {},
"competitive_positioning": {},
}
def get_channel_specific_dna(
self,
user_id: str,
channel: str
) -> Dict[str, Any]:
"""
Get channel-specific brand DNA adaptations.
Args:
user_id: User ID
channel: Target channel (instagram, linkedin, tiktok, etc.)
Returns:
Channel-specific brand DNA tokens
"""
brand_tokens = self.get_brand_dna_tokens(user_id)
channel_dna = brand_tokens.copy()
# Get platform-specific persona if available
persona = brand_tokens.get("persona") or {}
platform_personas = persona.get("platform_personas") or {}
if isinstance(platform_personas, dict) and channel in platform_personas:
platform_persona = platform_personas[channel]
if isinstance(platform_persona, dict):
channel_dna["platform_adaptation"] = {
"content_format_rules": platform_persona.get('content_format_rules') or {},
"engagement_patterns": platform_persona.get('engagement_patterns') or {},
"visual_identity": platform_persona.get('visual_identity') or {},
}
return channel_dna

View File

@@ -0,0 +1,222 @@
"""
Campaign Storage Service
Handles database persistence for campaigns, proposals, and assets.
"""
from typing import Dict, Any, List, Optional
from loguru import logger
from sqlalchemy.orm import Session
from sqlalchemy import desc
from models.product_marketing_models import Campaign, CampaignProposal, CampaignAsset, CampaignStatus
from services.database import SessionLocal
class CampaignStorageService:
"""Service for storing and retrieving campaigns from database."""
def __init__(self):
"""Initialize Campaign Storage Service."""
self.logger = logger
logger.info("[Campaign Storage] Service initialized")
def save_campaign(
self,
user_id: str,
campaign_data: Dict[str, Any]
) -> Campaign:
"""
Save campaign blueprint to database.
Args:
user_id: User ID
campaign_data: Campaign blueprint data
Returns:
Saved Campaign object
"""
db = SessionLocal()
try:
campaign_id = campaign_data.get('campaign_id')
# Check if campaign exists
existing = db.query(Campaign).filter(
Campaign.campaign_id == campaign_id,
Campaign.user_id == user_id
).first()
if existing:
# Update existing campaign
existing.campaign_name = campaign_data.get('campaign_name', existing.campaign_name)
existing.goal = campaign_data.get('goal', existing.goal)
existing.kpi = campaign_data.get('kpi', existing.kpi)
existing.status = campaign_data.get('status', existing.status)
existing.phases = campaign_data.get('phases', existing.phases)
existing.channels = campaign_data.get('channels', existing.channels)
existing.asset_nodes = campaign_data.get('asset_nodes', existing.asset_nodes)
existing.product_context = campaign_data.get('product_context', existing.product_context)
db.commit()
db.refresh(existing)
logger.info(f"[Campaign Storage] Updated campaign {campaign_id}")
return existing
else:
# Create new campaign
campaign = Campaign(
campaign_id=campaign_id,
user_id=user_id,
campaign_name=campaign_data.get('campaign_name'),
goal=campaign_data.get('goal'),
kpi=campaign_data.get('kpi'),
status=campaign_data.get('status', 'draft'),
phases=campaign_data.get('phases'),
channels=campaign_data.get('channels', []),
asset_nodes=campaign_data.get('asset_nodes', []),
product_context=campaign_data.get('product_context'),
)
db.add(campaign)
db.commit()
db.refresh(campaign)
logger.info(f"[Campaign Storage] Saved new campaign {campaign_id}")
return campaign
except Exception as e:
db.rollback()
logger.error(f"[Campaign Storage] Error saving campaign: {str(e)}")
raise
finally:
db.close()
def get_campaign(
self,
user_id: str,
campaign_id: str
) -> Optional[Campaign]:
"""Get campaign by ID."""
db = SessionLocal()
try:
campaign = db.query(Campaign).filter(
Campaign.campaign_id == campaign_id,
Campaign.user_id == user_id
).first()
return campaign
except Exception as e:
logger.error(f"[Campaign Storage] Error getting campaign: {str(e)}")
return None
finally:
db.close()
def list_campaigns(
self,
user_id: str,
status: Optional[str] = None,
limit: int = 50
) -> List[Campaign]:
"""List campaigns for user."""
db = SessionLocal()
try:
query = db.query(Campaign).filter(Campaign.user_id == user_id)
if status:
query = query.filter(Campaign.status == status)
campaigns = query.order_by(desc(Campaign.created_at)).limit(limit).all()
return campaigns
except Exception as e:
logger.error(f"[Campaign Storage] Error listing campaigns: {str(e)}")
return []
finally:
db.close()
def save_proposals(
self,
user_id: str,
campaign_id: str,
proposals: Dict[str, Any]
) -> List[CampaignProposal]:
"""Save asset proposals for a campaign."""
db = SessionLocal()
try:
# Delete existing proposals for this campaign
db.query(CampaignProposal).filter(
CampaignProposal.campaign_id == campaign_id,
CampaignProposal.user_id == user_id
).delete()
# Create new proposals
saved_proposals = []
for asset_id, proposal_data in proposals.get('proposals', {}).items():
proposal = CampaignProposal(
campaign_id=campaign_id,
user_id=user_id,
asset_node_id=asset_id,
asset_type=proposal_data.get('asset_type'),
channel=proposal_data.get('channel'),
proposed_prompt=proposal_data.get('proposed_prompt'),
recommended_template=proposal_data.get('recommended_template'),
recommended_provider=proposal_data.get('recommended_provider'),
recommended_model=proposal_data.get('recommended_model'),
cost_estimate=proposal_data.get('cost_estimate', 0.0),
concept_summary=proposal_data.get('concept_summary'),
status='proposed',
)
db.add(proposal)
saved_proposals.append(proposal)
db.commit()
for proposal in saved_proposals:
db.refresh(proposal)
logger.info(f"[Campaign Storage] Saved {len(saved_proposals)} proposals for campaign {campaign_id}")
return saved_proposals
except Exception as e:
db.rollback()
logger.error(f"[Campaign Storage] Error saving proposals: {str(e)}")
raise
finally:
db.close()
def get_proposals(
self,
user_id: str,
campaign_id: str
) -> List[CampaignProposal]:
"""Get proposals for a campaign."""
db = SessionLocal()
try:
proposals = db.query(CampaignProposal).filter(
CampaignProposal.campaign_id == campaign_id,
CampaignProposal.user_id == user_id
).all()
return proposals
except Exception as e:
logger.error(f"[Campaign Storage] Error getting proposals: {str(e)}")
return []
finally:
db.close()
def update_campaign_status(
self,
user_id: str,
campaign_id: str,
status: str
) -> bool:
"""Update campaign status."""
db = SessionLocal()
try:
campaign = db.query(Campaign).filter(
Campaign.campaign_id == campaign_id,
Campaign.user_id == user_id
).first()
if campaign:
campaign.status = status
db.commit()
logger.info(f"[Campaign Storage] Updated campaign {campaign_id} status to {status}")
return True
return False
except Exception as e:
db.rollback()
logger.error(f"[Campaign Storage] Error updating status: {str(e)}")
return False
finally:
db.close()

View File

@@ -0,0 +1,180 @@
"""
Channel Pack Service
Maps channels to templates, copy frameworks, and platform-specific optimizations.
"""
from typing import Dict, Any, List, Optional
from loguru import logger
from services.image_studio.templates import Platform, TemplateManager
from services.image_studio.social_optimizer_service import SocialOptimizerService
class ChannelPackService:
"""Service to build channel-specific asset packs."""
def __init__(self):
"""Initialize Channel Pack Service."""
self.template_manager = TemplateManager()
self.social_optimizer = SocialOptimizerService()
self.logger = logger
logger.info("[Channel Pack] Service initialized")
def get_channel_pack(
self,
channel: str,
asset_type: str = "social_post"
) -> Dict[str, Any]:
"""
Get channel-specific pack configuration.
Args:
channel: Target channel (instagram, linkedin, tiktok, facebook, twitter, pinterest, youtube)
asset_type: Type of asset (social_post, story, reel, cover, etc.)
Returns:
Channel pack configuration with templates, dimensions, copy frameworks
"""
try:
# Map channel string to Platform enum
platform_map = {
'instagram': Platform.INSTAGRAM,
'linkedin': Platform.LINKEDIN,
'tiktok': Platform.TIKTOK,
'facebook': Platform.FACEBOOK,
'twitter': Platform.TWITTER,
'pinterest': Platform.PINTEREST,
'youtube': Platform.YOUTUBE,
}
platform = platform_map.get(channel.lower())
if not platform:
raise ValueError(f"Unsupported channel: {channel}")
# Get templates for this platform
templates = self.template_manager.get_platform_templates().get(platform, [])
# Get platform formats
formats = self.social_optimizer.get_platform_formats(platform)
# Build channel pack
pack = {
"channel": channel,
"platform": platform.value,
"asset_type": asset_type,
"templates": [
{
"id": t.id,
"name": t.name,
"dimensions": f"{t.aspect_ratio.width}x{t.aspect_ratio.height}",
"aspect_ratio": t.aspect_ratio.ratio,
"recommended_provider": t.recommended_provider,
"quality": t.quality,
}
for t in templates
],
"formats": formats,
"copy_framework": self._get_copy_framework(channel, asset_type),
"optimization_tips": self._get_optimization_tips(channel),
}
logger.info(f"[Channel Pack] Built pack for {channel} ({asset_type})")
return pack
except Exception as e:
logger.error(f"[Channel Pack] Error building pack: {str(e)}")
return {
"channel": channel,
"error": str(e),
}
def _get_copy_framework(
self,
channel: str,
asset_type: str
) -> Dict[str, Any]:
"""Get copy framework for channel and asset type."""
frameworks = {
"instagram": {
"social_post": {
"caption_length": "125-150 words optimal",
"hashtags": "5-10 relevant hashtags",
"cta": "Clear call-to-action in first line",
"emoji": "Use 1-3 emojis strategically",
},
"story": {
"text_overlay": "Keep text minimal, readable at small size",
"cta": "Swipe-up or link sticker",
},
},
"linkedin": {
"social_post": {
"length": "150-300 words for maximum engagement",
"hashtags": "3-5 professional hashtags",
"tone": "Professional, thought-leadership focused",
"cta": "Engage with question or call-to-action",
},
},
"tiktok": {
"video": {
"hook": "Strong hook in first 3 seconds",
"caption": "Short, engaging, use trending hashtags",
"hashtags": "3-5 trending hashtags",
},
},
}
return frameworks.get(channel, {}).get(asset_type, {})
def _get_optimization_tips(self, channel: str) -> List[str]:
"""Get optimization tips for channel."""
tips = {
"instagram": [
"Use square (1:1) or portrait (4:5) for feed posts",
"Include text overlay safe zones (15% top/bottom, 10% left/right)",
"Optimize for mobile viewing",
],
"linkedin": {
"Use landscape (1.91:1) for feed posts",
"Professional photography style",
"Include clear value proposition",
},
"tiktok": {
"Vertical format (9:16) required",
"Eye-catching first frame",
"Fast-paced, engaging content",
},
}
return tips.get(channel, [])
def build_multi_channel_pack(
self,
channels: List[str],
source_image_base64: str
) -> Dict[str, Any]:
"""
Build optimized asset pack for multiple channels from single source.
Args:
channels: List of target channels
source_image_base64: Source image to optimize
Returns:
Multi-channel pack with optimized variants
"""
pack_results = []
for channel in channels:
pack = self.get_channel_pack(channel)
pack_results.append({
"channel": channel,
"pack": pack,
})
return {
"source_image": "provided",
"channels": pack_results,
"total_variants": len(channels),
}

View File

@@ -0,0 +1,469 @@
"""
Product Marketing Orchestrator
Main service that orchestrates campaign workflows and asset generation.
"""
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
from loguru import logger
from services.image_studio import ImageStudioManager, CreateStudioRequest
from .prompt_builder import ProductMarketingPromptBuilder
from .brand_dna_sync import BrandDNASyncService
from .asset_audit import AssetAuditService
from .channel_pack import ChannelPackService
from services.database import SessionLocal
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_generation_operations
@dataclass
class CampaignAssetNode:
"""Represents an asset node in the campaign graph."""
asset_id: str
asset_type: str # image, video, text, audio
channel: str
status: str # draft, generating, ready, approved
prompt: Optional[str] = None
template_id: Optional[str] = None
provider: Optional[str] = None
cost_estimate: Optional[float] = None
generated_asset_id: Optional[int] = None # Asset Library ID
@dataclass
class CampaignBlueprint:
"""Campaign blueprint with phases and asset nodes."""
campaign_id: str
campaign_name: str
goal: str
kpi: Optional[str] = None
phases: List[Dict[str, Any]] = None # teaser, launch, nurture
asset_nodes: List[CampaignAssetNode] = None
channels: List[str] = None
status: str = "draft" # draft, generating, ready, published
class ProductMarketingOrchestrator:
"""Main orchestrator for Product Marketing Suite."""
def __init__(self):
"""Initialize Product Marketing Orchestrator."""
self.image_studio = ImageStudioManager()
self.prompt_builder = ProductMarketingPromptBuilder()
self.brand_dna_sync = BrandDNASyncService()
self.asset_audit = AssetAuditService()
self.channel_pack = ChannelPackService()
self.logger = logger
logger.info("[Product Marketing Orchestrator] Initialized")
def create_campaign_blueprint(
self,
user_id: str,
campaign_data: Dict[str, Any]
) -> CampaignBlueprint:
"""
Create campaign blueprint from user input and onboarding data.
Args:
user_id: User ID
campaign_data: Campaign information (name, goal, channels, etc.)
Returns:
Campaign blueprint with asset nodes
"""
try:
import time
campaign_id = campaign_data.get('campaign_id') or f"campaign_{user_id}_{int(time.time())}"
campaign_name = campaign_data.get('campaign_name', 'New Campaign')
goal = campaign_data.get('goal', 'product_launch')
channels = campaign_data.get('channels', [])
# Get brand DNA for personalization
brand_dna = self.brand_dna_sync.get_brand_dna_tokens(user_id)
# Build campaign phases
phases = self._build_campaign_phases(goal, channels)
# Generate asset nodes for each phase and channel
asset_nodes = []
for phase in phases:
phase_name = phase.get('name')
for channel in channels:
# Determine required assets for this phase + channel
required_assets = self._get_required_assets(phase_name, channel)
for asset_type in required_assets:
asset_node = CampaignAssetNode(
asset_id=f"{campaign_id}_{phase_name}_{channel}_{asset_type}",
asset_type=asset_type,
channel=channel,
status="draft",
)
asset_nodes.append(asset_node)
blueprint = CampaignBlueprint(
campaign_id=campaign_id,
campaign_name=campaign_name,
goal=goal,
kpi=campaign_data.get('kpi'),
phases=phases,
asset_nodes=asset_nodes,
channels=channels,
status="draft",
)
logger.info(f"[Orchestrator] Created blueprint for campaign {campaign_id} with {len(asset_nodes)} assets")
return blueprint
except Exception as e:
logger.error(f"[Orchestrator] Error creating blueprint: {str(e)}")
raise
def generate_asset_proposals(
self,
user_id: str,
blueprint: CampaignBlueprint,
product_context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Generate AI proposals for each asset node in the blueprint.
Args:
user_id: User ID
blueprint: Campaign blueprint
product_context: Product information
Returns:
Dictionary with proposals for each asset node
"""
try:
proposals = {}
for asset_node in blueprint.asset_nodes:
# Build specialized prompt based on asset type and channel
if asset_node.asset_type == "image":
base_prompt = product_context.get('product_description', 'Product image') if product_context else 'Marketing image'
enhanced_prompt = self.prompt_builder.build_marketing_image_prompt(
base_prompt=base_prompt,
user_id=user_id,
channel=asset_node.channel,
asset_type="hero_image",
product_context=product_context,
)
# Get channel pack for template recommendations
channel_pack = self.channel_pack.get_channel_pack(asset_node.channel)
recommended_template = channel_pack.get('templates', [{}])[0] if channel_pack.get('templates') else None
# Estimate cost
cost_estimate = self._estimate_asset_cost("image", asset_node.channel)
proposals[asset_node.asset_id] = {
"asset_id": asset_node.asset_id,
"asset_type": asset_node.asset_type,
"channel": asset_node.channel,
"proposed_prompt": enhanced_prompt,
"recommended_template": recommended_template.get('id') if recommended_template else None,
"recommended_provider": recommended_template.get('recommended_provider', 'wavespeed') if recommended_template else 'wavespeed',
"cost_estimate": cost_estimate,
"concept_summary": self._generate_concept_summary(enhanced_prompt),
}
elif asset_node.asset_type == "text":
base_request = f"Write {asset_node.channel} {asset_node.asset_type} for product launch"
enhanced_prompt = self.prompt_builder.build_marketing_copy_prompt(
base_request=base_request,
user_id=user_id,
channel=asset_node.channel,
content_type="caption",
product_context=product_context,
)
proposals[asset_node.asset_id] = {
"asset_id": asset_node.asset_id,
"asset_type": asset_node.asset_type,
"channel": asset_node.channel,
"proposed_prompt": enhanced_prompt,
"cost_estimate": 0.0, # Text generation cost is minimal
"concept_summary": "Marketing copy optimized for channel and persona",
}
logger.info(f"[Orchestrator] Generated {len(proposals)} asset proposals")
return {"proposals": proposals, "total_assets": len(proposals)}
except Exception as e:
logger.error(f"[Orchestrator] Error generating proposals: {str(e)}")
raise
async def generate_asset(
self,
user_id: str,
asset_proposal: Dict[str, Any],
product_context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Generate a single asset using Image Studio APIs.
Args:
user_id: User ID
asset_proposal: Asset proposal from generate_asset_proposals
product_context: Product information
Returns:
Generated asset result
"""
try:
asset_type = asset_proposal.get('asset_type')
if asset_type == "image":
# Build CreateStudioRequest
create_request = CreateStudioRequest(
prompt=asset_proposal.get('proposed_prompt'),
template_id=asset_proposal.get('recommended_template'),
provider=asset_proposal.get('recommended_provider', 'wavespeed'),
quality="premium",
enhance_prompt=True,
use_persona=True,
num_variations=1,
)
# Generate image using Image Studio
result = await self.image_studio.create_image(create_request, user_id=user_id)
# Asset is automatically tracked in Asset Library via Image Studio
return {
"success": True,
"asset_type": "image",
"result": result,
"asset_library_ids": [
r.get('asset_id') for r in result.get('results', [])
if r.get('asset_id')
],
}
elif asset_type == "text":
# Import text generation service and tracker
import asyncio
from services.llm_providers.main_text_generation import llm_text_gen
from utils.text_asset_tracker import save_and_track_text_content
from services.database import SessionLocal
# Get enhanced prompt from proposal
text_prompt = asset_proposal.get('proposed_prompt', '')
channel = asset_proposal.get('channel', 'social')
asset_id = asset_proposal.get('asset_id', '')
# Extract campaign_id - try from asset_proposal first, then from asset_id
# asset_id format: {campaign_id}_{phase}_{channel}_{type}
campaign_id = asset_proposal.get('campaign_id')
if not campaign_id and asset_id and '_' in asset_id:
# Try to extract: asset_id might be "campaign_user123_1234567890_teaser_instagram_text"
# We need to find where phase_name starts (common phases: teaser, launch, nurture)
parts = asset_id.split('_')
# Find phase indicator (usually one of: teaser, launch, nurture)
phase_indicators = ['teaser', 'launch', 'nurture', 'prelaunch', 'postlaunch']
phase_idx = None
for i, part in enumerate(parts):
if part.lower() in phase_indicators:
phase_idx = i
break
if phase_idx and phase_idx > 0:
# Campaign ID is everything before the phase
campaign_id = '_'.join(parts[:phase_idx])
# If still not found, use None (metadata will work without it)
if not campaign_id:
logger.warning(f"[Orchestrator] Could not extract campaign_id from asset_id: {asset_id}")
# Build system prompt for marketing copy
system_prompt = f"""You are an expert marketing copywriter specializing in {channel} content.
Generate compelling, on-brand marketing copy that:
- Is optimized for {channel} platform best practices
- Includes a clear call-to-action
- Uses appropriate tone and style for the platform
- Is concise and engaging
- Aligns with the product marketing context provided
Return only the final copy text without explanations or markdown formatting."""
# Run synchronous llm_text_gen in thread pool
logger.info(f"[Orchestrator] Generating text asset for channel: {channel}")
generated_text = await asyncio.to_thread(
llm_text_gen,
prompt=text_prompt,
system_prompt=system_prompt,
user_id=user_id
)
if not generated_text or not generated_text.strip():
raise ValueError("Text generation returned empty content")
# Save to Asset Library
db = SessionLocal()
asset_library_id = None
try:
asset_library_id = save_and_track_text_content(
db=db,
user_id=user_id,
content=generated_text.strip(),
source_module="product_marketing",
title=f"{channel.title()} Copy: {asset_id.split('_')[-1] if '_' in asset_id else 'Marketing Copy'}",
description=f"Marketing copy for {channel} platform generated from campaign proposal",
prompt=text_prompt,
tags=["product_marketing", channel.lower(), "text", "copy"],
asset_metadata={
"campaign_id": campaign_id,
"asset_id": asset_id,
"asset_type": "text",
"channel": channel,
"concept_summary": asset_proposal.get('concept_summary'),
},
subdirectory="campaigns",
file_extension=".txt"
)
if asset_library_id:
logger.info(f"[Orchestrator] ✅ Text asset saved to library: ID={asset_library_id}")
else:
logger.warning(f"[Orchestrator] ⚠️ Text asset tracking returned None")
except Exception as save_error:
logger.error(f"[Orchestrator] ⚠️ Failed to save text asset to library: {str(save_error)}")
# Continue even if save fails - text is still generated
finally:
db.close()
return {
"success": True,
"asset_type": "text",
"content": generated_text.strip(),
"asset_library_id": asset_library_id,
"channel": channel,
}
else:
raise ValueError(f"Unsupported asset type: {asset_type}")
except Exception as e:
logger.error(f"[Orchestrator] Error generating asset: {str(e)}")
raise
def validate_campaign_preflight(
self,
user_id: str,
blueprint: CampaignBlueprint
) -> Dict[str, Any]:
"""
Validate campaign blueprint against subscription limits before generation.
Args:
user_id: User ID
blueprint: Campaign blueprint
Returns:
Pre-flight validation results
"""
try:
db = SessionLocal()
try:
pricing_service = PricingService(db)
# Count operations needed
image_count = sum(1 for node in blueprint.asset_nodes if node.asset_type == "image")
text_count = sum(1 for node in blueprint.asset_nodes if node.asset_type == "text")
# Estimate total cost
total_cost = 0.0
for node in blueprint.asset_nodes:
if node.cost_estimate:
total_cost += node.cost_estimate
# Validate image generation limits
operations = []
if image_count > 0:
operations.append({
'provider': 'stability', # Default provider
'tokens_requested': 0,
'actual_provider_name': 'wavespeed',
'operation_type': 'image_generation',
})
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations * image_count if operations else []
)
return {
"can_proceed": can_proceed,
"message": message,
"error_details": error_details,
"summary": {
"total_assets": len(blueprint.asset_nodes),
"image_count": image_count,
"text_count": text_count,
"estimated_cost": total_cost,
},
}
finally:
db.close()
except Exception as e:
logger.error(f"[Orchestrator] Error in pre-flight validation: {str(e)}")
return {
"can_proceed": False,
"message": f"Validation error: {str(e)}",
"error_details": {},
}
def _build_campaign_phases(
self,
goal: str,
channels: List[str]
) -> List[Dict[str, Any]]:
"""Build campaign phases based on goal."""
if goal == "product_launch":
return [
{"name": "teaser", "duration_days": 7, "purpose": "Build anticipation"},
{"name": "launch", "duration_days": 3, "purpose": "Official launch"},
{"name": "nurture", "duration_days": 14, "purpose": "Sustain engagement"},
]
else:
return [
{"name": "campaign", "duration_days": 30, "purpose": "Campaign execution"},
]
def _get_required_assets(
self,
phase: str,
channel: str
) -> List[str]:
"""Get required asset types for phase and channel."""
# Default: image for all phases and channels
assets = ["image"]
# Add text/copy for social channels
if channel in ["instagram", "linkedin", "facebook", "twitter"]:
assets.append("text")
return assets
def _estimate_asset_cost(
self,
asset_type: str,
channel: str
) -> float:
"""Estimate cost for asset generation."""
if asset_type == "image":
# Premium quality image: ~5-6 credits
return 5.0
elif asset_type == "text":
return 0.0 # Text generation is typically included
else:
return 0.0
def _generate_concept_summary(self, prompt: str) -> str:
"""Generate a brief concept summary from prompt."""
# Simple extraction: take first 100 chars
return prompt[:100] + "..." if len(prompt) > 100 else prompt

View File

@@ -0,0 +1,634 @@
"""
Product Image Service
Specialized service for generating product-focused images using AI models.
Optimized for e-commerce product photography, product showcases, and product marketing assets.
"""
import hashlib
import time
import os
import shutil
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
from pathlib import Path
from loguru import logger
from services.wavespeed.client import WaveSpeedClient
from utils.asset_tracker import save_asset_to_library
from services.database import SessionLocal
from fastapi import HTTPException
class ProductImageServiceError(Exception):
"""Base exception for Product Image Service errors."""
pass
class ValidationError(ProductImageServiceError):
"""Validation error for invalid requests."""
pass
class ImageGenerationError(ProductImageServiceError):
"""Error during image generation."""
pass
class StorageError(ProductImageServiceError):
"""Error saving image to storage."""
pass
@dataclass
class ProductImageRequest:
"""Request for product image generation."""
product_name: str
product_description: str
environment: str = "studio" # studio, lifestyle, outdoor, minimalist, luxury
background_style: str = "white" # white, transparent, lifestyle, branded
lighting: str = "natural" # natural, studio, dramatic, soft
product_variant: Optional[str] = None # color, size, etc.
angle: Optional[str] = None # front, side, top, 360, etc.
style: str = "photorealistic" # photorealistic, minimalist, luxury, technical
resolution: str = "1024x1024" # 1024x1024, 1280x720, etc.
num_variations: int = 1
brand_colors: Optional[List[str]] = None # Brand color palette
additional_context: Optional[str] = None
@dataclass
class ProductImageResult:
"""Result from product image generation."""
success: bool
product_name: str
image_url: Optional[str] = None
image_bytes: Optional[bytes] = None
asset_id: Optional[int] = None # Asset Library ID
provider: Optional[str] = None
model: Optional[str] = None
cost: float = 0.0
generation_time: float = 0.0
error: Optional[str] = None
class ProductImageService:
"""Service for generating product marketing images."""
# Product photography style presets
ENVIRONMENT_PROMPTS = {
"studio": "professional studio photography, clean white background, even lighting",
"lifestyle": "lifestyle photography, product in use, natural environment, relatable setting",
"outdoor": "outdoor photography, natural lighting, outdoor environment, dynamic setting",
"minimalist": "minimalist product photography, simple composition, clean aesthetic",
"luxury": "luxury product photography, premium aesthetic, sophisticated lighting, high-end",
}
BACKGROUND_STYLES = {
"white": "clean white background",
"transparent": "transparent background, isolated product",
"lifestyle": "lifestyle background, contextual environment",
"branded": "branded background with brand colors",
}
LIGHTING_STYLES = {
"natural": "natural lighting, soft shadows, balanced exposure",
"studio": "professional studio lighting, even illumination, no harsh shadows",
"dramatic": "dramatic lighting, high contrast, artistic shadows",
"soft": "soft diffused lighting, gentle shadows, elegant",
}
# Valid values for request parameters
VALID_ENVIRONMENTS = {"studio", "lifestyle", "outdoor", "minimalist", "luxury"}
VALID_BACKGROUND_STYLES = {"white", "transparent", "lifestyle", "branded"}
VALID_LIGHTING_STYLES = {"natural", "studio", "dramatic", "soft"}
VALID_STYLES = {"photorealistic", "minimalist", "luxury", "technical"}
VALID_ANGLES = {"front", "side", "top", "360"}
# Maximum values
MAX_RESOLUTION = (4096, 4096)
MIN_RESOLUTION = (256, 256)
MAX_NUM_VARIATIONS = 10
MAX_PRODUCT_NAME_LENGTH = 500
MAX_PRODUCT_DESCRIPTION_LENGTH = 2000
def __init__(self):
"""Initialize Product Image Service."""
try:
self.wavespeed_client = WaveSpeedClient()
logger.info("[Product Image Service] Initialized")
except Exception as e:
logger.error(f"[Product Image Service] Failed to initialize WaveSpeed client: {str(e)}")
raise ProductImageServiceError(f"Failed to initialize service: {str(e)}") from e
def validate_request(self, request: ProductImageRequest) -> None:
"""
Validate product image generation request.
Args:
request: Product image generation request
Raises:
ValidationError: If request is invalid
"""
errors = []
# Validate product_name
if not request.product_name or not request.product_name.strip():
errors.append("Product name is required")
elif len(request.product_name) > self.MAX_PRODUCT_NAME_LENGTH:
errors.append(f"Product name must be <= {self.MAX_PRODUCT_NAME_LENGTH} characters")
# Validate product_description
if request.product_description and len(request.product_description) > self.MAX_PRODUCT_DESCRIPTION_LENGTH:
errors.append(f"Product description must be <= {self.MAX_PRODUCT_DESCRIPTION_LENGTH} characters")
# Validate environment
if request.environment not in self.VALID_ENVIRONMENTS:
errors.append(f"Invalid environment: {request.environment}. Valid: {', '.join(self.VALID_ENVIRONMENTS)}")
# Validate background_style
if request.background_style not in self.VALID_BACKGROUND_STYLES:
errors.append(f"Invalid background_style: {request.background_style}. Valid: {', '.join(self.VALID_BACKGROUND_STYLES)}")
# Validate lighting
if request.lighting not in self.VALID_LIGHTING_STYLES:
errors.append(f"Invalid lighting: {request.lighting}. Valid: {', '.join(self.VALID_LIGHTING_STYLES)}")
# Validate style
if request.style not in self.VALID_STYLES:
errors.append(f"Invalid style: {request.style}. Valid: {', '.join(self.VALID_STYLES)}")
# Validate angle
if request.angle and request.angle not in self.VALID_ANGLES:
errors.append(f"Invalid angle: {request.angle}. Valid: {', '.join(self.VALID_ANGLES)}")
# Validate num_variations
if request.num_variations < 1:
errors.append("num_variations must be >= 1")
elif request.num_variations > self.MAX_NUM_VARIATIONS:
errors.append(f"num_variations must be <= {self.MAX_NUM_VARIATIONS}")
# Validate resolution
try:
width, height = self._parse_resolution(request.resolution)
if width < self.MIN_RESOLUTION[0] or height < self.MIN_RESOLUTION[1]:
errors.append(f"Resolution must be >= {self.MIN_RESOLUTION[0]}x{self.MIN_RESOLUTION[1]}")
if width > self.MAX_RESOLUTION[0] or height > self.MAX_RESOLUTION[1]:
errors.append(f"Resolution must be <= {self.MAX_RESOLUTION[0]}x{self.MAX_RESOLUTION[1]}")
except Exception as e:
errors.append(f"Invalid resolution format: {request.resolution}. Error: {str(e)}")
if errors:
raise ValidationError(f"Validation failed: {'; '.join(errors)}")
def build_product_prompt(
self,
request: ProductImageRequest,
brand_context: Optional[Dict[str, Any]] = None
) -> str:
"""
Build optimized prompt for product image generation.
Args:
request: Product image generation request
brand_context: Optional brand DNA context for personalization
Returns:
Optimized prompt string
"""
prompt_parts = []
# Base product description
prompt_parts.append(f"Professional product photography of {request.product_name}")
if request.product_description:
prompt_parts.append(f": {request.product_description}")
# Product variant
if request.product_variant:
prompt_parts.append(f", {request.product_variant}")
# Environment and style
env_prompt = self.ENVIRONMENT_PROMPTS.get(request.environment, self.ENVIRONMENT_PROMPTS["studio"])
prompt_parts.append(f", {env_prompt}")
# Background
bg_prompt = self.BACKGROUND_STYLES.get(request.background_style, self.BACKGROUND_STYLES["white"])
if request.background_style == "branded" and request.brand_colors:
bg_prompt += f", using brand colors: {', '.join(request.brand_colors)}"
prompt_parts.append(f", {bg_prompt}")
# Lighting
lighting_prompt = self.LIGHTING_STYLES.get(request.lighting, self.LIGHTING_STYLES["natural"])
prompt_parts.append(f", {lighting_prompt}")
# Angle/view
if request.angle:
angle_map = {
"front": "front view, centered composition",
"side": "side profile view, showing depth",
"top": "top-down view, flat lay style",
"360": "3/4 angle view, showing multiple sides",
}
angle_prompt = angle_map.get(request.angle, request.angle)
prompt_parts.append(f", {angle_prompt}")
# Style
style_map = {
"photorealistic": "photorealistic, highly detailed, professional photography",
"minimalist": "minimalist aesthetic, clean composition, simple and elegant",
"luxury": "luxury aesthetic, premium quality, sophisticated and refined",
"technical": "technical product photography, detailed features, professional documentation style",
}
style_prompt = style_map.get(request.style, style_map["photorealistic"])
prompt_parts.append(f", {style_prompt}")
# Additional context
if request.additional_context:
prompt_parts.append(f", {request.additional_context}")
# Brand DNA integration (if available)
if brand_context:
brand_tone = brand_context.get("visual_identity", {}).get("style_guidelines")
if brand_tone:
prompt_parts.append(f", brand style: {brand_tone}")
# Quality keywords
prompt_parts.append(", high resolution, professional quality, sharp focus, commercial photography")
full_prompt = " ".join(prompt_parts)
logger.debug(f"[Product Image Service] Built prompt: {full_prompt[:200]}...")
return full_prompt
def _generate_image_with_retry(
self,
model: str,
prompt: str,
width: int,
height: int,
max_retries: int = 3,
retry_delay: float = 2.0
) -> bytes:
"""
Generate image with retry logic for transient failures.
Args:
model: Model to use
prompt: Generation prompt
width: Image width
height: Image height
max_retries: Maximum number of retries
retry_delay: Delay between retries in seconds
Returns:
Generated image bytes
Raises:
ImageGenerationError: If generation fails after retries
"""
last_error = None
for attempt in range(max_retries):
try:
logger.info(f"[Product Image Service] Image generation attempt {attempt + 1}/{max_retries}")
image_bytes = self.wavespeed_client.generate_image(
model=model,
prompt=prompt,
width=width,
height=height,
enable_sync_mode=True,
timeout=120,
)
if not image_bytes:
raise ValueError("Image generation returned empty result")
if len(image_bytes) < 100: # Sanity check: image should be at least 100 bytes
raise ValueError(f"Generated image too small: {len(image_bytes)} bytes")
logger.info(f"[Product Image Service] ✅ Image generated successfully: {len(image_bytes)} bytes")
return image_bytes
except Exception as e:
last_error = e
error_msg = str(e)
logger.warning(f"[Product Image Service] Attempt {attempt + 1} failed: {error_msg}")
# Don't retry on validation errors or client errors (4xx)
if "4" in error_msg or "validation" in error_msg.lower() or "invalid" in error_msg.lower():
logger.error(f"[Product Image Service] Non-retryable error: {error_msg}")
raise ImageGenerationError(f"Image generation failed: {error_msg}") from e
# Retry on transient errors
if attempt < max_retries - 1:
logger.info(f"[Product Image Service] Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
retry_delay *= 1.5 # Exponential backoff
else:
logger.error(f"[Product Image Service] All retry attempts failed")
raise ImageGenerationError(f"Image generation failed after {max_retries} attempts: {str(last_error)}") from last_error
async def generate_product_image(
self,
request: ProductImageRequest,
user_id: str,
brand_context: Optional[Dict[str, Any]] = None
) -> ProductImageResult:
"""
Generate product image using AI models.
Args:
request: Product image generation request
user_id: User ID for tracking
brand_context: Optional brand DNA for personalization
Returns:
ProductImageResult with generated image
"""
start_time = time.time()
try:
# Validate request
self.validate_request(request)
# Validate user_id
if not user_id or not user_id.strip():
raise ValidationError("user_id is required")
# Build optimized prompt
prompt = self.build_product_prompt(request, brand_context)
# Parse resolution
width, height = self._parse_resolution(request.resolution)
# Select model based on style/quality needs
model = "ideogram-v3-turbo" # Default to Ideogram V3 for photorealistic products
if request.style == "minimalist":
model = "ideogram-v3-turbo" # Still use Ideogram for quality
elif request.style == "technical":
model = "ideogram-v3-turbo"
logger.info(f"[Product Image Service] Generating product image for '{request.product_name}' using {model}")
# Generate image using WaveSpeed with retry logic
try:
image_bytes = self._generate_image_with_retry(
model=model,
prompt=prompt,
width=width,
height=height,
max_retries=3,
retry_delay=2.0
)
except ImageGenerationError as e:
logger.error(f"[Product Image Service] Image generation failed: {str(e)}")
generation_time = time.time() - start_time
return ProductImageResult(
success=False,
product_name=request.product_name,
error=f"Image generation failed: {str(e)}",
generation_time=generation_time,
)
# Save image to file and Asset Library
asset_id = None
image_url = None
try:
asset_id, image_url = self._save_product_image(
image_bytes=image_bytes,
request=request,
user_id=user_id,
prompt=prompt,
model=model,
start_time=start_time
)
except StorageError as storage_error:
logger.error(f"[Product Image Service] Storage failed: {str(storage_error)}", exc_info=True)
# Continue with generation result even if storage fails
# The image_bytes is still available in the result
except Exception as save_error:
logger.error(f"[Product Image Service] Unexpected error saving image: {str(save_error)}", exc_info=True)
# Continue even if save fails
generation_time = time.time() - start_time
return ProductImageResult(
success=True,
product_name=request.product_name,
image_url=image_url,
image_bytes=image_bytes,
asset_id=asset_id,
provider="wavespeed",
model=model,
cost=0.10,
generation_time=generation_time,
)
except ValidationError as ve:
logger.error(f"[Product Image Service] Validation error: {str(ve)}")
generation_time = time.time() - start_time
return ProductImageResult(
success=False,
product_name=request.product_name if hasattr(request, 'product_name') else "unknown",
error=f"Validation error: {str(ve)}",
generation_time=generation_time,
)
except Exception as e:
logger.error(f"[Product Image Service] ❌ Unexpected error generating product image: {str(e)}", exc_info=True)
generation_time = time.time() - start_time
return ProductImageResult(
success=False,
product_name=request.product_name if hasattr(request, 'product_name') else "unknown",
error=f"Unexpected error: {str(e)}",
generation_time=generation_time,
)
def _save_product_image(
self,
image_bytes: bytes,
request: ProductImageRequest,
user_id: str,
prompt: str,
model: str,
start_time: float
) -> tuple[Optional[int], Optional[str]]:
"""
Save product image to disk and Asset Library.
Args:
image_bytes: Generated image bytes
request: Product image generation request
user_id: User ID
prompt: Generation prompt
model: Model used
start_time: Generation start time
Returns:
Tuple of (asset_id, image_url)
Raises:
StorageError: If saving fails
"""
db = None
asset_id = None
image_url = None
image_path = None
try:
# Generate filename
product_hash = hashlib.md5(request.product_name.encode()).hexdigest()[:8]
timestamp = int(start_time)
filename = f"product_{product_hash}_{timestamp}.png"
# Determine base directory and create product_images folder
base_dir = Path(__file__).parent.parent.parent
product_images_dir = base_dir / "product_images"
# Create directory with error handling
try:
product_images_dir.mkdir(parents=True, exist_ok=True)
except PermissionError as pe:
raise StorageError(f"Permission denied creating directory: {str(pe)}") from pe
except OSError as oe:
raise StorageError(f"Failed to create directory: {str(oe)}") from oe
# Check disk space (rough estimate - at least 10MB free)
try:
stat = shutil.disk_usage(product_images_dir)
free_space_mb = stat.free / (1024 * 1024)
if free_space_mb < 10:
raise StorageError(f"Insufficient disk space: {free_space_mb:.1f}MB free (need at least 10MB)")
except OSError as oe:
logger.warning(f"[Product Image Service] Could not check disk space: {str(oe)}")
# Save image to disk
image_path = product_images_dir / filename
try:
with open(image_path, "wb") as f:
f.write(image_bytes)
# Verify file was written
if not image_path.exists() or image_path.stat().st_size == 0:
raise StorageError("Image file was not written correctly")
except PermissionError as pe:
raise StorageError(f"Permission denied writing file: {str(pe)}") from pe
except OSError as oe:
raise StorageError(f"Failed to write file: {str(oe)}") from oe
file_size = len(image_bytes)
image_url = f"/api/product-marketing/images/{filename}"
# Save to Asset Library
db = SessionLocal()
try:
asset_id = save_asset_to_library(
db=db,
user_id=user_id,
asset_type="image",
source_module="product_marketing",
filename=filename,
file_url=image_url,
file_path=str(image_path),
file_size=file_size,
mime_type="image/png",
title=f"{request.product_name} - Product Image",
description=f"Product image: {request.product_description or request.product_name}",
prompt=prompt,
tags=["product_marketing", "product_image", request.environment, request.style],
provider="wavespeed",
model=model,
cost=0.10, # Estimated cost for Ideogram V3
asset_metadata={
"product_name": request.product_name,
"product_description": request.product_description,
"environment": request.environment,
"background_style": request.background_style,
"lighting": request.lighting,
"style": request.style,
"variant": request.product_variant,
"angle": request.angle,
},
)
if asset_id:
logger.info(f"[Product Image Service] ✅ Saved product image to Asset Library: ID={asset_id}")
else:
logger.warning(f"[Product Image Service] ⚠️ Asset Library save returned None (file saved but not tracked)")
except Exception as db_error:
logger.error(f"[Product Image Service] Database error saving to Asset Library: {str(db_error)}", exc_info=True)
# File is saved, but database tracking failed
# This is not critical - image is still accessible
raise StorageError(f"Failed to save to Asset Library: {str(db_error)}") from db_error
finally:
if db:
try:
db.close()
except Exception as close_error:
logger.warning(f"[Product Image Service] Error closing database: {str(close_error)}")
return (asset_id, image_url)
except StorageError:
# Clean up partial files on storage error
if image_path and image_path.exists():
try:
image_path.unlink()
logger.info(f"[Product Image Service] Cleaned up partial file: {image_path}")
except Exception as cleanup_error:
logger.warning(f"[Product Image Service] Failed to cleanup partial file: {str(cleanup_error)}")
raise
def _parse_resolution(self, resolution: str) -> tuple[int, int]:
"""
Parse resolution string to width, height tuple.
Args:
resolution: Resolution string (e.g., "1024x1024", "square", "landscape")
Returns:
Tuple of (width, height)
"""
try:
resolution = resolution.strip().lower()
if "x" in resolution:
parts = resolution.split("x")
if len(parts) != 2:
raise ValueError(f"Invalid resolution format: {resolution}")
width = int(parts[0].strip())
height = int(parts[1].strip())
# Validate resolution values
if width < 1 or height < 1:
raise ValueError(f"Resolution dimensions must be positive: {width}x{height}")
return (width, height)
elif resolution == "square":
return (1024, 1024)
elif resolution == "landscape":
return (1280, 720)
elif resolution == "portrait":
return (720, 1280)
else:
# Try to parse as single number (assume square)
try:
size = int(resolution)
return (size, size)
except ValueError:
# Default to square
logger.warning(f"[Product Image Service] Could not parse resolution '{resolution}', defaulting to 1024x1024")
return (1024, 1024)
except Exception as e:
logger.warning(f"[Product Image Service] Error parsing resolution '{resolution}': {str(e)}, defaulting to 1024x1024")
return (1024, 1024)
def estimate_cost(self, request: ProductImageRequest) -> float:
"""Estimate cost for product image generation."""
# Ideogram V3 Turbo: ~$0.10 per image
# Multiply by number of variations
base_cost = 0.10
return base_cost * request.num_variations

View File

@@ -0,0 +1,304 @@
"""
Product Marketing Prompt Builder
Extends AIPromptOptimizer with marketing-specific prompt enhancement.
"""
from typing import Dict, Any, Optional
from loguru import logger
from services.ai_prompt_optimizer import AIPromptOptimizer
from services.onboarding import OnboardingDataService
from services.onboarding.database_service import OnboardingDatabaseService
from services.persona_data_service import PersonaDataService
from services.database import SessionLocal
class ProductMarketingPromptBuilder(AIPromptOptimizer):
"""Specialized prompt builder for marketing assets with onboarding data integration."""
def __init__(self):
"""Initialize Product Marketing Prompt Builder."""
super().__init__()
self.onboarding_data_service = OnboardingDataService()
self.logger = logger
logger.info("[Product Marketing Prompt Builder] Initialized")
def build_marketing_image_prompt(
self,
base_prompt: str,
user_id: str,
channel: Optional[str] = None,
asset_type: str = "hero_image",
product_context: Optional[Dict[str, Any]] = None
) -> str:
"""
Build enhanced marketing image prompt with brand DNA and persona data.
Args:
base_prompt: Base product description or image concept
user_id: User ID to fetch onboarding data
channel: Target channel (instagram, linkedin, tiktok, etc.)
asset_type: Type of asset (hero_image, product_photo, lifestyle, etc.)
product_context: Additional product information
Returns:
Enhanced prompt with brand DNA, persona style, and marketing context
"""
try:
# Get onboarding data
db = SessionLocal()
try:
onboarding_db = OnboardingDatabaseService(db)
website_analysis = onboarding_db.get_website_analysis(user_id, db)
persona_data = onboarding_db.get_persona_data(user_id, db)
competitor_analyses = onboarding_db.get_competitor_analysis(user_id, db)
finally:
db.close()
# Build prompt layers
enhanced_prompt = base_prompt
# Layer 1: Brand DNA (from website_analysis)
if website_analysis:
writing_style = website_analysis.get('writing_style', {})
target_audience = website_analysis.get('target_audience', {})
brand_analysis = website_analysis.get('brand_analysis', {})
style_guidelines = website_analysis.get('style_guidelines', {})
# Add brand tone and style
tone = writing_style.get('tone', 'professional')
voice = writing_style.get('voice', 'authoritative')
brand_enhancement = f", {tone} tone, {voice} voice"
# Add target audience context
demographics = target_audience.get('demographics', [])
if demographics:
audience_context = f", targeting {', '.join(demographics[:2])}"
enhanced_prompt += audience_context
# Add brand visual identity if available
if brand_analysis:
color_palette = brand_analysis.get('color_palette', [])
if color_palette:
colors = ', '.join(color_palette[:3])
enhanced_prompt += f", brand colors: {colors}"
# Layer 2: Persona Visual Style (from persona_data)
if persona_data:
core_persona = persona_data.get('corePersona', {})
platform_personas = persona_data.get('platformPersonas', {})
if core_persona:
persona_name = core_persona.get('persona_name', '')
archetype = core_persona.get('archetype', '')
if persona_name:
enhanced_prompt += f", {persona_name} style"
# Channel-specific persona adaptation
if channel and platform_personas:
platform_persona = platform_personas.get(channel, {})
if platform_persona:
visual_identity = platform_persona.get('visual_identity', {})
if visual_identity:
aesthetic = visual_identity.get('aesthetic_preferences', '')
if aesthetic:
enhanced_prompt += f", {aesthetic} aesthetic"
# Layer 3: Channel Optimization
channel_enhancements = {
'instagram': ', Instagram-optimized composition, vibrant colors, engaging visual',
'linkedin': ', professional photography, clean composition, business-focused',
'tiktok': ', dynamic composition, eye-catching, vertical format optimized',
'facebook': ', social media optimized, engaging, shareable visual',
'twitter': ', Twitter card optimized, clear focal point, readable at small size',
'pinterest': ', Pinterest-optimized, vertical format, detailed and informative',
}
if channel and channel.lower() in channel_enhancements:
enhanced_prompt += channel_enhancements[channel.lower()]
# Layer 4: Asset Type Specific
asset_type_enhancements = {
'hero_image': ', hero image style, prominent product placement, professional photography',
'product_photo': ', product photography, clean background, detailed product showcase',
'lifestyle': ', lifestyle photography, natural setting, authentic scene',
'social_post': ', social media post, engaging composition, optimized for engagement',
}
if asset_type in asset_type_enhancements:
enhanced_prompt += asset_type_enhancements[asset_type]
# Layer 5: Competitive Differentiation
if competitor_analyses and len(competitor_analyses) > 0:
# Extract unique positioning from competitor analysis
enhanced_prompt += ", unique positioning, differentiated visual style"
# Layer 6: Quality Descriptors
enhanced_prompt += ", professional photography, high quality, detailed, sharp focus, natural lighting"
# Layer 7: Marketing Context
if product_context:
marketing_goal = product_context.get('marketing_goal', '')
if marketing_goal:
enhanced_prompt += f", {marketing_goal} focused"
logger.info(f"[Marketing Prompt] Enhanced prompt for user {user_id}: {enhanced_prompt[:200]}...")
return enhanced_prompt
except Exception as e:
logger.error(f"[Marketing Prompt] Error building prompt: {str(e)}")
# Return base prompt with minimal enhancement if error
return f"{base_prompt}, professional photography, high quality"
def build_marketing_copy_prompt(
self,
base_request: str,
user_id: str,
channel: Optional[str] = None,
content_type: str = "caption",
product_context: Optional[Dict[str, Any]] = None
) -> str:
"""
Build enhanced marketing copy prompt with persona linguistic fingerprint.
Args:
base_request: Base content request (e.g., "Write Instagram caption for product launch")
user_id: User ID to fetch onboarding data
channel: Target channel (instagram, linkedin, etc.)
content_type: Type of content (caption, cta, email, ad_copy, etc.)
product_context: Additional product information
Returns:
Enhanced prompt with persona style, brand voice, and marketing context
"""
try:
# Get onboarding data
db = SessionLocal()
try:
onboarding_db = OnboardingDatabaseService(db)
website_analysis = onboarding_db.get_website_analysis(user_id, db)
persona_data = onboarding_db.get_persona_data(user_id, db)
competitor_analyses = onboarding_db.get_competitor_analysis(user_id, db)
finally:
db.close()
# Build enhanced prompt
enhanced_prompt = base_request
# Add persona linguistic fingerprint
if persona_data:
core_persona = persona_data.get('corePersona', {})
platform_personas = persona_data.get('platformPersonas', {})
if core_persona:
persona_name = core_persona.get('persona_name', '')
linguistic_fingerprint = core_persona.get('linguistic_fingerprint', {})
if persona_name:
enhanced_prompt += f"\n\nFollow {persona_name} persona style:"
if linguistic_fingerprint:
sentence_metrics = linguistic_fingerprint.get('sentence_metrics', {})
lexical_features = linguistic_fingerprint.get('lexical_features', {})
if sentence_metrics:
avg_length = sentence_metrics.get('average_sentence_length_words', '')
if avg_length:
enhanced_prompt += f"\n- Average sentence length: {avg_length} words"
if lexical_features:
go_to_words = lexical_features.get('go_to_words', [])
avoid_words = lexical_features.get('avoid_words', [])
vocabulary_level = lexical_features.get('vocabulary_level', '')
if go_to_words:
enhanced_prompt += f"\n- Use these words: {', '.join(go_to_words[:5])}"
if avoid_words:
enhanced_prompt += f"\n- Avoid these words: {', '.join(avoid_words[:5])}"
if vocabulary_level:
enhanced_prompt += f"\n- Vocabulary level: {vocabulary_level}"
# Channel-specific persona adaptation
if channel and platform_personas:
platform_persona = platform_personas.get(channel, {})
if platform_persona:
content_format_rules = platform_persona.get('content_format_rules', {})
engagement_patterns = platform_persona.get('engagement_patterns', {})
if content_format_rules:
char_limit = content_format_rules.get('character_limit', '')
hashtag_strategy = content_format_rules.get('hashtag_strategy', '')
if char_limit:
enhanced_prompt += f"\n- Character limit: {char_limit}"
if hashtag_strategy:
enhanced_prompt += f"\n- Hashtag strategy: {hashtag_strategy}"
# Add brand voice
if website_analysis:
writing_style = website_analysis.get('writing_style', {})
target_audience = website_analysis.get('target_audience', {})
tone = writing_style.get('tone', 'professional')
voice = writing_style.get('voice', 'authoritative')
enhanced_prompt += f"\n- Brand tone: {tone}, Brand voice: {voice}"
demographics = target_audience.get('demographics', [])
expertise_level = target_audience.get('expertise_level', 'intermediate')
if demographics:
enhanced_prompt += f"\n- Target audience: {', '.join(demographics[:2])}, {expertise_level} level"
# Add competitive positioning
if competitor_analyses and len(competitor_analyses) > 0:
enhanced_prompt += "\n- Differentiate from competitors, highlight unique value propositions"
# Add marketing context
if product_context:
marketing_goal = product_context.get('marketing_goal', '')
if marketing_goal:
enhanced_prompt += f"\n- Marketing goal: {marketing_goal}"
logger.info(f"[Marketing Copy Prompt] Enhanced for user {user_id}: {enhanced_prompt[:200]}...")
return enhanced_prompt
except Exception as e:
logger.error(f"[Marketing Copy Prompt] Error building prompt: {str(e)}")
return base_request
def optimize_marketing_prompt(
self,
prompt_type: str,
base_prompt: str,
user_id: str,
context: Optional[Dict[str, Any]] = None
) -> str:
"""
Main entry point for marketing prompt optimization.
Args:
prompt_type: Type of prompt (image, copy, video_script, etc.)
base_prompt: Base prompt to enhance
user_id: User ID for personalization
context: Additional context (channel, asset_type, product_context, etc.)
Returns:
Optimized marketing prompt
"""
context = context or {}
channel = context.get('channel')
asset_type = context.get('asset_type', 'hero_image')
content_type = context.get('content_type', 'caption')
product_context = context.get('product_context')
if prompt_type == 'image':
return self.build_marketing_image_prompt(
base_prompt, user_id, channel, asset_type, product_context
)
elif prompt_type in ['copy', 'caption', 'cta', 'email', 'ad_copy']:
return self.build_marketing_copy_prompt(
base_prompt, user_id, channel, content_type, product_context
)
else:
# Default: minimal enhancement
return f"{base_prompt}, professional quality, marketing optimized"

View File

@@ -51,7 +51,7 @@ def save_asset_to_library(
description: Optional[str] = None,
prompt: Optional[str] = None,
tags: Optional[list] = None,
metadata: Optional[Dict[str, Any]] = None,
asset_metadata: Optional[Dict[str, Any]] = None,
provider: Optional[str] = None,
model: Optional[str] = None,
cost: Optional[float] = None,
@@ -77,7 +77,7 @@ def save_asset_to_library(
description: Asset description (optional)
prompt: Generation prompt (optional)
tags: List of tags (optional)
metadata: Additional metadata (optional)
asset_metadata: Additional metadata (optional)
provider: AI provider used (optional)
model: Model used (optional)
cost: Generation cost (optional)
@@ -143,7 +143,7 @@ def save_asset_to_library(
description=description,
prompt=prompt,
tags=tags or [],
metadata=metadata or {},
asset_metadata=asset_metadata or {},
provider=provider,
model=model,
cost=cost,

View File

@@ -0,0 +1,246 @@
"""
File Storage Utility
Robust file storage helper for saving generated content assets.
"""
import os
import uuid
from pathlib import Path
from typing import Optional, Tuple
import logging
logger = logging.getLogger(__name__)
# Maximum filename length
MAX_FILENAME_LENGTH = 255
# Allowed characters in filenames (alphanumeric, dash, underscore, dot)
ALLOWED_FILENAME_CHARS = set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.')
def sanitize_filename(filename: str, max_length: int = 100) -> str:
"""
Sanitize filename to be filesystem-safe.
Args:
filename: Original filename
max_length: Maximum length for filename
Returns:
Sanitized filename
"""
if not filename:
return f"file_{uuid.uuid4().hex[:8]}"
# Remove path separators and other dangerous characters
sanitized = "".join(c if c in ALLOWED_FILENAME_CHARS else '_' for c in filename)
# Remove leading/trailing dots and spaces
sanitized = sanitized.strip('. ')
# Ensure it's not empty
if not sanitized:
sanitized = f"file_{uuid.uuid4().hex[:8]}"
# Truncate if too long
if len(sanitized) > max_length:
name, ext = os.path.splitext(sanitized)
max_name_length = max_length - len(ext) - 1
sanitized = name[:max_name_length] + ext
return sanitized
def ensure_directory_exists(directory: Path) -> bool:
"""
Ensure directory exists, creating it if necessary.
Args:
directory: Path to directory
Returns:
True if directory exists or was created, False otherwise
"""
try:
directory.mkdir(parents=True, exist_ok=True)
return True
except Exception as e:
logger.error(f"Failed to create directory {directory}: {e}")
return False
def save_file_safely(
content: bytes,
directory: Path,
filename: str,
max_file_size: int = 100 * 1024 * 1024 # 100MB default
) -> Tuple[Optional[Path], Optional[str]]:
"""
Safely save file content to disk.
Args:
content: File content as bytes
directory: Directory to save file in
filename: Filename (will be sanitized)
max_file_size: Maximum allowed file size in bytes
Returns:
Tuple of (file_path, error_message). file_path is None on error.
"""
try:
# Validate file size
if len(content) > max_file_size:
return None, f"File size {len(content)} exceeds maximum {max_file_size}"
if len(content) == 0:
return None, "File content is empty"
# Ensure directory exists
if not ensure_directory_exists(directory):
return None, f"Failed to create directory: {directory}"
# Sanitize filename
safe_filename = sanitize_filename(filename)
# Construct full path
file_path = directory / safe_filename
# Check if file already exists (unlikely with UUID, but check anyway)
if file_path.exists():
# Add UUID to make it unique
name, ext = os.path.splitext(safe_filename)
safe_filename = f"{name}_{uuid.uuid4().hex[:8]}{ext}"
file_path = directory / safe_filename
# Write file atomically (write to temp file first, then rename)
temp_path = file_path.with_suffix(file_path.suffix + '.tmp')
try:
with open(temp_path, 'wb') as f:
f.write(content)
# Atomic rename
temp_path.replace(file_path)
logger.info(f"Successfully saved file: {file_path} ({len(content)} bytes)")
return file_path, None
except Exception as write_error:
# Clean up temp file if it exists
if temp_path.exists():
try:
temp_path.unlink()
except:
pass
raise write_error
except Exception as e:
logger.error(f"Error saving file: {e}", exc_info=True)
return None, str(e)
def generate_unique_filename(
prefix: str,
extension: str = ".png",
include_uuid: bool = True
) -> str:
"""
Generate a unique filename.
Args:
prefix: Filename prefix
extension: File extension (with or without dot)
include_uuid: Whether to include UUID in filename
Returns:
Unique filename
"""
if not extension.startswith('.'):
extension = '.' + extension
prefix = sanitize_filename(prefix, max_length=50)
if include_uuid:
unique_id = uuid.uuid4().hex[:8]
return f"{prefix}_{unique_id}{extension}"
else:
return f"{prefix}{extension}"
def save_text_file_safely(
content: str,
directory: Path,
filename: str,
encoding: str = 'utf-8',
max_file_size: int = 10 * 1024 * 1024 # 10MB default for text
) -> Tuple[Optional[Path], Optional[str]]:
"""
Safely save text content to disk.
Args:
content: Text content as string
directory: Directory to save file in
filename: Filename (will be sanitized)
encoding: Text encoding (default: utf-8)
max_file_size: Maximum allowed file size in bytes
Returns:
Tuple of (file_path, error_message). file_path is None on error.
"""
try:
# Validate content
if not content or not isinstance(content, str):
return None, "Content must be a non-empty string"
# Convert to bytes for size check
content_bytes = content.encode(encoding)
# Validate file size
if len(content_bytes) > max_file_size:
return None, f"File size {len(content_bytes)} exceeds maximum {max_file_size}"
# Ensure directory exists
if not ensure_directory_exists(directory):
return None, f"Failed to create directory: {directory}"
# Sanitize filename
safe_filename = sanitize_filename(filename)
# Ensure .txt extension if not present
if not safe_filename.endswith(('.txt', '.md', '.json')):
safe_filename = os.path.splitext(safe_filename)[0] + '.txt'
# Construct full path
file_path = directory / safe_filename
# Check if file already exists
if file_path.exists():
# Add UUID to make it unique
name, ext = os.path.splitext(safe_filename)
safe_filename = f"{name}_{uuid.uuid4().hex[:8]}{ext}"
file_path = directory / safe_filename
# Write file atomically (write to temp file first, then rename)
temp_path = file_path.with_suffix(file_path.suffix + '.tmp')
try:
with open(temp_path, 'w', encoding=encoding) as f:
f.write(content)
# Atomic rename
temp_path.replace(file_path)
logger.info(f"Successfully saved text file: {file_path} ({len(content_bytes)} bytes, {len(content)} chars)")
return file_path, None
except Exception as write_error:
# Clean up temp file if it exists
if temp_path.exists():
try:
temp_path.unlink()
except:
pass
raise write_error
except Exception as e:
logger.error(f"Error saving text file: {e}", exc_info=True)
return None, str(e)

View File

@@ -0,0 +1,133 @@
"""
Text Asset Tracker Utility
Helper utility for saving and tracking text content as files in the asset library.
"""
from typing import Dict, Any, Optional
from pathlib import Path
from sqlalchemy.orm import Session
from utils.asset_tracker import save_asset_to_library
from utils.file_storage import save_text_file_safely, generate_unique_filename, sanitize_filename
import logging
logger = logging.getLogger(__name__)
def save_and_track_text_content(
db: Session,
user_id: str,
content: str,
source_module: str,
title: str,
description: Optional[str] = None,
prompt: Optional[str] = None,
tags: Optional[list] = None,
asset_metadata: Optional[Dict[str, Any]] = None,
base_dir: Optional[Path] = None,
subdirectory: Optional[str] = None,
file_extension: str = ".txt"
) -> Optional[int]:
"""
Save text content to disk and track it in the asset library.
Args:
db: Database session
user_id: Clerk user ID
content: Text content to save
source_module: Source module name (e.g., "linkedin_writer", "facebook_writer")
title: Title for the asset
description: Description of the content
prompt: Original prompt used for generation
tags: List of tags for search/filtering
asset_metadata: Additional metadata
base_dir: Base directory for file storage (defaults to backend/{module}_text)
subdirectory: Optional subdirectory (e.g., "posts", "articles")
file_extension: File extension (.txt, .md, etc.)
Returns:
Asset ID if successful, None otherwise
"""
try:
if not content or not isinstance(content, str) or len(content.strip()) == 0:
logger.warning("Empty or invalid content provided")
return None
if not user_id or not isinstance(user_id, str):
logger.error("Invalid user_id provided")
return None
# Determine output directory
if base_dir is None:
# Default to backend/{module}_text
base_dir = Path(__file__).parent.parent
module_name = source_module.replace('_', '')
output_dir = base_dir / f"{module_name}_text"
else:
output_dir = base_dir
# Add subdirectory if specified
if subdirectory:
output_dir = output_dir / subdirectory
# Generate safe filename from title
safe_title = sanitize_filename(title, max_length=80)
filename = generate_unique_filename(
prefix=safe_title,
extension=file_extension,
include_uuid=True
)
# Save text file
file_path, save_error = save_text_file_safely(
content=content,
directory=output_dir,
filename=filename,
encoding='utf-8',
max_file_size=10 * 1024 * 1024 # 10MB for text
)
if not file_path or save_error:
logger.error(f"Failed to save text file: {save_error}")
return None
# Generate file URL
relative_path = file_path.relative_to(base_dir)
file_url = f"/api/text-assets/{relative_path.as_posix()}"
# Prepare metadata
final_metadata = asset_metadata or {}
final_metadata.update({
"status": "completed",
"character_count": len(content),
"word_count": len(content.split())
})
# Save to asset library
asset_id = save_asset_to_library(
db=db,
user_id=user_id,
asset_type="text",
source_module=source_module,
filename=filename,
file_url=file_url,
file_path=str(file_path),
file_size=len(content.encode('utf-8')),
mime_type="text/plain" if file_extension == ".txt" else "text/markdown",
title=title,
description=description or f"Generated {source_module.replace('_', ' ')} content",
prompt=prompt,
tags=tags or [source_module, "text"],
asset_metadata=final_metadata
)
if asset_id:
logger.info(f"✅ Text asset saved to library: ID={asset_id}, filename={filename}")
else:
logger.warning(f"Asset tracking returned None for {filename}")
return asset_id
except Exception as e:
logger.error(f"❌ Error saving and tracking text content: {str(e)}", exc_info=True)
return None