AI Image Studio, AI podcast Maker, AI product Marketing
This commit is contained in:
@@ -10,6 +10,9 @@ from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from sqlalchemy.orm import Session
|
||||
from services.database import get_db as get_db_dependency
|
||||
from utils.text_asset_tracker import save_and_track_text_content
|
||||
|
||||
from models.blog_models import (
|
||||
BlogResearchRequest,
|
||||
@@ -41,6 +44,10 @@ router = APIRouter(prefix="/api/blog", tags=["AI Blog Writer"])
|
||||
|
||||
service = BlogWriterService()
|
||||
recommendation_applier = BlogSEORecommendationApplier()
|
||||
|
||||
|
||||
# Use the proper database dependency from services.database
|
||||
get_db = get_db_dependency
|
||||
# ---------------------------
|
||||
# SEO Recommendation Endpoints
|
||||
# ---------------------------
|
||||
@@ -272,10 +279,41 @@ async def rebalance_outline(outline_data: Dict[str, Any], target_words: int = 15
|
||||
|
||||
# Content Generation Endpoints
|
||||
@router.post("/section/generate", response_model=BlogSectionResponse)
|
||||
async def generate_section(request: BlogSectionRequest) -> BlogSectionResponse:
|
||||
async def generate_section(
|
||||
request: BlogSectionRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> BlogSectionResponse:
|
||||
"""Generate content for a specific section."""
|
||||
try:
|
||||
return await service.generate_section(request)
|
||||
response = await service.generate_section(request)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if response.markdown:
|
||||
try:
|
||||
user_id = str(current_user.get('id', '')) if current_user else None
|
||||
if user_id:
|
||||
section_heading = getattr(request, 'section_heading', getattr(request, 'heading', 'Section'))
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=response.markdown,
|
||||
source_module="blog_writer",
|
||||
title=f"Blog Section: {section_heading[:60]}",
|
||||
description=f"Blog section content",
|
||||
prompt=f"Section: {section_heading}\nKeywords: {getattr(request, 'keywords', [])}",
|
||||
tags=["blog", "section", "content"],
|
||||
asset_metadata={
|
||||
"section_id": getattr(request, 'section_id', None),
|
||||
"word_count": len(response.markdown.split()),
|
||||
},
|
||||
subdirectory="sections",
|
||||
file_extension=".md"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track blog section asset: {track_error}")
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate section: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -321,13 +359,48 @@ async def start_content_generation(
|
||||
|
||||
|
||||
@router.get("/content/status/{task_id}")
|
||||
async def content_generation_status(task_id: str) -> Dict[str, Any]:
|
||||
async def content_generation_status(
|
||||
task_id: str,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Dict[str, Any]:
|
||||
"""Poll status for content generation task."""
|
||||
try:
|
||||
status = await task_manager.get_task_status(task_id)
|
||||
if status is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
# Track blog content when task completes (non-blocking)
|
||||
if status.get('status') == 'completed' and status.get('result'):
|
||||
try:
|
||||
result = status.get('result', {})
|
||||
if result.get('sections') and len(result.get('sections', [])) > 0:
|
||||
user_id = str(current_user.get('id', '')) if current_user else None
|
||||
if user_id:
|
||||
# Combine all sections into full blog content
|
||||
blog_content = f"# {result.get('title', 'Untitled Blog')}\n\n"
|
||||
for section in result.get('sections', []):
|
||||
blog_content += f"\n## {section.get('heading', 'Section')}\n\n{section.get('content', '')}\n\n"
|
||||
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=blog_content,
|
||||
source_module="blog_writer",
|
||||
title=f"Blog: {result.get('title', 'Untitled Blog')[:60]}",
|
||||
description=f"Complete blog post with {len(result.get('sections', []))} sections",
|
||||
prompt=f"Title: {result.get('title', 'Untitled')}\nSections: {len(result.get('sections', []))}",
|
||||
tags=["blog", "complete", "content"],
|
||||
asset_metadata={
|
||||
"section_count": len(result.get('sections', [])),
|
||||
"model": result.get('model'),
|
||||
},
|
||||
subdirectory="complete",
|
||||
file_extension=".md"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track blog content asset: {track_error}")
|
||||
|
||||
# If task failed with subscription error, return HTTP error so frontend interceptor can catch it
|
||||
if status.get('status') == 'failed' and status.get('error_status') in [429, 402]:
|
||||
error_data = status.get('error_data', {}) or {}
|
||||
@@ -420,10 +493,40 @@ async def analyze_flow_advanced(request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
|
||||
@router.post("/section/optimize", response_model=BlogOptimizeResponse)
|
||||
async def optimize_section(request: BlogOptimizeRequest) -> BlogOptimizeResponse:
|
||||
async def optimize_section(
|
||||
request: BlogOptimizeRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> BlogOptimizeResponse:
|
||||
"""Optimize a specific section for better quality and engagement."""
|
||||
try:
|
||||
return await service.optimize_section(request)
|
||||
response = await service.optimize_section(request)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if response.optimized:
|
||||
try:
|
||||
user_id = str(current_user.get('id', '')) if current_user else None
|
||||
if user_id:
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=response.optimized,
|
||||
source_module="blog_writer",
|
||||
title=f"Optimized Blog Section",
|
||||
description=f"Optimized blog section content",
|
||||
prompt=f"Original Content: {request.content[:200]}\nGoals: {request.goals}",
|
||||
tags=["blog", "section", "optimized"],
|
||||
asset_metadata={
|
||||
"optimization_goals": request.goals,
|
||||
"word_count": len(response.optimized.split()),
|
||||
},
|
||||
subdirectory="sections/optimized",
|
||||
file_extension=".md"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track optimized blog section asset: {track_error}")
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to optimize section: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -591,13 +694,49 @@ async def start_medium_generation(
|
||||
|
||||
|
||||
@router.get("/generate/medium/status/{task_id}")
|
||||
async def medium_generation_status(task_id: str):
|
||||
async def medium_generation_status(
|
||||
task_id: str,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Poll status for medium blog generation task."""
|
||||
try:
|
||||
status = await task_manager.get_task_status(task_id)
|
||||
if status is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
# Track blog content when task completes (non-blocking)
|
||||
if status.get('status') == 'completed' and status.get('result'):
|
||||
try:
|
||||
result = status.get('result', {})
|
||||
if result.get('sections') and len(result.get('sections', [])) > 0:
|
||||
user_id = str(current_user.get('id', '')) if current_user else None
|
||||
if user_id:
|
||||
# Combine all sections into full blog content
|
||||
blog_content = f"# {result.get('title', 'Untitled Blog')}\n\n"
|
||||
for section in result.get('sections', []):
|
||||
blog_content += f"\n## {section.get('heading', 'Section')}\n\n{section.get('content', '')}\n\n"
|
||||
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=blog_content,
|
||||
source_module="blog_writer",
|
||||
title=f"Medium Blog: {result.get('title', 'Untitled Blog')[:60]}",
|
||||
description=f"Medium-length blog post with {len(result.get('sections', []))} sections",
|
||||
prompt=f"Title: {result.get('title', 'Untitled')}\nSections: {len(result.get('sections', []))}",
|
||||
tags=["blog", "medium", "complete"],
|
||||
asset_metadata={
|
||||
"section_count": len(result.get('sections', [])),
|
||||
"model": result.get('model'),
|
||||
"generation_time_ms": result.get('generation_time_ms'),
|
||||
},
|
||||
subdirectory="medium",
|
||||
file_extension=".md"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track medium blog asset: {track_error}")
|
||||
|
||||
# If task failed with subscription error, return HTTP error so frontend interceptor can catch it
|
||||
if status.get('status') == 'failed' and status.get('error_status') in [429, 402]:
|
||||
error_data = status.get('error_data', {}) or {}
|
||||
@@ -677,7 +816,8 @@ async def rewrite_status(task_id: str):
|
||||
@router.post("/titles/generate-seo")
|
||||
async def generate_seo_titles(
|
||||
request: Dict[str, Any],
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate 5 SEO-optimized blog titles using research and outline data."""
|
||||
try:
|
||||
@@ -722,6 +862,30 @@ async def generate_seo_titles(
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Save and track titles (non-blocking)
|
||||
if titles and len(titles) > 0:
|
||||
try:
|
||||
titles_content = "# SEO Blog Titles\n\n" + "\n".join([f"{i+1}. {title}" for i, title in enumerate(titles)])
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=titles_content,
|
||||
source_module="blog_writer",
|
||||
title=f"SEO Blog Titles: {primary_keywords[0] if primary_keywords else 'Blog'}",
|
||||
description=f"SEO-optimized blog title suggestions",
|
||||
prompt=f"Primary Keywords: {primary_keywords}\nSearch Intent: {search_intent}\nWord Count: {word_count}",
|
||||
tags=["blog", "titles", "seo"],
|
||||
asset_metadata={
|
||||
"title_count": len(titles),
|
||||
"primary_keywords": primary_keywords,
|
||||
"search_intent": search_intent,
|
||||
},
|
||||
subdirectory="titles",
|
||||
file_extension=".md"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track SEO titles asset: {track_error}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"titles": titles
|
||||
@@ -736,7 +900,8 @@ async def generate_seo_titles(
|
||||
@router.post("/introductions/generate")
|
||||
async def generate_introductions(
|
||||
request: Dict[str, Any],
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate 3 varied blog introductions using research, outline, and content."""
|
||||
try:
|
||||
@@ -781,6 +946,33 @@ async def generate_introductions(
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Save and track introductions (non-blocking)
|
||||
if introductions and len(introductions) > 0:
|
||||
try:
|
||||
intro_content = f"# Blog Introductions for: {blog_title}\n\n"
|
||||
for i, intro in enumerate(introductions, 1):
|
||||
intro_content += f"## Introduction {i}\n\n{intro}\n\n"
|
||||
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=intro_content,
|
||||
source_module="blog_writer",
|
||||
title=f"Blog Introductions: {blog_title[:60]}",
|
||||
description=f"Blog introduction variations",
|
||||
prompt=f"Blog Title: {blog_title}\nPrimary Keywords: {primary_keywords}\nSearch Intent: {search_intent}",
|
||||
tags=["blog", "introductions"],
|
||||
asset_metadata={
|
||||
"introduction_count": len(introductions),
|
||||
"blog_title": blog_title,
|
||||
"search_intent": search_intent,
|
||||
},
|
||||
subdirectory="introductions",
|
||||
file_extension=".md"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track blog introductions asset: {track_error}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"introductions": introductions
|
||||
|
||||
@@ -21,6 +21,7 @@ from models.blog_models import (
|
||||
)
|
||||
from services.blog_writer.blog_service import BlogWriterService
|
||||
from services.blog_writer.database_task_manager import DatabaseTaskManager
|
||||
from utils.text_asset_tracker import save_and_track_text_content
|
||||
|
||||
|
||||
class TaskManager:
|
||||
@@ -281,6 +282,9 @@ class TaskManager:
|
||||
self.task_storage[task_id]["status"] = "completed"
|
||||
self.task_storage[task_id]["result"] = result.dict()
|
||||
await self.update_progress(task_id, f"✅ Generated {len(result.sections)} sections successfully.")
|
||||
|
||||
# Note: Blog content tracking is handled in the status endpoint
|
||||
# to ensure we have proper database session and user context
|
||||
|
||||
except HTTPException as http_error:
|
||||
# Handle HTTPException (e.g., 429 subscription limit) - preserve error details for frontend
|
||||
|
||||
@@ -32,7 +32,7 @@ class AssetResponse(BaseModel):
|
||||
description: Optional[str] = None
|
||||
prompt: Optional[str] = None
|
||||
tags: List[str] = []
|
||||
metadata: Dict[str, Any] = {}
|
||||
asset_metadata: Dict[str, Any] = {}
|
||||
provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
cost: float = 0.0
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
"""FastAPI router for Facebook Writer endpoints."""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, Optional
|
||||
import logging
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..models import *
|
||||
from ..services import *
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from services.database import get_db as get_db_dependency
|
||||
from utils.text_asset_tracker import save_and_track_text_content
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -115,9 +119,17 @@ async def get_available_tools():
|
||||
return {"tools": tools, "total_count": len(tools)}
|
||||
|
||||
|
||||
# Use the proper database dependency from services.database
|
||||
get_db = get_db_dependency
|
||||
|
||||
|
||||
# Content Creation Endpoints
|
||||
@router.post("/post/generate", response_model=FacebookPostResponse)
|
||||
async def generate_facebook_post(request: FacebookPostRequest):
|
||||
async def generate_facebook_post(
|
||||
request: FacebookPostRequest,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Generate a Facebook post with engagement optimization."""
|
||||
try:
|
||||
logger.info(f"Generating Facebook post for business: {request.business_type}")
|
||||
@@ -126,6 +138,37 @@ async def generate_facebook_post(request: FacebookPostRequest):
|
||||
if not response.success:
|
||||
raise HTTPException(status_code=400, detail=response.error)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if response.content:
|
||||
try:
|
||||
user_id = None
|
||||
if current_user:
|
||||
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
|
||||
|
||||
if user_id:
|
||||
text_content = response.content
|
||||
if response.analytics:
|
||||
text_content += f"\n\n## Analytics\nExpected Reach: {response.analytics.expected_reach}\nExpected Engagement: {response.analytics.expected_engagement}\nBest Time to Post: {response.analytics.best_time_to_post}"
|
||||
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=text_content,
|
||||
source_module="facebook_writer",
|
||||
title=f"Facebook Post: {request.business_type[:60]}",
|
||||
description=f"Facebook post for {request.business_type}",
|
||||
prompt=f"Business Type: {request.business_type}\nTarget Audience: {request.target_audience}\nGoal: {request.post_goal.value if hasattr(request.post_goal, 'value') else request.post_goal}\nTone: {request.post_tone.value if hasattr(request.post_tone, 'value') else request.post_tone}",
|
||||
tags=["facebook", "post", request.business_type.lower().replace(' ', '_')],
|
||||
asset_metadata={
|
||||
"post_goal": request.post_goal.value if hasattr(request.post_goal, 'value') else str(request.post_goal),
|
||||
"post_tone": request.post_tone.value if hasattr(request.post_tone, 'value') else str(request.post_tone),
|
||||
"media_type": request.media_type.value if hasattr(request.media_type, 'value') else str(request.media_type)
|
||||
},
|
||||
subdirectory="posts"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track Facebook post asset: {track_error}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
@@ -134,7 +177,11 @@ async def generate_facebook_post(request: FacebookPostRequest):
|
||||
|
||||
|
||||
@router.post("/story/generate", response_model=FacebookStoryResponse)
|
||||
async def generate_facebook_story(request: FacebookStoryRequest):
|
||||
async def generate_facebook_story(
|
||||
request: FacebookStoryRequest,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Generate a Facebook story with visual suggestions."""
|
||||
try:
|
||||
logger.info(f"Generating Facebook story for business: {request.business_type}")
|
||||
@@ -143,6 +190,31 @@ async def generate_facebook_story(request: FacebookStoryRequest):
|
||||
if not response.success:
|
||||
raise HTTPException(status_code=400, detail=response.error)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if response.content:
|
||||
try:
|
||||
user_id = None
|
||||
if current_user:
|
||||
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
|
||||
|
||||
if user_id:
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=response.content,
|
||||
source_module="facebook_writer",
|
||||
title=f"Facebook Story: {request.business_type[:60]}",
|
||||
description=f"Facebook story for {request.business_type}",
|
||||
prompt=f"Business Type: {request.business_type}\nStory Type: {request.story_type.value if hasattr(request.story_type, 'value') else request.story_type}",
|
||||
tags=["facebook", "story", request.business_type.lower().replace(' ', '_')],
|
||||
asset_metadata={
|
||||
"story_type": request.story_type.value if hasattr(request.story_type, 'value') else str(request.story_type)
|
||||
},
|
||||
subdirectory="stories"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track Facebook story asset: {track_error}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
@@ -151,7 +223,11 @@ async def generate_facebook_story(request: FacebookStoryRequest):
|
||||
|
||||
|
||||
@router.post("/reel/generate", response_model=FacebookReelResponse)
|
||||
async def generate_facebook_reel(request: FacebookReelRequest):
|
||||
async def generate_facebook_reel(
|
||||
request: FacebookReelRequest,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Generate a Facebook reel script with music suggestions."""
|
||||
try:
|
||||
logger.info(f"Generating Facebook reel for business: {request.business_type}")
|
||||
@@ -160,6 +236,42 @@ async def generate_facebook_reel(request: FacebookReelRequest):
|
||||
if not response.success:
|
||||
raise HTTPException(status_code=400, detail=response.error)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if response.script:
|
||||
try:
|
||||
user_id = None
|
||||
if current_user:
|
||||
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
|
||||
|
||||
if user_id:
|
||||
text_content = f"# Facebook Reel Script\n\n## Script\n{response.script}\n"
|
||||
if response.scene_breakdown:
|
||||
text_content += f"\n## Scene Breakdown\n" + "\n".join([f"{i+1}. {scene}" for i, scene in enumerate(response.scene_breakdown)]) + "\n"
|
||||
if response.music_suggestions:
|
||||
text_content += f"\n## Music Suggestions\n" + "\n".join(response.music_suggestions) + "\n"
|
||||
if response.hashtag_suggestions:
|
||||
text_content += f"\n## Hashtag Suggestions\n" + " ".join([f"#{tag}" for tag in response.hashtag_suggestions]) + "\n"
|
||||
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=text_content,
|
||||
source_module="facebook_writer",
|
||||
title=f"Facebook Reel: {request.topic[:60]}",
|
||||
description=f"Facebook reel script for {request.business_type}",
|
||||
prompt=f"Business Type: {request.business_type}\nTopic: {request.topic}\nReel Type: {request.reel_type.value if hasattr(request.reel_type, 'value') else request.reel_type}\nLength: {request.reel_length.value if hasattr(request.reel_length, 'value') else request.reel_length}",
|
||||
tags=["facebook", "reel", request.business_type.lower().replace(' ', '_')],
|
||||
asset_metadata={
|
||||
"reel_type": request.reel_type.value if hasattr(request.reel_type, 'value') else str(request.reel_type),
|
||||
"reel_length": request.reel_length.value if hasattr(request.reel_length, 'value') else str(request.reel_length),
|
||||
"reel_style": request.reel_style.value if hasattr(request.reel_style, 'value') else str(request.reel_style)
|
||||
},
|
||||
subdirectory="reels",
|
||||
file_extension=".md"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track Facebook reel asset: {track_error}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
@@ -168,7 +280,11 @@ async def generate_facebook_reel(request: FacebookReelRequest):
|
||||
|
||||
|
||||
@router.post("/carousel/generate", response_model=FacebookCarouselResponse)
|
||||
async def generate_facebook_carousel(request: FacebookCarouselRequest):
|
||||
async def generate_facebook_carousel(
|
||||
request: FacebookCarouselRequest,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Generate a Facebook carousel post with multiple slides."""
|
||||
try:
|
||||
logger.info(f"Generating Facebook carousel for business: {request.business_type}")
|
||||
@@ -177,6 +293,44 @@ async def generate_facebook_carousel(request: FacebookCarouselRequest):
|
||||
if not response.success:
|
||||
raise HTTPException(status_code=400, detail=response.error)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if response.main_caption and response.slides:
|
||||
try:
|
||||
user_id = None
|
||||
if current_user:
|
||||
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
|
||||
|
||||
if user_id:
|
||||
text_content = f"# Facebook Carousel\n\n## Main Caption\n{response.main_caption}\n\n"
|
||||
text_content += "## Slides\n"
|
||||
for i, slide in enumerate(response.slides, 1):
|
||||
text_content += f"\n### Slide {i}: {slide.title}\n{slide.content}\n"
|
||||
if slide.image_description:
|
||||
text_content += f"Image Description: {slide.image_description}\n"
|
||||
|
||||
if response.hashtag_suggestions:
|
||||
text_content += f"\n## Hashtag Suggestions\n" + " ".join([f"#{tag}" for tag in response.hashtag_suggestions]) + "\n"
|
||||
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=text_content,
|
||||
source_module="facebook_writer",
|
||||
title=f"Facebook Carousel: {request.topic[:60]}",
|
||||
description=f"Facebook carousel for {request.business_type}",
|
||||
prompt=f"Business Type: {request.business_type}\nTopic: {request.topic}\nCarousel Type: {request.carousel_type.value if hasattr(request.carousel_type, 'value') else request.carousel_type}\nSlides: {request.num_slides}",
|
||||
tags=["facebook", "carousel", request.business_type.lower().replace(' ', '_')],
|
||||
asset_metadata={
|
||||
"carousel_type": request.carousel_type.value if hasattr(request.carousel_type, 'value') else str(request.carousel_type),
|
||||
"num_slides": request.num_slides,
|
||||
"has_cta": request.include_cta
|
||||
},
|
||||
subdirectory="carousels",
|
||||
file_extension=".md"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track Facebook carousel asset: {track_error}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
@@ -186,7 +340,11 @@ async def generate_facebook_carousel(request: FacebookCarouselRequest):
|
||||
|
||||
# Business Tools Endpoints
|
||||
@router.post("/event/generate", response_model=FacebookEventResponse)
|
||||
async def generate_facebook_event(request: FacebookEventRequest):
|
||||
async def generate_facebook_event(
|
||||
request: FacebookEventRequest,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Generate a Facebook event description."""
|
||||
try:
|
||||
logger.info(f"Generating Facebook event: {request.event_name}")
|
||||
@@ -195,6 +353,36 @@ async def generate_facebook_event(request: FacebookEventRequest):
|
||||
if not response.success:
|
||||
raise HTTPException(status_code=400, detail=response.error)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if response.description:
|
||||
try:
|
||||
user_id = None
|
||||
if current_user:
|
||||
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
|
||||
|
||||
if user_id:
|
||||
text_content = f"# Facebook Event: {request.event_name}\n\n## Description\n{response.description}\n"
|
||||
if hasattr(response, 'details') and response.details:
|
||||
text_content += f"\n## Details\n{response.details}\n"
|
||||
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=text_content,
|
||||
source_module="facebook_writer",
|
||||
title=f"Facebook Event: {request.event_name[:60]}",
|
||||
description=f"Facebook event description for {request.event_name}",
|
||||
prompt=f"Event Name: {request.event_name}\nEvent Type: {getattr(request, 'event_type', 'N/A')}\nDate: {getattr(request, 'event_date', 'N/A')}",
|
||||
tags=["facebook", "event", request.event_name.lower().replace(' ', '_')[:20]],
|
||||
asset_metadata={
|
||||
"event_name": request.event_name,
|
||||
"event_type": getattr(request, 'event_type', None)
|
||||
},
|
||||
subdirectory="events"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track Facebook event asset: {track_error}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
@@ -203,7 +391,11 @@ async def generate_facebook_event(request: FacebookEventRequest):
|
||||
|
||||
|
||||
@router.post("/group-post/generate", response_model=FacebookGroupPostResponse)
|
||||
async def generate_facebook_group_post(request: FacebookGroupPostRequest):
|
||||
async def generate_facebook_group_post(
|
||||
request: FacebookGroupPostRequest,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Generate a Facebook group post following community guidelines."""
|
||||
try:
|
||||
logger.info(f"Generating Facebook group post for: {request.group_name}")
|
||||
@@ -212,6 +404,32 @@ async def generate_facebook_group_post(request: FacebookGroupPostRequest):
|
||||
if not response.success:
|
||||
raise HTTPException(status_code=400, detail=response.error)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if response.content:
|
||||
try:
|
||||
user_id = None
|
||||
if current_user:
|
||||
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
|
||||
|
||||
if user_id:
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=response.content,
|
||||
source_module="facebook_writer",
|
||||
title=f"Facebook Group Post: {request.group_name[:60]}",
|
||||
description=f"Facebook group post for {request.group_name}",
|
||||
prompt=f"Group Name: {request.group_name}\nTopic: {getattr(request, 'topic', 'N/A')}",
|
||||
tags=["facebook", "group_post", request.group_name.lower().replace(' ', '_')[:20]],
|
||||
asset_metadata={
|
||||
"group_name": request.group_name,
|
||||
"group_type": getattr(request, 'group_type', None)
|
||||
},
|
||||
subdirectory="group_posts"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track Facebook group post asset: {track_error}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
@@ -220,7 +438,11 @@ async def generate_facebook_group_post(request: FacebookGroupPostRequest):
|
||||
|
||||
|
||||
@router.post("/page-about/generate", response_model=FacebookPageAboutResponse)
|
||||
async def generate_facebook_page_about(request: FacebookPageAboutRequest):
|
||||
async def generate_facebook_page_about(
|
||||
request: FacebookPageAboutRequest,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Generate a Facebook page about section."""
|
||||
try:
|
||||
logger.info(f"Generating Facebook page about for: {request.business_name}")
|
||||
@@ -229,6 +451,32 @@ async def generate_facebook_page_about(request: FacebookPageAboutRequest):
|
||||
if not response.success:
|
||||
raise HTTPException(status_code=400, detail=response.error)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if response.about_section:
|
||||
try:
|
||||
user_id = None
|
||||
if current_user:
|
||||
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
|
||||
|
||||
if user_id:
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=response.about_section,
|
||||
source_module="facebook_writer",
|
||||
title=f"Facebook Page About: {request.business_name[:60]}",
|
||||
description=f"Facebook page about section for {request.business_name}",
|
||||
prompt=f"Business Name: {request.business_name}\nBusiness Type: {getattr(request, 'business_type', 'N/A')}",
|
||||
tags=["facebook", "page_about", request.business_name.lower().replace(' ', '_')[:20]],
|
||||
asset_metadata={
|
||||
"business_name": request.business_name,
|
||||
"business_type": getattr(request, 'business_type', None)
|
||||
},
|
||||
subdirectory="page_about"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track Facebook page about asset: {track_error}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
@@ -238,7 +486,11 @@ async def generate_facebook_page_about(request: FacebookPageAboutRequest):
|
||||
|
||||
# Marketing Tools Endpoints
|
||||
@router.post("/ad-copy/generate", response_model=FacebookAdCopyResponse)
|
||||
async def generate_facebook_ad_copy(request: FacebookAdCopyRequest):
|
||||
async def generate_facebook_ad_copy(
|
||||
request: FacebookAdCopyRequest,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Generate Facebook ad copy with targeting suggestions."""
|
||||
try:
|
||||
logger.info(f"Generating Facebook ad copy for: {request.business_type}")
|
||||
@@ -247,6 +499,41 @@ async def generate_facebook_ad_copy(request: FacebookAdCopyRequest):
|
||||
if not response.success:
|
||||
raise HTTPException(status_code=400, detail=response.error)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if response.ad_copy:
|
||||
try:
|
||||
user_id = None
|
||||
if current_user:
|
||||
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
|
||||
|
||||
if user_id:
|
||||
text_content = f"# Facebook Ad Copy\n\n## Ad Copy\n{response.ad_copy}\n"
|
||||
if hasattr(response, 'headline') and response.headline:
|
||||
text_content += f"\n## Headline\n{response.headline}\n"
|
||||
if hasattr(response, 'description') and response.description:
|
||||
text_content += f"\n## Description\n{response.description}\n"
|
||||
if hasattr(response, 'targeting_suggestions') and response.targeting_suggestions:
|
||||
text_content += f"\n## Targeting Suggestions\n" + "\n".join(response.targeting_suggestions) + "\n"
|
||||
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=text_content,
|
||||
source_module="facebook_writer",
|
||||
title=f"Facebook Ad Copy: {request.business_type[:60]}",
|
||||
description=f"Facebook ad copy for {request.business_type}",
|
||||
prompt=f"Business Type: {request.business_type}\nAd Objective: {getattr(request, 'ad_objective', 'N/A')}\nTarget Audience: {getattr(request, 'target_audience', 'N/A')}",
|
||||
tags=["facebook", "ad_copy", request.business_type.lower().replace(' ', '_')],
|
||||
asset_metadata={
|
||||
"ad_objective": getattr(request, 'ad_objective', None),
|
||||
"budget": getattr(request, 'budget', None)
|
||||
},
|
||||
subdirectory="ad_copy",
|
||||
file_extension=".md"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track Facebook ad copy asset: {track_error}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -2,10 +2,14 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import os
|
||||
import uuid
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from services.llm_providers.main_image_generation import generate_image
|
||||
@@ -16,6 +20,8 @@ from middleware.auth_middleware import get_current_user
|
||||
from services.database import get_db
|
||||
from services.subscription import UsageTrackingService, PricingService
|
||||
from models.subscription_models import APIProvider, UsageSummary
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from utils.file_storage import save_file_safely, generate_unique_filename, sanitize_filename
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/images", tags=["images"])
|
||||
@@ -37,6 +43,7 @@ class ImageGenerateRequest(BaseModel):
|
||||
class ImageGenerateResponse(BaseModel):
|
||||
success: bool = True
|
||||
image_base64: str
|
||||
image_url: Optional[str] = None # URL to saved image file
|
||||
width: int
|
||||
height: int
|
||||
provider: str
|
||||
@@ -47,7 +54,8 @@ class ImageGenerateResponse(BaseModel):
|
||||
@router.post("/generate", response_model=ImageGenerateResponse)
|
||||
def generate(
|
||||
req: ImageGenerateRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> ImageGenerateResponse:
|
||||
"""Generate image with subscription checking."""
|
||||
try:
|
||||
@@ -80,6 +88,78 @@ def generate(
|
||||
)
|
||||
image_b64 = base64.b64encode(result.image_bytes).decode("utf-8")
|
||||
|
||||
# Save image to disk and track in asset library
|
||||
image_url = None
|
||||
image_filename = None
|
||||
image_path = None
|
||||
|
||||
try:
|
||||
# Create output directory for image studio images
|
||||
base_dir = Path(__file__).parent.parent
|
||||
output_dir = base_dir / "image_studio_images"
|
||||
|
||||
# Generate safe filename from prompt
|
||||
clean_prompt = sanitize_filename(req.prompt[:50], max_length=50)
|
||||
image_filename = generate_unique_filename(
|
||||
prefix=f"img_{clean_prompt}",
|
||||
extension=".png",
|
||||
include_uuid=True
|
||||
)
|
||||
|
||||
# Save file safely
|
||||
image_path, save_error = save_file_safely(
|
||||
content=result.image_bytes,
|
||||
directory=output_dir,
|
||||
filename=image_filename,
|
||||
max_file_size=50 * 1024 * 1024 # 50MB for images
|
||||
)
|
||||
|
||||
if image_path and not save_error:
|
||||
# Generate file URL (will be served via API endpoint)
|
||||
image_url = f"/api/images/image-studio/images/{image_path.name}"
|
||||
|
||||
logger.info(f"[images.generate] Saved image to: {image_path} ({len(result.image_bytes)} bytes)")
|
||||
|
||||
# Save to asset library (non-blocking)
|
||||
try:
|
||||
asset_id = save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="image",
|
||||
source_module="image_studio",
|
||||
filename=image_path.name,
|
||||
file_url=image_url,
|
||||
file_path=str(image_path),
|
||||
file_size=len(result.image_bytes),
|
||||
mime_type="image/png",
|
||||
title=req.prompt[:100] if len(req.prompt) <= 100 else req.prompt[:97] + "...",
|
||||
description=f"Generated image: {req.prompt[:200]}" if len(req.prompt) > 200 else req.prompt,
|
||||
prompt=req.prompt,
|
||||
tags=["image_studio", "generated", result.provider] if result.provider else ["image_studio", "generated"],
|
||||
provider=result.provider,
|
||||
model=result.model,
|
||||
asset_metadata={
|
||||
"width": result.width,
|
||||
"height": result.height,
|
||||
"seed": result.seed,
|
||||
"status": "completed",
|
||||
"negative_prompt": req.negative_prompt
|
||||
}
|
||||
)
|
||||
if asset_id:
|
||||
logger.info(f"[images.generate] ✅ Asset saved to library: ID={asset_id}, filename={image_path.name}")
|
||||
else:
|
||||
logger.warning(f"[images.generate] Asset tracking returned None (may have failed silently)")
|
||||
except Exception as asset_error:
|
||||
logger.error(f"[images.generate] Failed to save asset to library: {asset_error}", exc_info=True)
|
||||
# Don't fail the request if asset tracking fails
|
||||
else:
|
||||
logger.warning(f"[images.generate] Failed to save image to disk: {save_error}")
|
||||
# Continue without failing the request - base64 is still available
|
||||
except Exception as save_error:
|
||||
logger.error(f"[images.generate] Unexpected error saving image: {save_error}", exc_info=True)
|
||||
# Continue without failing the request
|
||||
|
||||
# TRACK USAGE after successful image generation
|
||||
if result:
|
||||
logger.info(f"[images.generate] ✅ Image generation successful, tracking usage for user {user_id}")
|
||||
@@ -168,6 +248,7 @@ def generate(
|
||||
|
||||
return ImageGenerateResponse(
|
||||
image_base64=image_b64,
|
||||
image_url=image_url,
|
||||
width=result.width,
|
||||
height=result.height,
|
||||
provider=result.provider,
|
||||
@@ -226,6 +307,7 @@ class ImageEditRequest(BaseModel):
|
||||
class ImageEditResponse(BaseModel):
|
||||
success: bool = True
|
||||
image_base64: str
|
||||
image_url: Optional[str] = None # URL to saved edited image file
|
||||
width: int
|
||||
height: int
|
||||
provider: str
|
||||
@@ -358,7 +440,8 @@ def suggest_prompts(
|
||||
@router.post("/edit", response_model=ImageEditResponse)
|
||||
def edit(
|
||||
req: ImageEditRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> ImageEditResponse:
|
||||
"""Edit image with subscription checking."""
|
||||
try:
|
||||
@@ -391,6 +474,78 @@ def edit(
|
||||
)
|
||||
edited_image_b64 = base64.b64encode(result.image_bytes).decode("utf-8")
|
||||
|
||||
# Save edited image to disk and track in asset library
|
||||
image_url = None
|
||||
image_filename = None
|
||||
image_path = None
|
||||
|
||||
try:
|
||||
# Create output directory for image studio edited images
|
||||
base_dir = Path(__file__).parent.parent
|
||||
output_dir = base_dir / "image_studio_images" / "edited"
|
||||
|
||||
# Generate safe filename from prompt
|
||||
clean_prompt = sanitize_filename(req.prompt[:50], max_length=50)
|
||||
image_filename = generate_unique_filename(
|
||||
prefix=f"edited_{clean_prompt}",
|
||||
extension=".png",
|
||||
include_uuid=True
|
||||
)
|
||||
|
||||
# Save file safely
|
||||
image_path, save_error = save_file_safely(
|
||||
content=result.image_bytes,
|
||||
directory=output_dir,
|
||||
filename=image_filename,
|
||||
max_file_size=50 * 1024 * 1024 # 50MB for images
|
||||
)
|
||||
|
||||
if image_path and not save_error:
|
||||
# Generate file URL
|
||||
image_url = f"/api/images/image-studio/images/edited/{image_path.name}"
|
||||
|
||||
logger.info(f"[images.edit] Saved edited image to: {image_path} ({len(result.image_bytes)} bytes)")
|
||||
|
||||
# Save to asset library (non-blocking)
|
||||
try:
|
||||
asset_id = save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="image",
|
||||
source_module="image_studio",
|
||||
filename=image_path.name,
|
||||
file_url=image_url,
|
||||
file_path=str(image_path),
|
||||
file_size=len(result.image_bytes),
|
||||
mime_type="image/png",
|
||||
title=f"Edited: {req.prompt[:100]}" if len(req.prompt) <= 100 else f"Edited: {req.prompt[:97]}...",
|
||||
description=f"Edited image with prompt: {req.prompt[:200]}" if len(req.prompt) > 200 else f"Edited image with prompt: {req.prompt}",
|
||||
prompt=req.prompt,
|
||||
tags=["image_studio", "edited", result.provider] if result.provider else ["image_studio", "edited"],
|
||||
provider=result.provider,
|
||||
model=result.model,
|
||||
asset_metadata={
|
||||
"width": result.width,
|
||||
"height": result.height,
|
||||
"seed": result.seed,
|
||||
"status": "completed",
|
||||
"operation": "edit"
|
||||
}
|
||||
)
|
||||
if asset_id:
|
||||
logger.info(f"[images.edit] ✅ Asset saved to library: ID={asset_id}, filename={image_path.name}")
|
||||
else:
|
||||
logger.warning(f"[images.edit] Asset tracking returned None (may have failed silently)")
|
||||
except Exception as asset_error:
|
||||
logger.error(f"[images.edit] Failed to save asset to library: {asset_error}", exc_info=True)
|
||||
# Don't fail the request if asset tracking fails
|
||||
else:
|
||||
logger.warning(f"[images.edit] Failed to save edited image to disk: {save_error}")
|
||||
# Continue without failing the request - base64 is still available
|
||||
except Exception as save_error:
|
||||
logger.error(f"[images.edit] Unexpected error saving edited image: {save_error}", exc_info=True)
|
||||
# Continue without failing the request
|
||||
|
||||
# TRACK USAGE after successful image editing
|
||||
if result:
|
||||
logger.info(f"[images.edit] ✅ Image editing successful, tracking usage for user {user_id}")
|
||||
@@ -478,6 +633,7 @@ def edit(
|
||||
|
||||
return ImageEditResponse(
|
||||
image_base64=edited_image_b64,
|
||||
image_url=image_url,
|
||||
width=result.width,
|
||||
height=result.height,
|
||||
provider=result.provider,
|
||||
@@ -494,3 +650,55 @@ def edit(
|
||||
detail="Image editing service is temporarily unavailable or the connection was reset. Please try again."
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------
|
||||
# Image Serving Endpoints
|
||||
# ---------------------------
|
||||
|
||||
@router.get("/image-studio/images/{image_filename:path}")
|
||||
async def serve_image_studio_image(
|
||||
image_filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Serve a generated or edited image from Image Studio."""
|
||||
try:
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
# Determine if it's an edited image or regular image
|
||||
base_dir = Path(__file__).parent.parent
|
||||
image_studio_dir = (base_dir / "image_studio_images").resolve()
|
||||
|
||||
if image_filename.startswith("edited/"):
|
||||
# Remove "edited/" prefix and serve from edited directory
|
||||
actual_filename = image_filename.replace("edited/", "", 1)
|
||||
image_path = (image_studio_dir / "edited" / actual_filename).resolve()
|
||||
base_subdir = (image_studio_dir / "edited").resolve()
|
||||
else:
|
||||
image_path = (image_studio_dir / image_filename).resolve()
|
||||
base_subdir = image_studio_dir
|
||||
|
||||
# Security: Prevent directory traversal attacks
|
||||
# Ensure the resolved path is within the intended directory
|
||||
try:
|
||||
image_path.relative_to(base_subdir)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Access denied: Invalid image path"
|
||||
)
|
||||
|
||||
if not image_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
return FileResponse(
|
||||
path=str(image_path),
|
||||
media_type="image/png",
|
||||
filename=image_path.name
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[images] Failed to serve image: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,18 +4,20 @@ Each module focuses on a related set of routes to keep the primary
|
||||
`router.py` concise and easier to maintain.
|
||||
"""
|
||||
|
||||
from . import story_setup
|
||||
from . import story_content
|
||||
from . import story_tasks
|
||||
from . import media_generation
|
||||
from . import video_generation
|
||||
from . import cache_routes
|
||||
from . import media_generation
|
||||
from . import scene_animation
|
||||
from . import story_content
|
||||
from . import story_setup
|
||||
from . import story_tasks
|
||||
from . import video_generation
|
||||
|
||||
__all__ = [
|
||||
"story_setup",
|
||||
"story_content",
|
||||
"story_tasks",
|
||||
"media_generation",
|
||||
"video_generation",
|
||||
"cache_routes",
|
||||
"media_generation",
|
||||
"scene_animation",
|
||||
"story_content",
|
||||
"story_setup",
|
||||
"story_tasks",
|
||||
"video_generation",
|
||||
]
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
|
||||
from models.story_models import (
|
||||
@@ -18,8 +20,10 @@ from models.story_models import (
|
||||
GenerateAIAudioResponse,
|
||||
StoryScene,
|
||||
)
|
||||
from services.database import get_db
|
||||
from services.story_writer.image_generation_service import StoryImageGenerationService
|
||||
from services.story_writer.audio_generation_service import StoryAudioGenerationService
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
|
||||
from ..utils.auth import require_authenticated_user
|
||||
from ..utils.media_utils import resolve_media_file
|
||||
@@ -34,6 +38,7 @@ audio_service = StoryAudioGenerationService()
|
||||
async def generate_scene_images(
|
||||
request: StoryImageGenerationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> StoryImageGenerationResponse:
|
||||
"""Generate images for story scenes."""
|
||||
try:
|
||||
@@ -70,6 +75,37 @@ async def generate_scene_images(
|
||||
for result in image_results
|
||||
]
|
||||
|
||||
# Save assets to library
|
||||
for result in image_results:
|
||||
if not result.get("error") and result.get("image_url"):
|
||||
try:
|
||||
scene_number = result.get("scene_number", 0)
|
||||
# Safely get prompt from scenes_data with bounds checking
|
||||
prompt = None
|
||||
if scene_number > 0 and scene_number <= len(scenes_data):
|
||||
prompt = scenes_data[scene_number - 1].get("image_prompt")
|
||||
|
||||
save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="image",
|
||||
source_module="story_writer",
|
||||
filename=result.get("image_filename", ""),
|
||||
file_url=result.get("image_url", ""),
|
||||
file_path=result.get("image_path"),
|
||||
file_size=result.get("file_size"),
|
||||
mime_type="image/png",
|
||||
title=f"Scene {scene_number}: {result.get('scene_title', 'Untitled')}",
|
||||
description=f"Story scene image for scene {scene_number}",
|
||||
prompt=prompt,
|
||||
tags=["story_writer", "scene", f"scene_{scene_number}"],
|
||||
provider=result.get("provider"),
|
||||
model=result.get("model"),
|
||||
asset_metadata={"scene_number": scene_number, "scene_title": result.get("scene_title"), "status": "completed"}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[StoryWriter] Failed to save image asset to library: {e}")
|
||||
|
||||
return StoryImageGenerationResponse(images=image_models, success=True)
|
||||
|
||||
except HTTPException:
|
||||
@@ -163,6 +199,7 @@ async def serve_scene_image(
|
||||
async def generate_scene_audio(
|
||||
request: StoryAudioGenerationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> StoryAudioGenerationResponse:
|
||||
"""Generate audio narration for story scenes."""
|
||||
try:
|
||||
@@ -185,18 +222,52 @@ async def generate_scene_audio(
|
||||
|
||||
audio_models: List[StoryAudioResult] = []
|
||||
for result in audio_results:
|
||||
audio_url = result.get("audio_url") or ""
|
||||
audio_filename = result.get("audio_filename") or ""
|
||||
|
||||
audio_models.append(
|
||||
StoryAudioResult(
|
||||
scene_number=result.get("scene_number", 0),
|
||||
scene_title=result.get("scene_title", "Untitled"),
|
||||
audio_filename=result.get("audio_filename") or "",
|
||||
audio_url=result.get("audio_url") or "",
|
||||
audio_filename=audio_filename,
|
||||
audio_url=audio_url,
|
||||
provider=result.get("provider", "unknown"),
|
||||
file_size=result.get("file_size", 0),
|
||||
error=result.get("error"),
|
||||
)
|
||||
)
|
||||
|
||||
# Save assets to library
|
||||
if not result.get("error") and audio_url:
|
||||
try:
|
||||
scene_number = result.get("scene_number", 0)
|
||||
# Safely get prompt from scenes_data with bounds checking
|
||||
prompt = None
|
||||
if scene_number > 0 and scene_number <= len(scenes_data):
|
||||
prompt = scenes_data[scene_number - 1].get("text")
|
||||
|
||||
save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="audio",
|
||||
source_module="story_writer",
|
||||
filename=audio_filename,
|
||||
file_url=audio_url,
|
||||
file_path=result.get("audio_path"),
|
||||
file_size=result.get("file_size"),
|
||||
mime_type="audio/mpeg",
|
||||
title=f"Scene {scene_number}: {result.get('scene_title', 'Untitled')}",
|
||||
description=f"Story scene audio narration for scene {scene_number}",
|
||||
prompt=prompt,
|
||||
tags=["story_writer", "audio", "narration", f"scene_{scene_number}"],
|
||||
provider=result.get("provider"),
|
||||
model=result.get("model"),
|
||||
cost=result.get("cost"),
|
||||
asset_metadata={"scene_number": scene_number, "scene_title": result.get("scene_title"), "status": "completed"}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[StoryWriter] Failed to save audio asset to library: {e}")
|
||||
|
||||
return StoryAudioGenerationResponse(audio_files=audio_models, success=True)
|
||||
|
||||
except HTTPException:
|
||||
@@ -287,3 +358,59 @@ async def serve_scene_audio(
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
class PromptOptimizeRequest(BaseModel):
|
||||
text: str = Field(..., description="The prompt text to optimize")
|
||||
mode: Optional[str] = Field(default="image", pattern="^(image|video)$", description="Optimization mode: 'image' or 'video'")
|
||||
style: Optional[str] = Field(
|
||||
default="default",
|
||||
pattern="^(default|artistic|photographic|technical|anime|realistic)$",
|
||||
description="Style: 'default', 'artistic', 'photographic', 'technical', 'anime', or 'realistic'"
|
||||
)
|
||||
image: Optional[str] = Field(None, description="Base64-encoded image for context (optional)")
|
||||
|
||||
|
||||
class PromptOptimizeResponse(BaseModel):
|
||||
optimized_prompt: str
|
||||
success: bool
|
||||
|
||||
|
||||
@router.post("/optimize-prompt", response_model=PromptOptimizeResponse)
|
||||
async def optimize_prompt(
|
||||
request: PromptOptimizeRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> PromptOptimizeResponse:
|
||||
"""Optimize an image prompt using WaveSpeed prompt optimizer."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.text or not request.text.strip():
|
||||
raise HTTPException(status_code=400, detail="Prompt text is required")
|
||||
|
||||
logger.info(f"[StoryWriter] Optimizing prompt for user {user_id} (mode={request.mode}, style={request.style})")
|
||||
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
|
||||
client = WaveSpeedClient()
|
||||
optimized_prompt = client.optimize_prompt(
|
||||
text=request.text.strip(),
|
||||
mode=request.mode or "image",
|
||||
style=request.style or "default",
|
||||
image=request.image, # Optional base64 image
|
||||
enable_sync_mode=True,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
logger.info(f"[StoryWriter] Prompt optimized successfully for user {user_id}")
|
||||
|
||||
return PromptOptimizeResponse(
|
||||
optimized_prompt=optimized_prompt,
|
||||
success=True
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to optimize prompt: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
|
||||
484
backend/api/story_writer/routes/scene_animation.py
Normal file
484
backend/api/story_writer/routes/scene_animation.py
Normal file
@@ -0,0 +1,484 @@
|
||||
"""
|
||||
Scene Animation Routes
|
||||
|
||||
Handles scene animation endpoints using WaveSpeed Kling and InfiniteTalk.
|
||||
"""
|
||||
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from models.story_models import (
|
||||
AnimateSceneRequest,
|
||||
AnimateSceneResponse,
|
||||
AnimateSceneVoiceoverRequest,
|
||||
ResumeSceneAnimationRequest,
|
||||
)
|
||||
from services.database import get_db
|
||||
from services.llm_providers.main_video_generation import track_video_usage
|
||||
from services.story_writer.video_generation_service import StoryVideoGenerationService
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_scene_animation_operation
|
||||
from services.wavespeed.infinitetalk import animate_scene_with_voiceover
|
||||
from services.wavespeed.kling_animation import animate_scene_image, resume_scene_animation
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
from ..task_manager import task_manager
|
||||
from ..utils.auth import require_authenticated_user
|
||||
from ..utils.media_utils import load_story_audio_bytes, load_story_image_bytes
|
||||
|
||||
router = APIRouter()
|
||||
scene_logger = get_service_logger("api.story_writer.scene_animation")
|
||||
AI_VIDEO_SUBDIR = Path("AI_Videos")
|
||||
|
||||
|
||||
def _build_authenticated_media_url(request: Request, path: str) -> str:
|
||||
"""Append the caller's auth token to a media URL so <video>/<img> tags can access it."""
|
||||
if not path:
|
||||
return path
|
||||
|
||||
token: Optional[str] = None
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
token = auth_header.replace("Bearer ", "").strip()
|
||||
elif "token" in request.query_params:
|
||||
token = request.query_params["token"]
|
||||
|
||||
if token:
|
||||
separator = "&" if "?" in path else "?"
|
||||
path = f"{path}{separator}token={quote(token)}"
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def _guess_mime_from_url(url: str, fallback: str) -> str:
|
||||
"""Guess MIME type from URL."""
|
||||
if not url:
|
||||
return fallback
|
||||
mime, _ = mimetypes.guess_type(url)
|
||||
return mime or fallback
|
||||
|
||||
|
||||
@router.post("/animate-scene-preview", response_model=AnimateSceneResponse)
|
||||
async def animate_scene_preview(
|
||||
request_obj: Request,
|
||||
request: AnimateSceneRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> AnimateSceneResponse:
|
||||
"""
|
||||
Animate a single scene image using WaveSpeed Kling v2.5 Turbo Std.
|
||||
"""
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get("id", ""))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
duration = request.duration or 5
|
||||
if duration not in (5, 10):
|
||||
raise HTTPException(status_code=400, detail="Duration must be 5 or 10 seconds.")
|
||||
|
||||
scene_logger.info(
|
||||
"[AnimateScene] User=%s scene=%s duration=%s image_url=%s",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
duration,
|
||||
request.image_url,
|
||||
)
|
||||
|
||||
image_bytes = load_story_image_bytes(request.image_url)
|
||||
if not image_bytes:
|
||||
scene_logger.warning("[AnimateScene] Missing image bytes for user=%s scene=%s", user_id, request.scene_number)
|
||||
raise HTTPException(status_code=404, detail="Scene image not found. Generate images first.")
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
validate_scene_animation_operation(pricing_service=pricing_service, user_id=user_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
animation_result = animate_scene_image(
|
||||
image_bytes=image_bytes,
|
||||
scene_data=request.scene_data,
|
||||
story_context=request.story_context,
|
||||
user_id=user_id,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
base_dir = Path(__file__).parent.parent.parent.parent
|
||||
ai_video_dir = base_dir / "story_videos" / AI_VIDEO_SUBDIR
|
||||
ai_video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_service = StoryVideoGenerationService(output_dir=str(ai_video_dir))
|
||||
|
||||
save_result = video_service.save_scene_video(
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
scene_number=request.scene_number,
|
||||
user_id=user_id,
|
||||
)
|
||||
video_filename = save_result["video_filename"]
|
||||
video_url = _build_authenticated_media_url(
|
||||
request_obj, f"/api/story/videos/ai/{video_filename}"
|
||||
)
|
||||
|
||||
usage_info = track_video_usage(
|
||||
user_id=user_id,
|
||||
provider=animation_result["provider"],
|
||||
model_name=animation_result["model_name"],
|
||||
prompt=animation_result["prompt"],
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
cost_override=animation_result["cost"],
|
||||
)
|
||||
if usage_info:
|
||||
scene_logger.warning(
|
||||
"[AnimateScene] Video usage tracked user=%s: %s → %s / %s (cost +$%.2f, total=$%.2f)",
|
||||
user_id,
|
||||
usage_info.get("previous_calls"),
|
||||
usage_info.get("current_calls"),
|
||||
usage_info.get("video_limit_display"),
|
||||
usage_info.get("cost_per_video", 0.0),
|
||||
usage_info.get("total_video_cost", 0.0),
|
||||
)
|
||||
|
||||
scene_logger.info(
|
||||
"[AnimateScene] ✅ Completed user=%s scene=%s duration=%s cost=$%.2f video=%s",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
animation_result["duration"],
|
||||
animation_result["cost"],
|
||||
video_url,
|
||||
)
|
||||
|
||||
# Save video asset to library
|
||||
db = next(get_db())
|
||||
try:
|
||||
save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="video",
|
||||
source_module="story_writer",
|
||||
filename=video_filename,
|
||||
file_url=video_url,
|
||||
file_path=str(ai_video_dir / video_filename),
|
||||
file_size=len(animation_result["video_bytes"]),
|
||||
mime_type="video/mp4",
|
||||
title=f"Scene {request.scene_number} Animation",
|
||||
description=f"Animated scene {request.scene_number} from story",
|
||||
prompt=animation_result["prompt"],
|
||||
tags=["story_writer", "video", "animation", f"scene_{request.scene_number}"],
|
||||
provider=animation_result["provider"],
|
||||
model=animation_result.get("model_name"),
|
||||
cost=animation_result["cost"],
|
||||
asset_metadata={"scene_number": request.scene_number, "duration": animation_result["duration"], "status": "completed"}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[StoryWriter] Failed to save video asset to library: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return AnimateSceneResponse(
|
||||
success=True,
|
||||
scene_number=request.scene_number,
|
||||
video_filename=video_filename,
|
||||
video_url=video_url,
|
||||
duration=animation_result["duration"],
|
||||
cost=animation_result["cost"],
|
||||
prompt_used=animation_result["prompt"],
|
||||
provider=animation_result["provider"],
|
||||
prediction_id=animation_result.get("prediction_id"),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/animate-scene-resume", response_model=AnimateSceneResponse)
|
||||
async def resume_scene_animation_endpoint(
|
||||
request_obj: Request,
|
||||
request: ResumeSceneAnimationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> AnimateSceneResponse:
|
||||
"""Resume downloading a WaveSpeed animation when the initial call timed out."""
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get("id", ""))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
scene_logger.info(
|
||||
"[AnimateScene] Resume requested user=%s scene=%s prediction=%s",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
request.prediction_id,
|
||||
)
|
||||
|
||||
animation_result = resume_scene_animation(
|
||||
prediction_id=request.prediction_id,
|
||||
duration=request.duration or 5,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
base_dir = Path(__file__).parent.parent.parent.parent
|
||||
ai_video_dir = base_dir / "story_videos" / AI_VIDEO_SUBDIR
|
||||
ai_video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_service = StoryVideoGenerationService(output_dir=str(ai_video_dir))
|
||||
|
||||
save_result = video_service.save_scene_video(
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
scene_number=request.scene_number,
|
||||
user_id=user_id,
|
||||
)
|
||||
video_filename = save_result["video_filename"]
|
||||
video_url = _build_authenticated_media_url(
|
||||
request_obj, f"/api/story/videos/ai/{video_filename}"
|
||||
)
|
||||
|
||||
usage_info = track_video_usage(
|
||||
user_id=user_id,
|
||||
provider=animation_result["provider"],
|
||||
model_name=animation_result["model_name"],
|
||||
prompt=animation_result["prompt"],
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
cost_override=animation_result["cost"],
|
||||
)
|
||||
if usage_info:
|
||||
scene_logger.warning(
|
||||
"[AnimateScene] (Resume) Video usage tracked user=%s: %s → %s / %s (cost +$%.2f, total=$%.2f)",
|
||||
user_id,
|
||||
usage_info.get("previous_calls"),
|
||||
usage_info.get("current_calls"),
|
||||
usage_info.get("video_limit_display"),
|
||||
usage_info.get("cost_per_video", 0.0),
|
||||
usage_info.get("total_video_cost", 0.0),
|
||||
)
|
||||
|
||||
scene_logger.info(
|
||||
"[AnimateScene] ✅ Resume completed user=%s scene=%s prediction=%s video=%s",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
request.prediction_id,
|
||||
video_url,
|
||||
)
|
||||
|
||||
return AnimateSceneResponse(
|
||||
success=True,
|
||||
scene_number=request.scene_number,
|
||||
video_filename=video_filename,
|
||||
video_url=video_url,
|
||||
duration=animation_result["duration"],
|
||||
cost=animation_result["cost"],
|
||||
prompt_used=animation_result["prompt"],
|
||||
provider=animation_result["provider"],
|
||||
prediction_id=animation_result.get("prediction_id"),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/animate-scene-voiceover", response_model=Dict[str, Any])
|
||||
async def animate_scene_voiceover_endpoint(
|
||||
request_obj: Request,
|
||||
request: AnimateSceneVoiceoverRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Animate a scene using WaveSpeed InfiniteTalk (image + audio) asynchronously.
|
||||
Returns task_id for polling since InfiniteTalk can take up to 10 minutes.
|
||||
"""
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get("id", ""))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
scene_logger.info(
|
||||
"[AnimateSceneVoiceover] User=%s scene=%s resolution=%s (async)",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
request.resolution or "720p",
|
||||
)
|
||||
|
||||
image_bytes = load_story_image_bytes(request.image_url)
|
||||
if not image_bytes:
|
||||
raise HTTPException(status_code=404, detail="Scene image not found. Generate images first.")
|
||||
|
||||
audio_bytes = load_story_audio_bytes(request.audio_url)
|
||||
if not audio_bytes:
|
||||
raise HTTPException(status_code=404, detail="Scene audio not found. Generate audio first.")
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
validate_scene_animation_operation(pricing_service=pricing_service, user_id=user_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Extract token for authenticated URL building (if needed)
|
||||
auth_token = None
|
||||
auth_header = request_obj.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
auth_token = auth_header.replace("Bearer ", "").strip()
|
||||
|
||||
# Create async task
|
||||
task_id = task_manager.create_task("scene_voiceover_animation")
|
||||
background_tasks.add_task(
|
||||
_execute_voiceover_animation_task,
|
||||
task_id=task_id,
|
||||
request=request,
|
||||
user_id=user_id,
|
||||
image_bytes=image_bytes,
|
||||
audio_bytes=audio_bytes,
|
||||
auth_token=auth_token,
|
||||
)
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"status": "pending",
|
||||
"message": "InfiniteTalk animation started. This may take up to 10 minutes.",
|
||||
}
|
||||
|
||||
|
||||
def _execute_voiceover_animation_task(
|
||||
task_id: str,
|
||||
request: AnimateSceneVoiceoverRequest,
|
||||
user_id: str,
|
||||
image_bytes: bytes,
|
||||
audio_bytes: bytes,
|
||||
auth_token: Optional[str] = None,
|
||||
):
|
||||
"""Background task to generate InfiniteTalk video with progress updates."""
|
||||
try:
|
||||
task_manager.update_task_status(
|
||||
task_id, "processing", progress=5.0, message="Submitting to WaveSpeed InfiniteTalk..."
|
||||
)
|
||||
|
||||
animation_result = animate_scene_with_voiceover(
|
||||
image_bytes=image_bytes,
|
||||
audio_bytes=audio_bytes,
|
||||
scene_data=request.scene_data,
|
||||
story_context=request.story_context,
|
||||
user_id=user_id,
|
||||
resolution=request.resolution or "720p",
|
||||
prompt_override=request.prompt,
|
||||
image_mime=_guess_mime_from_url(request.image_url, "image/png"),
|
||||
audio_mime=_guess_mime_from_url(request.audio_url, "audio/mpeg"),
|
||||
)
|
||||
|
||||
task_manager.update_task_status(
|
||||
task_id, "processing", progress=80.0, message="Saving video file..."
|
||||
)
|
||||
|
||||
base_dir = Path(__file__).parent.parent.parent.parent
|
||||
ai_video_dir = base_dir / "story_videos" / AI_VIDEO_SUBDIR
|
||||
ai_video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_service = StoryVideoGenerationService(output_dir=str(ai_video_dir))
|
||||
|
||||
save_result = video_service.save_scene_video(
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
scene_number=request.scene_number,
|
||||
user_id=user_id,
|
||||
)
|
||||
video_filename = save_result["video_filename"]
|
||||
# Build authenticated URL if token provided, otherwise return plain URL
|
||||
video_url = f"/api/story/videos/ai/{video_filename}"
|
||||
if auth_token:
|
||||
video_url = f"{video_url}?token={quote(auth_token)}"
|
||||
|
||||
usage_info = track_video_usage(
|
||||
user_id=user_id,
|
||||
provider=animation_result["provider"],
|
||||
model_name=animation_result["model_name"],
|
||||
prompt=animation_result["prompt"],
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
cost_override=animation_result["cost"],
|
||||
)
|
||||
if usage_info:
|
||||
scene_logger.warning(
|
||||
"[AnimateSceneVoiceover] Video usage tracked user=%s: %s → %s / %s (cost +$%.2f, total=$%.2f)",
|
||||
user_id,
|
||||
usage_info.get("previous_calls"),
|
||||
usage_info.get("current_calls"),
|
||||
usage_info.get("video_limit_display"),
|
||||
usage_info.get("cost_per_video", 0.0),
|
||||
usage_info.get("total_video_cost", 0.0),
|
||||
)
|
||||
|
||||
scene_logger.info(
|
||||
"[AnimateSceneVoiceover] ✅ Completed user=%s scene=%s cost=$%.2f video=%s",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
animation_result["cost"],
|
||||
video_url,
|
||||
)
|
||||
|
||||
# Save video asset to library
|
||||
db = next(get_db())
|
||||
try:
|
||||
save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="video",
|
||||
source_module="story_writer",
|
||||
filename=video_filename,
|
||||
file_url=video_url,
|
||||
file_path=str(ai_video_dir / video_filename),
|
||||
file_size=len(animation_result["video_bytes"]),
|
||||
mime_type="video/mp4",
|
||||
title=f"Scene {request.scene_number} Animation (Voiceover)",
|
||||
description=f"Animated scene {request.scene_number} with voiceover from story",
|
||||
prompt=animation_result["prompt"],
|
||||
tags=["story_writer", "video", "animation", "voiceover", f"scene_{request.scene_number}"],
|
||||
provider=animation_result["provider"],
|
||||
model=animation_result.get("model_name"),
|
||||
cost=animation_result["cost"],
|
||||
asset_metadata={"scene_number": request.scene_number, "duration": animation_result["duration"], "status": "completed"}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[StoryWriter] Failed to save video asset to library: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
result = AnimateSceneResponse(
|
||||
success=True,
|
||||
scene_number=request.scene_number,
|
||||
video_filename=video_filename,
|
||||
video_url=video_url,
|
||||
duration=animation_result["duration"],
|
||||
cost=animation_result["cost"],
|
||||
prompt_used=animation_result["prompt"],
|
||||
provider=animation_result["provider"],
|
||||
prediction_id=animation_result.get("prediction_id"),
|
||||
)
|
||||
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"completed",
|
||||
progress=100.0,
|
||||
message="InfiniteTalk animation complete!",
|
||||
result=result.dict(),
|
||||
)
|
||||
except HTTPException as exc:
|
||||
error_msg = str(exc.detail) if isinstance(exc.detail, str) else exc.detail.get("error", "Animation failed") if isinstance(exc.detail, dict) else "Animation failed"
|
||||
scene_logger.error(f"[AnimateSceneVoiceover] Failed: {error_msg}")
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"failed",
|
||||
error=error_msg,
|
||||
message=f"InfiniteTalk animation failed: {error_msg}",
|
||||
)
|
||||
except Exception as exc:
|
||||
error_msg = str(exc)
|
||||
scene_logger.error(f"[AnimateSceneVoiceover] Error: {error_msg}", exc_info=True)
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"failed",
|
||||
error=error_msg,
|
||||
message=f"InfiniteTalk animation error: {error_msg}",
|
||||
)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import Any, Dict, List
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from models.story_models import (
|
||||
@@ -18,6 +20,7 @@ from ..utils.auth import require_authenticated_user
|
||||
|
||||
router = APIRouter()
|
||||
story_service = StoryWriterService()
|
||||
scene_approval_store: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
@router.post("/generate-start", response_model=StoryContentResponse)
|
||||
@@ -193,3 +196,45 @@ async def continue_story(
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
class SceneApprovalRequest(BaseModel):
|
||||
project_id: str
|
||||
scene_id: str
|
||||
approved: bool = True
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/script/approve")
|
||||
async def approve_script_scene(
|
||||
request: SceneApprovalRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
"""Persist scene approval metadata for auditing."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
approvals = scene_approval_store.setdefault(request.project_id, {})
|
||||
approvals[request.scene_id] = {
|
||||
"approved": request.approved,
|
||||
"notes": request.notes,
|
||||
"user_id": user_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
logger.info(
|
||||
"[StoryWriter] Scene approval recorded user=%s project=%s scene=%s approved=%s",
|
||||
user_id,
|
||||
request.project_id,
|
||||
request.scene_id,
|
||||
request.approved,
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"project_id": request.project_id,
|
||||
"scene_id": request.scene_id,
|
||||
"approved": request.approved,
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to approve scene: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
|
||||
@@ -509,3 +509,30 @@ async def serve_story_video(
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.get("/videos/ai/{video_filename}")
|
||||
async def serve_ai_story_video(
|
||||
video_filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Serve a generated AI scene animation video."""
|
||||
try:
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
base_dir = Path(__file__).parent.parent.parent.parent
|
||||
ai_video_dir = (base_dir / "story_videos" / "AI_Videos").resolve()
|
||||
video_service_ai = StoryVideoGenerationService(output_dir=str(ai_video_dir))
|
||||
video_path = resolve_media_file(video_service_ai.output_dir, video_filename)
|
||||
|
||||
return FileResponse(
|
||||
path=str(video_path),
|
||||
media_type="video/mp4",
|
||||
filename=video_filename
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to serve AI video: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
|
||||
@@ -53,6 +53,7 @@ from api.linkedin_image_generation import router as linkedin_image_router
|
||||
from api.brainstorm import router as brainstorm_router
|
||||
from api.images import router as images_router
|
||||
from routers.image_studio import router as image_studio_router
|
||||
from routers.product_marketing import router as product_marketing_router
|
||||
|
||||
# Import hallucination detector router
|
||||
from api.hallucination_detector import router as hallucination_detector_router
|
||||
@@ -298,6 +299,7 @@ from routers.platform_analytics import router as platform_analytics_router
|
||||
app.include_router(platform_analytics_router)
|
||||
app.include_router(images_router)
|
||||
app.include_router(image_studio_router)
|
||||
app.include_router(product_marketing_router)
|
||||
|
||||
# Include content assets router
|
||||
from api.content_assets.router import router as content_assets_router
|
||||
|
||||
264
backend/docs/ASSET_TRACKING_IMPLEMENTATION.md
Normal file
264
backend/docs/ASSET_TRACKING_IMPLEMENTATION.md
Normal file
@@ -0,0 +1,264 @@
|
||||
# Asset Tracking Implementation Guide
|
||||
|
||||
## Overview
|
||||
|
||||
This document describes the production-ready implementation of asset tracking across all ALwrity modules. The unified Content Asset Library automatically tracks all AI-generated content (images, videos, audio, text) for easy management and organization.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Components
|
||||
|
||||
1. **Database Models** (`backend/models/content_asset_models.py`)
|
||||
- `ContentAsset`: Main model for tracking assets
|
||||
- `AssetCollection`: Collections/albums for organizing assets
|
||||
- `AssetType`: Enum (text, image, video, audio)
|
||||
- `AssetSource`: Enum (all ALwrity modules)
|
||||
|
||||
2. **Service Layer** (`backend/services/content_asset_service.py`)
|
||||
- CRUD operations for assets
|
||||
- Search, filter, pagination
|
||||
- Usage tracking
|
||||
|
||||
3. **Utility Functions**
|
||||
- `backend/utils/asset_tracker.py`: `save_asset_to_library()` helper
|
||||
- `backend/utils/file_storage.py`: Robust file saving utilities
|
||||
|
||||
## Implementation Status
|
||||
|
||||
### ✅ Completed Integrations
|
||||
|
||||
#### 1. Story Writer (`backend/api/story_writer/router.py`)
|
||||
- **Images**: Tracks all scene images with metadata
|
||||
- **Audio**: Tracks all scene audio files with narration details
|
||||
- **Videos**: Tracks individual scene videos and complete story videos
|
||||
- **Location**: After generation in `/generate-images`, `/generate-audio`, `/generate-video`, `/generate-complete-video`
|
||||
- **Metadata**: Includes prompts, scene numbers, providers, models, costs, status
|
||||
|
||||
#### 2. Image Studio (`backend/api/images.py`)
|
||||
- **Image Generation**: Tracks all generated images
|
||||
- **Image Editing**: Tracks all edited images
|
||||
- **Location**: After generation in `/api/images/generate` and `/api/images/edit`
|
||||
- **Features**:
|
||||
- Robust file saving with validation
|
||||
- Atomic file writes
|
||||
- Proper error handling (non-blocking)
|
||||
- File serving endpoint at `/api/images/image-studio/images/{filename}`
|
||||
|
||||
### 📝 Notes on Other Modules
|
||||
|
||||
#### Main Generation Services
|
||||
- **Text Generation** (`main_text_generation.py`): Returns strings, not files. If text content needs tracking, save to `.txt` or `.md` files first.
|
||||
- **Video Generation** (`main_video_generation.py`): Already integrated via Story Writer
|
||||
- **Audio Generation** (`main_audio_generation.py`): Already integrated via Story Writer
|
||||
|
||||
#### Social Writers
|
||||
- **LinkedIn Writer**: Generates text content (posts, articles). No file generation currently.
|
||||
- **Facebook Writer**: Generates text content (posts, stories, reels). No file generation currently.
|
||||
- **Blog Writer**: Generates blog content. May generate images in future.
|
||||
|
||||
**Note**: If these modules generate files in the future, follow the integration pattern below.
|
||||
|
||||
## Integration Pattern
|
||||
|
||||
### For Image Generation
|
||||
|
||||
```python
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from utils.file_storage import save_file_safely, generate_unique_filename
|
||||
from sqlalchemy.orm import Session
|
||||
from pathlib import Path
|
||||
|
||||
# After successful image generation
|
||||
try:
|
||||
base_dir = Path(__file__).parent.parent
|
||||
output_dir = base_dir / "module_images"
|
||||
|
||||
image_filename = generate_unique_filename(
|
||||
prefix="img_prompt",
|
||||
extension=".png",
|
||||
include_uuid=True
|
||||
)
|
||||
|
||||
# Save file safely
|
||||
image_path, save_error = save_file_safely(
|
||||
content=result.image_bytes,
|
||||
directory=output_dir,
|
||||
filename=image_filename,
|
||||
max_file_size=50 * 1024 * 1024 # 50MB
|
||||
)
|
||||
|
||||
if image_path and not save_error:
|
||||
image_url = f"/api/module/images/{image_path.name}"
|
||||
|
||||
# Track in asset library (non-blocking)
|
||||
try:
|
||||
asset_id = save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="image",
|
||||
source_module="module_name",
|
||||
filename=image_path.name,
|
||||
file_url=image_url,
|
||||
file_path=str(image_path),
|
||||
file_size=len(result.image_bytes),
|
||||
mime_type="image/png",
|
||||
title="Image Title",
|
||||
description="Image description",
|
||||
prompt=prompt,
|
||||
tags=["tag1", "tag2"],
|
||||
provider=result.provider,
|
||||
model=result.model,
|
||||
metadata={"status": "completed"}
|
||||
)
|
||||
logger.info(f"✅ Asset saved: ID={asset_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Asset tracking failed: {e}", exc_info=True)
|
||||
# Don't fail the request
|
||||
except Exception as e:
|
||||
logger.error(f"File save failed: {e}", exc_info=True)
|
||||
# Continue - base64 is still available
|
||||
```
|
||||
|
||||
### For Video Generation
|
||||
|
||||
```python
|
||||
# After successful video generation
|
||||
try:
|
||||
asset_id = save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="video",
|
||||
source_module="module_name",
|
||||
filename=video_filename,
|
||||
file_url=video_url,
|
||||
file_path=str(video_path),
|
||||
file_size=file_size,
|
||||
mime_type="video/mp4",
|
||||
title="Video Title",
|
||||
description="Video description",
|
||||
prompt=prompt,
|
||||
tags=["video", "tag"],
|
||||
provider=provider,
|
||||
model=model,
|
||||
cost=cost,
|
||||
metadata={"duration": duration, "status": "completed"}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Asset tracking failed: {e}", exc_info=True)
|
||||
```
|
||||
|
||||
### For Audio Generation
|
||||
|
||||
```python
|
||||
# After successful audio generation
|
||||
try:
|
||||
asset_id = save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="audio",
|
||||
source_module="module_name",
|
||||
filename=audio_filename,
|
||||
file_url=audio_url,
|
||||
file_path=str(audio_path),
|
||||
file_size=file_size,
|
||||
mime_type="audio/mpeg",
|
||||
title="Audio Title",
|
||||
description="Audio description",
|
||||
prompt=text,
|
||||
tags=["audio", "tag"],
|
||||
provider=provider,
|
||||
model=model,
|
||||
cost=cost,
|
||||
metadata={"status": "completed"}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Asset tracking failed: {e}", exc_info=True)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Error Handling
|
||||
- **Always non-blocking**: Asset tracking failures should never break the main request
|
||||
- **Log errors**: Use `logger.error()` with `exc_info=True` for debugging
|
||||
- **Graceful degradation**: Continue with base64/file response even if tracking fails
|
||||
|
||||
### 2. File Management
|
||||
- **Use `save_file_safely()`**: Handles validation, atomic writes, directory creation
|
||||
- **Sanitize filenames**: Use `sanitize_filename()` to prevent path traversal
|
||||
- **Unique filenames**: Use `generate_unique_filename()` with UUIDs
|
||||
- **File size limits**: Enforce reasonable limits (50MB for images, 100MB for videos)
|
||||
|
||||
### 3. Database Sessions
|
||||
- **Pass session explicitly**: Use `db: Session = Depends(get_db)` in endpoints
|
||||
- **Handle session lifecycle**: Let FastAPI manage session cleanup
|
||||
- **Background tasks**: Get new session in background tasks
|
||||
|
||||
### 4. Metadata
|
||||
- **Rich metadata**: Include provider, model, dimensions, cost, status
|
||||
- **Searchable tags**: Use consistent tag naming (e.g., "image_studio", "generated")
|
||||
- **Status tracking**: Always include `"status": "completed"` in metadata
|
||||
|
||||
### 5. File URLs
|
||||
- **Consistent patterns**: Use `/api/{module}/images/{filename}` format
|
||||
- **Serving endpoints**: Create corresponding GET endpoints to serve files
|
||||
- **Authentication**: Protect file serving endpoints with `get_current_user`
|
||||
|
||||
## File Storage Utilities
|
||||
|
||||
### `save_file_safely()`
|
||||
- Validates file size
|
||||
- Creates directories automatically
|
||||
- Atomic writes (temp file + rename)
|
||||
- Returns `(file_path, error_message)` tuple
|
||||
|
||||
### `sanitize_filename()`
|
||||
- Removes dangerous characters
|
||||
- Prevents path traversal
|
||||
- Limits filename length
|
||||
- Handles empty filenames
|
||||
|
||||
### `generate_unique_filename()`
|
||||
- Creates unique filenames with UUIDs
|
||||
- Sanitizes prefix
|
||||
- Handles extensions properly
|
||||
|
||||
## Testing Checklist
|
||||
|
||||
- [ ] Images are saved to disk correctly
|
||||
- [ ] Files are accessible via serving endpoints
|
||||
- [ ] Asset tracking works (check database)
|
||||
- [ ] Errors don't break main requests
|
||||
- [ ] File size limits are enforced
|
||||
- [ ] Filenames are sanitized properly
|
||||
- [ ] Metadata is complete and accurate
|
||||
- [ ] Asset Library UI displays assets correctly
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
1. **Text Content Tracking**: Save text content as files when needed
|
||||
2. **Batch Operations**: Track multiple assets in single transaction
|
||||
3. **File Cleanup**: Automatic cleanup of orphaned files
|
||||
4. **Storage Backends**: Support S3, GCS for production
|
||||
5. **Thumbnail Generation**: Auto-generate thumbnails for videos/images
|
||||
6. **Compression**: Compress large files before storage
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Assets not appearing in library
|
||||
1. Check database: `SELECT * FROM content_assets WHERE user_id = '...'`
|
||||
2. Check logs for asset tracking errors
|
||||
3. Verify `save_asset_to_library()` returns asset ID
|
||||
4. Check file URLs are correct
|
||||
|
||||
### File serving fails
|
||||
1. Verify file exists on disk
|
||||
2. Check serving endpoint is registered
|
||||
3. Verify authentication is working
|
||||
4. Check file permissions
|
||||
|
||||
### Performance issues
|
||||
1. Use background tasks for heavy operations
|
||||
2. Batch database operations
|
||||
3. Consider async file I/O for large files
|
||||
4. Monitor database query performance
|
||||
|
||||
143
backend/docs/TEXT_ASSET_TRACKING_IMPLEMENTATION.md
Normal file
143
backend/docs/TEXT_ASSET_TRACKING_IMPLEMENTATION.md
Normal file
@@ -0,0 +1,143 @@
|
||||
# Text Asset Tracking Implementation
|
||||
|
||||
## Overview
|
||||
|
||||
Text content tracking has been successfully implemented across LinkedIn Writer and Facebook Writer endpoints. All generated text content is automatically saved as files and tracked in the unified Content Asset Library.
|
||||
|
||||
## Implementation Status
|
||||
|
||||
### ✅ Completed Integrations
|
||||
|
||||
#### 1. LinkedIn Writer (`backend/routers/linkedin.py`)
|
||||
- **Post Generation**: Tracks LinkedIn posts with content, hashtags, and CTAs
|
||||
- **Article Generation**: Tracks LinkedIn articles with full content, sections, and SEO metadata
|
||||
- **Carousel Generation**: Tracks LinkedIn carousels with all slides
|
||||
- **Video Script Generation**: Tracks LinkedIn video scripts with hooks, scenes, captions
|
||||
- **Comment Response Generation**: Tracks LinkedIn comment responses
|
||||
|
||||
**File Format**: Markdown (`.md`) for articles, carousels, video scripts, comment responses; Text (`.txt`) for posts
|
||||
|
||||
**Storage Location**: `backend/linkedinwriter_text/{subdirectory}/`
|
||||
- `posts/` - LinkedIn posts
|
||||
- `articles/` - LinkedIn articles
|
||||
- `carousels/` - LinkedIn carousels
|
||||
- `video_scripts/` - LinkedIn video scripts
|
||||
- `comment_responses/` - LinkedIn comment responses
|
||||
|
||||
#### 2. Facebook Writer (`backend/api/facebook_writer/routers/facebook_router.py`)
|
||||
- **Post Generation**: Tracks Facebook posts with content and analytics
|
||||
- **Story Generation**: Tracks Facebook stories
|
||||
|
||||
**File Format**: Text (`.txt`)
|
||||
|
||||
**Storage Location**: `backend/facebookwriter_text/{subdirectory}/`
|
||||
- `posts/` - Facebook posts
|
||||
- `stories/` - Facebook stories
|
||||
|
||||
### 📝 Pending Integrations
|
||||
|
||||
#### Facebook Writer (Additional Endpoints)
|
||||
- Reel Generation
|
||||
- Carousel Generation
|
||||
- Event Generation
|
||||
- Group Post Generation
|
||||
- Page About Generation
|
||||
- Ad Copy Generation
|
||||
- Hashtag Generation
|
||||
|
||||
#### Blog Writer (`backend/api/blog_writer/router.py`)
|
||||
- Blog content generation endpoints
|
||||
- Medium blog generation
|
||||
- Blog section generation
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Components
|
||||
|
||||
1. **Text Asset Tracker** (`backend/utils/text_asset_tracker.py`)
|
||||
- `save_and_track_text_content()`: Main function for saving and tracking text
|
||||
- Handles file saving, URL generation, and asset library tracking
|
||||
- Non-blocking error handling
|
||||
|
||||
2. **File Storage Utilities** (`backend/utils/file_storage.py`)
|
||||
- `save_text_file_safely()`: Safely saves text files with validation
|
||||
- `sanitize_filename()`: Prevents path traversal
|
||||
- `generate_unique_filename()`: Creates unique filenames
|
||||
|
||||
3. **Asset Tracker** (`backend/utils/asset_tracker.py`)
|
||||
- `save_asset_to_library()`: Saves asset metadata to database
|
||||
|
||||
## Integration Pattern
|
||||
|
||||
### Basic Integration
|
||||
|
||||
```python
|
||||
from utils.text_asset_tracker import save_and_track_text_content
|
||||
from sqlalchemy.orm import Session
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
@router.post("/generate-content")
|
||||
async def generate_content(
|
||||
request: ContentRequest,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
# Generate content
|
||||
response = await service.generate(request)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if response.content:
|
||||
try:
|
||||
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
|
||||
|
||||
if user_id:
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=response.content,
|
||||
source_module="module_name",
|
||||
title=f"Content Title: {request.topic[:60]}",
|
||||
description=f"Content description",
|
||||
prompt=f"Topic: {request.topic}",
|
||||
tags=["tag1", "tag2"],
|
||||
metadata={"key": "value"},
|
||||
subdirectory="content_type"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track text asset: {track_error}")
|
||||
|
||||
return response
|
||||
```
|
||||
|
||||
## File Serving
|
||||
|
||||
Text files are saved with URLs like `/api/text-assets/{module}/{subdirectory}/{filename}`. A serving endpoint should be created in `backend/app.py`:
|
||||
|
||||
```python
|
||||
@router.get("/api/text-assets/{file_path:path}")
|
||||
async def serve_text_asset(
|
||||
file_path: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Serve text assets with authentication."""
|
||||
# Implementation needed
|
||||
pass
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Non-blocking**: Text tracking failures should never break the main request
|
||||
2. **Error Handling**: Use try/except around tracking calls
|
||||
3. **User ID Extraction**: Support both `current_user` dependency and header-based extraction
|
||||
4. **Content Formatting**: Combine related content (e.g., post + hashtags + CTA)
|
||||
5. **Metadata**: Include rich metadata for search and filtering
|
||||
6. **File Organization**: Use subdirectories to organize by content type
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. Add text tracking to remaining Facebook Writer endpoints
|
||||
2. Add text tracking to Blog Writer endpoints
|
||||
3. Create text asset serving endpoint
|
||||
4. Add text preview in Asset Library UI
|
||||
5. Support text file downloads
|
||||
|
||||
@@ -22,41 +22,31 @@ class AssetType(enum.Enum):
|
||||
|
||||
|
||||
class AssetSource(enum.Enum):
|
||||
"""Source module/tool that generated the asset - covers ALL ALwrity tools."""
|
||||
# Image Studio modules
|
||||
IMAGE_STUDIO_CREATE = "image_studio_create"
|
||||
IMAGE_STUDIO_EDIT = "image_studio_edit"
|
||||
IMAGE_STUDIO_UPSCALE = "image_studio_upscale"
|
||||
IMAGE_STUDIO_TRANSFORM = "image_studio_transform"
|
||||
IMAGE_STUDIO_CONTROL = "image_studio_control"
|
||||
IMAGE_STUDIO_SOCIAL = "image_studio_social"
|
||||
IMAGE_STUDIO_BATCH = "image_studio_batch"
|
||||
|
||||
# Content Writers
|
||||
"""Source module/tool that generated the asset."""
|
||||
# Core Content Generation
|
||||
STORY_WRITER = "story_writer"
|
||||
BLOG_WRITER = "blog_writer"
|
||||
LINKEDIN_WRITER = "linkedin_writer"
|
||||
FACEBOOK_WRITER = "facebook_writer"
|
||||
|
||||
# Content Planning
|
||||
CONTENT_PLANNING = "content_planning"
|
||||
CONTENT_STRATEGY = "content_strategy"
|
||||
|
||||
# SEO Tools
|
||||
SEO_DASHBOARD = "seo_dashboard"
|
||||
SEO_TOOLS = "seo_tools"
|
||||
|
||||
# Research
|
||||
RESEARCH = "research"
|
||||
|
||||
# Scheduler
|
||||
SCHEDULER = "scheduler"
|
||||
|
||||
# Main Generation (legacy/fallback)
|
||||
IMAGE_STUDIO = "image_studio"
|
||||
MAIN_TEXT_GENERATION = "main_text_generation"
|
||||
MAIN_IMAGE_GENERATION = "main_image_generation"
|
||||
MAIN_VIDEO_GENERATION = "main_video_generation"
|
||||
MAIN_AUDIO_GENERATION = "main_audio_generation"
|
||||
|
||||
# Social Media Writers
|
||||
BLOG_WRITER = "blog_writer"
|
||||
LINKEDIN_WRITER = "linkedin_writer"
|
||||
FACEBOOK_WRITER = "facebook_writer"
|
||||
|
||||
# SEO & Content Tools
|
||||
SEO_TOOLS = "seo_tools"
|
||||
CONTENT_PLANNING = "content_planning"
|
||||
WRITING_ASSISTANT = "writing_assistant"
|
||||
|
||||
# Research & Strategy
|
||||
RESEARCH_TOOLS = "research_tools"
|
||||
CONTENT_STRATEGY = "content_strategy"
|
||||
|
||||
# Product Marketing Suite
|
||||
PRODUCT_MARKETING = "product_marketing"
|
||||
|
||||
|
||||
class ContentAsset(Base):
|
||||
@@ -87,18 +77,14 @@ class ContentAsset(Base):
|
||||
description = Column(Text, nullable=True)
|
||||
prompt = Column(Text, nullable=True) # Original prompt used for generation
|
||||
tags = Column(JSON, nullable=True) # Array of tags for search/filtering
|
||||
metadata = Column(JSON, nullable=True) # Additional module-specific metadata
|
||||
asset_metadata = Column(JSON, nullable=True) # Additional module-specific metadata (renamed from 'metadata' to avoid SQLAlchemy conflict)
|
||||
|
||||
# Generation details
|
||||
provider = Column(String(100), nullable=True, index=True) # AI provider used (e.g., "stability", "gemini")
|
||||
model = Column(String(200), nullable=True, index=True) # Model used (full model path/name)
|
||||
provider = Column(String(100), nullable=True) # AI provider used (e.g., "stability", "gemini")
|
||||
model = Column(String(100), nullable=True) # Model used
|
||||
cost = Column(Float, nullable=True, default=0.0) # Generation cost in USD
|
||||
generation_time = Column(Float, nullable=True) # Time taken in seconds
|
||||
|
||||
# Status tracking
|
||||
status = Column(String(50), default='completed', index=True) # completed, processing, failed, pending
|
||||
error_message = Column(Text, nullable=True) # Error details if failed
|
||||
|
||||
# Organization
|
||||
is_favorite = Column(Boolean, default=False, index=True)
|
||||
collection_id = Column(Integer, ForeignKey('asset_collections.id'), nullable=True)
|
||||
@@ -113,7 +99,11 @@ class ContentAsset(Base):
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
collection = relationship("AssetCollection", back_populates="assets", cascade="all, delete-orphan")
|
||||
collection = relationship(
|
||||
"AssetCollection",
|
||||
back_populates="assets",
|
||||
foreign_keys=[collection_id]
|
||||
)
|
||||
|
||||
# Composite indexes for common query patterns
|
||||
__table_args__ = (
|
||||
@@ -141,5 +131,15 @@ class AssetCollection(Base):
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
assets = relationship("ContentAsset", back_populates="collection")
|
||||
assets = relationship(
|
||||
"ContentAsset",
|
||||
back_populates="collection",
|
||||
foreign_keys="[ContentAsset.collection_id]",
|
||||
cascade="all, delete-orphan" # Cascade delete on the "one" side (one-to-many)
|
||||
)
|
||||
cover_asset = relationship(
|
||||
"ContentAsset",
|
||||
foreign_keys=[cover_asset_id],
|
||||
uselist=False
|
||||
)
|
||||
|
||||
|
||||
155
backend/models/product_asset_models.py
Normal file
155
backend/models/product_asset_models.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""
|
||||
Product Asset Models
|
||||
Database models for storing product-specific assets (separate from campaign assets).
|
||||
These models are for the Product Marketing Suite (product asset creation).
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, Text, ForeignKey, Index
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
import enum
|
||||
|
||||
from models.subscription_models import Base
|
||||
|
||||
|
||||
class ProductAssetType(enum.Enum):
|
||||
"""Product asset type enum."""
|
||||
IMAGE = "image"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
ANIMATION = "animation"
|
||||
|
||||
|
||||
class ProductImageStyle(enum.Enum):
|
||||
"""Product image style enum."""
|
||||
STUDIO = "studio"
|
||||
LIFESTYLE = "lifestyle"
|
||||
OUTDOOR = "outdoor"
|
||||
MINIMALIST = "minimalist"
|
||||
LUXURY = "luxury"
|
||||
TECHNICAL = "technical"
|
||||
|
||||
|
||||
class ProductAsset(Base):
|
||||
"""
|
||||
Product asset model.
|
||||
Stores product-specific assets (images, videos, audio) generated for product marketing.
|
||||
"""
|
||||
|
||||
__tablename__ = "product_assets"
|
||||
|
||||
# Primary fields
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
product_id = Column(String(255), nullable=False, index=True) # User-defined product ID
|
||||
user_id = Column(String(255), nullable=False, index=True) # Clerk user ID
|
||||
|
||||
# Product information
|
||||
product_name = Column(String(500), nullable=False)
|
||||
product_description = Column(Text, nullable=True)
|
||||
|
||||
# Asset details
|
||||
asset_type = Column(String(50), nullable=False, index=True) # image, video, audio, animation
|
||||
variant = Column(String(100), nullable=True) # color, size, angle, etc.
|
||||
style = Column(String(50), nullable=True) # studio, lifestyle, minimalist, etc.
|
||||
environment = Column(String(50), nullable=True) # studio, lifestyle, outdoor, etc.
|
||||
|
||||
# Link to ContentAsset (unified asset library)
|
||||
content_asset_id = Column(Integer, ForeignKey('content_assets.id', ondelete='SET NULL'), nullable=True, index=True)
|
||||
|
||||
# Generation details
|
||||
provider = Column(String(100), nullable=True)
|
||||
model = Column(String(100), nullable=True)
|
||||
cost = Column(Float, default=0.0)
|
||||
generation_time = Column(Float, nullable=True)
|
||||
prompt_used = Column(Text, nullable=True)
|
||||
|
||||
# E-commerce integration
|
||||
ecommerce_exported = Column(Boolean, default=False)
|
||||
exported_to = Column(JSON, nullable=True) # Array of platform names
|
||||
|
||||
# Status
|
||||
status = Column(String(50), default="completed", nullable=False) # completed, processing, failed
|
||||
|
||||
# Metadata
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Additional metadata
|
||||
metadata = Column(JSON, nullable=True) # Additional product-specific metadata
|
||||
|
||||
# Composite indexes
|
||||
__table_args__ = (
|
||||
Index('idx_user_product', 'user_id', 'product_id'),
|
||||
Index('idx_user_type', 'user_id', 'asset_type'),
|
||||
Index('idx_product_type', 'product_id', 'asset_type'),
|
||||
)
|
||||
|
||||
|
||||
class ProductStyleTemplate(Base):
|
||||
"""
|
||||
Brand style template for products.
|
||||
Stores reusable brand style configurations for product asset generation.
|
||||
"""
|
||||
|
||||
__tablename__ = "product_style_templates"
|
||||
|
||||
# Primary fields
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id = Column(String(255), nullable=False, index=True)
|
||||
template_name = Column(String(255), nullable=False)
|
||||
|
||||
# Style configuration
|
||||
color_palette = Column(JSON, nullable=True) # Array of brand colors
|
||||
background_style = Column(String(50), nullable=True) # white, transparent, lifestyle, branded
|
||||
lighting_preset = Column(String(50), nullable=True) # natural, studio, dramatic, soft
|
||||
preferred_style = Column(String(50), nullable=True) # photorealistic, minimalist, luxury, technical
|
||||
preferred_environment = Column(String(50), nullable=True) # studio, lifestyle, outdoor
|
||||
|
||||
# Brand integration
|
||||
use_brand_colors = Column(Boolean, default=True)
|
||||
use_brand_logo = Column(Boolean, default=False)
|
||||
|
||||
# Metadata
|
||||
is_default = Column(Boolean, default=False) # Default template for user
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Composite indexes
|
||||
__table_args__ = (
|
||||
Index('idx_user_template', 'user_id', 'template_name'),
|
||||
)
|
||||
|
||||
|
||||
class EcommerceExport(Base):
|
||||
"""
|
||||
E-commerce platform export tracking.
|
||||
Tracks product asset exports to e-commerce platforms.
|
||||
"""
|
||||
|
||||
__tablename__ = "product_ecommerce_exports"
|
||||
|
||||
# Primary fields
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id = Column(String(255), nullable=False, index=True)
|
||||
product_id = Column(String(255), nullable=False, index=True)
|
||||
|
||||
# Platform information
|
||||
platform = Column(String(50), nullable=False) # shopify, amazon, woocommerce
|
||||
platform_product_id = Column(String(255), nullable=True) # Product ID on the platform
|
||||
|
||||
# Export details
|
||||
exported_assets = Column(JSON, nullable=False) # Array of asset IDs exported
|
||||
export_status = Column(String(50), default="pending", nullable=False) # pending, completed, failed
|
||||
error_message = Column(Text, nullable=True)
|
||||
|
||||
# Metadata
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
exported_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Composite indexes
|
||||
__table_args__ = (
|
||||
Index('idx_user_platform', 'user_id', 'platform'),
|
||||
Index('idx_product_platform', 'product_id', 'platform'),
|
||||
)
|
||||
|
||||
162
backend/models/product_marketing_models.py
Normal file
162
backend/models/product_marketing_models.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Product Marketing Campaign Models
|
||||
Database models for storing campaign blueprints and asset proposals.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, Text, ForeignKey, Index, func
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
import enum
|
||||
|
||||
from models.subscription_models import Base
|
||||
|
||||
|
||||
class CampaignStatus(enum.Enum):
|
||||
"""Campaign status enum."""
|
||||
DRAFT = "draft"
|
||||
GENERATING = "generating"
|
||||
READY = "ready"
|
||||
PUBLISHED = "published"
|
||||
ARCHIVED = "archived"
|
||||
|
||||
|
||||
class AssetNodeStatus(enum.Enum):
|
||||
"""Asset node status enum."""
|
||||
DRAFT = "draft"
|
||||
PROPOSED = "proposed"
|
||||
GENERATING = "generating"
|
||||
READY = "ready"
|
||||
APPROVED = "approved"
|
||||
REJECTED = "rejected"
|
||||
|
||||
|
||||
class Campaign(Base):
|
||||
"""
|
||||
Campaign blueprint model.
|
||||
Stores campaign information, phases, and asset nodes.
|
||||
"""
|
||||
|
||||
__tablename__ = "product_marketing_campaigns"
|
||||
|
||||
# Primary fields
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
campaign_id = Column(String(255), unique=True, nullable=False, index=True)
|
||||
user_id = Column(String(255), nullable=False, index=True) # Clerk user ID
|
||||
|
||||
# Campaign details
|
||||
campaign_name = Column(String(500), nullable=False)
|
||||
goal = Column(String(100), nullable=False) # product_launch, awareness, conversion, etc.
|
||||
kpi = Column(String(500), nullable=True)
|
||||
status = Column(String(50), default="draft", nullable=False, index=True)
|
||||
|
||||
# Campaign structure
|
||||
phases = Column(JSON, nullable=True) # Array of phase objects
|
||||
channels = Column(JSON, nullable=False) # Array of channel strings
|
||||
asset_nodes = Column(JSON, nullable=True) # Array of asset node objects
|
||||
|
||||
# Product context
|
||||
product_context = Column(JSON, nullable=True) # Product information
|
||||
|
||||
# Metadata
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
proposals = relationship("CampaignProposal", back_populates="campaign", cascade="all, delete-orphan")
|
||||
generated_assets = relationship("CampaignAsset", back_populates="campaign", cascade="all, delete-orphan")
|
||||
|
||||
# Composite indexes
|
||||
__table_args__ = (
|
||||
Index('idx_user_status', 'user_id', 'status'),
|
||||
Index('idx_user_created', 'user_id', 'created_at'),
|
||||
)
|
||||
|
||||
|
||||
class CampaignProposal(Base):
|
||||
"""
|
||||
Asset proposals for a campaign.
|
||||
Stores AI-generated proposals for each asset node.
|
||||
"""
|
||||
|
||||
__tablename__ = "product_marketing_proposals"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
campaign_id = Column(String(255), ForeignKey('product_marketing_campaigns.campaign_id', ondelete='CASCADE'), nullable=False, index=True)
|
||||
user_id = Column(String(255), nullable=False, index=True)
|
||||
|
||||
# Asset node reference
|
||||
asset_node_id = Column(String(255), nullable=False, index=True)
|
||||
asset_type = Column(String(50), nullable=False) # image, text, video, audio
|
||||
channel = Column(String(50), nullable=False)
|
||||
|
||||
# Proposal details
|
||||
proposed_prompt = Column(Text, nullable=False)
|
||||
recommended_template = Column(String(255), nullable=True)
|
||||
recommended_provider = Column(String(100), nullable=True)
|
||||
recommended_model = Column(String(100), nullable=True)
|
||||
cost_estimate = Column(Float, default=0.0)
|
||||
concept_summary = Column(Text, nullable=True)
|
||||
|
||||
# Status
|
||||
status = Column(String(50), default="proposed", nullable=False) # proposed, approved, rejected, generating
|
||||
approved_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Metadata
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
campaign = relationship("Campaign", back_populates="proposals")
|
||||
generated_asset = relationship("CampaignAsset", back_populates="proposal", uselist=False)
|
||||
|
||||
# Composite indexes
|
||||
__table_args__ = (
|
||||
Index('idx_campaign_node', 'campaign_id', 'asset_node_id'),
|
||||
Index('idx_user_status', 'user_id', 'status'),
|
||||
)
|
||||
|
||||
|
||||
class CampaignAsset(Base):
|
||||
"""
|
||||
Generated assets for a campaign.
|
||||
Links to ContentAsset and stores campaign-specific metadata.
|
||||
"""
|
||||
|
||||
__tablename__ = "product_marketing_assets"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
campaign_id = Column(String(255), ForeignKey('product_marketing_campaigns.campaign_id', ondelete='CASCADE'), nullable=False, index=True)
|
||||
proposal_id = Column(Integer, ForeignKey('product_marketing_proposals.id', ondelete='SET NULL'), nullable=True)
|
||||
user_id = Column(String(255), nullable=False, index=True)
|
||||
|
||||
# Asset node reference
|
||||
asset_node_id = Column(String(255), nullable=False, index=True)
|
||||
|
||||
# Link to ContentAsset
|
||||
content_asset_id = Column(Integer, ForeignKey('content_assets.id', ondelete='SET NULL'), nullable=True)
|
||||
|
||||
# Generation details
|
||||
provider = Column(String(100), nullable=True)
|
||||
model = Column(String(100), nullable=True)
|
||||
cost = Column(Float, default=0.0)
|
||||
generation_time = Column(Float, nullable=True)
|
||||
|
||||
# Status
|
||||
status = Column(String(50), default="generating", nullable=False) # generating, ready, approved, published
|
||||
approved_at = Column(DateTime, nullable=True)
|
||||
published_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Metadata
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
campaign = relationship("Campaign", back_populates="generated_assets")
|
||||
proposal = relationship("CampaignProposal", back_populates="generated_asset")
|
||||
|
||||
# Composite indexes
|
||||
__table_args__ = (
|
||||
Index('idx_campaign_node', 'campaign_id', 'asset_node_id'),
|
||||
Index('idx_user_status', 'user_id', 'status'),
|
||||
)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
"""API endpoints for Image Studio operations."""
|
||||
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Any, Literal
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from services.image_studio import (
|
||||
@@ -11,10 +13,12 @@ from services.image_studio import (
|
||||
EditStudioRequest,
|
||||
ControlStudioRequest,
|
||||
SocialOptimizerRequest,
|
||||
TransformImageToVideoRequest,
|
||||
TalkingAvatarRequest,
|
||||
)
|
||||
from services.image_studio.upscale_service import UpscaleStudioRequest
|
||||
from services.image_studio.templates import Platform, TemplateCategory
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
@@ -136,7 +140,12 @@ def get_studio_manager() -> ImageStudioManager:
|
||||
|
||||
def _require_user_id(current_user: Dict[str, Any], operation: str) -> str:
|
||||
"""Ensure user_id is available for protected operations."""
|
||||
user_id = current_user.get("sub") or current_user.get("user_id")
|
||||
user_id = (
|
||||
current_user.get("sub")
|
||||
or current_user.get("user_id")
|
||||
or current_user.get("id")
|
||||
or current_user.get("clerk_user_id")
|
||||
)
|
||||
if not user_id:
|
||||
logger.error(
|
||||
"[Image Studio] ❌ Missing user_id for %s operation - blocking request",
|
||||
@@ -762,6 +771,244 @@ async def get_platform_specs(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ====================
|
||||
# TRANSFORM STUDIO ENDPOINTS
|
||||
# ====================
|
||||
|
||||
class TransformImageToVideoRequestModel(BaseModel):
|
||||
"""Request model for image-to-video transformation."""
|
||||
image_base64: str = Field(..., description="Image in base64 or data URL format")
|
||||
prompt: str = Field(..., description="Text prompt describing the video")
|
||||
audio_base64: Optional[str] = Field(None, description="Optional audio file (wav/mp3, 3-30s, ≤15MB)")
|
||||
resolution: Literal["480p", "720p", "1080p"] = Field("720p", description="Output resolution")
|
||||
duration: Literal[5, 10] = Field(5, description="Video duration in seconds")
|
||||
negative_prompt: Optional[str] = Field(None, description="Negative prompt")
|
||||
seed: Optional[int] = Field(None, description="Random seed for reproducibility")
|
||||
enable_prompt_expansion: bool = Field(True, description="Enable prompt optimizer")
|
||||
|
||||
|
||||
class TalkingAvatarRequestModel(BaseModel):
|
||||
"""Request model for talking avatar generation."""
|
||||
image_base64: str = Field(..., description="Person image in base64 or data URL")
|
||||
audio_base64: str = Field(..., description="Audio file in base64 or data URL (wav/mp3, max 10 minutes)")
|
||||
resolution: Literal["480p", "720p"] = Field("720p", description="Output resolution")
|
||||
prompt: Optional[str] = Field(None, description="Optional prompt for expression/style")
|
||||
mask_image_base64: Optional[str] = Field(None, description="Optional mask for animatable regions")
|
||||
seed: Optional[int] = Field(None, description="Random seed")
|
||||
|
||||
|
||||
class TransformVideoResponse(BaseModel):
|
||||
"""Response model for video generation."""
|
||||
success: bool
|
||||
video_url: Optional[str] = None
|
||||
video_base64: Optional[str] = None
|
||||
duration: float
|
||||
resolution: str
|
||||
width: int
|
||||
height: int
|
||||
file_size: int
|
||||
cost: float
|
||||
provider: str
|
||||
model: str
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class TransformCostEstimateRequest(BaseModel):
|
||||
"""Request model for cost estimation."""
|
||||
operation: Literal["image-to-video", "talking-avatar"] = Field(..., description="Operation type")
|
||||
resolution: str = Field(..., description="Output resolution")
|
||||
duration: Optional[int] = Field(None, description="Video duration in seconds (for image-to-video)")
|
||||
|
||||
|
||||
class TransformCostEstimateResponse(BaseModel):
|
||||
"""Response model for cost estimation."""
|
||||
estimated_cost: float
|
||||
breakdown: Dict[str, Any]
|
||||
currency: str
|
||||
provider: str
|
||||
model: str
|
||||
|
||||
|
||||
@router.post("/transform/image-to-video", response_model=TransformVideoResponse, summary="Transform Image to Video")
|
||||
async def transform_image_to_video(
|
||||
request: TransformImageToVideoRequestModel,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Transform an image into a video using WAN 2.5.
|
||||
|
||||
This endpoint generates a video from an image and text prompt, with optional audio synchronization.
|
||||
Supports resolutions of 480p, 720p, and 1080p, with durations of 5 or 10 seconds.
|
||||
|
||||
Returns:
|
||||
Video generation result with URL and metadata
|
||||
"""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "image-to-video transformation")
|
||||
logger.info(f"[Transform Studio] Image-to-video request from user {user_id}: resolution={request.resolution}, duration={request.duration}s")
|
||||
|
||||
# Convert request to service request
|
||||
transform_request = TransformImageToVideoRequest(
|
||||
image_base64=request.image_base64,
|
||||
prompt=request.prompt,
|
||||
audio_base64=request.audio_base64,
|
||||
resolution=request.resolution,
|
||||
duration=request.duration,
|
||||
negative_prompt=request.negative_prompt,
|
||||
seed=request.seed,
|
||||
enable_prompt_expansion=request.enable_prompt_expansion,
|
||||
)
|
||||
|
||||
# Generate video
|
||||
result = await studio_manager.transform_image_to_video(transform_request, user_id=user_id)
|
||||
|
||||
logger.info(f"[Transform Studio] ✅ Image-to-video completed: cost=${result['cost']:.2f}")
|
||||
return TransformVideoResponse(**result)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"[Transform Studio] ❌ Validation error: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Transform Studio] ❌ Unexpected error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Video generation failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/transform/talking-avatar", response_model=TransformVideoResponse, summary="Create Talking Avatar")
|
||||
async def create_talking_avatar(
|
||||
request: TalkingAvatarRequestModel,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Create a talking avatar video using InfiniteTalk.
|
||||
|
||||
This endpoint generates a video with precise lip-sync from an image and audio file.
|
||||
Supports resolutions of 480p and 720p, with videos up to 10 minutes long.
|
||||
|
||||
Returns:
|
||||
Video generation result with URL and metadata
|
||||
"""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "talking avatar generation")
|
||||
logger.info(f"[Transform Studio] Talking avatar request from user {user_id}: resolution={request.resolution}")
|
||||
|
||||
# Convert request to service request
|
||||
avatar_request = TalkingAvatarRequest(
|
||||
image_base64=request.image_base64,
|
||||
audio_base64=request.audio_base64,
|
||||
resolution=request.resolution,
|
||||
prompt=request.prompt,
|
||||
mask_image_base64=request.mask_image_base64,
|
||||
seed=request.seed,
|
||||
)
|
||||
|
||||
# Generate video
|
||||
result = await studio_manager.create_talking_avatar(avatar_request, user_id=user_id)
|
||||
|
||||
logger.info(f"[Transform Studio] ✅ Talking avatar completed: cost=${result['cost']:.2f}")
|
||||
return TransformVideoResponse(**result)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"[Transform Studio] ❌ Validation error: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Transform Studio] ❌ Unexpected error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Talking avatar generation failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/transform/estimate-cost", response_model=TransformCostEstimateResponse, summary="Estimate Transform Cost")
|
||||
async def estimate_transform_cost(
|
||||
request: TransformCostEstimateRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Estimate cost for transform operations.
|
||||
|
||||
Provides cost estimates before generation to help users make informed decisions.
|
||||
|
||||
Returns:
|
||||
Cost estimation details
|
||||
"""
|
||||
try:
|
||||
estimate = studio_manager.estimate_transform_cost(
|
||||
operation=request.operation,
|
||||
resolution=request.resolution,
|
||||
duration=request.duration,
|
||||
)
|
||||
return TransformCostEstimateResponse(**estimate)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"[Transform Studio] ❌ Cost estimation error: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"[Transform Studio] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/videos/{user_id}/{video_filename:path}", summary="Serve Transform Studio Video")
|
||||
async def serve_transform_video(
|
||||
user_id: str,
|
||||
video_filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_with_query_token),
|
||||
):
|
||||
"""Serve a generated Transform Studio video file.
|
||||
|
||||
Args:
|
||||
user_id: User ID from URL path
|
||||
video_filename: Video filename
|
||||
current_user: Authenticated user
|
||||
|
||||
Returns:
|
||||
Video file response
|
||||
"""
|
||||
try:
|
||||
# Verify user has access (must be the owner)
|
||||
authenticated_user_id = _require_user_id(current_user, "video access")
|
||||
if authenticated_user_id != user_id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Access denied: You can only access your own videos"
|
||||
)
|
||||
|
||||
# Resolve video path
|
||||
# __file__ is: backend/routers/image_studio.py
|
||||
# We need: backend/transform_videos
|
||||
base_dir = Path(__file__).parent.parent.parent
|
||||
transform_videos_dir = base_dir / "transform_videos"
|
||||
video_path = transform_videos_dir / user_id / video_filename
|
||||
|
||||
# Security: Ensure path is within transform_videos directory
|
||||
# Prevent directory traversal attacks
|
||||
try:
|
||||
resolved_video_path = video_path.resolve()
|
||||
resolved_base = transform_videos_dir.resolve()
|
||||
# Check if video path is within base directory
|
||||
resolved_video_path.relative_to(resolved_base)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Invalid video path: path traversal detected"
|
||||
)
|
||||
|
||||
if not video_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Video not found")
|
||||
|
||||
return FileResponse(
|
||||
path=str(video_path),
|
||||
media_type="video/mp4",
|
||||
filename=video_filename
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Transform Studio] Failed to serve video: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ====================
|
||||
# HEALTH CHECK
|
||||
# ====================
|
||||
|
||||
@@ -8,9 +8,10 @@ proper error handling, monitoring, and documentation.
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, Optional
|
||||
import time
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
|
||||
from models.linkedin_models import (
|
||||
LinkedInPostRequest, LinkedInArticleRequest, LinkedInCarouselRequest,
|
||||
@@ -19,11 +20,13 @@ from models.linkedin_models import (
|
||||
LinkedInVideoScriptResponse, LinkedInCommentResponseResult
|
||||
)
|
||||
from services.linkedin_service import LinkedInService
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.text_asset_tracker import save_and_track_text_content
|
||||
|
||||
# Initialize the LinkedIn service instance
|
||||
linkedin_service = LinkedInService()
|
||||
from services.subscription.monitoring_middleware import DatabaseAPIMonitor
|
||||
from services.database import get_db_session
|
||||
from services.database import get_db as get_db_dependency
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Initialize router
|
||||
@@ -41,14 +44,8 @@ router = APIRouter(
|
||||
monitor = DatabaseAPIMonitor()
|
||||
|
||||
|
||||
def get_db():
|
||||
"""Dependency to get database session."""
|
||||
db = get_db_session()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
# Use the proper database dependency from services.database
|
||||
get_db = get_db_dependency
|
||||
|
||||
|
||||
async def log_api_request(request: Request, db: Session, duration: float, status_code: int):
|
||||
@@ -104,7 +101,8 @@ async def generate_post(
|
||||
request: LinkedInPostRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
|
||||
):
|
||||
"""Generate a LinkedIn post based on the provided parameters."""
|
||||
start_time = time.time()
|
||||
@@ -119,6 +117,13 @@ async def generate_post(
|
||||
if not request.industry.strip():
|
||||
raise HTTPException(status_code=422, detail="Industry cannot be empty")
|
||||
|
||||
# Extract user_id
|
||||
user_id = None
|
||||
if current_user:
|
||||
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
|
||||
if not user_id:
|
||||
user_id = http_request.headers.get("X-User-ID") or http_request.headers.get("Authorization")
|
||||
|
||||
# Generate post content
|
||||
response = await linkedin_service.generate_linkedin_post(request)
|
||||
|
||||
@@ -131,6 +136,38 @@ async def generate_post(
|
||||
if not response.success:
|
||||
raise HTTPException(status_code=500, detail=response.error)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if user_id and response.data and response.data.content:
|
||||
try:
|
||||
# Combine all text content
|
||||
text_content = response.data.content
|
||||
if response.data.call_to_action:
|
||||
text_content += f"\n\nCall to Action: {response.data.call_to_action}"
|
||||
if response.data.hashtags:
|
||||
hashtag_text = " ".join([f"#{h.hashtag}" if isinstance(h, dict) else f"#{h.get('hashtag', '')}" for h in response.data.hashtags])
|
||||
text_content += f"\n\nHashtags: {hashtag_text}"
|
||||
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=text_content,
|
||||
source_module="linkedin_writer",
|
||||
title=f"LinkedIn Post: {request.topic[:80]}",
|
||||
description=f"LinkedIn post for {request.industry} industry",
|
||||
prompt=f"Topic: {request.topic}\nIndustry: {request.industry}\nTone: {request.tone}",
|
||||
tags=["linkedin", "post", request.industry.lower().replace(' ', '_')],
|
||||
asset_metadata={
|
||||
"post_type": request.post_type.value if hasattr(request.post_type, 'value') else str(request.post_type),
|
||||
"tone": request.tone.value if hasattr(request.tone, 'value') else str(request.tone),
|
||||
"character_count": response.data.character_count,
|
||||
"hashtag_count": len(response.data.hashtags),
|
||||
"grounding_enabled": response.data.grounding_enabled if hasattr(response.data, 'grounding_enabled') else False
|
||||
},
|
||||
subdirectory="posts"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track LinkedIn post asset: {track_error}")
|
||||
|
||||
logger.info(f"Successfully generated LinkedIn post in {duration:.2f} seconds")
|
||||
return response
|
||||
|
||||
@@ -174,7 +211,8 @@ async def generate_article(
|
||||
request: LinkedInArticleRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
|
||||
):
|
||||
"""Generate a LinkedIn article based on the provided parameters."""
|
||||
start_time = time.time()
|
||||
@@ -189,6 +227,13 @@ async def generate_article(
|
||||
if not request.industry.strip():
|
||||
raise HTTPException(status_code=422, detail="Industry cannot be empty")
|
||||
|
||||
# Extract user_id
|
||||
user_id = None
|
||||
if current_user:
|
||||
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
|
||||
if not user_id:
|
||||
user_id = http_request.headers.get("X-User-ID") or http_request.headers.get("Authorization")
|
||||
|
||||
# Generate article content
|
||||
response = await linkedin_service.generate_linkedin_article(request)
|
||||
|
||||
@@ -201,6 +246,44 @@ async def generate_article(
|
||||
if not response.success:
|
||||
raise HTTPException(status_code=500, detail=response.error)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if user_id and response.data:
|
||||
try:
|
||||
# Combine article content
|
||||
text_content = f"# {response.data.title}\n\n"
|
||||
text_content += response.data.content
|
||||
|
||||
if response.data.sections:
|
||||
text_content += "\n\n## Sections:\n"
|
||||
for section in response.data.sections:
|
||||
if isinstance(section, dict):
|
||||
text_content += f"\n### {section.get('heading', 'Section')}\n{section.get('content', '')}\n"
|
||||
|
||||
if response.data.seo_metadata:
|
||||
text_content += f"\n\n## SEO Metadata\n{response.data.seo_metadata}\n"
|
||||
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=text_content,
|
||||
source_module="linkedin_writer",
|
||||
title=f"LinkedIn Article: {response.data.title[:80] if response.data.title else request.topic[:80]}",
|
||||
description=f"LinkedIn article for {request.industry} industry",
|
||||
prompt=f"Topic: {request.topic}\nIndustry: {request.industry}\nTone: {request.tone}\nWord Count: {request.word_count}",
|
||||
tags=["linkedin", "article", request.industry.lower().replace(' ', '_')],
|
||||
asset_metadata={
|
||||
"tone": request.tone.value if hasattr(request.tone, 'value') else str(request.tone),
|
||||
"word_count": response.data.word_count,
|
||||
"reading_time": response.data.reading_time,
|
||||
"section_count": len(response.data.sections) if response.data.sections else 0,
|
||||
"grounding_enabled": response.data.grounding_enabled if hasattr(response.data, 'grounding_enabled') else False
|
||||
},
|
||||
subdirectory="articles",
|
||||
file_extension=".md"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track LinkedIn article asset: {track_error}")
|
||||
|
||||
logger.info(f"Successfully generated LinkedIn article in {duration:.2f} seconds")
|
||||
return response
|
||||
|
||||
@@ -243,7 +326,8 @@ async def generate_carousel(
|
||||
request: LinkedInCarouselRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
|
||||
):
|
||||
"""Generate a LinkedIn carousel based on the provided parameters."""
|
||||
start_time = time.time()
|
||||
@@ -261,6 +345,13 @@ async def generate_carousel(
|
||||
if request.slide_count < 3 or request.slide_count > 15:
|
||||
raise HTTPException(status_code=422, detail="Slide count must be between 3 and 15")
|
||||
|
||||
# Extract user_id
|
||||
user_id = None
|
||||
if current_user:
|
||||
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
|
||||
if not user_id:
|
||||
user_id = http_request.headers.get("X-User-ID") or http_request.headers.get("Authorization")
|
||||
|
||||
# Generate carousel content
|
||||
response = await linkedin_service.generate_linkedin_carousel(request)
|
||||
|
||||
@@ -273,6 +364,36 @@ async def generate_carousel(
|
||||
if not response.success:
|
||||
raise HTTPException(status_code=500, detail=response.error)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if user_id and response.data:
|
||||
try:
|
||||
# Combine carousel content
|
||||
text_content = f"# {response.data.title}\n\n"
|
||||
for slide in response.data.slides:
|
||||
text_content += f"\n## Slide {slide.slide_number}: {slide.title}\n{slide.content}\n"
|
||||
if slide.visual_elements:
|
||||
text_content += f"\nVisual Elements: {', '.join(slide.visual_elements)}\n"
|
||||
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=text_content,
|
||||
source_module="linkedin_writer",
|
||||
title=f"LinkedIn Carousel: {response.data.title[:80] if response.data.title else request.topic[:80]}",
|
||||
description=f"LinkedIn carousel for {request.industry} industry",
|
||||
prompt=f"Topic: {request.topic}\nIndustry: {request.industry}\nSlides: {getattr(request, 'number_of_slides', request.slide_count if hasattr(request, 'slide_count') else 5)}",
|
||||
tags=["linkedin", "carousel", request.industry.lower().replace(' ', '_')],
|
||||
asset_metadata={
|
||||
"slide_count": len(response.data.slides),
|
||||
"has_cover": response.data.cover_slide is not None,
|
||||
"has_cta": response.data.cta_slide is not None
|
||||
},
|
||||
subdirectory="carousels",
|
||||
file_extension=".md"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track LinkedIn carousel asset: {track_error}")
|
||||
|
||||
logger.info(f"Successfully generated LinkedIn carousel in {duration:.2f} seconds")
|
||||
return response
|
||||
|
||||
@@ -315,7 +436,8 @@ async def generate_video_script(
|
||||
request: LinkedInVideoScriptRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
|
||||
):
|
||||
"""Generate a LinkedIn video script based on the provided parameters."""
|
||||
start_time = time.time()
|
||||
@@ -330,9 +452,17 @@ async def generate_video_script(
|
||||
if not request.industry.strip():
|
||||
raise HTTPException(status_code=422, detail="Industry cannot be empty")
|
||||
|
||||
if request.video_length < 15 or request.video_length > 300:
|
||||
video_duration = getattr(request, 'video_duration', getattr(request, 'video_length', 60))
|
||||
if video_duration < 15 or video_duration > 300:
|
||||
raise HTTPException(status_code=422, detail="Video length must be between 15 and 300 seconds")
|
||||
|
||||
# Extract user_id
|
||||
user_id = None
|
||||
if current_user:
|
||||
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
|
||||
if not user_id:
|
||||
user_id = http_request.headers.get("X-User-ID") or http_request.headers.get("Authorization")
|
||||
|
||||
# Generate video script content
|
||||
response = await linkedin_service.generate_linkedin_video_script(request)
|
||||
|
||||
@@ -345,6 +475,47 @@ async def generate_video_script(
|
||||
if not response.success:
|
||||
raise HTTPException(status_code=500, detail=response.error)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if user_id and response.data:
|
||||
try:
|
||||
# Combine video script content
|
||||
text_content = f"# Video Script: {request.topic}\n\n"
|
||||
text_content += f"## Hook\n{response.data.hook}\n\n"
|
||||
text_content += "## Main Content\n"
|
||||
for scene in response.data.main_content:
|
||||
if isinstance(scene, dict):
|
||||
text_content += f"\n### Scene {scene.get('scene_number', '')}\n"
|
||||
text_content += f"{scene.get('content', '')}\n"
|
||||
if scene.get('duration'):
|
||||
text_content += f"Duration: {scene.get('duration')}s\n"
|
||||
if scene.get('visual_notes'):
|
||||
text_content += f"Visual Notes: {scene.get('visual_notes')}\n"
|
||||
text_content += f"\n## Conclusion\n{response.data.conclusion}\n"
|
||||
if response.data.captions:
|
||||
text_content += f"\n## Captions\n" + "\n".join(response.data.captions) + "\n"
|
||||
if response.data.thumbnail_suggestions:
|
||||
text_content += f"\n## Thumbnail Suggestions\n" + "\n".join(response.data.thumbnail_suggestions) + "\n"
|
||||
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=text_content,
|
||||
source_module="linkedin_writer",
|
||||
title=f"LinkedIn Video Script: {request.topic[:80]}",
|
||||
description=f"LinkedIn video script for {request.industry} industry",
|
||||
prompt=f"Topic: {request.topic}\nIndustry: {request.industry}\nDuration: {video_duration}s",
|
||||
tags=["linkedin", "video_script", request.industry.lower().replace(' ', '_')],
|
||||
asset_metadata={
|
||||
"video_duration": video_duration,
|
||||
"scene_count": len(response.data.main_content),
|
||||
"has_captions": bool(response.data.captions)
|
||||
},
|
||||
subdirectory="video_scripts",
|
||||
file_extension=".md"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track LinkedIn video script asset: {track_error}")
|
||||
|
||||
logger.info(f"Successfully generated LinkedIn video script in {duration:.2f} seconds")
|
||||
return response
|
||||
|
||||
@@ -387,7 +558,8 @@ async def generate_comment_response(
|
||||
request: LinkedInCommentResponseRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
|
||||
):
|
||||
"""Generate a LinkedIn comment response based on the provided parameters."""
|
||||
start_time = time.time()
|
||||
@@ -396,11 +568,21 @@ async def generate_comment_response(
|
||||
logger.info("Received LinkedIn comment response generation request")
|
||||
|
||||
# Validate request
|
||||
if not request.original_post.strip():
|
||||
raise HTTPException(status_code=422, detail="Original post cannot be empty")
|
||||
original_comment = getattr(request, 'original_comment', getattr(request, 'comment', ''))
|
||||
post_context = getattr(request, 'post_context', getattr(request, 'original_post', ''))
|
||||
|
||||
if not request.comment.strip():
|
||||
raise HTTPException(status_code=422, detail="Comment cannot be empty")
|
||||
if not original_comment.strip():
|
||||
raise HTTPException(status_code=422, detail="Original comment cannot be empty")
|
||||
|
||||
if not post_context.strip():
|
||||
raise HTTPException(status_code=422, detail="Post context cannot be empty")
|
||||
|
||||
# Extract user_id
|
||||
user_id = None
|
||||
if current_user:
|
||||
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
|
||||
if not user_id:
|
||||
user_id = http_request.headers.get("X-User-ID") or http_request.headers.get("Authorization")
|
||||
|
||||
# Generate comment response
|
||||
response = await linkedin_service.generate_linkedin_comment_response(request)
|
||||
@@ -414,6 +596,38 @@ async def generate_comment_response(
|
||||
if not response.success:
|
||||
raise HTTPException(status_code=500, detail=response.error)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if user_id and hasattr(response, 'response') and response.response:
|
||||
try:
|
||||
text_content = f"# Comment Response\n\n"
|
||||
text_content += f"## Original Comment\n{original_comment}\n\n"
|
||||
text_content += f"## Post Context\n{post_context}\n\n"
|
||||
text_content += f"## Generated Response\n{response.response}\n"
|
||||
if hasattr(response, 'alternatives') and response.alternatives:
|
||||
text_content += f"\n## Alternative Responses\n"
|
||||
for i, alt in enumerate(response.alternatives, 1):
|
||||
text_content += f"\n### Alternative {i}\n{alt}\n"
|
||||
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=text_content,
|
||||
source_module="linkedin_writer",
|
||||
title=f"LinkedIn Comment Response: {original_comment[:60]}",
|
||||
description=f"LinkedIn comment response for {request.industry} industry",
|
||||
prompt=f"Original Comment: {original_comment}\nPost Context: {post_context}\nIndustry: {request.industry}",
|
||||
tags=["linkedin", "comment_response", request.industry.lower().replace(' ', '_')],
|
||||
asset_metadata={
|
||||
"response_length": getattr(request, 'response_length', 'medium'),
|
||||
"tone": request.tone.value if hasattr(request.tone, 'value') else str(request.tone),
|
||||
"has_alternatives": hasattr(response, 'alternatives') and bool(response.alternatives)
|
||||
},
|
||||
subdirectory="comment_responses",
|
||||
file_extension=".md"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track LinkedIn comment response asset: {track_error}")
|
||||
|
||||
logger.info(f"Successfully generated LinkedIn comment response in {duration:.2f} seconds")
|
||||
return response
|
||||
|
||||
|
||||
640
backend/routers/product_marketing.py
Normal file
640
backend/routers/product_marketing.py
Normal file
@@ -0,0 +1,640 @@
|
||||
"""API endpoints for Product Marketing Suite."""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from services.product_marketing import (
|
||||
ProductMarketingOrchestrator,
|
||||
BrandDNASyncService,
|
||||
AssetAuditService,
|
||||
ChannelPackService,
|
||||
)
|
||||
from services.product_marketing.campaign_storage import CampaignStorageService
|
||||
from services.product_marketing.product_image_service import ProductImageService, ProductImageRequest
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
from services.database import get_db
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
logger = get_service_logger("api.product_marketing")
|
||||
router = APIRouter(prefix="/api/product-marketing", tags=["product-marketing"])
|
||||
|
||||
|
||||
# ====================
|
||||
# REQUEST MODELS
|
||||
# ====================
|
||||
|
||||
class CampaignCreateRequest(BaseModel):
|
||||
"""Request to create a new campaign blueprint."""
|
||||
campaign_name: str = Field(..., description="Campaign name")
|
||||
goal: str = Field(..., description="Campaign goal (product_launch, awareness, conversion, etc.)")
|
||||
kpi: Optional[str] = Field(None, description="Key performance indicator")
|
||||
channels: List[str] = Field(..., description="Target channels (instagram, linkedin, tiktok, etc.)")
|
||||
product_context: Optional[Dict[str, Any]] = Field(None, description="Product information")
|
||||
|
||||
|
||||
class AssetProposalRequest(BaseModel):
|
||||
"""Request to generate asset proposals."""
|
||||
campaign_id: str = Field(..., description="Campaign ID")
|
||||
product_context: Optional[Dict[str, Any]] = Field(None, description="Product information")
|
||||
|
||||
|
||||
class AssetGenerateRequest(BaseModel):
|
||||
"""Request to generate a specific asset."""
|
||||
asset_proposal: Dict[str, Any] = Field(..., description="Asset proposal from generate_proposals")
|
||||
product_context: Optional[Dict[str, Any]] = Field(None, description="Product information")
|
||||
|
||||
|
||||
class AssetAuditRequest(BaseModel):
|
||||
"""Request to audit uploaded assets."""
|
||||
image_base64: str = Field(..., description="Base64 encoded image")
|
||||
asset_metadata: Optional[Dict[str, Any]] = Field(None, description="Asset metadata")
|
||||
|
||||
|
||||
# ====================
|
||||
# DEPENDENCY
|
||||
# ====================
|
||||
|
||||
def get_orchestrator() -> ProductMarketingOrchestrator:
|
||||
"""Get Product Marketing Orchestrator instance."""
|
||||
return ProductMarketingOrchestrator()
|
||||
|
||||
|
||||
def get_campaign_storage() -> CampaignStorageService:
|
||||
"""Get Campaign Storage Service instance."""
|
||||
return CampaignStorageService()
|
||||
|
||||
|
||||
def _require_user_id(current_user: Dict[str, Any], operation: str) -> str:
|
||||
"""Ensure user_id is available for protected operations."""
|
||||
user_id = current_user.get("sub") or current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
logger.error(
|
||||
"[Product Marketing] ❌ Missing user_id for %s operation - blocking request",
|
||||
operation,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authenticated user required for product marketing operations.",
|
||||
)
|
||||
return str(user_id)
|
||||
|
||||
|
||||
# ====================
|
||||
# CAMPAIGN ENDPOINTS
|
||||
# ====================
|
||||
|
||||
@router.post("/campaigns/validate-preflight", summary="Validate Campaign Pre-flight")
|
||||
async def validate_campaign_preflight(
|
||||
request: CampaignCreateRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
orchestrator: ProductMarketingOrchestrator = Depends(get_orchestrator)
|
||||
):
|
||||
"""Validate campaign blueprint against subscription limits before creation.
|
||||
|
||||
This endpoint:
|
||||
- Creates a temporary blueprint to estimate costs
|
||||
- Validates subscription limits
|
||||
- Returns cost estimates and validation results
|
||||
- Does NOT save anything to database
|
||||
"""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "campaign pre-flight validation")
|
||||
logger.info(f"[Product Marketing] Pre-flight validation for user {user_id}")
|
||||
|
||||
# Create temporary blueprint for validation (not saved)
|
||||
campaign_data = {
|
||||
"campaign_name": request.campaign_name or "Temporary Campaign",
|
||||
"goal": request.goal,
|
||||
"kpi": request.kpi,
|
||||
"channels": request.channels,
|
||||
}
|
||||
|
||||
blueprint = orchestrator.create_campaign_blueprint(user_id, campaign_data)
|
||||
|
||||
# Run pre-flight validation
|
||||
validation_result = orchestrator.validate_campaign_preflight(user_id, blueprint)
|
||||
|
||||
logger.info(f"[Product Marketing] ✅ Pre-flight validation completed: can_proceed={validation_result.get('can_proceed')}")
|
||||
return validation_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Product Marketing] ❌ Error in pre-flight validation: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Pre-flight validation failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/campaigns/create-blueprint", summary="Create Campaign Blueprint")
|
||||
async def create_campaign_blueprint(
|
||||
request: CampaignCreateRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
orchestrator: ProductMarketingOrchestrator = Depends(get_orchestrator)
|
||||
):
|
||||
"""Create a campaign blueprint with personalized asset nodes.
|
||||
|
||||
This endpoint:
|
||||
- Uses onboarding data to personalize the blueprint
|
||||
- Generates campaign phases (teaser, launch, nurture)
|
||||
- Creates asset nodes for each phase and channel
|
||||
- Returns blueprint ready for AI proposal generation
|
||||
"""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "campaign blueprint creation")
|
||||
logger.info(f"[Product Marketing] Creating blueprint for user {user_id}: {request.campaign_name}")
|
||||
|
||||
campaign_data = {
|
||||
"campaign_name": request.campaign_name,
|
||||
"goal": request.goal,
|
||||
"kpi": request.kpi,
|
||||
"channels": request.channels,
|
||||
}
|
||||
|
||||
blueprint = orchestrator.create_campaign_blueprint(user_id, campaign_data)
|
||||
|
||||
# Convert blueprint to dict for JSON response
|
||||
blueprint_dict = {
|
||||
"campaign_id": blueprint.campaign_id,
|
||||
"campaign_name": blueprint.campaign_name,
|
||||
"goal": blueprint.goal,
|
||||
"kpi": blueprint.kpi,
|
||||
"phases": blueprint.phases,
|
||||
"asset_nodes": [
|
||||
{
|
||||
"asset_id": node.asset_id,
|
||||
"asset_type": node.asset_type,
|
||||
"channel": node.channel,
|
||||
"status": node.status,
|
||||
}
|
||||
for node in blueprint.asset_nodes
|
||||
],
|
||||
"channels": blueprint.channels,
|
||||
"status": blueprint.status,
|
||||
}
|
||||
|
||||
# Save to database
|
||||
campaign_storage = get_campaign_storage()
|
||||
campaign_storage.save_campaign(user_id, blueprint_dict)
|
||||
|
||||
logger.info(f"[Product Marketing] ✅ Blueprint created and saved: {blueprint.campaign_id}")
|
||||
return blueprint_dict
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Product Marketing] ❌ Error creating blueprint: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Campaign blueprint creation failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/campaigns/{campaign_id}/generate-proposals", summary="Generate Asset Proposals")
|
||||
async def generate_asset_proposals(
|
||||
campaign_id: str,
|
||||
request: AssetProposalRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
orchestrator: ProductMarketingOrchestrator = Depends(get_orchestrator)
|
||||
):
|
||||
"""Generate AI proposals for all assets in a campaign blueprint.
|
||||
|
||||
This endpoint:
|
||||
- Uses specialized marketing prompts with brand DNA
|
||||
- Recommends templates, providers, and settings
|
||||
- Provides cost estimates
|
||||
- Returns proposals ready for user approval
|
||||
"""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "asset proposal generation")
|
||||
logger.info(f"[Product Marketing] Generating proposals for campaign {campaign_id}")
|
||||
|
||||
# Fetch blueprint from database
|
||||
campaign_storage = get_campaign_storage()
|
||||
campaign = campaign_storage.get_campaign(user_id, campaign_id)
|
||||
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
# Reconstruct blueprint from database
|
||||
from services.product_marketing.orchestrator import CampaignBlueprint, CampaignAssetNode
|
||||
|
||||
asset_nodes = []
|
||||
if campaign.asset_nodes:
|
||||
for node_data in campaign.asset_nodes:
|
||||
asset_nodes.append(CampaignAssetNode(
|
||||
asset_id=node_data.get('asset_id'),
|
||||
asset_type=node_data.get('asset_type'),
|
||||
channel=node_data.get('channel'),
|
||||
status=node_data.get('status', 'draft'),
|
||||
))
|
||||
|
||||
blueprint = CampaignBlueprint(
|
||||
campaign_id=campaign.campaign_id,
|
||||
campaign_name=campaign.campaign_name,
|
||||
goal=campaign.goal,
|
||||
kpi=campaign.kpi,
|
||||
channels=campaign.channels or [],
|
||||
asset_nodes=asset_nodes,
|
||||
)
|
||||
|
||||
proposals = orchestrator.generate_asset_proposals(
|
||||
user_id=user_id,
|
||||
blueprint=blueprint,
|
||||
product_context=request.product_context,
|
||||
)
|
||||
|
||||
# Save proposals to database
|
||||
try:
|
||||
campaign_storage.save_proposals(user_id, campaign_id, proposals)
|
||||
logger.info(f"[Product Marketing] ✅ Saved {proposals['total_assets']} proposals to database")
|
||||
except Exception as save_error:
|
||||
logger.error(f"[Product Marketing] ⚠️ Failed to save proposals to database: {str(save_error)}")
|
||||
# Continue even if save fails - proposals are still returned to user
|
||||
# This allows the workflow to continue, but proposals won't persist
|
||||
|
||||
logger.info(f"[Product Marketing] ✅ Generated {proposals['total_assets']} proposals")
|
||||
return proposals
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Product Marketing] ❌ Error generating proposals: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Asset proposal generation failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/assets/generate", summary="Generate Asset")
|
||||
async def generate_asset(
|
||||
request: AssetGenerateRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
orchestrator: ProductMarketingOrchestrator = Depends(get_orchestrator)
|
||||
):
|
||||
"""Generate a single asset using Image Studio APIs.
|
||||
|
||||
This endpoint:
|
||||
- Reuses existing Image Studio APIs
|
||||
- Applies specialized marketing prompts
|
||||
- Automatically tracks assets in Asset Library
|
||||
- Validates subscription limits
|
||||
"""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "asset generation")
|
||||
logger.info(f"[Product Marketing] Generating asset for user {user_id}")
|
||||
|
||||
result = await orchestrator.generate_asset(
|
||||
user_id=user_id,
|
||||
asset_proposal=request.asset_proposal,
|
||||
product_context=request.product_context,
|
||||
)
|
||||
|
||||
logger.info(f"[Product Marketing] ✅ Asset generated successfully")
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Product Marketing] ❌ Error generating asset: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Asset generation failed: {str(e)}")
|
||||
|
||||
|
||||
# ====================
|
||||
# BRAND DNA ENDPOINTS
|
||||
# ====================
|
||||
|
||||
@router.get("/brand-dna", summary="Get Brand DNA Tokens")
|
||||
async def get_brand_dna(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
brand_dna_sync: BrandDNASyncService = Depends(lambda: BrandDNASyncService())
|
||||
):
|
||||
"""Get brand DNA tokens for the authenticated user.
|
||||
|
||||
Returns normalized brand DNA from onboarding and persona data.
|
||||
"""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "brand DNA retrieval")
|
||||
brand_tokens = brand_dna_sync.get_brand_dna_tokens(user_id)
|
||||
|
||||
return {"brand_dna": brand_tokens}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Product Marketing] ❌ Error getting brand DNA: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/brand-dna/channel/{channel}", summary="Get Channel-Specific Brand DNA")
|
||||
async def get_channel_brand_dna(
|
||||
channel: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
brand_dna_sync: BrandDNASyncService = Depends(lambda: BrandDNASyncService())
|
||||
):
|
||||
"""Get channel-specific brand DNA adaptations."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "channel brand DNA retrieval")
|
||||
channel_dna = brand_dna_sync.get_channel_specific_dna(user_id, channel)
|
||||
|
||||
return {"channel": channel, "brand_dna": channel_dna}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Product Marketing] ❌ Error getting channel DNA: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ====================
|
||||
# ASSET AUDIT ENDPOINTS
|
||||
# ====================
|
||||
|
||||
@router.post("/assets/audit", summary="Audit Asset")
|
||||
async def audit_asset(
|
||||
request: AssetAuditRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
asset_audit: AssetAuditService = Depends(lambda: AssetAuditService())
|
||||
):
|
||||
"""Audit an uploaded asset and get enhancement recommendations."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "asset audit")
|
||||
audit_result = asset_audit.audit_asset(
|
||||
request.image_base64,
|
||||
request.asset_metadata,
|
||||
)
|
||||
|
||||
return audit_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Product Marketing] ❌ Error auditing asset: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ====================
|
||||
# CHANNEL PACK ENDPOINTS
|
||||
# ====================
|
||||
|
||||
@router.get("/channels/{channel}/pack", summary="Get Channel Pack")
|
||||
async def get_channel_pack(
|
||||
channel: str,
|
||||
asset_type: str = "social_post",
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
channel_pack: ChannelPackService = Depends(lambda: ChannelPackService())
|
||||
):
|
||||
"""Get channel-specific pack configuration with templates and optimization tips."""
|
||||
try:
|
||||
pack = channel_pack.get_channel_pack(channel, asset_type)
|
||||
return pack
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Product Marketing] ❌ Error getting channel pack: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ====================
|
||||
# CAMPAIGN LISTING & RETRIEVAL
|
||||
# ====================
|
||||
|
||||
@router.get("/campaigns", summary="List Campaigns")
|
||||
async def list_campaigns(
|
||||
status: Optional[str] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
campaign_storage: CampaignStorageService = Depends(get_campaign_storage)
|
||||
):
|
||||
"""List all campaigns for the authenticated user."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "list campaigns")
|
||||
campaigns = campaign_storage.list_campaigns(user_id, status=status)
|
||||
|
||||
return {
|
||||
"campaigns": [
|
||||
{
|
||||
"campaign_id": c.campaign_id,
|
||||
"campaign_name": c.campaign_name,
|
||||
"goal": c.goal,
|
||||
"kpi": c.kpi,
|
||||
"status": c.status,
|
||||
"channels": c.channels,
|
||||
"phases": c.phases,
|
||||
"asset_nodes": c.asset_nodes,
|
||||
"created_at": c.created_at.isoformat() if c.created_at else None,
|
||||
"updated_at": c.updated_at.isoformat() if c.updated_at else None,
|
||||
}
|
||||
for c in campaigns
|
||||
],
|
||||
"total": len(campaigns),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[Product Marketing] ❌ Error listing campaigns: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/campaigns/{campaign_id}", summary="Get Campaign")
|
||||
async def get_campaign(
|
||||
campaign_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
campaign_storage: CampaignStorageService = Depends(get_campaign_storage)
|
||||
):
|
||||
"""Get a specific campaign by ID."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "get campaign")
|
||||
campaign = campaign_storage.get_campaign(user_id, campaign_id)
|
||||
|
||||
if not campaign:
|
||||
raise HTTPException(status_code=404, detail="Campaign not found")
|
||||
|
||||
return {
|
||||
"campaign_id": campaign.campaign_id,
|
||||
"campaign_name": campaign.campaign_name,
|
||||
"goal": campaign.goal,
|
||||
"kpi": campaign.kpi,
|
||||
"status": campaign.status,
|
||||
"channels": campaign.channels,
|
||||
"phases": campaign.phases,
|
||||
"asset_nodes": campaign.asset_nodes,
|
||||
"product_context": campaign.product_context,
|
||||
"created_at": campaign.created_at.isoformat() if campaign.created_at else None,
|
||||
"updated_at": campaign.updated_at.isoformat() if campaign.updated_at else None,
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Product Marketing] ❌ Error getting campaign: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/campaigns/{campaign_id}/proposals", summary="Get Campaign Proposals")
|
||||
async def get_campaign_proposals(
|
||||
campaign_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
campaign_storage: CampaignStorageService = Depends(get_campaign_storage)
|
||||
):
|
||||
"""Get proposals for a campaign."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "get proposals")
|
||||
proposals = campaign_storage.get_proposals(user_id, campaign_id)
|
||||
|
||||
proposals_dict = {}
|
||||
for proposal in proposals:
|
||||
proposals_dict[proposal.asset_node_id] = {
|
||||
"asset_id": proposal.asset_node_id,
|
||||
"asset_type": proposal.asset_type,
|
||||
"channel": proposal.channel,
|
||||
"proposed_prompt": proposal.proposed_prompt,
|
||||
"recommended_template": proposal.recommended_template,
|
||||
"recommended_provider": proposal.recommended_provider,
|
||||
"cost_estimate": proposal.cost_estimate,
|
||||
"concept_summary": proposal.concept_summary,
|
||||
"status": proposal.status,
|
||||
}
|
||||
|
||||
return {
|
||||
"proposals": proposals_dict,
|
||||
"total_assets": len(proposals),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[Product Marketing] ❌ Error getting proposals: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ====================
|
||||
# PRODUCT ASSET ENDPOINTS (Product Marketing Suite - Product Assets)
|
||||
# ====================
|
||||
|
||||
class ProductPhotoshootRequest(BaseModel):
|
||||
"""Request for product image photoshoot generation."""
|
||||
product_name: str = Field(..., description="Product name")
|
||||
product_description: str = Field(..., description="Product description")
|
||||
environment: str = Field(default="studio", description="Environment: studio, lifestyle, outdoor, minimalist, luxury")
|
||||
background_style: str = Field(default="white", description="Background: white, transparent, lifestyle, branded")
|
||||
lighting: str = Field(default="natural", description="Lighting: natural, studio, dramatic, soft")
|
||||
product_variant: Optional[str] = Field(None, description="Product variant (color, size, etc.)")
|
||||
angle: Optional[str] = Field(None, description="Product angle: front, side, top, 360")
|
||||
style: str = Field(default="photorealistic", description="Style: photorealistic, minimalist, luxury, technical")
|
||||
resolution: str = Field(default="1024x1024", description="Resolution (e.g., 1024x1024, 1280x720)")
|
||||
num_variations: int = Field(default=1, description="Number of variations to generate")
|
||||
brand_colors: Optional[List[str]] = Field(None, description="Brand color palette")
|
||||
additional_context: Optional[str] = Field(None, description="Additional context for generation")
|
||||
|
||||
|
||||
def get_product_image_service() -> ProductImageService:
|
||||
"""Get Product Image Service instance."""
|
||||
return ProductImageService()
|
||||
|
||||
|
||||
@router.post("/products/photoshoot", summary="Generate Product Image")
|
||||
async def generate_product_image(
|
||||
request: ProductPhotoshootRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
product_image_service: ProductImageService = Depends(get_product_image_service),
|
||||
brand_dna_sync: BrandDNASyncService = Depends(lambda: BrandDNASyncService())
|
||||
):
|
||||
"""Generate professional product images using AI.
|
||||
|
||||
This endpoint:
|
||||
- Generates product images optimized for e-commerce
|
||||
- Supports multiple environments and styles
|
||||
- Integrates with brand DNA for personalization
|
||||
- Automatically saves to Asset Library
|
||||
"""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "product image generation")
|
||||
logger.info(f"[Product Marketing] Generating product image for '{request.product_name}'")
|
||||
|
||||
# Get brand DNA for personalization
|
||||
brand_context = None
|
||||
try:
|
||||
brand_dna = brand_dna_sync.get_brand_dna_tokens(user_id)
|
||||
brand_context = {
|
||||
"visual_identity": brand_dna.get("visual_identity", {}),
|
||||
"persona": brand_dna.get("persona", {}),
|
||||
}
|
||||
except Exception as brand_error:
|
||||
logger.warning(f"[Product Marketing] Could not load brand DNA: {str(brand_error)}")
|
||||
|
||||
# Convert request to service request
|
||||
service_request = ProductImageRequest(
|
||||
product_name=request.product_name,
|
||||
product_description=request.product_description,
|
||||
environment=request.environment,
|
||||
background_style=request.background_style,
|
||||
lighting=request.lighting,
|
||||
product_variant=request.product_variant,
|
||||
angle=request.angle,
|
||||
style=request.style,
|
||||
resolution=request.resolution,
|
||||
num_variations=request.num_variations,
|
||||
brand_colors=request.brand_colors,
|
||||
additional_context=request.additional_context,
|
||||
)
|
||||
|
||||
# Generate product image
|
||||
result = await product_image_service.generate_product_image(
|
||||
request=service_request,
|
||||
user_id=user_id,
|
||||
brand_context=brand_context,
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
raise HTTPException(status_code=500, detail=result.error or "Product image generation failed")
|
||||
|
||||
logger.info(f"[Product Marketing] ✅ Generated product image: {result.asset_id}")
|
||||
|
||||
# Return result (image_bytes will be served via separate endpoint)
|
||||
return {
|
||||
"success": True,
|
||||
"product_name": result.product_name,
|
||||
"image_url": result.image_url,
|
||||
"asset_id": result.asset_id,
|
||||
"provider": result.provider,
|
||||
"model": result.model,
|
||||
"cost": result.cost,
|
||||
"generation_time": result.generation_time,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Product Marketing] ❌ Error generating product image: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Product image generation failed: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/products/images/{filename}", summary="Serve Product Image")
|
||||
async def serve_product_image(
|
||||
filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Serve generated product images."""
|
||||
try:
|
||||
from fastapi.responses import FileResponse
|
||||
from pathlib import Path
|
||||
|
||||
_require_user_id(current_user, "serving product image")
|
||||
|
||||
# Locate image file
|
||||
base_dir = Path(__file__).parent.parent.parent
|
||||
image_path = base_dir / "product_images" / filename
|
||||
|
||||
if not image_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
return FileResponse(
|
||||
path=str(image_path),
|
||||
media_type="image/png",
|
||||
filename=filename
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Product Marketing] ❌ Error serving product image: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ====================
|
||||
# HEALTH CHECK
|
||||
# ====================
|
||||
|
||||
@router.get("/health", summary="Health Check")
|
||||
async def health_check():
|
||||
"""Health check endpoint for Product Marketing Suite."""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "product_marketing",
|
||||
"version": "1.0.0",
|
||||
"modules": {
|
||||
"orchestrator": "available",
|
||||
"prompt_builder": "available",
|
||||
"brand_dna_sync": "available",
|
||||
"asset_audit": "available",
|
||||
"channel_pack": "available",
|
||||
"product_image_service": "available",
|
||||
}
|
||||
}
|
||||
|
||||
88
backend/scripts/create_product_asset_tables.py
Normal file
88
backend/scripts/create_product_asset_tables.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
Database Migration Script for Product Asset Tables
|
||||
Creates all tables needed for Product Marketing Suite (product asset creation).
|
||||
These tables are separate from campaign-related tables and focus on product-specific assets.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Add the backend directory to Python path
|
||||
backend_dir = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(backend_dir))
|
||||
|
||||
from sqlalchemy import create_engine, text, inspect
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from loguru import logger
|
||||
import traceback
|
||||
|
||||
# Import models - Product Asset models use SubscriptionBase
|
||||
from models.subscription_models import Base as SubscriptionBase
|
||||
from models.product_asset_models import ProductAsset, ProductStyleTemplate, EcommerceExport
|
||||
from services.database import DATABASE_URL
|
||||
|
||||
|
||||
def create_product_asset_tables():
|
||||
"""Create all product asset tables."""
|
||||
|
||||
try:
|
||||
# Create engine
|
||||
engine = create_engine(DATABASE_URL, echo=False)
|
||||
|
||||
# Create all tables (product asset models share SubscriptionBase)
|
||||
logger.info("Creating product asset tables for Product Marketing Suite...")
|
||||
SubscriptionBase.metadata.create_all(bind=engine)
|
||||
logger.info("✅ Product asset tables created successfully")
|
||||
|
||||
# Verify tables were created
|
||||
with engine.connect() as conn:
|
||||
# Check if tables exist
|
||||
inspector = inspect(engine)
|
||||
tables = inspector.get_table_names()
|
||||
|
||||
expected_tables = [
|
||||
'product_assets',
|
||||
'product_style_templates',
|
||||
'product_ecommerce_exports'
|
||||
]
|
||||
|
||||
created_tables = [t for t in expected_tables if t in tables]
|
||||
missing_tables = [t for t in expected_tables if t not in tables]
|
||||
|
||||
if created_tables:
|
||||
logger.info(f"✅ Created tables: {', '.join(created_tables)}")
|
||||
|
||||
if missing_tables:
|
||||
logger.warning(f"⚠️ Missing tables: {', '.join(missing_tables)}")
|
||||
else:
|
||||
logger.info("🎉 All product asset tables verified!")
|
||||
|
||||
# Verify indexes were created
|
||||
with engine.connect() as conn:
|
||||
inspector = inspect(engine)
|
||||
|
||||
# Check ProductAsset indexes
|
||||
product_asset_indexes = inspector.get_indexes('product_assets')
|
||||
logger.info(f"✅ ProductAsset indexes: {len(product_asset_indexes)} indexes created")
|
||||
|
||||
# Check ProductStyleTemplate indexes
|
||||
style_template_indexes = inspector.get_indexes('product_style_templates')
|
||||
logger.info(f"✅ ProductStyleTemplate indexes: {len(style_template_indexes)} indexes created")
|
||||
|
||||
# Check EcommerceExport indexes
|
||||
ecommerce_export_indexes = inspector.get_indexes('product_ecommerce_exports')
|
||||
logger.info(f"✅ EcommerceExport indexes: {len(ecommerce_export_indexes)} indexes created")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error creating product asset tables: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = create_product_asset_tables()
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
71
backend/scripts/create_product_marketing_tables.py
Normal file
71
backend/scripts/create_product_marketing_tables.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
Database Migration Script for Product Marketing Suite
|
||||
Creates all tables needed for campaigns, proposals, and generated assets.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Add the backend directory to Python path
|
||||
backend_dir = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(backend_dir))
|
||||
|
||||
from sqlalchemy import create_engine, text, inspect
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from loguru import logger
|
||||
import traceback
|
||||
|
||||
# Import models - Product Marketing uses SubscriptionBase
|
||||
# Import the Base first, then import product marketing models to register them
|
||||
from models.subscription_models import Base as SubscriptionBase
|
||||
from models.product_marketing_models import Campaign, CampaignProposal, CampaignAsset
|
||||
from services.database import DATABASE_URL
|
||||
|
||||
def create_product_marketing_tables():
|
||||
"""Create all product marketing tables."""
|
||||
|
||||
try:
|
||||
# Create engine
|
||||
engine = create_engine(DATABASE_URL, echo=False)
|
||||
|
||||
# Create all tables (product marketing models share SubscriptionBase)
|
||||
logger.info("Creating product marketing tables...")
|
||||
SubscriptionBase.metadata.create_all(bind=engine)
|
||||
logger.info("✅ Product marketing tables created successfully")
|
||||
|
||||
# Verify tables were created
|
||||
with engine.connect() as conn:
|
||||
# Check if tables exist
|
||||
from sqlalchemy import inspect as sqlalchemy_inspect
|
||||
inspector = sqlalchemy_inspect(engine)
|
||||
tables = inspector.get_table_names()
|
||||
|
||||
expected_tables = [
|
||||
'product_marketing_campaigns',
|
||||
'product_marketing_proposals',
|
||||
'product_marketing_assets'
|
||||
]
|
||||
|
||||
created_tables = [t for t in expected_tables if t in tables]
|
||||
missing_tables = [t for t in expected_tables if t not in tables]
|
||||
|
||||
if created_tables:
|
||||
logger.info(f"✅ Created tables: {', '.join(created_tables)}")
|
||||
|
||||
if missing_tables:
|
||||
logger.warning(f"⚠️ Missing tables: {', '.join(missing_tables)}")
|
||||
else:
|
||||
logger.info("🎉 All product marketing tables verified!")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error creating product marketing tables: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = create_product_marketing_tables()
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
@@ -38,7 +38,7 @@ class ContentAssetService:
|
||||
description: Optional[str] = None,
|
||||
prompt: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
asset_metadata: Optional[Dict[str, Any]] = None,
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
cost: Optional[float] = None,
|
||||
@@ -60,7 +60,7 @@ class ContentAssetService:
|
||||
description: Asset description (optional)
|
||||
prompt: Generation prompt (optional)
|
||||
tags: List of tags (optional)
|
||||
metadata: Additional metadata (optional)
|
||||
asset_metadata: Additional metadata (optional)
|
||||
provider: AI provider used (optional)
|
||||
model: Model used (optional)
|
||||
cost: Generation cost (optional)
|
||||
@@ -83,7 +83,7 @@ class ContentAssetService:
|
||||
description=description,
|
||||
prompt=prompt,
|
||||
tags=tags or [],
|
||||
metadata=metadata or {},
|
||||
asset_metadata=asset_metadata or {},
|
||||
provider=provider,
|
||||
model=model,
|
||||
cost=cost or 0.0,
|
||||
@@ -222,7 +222,7 @@ class ContentAssetService:
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
asset_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[ContentAsset]:
|
||||
"""Update asset metadata."""
|
||||
try:
|
||||
@@ -236,8 +236,8 @@ class ContentAssetService:
|
||||
asset.description = description
|
||||
if tags is not None:
|
||||
asset.tags = tags
|
||||
if metadata is not None:
|
||||
asset.metadata = {**(asset.metadata or {}), **metadata}
|
||||
if asset_metadata is not None:
|
||||
asset.asset_metadata = {**(asset.asset_metadata or {}), **asset_metadata}
|
||||
|
||||
asset.updated_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
|
||||
@@ -21,6 +21,10 @@ from models.persona_models import Base as PersonaBase
|
||||
from models.subscription_models import Base as SubscriptionBase
|
||||
from models.user_business_info import Base as UserBusinessInfoBase
|
||||
from models.content_asset_models import Base as ContentAssetBase
|
||||
# Product Marketing models use SubscriptionBase, but import to ensure models are registered
|
||||
from models.product_marketing_models import Campaign, CampaignProposal, CampaignAsset
|
||||
# Product Asset models (Product Marketing Suite - product assets, not campaigns)
|
||||
from models.product_asset_models import ProductAsset, ProductStyleTemplate, EcommerceExport
|
||||
|
||||
# Database configuration
|
||||
DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./alwrity.db')
|
||||
@@ -73,10 +77,10 @@ def init_database():
|
||||
EnhancedStrategyBase.metadata.create_all(bind=engine)
|
||||
MonitoringBase.metadata.create_all(bind=engine)
|
||||
PersonaBase.metadata.create_all(bind=engine)
|
||||
SubscriptionBase.metadata.create_all(bind=engine)
|
||||
SubscriptionBase.metadata.create_all(bind=engine) # Includes product_marketing models
|
||||
UserBusinessInfoBase.metadata.create_all(bind=engine)
|
||||
ContentAssetBase.metadata.create_all(bind=engine)
|
||||
logger.info("Database initialized successfully with all models including subscription system, business info, and content assets")
|
||||
logger.info("Database initialized successfully with all models including subscription system, product marketing, business info, and content assets")
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error initializing database: {str(e)}")
|
||||
raise
|
||||
|
||||
@@ -6,6 +6,11 @@ from .edit_service import EditStudioService, EditStudioRequest
|
||||
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
|
||||
from .control_service import ControlStudioService, ControlStudioRequest
|
||||
from .social_optimizer_service import SocialOptimizerService, SocialOptimizerRequest
|
||||
from .transform_service import (
|
||||
TransformStudioService,
|
||||
TransformImageToVideoRequest,
|
||||
TalkingAvatarRequest,
|
||||
)
|
||||
from .templates import PlatformTemplates, TemplateManager
|
||||
|
||||
__all__ = [
|
||||
@@ -20,6 +25,9 @@ __all__ = [
|
||||
"ControlStudioRequest",
|
||||
"SocialOptimizerService",
|
||||
"SocialOptimizerRequest",
|
||||
"TransformStudioService",
|
||||
"TransformImageToVideoRequest",
|
||||
"TalkingAvatarRequest",
|
||||
"PlatformTemplates",
|
||||
"TemplateManager",
|
||||
]
|
||||
|
||||
155
backend/services/image_studio/infinitetalk_adapter.py
Normal file
155
backend/services/image_studio/infinitetalk_adapter.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""InfiniteTalk adapter for Transform Studio."""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from services.wavespeed.infinitetalk import animate_scene_with_voiceover
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("image_studio.infinitetalk")
|
||||
|
||||
|
||||
class InfiniteTalkService:
|
||||
"""Adapter for InfiniteTalk in Transform Studio context."""
|
||||
|
||||
def __init__(self, client: Optional[WaveSpeedClient] = None):
|
||||
"""Initialize InfiniteTalk service adapter."""
|
||||
self.client = client or WaveSpeedClient()
|
||||
logger.info("[InfiniteTalk Adapter] Service initialized")
|
||||
|
||||
def calculate_cost(self, resolution: str, duration: float) -> float:
|
||||
"""Calculate cost for InfiniteTalk video.
|
||||
|
||||
Args:
|
||||
resolution: Output resolution (480p or 720p)
|
||||
duration: Video duration in seconds
|
||||
|
||||
Returns:
|
||||
Cost in USD
|
||||
"""
|
||||
# InfiniteTalk pricing: $0.03/s (480p) or $0.06/s (720p)
|
||||
# Minimum charge: 5 seconds
|
||||
cost_per_second = 0.03 if resolution == "480p" else 0.06
|
||||
actual_duration = max(5.0, duration) # Minimum 5 seconds
|
||||
return cost_per_second * actual_duration
|
||||
|
||||
async def create_talking_avatar(
|
||||
self,
|
||||
image_base64: str,
|
||||
audio_base64: str,
|
||||
resolution: str = "720p",
|
||||
prompt: Optional[str] = None,
|
||||
mask_image_base64: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
user_id: str = "transform_studio",
|
||||
) -> Dict[str, Any]:
|
||||
"""Create talking avatar video using InfiniteTalk.
|
||||
|
||||
Args:
|
||||
image_base64: Person image in base64 or data URI
|
||||
audio_base64: Audio file in base64 or data URI
|
||||
resolution: Output resolution (480p or 720p)
|
||||
prompt: Optional prompt for expression/style
|
||||
mask_image_base64: Optional mask for animatable regions
|
||||
seed: Optional random seed
|
||||
user_id: User ID for tracking
|
||||
|
||||
Returns:
|
||||
Dictionary with video bytes, metadata, and cost
|
||||
"""
|
||||
# Validate resolution
|
||||
if resolution not in ["480p", "720p"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Resolution must be '480p' or '720p' for InfiniteTalk"
|
||||
)
|
||||
|
||||
# Decode image
|
||||
import base64
|
||||
try:
|
||||
if image_base64.startswith("data:"):
|
||||
if "," not in image_base64:
|
||||
raise ValueError("Invalid data URI format: missing comma separator")
|
||||
header, encoded = image_base64.split(",", 1)
|
||||
mime_parts = header.split(":")[1].split(";")[0] if ":" in header else "image/png"
|
||||
image_mime = mime_parts.strip() or "image/png"
|
||||
image_bytes = base64.b64decode(encoded)
|
||||
else:
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_mime = "image/png"
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to decode image: {str(e)}"
|
||||
)
|
||||
|
||||
# Decode audio
|
||||
try:
|
||||
if audio_base64.startswith("data:"):
|
||||
if "," not in audio_base64:
|
||||
raise ValueError("Invalid data URI format: missing comma separator")
|
||||
header, encoded = audio_base64.split(",", 1)
|
||||
mime_parts = header.split(":")[1].split(";")[0] if ":" in header else "audio/mpeg"
|
||||
audio_mime = mime_parts.strip() or "audio/mpeg"
|
||||
audio_bytes = base64.b64decode(encoded)
|
||||
else:
|
||||
audio_bytes = base64.b64decode(audio_base64)
|
||||
audio_mime = "audio/mpeg"
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to decode audio: {str(e)}"
|
||||
)
|
||||
|
||||
# Call existing InfiniteTalk function (run in thread since it's synchronous)
|
||||
# Note: We pass empty dicts for scene_data and story_context since
|
||||
# Transform Studio doesn't have story context
|
||||
try:
|
||||
result = await asyncio.to_thread(
|
||||
animate_scene_with_voiceover,
|
||||
image_bytes=image_bytes,
|
||||
audio_bytes=audio_bytes,
|
||||
scene_data={}, # Empty for Transform Studio
|
||||
story_context={}, # Empty for Transform Studio
|
||||
user_id=user_id,
|
||||
resolution=resolution,
|
||||
prompt_override=prompt,
|
||||
image_mime=image_mime,
|
||||
audio_mime=audio_mime,
|
||||
client=self.client,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[InfiniteTalk Adapter] Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"InfiniteTalk generation failed: {str(e)}"
|
||||
)
|
||||
|
||||
# Calculate actual cost based on duration
|
||||
actual_cost = self.calculate_cost(resolution, result.get("duration", 5.0))
|
||||
|
||||
# Update result with actual cost and additional metadata
|
||||
result["cost"] = actual_cost
|
||||
result["resolution"] = resolution
|
||||
|
||||
# Get video dimensions from resolution
|
||||
resolution_dims = {
|
||||
"480p": (854, 480),
|
||||
"720p": (1280, 720),
|
||||
}
|
||||
width, height = resolution_dims.get(resolution, (1280, 720))
|
||||
result["width"] = width
|
||||
result["height"] = height
|
||||
|
||||
logger.info(
|
||||
f"[InfiniteTalk Adapter] ✅ Generated talking avatar: "
|
||||
f"resolution={resolution}, duration={result.get('duration', 5.0)}s, cost=${actual_cost:.2f}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -7,6 +7,11 @@ from .edit_service import EditStudioService, EditStudioRequest
|
||||
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
|
||||
from .control_service import ControlStudioService, ControlStudioRequest
|
||||
from .social_optimizer_service import SocialOptimizerService, SocialOptimizerRequest
|
||||
from .transform_service import (
|
||||
TransformStudioService,
|
||||
TransformImageToVideoRequest,
|
||||
TalkingAvatarRequest,
|
||||
)
|
||||
from .templates import Platform, TemplateCategory, ImageTemplate
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
@@ -24,6 +29,7 @@ class ImageStudioManager:
|
||||
self.upscale_service = UpscaleStudioService()
|
||||
self.control_service = ControlStudioService()
|
||||
self.social_optimizer_service = SocialOptimizerService()
|
||||
self.transform_service = TransformStudioService()
|
||||
logger.info("[Image Studio Manager] Initialized successfully")
|
||||
|
||||
# ====================
|
||||
@@ -339,4 +345,35 @@ class ImageStudioManager:
|
||||
}
|
||||
|
||||
return specs.get(platform, {})
|
||||
|
||||
# ====================
|
||||
# TRANSFORM STUDIO
|
||||
# ====================
|
||||
|
||||
async def transform_image_to_video(
|
||||
self,
|
||||
request: TransformImageToVideoRequest,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Transform image to video using WAN 2.5."""
|
||||
logger.info("[Image Studio] Transform image-to-video request from user: %s", user_id)
|
||||
return await self.transform_service.transform_image_to_video(request, user_id=user_id or "anonymous")
|
||||
|
||||
async def create_talking_avatar(
|
||||
self,
|
||||
request: TalkingAvatarRequest,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create talking avatar using InfiniteTalk."""
|
||||
logger.info("[Image Studio] Talking avatar request from user: %s", user_id)
|
||||
return await self.transform_service.create_talking_avatar(request, user_id=user_id or "anonymous")
|
||||
|
||||
def estimate_transform_cost(
|
||||
self,
|
||||
operation: str,
|
||||
resolution: str,
|
||||
duration: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Estimate cost for transform operation."""
|
||||
return self.transform_service.estimate_cost(operation, resolution, duration)
|
||||
|
||||
|
||||
379
backend/services/image_studio/transform_service.py
Normal file
379
backend/services/image_studio/transform_service.py
Normal file
@@ -0,0 +1,379 @@
|
||||
"""Transform Studio service for image-to-video and talking avatar generation."""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from .wan25_service import WAN25Service
|
||||
from .infinitetalk_adapter import InfiniteTalkService
|
||||
from services.llm_providers.main_video_generation import track_video_usage
|
||||
from utils.logger_utils import get_service_logger
|
||||
from utils.file_storage import save_file_safely, sanitize_filename
|
||||
|
||||
logger = get_service_logger("image_studio.transform")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformImageToVideoRequest:
|
||||
"""Request for WAN 2.5 image-to-video."""
|
||||
image_base64: str
|
||||
prompt: str
|
||||
audio_base64: Optional[str] = None
|
||||
resolution: str = "720p" # 480p, 720p, 1080p
|
||||
duration: int = 5 # 5 or 10 seconds
|
||||
negative_prompt: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
enable_prompt_expansion: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class TalkingAvatarRequest:
|
||||
"""Request for InfiniteTalk talking avatar."""
|
||||
image_base64: str
|
||||
audio_base64: str
|
||||
resolution: str = "720p" # 480p or 720p
|
||||
prompt: Optional[str] = None
|
||||
mask_image_base64: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
|
||||
class TransformStudioService:
|
||||
"""Service for Transform Studio operations."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Transform Studio service."""
|
||||
self.wan25_service = WAN25Service()
|
||||
self.infinitetalk_service = InfiniteTalkService()
|
||||
|
||||
# Video output directory
|
||||
# __file__ is: backend/services/image_studio/transform_service.py
|
||||
# We need: backend/transform_videos
|
||||
base_dir = Path(__file__).parent.parent.parent.parent
|
||||
self.output_dir = base_dir / "transform_videos"
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Verify directory was created
|
||||
if not self.output_dir.exists():
|
||||
raise RuntimeError(f"Failed to create transform_videos directory: {self.output_dir}")
|
||||
|
||||
logger.info(f"[Transform Studio] Initialized with output directory: {self.output_dir}")
|
||||
|
||||
def _save_video_file(
|
||||
self,
|
||||
video_bytes: bytes,
|
||||
operation_type: str,
|
||||
user_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Save video file to disk.
|
||||
|
||||
Args:
|
||||
video_bytes: Video content as bytes
|
||||
operation_type: Type of operation (e.g., "image-to-video", "talking-avatar")
|
||||
user_id: User ID for directory organization
|
||||
|
||||
Returns:
|
||||
Dictionary with filename, file_path, and file_url
|
||||
"""
|
||||
# Create user-specific directory
|
||||
user_dir = self.output_dir / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate filename
|
||||
filename = f"{operation_type}_{uuid.uuid4().hex[:8]}.mp4"
|
||||
filename = sanitize_filename(filename)
|
||||
|
||||
# Save file
|
||||
file_path, error = save_file_safely(
|
||||
content=video_bytes,
|
||||
directory=user_dir,
|
||||
filename=filename,
|
||||
max_file_size=500 * 1024 * 1024 # 500MB max for videos
|
||||
)
|
||||
|
||||
if error:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to save video file: {error}"
|
||||
)
|
||||
|
||||
file_url = f"/api/image-studio/videos/{user_id}/{filename}"
|
||||
|
||||
return {
|
||||
"filename": filename,
|
||||
"file_path": str(file_path),
|
||||
"file_url": file_url,
|
||||
"file_size": len(video_bytes),
|
||||
}
|
||||
|
||||
async def transform_image_to_video(
|
||||
self,
|
||||
request: TransformImageToVideoRequest,
|
||||
user_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Transform image to video using WAN 2.5.
|
||||
|
||||
Args:
|
||||
request: Transform request
|
||||
user_id: User ID for tracking and file organization
|
||||
|
||||
Returns:
|
||||
Dictionary with video URL, metadata, and cost
|
||||
"""
|
||||
logger.info(
|
||||
f"[Transform Studio] Image-to-video request from user {user_id}: "
|
||||
f"resolution={request.resolution}, duration={request.duration}s"
|
||||
)
|
||||
|
||||
# Generate video using WAN 2.5
|
||||
result = await self.wan25_service.generate_video(
|
||||
image_base64=request.image_base64,
|
||||
prompt=request.prompt,
|
||||
audio_base64=request.audio_base64,
|
||||
resolution=request.resolution,
|
||||
duration=request.duration,
|
||||
negative_prompt=request.negative_prompt,
|
||||
seed=request.seed,
|
||||
enable_prompt_expansion=request.enable_prompt_expansion,
|
||||
)
|
||||
|
||||
# Save video to disk
|
||||
save_result = self._save_video_file(
|
||||
video_bytes=result["video_bytes"],
|
||||
operation_type="image-to-video",
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Track usage
|
||||
try:
|
||||
usage_info = track_video_usage(
|
||||
user_id=user_id,
|
||||
provider=result["provider"],
|
||||
model_name=result["model_name"],
|
||||
prompt=result["prompt"],
|
||||
video_bytes=result["video_bytes"],
|
||||
cost_override=result["cost"],
|
||||
)
|
||||
logger.info(
|
||||
f"[Transform Studio] Usage tracked: {usage_info.get('current_calls', 0)} / "
|
||||
f"{usage_info.get('video_limit_display', '∞')} videos, "
|
||||
f"cost=${result['cost']:.2f}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transform Studio] Failed to track usage: {e}")
|
||||
|
||||
# Save to asset library
|
||||
try:
|
||||
from services.database import get_db
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="video",
|
||||
source_module="image_studio",
|
||||
filename=save_result["filename"],
|
||||
file_url=save_result["file_url"],
|
||||
file_path=save_result["file_path"],
|
||||
file_size=save_result["file_size"],
|
||||
mime_type="video/mp4",
|
||||
title=f"Transform: Image-to-Video ({request.resolution})",
|
||||
description=f"Generated video using WAN 2.5: {request.prompt[:100]}",
|
||||
prompt=result["prompt"],
|
||||
tags=["image_studio", "transform", "video", "image-to-video", request.resolution],
|
||||
provider=result["provider"],
|
||||
model=result["model_name"],
|
||||
cost=result["cost"],
|
||||
asset_metadata={
|
||||
"resolution": request.resolution,
|
||||
"duration": result["duration"],
|
||||
"operation": "image-to-video",
|
||||
"width": result["width"],
|
||||
"height": result["height"],
|
||||
}
|
||||
)
|
||||
logger.info(f"[Transform Studio] Video saved to asset library")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transform Studio] Failed to save to asset library: {e}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"video_url": save_result["file_url"],
|
||||
"video_base64": None, # Don't include base64 for large videos
|
||||
"duration": result["duration"],
|
||||
"resolution": result["resolution"],
|
||||
"width": result["width"],
|
||||
"height": result["height"],
|
||||
"file_size": save_result["file_size"],
|
||||
"cost": result["cost"],
|
||||
"provider": result["provider"],
|
||||
"model": result["model_name"],
|
||||
"metadata": result.get("metadata", {}),
|
||||
}
|
||||
|
||||
async def create_talking_avatar(
|
||||
self,
|
||||
request: TalkingAvatarRequest,
|
||||
user_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create talking avatar using InfiniteTalk.
|
||||
|
||||
Args:
|
||||
request: Talking avatar request
|
||||
user_id: User ID for tracking and file organization
|
||||
|
||||
Returns:
|
||||
Dictionary with video URL, metadata, and cost
|
||||
"""
|
||||
logger.info(
|
||||
f"[Transform Studio] Talking avatar request from user {user_id}: "
|
||||
f"resolution={request.resolution}"
|
||||
)
|
||||
|
||||
# Generate video using InfiniteTalk
|
||||
result = await self.infinitetalk_service.create_talking_avatar(
|
||||
image_base64=request.image_base64,
|
||||
audio_base64=request.audio_base64,
|
||||
resolution=request.resolution,
|
||||
prompt=request.prompt,
|
||||
mask_image_base64=request.mask_image_base64,
|
||||
seed=request.seed,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Save video to disk
|
||||
save_result = self._save_video_file(
|
||||
video_bytes=result["video_bytes"],
|
||||
operation_type="talking-avatar",
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Track usage
|
||||
try:
|
||||
usage_info = track_video_usage(
|
||||
user_id=user_id,
|
||||
provider=result["provider"],
|
||||
model_name=result["model_name"],
|
||||
prompt=result.get("prompt", ""),
|
||||
video_bytes=result["video_bytes"],
|
||||
cost_override=result["cost"],
|
||||
)
|
||||
logger.info(
|
||||
f"[Transform Studio] Usage tracked: {usage_info.get('current_calls', 0)} / "
|
||||
f"{usage_info.get('video_limit_display', '∞')} videos, "
|
||||
f"cost=${result['cost']:.2f}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transform Studio] Failed to track usage: {e}")
|
||||
|
||||
# Save to asset library
|
||||
try:
|
||||
from services.database import get_db
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="video",
|
||||
source_module="image_studio",
|
||||
filename=save_result["filename"],
|
||||
file_url=save_result["file_url"],
|
||||
file_path=save_result["file_path"],
|
||||
file_size=save_result["file_size"],
|
||||
mime_type="video/mp4",
|
||||
title=f"Transform: Talking Avatar ({request.resolution})",
|
||||
description="Generated talking avatar video using InfiniteTalk",
|
||||
prompt=result.get("prompt", ""),
|
||||
tags=["image_studio", "transform", "video", "talking-avatar", request.resolution],
|
||||
provider=result["provider"],
|
||||
model=result["model_name"],
|
||||
cost=result["cost"],
|
||||
asset_metadata={
|
||||
"resolution": request.resolution,
|
||||
"duration": result.get("duration", 5.0),
|
||||
"operation": "talking-avatar",
|
||||
"width": result.get("width", 1280),
|
||||
"height": result.get("height", 720),
|
||||
}
|
||||
)
|
||||
logger.info(f"[Transform Studio] Video saved to asset library")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transform Studio] Failed to save to asset library: {e}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"video_url": save_result["file_url"],
|
||||
"video_base64": None, # Don't include base64 for large videos
|
||||
"duration": result.get("duration", 5.0),
|
||||
"resolution": result.get("resolution", request.resolution),
|
||||
"width": result.get("width", 1280),
|
||||
"height": result.get("height", 720),
|
||||
"file_size": save_result["file_size"],
|
||||
"cost": result["cost"],
|
||||
"provider": result["provider"],
|
||||
"model": result["model_name"],
|
||||
"metadata": result.get("metadata", {}),
|
||||
}
|
||||
|
||||
def estimate_cost(
|
||||
self,
|
||||
operation: str,
|
||||
resolution: str,
|
||||
duration: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Estimate cost for transform operation.
|
||||
|
||||
Args:
|
||||
operation: Operation type ("image-to-video" or "talking-avatar")
|
||||
resolution: Output resolution
|
||||
duration: Video duration in seconds (for image-to-video)
|
||||
|
||||
Returns:
|
||||
Cost estimation details
|
||||
"""
|
||||
if operation == "image-to-video":
|
||||
if duration is None:
|
||||
duration = 5
|
||||
cost = self.wan25_service.calculate_cost(resolution, duration)
|
||||
return {
|
||||
"estimated_cost": cost,
|
||||
"breakdown": {
|
||||
"base_cost": 0.0,
|
||||
"per_second": self.wan25_service.calculate_cost(resolution, 1),
|
||||
"duration": duration,
|
||||
"total": cost,
|
||||
},
|
||||
"currency": "USD",
|
||||
"provider": "wavespeed",
|
||||
"model": "alibaba/wan-2.5/image-to-video",
|
||||
}
|
||||
elif operation == "talking-avatar":
|
||||
# InfiniteTalk minimum is 5 seconds
|
||||
estimated_duration = duration or 5.0
|
||||
cost = self.infinitetalk_service.calculate_cost(resolution, estimated_duration)
|
||||
return {
|
||||
"estimated_cost": cost,
|
||||
"breakdown": {
|
||||
"base_cost": 0.0,
|
||||
"per_second": self.infinitetalk_service.calculate_cost(resolution, 1.0),
|
||||
"duration": estimated_duration,
|
||||
"total": cost,
|
||||
},
|
||||
"currency": "USD",
|
||||
"provider": "wavespeed",
|
||||
"model": "wavespeed-ai/infinitetalk",
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown operation: {operation}")
|
||||
|
||||
295
backend/services/image_studio/wan25_service.py
Normal file
295
backend/services/image_studio/wan25_service.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""WAN 2.5 service for Alibaba image-to-video generation via WaveSpeed."""
|
||||
|
||||
import base64
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("image_studio.wan25")
|
||||
|
||||
WAN25_MODEL_PATH = "alibaba/wan-2.5/image-to-video"
|
||||
WAN25_MODEL_NAME = "alibaba/wan-2.5/image-to-video"
|
||||
|
||||
# Pricing per second (from WaveSpeed docs)
|
||||
PRICING = {
|
||||
"480p": 0.05, # $0.05 per second
|
||||
"720p": 0.10, # $0.10 per second
|
||||
"1080p": 0.15, # $0.15 per second
|
||||
}
|
||||
|
||||
MAX_IMAGE_BYTES = 10 * 1024 * 1024 # 10MB (recommended)
|
||||
MAX_AUDIO_BYTES = 15 * 1024 * 1024 # 15MB (API limit)
|
||||
MIN_AUDIO_DURATION = 3 # seconds
|
||||
MAX_AUDIO_DURATION = 30 # seconds
|
||||
|
||||
|
||||
def _as_data_uri(content_bytes: bytes, mime_type: str) -> str:
|
||||
"""Convert bytes to data URI."""
|
||||
encoded = base64.b64encode(content_bytes).decode("utf-8")
|
||||
return f"data:{mime_type};base64,{encoded}"
|
||||
|
||||
|
||||
def _decode_base64_image(image_base64: str) -> tuple[bytes, str]:
|
||||
"""Decode base64 image, handling data URIs."""
|
||||
if image_base64.startswith("data:"):
|
||||
# Extract mime type and base64 data
|
||||
if "," not in image_base64:
|
||||
raise ValueError("Invalid data URI format: missing comma separator")
|
||||
header, encoded = image_base64.split(",", 1)
|
||||
mime_parts = header.split(":")[1].split(";")[0] if ":" in header else "image/png"
|
||||
mime_type = mime_parts.strip()
|
||||
if not mime_type:
|
||||
mime_type = "image/png"
|
||||
image_bytes = base64.b64decode(encoded)
|
||||
else:
|
||||
# Assume it's raw base64
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
mime_type = "image/png" # Default
|
||||
|
||||
return image_bytes, mime_type
|
||||
|
||||
|
||||
def _decode_base64_audio(audio_base64: str) -> tuple[bytes, str]:
|
||||
"""Decode base64 audio, handling data URIs."""
|
||||
if audio_base64.startswith("data:"):
|
||||
if "," not in audio_base64:
|
||||
raise ValueError("Invalid data URI format: missing comma separator")
|
||||
header, encoded = audio_base64.split(",", 1)
|
||||
mime_parts = header.split(":")[1].split(";")[0] if ":" in header else "audio/mpeg"
|
||||
mime_type = mime_parts.strip()
|
||||
if not mime_type:
|
||||
mime_type = "audio/mpeg"
|
||||
audio_bytes = base64.b64decode(encoded)
|
||||
else:
|
||||
audio_bytes = base64.b64decode(audio_base64)
|
||||
mime_type = "audio/mpeg" # Default
|
||||
|
||||
return audio_bytes, mime_type
|
||||
|
||||
|
||||
class WAN25Service:
|
||||
"""Service for Alibaba WAN 2.5 image-to-video generation."""
|
||||
|
||||
def __init__(self, client: Optional[WaveSpeedClient] = None):
|
||||
"""Initialize WAN 2.5 service."""
|
||||
self.client = client or WaveSpeedClient()
|
||||
logger.info("[WAN 2.5] Service initialized")
|
||||
|
||||
def calculate_cost(self, resolution: str, duration: int) -> float:
|
||||
"""Calculate cost for video generation.
|
||||
|
||||
Args:
|
||||
resolution: Output resolution (480p, 720p, 1080p)
|
||||
duration: Video duration in seconds (5 or 10)
|
||||
|
||||
Returns:
|
||||
Cost in USD
|
||||
"""
|
||||
cost_per_second = PRICING.get(resolution, PRICING["720p"])
|
||||
return cost_per_second * duration
|
||||
|
||||
async def generate_video(
|
||||
self,
|
||||
image_base64: str,
|
||||
prompt: str,
|
||||
audio_base64: Optional[str] = None,
|
||||
resolution: str = "720p",
|
||||
duration: int = 5,
|
||||
negative_prompt: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
enable_prompt_expansion: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate video using WAN 2.5.
|
||||
|
||||
Args:
|
||||
image_base64: Image in base64 or data URI format
|
||||
prompt: Text prompt describing the video
|
||||
audio_base64: Optional audio file (wav/mp3, 3-30s, ≤15MB)
|
||||
resolution: Output resolution (480p, 720p, 1080p)
|
||||
duration: Video duration in seconds (5 or 10)
|
||||
negative_prompt: Optional negative prompt
|
||||
seed: Optional random seed for reproducibility
|
||||
enable_prompt_expansion: Enable prompt optimizer
|
||||
|
||||
Returns:
|
||||
Dictionary with video bytes, metadata, and cost
|
||||
"""
|
||||
# Validate resolution
|
||||
if resolution not in PRICING:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid resolution: {resolution}. Must be one of: {list(PRICING.keys())}"
|
||||
)
|
||||
|
||||
# Validate duration
|
||||
if duration not in [5, 10]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid duration: {duration}. Must be 5 or 10 seconds"
|
||||
)
|
||||
|
||||
# Validate prompt
|
||||
if not prompt or not prompt.strip():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Prompt is required and cannot be empty"
|
||||
)
|
||||
|
||||
# Decode image
|
||||
try:
|
||||
image_bytes, image_mime = _decode_base64_image(image_base64)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to decode image: {str(e)}"
|
||||
)
|
||||
|
||||
# Validate image size
|
||||
if len(image_bytes) > MAX_IMAGE_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Image exceeds {MAX_IMAGE_BYTES / (1024*1024):.0f}MB limit"
|
||||
)
|
||||
|
||||
# Build payload
|
||||
payload = {
|
||||
"image": _as_data_uri(image_bytes, image_mime),
|
||||
"prompt": prompt,
|
||||
"resolution": resolution,
|
||||
"duration": duration,
|
||||
"enable_prompt_expansion": enable_prompt_expansion,
|
||||
}
|
||||
|
||||
# Add optional audio
|
||||
if audio_base64:
|
||||
try:
|
||||
audio_bytes, audio_mime = _decode_base64_audio(audio_base64)
|
||||
|
||||
# Validate audio size
|
||||
if len(audio_bytes) > MAX_AUDIO_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Audio exceeds {MAX_AUDIO_BYTES / (1024*1024):.0f}MB limit"
|
||||
)
|
||||
|
||||
# Note: Audio duration validation would require audio analysis
|
||||
# For now, we rely on API to handle it (API keeps first 5s/10s if longer)
|
||||
|
||||
payload["audio"] = _as_data_uri(audio_bytes, audio_mime)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to decode audio: {str(e)}"
|
||||
)
|
||||
|
||||
# Add optional parameters
|
||||
if negative_prompt:
|
||||
payload["negative_prompt"] = negative_prompt
|
||||
|
||||
if seed is not None:
|
||||
payload["seed"] = seed
|
||||
|
||||
# Submit to WaveSpeed
|
||||
logger.info(
|
||||
f"[WAN 2.5] Submitting video generation request: resolution={resolution}, duration={duration}s"
|
||||
)
|
||||
|
||||
try:
|
||||
prediction_id = self.client.submit_image_to_video(
|
||||
WAN25_MODEL_PATH,
|
||||
payload,
|
||||
timeout=60
|
||||
)
|
||||
except HTTPException as e:
|
||||
logger.error(f"[WAN 2.5] Submission failed: {e.detail}")
|
||||
raise
|
||||
|
||||
# Poll for completion
|
||||
logger.info(f"[WAN 2.5] Polling for completion: prediction_id={prediction_id}")
|
||||
|
||||
try:
|
||||
# WAN 2.5 typically takes 1-2 minutes
|
||||
result = self.client.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=180, # 3 minutes max
|
||||
interval_seconds=2.0
|
||||
)
|
||||
except HTTPException as e:
|
||||
detail = e.detail or {}
|
||||
if isinstance(detail, dict):
|
||||
detail.setdefault("prediction_id", prediction_id)
|
||||
detail.setdefault("resume_available", True)
|
||||
raise HTTPException(status_code=e.status_code, detail=detail)
|
||||
|
||||
# Extract video URL
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WAN 2.5 completed but returned no outputs"
|
||||
)
|
||||
|
||||
video_url = outputs[0]
|
||||
if not isinstance(video_url, str) or not video_url.startswith("http"):
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Invalid video URL format: {video_url}"
|
||||
)
|
||||
|
||||
# Download video (run synchronous request in thread)
|
||||
logger.info(f"[WAN 2.5] Downloading video from: {video_url}")
|
||||
video_response = await asyncio.to_thread(
|
||||
requests.get,
|
||||
video_url,
|
||||
timeout=180
|
||||
)
|
||||
|
||||
if video_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Failed to download WAN 2.5 video",
|
||||
"status_code": video_response.status_code,
|
||||
"response": video_response.text[:200],
|
||||
}
|
||||
)
|
||||
|
||||
video_bytes = video_response.content
|
||||
metadata = result.get("metadata") or {}
|
||||
|
||||
# Calculate cost
|
||||
cost = self.calculate_cost(resolution, duration)
|
||||
|
||||
# Get video dimensions from resolution
|
||||
resolution_dims = {
|
||||
"480p": (854, 480),
|
||||
"720p": (1280, 720),
|
||||
"1080p": (1920, 1080),
|
||||
}
|
||||
width, height = resolution_dims.get(resolution, (1280, 720))
|
||||
|
||||
logger.info(
|
||||
f"[WAN 2.5] ✅ Generated video: {len(video_bytes)} bytes, "
|
||||
f"resolution={resolution}, duration={duration}s, cost=${cost:.2f}"
|
||||
)
|
||||
|
||||
return {
|
||||
"video_bytes": video_bytes,
|
||||
"prompt": prompt,
|
||||
"duration": float(duration),
|
||||
"model_name": WAN25_MODEL_NAME,
|
||||
"cost": cost,
|
||||
"provider": "wavespeed",
|
||||
"source_video_url": video_url,
|
||||
"prediction_id": prediction_id,
|
||||
"resolution": resolution,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
20
backend/services/product_marketing/__init__.py
Normal file
20
backend/services/product_marketing/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Product Marketing Suite service package."""
|
||||
|
||||
from .orchestrator import ProductMarketingOrchestrator
|
||||
from .brand_dna_sync import BrandDNASyncService
|
||||
from .prompt_builder import ProductMarketingPromptBuilder
|
||||
from .asset_audit import AssetAuditService
|
||||
from .channel_pack import ChannelPackService
|
||||
from .campaign_storage import CampaignStorageService
|
||||
from .product_image_service import ProductImageService
|
||||
|
||||
__all__ = [
|
||||
"ProductMarketingOrchestrator",
|
||||
"BrandDNASyncService",
|
||||
"ProductMarketingPromptBuilder",
|
||||
"AssetAuditService",
|
||||
"ChannelPackService",
|
||||
"CampaignStorageService",
|
||||
"ProductImageService",
|
||||
]
|
||||
|
||||
205
backend/services/product_marketing/asset_audit.py
Normal file
205
backend/services/product_marketing/asset_audit.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""
|
||||
Asset Audit Service
|
||||
Analyzes uploaded assets and recommends enhancement operations.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from loguru import logger
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class AssetAuditService:
|
||||
"""Service to audit assets and recommend enhancements."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Asset Audit Service."""
|
||||
self.logger = logger
|
||||
logger.info("[Asset Audit] Service initialized")
|
||||
|
||||
def audit_asset(
|
||||
self,
|
||||
image_base64: str,
|
||||
asset_metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Audit an uploaded asset and recommend enhancement operations.
|
||||
|
||||
Args:
|
||||
image_base64: Base64 encoded image
|
||||
asset_metadata: Optional metadata about the asset
|
||||
|
||||
Returns:
|
||||
Audit results with recommendations
|
||||
"""
|
||||
try:
|
||||
# Decode image
|
||||
image_bytes = self._decode_base64(image_base64)
|
||||
if not image_bytes:
|
||||
raise ValueError("Invalid image data")
|
||||
|
||||
# Analyze image
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
width, height = image.size
|
||||
format_type = image.format or "PNG"
|
||||
mode = image.mode
|
||||
|
||||
# Basic quality checks
|
||||
quality_score = self._assess_quality(image, width, height)
|
||||
|
||||
# Generate recommendations
|
||||
recommendations = []
|
||||
|
||||
# Resolution recommendations
|
||||
if width < 1080 or height < 1080:
|
||||
recommendations.append({
|
||||
"operation": "upscale",
|
||||
"priority": "high",
|
||||
"reason": f"Image resolution ({width}x{height}) is below recommended 1080p for social media",
|
||||
"suggested_mode": "fast" if width < 512 else "conservative",
|
||||
})
|
||||
|
||||
# Background recommendations
|
||||
if mode == "RGBA" and self._has_transparency(image):
|
||||
recommendations.append({
|
||||
"operation": "remove_background",
|
||||
"priority": "low",
|
||||
"reason": "Image already has transparency, background removal may not be needed",
|
||||
})
|
||||
else:
|
||||
recommendations.append({
|
||||
"operation": "remove_background",
|
||||
"priority": "medium",
|
||||
"reason": "Background removal can create versatile product images",
|
||||
})
|
||||
|
||||
# Enhancement recommendations based on quality
|
||||
if quality_score < 0.7:
|
||||
recommendations.append({
|
||||
"operation": "enhance",
|
||||
"priority": "high",
|
||||
"reason": f"Image quality score ({quality_score:.2f}) suggests enhancement needed",
|
||||
"suggested_operations": ["upscale", "general_edit"],
|
||||
})
|
||||
|
||||
# Format recommendations
|
||||
if format_type not in ["PNG", "JPEG"]:
|
||||
recommendations.append({
|
||||
"operation": "convert",
|
||||
"priority": "low",
|
||||
"reason": f"Format {format_type} may not be optimal for web/social media",
|
||||
"suggested_format": "PNG" if mode == "RGBA" else "JPEG",
|
||||
})
|
||||
|
||||
audit_result = {
|
||||
"asset_info": {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"format": format_type,
|
||||
"mode": mode,
|
||||
"quality_score": quality_score,
|
||||
},
|
||||
"recommendations": recommendations,
|
||||
"status": "usable" if quality_score > 0.6 else "needs_enhancement",
|
||||
}
|
||||
|
||||
logger.info(f"[Asset Audit] Audited asset: {width}x{height}, quality: {quality_score:.2f}")
|
||||
return audit_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Asset Audit] Error auditing asset: {str(e)}")
|
||||
return {
|
||||
"asset_info": {},
|
||||
"recommendations": [],
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
def _decode_base64(self, image_base64: str) -> Optional[bytes]:
|
||||
"""Decode base64 image data."""
|
||||
try:
|
||||
if image_base64.startswith("data:"):
|
||||
_, b64data = image_base64.split(",", 1)
|
||||
else:
|
||||
b64data = image_base64
|
||||
return base64.b64decode(b64data)
|
||||
except Exception as e:
|
||||
logger.error(f"[Asset Audit] Error decoding base64: {str(e)}")
|
||||
return None
|
||||
|
||||
def _has_transparency(self, image: Image.Image) -> bool:
|
||||
"""Check if image has transparency."""
|
||||
if image.mode in ("RGBA", "LA"):
|
||||
alpha = image.split()[-1]
|
||||
return any(pixel < 255 for pixel in alpha.getdata())
|
||||
return False
|
||||
|
||||
def _assess_quality(self, image: Image.Image, width: int, height: int) -> float:
|
||||
"""
|
||||
Assess image quality score (0.0 to 1.0).
|
||||
|
||||
Simple heuristic based on resolution and format.
|
||||
"""
|
||||
score = 0.5 # Base score
|
||||
|
||||
# Resolution scoring
|
||||
min_dimension = min(width, height)
|
||||
if min_dimension >= 1080:
|
||||
score += 0.3
|
||||
elif min_dimension >= 512:
|
||||
score += 0.2
|
||||
elif min_dimension >= 256:
|
||||
score += 0.1
|
||||
|
||||
# Format scoring
|
||||
if image.format in ["PNG", "JPEG"]:
|
||||
score += 0.1
|
||||
|
||||
# Mode scoring
|
||||
if image.mode in ["RGB", "RGBA"]:
|
||||
score += 0.1
|
||||
|
||||
return min(score, 1.0)
|
||||
|
||||
def batch_audit_assets(
|
||||
self,
|
||||
assets: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Audit multiple assets in batch.
|
||||
|
||||
Args:
|
||||
assets: List of asset dictionaries with 'image_base64' and optional 'metadata'
|
||||
|
||||
Returns:
|
||||
Batch audit results
|
||||
"""
|
||||
results = []
|
||||
for asset in assets:
|
||||
audit_result = self.audit_asset(
|
||||
asset.get('image_base64'),
|
||||
asset.get('metadata')
|
||||
)
|
||||
results.append({
|
||||
"asset_id": asset.get('id'),
|
||||
"audit": audit_result,
|
||||
})
|
||||
|
||||
# Summary statistics
|
||||
total_assets = len(results)
|
||||
usable_count = sum(1 for r in results if r["audit"]["status"] == "usable")
|
||||
needs_enhancement_count = sum(
|
||||
1 for r in results if r["audit"]["status"] == "needs_enhancement"
|
||||
)
|
||||
|
||||
return {
|
||||
"results": results,
|
||||
"summary": {
|
||||
"total_assets": total_assets,
|
||||
"usable": usable_count,
|
||||
"needs_enhancement": needs_enhancement_count,
|
||||
"error": total_assets - usable_count - needs_enhancement_count,
|
||||
},
|
||||
}
|
||||
|
||||
176
backend/services/product_marketing/brand_dna_sync.py
Normal file
176
backend/services/product_marketing/brand_dna_sync.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Brand DNA Sync Service
|
||||
Normalizes persona data and onboarding information into reusable brand tokens.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from loguru import logger
|
||||
|
||||
from services.onboarding import OnboardingDatabaseService
|
||||
from services.database import SessionLocal
|
||||
|
||||
|
||||
class BrandDNASyncService:
|
||||
"""Service to sync and normalize brand DNA from onboarding and persona data."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Brand DNA Sync Service."""
|
||||
self.logger = logger
|
||||
logger.info("[Brand DNA Sync] Service initialized")
|
||||
|
||||
def get_brand_dna_tokens(self, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract and normalize brand DNA tokens from onboarding and persona data.
|
||||
|
||||
Args:
|
||||
user_id: User ID to fetch data for
|
||||
|
||||
Returns:
|
||||
Dictionary of brand DNA tokens ready for prompt injection
|
||||
"""
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
onboarding_db = OnboardingDatabaseService(db)
|
||||
website_analysis = onboarding_db.get_website_analysis(user_id, db)
|
||||
persona_data = onboarding_db.get_persona_data(user_id, db)
|
||||
competitor_analyses = onboarding_db.get_competitor_analysis(user_id, db)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
brand_tokens = {
|
||||
"writing_style": {},
|
||||
"target_audience": {},
|
||||
"visual_identity": {},
|
||||
"persona": {},
|
||||
"competitive_positioning": {},
|
||||
}
|
||||
|
||||
# Extract writing style from website analysis
|
||||
if website_analysis:
|
||||
writing_style = website_analysis.get('writing_style') or {}
|
||||
target_audience = website_analysis.get('target_audience') or {}
|
||||
brand_analysis = website_analysis.get('brand_analysis') or {}
|
||||
style_guidelines = website_analysis.get('style_guidelines') or {}
|
||||
|
||||
# Ensure writing_style is a dict before accessing
|
||||
if isinstance(writing_style, dict):
|
||||
brand_tokens["writing_style"] = {
|
||||
"tone": writing_style.get('tone', 'professional'),
|
||||
"voice": writing_style.get('voice', 'authoritative'),
|
||||
"complexity": writing_style.get('complexity', 'intermediate'),
|
||||
"engagement_level": writing_style.get('engagement_level', 'moderate'),
|
||||
}
|
||||
|
||||
# Ensure target_audience is a dict before accessing
|
||||
if isinstance(target_audience, dict):
|
||||
brand_tokens["target_audience"] = {
|
||||
"demographics": target_audience.get('demographics', []),
|
||||
"industry_focus": target_audience.get('industry_focus', 'general'),
|
||||
"expertise_level": target_audience.get('expertise_level', 'intermediate'),
|
||||
}
|
||||
|
||||
# Ensure brand_analysis is a dict before accessing
|
||||
if isinstance(brand_analysis, dict) and brand_analysis:
|
||||
brand_tokens["visual_identity"] = {
|
||||
"color_palette": brand_analysis.get('color_palette', []),
|
||||
"brand_values": brand_analysis.get('brand_values', []),
|
||||
"positioning": brand_analysis.get('positioning', ''),
|
||||
}
|
||||
|
||||
# Add style_guidelines if available and visual_identity exists
|
||||
if style_guidelines and isinstance(style_guidelines, dict):
|
||||
if "visual_identity" not in brand_tokens:
|
||||
brand_tokens["visual_identity"] = {}
|
||||
brand_tokens["visual_identity"]["style_guidelines"] = style_guidelines
|
||||
|
||||
# Extract persona data
|
||||
if persona_data:
|
||||
core_persona = persona_data.get('corePersona') or {}
|
||||
platform_personas = persona_data.get('platformPersonas') or {}
|
||||
|
||||
# Ensure core_persona is a dict before accessing
|
||||
if isinstance(core_persona, dict) and core_persona:
|
||||
brand_tokens["persona"] = {
|
||||
"persona_name": core_persona.get('persona_name', ''),
|
||||
"archetype": core_persona.get('archetype', ''),
|
||||
"core_belief": core_persona.get('core_belief', ''),
|
||||
"linguistic_fingerprint": core_persona.get('linguistic_fingerprint', {}),
|
||||
}
|
||||
|
||||
# Ensure persona dict exists before setting platform_personas
|
||||
if "persona" not in brand_tokens:
|
||||
brand_tokens["persona"] = {}
|
||||
|
||||
# Only set platform_personas if it's a valid dict
|
||||
if isinstance(platform_personas, dict):
|
||||
brand_tokens["persona"]["platform_personas"] = platform_personas
|
||||
|
||||
# Extract competitive positioning
|
||||
if competitor_analyses and isinstance(competitor_analyses, list) and len(competitor_analyses) > 0:
|
||||
# Extract differentiation points
|
||||
brand_tokens["competitive_positioning"] = {
|
||||
"differentiators": [],
|
||||
"unique_value_props": [],
|
||||
}
|
||||
|
||||
for competitor in competitor_analyses[:3]: # Top 3 competitors
|
||||
if not isinstance(competitor, dict):
|
||||
continue
|
||||
|
||||
analysis_data = competitor.get('analysis_data') or {}
|
||||
if isinstance(analysis_data, dict) and analysis_data:
|
||||
competitive_insights = analysis_data.get('competitive_analysis') or {}
|
||||
if isinstance(competitive_insights, dict) and competitive_insights:
|
||||
differentiators = competitive_insights.get('differentiators', [])
|
||||
if isinstance(differentiators, list) and differentiators:
|
||||
brand_tokens["competitive_positioning"]["differentiators"].extend(
|
||||
differentiators[:2]
|
||||
)
|
||||
|
||||
logger.info(f"[Brand DNA Sync] Extracted brand tokens for user {user_id}")
|
||||
return brand_tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Brand DNA Sync] Error extracting brand tokens: {str(e)}")
|
||||
return {
|
||||
"writing_style": {"tone": "professional", "voice": "authoritative"},
|
||||
"target_audience": {"demographics": [], "expertise_level": "intermediate"},
|
||||
"visual_identity": {},
|
||||
"persona": {},
|
||||
"competitive_positioning": {},
|
||||
}
|
||||
|
||||
def get_channel_specific_dna(
|
||||
self,
|
||||
user_id: str,
|
||||
channel: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get channel-specific brand DNA adaptations.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
channel: Target channel (instagram, linkedin, tiktok, etc.)
|
||||
|
||||
Returns:
|
||||
Channel-specific brand DNA tokens
|
||||
"""
|
||||
brand_tokens = self.get_brand_dna_tokens(user_id)
|
||||
channel_dna = brand_tokens.copy()
|
||||
|
||||
# Get platform-specific persona if available
|
||||
persona = brand_tokens.get("persona") or {}
|
||||
platform_personas = persona.get("platform_personas") or {}
|
||||
|
||||
if isinstance(platform_personas, dict) and channel in platform_personas:
|
||||
platform_persona = platform_personas[channel]
|
||||
if isinstance(platform_persona, dict):
|
||||
channel_dna["platform_adaptation"] = {
|
||||
"content_format_rules": platform_persona.get('content_format_rules') or {},
|
||||
"engagement_patterns": platform_persona.get('engagement_patterns') or {},
|
||||
"visual_identity": platform_persona.get('visual_identity') or {},
|
||||
}
|
||||
|
||||
return channel_dna
|
||||
|
||||
222
backend/services/product_marketing/campaign_storage.py
Normal file
222
backend/services/product_marketing/campaign_storage.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""
|
||||
Campaign Storage Service
|
||||
Handles database persistence for campaigns, proposals, and assets.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
|
||||
from models.product_marketing_models import Campaign, CampaignProposal, CampaignAsset, CampaignStatus
|
||||
from services.database import SessionLocal
|
||||
|
||||
|
||||
class CampaignStorageService:
|
||||
"""Service for storing and retrieving campaigns from database."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Campaign Storage Service."""
|
||||
self.logger = logger
|
||||
logger.info("[Campaign Storage] Service initialized")
|
||||
|
||||
def save_campaign(
|
||||
self,
|
||||
user_id: str,
|
||||
campaign_data: Dict[str, Any]
|
||||
) -> Campaign:
|
||||
"""
|
||||
Save campaign blueprint to database.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
campaign_data: Campaign blueprint data
|
||||
|
||||
Returns:
|
||||
Saved Campaign object
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
campaign_id = campaign_data.get('campaign_id')
|
||||
|
||||
# Check if campaign exists
|
||||
existing = db.query(Campaign).filter(
|
||||
Campaign.campaign_id == campaign_id,
|
||||
Campaign.user_id == user_id
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
# Update existing campaign
|
||||
existing.campaign_name = campaign_data.get('campaign_name', existing.campaign_name)
|
||||
existing.goal = campaign_data.get('goal', existing.goal)
|
||||
existing.kpi = campaign_data.get('kpi', existing.kpi)
|
||||
existing.status = campaign_data.get('status', existing.status)
|
||||
existing.phases = campaign_data.get('phases', existing.phases)
|
||||
existing.channels = campaign_data.get('channels', existing.channels)
|
||||
existing.asset_nodes = campaign_data.get('asset_nodes', existing.asset_nodes)
|
||||
existing.product_context = campaign_data.get('product_context', existing.product_context)
|
||||
db.commit()
|
||||
db.refresh(existing)
|
||||
logger.info(f"[Campaign Storage] Updated campaign {campaign_id}")
|
||||
return existing
|
||||
else:
|
||||
# Create new campaign
|
||||
campaign = Campaign(
|
||||
campaign_id=campaign_id,
|
||||
user_id=user_id,
|
||||
campaign_name=campaign_data.get('campaign_name'),
|
||||
goal=campaign_data.get('goal'),
|
||||
kpi=campaign_data.get('kpi'),
|
||||
status=campaign_data.get('status', 'draft'),
|
||||
phases=campaign_data.get('phases'),
|
||||
channels=campaign_data.get('channels', []),
|
||||
asset_nodes=campaign_data.get('asset_nodes', []),
|
||||
product_context=campaign_data.get('product_context'),
|
||||
)
|
||||
db.add(campaign)
|
||||
db.commit()
|
||||
db.refresh(campaign)
|
||||
logger.info(f"[Campaign Storage] Saved new campaign {campaign_id}")
|
||||
return campaign
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"[Campaign Storage] Error saving campaign: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def get_campaign(
|
||||
self,
|
||||
user_id: str,
|
||||
campaign_id: str
|
||||
) -> Optional[Campaign]:
|
||||
"""Get campaign by ID."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
campaign = db.query(Campaign).filter(
|
||||
Campaign.campaign_id == campaign_id,
|
||||
Campaign.user_id == user_id
|
||||
).first()
|
||||
return campaign
|
||||
except Exception as e:
|
||||
logger.error(f"[Campaign Storage] Error getting campaign: {str(e)}")
|
||||
return None
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def list_campaigns(
|
||||
self,
|
||||
user_id: str,
|
||||
status: Optional[str] = None,
|
||||
limit: int = 50
|
||||
) -> List[Campaign]:
|
||||
"""List campaigns for user."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
query = db.query(Campaign).filter(Campaign.user_id == user_id)
|
||||
|
||||
if status:
|
||||
query = query.filter(Campaign.status == status)
|
||||
|
||||
campaigns = query.order_by(desc(Campaign.created_at)).limit(limit).all()
|
||||
return campaigns
|
||||
except Exception as e:
|
||||
logger.error(f"[Campaign Storage] Error listing campaigns: {str(e)}")
|
||||
return []
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def save_proposals(
|
||||
self,
|
||||
user_id: str,
|
||||
campaign_id: str,
|
||||
proposals: Dict[str, Any]
|
||||
) -> List[CampaignProposal]:
|
||||
"""Save asset proposals for a campaign."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Delete existing proposals for this campaign
|
||||
db.query(CampaignProposal).filter(
|
||||
CampaignProposal.campaign_id == campaign_id,
|
||||
CampaignProposal.user_id == user_id
|
||||
).delete()
|
||||
|
||||
# Create new proposals
|
||||
saved_proposals = []
|
||||
for asset_id, proposal_data in proposals.get('proposals', {}).items():
|
||||
proposal = CampaignProposal(
|
||||
campaign_id=campaign_id,
|
||||
user_id=user_id,
|
||||
asset_node_id=asset_id,
|
||||
asset_type=proposal_data.get('asset_type'),
|
||||
channel=proposal_data.get('channel'),
|
||||
proposed_prompt=proposal_data.get('proposed_prompt'),
|
||||
recommended_template=proposal_data.get('recommended_template'),
|
||||
recommended_provider=proposal_data.get('recommended_provider'),
|
||||
recommended_model=proposal_data.get('recommended_model'),
|
||||
cost_estimate=proposal_data.get('cost_estimate', 0.0),
|
||||
concept_summary=proposal_data.get('concept_summary'),
|
||||
status='proposed',
|
||||
)
|
||||
db.add(proposal)
|
||||
saved_proposals.append(proposal)
|
||||
|
||||
db.commit()
|
||||
for proposal in saved_proposals:
|
||||
db.refresh(proposal)
|
||||
|
||||
logger.info(f"[Campaign Storage] Saved {len(saved_proposals)} proposals for campaign {campaign_id}")
|
||||
return saved_proposals
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"[Campaign Storage] Error saving proposals: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def get_proposals(
|
||||
self,
|
||||
user_id: str,
|
||||
campaign_id: str
|
||||
) -> List[CampaignProposal]:
|
||||
"""Get proposals for a campaign."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
proposals = db.query(CampaignProposal).filter(
|
||||
CampaignProposal.campaign_id == campaign_id,
|
||||
CampaignProposal.user_id == user_id
|
||||
).all()
|
||||
return proposals
|
||||
except Exception as e:
|
||||
logger.error(f"[Campaign Storage] Error getting proposals: {str(e)}")
|
||||
return []
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def update_campaign_status(
|
||||
self,
|
||||
user_id: str,
|
||||
campaign_id: str,
|
||||
status: str
|
||||
) -> bool:
|
||||
"""Update campaign status."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
campaign = db.query(Campaign).filter(
|
||||
Campaign.campaign_id == campaign_id,
|
||||
Campaign.user_id == user_id
|
||||
).first()
|
||||
|
||||
if campaign:
|
||||
campaign.status = status
|
||||
db.commit()
|
||||
logger.info(f"[Campaign Storage] Updated campaign {campaign_id} status to {status}")
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"[Campaign Storage] Error updating status: {str(e)}")
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
180
backend/services/product_marketing/channel_pack.py
Normal file
180
backend/services/product_marketing/channel_pack.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
Channel Pack Service
|
||||
Maps channels to templates, copy frameworks, and platform-specific optimizations.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from loguru import logger
|
||||
|
||||
from services.image_studio.templates import Platform, TemplateManager
|
||||
from services.image_studio.social_optimizer_service import SocialOptimizerService
|
||||
|
||||
|
||||
class ChannelPackService:
|
||||
"""Service to build channel-specific asset packs."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Channel Pack Service."""
|
||||
self.template_manager = TemplateManager()
|
||||
self.social_optimizer = SocialOptimizerService()
|
||||
self.logger = logger
|
||||
logger.info("[Channel Pack] Service initialized")
|
||||
|
||||
def get_channel_pack(
|
||||
self,
|
||||
channel: str,
|
||||
asset_type: str = "social_post"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get channel-specific pack configuration.
|
||||
|
||||
Args:
|
||||
channel: Target channel (instagram, linkedin, tiktok, facebook, twitter, pinterest, youtube)
|
||||
asset_type: Type of asset (social_post, story, reel, cover, etc.)
|
||||
|
||||
Returns:
|
||||
Channel pack configuration with templates, dimensions, copy frameworks
|
||||
"""
|
||||
try:
|
||||
# Map channel string to Platform enum
|
||||
platform_map = {
|
||||
'instagram': Platform.INSTAGRAM,
|
||||
'linkedin': Platform.LINKEDIN,
|
||||
'tiktok': Platform.TIKTOK,
|
||||
'facebook': Platform.FACEBOOK,
|
||||
'twitter': Platform.TWITTER,
|
||||
'pinterest': Platform.PINTEREST,
|
||||
'youtube': Platform.YOUTUBE,
|
||||
}
|
||||
|
||||
platform = platform_map.get(channel.lower())
|
||||
if not platform:
|
||||
raise ValueError(f"Unsupported channel: {channel}")
|
||||
|
||||
# Get templates for this platform
|
||||
templates = self.template_manager.get_platform_templates().get(platform, [])
|
||||
|
||||
# Get platform formats
|
||||
formats = self.social_optimizer.get_platform_formats(platform)
|
||||
|
||||
# Build channel pack
|
||||
pack = {
|
||||
"channel": channel,
|
||||
"platform": platform.value,
|
||||
"asset_type": asset_type,
|
||||
"templates": [
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"dimensions": f"{t.aspect_ratio.width}x{t.aspect_ratio.height}",
|
||||
"aspect_ratio": t.aspect_ratio.ratio,
|
||||
"recommended_provider": t.recommended_provider,
|
||||
"quality": t.quality,
|
||||
}
|
||||
for t in templates
|
||||
],
|
||||
"formats": formats,
|
||||
"copy_framework": self._get_copy_framework(channel, asset_type),
|
||||
"optimization_tips": self._get_optimization_tips(channel),
|
||||
}
|
||||
|
||||
logger.info(f"[Channel Pack] Built pack for {channel} ({asset_type})")
|
||||
return pack
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Channel Pack] Error building pack: {str(e)}")
|
||||
return {
|
||||
"channel": channel,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
def _get_copy_framework(
|
||||
self,
|
||||
channel: str,
|
||||
asset_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Get copy framework for channel and asset type."""
|
||||
frameworks = {
|
||||
"instagram": {
|
||||
"social_post": {
|
||||
"caption_length": "125-150 words optimal",
|
||||
"hashtags": "5-10 relevant hashtags",
|
||||
"cta": "Clear call-to-action in first line",
|
||||
"emoji": "Use 1-3 emojis strategically",
|
||||
},
|
||||
"story": {
|
||||
"text_overlay": "Keep text minimal, readable at small size",
|
||||
"cta": "Swipe-up or link sticker",
|
||||
},
|
||||
},
|
||||
"linkedin": {
|
||||
"social_post": {
|
||||
"length": "150-300 words for maximum engagement",
|
||||
"hashtags": "3-5 professional hashtags",
|
||||
"tone": "Professional, thought-leadership focused",
|
||||
"cta": "Engage with question or call-to-action",
|
||||
},
|
||||
},
|
||||
"tiktok": {
|
||||
"video": {
|
||||
"hook": "Strong hook in first 3 seconds",
|
||||
"caption": "Short, engaging, use trending hashtags",
|
||||
"hashtags": "3-5 trending hashtags",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return frameworks.get(channel, {}).get(asset_type, {})
|
||||
|
||||
def _get_optimization_tips(self, channel: str) -> List[str]:
|
||||
"""Get optimization tips for channel."""
|
||||
tips = {
|
||||
"instagram": [
|
||||
"Use square (1:1) or portrait (4:5) for feed posts",
|
||||
"Include text overlay safe zones (15% top/bottom, 10% left/right)",
|
||||
"Optimize for mobile viewing",
|
||||
],
|
||||
"linkedin": {
|
||||
"Use landscape (1.91:1) for feed posts",
|
||||
"Professional photography style",
|
||||
"Include clear value proposition",
|
||||
},
|
||||
"tiktok": {
|
||||
"Vertical format (9:16) required",
|
||||
"Eye-catching first frame",
|
||||
"Fast-paced, engaging content",
|
||||
},
|
||||
}
|
||||
|
||||
return tips.get(channel, [])
|
||||
|
||||
def build_multi_channel_pack(
|
||||
self,
|
||||
channels: List[str],
|
||||
source_image_base64: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build optimized asset pack for multiple channels from single source.
|
||||
|
||||
Args:
|
||||
channels: List of target channels
|
||||
source_image_base64: Source image to optimize
|
||||
|
||||
Returns:
|
||||
Multi-channel pack with optimized variants
|
||||
"""
|
||||
pack_results = []
|
||||
|
||||
for channel in channels:
|
||||
pack = self.get_channel_pack(channel)
|
||||
pack_results.append({
|
||||
"channel": channel,
|
||||
"pack": pack,
|
||||
})
|
||||
|
||||
return {
|
||||
"source_image": "provided",
|
||||
"channels": pack_results,
|
||||
"total_variants": len(channels),
|
||||
}
|
||||
|
||||
469
backend/services/product_marketing/orchestrator.py
Normal file
469
backend/services/product_marketing/orchestrator.py
Normal file
@@ -0,0 +1,469 @@
|
||||
"""
|
||||
Product Marketing Orchestrator
|
||||
Main service that orchestrates campaign workflows and asset generation.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass
|
||||
from loguru import logger
|
||||
|
||||
from services.image_studio import ImageStudioManager, CreateStudioRequest
|
||||
from .prompt_builder import ProductMarketingPromptBuilder
|
||||
from .brand_dna_sync import BrandDNASyncService
|
||||
from .asset_audit import AssetAuditService
|
||||
from .channel_pack import ChannelPackService
|
||||
from services.database import SessionLocal
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
|
||||
|
||||
@dataclass
|
||||
class CampaignAssetNode:
|
||||
"""Represents an asset node in the campaign graph."""
|
||||
asset_id: str
|
||||
asset_type: str # image, video, text, audio
|
||||
channel: str
|
||||
status: str # draft, generating, ready, approved
|
||||
prompt: Optional[str] = None
|
||||
template_id: Optional[str] = None
|
||||
provider: Optional[str] = None
|
||||
cost_estimate: Optional[float] = None
|
||||
generated_asset_id: Optional[int] = None # Asset Library ID
|
||||
|
||||
|
||||
@dataclass
|
||||
class CampaignBlueprint:
|
||||
"""Campaign blueprint with phases and asset nodes."""
|
||||
campaign_id: str
|
||||
campaign_name: str
|
||||
goal: str
|
||||
kpi: Optional[str] = None
|
||||
phases: List[Dict[str, Any]] = None # teaser, launch, nurture
|
||||
asset_nodes: List[CampaignAssetNode] = None
|
||||
channels: List[str] = None
|
||||
status: str = "draft" # draft, generating, ready, published
|
||||
|
||||
|
||||
class ProductMarketingOrchestrator:
|
||||
"""Main orchestrator for Product Marketing Suite."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Product Marketing Orchestrator."""
|
||||
self.image_studio = ImageStudioManager()
|
||||
self.prompt_builder = ProductMarketingPromptBuilder()
|
||||
self.brand_dna_sync = BrandDNASyncService()
|
||||
self.asset_audit = AssetAuditService()
|
||||
self.channel_pack = ChannelPackService()
|
||||
self.logger = logger
|
||||
logger.info("[Product Marketing Orchestrator] Initialized")
|
||||
|
||||
def create_campaign_blueprint(
|
||||
self,
|
||||
user_id: str,
|
||||
campaign_data: Dict[str, Any]
|
||||
) -> CampaignBlueprint:
|
||||
"""
|
||||
Create campaign blueprint from user input and onboarding data.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
campaign_data: Campaign information (name, goal, channels, etc.)
|
||||
|
||||
Returns:
|
||||
Campaign blueprint with asset nodes
|
||||
"""
|
||||
try:
|
||||
import time
|
||||
campaign_id = campaign_data.get('campaign_id') or f"campaign_{user_id}_{int(time.time())}"
|
||||
campaign_name = campaign_data.get('campaign_name', 'New Campaign')
|
||||
goal = campaign_data.get('goal', 'product_launch')
|
||||
channels = campaign_data.get('channels', [])
|
||||
|
||||
# Get brand DNA for personalization
|
||||
brand_dna = self.brand_dna_sync.get_brand_dna_tokens(user_id)
|
||||
|
||||
# Build campaign phases
|
||||
phases = self._build_campaign_phases(goal, channels)
|
||||
|
||||
# Generate asset nodes for each phase and channel
|
||||
asset_nodes = []
|
||||
for phase in phases:
|
||||
phase_name = phase.get('name')
|
||||
for channel in channels:
|
||||
# Determine required assets for this phase + channel
|
||||
required_assets = self._get_required_assets(phase_name, channel)
|
||||
|
||||
for asset_type in required_assets:
|
||||
asset_node = CampaignAssetNode(
|
||||
asset_id=f"{campaign_id}_{phase_name}_{channel}_{asset_type}",
|
||||
asset_type=asset_type,
|
||||
channel=channel,
|
||||
status="draft",
|
||||
)
|
||||
asset_nodes.append(asset_node)
|
||||
|
||||
blueprint = CampaignBlueprint(
|
||||
campaign_id=campaign_id,
|
||||
campaign_name=campaign_name,
|
||||
goal=goal,
|
||||
kpi=campaign_data.get('kpi'),
|
||||
phases=phases,
|
||||
asset_nodes=asset_nodes,
|
||||
channels=channels,
|
||||
status="draft",
|
||||
)
|
||||
|
||||
logger.info(f"[Orchestrator] Created blueprint for campaign {campaign_id} with {len(asset_nodes)} assets")
|
||||
return blueprint
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Orchestrator] Error creating blueprint: {str(e)}")
|
||||
raise
|
||||
|
||||
def generate_asset_proposals(
|
||||
self,
|
||||
user_id: str,
|
||||
blueprint: CampaignBlueprint,
|
||||
product_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate AI proposals for each asset node in the blueprint.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
blueprint: Campaign blueprint
|
||||
product_context: Product information
|
||||
|
||||
Returns:
|
||||
Dictionary with proposals for each asset node
|
||||
"""
|
||||
try:
|
||||
proposals = {}
|
||||
|
||||
for asset_node in blueprint.asset_nodes:
|
||||
# Build specialized prompt based on asset type and channel
|
||||
if asset_node.asset_type == "image":
|
||||
base_prompt = product_context.get('product_description', 'Product image') if product_context else 'Marketing image'
|
||||
enhanced_prompt = self.prompt_builder.build_marketing_image_prompt(
|
||||
base_prompt=base_prompt,
|
||||
user_id=user_id,
|
||||
channel=asset_node.channel,
|
||||
asset_type="hero_image",
|
||||
product_context=product_context,
|
||||
)
|
||||
|
||||
# Get channel pack for template recommendations
|
||||
channel_pack = self.channel_pack.get_channel_pack(asset_node.channel)
|
||||
recommended_template = channel_pack.get('templates', [{}])[0] if channel_pack.get('templates') else None
|
||||
|
||||
# Estimate cost
|
||||
cost_estimate = self._estimate_asset_cost("image", asset_node.channel)
|
||||
|
||||
proposals[asset_node.asset_id] = {
|
||||
"asset_id": asset_node.asset_id,
|
||||
"asset_type": asset_node.asset_type,
|
||||
"channel": asset_node.channel,
|
||||
"proposed_prompt": enhanced_prompt,
|
||||
"recommended_template": recommended_template.get('id') if recommended_template else None,
|
||||
"recommended_provider": recommended_template.get('recommended_provider', 'wavespeed') if recommended_template else 'wavespeed',
|
||||
"cost_estimate": cost_estimate,
|
||||
"concept_summary": self._generate_concept_summary(enhanced_prompt),
|
||||
}
|
||||
|
||||
elif asset_node.asset_type == "text":
|
||||
base_request = f"Write {asset_node.channel} {asset_node.asset_type} for product launch"
|
||||
enhanced_prompt = self.prompt_builder.build_marketing_copy_prompt(
|
||||
base_request=base_request,
|
||||
user_id=user_id,
|
||||
channel=asset_node.channel,
|
||||
content_type="caption",
|
||||
product_context=product_context,
|
||||
)
|
||||
|
||||
proposals[asset_node.asset_id] = {
|
||||
"asset_id": asset_node.asset_id,
|
||||
"asset_type": asset_node.asset_type,
|
||||
"channel": asset_node.channel,
|
||||
"proposed_prompt": enhanced_prompt,
|
||||
"cost_estimate": 0.0, # Text generation cost is minimal
|
||||
"concept_summary": "Marketing copy optimized for channel and persona",
|
||||
}
|
||||
|
||||
logger.info(f"[Orchestrator] Generated {len(proposals)} asset proposals")
|
||||
return {"proposals": proposals, "total_assets": len(proposals)}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Orchestrator] Error generating proposals: {str(e)}")
|
||||
raise
|
||||
|
||||
async def generate_asset(
|
||||
self,
|
||||
user_id: str,
|
||||
asset_proposal: Dict[str, Any],
|
||||
product_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a single asset using Image Studio APIs.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
asset_proposal: Asset proposal from generate_asset_proposals
|
||||
product_context: Product information
|
||||
|
||||
Returns:
|
||||
Generated asset result
|
||||
"""
|
||||
try:
|
||||
asset_type = asset_proposal.get('asset_type')
|
||||
|
||||
if asset_type == "image":
|
||||
# Build CreateStudioRequest
|
||||
create_request = CreateStudioRequest(
|
||||
prompt=asset_proposal.get('proposed_prompt'),
|
||||
template_id=asset_proposal.get('recommended_template'),
|
||||
provider=asset_proposal.get('recommended_provider', 'wavespeed'),
|
||||
quality="premium",
|
||||
enhance_prompt=True,
|
||||
use_persona=True,
|
||||
num_variations=1,
|
||||
)
|
||||
|
||||
# Generate image using Image Studio
|
||||
result = await self.image_studio.create_image(create_request, user_id=user_id)
|
||||
|
||||
# Asset is automatically tracked in Asset Library via Image Studio
|
||||
return {
|
||||
"success": True,
|
||||
"asset_type": "image",
|
||||
"result": result,
|
||||
"asset_library_ids": [
|
||||
r.get('asset_id') for r in result.get('results', [])
|
||||
if r.get('asset_id')
|
||||
],
|
||||
}
|
||||
|
||||
elif asset_type == "text":
|
||||
# Import text generation service and tracker
|
||||
import asyncio
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from utils.text_asset_tracker import save_and_track_text_content
|
||||
from services.database import SessionLocal
|
||||
|
||||
# Get enhanced prompt from proposal
|
||||
text_prompt = asset_proposal.get('proposed_prompt', '')
|
||||
channel = asset_proposal.get('channel', 'social')
|
||||
asset_id = asset_proposal.get('asset_id', '')
|
||||
|
||||
# Extract campaign_id - try from asset_proposal first, then from asset_id
|
||||
# asset_id format: {campaign_id}_{phase}_{channel}_{type}
|
||||
campaign_id = asset_proposal.get('campaign_id')
|
||||
if not campaign_id and asset_id and '_' in asset_id:
|
||||
# Try to extract: asset_id might be "campaign_user123_1234567890_teaser_instagram_text"
|
||||
# We need to find where phase_name starts (common phases: teaser, launch, nurture)
|
||||
parts = asset_id.split('_')
|
||||
# Find phase indicator (usually one of: teaser, launch, nurture)
|
||||
phase_indicators = ['teaser', 'launch', 'nurture', 'prelaunch', 'postlaunch']
|
||||
phase_idx = None
|
||||
for i, part in enumerate(parts):
|
||||
if part.lower() in phase_indicators:
|
||||
phase_idx = i
|
||||
break
|
||||
if phase_idx and phase_idx > 0:
|
||||
# Campaign ID is everything before the phase
|
||||
campaign_id = '_'.join(parts[:phase_idx])
|
||||
|
||||
# If still not found, use None (metadata will work without it)
|
||||
if not campaign_id:
|
||||
logger.warning(f"[Orchestrator] Could not extract campaign_id from asset_id: {asset_id}")
|
||||
|
||||
# Build system prompt for marketing copy
|
||||
system_prompt = f"""You are an expert marketing copywriter specializing in {channel} content.
|
||||
Generate compelling, on-brand marketing copy that:
|
||||
- Is optimized for {channel} platform best practices
|
||||
- Includes a clear call-to-action
|
||||
- Uses appropriate tone and style for the platform
|
||||
- Is concise and engaging
|
||||
- Aligns with the product marketing context provided
|
||||
|
||||
Return only the final copy text without explanations or markdown formatting."""
|
||||
|
||||
# Run synchronous llm_text_gen in thread pool
|
||||
logger.info(f"[Orchestrator] Generating text asset for channel: {channel}")
|
||||
generated_text = await asyncio.to_thread(
|
||||
llm_text_gen,
|
||||
prompt=text_prompt,
|
||||
system_prompt=system_prompt,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if not generated_text or not generated_text.strip():
|
||||
raise ValueError("Text generation returned empty content")
|
||||
|
||||
# Save to Asset Library
|
||||
db = SessionLocal()
|
||||
asset_library_id = None
|
||||
try:
|
||||
asset_library_id = save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=generated_text.strip(),
|
||||
source_module="product_marketing",
|
||||
title=f"{channel.title()} Copy: {asset_id.split('_')[-1] if '_' in asset_id else 'Marketing Copy'}",
|
||||
description=f"Marketing copy for {channel} platform generated from campaign proposal",
|
||||
prompt=text_prompt,
|
||||
tags=["product_marketing", channel.lower(), "text", "copy"],
|
||||
asset_metadata={
|
||||
"campaign_id": campaign_id,
|
||||
"asset_id": asset_id,
|
||||
"asset_type": "text",
|
||||
"channel": channel,
|
||||
"concept_summary": asset_proposal.get('concept_summary'),
|
||||
},
|
||||
subdirectory="campaigns",
|
||||
file_extension=".txt"
|
||||
)
|
||||
|
||||
if asset_library_id:
|
||||
logger.info(f"[Orchestrator] ✅ Text asset saved to library: ID={asset_library_id}")
|
||||
else:
|
||||
logger.warning(f"[Orchestrator] ⚠️ Text asset tracking returned None")
|
||||
|
||||
except Exception as save_error:
|
||||
logger.error(f"[Orchestrator] ⚠️ Failed to save text asset to library: {str(save_error)}")
|
||||
# Continue even if save fails - text is still generated
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"asset_type": "text",
|
||||
"content": generated_text.strip(),
|
||||
"asset_library_id": asset_library_id,
|
||||
"channel": channel,
|
||||
}
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported asset type: {asset_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Orchestrator] Error generating asset: {str(e)}")
|
||||
raise
|
||||
|
||||
def validate_campaign_preflight(
|
||||
self,
|
||||
user_id: str,
|
||||
blueprint: CampaignBlueprint
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate campaign blueprint against subscription limits before generation.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
blueprint: Campaign blueprint
|
||||
|
||||
Returns:
|
||||
Pre-flight validation results
|
||||
"""
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
# Count operations needed
|
||||
image_count = sum(1 for node in blueprint.asset_nodes if node.asset_type == "image")
|
||||
text_count = sum(1 for node in blueprint.asset_nodes if node.asset_type == "text")
|
||||
|
||||
# Estimate total cost
|
||||
total_cost = 0.0
|
||||
for node in blueprint.asset_nodes:
|
||||
if node.cost_estimate:
|
||||
total_cost += node.cost_estimate
|
||||
|
||||
# Validate image generation limits
|
||||
operations = []
|
||||
if image_count > 0:
|
||||
operations.append({
|
||||
'provider': 'stability', # Default provider
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'wavespeed',
|
||||
'operation_type': 'image_generation',
|
||||
})
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations * image_count if operations else []
|
||||
)
|
||||
|
||||
return {
|
||||
"can_proceed": can_proceed,
|
||||
"message": message,
|
||||
"error_details": error_details,
|
||||
"summary": {
|
||||
"total_assets": len(blueprint.asset_nodes),
|
||||
"image_count": image_count,
|
||||
"text_count": text_count,
|
||||
"estimated_cost": total_cost,
|
||||
},
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Orchestrator] Error in pre-flight validation: {str(e)}")
|
||||
return {
|
||||
"can_proceed": False,
|
||||
"message": f"Validation error: {str(e)}",
|
||||
"error_details": {},
|
||||
}
|
||||
|
||||
def _build_campaign_phases(
|
||||
self,
|
||||
goal: str,
|
||||
channels: List[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Build campaign phases based on goal."""
|
||||
if goal == "product_launch":
|
||||
return [
|
||||
{"name": "teaser", "duration_days": 7, "purpose": "Build anticipation"},
|
||||
{"name": "launch", "duration_days": 3, "purpose": "Official launch"},
|
||||
{"name": "nurture", "duration_days": 14, "purpose": "Sustain engagement"},
|
||||
]
|
||||
else:
|
||||
return [
|
||||
{"name": "campaign", "duration_days": 30, "purpose": "Campaign execution"},
|
||||
]
|
||||
|
||||
def _get_required_assets(
|
||||
self,
|
||||
phase: str,
|
||||
channel: str
|
||||
) -> List[str]:
|
||||
"""Get required asset types for phase and channel."""
|
||||
# Default: image for all phases and channels
|
||||
assets = ["image"]
|
||||
|
||||
# Add text/copy for social channels
|
||||
if channel in ["instagram", "linkedin", "facebook", "twitter"]:
|
||||
assets.append("text")
|
||||
|
||||
return assets
|
||||
|
||||
def _estimate_asset_cost(
|
||||
self,
|
||||
asset_type: str,
|
||||
channel: str
|
||||
) -> float:
|
||||
"""Estimate cost for asset generation."""
|
||||
if asset_type == "image":
|
||||
# Premium quality image: ~5-6 credits
|
||||
return 5.0
|
||||
elif asset_type == "text":
|
||||
return 0.0 # Text generation is typically included
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def _generate_concept_summary(self, prompt: str) -> str:
|
||||
"""Generate a brief concept summary from prompt."""
|
||||
# Simple extraction: take first 100 chars
|
||||
return prompt[:100] + "..." if len(prompt) > 100 else prompt
|
||||
|
||||
634
backend/services/product_marketing/product_image_service.py
Normal file
634
backend/services/product_marketing/product_image_service.py
Normal file
@@ -0,0 +1,634 @@
|
||||
"""
|
||||
Product Image Service
|
||||
Specialized service for generating product-focused images using AI models.
|
||||
Optimized for e-commerce product photography, product showcases, and product marketing assets.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
import os
|
||||
import shutil
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from services.database import SessionLocal
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
class ProductImageServiceError(Exception):
|
||||
"""Base exception for Product Image Service errors."""
|
||||
pass
|
||||
|
||||
|
||||
class ValidationError(ProductImageServiceError):
|
||||
"""Validation error for invalid requests."""
|
||||
pass
|
||||
|
||||
|
||||
class ImageGenerationError(ProductImageServiceError):
|
||||
"""Error during image generation."""
|
||||
pass
|
||||
|
||||
|
||||
class StorageError(ProductImageServiceError):
|
||||
"""Error saving image to storage."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProductImageRequest:
|
||||
"""Request for product image generation."""
|
||||
product_name: str
|
||||
product_description: str
|
||||
environment: str = "studio" # studio, lifestyle, outdoor, minimalist, luxury
|
||||
background_style: str = "white" # white, transparent, lifestyle, branded
|
||||
lighting: str = "natural" # natural, studio, dramatic, soft
|
||||
product_variant: Optional[str] = None # color, size, etc.
|
||||
angle: Optional[str] = None # front, side, top, 360, etc.
|
||||
style: str = "photorealistic" # photorealistic, minimalist, luxury, technical
|
||||
resolution: str = "1024x1024" # 1024x1024, 1280x720, etc.
|
||||
num_variations: int = 1
|
||||
brand_colors: Optional[List[str]] = None # Brand color palette
|
||||
additional_context: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProductImageResult:
|
||||
"""Result from product image generation."""
|
||||
success: bool
|
||||
product_name: str
|
||||
image_url: Optional[str] = None
|
||||
image_bytes: Optional[bytes] = None
|
||||
asset_id: Optional[int] = None # Asset Library ID
|
||||
provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
cost: float = 0.0
|
||||
generation_time: float = 0.0
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class ProductImageService:
|
||||
"""Service for generating product marketing images."""
|
||||
|
||||
# Product photography style presets
|
||||
ENVIRONMENT_PROMPTS = {
|
||||
"studio": "professional studio photography, clean white background, even lighting",
|
||||
"lifestyle": "lifestyle photography, product in use, natural environment, relatable setting",
|
||||
"outdoor": "outdoor photography, natural lighting, outdoor environment, dynamic setting",
|
||||
"minimalist": "minimalist product photography, simple composition, clean aesthetic",
|
||||
"luxury": "luxury product photography, premium aesthetic, sophisticated lighting, high-end",
|
||||
}
|
||||
|
||||
BACKGROUND_STYLES = {
|
||||
"white": "clean white background",
|
||||
"transparent": "transparent background, isolated product",
|
||||
"lifestyle": "lifestyle background, contextual environment",
|
||||
"branded": "branded background with brand colors",
|
||||
}
|
||||
|
||||
LIGHTING_STYLES = {
|
||||
"natural": "natural lighting, soft shadows, balanced exposure",
|
||||
"studio": "professional studio lighting, even illumination, no harsh shadows",
|
||||
"dramatic": "dramatic lighting, high contrast, artistic shadows",
|
||||
"soft": "soft diffused lighting, gentle shadows, elegant",
|
||||
}
|
||||
|
||||
# Valid values for request parameters
|
||||
VALID_ENVIRONMENTS = {"studio", "lifestyle", "outdoor", "minimalist", "luxury"}
|
||||
VALID_BACKGROUND_STYLES = {"white", "transparent", "lifestyle", "branded"}
|
||||
VALID_LIGHTING_STYLES = {"natural", "studio", "dramatic", "soft"}
|
||||
VALID_STYLES = {"photorealistic", "minimalist", "luxury", "technical"}
|
||||
VALID_ANGLES = {"front", "side", "top", "360"}
|
||||
|
||||
# Maximum values
|
||||
MAX_RESOLUTION = (4096, 4096)
|
||||
MIN_RESOLUTION = (256, 256)
|
||||
MAX_NUM_VARIATIONS = 10
|
||||
MAX_PRODUCT_NAME_LENGTH = 500
|
||||
MAX_PRODUCT_DESCRIPTION_LENGTH = 2000
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Product Image Service."""
|
||||
try:
|
||||
self.wavespeed_client = WaveSpeedClient()
|
||||
logger.info("[Product Image Service] Initialized")
|
||||
except Exception as e:
|
||||
logger.error(f"[Product Image Service] Failed to initialize WaveSpeed client: {str(e)}")
|
||||
raise ProductImageServiceError(f"Failed to initialize service: {str(e)}") from e
|
||||
|
||||
def validate_request(self, request: ProductImageRequest) -> None:
|
||||
"""
|
||||
Validate product image generation request.
|
||||
|
||||
Args:
|
||||
request: Product image generation request
|
||||
|
||||
Raises:
|
||||
ValidationError: If request is invalid
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# Validate product_name
|
||||
if not request.product_name or not request.product_name.strip():
|
||||
errors.append("Product name is required")
|
||||
elif len(request.product_name) > self.MAX_PRODUCT_NAME_LENGTH:
|
||||
errors.append(f"Product name must be <= {self.MAX_PRODUCT_NAME_LENGTH} characters")
|
||||
|
||||
# Validate product_description
|
||||
if request.product_description and len(request.product_description) > self.MAX_PRODUCT_DESCRIPTION_LENGTH:
|
||||
errors.append(f"Product description must be <= {self.MAX_PRODUCT_DESCRIPTION_LENGTH} characters")
|
||||
|
||||
# Validate environment
|
||||
if request.environment not in self.VALID_ENVIRONMENTS:
|
||||
errors.append(f"Invalid environment: {request.environment}. Valid: {', '.join(self.VALID_ENVIRONMENTS)}")
|
||||
|
||||
# Validate background_style
|
||||
if request.background_style not in self.VALID_BACKGROUND_STYLES:
|
||||
errors.append(f"Invalid background_style: {request.background_style}. Valid: {', '.join(self.VALID_BACKGROUND_STYLES)}")
|
||||
|
||||
# Validate lighting
|
||||
if request.lighting not in self.VALID_LIGHTING_STYLES:
|
||||
errors.append(f"Invalid lighting: {request.lighting}. Valid: {', '.join(self.VALID_LIGHTING_STYLES)}")
|
||||
|
||||
# Validate style
|
||||
if request.style not in self.VALID_STYLES:
|
||||
errors.append(f"Invalid style: {request.style}. Valid: {', '.join(self.VALID_STYLES)}")
|
||||
|
||||
# Validate angle
|
||||
if request.angle and request.angle not in self.VALID_ANGLES:
|
||||
errors.append(f"Invalid angle: {request.angle}. Valid: {', '.join(self.VALID_ANGLES)}")
|
||||
|
||||
# Validate num_variations
|
||||
if request.num_variations < 1:
|
||||
errors.append("num_variations must be >= 1")
|
||||
elif request.num_variations > self.MAX_NUM_VARIATIONS:
|
||||
errors.append(f"num_variations must be <= {self.MAX_NUM_VARIATIONS}")
|
||||
|
||||
# Validate resolution
|
||||
try:
|
||||
width, height = self._parse_resolution(request.resolution)
|
||||
if width < self.MIN_RESOLUTION[0] or height < self.MIN_RESOLUTION[1]:
|
||||
errors.append(f"Resolution must be >= {self.MIN_RESOLUTION[0]}x{self.MIN_RESOLUTION[1]}")
|
||||
if width > self.MAX_RESOLUTION[0] or height > self.MAX_RESOLUTION[1]:
|
||||
errors.append(f"Resolution must be <= {self.MAX_RESOLUTION[0]}x{self.MAX_RESOLUTION[1]}")
|
||||
except Exception as e:
|
||||
errors.append(f"Invalid resolution format: {request.resolution}. Error: {str(e)}")
|
||||
|
||||
if errors:
|
||||
raise ValidationError(f"Validation failed: {'; '.join(errors)}")
|
||||
|
||||
def build_product_prompt(
|
||||
self,
|
||||
request: ProductImageRequest,
|
||||
brand_context: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Build optimized prompt for product image generation.
|
||||
|
||||
Args:
|
||||
request: Product image generation request
|
||||
brand_context: Optional brand DNA context for personalization
|
||||
|
||||
Returns:
|
||||
Optimized prompt string
|
||||
"""
|
||||
prompt_parts = []
|
||||
|
||||
# Base product description
|
||||
prompt_parts.append(f"Professional product photography of {request.product_name}")
|
||||
if request.product_description:
|
||||
prompt_parts.append(f": {request.product_description}")
|
||||
|
||||
# Product variant
|
||||
if request.product_variant:
|
||||
prompt_parts.append(f", {request.product_variant}")
|
||||
|
||||
# Environment and style
|
||||
env_prompt = self.ENVIRONMENT_PROMPTS.get(request.environment, self.ENVIRONMENT_PROMPTS["studio"])
|
||||
prompt_parts.append(f", {env_prompt}")
|
||||
|
||||
# Background
|
||||
bg_prompt = self.BACKGROUND_STYLES.get(request.background_style, self.BACKGROUND_STYLES["white"])
|
||||
if request.background_style == "branded" and request.brand_colors:
|
||||
bg_prompt += f", using brand colors: {', '.join(request.brand_colors)}"
|
||||
prompt_parts.append(f", {bg_prompt}")
|
||||
|
||||
# Lighting
|
||||
lighting_prompt = self.LIGHTING_STYLES.get(request.lighting, self.LIGHTING_STYLES["natural"])
|
||||
prompt_parts.append(f", {lighting_prompt}")
|
||||
|
||||
# Angle/view
|
||||
if request.angle:
|
||||
angle_map = {
|
||||
"front": "front view, centered composition",
|
||||
"side": "side profile view, showing depth",
|
||||
"top": "top-down view, flat lay style",
|
||||
"360": "3/4 angle view, showing multiple sides",
|
||||
}
|
||||
angle_prompt = angle_map.get(request.angle, request.angle)
|
||||
prompt_parts.append(f", {angle_prompt}")
|
||||
|
||||
# Style
|
||||
style_map = {
|
||||
"photorealistic": "photorealistic, highly detailed, professional photography",
|
||||
"minimalist": "minimalist aesthetic, clean composition, simple and elegant",
|
||||
"luxury": "luxury aesthetic, premium quality, sophisticated and refined",
|
||||
"technical": "technical product photography, detailed features, professional documentation style",
|
||||
}
|
||||
style_prompt = style_map.get(request.style, style_map["photorealistic"])
|
||||
prompt_parts.append(f", {style_prompt}")
|
||||
|
||||
# Additional context
|
||||
if request.additional_context:
|
||||
prompt_parts.append(f", {request.additional_context}")
|
||||
|
||||
# Brand DNA integration (if available)
|
||||
if brand_context:
|
||||
brand_tone = brand_context.get("visual_identity", {}).get("style_guidelines")
|
||||
if brand_tone:
|
||||
prompt_parts.append(f", brand style: {brand_tone}")
|
||||
|
||||
# Quality keywords
|
||||
prompt_parts.append(", high resolution, professional quality, sharp focus, commercial photography")
|
||||
|
||||
full_prompt = " ".join(prompt_parts)
|
||||
logger.debug(f"[Product Image Service] Built prompt: {full_prompt[:200]}...")
|
||||
|
||||
return full_prompt
|
||||
|
||||
def _generate_image_with_retry(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
width: int,
|
||||
height: int,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 2.0
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate image with retry logic for transient failures.
|
||||
|
||||
Args:
|
||||
model: Model to use
|
||||
prompt: Generation prompt
|
||||
width: Image width
|
||||
height: Image height
|
||||
max_retries: Maximum number of retries
|
||||
retry_delay: Delay between retries in seconds
|
||||
|
||||
Returns:
|
||||
Generated image bytes
|
||||
|
||||
Raises:
|
||||
ImageGenerationError: If generation fails after retries
|
||||
"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
logger.info(f"[Product Image Service] Image generation attempt {attempt + 1}/{max_retries}")
|
||||
|
||||
image_bytes = self.wavespeed_client.generate_image(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
enable_sync_mode=True,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
if not image_bytes:
|
||||
raise ValueError("Image generation returned empty result")
|
||||
|
||||
if len(image_bytes) < 100: # Sanity check: image should be at least 100 bytes
|
||||
raise ValueError(f"Generated image too small: {len(image_bytes)} bytes")
|
||||
|
||||
logger.info(f"[Product Image Service] ✅ Image generated successfully: {len(image_bytes)} bytes")
|
||||
return image_bytes
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
error_msg = str(e)
|
||||
logger.warning(f"[Product Image Service] Attempt {attempt + 1} failed: {error_msg}")
|
||||
|
||||
# Don't retry on validation errors or client errors (4xx)
|
||||
if "4" in error_msg or "validation" in error_msg.lower() or "invalid" in error_msg.lower():
|
||||
logger.error(f"[Product Image Service] Non-retryable error: {error_msg}")
|
||||
raise ImageGenerationError(f"Image generation failed: {error_msg}") from e
|
||||
|
||||
# Retry on transient errors
|
||||
if attempt < max_retries - 1:
|
||||
logger.info(f"[Product Image Service] Retrying in {retry_delay} seconds...")
|
||||
time.sleep(retry_delay)
|
||||
retry_delay *= 1.5 # Exponential backoff
|
||||
else:
|
||||
logger.error(f"[Product Image Service] All retry attempts failed")
|
||||
|
||||
raise ImageGenerationError(f"Image generation failed after {max_retries} attempts: {str(last_error)}") from last_error
|
||||
|
||||
async def generate_product_image(
|
||||
self,
|
||||
request: ProductImageRequest,
|
||||
user_id: str,
|
||||
brand_context: Optional[Dict[str, Any]] = None
|
||||
) -> ProductImageResult:
|
||||
"""
|
||||
Generate product image using AI models.
|
||||
|
||||
Args:
|
||||
request: Product image generation request
|
||||
user_id: User ID for tracking
|
||||
brand_context: Optional brand DNA for personalization
|
||||
|
||||
Returns:
|
||||
ProductImageResult with generated image
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Validate request
|
||||
self.validate_request(request)
|
||||
|
||||
# Validate user_id
|
||||
if not user_id or not user_id.strip():
|
||||
raise ValidationError("user_id is required")
|
||||
|
||||
# Build optimized prompt
|
||||
prompt = self.build_product_prompt(request, brand_context)
|
||||
|
||||
# Parse resolution
|
||||
width, height = self._parse_resolution(request.resolution)
|
||||
|
||||
# Select model based on style/quality needs
|
||||
model = "ideogram-v3-turbo" # Default to Ideogram V3 for photorealistic products
|
||||
if request.style == "minimalist":
|
||||
model = "ideogram-v3-turbo" # Still use Ideogram for quality
|
||||
elif request.style == "technical":
|
||||
model = "ideogram-v3-turbo"
|
||||
|
||||
logger.info(f"[Product Image Service] Generating product image for '{request.product_name}' using {model}")
|
||||
|
||||
# Generate image using WaveSpeed with retry logic
|
||||
try:
|
||||
image_bytes = self._generate_image_with_retry(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
max_retries=3,
|
||||
retry_delay=2.0
|
||||
)
|
||||
except ImageGenerationError as e:
|
||||
logger.error(f"[Product Image Service] Image generation failed: {str(e)}")
|
||||
generation_time = time.time() - start_time
|
||||
return ProductImageResult(
|
||||
success=False,
|
||||
product_name=request.product_name,
|
||||
error=f"Image generation failed: {str(e)}",
|
||||
generation_time=generation_time,
|
||||
)
|
||||
|
||||
# Save image to file and Asset Library
|
||||
asset_id = None
|
||||
image_url = None
|
||||
|
||||
try:
|
||||
asset_id, image_url = self._save_product_image(
|
||||
image_bytes=image_bytes,
|
||||
request=request,
|
||||
user_id=user_id,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
start_time=start_time
|
||||
)
|
||||
except StorageError as storage_error:
|
||||
logger.error(f"[Product Image Service] Storage failed: {str(storage_error)}", exc_info=True)
|
||||
# Continue with generation result even if storage fails
|
||||
# The image_bytes is still available in the result
|
||||
except Exception as save_error:
|
||||
logger.error(f"[Product Image Service] Unexpected error saving image: {str(save_error)}", exc_info=True)
|
||||
# Continue even if save fails
|
||||
|
||||
generation_time = time.time() - start_time
|
||||
|
||||
return ProductImageResult(
|
||||
success=True,
|
||||
product_name=request.product_name,
|
||||
image_url=image_url,
|
||||
image_bytes=image_bytes,
|
||||
asset_id=asset_id,
|
||||
provider="wavespeed",
|
||||
model=model,
|
||||
cost=0.10,
|
||||
generation_time=generation_time,
|
||||
)
|
||||
|
||||
except ValidationError as ve:
|
||||
logger.error(f"[Product Image Service] Validation error: {str(ve)}")
|
||||
generation_time = time.time() - start_time
|
||||
return ProductImageResult(
|
||||
success=False,
|
||||
product_name=request.product_name if hasattr(request, 'product_name') else "unknown",
|
||||
error=f"Validation error: {str(ve)}",
|
||||
generation_time=generation_time,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Product Image Service] ❌ Unexpected error generating product image: {str(e)}", exc_info=True)
|
||||
generation_time = time.time() - start_time
|
||||
return ProductImageResult(
|
||||
success=False,
|
||||
product_name=request.product_name if hasattr(request, 'product_name') else "unknown",
|
||||
error=f"Unexpected error: {str(e)}",
|
||||
generation_time=generation_time,
|
||||
)
|
||||
|
||||
def _save_product_image(
|
||||
self,
|
||||
image_bytes: bytes,
|
||||
request: ProductImageRequest,
|
||||
user_id: str,
|
||||
prompt: str,
|
||||
model: str,
|
||||
start_time: float
|
||||
) -> tuple[Optional[int], Optional[str]]:
|
||||
"""
|
||||
Save product image to disk and Asset Library.
|
||||
|
||||
Args:
|
||||
image_bytes: Generated image bytes
|
||||
request: Product image generation request
|
||||
user_id: User ID
|
||||
prompt: Generation prompt
|
||||
model: Model used
|
||||
start_time: Generation start time
|
||||
|
||||
Returns:
|
||||
Tuple of (asset_id, image_url)
|
||||
|
||||
Raises:
|
||||
StorageError: If saving fails
|
||||
"""
|
||||
db = None
|
||||
asset_id = None
|
||||
image_url = None
|
||||
image_path = None
|
||||
|
||||
try:
|
||||
# Generate filename
|
||||
product_hash = hashlib.md5(request.product_name.encode()).hexdigest()[:8]
|
||||
timestamp = int(start_time)
|
||||
filename = f"product_{product_hash}_{timestamp}.png"
|
||||
|
||||
# Determine base directory and create product_images folder
|
||||
base_dir = Path(__file__).parent.parent.parent
|
||||
product_images_dir = base_dir / "product_images"
|
||||
|
||||
# Create directory with error handling
|
||||
try:
|
||||
product_images_dir.mkdir(parents=True, exist_ok=True)
|
||||
except PermissionError as pe:
|
||||
raise StorageError(f"Permission denied creating directory: {str(pe)}") from pe
|
||||
except OSError as oe:
|
||||
raise StorageError(f"Failed to create directory: {str(oe)}") from oe
|
||||
|
||||
# Check disk space (rough estimate - at least 10MB free)
|
||||
try:
|
||||
stat = shutil.disk_usage(product_images_dir)
|
||||
free_space_mb = stat.free / (1024 * 1024)
|
||||
if free_space_mb < 10:
|
||||
raise StorageError(f"Insufficient disk space: {free_space_mb:.1f}MB free (need at least 10MB)")
|
||||
except OSError as oe:
|
||||
logger.warning(f"[Product Image Service] Could not check disk space: {str(oe)}")
|
||||
|
||||
# Save image to disk
|
||||
image_path = product_images_dir / filename
|
||||
try:
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
# Verify file was written
|
||||
if not image_path.exists() or image_path.stat().st_size == 0:
|
||||
raise StorageError("Image file was not written correctly")
|
||||
except PermissionError as pe:
|
||||
raise StorageError(f"Permission denied writing file: {str(pe)}") from pe
|
||||
except OSError as oe:
|
||||
raise StorageError(f"Failed to write file: {str(oe)}") from oe
|
||||
|
||||
file_size = len(image_bytes)
|
||||
image_url = f"/api/product-marketing/images/{filename}"
|
||||
|
||||
# Save to Asset Library
|
||||
db = SessionLocal()
|
||||
try:
|
||||
asset_id = save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="image",
|
||||
source_module="product_marketing",
|
||||
filename=filename,
|
||||
file_url=image_url,
|
||||
file_path=str(image_path),
|
||||
file_size=file_size,
|
||||
mime_type="image/png",
|
||||
title=f"{request.product_name} - Product Image",
|
||||
description=f"Product image: {request.product_description or request.product_name}",
|
||||
prompt=prompt,
|
||||
tags=["product_marketing", "product_image", request.environment, request.style],
|
||||
provider="wavespeed",
|
||||
model=model,
|
||||
cost=0.10, # Estimated cost for Ideogram V3
|
||||
asset_metadata={
|
||||
"product_name": request.product_name,
|
||||
"product_description": request.product_description,
|
||||
"environment": request.environment,
|
||||
"background_style": request.background_style,
|
||||
"lighting": request.lighting,
|
||||
"style": request.style,
|
||||
"variant": request.product_variant,
|
||||
"angle": request.angle,
|
||||
},
|
||||
)
|
||||
|
||||
if asset_id:
|
||||
logger.info(f"[Product Image Service] ✅ Saved product image to Asset Library: ID={asset_id}")
|
||||
else:
|
||||
logger.warning(f"[Product Image Service] ⚠️ Asset Library save returned None (file saved but not tracked)")
|
||||
|
||||
except Exception as db_error:
|
||||
logger.error(f"[Product Image Service] Database error saving to Asset Library: {str(db_error)}", exc_info=True)
|
||||
# File is saved, but database tracking failed
|
||||
# This is not critical - image is still accessible
|
||||
raise StorageError(f"Failed to save to Asset Library: {str(db_error)}") from db_error
|
||||
finally:
|
||||
if db:
|
||||
try:
|
||||
db.close()
|
||||
except Exception as close_error:
|
||||
logger.warning(f"[Product Image Service] Error closing database: {str(close_error)}")
|
||||
|
||||
return (asset_id, image_url)
|
||||
|
||||
except StorageError:
|
||||
# Clean up partial files on storage error
|
||||
if image_path and image_path.exists():
|
||||
try:
|
||||
image_path.unlink()
|
||||
logger.info(f"[Product Image Service] Cleaned up partial file: {image_path}")
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"[Product Image Service] Failed to cleanup partial file: {str(cleanup_error)}")
|
||||
raise
|
||||
|
||||
def _parse_resolution(self, resolution: str) -> tuple[int, int]:
|
||||
"""
|
||||
Parse resolution string to width, height tuple.
|
||||
|
||||
Args:
|
||||
resolution: Resolution string (e.g., "1024x1024", "square", "landscape")
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
try:
|
||||
resolution = resolution.strip().lower()
|
||||
|
||||
if "x" in resolution:
|
||||
parts = resolution.split("x")
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"Invalid resolution format: {resolution}")
|
||||
width = int(parts[0].strip())
|
||||
height = int(parts[1].strip())
|
||||
|
||||
# Validate resolution values
|
||||
if width < 1 or height < 1:
|
||||
raise ValueError(f"Resolution dimensions must be positive: {width}x{height}")
|
||||
|
||||
return (width, height)
|
||||
elif resolution == "square":
|
||||
return (1024, 1024)
|
||||
elif resolution == "landscape":
|
||||
return (1280, 720)
|
||||
elif resolution == "portrait":
|
||||
return (720, 1280)
|
||||
else:
|
||||
# Try to parse as single number (assume square)
|
||||
try:
|
||||
size = int(resolution)
|
||||
return (size, size)
|
||||
except ValueError:
|
||||
# Default to square
|
||||
logger.warning(f"[Product Image Service] Could not parse resolution '{resolution}', defaulting to 1024x1024")
|
||||
return (1024, 1024)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Product Image Service] Error parsing resolution '{resolution}': {str(e)}, defaulting to 1024x1024")
|
||||
return (1024, 1024)
|
||||
|
||||
def estimate_cost(self, request: ProductImageRequest) -> float:
|
||||
"""Estimate cost for product image generation."""
|
||||
# Ideogram V3 Turbo: ~$0.10 per image
|
||||
# Multiply by number of variations
|
||||
base_cost = 0.10
|
||||
return base_cost * request.num_variations
|
||||
|
||||
304
backend/services/product_marketing/prompt_builder.py
Normal file
304
backend/services/product_marketing/prompt_builder.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Product Marketing Prompt Builder
|
||||
Extends AIPromptOptimizer with marketing-specific prompt enhancement.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from loguru import logger
|
||||
|
||||
from services.ai_prompt_optimizer import AIPromptOptimizer
|
||||
from services.onboarding import OnboardingDataService
|
||||
from services.onboarding.database_service import OnboardingDatabaseService
|
||||
from services.persona_data_service import PersonaDataService
|
||||
from services.database import SessionLocal
|
||||
|
||||
|
||||
class ProductMarketingPromptBuilder(AIPromptOptimizer):
|
||||
"""Specialized prompt builder for marketing assets with onboarding data integration."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Product Marketing Prompt Builder."""
|
||||
super().__init__()
|
||||
self.onboarding_data_service = OnboardingDataService()
|
||||
self.logger = logger
|
||||
logger.info("[Product Marketing Prompt Builder] Initialized")
|
||||
|
||||
def build_marketing_image_prompt(
|
||||
self,
|
||||
base_prompt: str,
|
||||
user_id: str,
|
||||
channel: Optional[str] = None,
|
||||
asset_type: str = "hero_image",
|
||||
product_context: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Build enhanced marketing image prompt with brand DNA and persona data.
|
||||
|
||||
Args:
|
||||
base_prompt: Base product description or image concept
|
||||
user_id: User ID to fetch onboarding data
|
||||
channel: Target channel (instagram, linkedin, tiktok, etc.)
|
||||
asset_type: Type of asset (hero_image, product_photo, lifestyle, etc.)
|
||||
product_context: Additional product information
|
||||
|
||||
Returns:
|
||||
Enhanced prompt with brand DNA, persona style, and marketing context
|
||||
"""
|
||||
try:
|
||||
# Get onboarding data
|
||||
db = SessionLocal()
|
||||
try:
|
||||
onboarding_db = OnboardingDatabaseService(db)
|
||||
website_analysis = onboarding_db.get_website_analysis(user_id, db)
|
||||
persona_data = onboarding_db.get_persona_data(user_id, db)
|
||||
competitor_analyses = onboarding_db.get_competitor_analysis(user_id, db)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Build prompt layers
|
||||
enhanced_prompt = base_prompt
|
||||
|
||||
# Layer 1: Brand DNA (from website_analysis)
|
||||
if website_analysis:
|
||||
writing_style = website_analysis.get('writing_style', {})
|
||||
target_audience = website_analysis.get('target_audience', {})
|
||||
brand_analysis = website_analysis.get('brand_analysis', {})
|
||||
style_guidelines = website_analysis.get('style_guidelines', {})
|
||||
|
||||
# Add brand tone and style
|
||||
tone = writing_style.get('tone', 'professional')
|
||||
voice = writing_style.get('voice', 'authoritative')
|
||||
brand_enhancement = f", {tone} tone, {voice} voice"
|
||||
|
||||
# Add target audience context
|
||||
demographics = target_audience.get('demographics', [])
|
||||
if demographics:
|
||||
audience_context = f", targeting {', '.join(demographics[:2])}"
|
||||
enhanced_prompt += audience_context
|
||||
|
||||
# Add brand visual identity if available
|
||||
if brand_analysis:
|
||||
color_palette = brand_analysis.get('color_palette', [])
|
||||
if color_palette:
|
||||
colors = ', '.join(color_palette[:3])
|
||||
enhanced_prompt += f", brand colors: {colors}"
|
||||
|
||||
# Layer 2: Persona Visual Style (from persona_data)
|
||||
if persona_data:
|
||||
core_persona = persona_data.get('corePersona', {})
|
||||
platform_personas = persona_data.get('platformPersonas', {})
|
||||
|
||||
if core_persona:
|
||||
persona_name = core_persona.get('persona_name', '')
|
||||
archetype = core_persona.get('archetype', '')
|
||||
if persona_name:
|
||||
enhanced_prompt += f", {persona_name} style"
|
||||
|
||||
# Channel-specific persona adaptation
|
||||
if channel and platform_personas:
|
||||
platform_persona = platform_personas.get(channel, {})
|
||||
if platform_persona:
|
||||
visual_identity = platform_persona.get('visual_identity', {})
|
||||
if visual_identity:
|
||||
aesthetic = visual_identity.get('aesthetic_preferences', '')
|
||||
if aesthetic:
|
||||
enhanced_prompt += f", {aesthetic} aesthetic"
|
||||
|
||||
# Layer 3: Channel Optimization
|
||||
channel_enhancements = {
|
||||
'instagram': ', Instagram-optimized composition, vibrant colors, engaging visual',
|
||||
'linkedin': ', professional photography, clean composition, business-focused',
|
||||
'tiktok': ', dynamic composition, eye-catching, vertical format optimized',
|
||||
'facebook': ', social media optimized, engaging, shareable visual',
|
||||
'twitter': ', Twitter card optimized, clear focal point, readable at small size',
|
||||
'pinterest': ', Pinterest-optimized, vertical format, detailed and informative',
|
||||
}
|
||||
|
||||
if channel and channel.lower() in channel_enhancements:
|
||||
enhanced_prompt += channel_enhancements[channel.lower()]
|
||||
|
||||
# Layer 4: Asset Type Specific
|
||||
asset_type_enhancements = {
|
||||
'hero_image': ', hero image style, prominent product placement, professional photography',
|
||||
'product_photo': ', product photography, clean background, detailed product showcase',
|
||||
'lifestyle': ', lifestyle photography, natural setting, authentic scene',
|
||||
'social_post': ', social media post, engaging composition, optimized for engagement',
|
||||
}
|
||||
|
||||
if asset_type in asset_type_enhancements:
|
||||
enhanced_prompt += asset_type_enhancements[asset_type]
|
||||
|
||||
# Layer 5: Competitive Differentiation
|
||||
if competitor_analyses and len(competitor_analyses) > 0:
|
||||
# Extract unique positioning from competitor analysis
|
||||
enhanced_prompt += ", unique positioning, differentiated visual style"
|
||||
|
||||
# Layer 6: Quality Descriptors
|
||||
enhanced_prompt += ", professional photography, high quality, detailed, sharp focus, natural lighting"
|
||||
|
||||
# Layer 7: Marketing Context
|
||||
if product_context:
|
||||
marketing_goal = product_context.get('marketing_goal', '')
|
||||
if marketing_goal:
|
||||
enhanced_prompt += f", {marketing_goal} focused"
|
||||
|
||||
logger.info(f"[Marketing Prompt] Enhanced prompt for user {user_id}: {enhanced_prompt[:200]}...")
|
||||
return enhanced_prompt
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Marketing Prompt] Error building prompt: {str(e)}")
|
||||
# Return base prompt with minimal enhancement if error
|
||||
return f"{base_prompt}, professional photography, high quality"
|
||||
|
||||
def build_marketing_copy_prompt(
|
||||
self,
|
||||
base_request: str,
|
||||
user_id: str,
|
||||
channel: Optional[str] = None,
|
||||
content_type: str = "caption",
|
||||
product_context: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Build enhanced marketing copy prompt with persona linguistic fingerprint.
|
||||
|
||||
Args:
|
||||
base_request: Base content request (e.g., "Write Instagram caption for product launch")
|
||||
user_id: User ID to fetch onboarding data
|
||||
channel: Target channel (instagram, linkedin, etc.)
|
||||
content_type: Type of content (caption, cta, email, ad_copy, etc.)
|
||||
product_context: Additional product information
|
||||
|
||||
Returns:
|
||||
Enhanced prompt with persona style, brand voice, and marketing context
|
||||
"""
|
||||
try:
|
||||
# Get onboarding data
|
||||
db = SessionLocal()
|
||||
try:
|
||||
onboarding_db = OnboardingDatabaseService(db)
|
||||
website_analysis = onboarding_db.get_website_analysis(user_id, db)
|
||||
persona_data = onboarding_db.get_persona_data(user_id, db)
|
||||
competitor_analyses = onboarding_db.get_competitor_analysis(user_id, db)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Build enhanced prompt
|
||||
enhanced_prompt = base_request
|
||||
|
||||
# Add persona linguistic fingerprint
|
||||
if persona_data:
|
||||
core_persona = persona_data.get('corePersona', {})
|
||||
platform_personas = persona_data.get('platformPersonas', {})
|
||||
|
||||
if core_persona:
|
||||
persona_name = core_persona.get('persona_name', '')
|
||||
linguistic_fingerprint = core_persona.get('linguistic_fingerprint', {})
|
||||
|
||||
if persona_name:
|
||||
enhanced_prompt += f"\n\nFollow {persona_name} persona style:"
|
||||
|
||||
if linguistic_fingerprint:
|
||||
sentence_metrics = linguistic_fingerprint.get('sentence_metrics', {})
|
||||
lexical_features = linguistic_fingerprint.get('lexical_features', {})
|
||||
|
||||
if sentence_metrics:
|
||||
avg_length = sentence_metrics.get('average_sentence_length_words', '')
|
||||
if avg_length:
|
||||
enhanced_prompt += f"\n- Average sentence length: {avg_length} words"
|
||||
|
||||
if lexical_features:
|
||||
go_to_words = lexical_features.get('go_to_words', [])
|
||||
avoid_words = lexical_features.get('avoid_words', [])
|
||||
vocabulary_level = lexical_features.get('vocabulary_level', '')
|
||||
|
||||
if go_to_words:
|
||||
enhanced_prompt += f"\n- Use these words: {', '.join(go_to_words[:5])}"
|
||||
if avoid_words:
|
||||
enhanced_prompt += f"\n- Avoid these words: {', '.join(avoid_words[:5])}"
|
||||
if vocabulary_level:
|
||||
enhanced_prompt += f"\n- Vocabulary level: {vocabulary_level}"
|
||||
|
||||
# Channel-specific persona adaptation
|
||||
if channel and platform_personas:
|
||||
platform_persona = platform_personas.get(channel, {})
|
||||
if platform_persona:
|
||||
content_format_rules = platform_persona.get('content_format_rules', {})
|
||||
engagement_patterns = platform_persona.get('engagement_patterns', {})
|
||||
|
||||
if content_format_rules:
|
||||
char_limit = content_format_rules.get('character_limit', '')
|
||||
hashtag_strategy = content_format_rules.get('hashtag_strategy', '')
|
||||
|
||||
if char_limit:
|
||||
enhanced_prompt += f"\n- Character limit: {char_limit}"
|
||||
if hashtag_strategy:
|
||||
enhanced_prompt += f"\n- Hashtag strategy: {hashtag_strategy}"
|
||||
|
||||
# Add brand voice
|
||||
if website_analysis:
|
||||
writing_style = website_analysis.get('writing_style', {})
|
||||
target_audience = website_analysis.get('target_audience', {})
|
||||
|
||||
tone = writing_style.get('tone', 'professional')
|
||||
voice = writing_style.get('voice', 'authoritative')
|
||||
enhanced_prompt += f"\n- Brand tone: {tone}, Brand voice: {voice}"
|
||||
|
||||
demographics = target_audience.get('demographics', [])
|
||||
expertise_level = target_audience.get('expertise_level', 'intermediate')
|
||||
if demographics:
|
||||
enhanced_prompt += f"\n- Target audience: {', '.join(demographics[:2])}, {expertise_level} level"
|
||||
|
||||
# Add competitive positioning
|
||||
if competitor_analyses and len(competitor_analyses) > 0:
|
||||
enhanced_prompt += "\n- Differentiate from competitors, highlight unique value propositions"
|
||||
|
||||
# Add marketing context
|
||||
if product_context:
|
||||
marketing_goal = product_context.get('marketing_goal', '')
|
||||
if marketing_goal:
|
||||
enhanced_prompt += f"\n- Marketing goal: {marketing_goal}"
|
||||
|
||||
logger.info(f"[Marketing Copy Prompt] Enhanced for user {user_id}: {enhanced_prompt[:200]}...")
|
||||
return enhanced_prompt
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Marketing Copy Prompt] Error building prompt: {str(e)}")
|
||||
return base_request
|
||||
|
||||
def optimize_marketing_prompt(
|
||||
self,
|
||||
prompt_type: str,
|
||||
base_prompt: str,
|
||||
user_id: str,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Main entry point for marketing prompt optimization.
|
||||
|
||||
Args:
|
||||
prompt_type: Type of prompt (image, copy, video_script, etc.)
|
||||
base_prompt: Base prompt to enhance
|
||||
user_id: User ID for personalization
|
||||
context: Additional context (channel, asset_type, product_context, etc.)
|
||||
|
||||
Returns:
|
||||
Optimized marketing prompt
|
||||
"""
|
||||
context = context or {}
|
||||
channel = context.get('channel')
|
||||
asset_type = context.get('asset_type', 'hero_image')
|
||||
content_type = context.get('content_type', 'caption')
|
||||
product_context = context.get('product_context')
|
||||
|
||||
if prompt_type == 'image':
|
||||
return self.build_marketing_image_prompt(
|
||||
base_prompt, user_id, channel, asset_type, product_context
|
||||
)
|
||||
elif prompt_type in ['copy', 'caption', 'cta', 'email', 'ad_copy']:
|
||||
return self.build_marketing_copy_prompt(
|
||||
base_prompt, user_id, channel, content_type, product_context
|
||||
)
|
||||
else:
|
||||
# Default: minimal enhancement
|
||||
return f"{base_prompt}, professional quality, marketing optimized"
|
||||
|
||||
@@ -51,7 +51,7 @@ def save_asset_to_library(
|
||||
description: Optional[str] = None,
|
||||
prompt: Optional[str] = None,
|
||||
tags: Optional[list] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
asset_metadata: Optional[Dict[str, Any]] = None,
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
cost: Optional[float] = None,
|
||||
@@ -77,7 +77,7 @@ def save_asset_to_library(
|
||||
description: Asset description (optional)
|
||||
prompt: Generation prompt (optional)
|
||||
tags: List of tags (optional)
|
||||
metadata: Additional metadata (optional)
|
||||
asset_metadata: Additional metadata (optional)
|
||||
provider: AI provider used (optional)
|
||||
model: Model used (optional)
|
||||
cost: Generation cost (optional)
|
||||
@@ -143,7 +143,7 @@ def save_asset_to_library(
|
||||
description=description,
|
||||
prompt=prompt,
|
||||
tags=tags or [],
|
||||
metadata=metadata or {},
|
||||
asset_metadata=asset_metadata or {},
|
||||
provider=provider,
|
||||
model=model,
|
||||
cost=cost,
|
||||
|
||||
246
backend/utils/file_storage.py
Normal file
246
backend/utils/file_storage.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""
|
||||
File Storage Utility
|
||||
Robust file storage helper for saving generated content assets.
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum filename length
|
||||
MAX_FILENAME_LENGTH = 255
|
||||
|
||||
# Allowed characters in filenames (alphanumeric, dash, underscore, dot)
|
||||
ALLOWED_FILENAME_CHARS = set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.')
|
||||
|
||||
|
||||
def sanitize_filename(filename: str, max_length: int = 100) -> str:
|
||||
"""
|
||||
Sanitize filename to be filesystem-safe.
|
||||
|
||||
Args:
|
||||
filename: Original filename
|
||||
max_length: Maximum length for filename
|
||||
|
||||
Returns:
|
||||
Sanitized filename
|
||||
"""
|
||||
if not filename:
|
||||
return f"file_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Remove path separators and other dangerous characters
|
||||
sanitized = "".join(c if c in ALLOWED_FILENAME_CHARS else '_' for c in filename)
|
||||
|
||||
# Remove leading/trailing dots and spaces
|
||||
sanitized = sanitized.strip('. ')
|
||||
|
||||
# Ensure it's not empty
|
||||
if not sanitized:
|
||||
sanitized = f"file_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Truncate if too long
|
||||
if len(sanitized) > max_length:
|
||||
name, ext = os.path.splitext(sanitized)
|
||||
max_name_length = max_length - len(ext) - 1
|
||||
sanitized = name[:max_name_length] + ext
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def ensure_directory_exists(directory: Path) -> bool:
|
||||
"""
|
||||
Ensure directory exists, creating it if necessary.
|
||||
|
||||
Args:
|
||||
directory: Path to directory
|
||||
|
||||
Returns:
|
||||
True if directory exists or was created, False otherwise
|
||||
"""
|
||||
try:
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create directory {directory}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def save_file_safely(
|
||||
content: bytes,
|
||||
directory: Path,
|
||||
filename: str,
|
||||
max_file_size: int = 100 * 1024 * 1024 # 100MB default
|
||||
) -> Tuple[Optional[Path], Optional[str]]:
|
||||
"""
|
||||
Safely save file content to disk.
|
||||
|
||||
Args:
|
||||
content: File content as bytes
|
||||
directory: Directory to save file in
|
||||
filename: Filename (will be sanitized)
|
||||
max_file_size: Maximum allowed file size in bytes
|
||||
|
||||
Returns:
|
||||
Tuple of (file_path, error_message). file_path is None on error.
|
||||
"""
|
||||
try:
|
||||
# Validate file size
|
||||
if len(content) > max_file_size:
|
||||
return None, f"File size {len(content)} exceeds maximum {max_file_size}"
|
||||
|
||||
if len(content) == 0:
|
||||
return None, "File content is empty"
|
||||
|
||||
# Ensure directory exists
|
||||
if not ensure_directory_exists(directory):
|
||||
return None, f"Failed to create directory: {directory}"
|
||||
|
||||
# Sanitize filename
|
||||
safe_filename = sanitize_filename(filename)
|
||||
|
||||
# Construct full path
|
||||
file_path = directory / safe_filename
|
||||
|
||||
# Check if file already exists (unlikely with UUID, but check anyway)
|
||||
if file_path.exists():
|
||||
# Add UUID to make it unique
|
||||
name, ext = os.path.splitext(safe_filename)
|
||||
safe_filename = f"{name}_{uuid.uuid4().hex[:8]}{ext}"
|
||||
file_path = directory / safe_filename
|
||||
|
||||
# Write file atomically (write to temp file first, then rename)
|
||||
temp_path = file_path.with_suffix(file_path.suffix + '.tmp')
|
||||
try:
|
||||
with open(temp_path, 'wb') as f:
|
||||
f.write(content)
|
||||
|
||||
# Atomic rename
|
||||
temp_path.replace(file_path)
|
||||
|
||||
logger.info(f"Successfully saved file: {file_path} ({len(content)} bytes)")
|
||||
return file_path, None
|
||||
|
||||
except Exception as write_error:
|
||||
# Clean up temp file if it exists
|
||||
if temp_path.exists():
|
||||
try:
|
||||
temp_path.unlink()
|
||||
except:
|
||||
pass
|
||||
raise write_error
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving file: {e}", exc_info=True)
|
||||
return None, str(e)
|
||||
|
||||
|
||||
def generate_unique_filename(
|
||||
prefix: str,
|
||||
extension: str = ".png",
|
||||
include_uuid: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Generate a unique filename.
|
||||
|
||||
Args:
|
||||
prefix: Filename prefix
|
||||
extension: File extension (with or without dot)
|
||||
include_uuid: Whether to include UUID in filename
|
||||
|
||||
Returns:
|
||||
Unique filename
|
||||
"""
|
||||
if not extension.startswith('.'):
|
||||
extension = '.' + extension
|
||||
|
||||
prefix = sanitize_filename(prefix, max_length=50)
|
||||
|
||||
if include_uuid:
|
||||
unique_id = uuid.uuid4().hex[:8]
|
||||
return f"{prefix}_{unique_id}{extension}"
|
||||
else:
|
||||
return f"{prefix}{extension}"
|
||||
|
||||
|
||||
def save_text_file_safely(
|
||||
content: str,
|
||||
directory: Path,
|
||||
filename: str,
|
||||
encoding: str = 'utf-8',
|
||||
max_file_size: int = 10 * 1024 * 1024 # 10MB default for text
|
||||
) -> Tuple[Optional[Path], Optional[str]]:
|
||||
"""
|
||||
Safely save text content to disk.
|
||||
|
||||
Args:
|
||||
content: Text content as string
|
||||
directory: Directory to save file in
|
||||
filename: Filename (will be sanitized)
|
||||
encoding: Text encoding (default: utf-8)
|
||||
max_file_size: Maximum allowed file size in bytes
|
||||
|
||||
Returns:
|
||||
Tuple of (file_path, error_message). file_path is None on error.
|
||||
"""
|
||||
try:
|
||||
# Validate content
|
||||
if not content or not isinstance(content, str):
|
||||
return None, "Content must be a non-empty string"
|
||||
|
||||
# Convert to bytes for size check
|
||||
content_bytes = content.encode(encoding)
|
||||
|
||||
# Validate file size
|
||||
if len(content_bytes) > max_file_size:
|
||||
return None, f"File size {len(content_bytes)} exceeds maximum {max_file_size}"
|
||||
|
||||
# Ensure directory exists
|
||||
if not ensure_directory_exists(directory):
|
||||
return None, f"Failed to create directory: {directory}"
|
||||
|
||||
# Sanitize filename
|
||||
safe_filename = sanitize_filename(filename)
|
||||
|
||||
# Ensure .txt extension if not present
|
||||
if not safe_filename.endswith(('.txt', '.md', '.json')):
|
||||
safe_filename = os.path.splitext(safe_filename)[0] + '.txt'
|
||||
|
||||
# Construct full path
|
||||
file_path = directory / safe_filename
|
||||
|
||||
# Check if file already exists
|
||||
if file_path.exists():
|
||||
# Add UUID to make it unique
|
||||
name, ext = os.path.splitext(safe_filename)
|
||||
safe_filename = f"{name}_{uuid.uuid4().hex[:8]}{ext}"
|
||||
file_path = directory / safe_filename
|
||||
|
||||
# Write file atomically (write to temp file first, then rename)
|
||||
temp_path = file_path.with_suffix(file_path.suffix + '.tmp')
|
||||
try:
|
||||
with open(temp_path, 'w', encoding=encoding) as f:
|
||||
f.write(content)
|
||||
|
||||
# Atomic rename
|
||||
temp_path.replace(file_path)
|
||||
|
||||
logger.info(f"Successfully saved text file: {file_path} ({len(content_bytes)} bytes, {len(content)} chars)")
|
||||
return file_path, None
|
||||
|
||||
except Exception as write_error:
|
||||
# Clean up temp file if it exists
|
||||
if temp_path.exists():
|
||||
try:
|
||||
temp_path.unlink()
|
||||
except:
|
||||
pass
|
||||
raise write_error
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving text file: {e}", exc_info=True)
|
||||
return None, str(e)
|
||||
|
||||
133
backend/utils/text_asset_tracker.py
Normal file
133
backend/utils/text_asset_tracker.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
Text Asset Tracker Utility
|
||||
Helper utility for saving and tracking text content as files in the asset library.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
from sqlalchemy.orm import Session
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from utils.file_storage import save_text_file_safely, generate_unique_filename, sanitize_filename
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def save_and_track_text_content(
|
||||
db: Session,
|
||||
user_id: str,
|
||||
content: str,
|
||||
source_module: str,
|
||||
title: str,
|
||||
description: Optional[str] = None,
|
||||
prompt: Optional[str] = None,
|
||||
tags: Optional[list] = None,
|
||||
asset_metadata: Optional[Dict[str, Any]] = None,
|
||||
base_dir: Optional[Path] = None,
|
||||
subdirectory: Optional[str] = None,
|
||||
file_extension: str = ".txt"
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
Save text content to disk and track it in the asset library.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: Clerk user ID
|
||||
content: Text content to save
|
||||
source_module: Source module name (e.g., "linkedin_writer", "facebook_writer")
|
||||
title: Title for the asset
|
||||
description: Description of the content
|
||||
prompt: Original prompt used for generation
|
||||
tags: List of tags for search/filtering
|
||||
asset_metadata: Additional metadata
|
||||
base_dir: Base directory for file storage (defaults to backend/{module}_text)
|
||||
subdirectory: Optional subdirectory (e.g., "posts", "articles")
|
||||
file_extension: File extension (.txt, .md, etc.)
|
||||
|
||||
Returns:
|
||||
Asset ID if successful, None otherwise
|
||||
"""
|
||||
try:
|
||||
if not content or not isinstance(content, str) or len(content.strip()) == 0:
|
||||
logger.warning("Empty or invalid content provided")
|
||||
return None
|
||||
|
||||
if not user_id or not isinstance(user_id, str):
|
||||
logger.error("Invalid user_id provided")
|
||||
return None
|
||||
|
||||
# Determine output directory
|
||||
if base_dir is None:
|
||||
# Default to backend/{module}_text
|
||||
base_dir = Path(__file__).parent.parent
|
||||
module_name = source_module.replace('_', '')
|
||||
output_dir = base_dir / f"{module_name}_text"
|
||||
else:
|
||||
output_dir = base_dir
|
||||
|
||||
# Add subdirectory if specified
|
||||
if subdirectory:
|
||||
output_dir = output_dir / subdirectory
|
||||
|
||||
# Generate safe filename from title
|
||||
safe_title = sanitize_filename(title, max_length=80)
|
||||
filename = generate_unique_filename(
|
||||
prefix=safe_title,
|
||||
extension=file_extension,
|
||||
include_uuid=True
|
||||
)
|
||||
|
||||
# Save text file
|
||||
file_path, save_error = save_text_file_safely(
|
||||
content=content,
|
||||
directory=output_dir,
|
||||
filename=filename,
|
||||
encoding='utf-8',
|
||||
max_file_size=10 * 1024 * 1024 # 10MB for text
|
||||
)
|
||||
|
||||
if not file_path or save_error:
|
||||
logger.error(f"Failed to save text file: {save_error}")
|
||||
return None
|
||||
|
||||
# Generate file URL
|
||||
relative_path = file_path.relative_to(base_dir)
|
||||
file_url = f"/api/text-assets/{relative_path.as_posix()}"
|
||||
|
||||
# Prepare metadata
|
||||
final_metadata = asset_metadata or {}
|
||||
final_metadata.update({
|
||||
"status": "completed",
|
||||
"character_count": len(content),
|
||||
"word_count": len(content.split())
|
||||
})
|
||||
|
||||
# Save to asset library
|
||||
asset_id = save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="text",
|
||||
source_module=source_module,
|
||||
filename=filename,
|
||||
file_url=file_url,
|
||||
file_path=str(file_path),
|
||||
file_size=len(content.encode('utf-8')),
|
||||
mime_type="text/plain" if file_extension == ".txt" else "text/markdown",
|
||||
title=title,
|
||||
description=description or f"Generated {source_module.replace('_', ' ')} content",
|
||||
prompt=prompt,
|
||||
tags=tags or [source_module, "text"],
|
||||
asset_metadata=final_metadata
|
||||
)
|
||||
|
||||
if asset_id:
|
||||
logger.info(f"✅ Text asset saved to library: ID={asset_id}, filename={filename}")
|
||||
else:
|
||||
logger.warning(f"Asset tracking returned None for {filename}")
|
||||
|
||||
return asset_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error saving and tracking text content: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user