AI Image Studio, AI podcast Maker, AI product Marketing
This commit is contained in:
@@ -10,6 +10,9 @@ from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from sqlalchemy.orm import Session
|
||||
from services.database import get_db as get_db_dependency
|
||||
from utils.text_asset_tracker import save_and_track_text_content
|
||||
|
||||
from models.blog_models import (
|
||||
BlogResearchRequest,
|
||||
@@ -41,6 +44,10 @@ router = APIRouter(prefix="/api/blog", tags=["AI Blog Writer"])
|
||||
|
||||
service = BlogWriterService()
|
||||
recommendation_applier = BlogSEORecommendationApplier()
|
||||
|
||||
|
||||
# Use the proper database dependency from services.database
|
||||
get_db = get_db_dependency
|
||||
# ---------------------------
|
||||
# SEO Recommendation Endpoints
|
||||
# ---------------------------
|
||||
@@ -272,10 +279,41 @@ async def rebalance_outline(outline_data: Dict[str, Any], target_words: int = 15
|
||||
|
||||
# Content Generation Endpoints
|
||||
@router.post("/section/generate", response_model=BlogSectionResponse)
|
||||
async def generate_section(request: BlogSectionRequest) -> BlogSectionResponse:
|
||||
async def generate_section(
|
||||
request: BlogSectionRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> BlogSectionResponse:
|
||||
"""Generate content for a specific section."""
|
||||
try:
|
||||
return await service.generate_section(request)
|
||||
response = await service.generate_section(request)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if response.markdown:
|
||||
try:
|
||||
user_id = str(current_user.get('id', '')) if current_user else None
|
||||
if user_id:
|
||||
section_heading = getattr(request, 'section_heading', getattr(request, 'heading', 'Section'))
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=response.markdown,
|
||||
source_module="blog_writer",
|
||||
title=f"Blog Section: {section_heading[:60]}",
|
||||
description=f"Blog section content",
|
||||
prompt=f"Section: {section_heading}\nKeywords: {getattr(request, 'keywords', [])}",
|
||||
tags=["blog", "section", "content"],
|
||||
asset_metadata={
|
||||
"section_id": getattr(request, 'section_id', None),
|
||||
"word_count": len(response.markdown.split()),
|
||||
},
|
||||
subdirectory="sections",
|
||||
file_extension=".md"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track blog section asset: {track_error}")
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate section: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -321,13 +359,48 @@ async def start_content_generation(
|
||||
|
||||
|
||||
@router.get("/content/status/{task_id}")
|
||||
async def content_generation_status(task_id: str) -> Dict[str, Any]:
|
||||
async def content_generation_status(
|
||||
task_id: str,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Dict[str, Any]:
|
||||
"""Poll status for content generation task."""
|
||||
try:
|
||||
status = await task_manager.get_task_status(task_id)
|
||||
if status is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
# Track blog content when task completes (non-blocking)
|
||||
if status.get('status') == 'completed' and status.get('result'):
|
||||
try:
|
||||
result = status.get('result', {})
|
||||
if result.get('sections') and len(result.get('sections', [])) > 0:
|
||||
user_id = str(current_user.get('id', '')) if current_user else None
|
||||
if user_id:
|
||||
# Combine all sections into full blog content
|
||||
blog_content = f"# {result.get('title', 'Untitled Blog')}\n\n"
|
||||
for section in result.get('sections', []):
|
||||
blog_content += f"\n## {section.get('heading', 'Section')}\n\n{section.get('content', '')}\n\n"
|
||||
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=blog_content,
|
||||
source_module="blog_writer",
|
||||
title=f"Blog: {result.get('title', 'Untitled Blog')[:60]}",
|
||||
description=f"Complete blog post with {len(result.get('sections', []))} sections",
|
||||
prompt=f"Title: {result.get('title', 'Untitled')}\nSections: {len(result.get('sections', []))}",
|
||||
tags=["blog", "complete", "content"],
|
||||
asset_metadata={
|
||||
"section_count": len(result.get('sections', [])),
|
||||
"model": result.get('model'),
|
||||
},
|
||||
subdirectory="complete",
|
||||
file_extension=".md"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track blog content asset: {track_error}")
|
||||
|
||||
# If task failed with subscription error, return HTTP error so frontend interceptor can catch it
|
||||
if status.get('status') == 'failed' and status.get('error_status') in [429, 402]:
|
||||
error_data = status.get('error_data', {}) or {}
|
||||
@@ -420,10 +493,40 @@ async def analyze_flow_advanced(request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
|
||||
@router.post("/section/optimize", response_model=BlogOptimizeResponse)
|
||||
async def optimize_section(request: BlogOptimizeRequest) -> BlogOptimizeResponse:
|
||||
async def optimize_section(
|
||||
request: BlogOptimizeRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> BlogOptimizeResponse:
|
||||
"""Optimize a specific section for better quality and engagement."""
|
||||
try:
|
||||
return await service.optimize_section(request)
|
||||
response = await service.optimize_section(request)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if response.optimized:
|
||||
try:
|
||||
user_id = str(current_user.get('id', '')) if current_user else None
|
||||
if user_id:
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=response.optimized,
|
||||
source_module="blog_writer",
|
||||
title=f"Optimized Blog Section",
|
||||
description=f"Optimized blog section content",
|
||||
prompt=f"Original Content: {request.content[:200]}\nGoals: {request.goals}",
|
||||
tags=["blog", "section", "optimized"],
|
||||
asset_metadata={
|
||||
"optimization_goals": request.goals,
|
||||
"word_count": len(response.optimized.split()),
|
||||
},
|
||||
subdirectory="sections/optimized",
|
||||
file_extension=".md"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track optimized blog section asset: {track_error}")
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to optimize section: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -591,13 +694,49 @@ async def start_medium_generation(
|
||||
|
||||
|
||||
@router.get("/generate/medium/status/{task_id}")
|
||||
async def medium_generation_status(task_id: str):
|
||||
async def medium_generation_status(
|
||||
task_id: str,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Poll status for medium blog generation task."""
|
||||
try:
|
||||
status = await task_manager.get_task_status(task_id)
|
||||
if status is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
# Track blog content when task completes (non-blocking)
|
||||
if status.get('status') == 'completed' and status.get('result'):
|
||||
try:
|
||||
result = status.get('result', {})
|
||||
if result.get('sections') and len(result.get('sections', [])) > 0:
|
||||
user_id = str(current_user.get('id', '')) if current_user else None
|
||||
if user_id:
|
||||
# Combine all sections into full blog content
|
||||
blog_content = f"# {result.get('title', 'Untitled Blog')}\n\n"
|
||||
for section in result.get('sections', []):
|
||||
blog_content += f"\n## {section.get('heading', 'Section')}\n\n{section.get('content', '')}\n\n"
|
||||
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=blog_content,
|
||||
source_module="blog_writer",
|
||||
title=f"Medium Blog: {result.get('title', 'Untitled Blog')[:60]}",
|
||||
description=f"Medium-length blog post with {len(result.get('sections', []))} sections",
|
||||
prompt=f"Title: {result.get('title', 'Untitled')}\nSections: {len(result.get('sections', []))}",
|
||||
tags=["blog", "medium", "complete"],
|
||||
asset_metadata={
|
||||
"section_count": len(result.get('sections', [])),
|
||||
"model": result.get('model'),
|
||||
"generation_time_ms": result.get('generation_time_ms'),
|
||||
},
|
||||
subdirectory="medium",
|
||||
file_extension=".md"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track medium blog asset: {track_error}")
|
||||
|
||||
# If task failed with subscription error, return HTTP error so frontend interceptor can catch it
|
||||
if status.get('status') == 'failed' and status.get('error_status') in [429, 402]:
|
||||
error_data = status.get('error_data', {}) or {}
|
||||
@@ -677,7 +816,8 @@ async def rewrite_status(task_id: str):
|
||||
@router.post("/titles/generate-seo")
|
||||
async def generate_seo_titles(
|
||||
request: Dict[str, Any],
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate 5 SEO-optimized blog titles using research and outline data."""
|
||||
try:
|
||||
@@ -722,6 +862,30 @@ async def generate_seo_titles(
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Save and track titles (non-blocking)
|
||||
if titles and len(titles) > 0:
|
||||
try:
|
||||
titles_content = "# SEO Blog Titles\n\n" + "\n".join([f"{i+1}. {title}" for i, title in enumerate(titles)])
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=titles_content,
|
||||
source_module="blog_writer",
|
||||
title=f"SEO Blog Titles: {primary_keywords[0] if primary_keywords else 'Blog'}",
|
||||
description=f"SEO-optimized blog title suggestions",
|
||||
prompt=f"Primary Keywords: {primary_keywords}\nSearch Intent: {search_intent}\nWord Count: {word_count}",
|
||||
tags=["blog", "titles", "seo"],
|
||||
asset_metadata={
|
||||
"title_count": len(titles),
|
||||
"primary_keywords": primary_keywords,
|
||||
"search_intent": search_intent,
|
||||
},
|
||||
subdirectory="titles",
|
||||
file_extension=".md"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track SEO titles asset: {track_error}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"titles": titles
|
||||
@@ -736,7 +900,8 @@ async def generate_seo_titles(
|
||||
@router.post("/introductions/generate")
|
||||
async def generate_introductions(
|
||||
request: Dict[str, Any],
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate 3 varied blog introductions using research, outline, and content."""
|
||||
try:
|
||||
@@ -781,6 +946,33 @@ async def generate_introductions(
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Save and track introductions (non-blocking)
|
||||
if introductions and len(introductions) > 0:
|
||||
try:
|
||||
intro_content = f"# Blog Introductions for: {blog_title}\n\n"
|
||||
for i, intro in enumerate(introductions, 1):
|
||||
intro_content += f"## Introduction {i}\n\n{intro}\n\n"
|
||||
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=intro_content,
|
||||
source_module="blog_writer",
|
||||
title=f"Blog Introductions: {blog_title[:60]}",
|
||||
description=f"Blog introduction variations",
|
||||
prompt=f"Blog Title: {blog_title}\nPrimary Keywords: {primary_keywords}\nSearch Intent: {search_intent}",
|
||||
tags=["blog", "introductions"],
|
||||
asset_metadata={
|
||||
"introduction_count": len(introductions),
|
||||
"blog_title": blog_title,
|
||||
"search_intent": search_intent,
|
||||
},
|
||||
subdirectory="introductions",
|
||||
file_extension=".md"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track blog introductions asset: {track_error}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"introductions": introductions
|
||||
|
||||
@@ -21,6 +21,7 @@ from models.blog_models import (
|
||||
)
|
||||
from services.blog_writer.blog_service import BlogWriterService
|
||||
from services.blog_writer.database_task_manager import DatabaseTaskManager
|
||||
from utils.text_asset_tracker import save_and_track_text_content
|
||||
|
||||
|
||||
class TaskManager:
|
||||
@@ -281,6 +282,9 @@ class TaskManager:
|
||||
self.task_storage[task_id]["status"] = "completed"
|
||||
self.task_storage[task_id]["result"] = result.dict()
|
||||
await self.update_progress(task_id, f"✅ Generated {len(result.sections)} sections successfully.")
|
||||
|
||||
# Note: Blog content tracking is handled in the status endpoint
|
||||
# to ensure we have proper database session and user context
|
||||
|
||||
except HTTPException as http_error:
|
||||
# Handle HTTPException (e.g., 429 subscription limit) - preserve error details for frontend
|
||||
|
||||
@@ -32,7 +32,7 @@ class AssetResponse(BaseModel):
|
||||
description: Optional[str] = None
|
||||
prompt: Optional[str] = None
|
||||
tags: List[str] = []
|
||||
metadata: Dict[str, Any] = {}
|
||||
asset_metadata: Dict[str, Any] = {}
|
||||
provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
cost: float = 0.0
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
"""FastAPI router for Facebook Writer endpoints."""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, Optional
|
||||
import logging
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..models import *
|
||||
from ..services import *
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from services.database import get_db as get_db_dependency
|
||||
from utils.text_asset_tracker import save_and_track_text_content
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -115,9 +119,17 @@ async def get_available_tools():
|
||||
return {"tools": tools, "total_count": len(tools)}
|
||||
|
||||
|
||||
# Use the proper database dependency from services.database
|
||||
get_db = get_db_dependency
|
||||
|
||||
|
||||
# Content Creation Endpoints
|
||||
@router.post("/post/generate", response_model=FacebookPostResponse)
|
||||
async def generate_facebook_post(request: FacebookPostRequest):
|
||||
async def generate_facebook_post(
|
||||
request: FacebookPostRequest,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Generate a Facebook post with engagement optimization."""
|
||||
try:
|
||||
logger.info(f"Generating Facebook post for business: {request.business_type}")
|
||||
@@ -126,6 +138,37 @@ async def generate_facebook_post(request: FacebookPostRequest):
|
||||
if not response.success:
|
||||
raise HTTPException(status_code=400, detail=response.error)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if response.content:
|
||||
try:
|
||||
user_id = None
|
||||
if current_user:
|
||||
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
|
||||
|
||||
if user_id:
|
||||
text_content = response.content
|
||||
if response.analytics:
|
||||
text_content += f"\n\n## Analytics\nExpected Reach: {response.analytics.expected_reach}\nExpected Engagement: {response.analytics.expected_engagement}\nBest Time to Post: {response.analytics.best_time_to_post}"
|
||||
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=text_content,
|
||||
source_module="facebook_writer",
|
||||
title=f"Facebook Post: {request.business_type[:60]}",
|
||||
description=f"Facebook post for {request.business_type}",
|
||||
prompt=f"Business Type: {request.business_type}\nTarget Audience: {request.target_audience}\nGoal: {request.post_goal.value if hasattr(request.post_goal, 'value') else request.post_goal}\nTone: {request.post_tone.value if hasattr(request.post_tone, 'value') else request.post_tone}",
|
||||
tags=["facebook", "post", request.business_type.lower().replace(' ', '_')],
|
||||
asset_metadata={
|
||||
"post_goal": request.post_goal.value if hasattr(request.post_goal, 'value') else str(request.post_goal),
|
||||
"post_tone": request.post_tone.value if hasattr(request.post_tone, 'value') else str(request.post_tone),
|
||||
"media_type": request.media_type.value if hasattr(request.media_type, 'value') else str(request.media_type)
|
||||
},
|
||||
subdirectory="posts"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track Facebook post asset: {track_error}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
@@ -134,7 +177,11 @@ async def generate_facebook_post(request: FacebookPostRequest):
|
||||
|
||||
|
||||
@router.post("/story/generate", response_model=FacebookStoryResponse)
|
||||
async def generate_facebook_story(request: FacebookStoryRequest):
|
||||
async def generate_facebook_story(
|
||||
request: FacebookStoryRequest,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Generate a Facebook story with visual suggestions."""
|
||||
try:
|
||||
logger.info(f"Generating Facebook story for business: {request.business_type}")
|
||||
@@ -143,6 +190,31 @@ async def generate_facebook_story(request: FacebookStoryRequest):
|
||||
if not response.success:
|
||||
raise HTTPException(status_code=400, detail=response.error)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if response.content:
|
||||
try:
|
||||
user_id = None
|
||||
if current_user:
|
||||
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
|
||||
|
||||
if user_id:
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=response.content,
|
||||
source_module="facebook_writer",
|
||||
title=f"Facebook Story: {request.business_type[:60]}",
|
||||
description=f"Facebook story for {request.business_type}",
|
||||
prompt=f"Business Type: {request.business_type}\nStory Type: {request.story_type.value if hasattr(request.story_type, 'value') else request.story_type}",
|
||||
tags=["facebook", "story", request.business_type.lower().replace(' ', '_')],
|
||||
asset_metadata={
|
||||
"story_type": request.story_type.value if hasattr(request.story_type, 'value') else str(request.story_type)
|
||||
},
|
||||
subdirectory="stories"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track Facebook story asset: {track_error}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
@@ -151,7 +223,11 @@ async def generate_facebook_story(request: FacebookStoryRequest):
|
||||
|
||||
|
||||
@router.post("/reel/generate", response_model=FacebookReelResponse)
|
||||
async def generate_facebook_reel(request: FacebookReelRequest):
|
||||
async def generate_facebook_reel(
|
||||
request: FacebookReelRequest,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Generate a Facebook reel script with music suggestions."""
|
||||
try:
|
||||
logger.info(f"Generating Facebook reel for business: {request.business_type}")
|
||||
@@ -160,6 +236,42 @@ async def generate_facebook_reel(request: FacebookReelRequest):
|
||||
if not response.success:
|
||||
raise HTTPException(status_code=400, detail=response.error)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if response.script:
|
||||
try:
|
||||
user_id = None
|
||||
if current_user:
|
||||
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
|
||||
|
||||
if user_id:
|
||||
text_content = f"# Facebook Reel Script\n\n## Script\n{response.script}\n"
|
||||
if response.scene_breakdown:
|
||||
text_content += f"\n## Scene Breakdown\n" + "\n".join([f"{i+1}. {scene}" for i, scene in enumerate(response.scene_breakdown)]) + "\n"
|
||||
if response.music_suggestions:
|
||||
text_content += f"\n## Music Suggestions\n" + "\n".join(response.music_suggestions) + "\n"
|
||||
if response.hashtag_suggestions:
|
||||
text_content += f"\n## Hashtag Suggestions\n" + " ".join([f"#{tag}" for tag in response.hashtag_suggestions]) + "\n"
|
||||
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=text_content,
|
||||
source_module="facebook_writer",
|
||||
title=f"Facebook Reel: {request.topic[:60]}",
|
||||
description=f"Facebook reel script for {request.business_type}",
|
||||
prompt=f"Business Type: {request.business_type}\nTopic: {request.topic}\nReel Type: {request.reel_type.value if hasattr(request.reel_type, 'value') else request.reel_type}\nLength: {request.reel_length.value if hasattr(request.reel_length, 'value') else request.reel_length}",
|
||||
tags=["facebook", "reel", request.business_type.lower().replace(' ', '_')],
|
||||
asset_metadata={
|
||||
"reel_type": request.reel_type.value if hasattr(request.reel_type, 'value') else str(request.reel_type),
|
||||
"reel_length": request.reel_length.value if hasattr(request.reel_length, 'value') else str(request.reel_length),
|
||||
"reel_style": request.reel_style.value if hasattr(request.reel_style, 'value') else str(request.reel_style)
|
||||
},
|
||||
subdirectory="reels",
|
||||
file_extension=".md"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track Facebook reel asset: {track_error}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
@@ -168,7 +280,11 @@ async def generate_facebook_reel(request: FacebookReelRequest):
|
||||
|
||||
|
||||
@router.post("/carousel/generate", response_model=FacebookCarouselResponse)
|
||||
async def generate_facebook_carousel(request: FacebookCarouselRequest):
|
||||
async def generate_facebook_carousel(
|
||||
request: FacebookCarouselRequest,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Generate a Facebook carousel post with multiple slides."""
|
||||
try:
|
||||
logger.info(f"Generating Facebook carousel for business: {request.business_type}")
|
||||
@@ -177,6 +293,44 @@ async def generate_facebook_carousel(request: FacebookCarouselRequest):
|
||||
if not response.success:
|
||||
raise HTTPException(status_code=400, detail=response.error)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if response.main_caption and response.slides:
|
||||
try:
|
||||
user_id = None
|
||||
if current_user:
|
||||
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
|
||||
|
||||
if user_id:
|
||||
text_content = f"# Facebook Carousel\n\n## Main Caption\n{response.main_caption}\n\n"
|
||||
text_content += "## Slides\n"
|
||||
for i, slide in enumerate(response.slides, 1):
|
||||
text_content += f"\n### Slide {i}: {slide.title}\n{slide.content}\n"
|
||||
if slide.image_description:
|
||||
text_content += f"Image Description: {slide.image_description}\n"
|
||||
|
||||
if response.hashtag_suggestions:
|
||||
text_content += f"\n## Hashtag Suggestions\n" + " ".join([f"#{tag}" for tag in response.hashtag_suggestions]) + "\n"
|
||||
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=text_content,
|
||||
source_module="facebook_writer",
|
||||
title=f"Facebook Carousel: {request.topic[:60]}",
|
||||
description=f"Facebook carousel for {request.business_type}",
|
||||
prompt=f"Business Type: {request.business_type}\nTopic: {request.topic}\nCarousel Type: {request.carousel_type.value if hasattr(request.carousel_type, 'value') else request.carousel_type}\nSlides: {request.num_slides}",
|
||||
tags=["facebook", "carousel", request.business_type.lower().replace(' ', '_')],
|
||||
asset_metadata={
|
||||
"carousel_type": request.carousel_type.value if hasattr(request.carousel_type, 'value') else str(request.carousel_type),
|
||||
"num_slides": request.num_slides,
|
||||
"has_cta": request.include_cta
|
||||
},
|
||||
subdirectory="carousels",
|
||||
file_extension=".md"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track Facebook carousel asset: {track_error}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
@@ -186,7 +340,11 @@ async def generate_facebook_carousel(request: FacebookCarouselRequest):
|
||||
|
||||
# Business Tools Endpoints
|
||||
@router.post("/event/generate", response_model=FacebookEventResponse)
|
||||
async def generate_facebook_event(request: FacebookEventRequest):
|
||||
async def generate_facebook_event(
|
||||
request: FacebookEventRequest,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Generate a Facebook event description."""
|
||||
try:
|
||||
logger.info(f"Generating Facebook event: {request.event_name}")
|
||||
@@ -195,6 +353,36 @@ async def generate_facebook_event(request: FacebookEventRequest):
|
||||
if not response.success:
|
||||
raise HTTPException(status_code=400, detail=response.error)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if response.description:
|
||||
try:
|
||||
user_id = None
|
||||
if current_user:
|
||||
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
|
||||
|
||||
if user_id:
|
||||
text_content = f"# Facebook Event: {request.event_name}\n\n## Description\n{response.description}\n"
|
||||
if hasattr(response, 'details') and response.details:
|
||||
text_content += f"\n## Details\n{response.details}\n"
|
||||
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=text_content,
|
||||
source_module="facebook_writer",
|
||||
title=f"Facebook Event: {request.event_name[:60]}",
|
||||
description=f"Facebook event description for {request.event_name}",
|
||||
prompt=f"Event Name: {request.event_name}\nEvent Type: {getattr(request, 'event_type', 'N/A')}\nDate: {getattr(request, 'event_date', 'N/A')}",
|
||||
tags=["facebook", "event", request.event_name.lower().replace(' ', '_')[:20]],
|
||||
asset_metadata={
|
||||
"event_name": request.event_name,
|
||||
"event_type": getattr(request, 'event_type', None)
|
||||
},
|
||||
subdirectory="events"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track Facebook event asset: {track_error}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
@@ -203,7 +391,11 @@ async def generate_facebook_event(request: FacebookEventRequest):
|
||||
|
||||
|
||||
@router.post("/group-post/generate", response_model=FacebookGroupPostResponse)
|
||||
async def generate_facebook_group_post(request: FacebookGroupPostRequest):
|
||||
async def generate_facebook_group_post(
|
||||
request: FacebookGroupPostRequest,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Generate a Facebook group post following community guidelines."""
|
||||
try:
|
||||
logger.info(f"Generating Facebook group post for: {request.group_name}")
|
||||
@@ -212,6 +404,32 @@ async def generate_facebook_group_post(request: FacebookGroupPostRequest):
|
||||
if not response.success:
|
||||
raise HTTPException(status_code=400, detail=response.error)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if response.content:
|
||||
try:
|
||||
user_id = None
|
||||
if current_user:
|
||||
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
|
||||
|
||||
if user_id:
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=response.content,
|
||||
source_module="facebook_writer",
|
||||
title=f"Facebook Group Post: {request.group_name[:60]}",
|
||||
description=f"Facebook group post for {request.group_name}",
|
||||
prompt=f"Group Name: {request.group_name}\nTopic: {getattr(request, 'topic', 'N/A')}",
|
||||
tags=["facebook", "group_post", request.group_name.lower().replace(' ', '_')[:20]],
|
||||
asset_metadata={
|
||||
"group_name": request.group_name,
|
||||
"group_type": getattr(request, 'group_type', None)
|
||||
},
|
||||
subdirectory="group_posts"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track Facebook group post asset: {track_error}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
@@ -220,7 +438,11 @@ async def generate_facebook_group_post(request: FacebookGroupPostRequest):
|
||||
|
||||
|
||||
@router.post("/page-about/generate", response_model=FacebookPageAboutResponse)
|
||||
async def generate_facebook_page_about(request: FacebookPageAboutRequest):
|
||||
async def generate_facebook_page_about(
|
||||
request: FacebookPageAboutRequest,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Generate a Facebook page about section."""
|
||||
try:
|
||||
logger.info(f"Generating Facebook page about for: {request.business_name}")
|
||||
@@ -229,6 +451,32 @@ async def generate_facebook_page_about(request: FacebookPageAboutRequest):
|
||||
if not response.success:
|
||||
raise HTTPException(status_code=400, detail=response.error)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if response.about_section:
|
||||
try:
|
||||
user_id = None
|
||||
if current_user:
|
||||
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
|
||||
|
||||
if user_id:
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=response.about_section,
|
||||
source_module="facebook_writer",
|
||||
title=f"Facebook Page About: {request.business_name[:60]}",
|
||||
description=f"Facebook page about section for {request.business_name}",
|
||||
prompt=f"Business Name: {request.business_name}\nBusiness Type: {getattr(request, 'business_type', 'N/A')}",
|
||||
tags=["facebook", "page_about", request.business_name.lower().replace(' ', '_')[:20]],
|
||||
asset_metadata={
|
||||
"business_name": request.business_name,
|
||||
"business_type": getattr(request, 'business_type', None)
|
||||
},
|
||||
subdirectory="page_about"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track Facebook page about asset: {track_error}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
@@ -238,7 +486,11 @@ async def generate_facebook_page_about(request: FacebookPageAboutRequest):
|
||||
|
||||
# Marketing Tools Endpoints
|
||||
@router.post("/ad-copy/generate", response_model=FacebookAdCopyResponse)
|
||||
async def generate_facebook_ad_copy(request: FacebookAdCopyRequest):
|
||||
async def generate_facebook_ad_copy(
|
||||
request: FacebookAdCopyRequest,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Generate Facebook ad copy with targeting suggestions."""
|
||||
try:
|
||||
logger.info(f"Generating Facebook ad copy for: {request.business_type}")
|
||||
@@ -247,6 +499,41 @@ async def generate_facebook_ad_copy(request: FacebookAdCopyRequest):
|
||||
if not response.success:
|
||||
raise HTTPException(status_code=400, detail=response.error)
|
||||
|
||||
# Save and track text content (non-blocking)
|
||||
if response.ad_copy:
|
||||
try:
|
||||
user_id = None
|
||||
if current_user:
|
||||
user_id = str(current_user.get('id', '') or current_user.get('sub', ''))
|
||||
|
||||
if user_id:
|
||||
text_content = f"# Facebook Ad Copy\n\n## Ad Copy\n{response.ad_copy}\n"
|
||||
if hasattr(response, 'headline') and response.headline:
|
||||
text_content += f"\n## Headline\n{response.headline}\n"
|
||||
if hasattr(response, 'description') and response.description:
|
||||
text_content += f"\n## Description\n{response.description}\n"
|
||||
if hasattr(response, 'targeting_suggestions') and response.targeting_suggestions:
|
||||
text_content += f"\n## Targeting Suggestions\n" + "\n".join(response.targeting_suggestions) + "\n"
|
||||
|
||||
save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=text_content,
|
||||
source_module="facebook_writer",
|
||||
title=f"Facebook Ad Copy: {request.business_type[:60]}",
|
||||
description=f"Facebook ad copy for {request.business_type}",
|
||||
prompt=f"Business Type: {request.business_type}\nAd Objective: {getattr(request, 'ad_objective', 'N/A')}\nTarget Audience: {getattr(request, 'target_audience', 'N/A')}",
|
||||
tags=["facebook", "ad_copy", request.business_type.lower().replace(' ', '_')],
|
||||
asset_metadata={
|
||||
"ad_objective": getattr(request, 'ad_objective', None),
|
||||
"budget": getattr(request, 'budget', None)
|
||||
},
|
||||
subdirectory="ad_copy",
|
||||
file_extension=".md"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"Failed to track Facebook ad copy asset: {track_error}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -2,10 +2,14 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import os
|
||||
import uuid
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from services.llm_providers.main_image_generation import generate_image
|
||||
@@ -16,6 +20,8 @@ from middleware.auth_middleware import get_current_user
|
||||
from services.database import get_db
|
||||
from services.subscription import UsageTrackingService, PricingService
|
||||
from models.subscription_models import APIProvider, UsageSummary
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from utils.file_storage import save_file_safely, generate_unique_filename, sanitize_filename
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/images", tags=["images"])
|
||||
@@ -37,6 +43,7 @@ class ImageGenerateRequest(BaseModel):
|
||||
class ImageGenerateResponse(BaseModel):
|
||||
success: bool = True
|
||||
image_base64: str
|
||||
image_url: Optional[str] = None # URL to saved image file
|
||||
width: int
|
||||
height: int
|
||||
provider: str
|
||||
@@ -47,7 +54,8 @@ class ImageGenerateResponse(BaseModel):
|
||||
@router.post("/generate", response_model=ImageGenerateResponse)
|
||||
def generate(
|
||||
req: ImageGenerateRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> ImageGenerateResponse:
|
||||
"""Generate image with subscription checking."""
|
||||
try:
|
||||
@@ -80,6 +88,78 @@ def generate(
|
||||
)
|
||||
image_b64 = base64.b64encode(result.image_bytes).decode("utf-8")
|
||||
|
||||
# Save image to disk and track in asset library
|
||||
image_url = None
|
||||
image_filename = None
|
||||
image_path = None
|
||||
|
||||
try:
|
||||
# Create output directory for image studio images
|
||||
base_dir = Path(__file__).parent.parent
|
||||
output_dir = base_dir / "image_studio_images"
|
||||
|
||||
# Generate safe filename from prompt
|
||||
clean_prompt = sanitize_filename(req.prompt[:50], max_length=50)
|
||||
image_filename = generate_unique_filename(
|
||||
prefix=f"img_{clean_prompt}",
|
||||
extension=".png",
|
||||
include_uuid=True
|
||||
)
|
||||
|
||||
# Save file safely
|
||||
image_path, save_error = save_file_safely(
|
||||
content=result.image_bytes,
|
||||
directory=output_dir,
|
||||
filename=image_filename,
|
||||
max_file_size=50 * 1024 * 1024 # 50MB for images
|
||||
)
|
||||
|
||||
if image_path and not save_error:
|
||||
# Generate file URL (will be served via API endpoint)
|
||||
image_url = f"/api/images/image-studio/images/{image_path.name}"
|
||||
|
||||
logger.info(f"[images.generate] Saved image to: {image_path} ({len(result.image_bytes)} bytes)")
|
||||
|
||||
# Save to asset library (non-blocking)
|
||||
try:
|
||||
asset_id = save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="image",
|
||||
source_module="image_studio",
|
||||
filename=image_path.name,
|
||||
file_url=image_url,
|
||||
file_path=str(image_path),
|
||||
file_size=len(result.image_bytes),
|
||||
mime_type="image/png",
|
||||
title=req.prompt[:100] if len(req.prompt) <= 100 else req.prompt[:97] + "...",
|
||||
description=f"Generated image: {req.prompt[:200]}" if len(req.prompt) > 200 else req.prompt,
|
||||
prompt=req.prompt,
|
||||
tags=["image_studio", "generated", result.provider] if result.provider else ["image_studio", "generated"],
|
||||
provider=result.provider,
|
||||
model=result.model,
|
||||
asset_metadata={
|
||||
"width": result.width,
|
||||
"height": result.height,
|
||||
"seed": result.seed,
|
||||
"status": "completed",
|
||||
"negative_prompt": req.negative_prompt
|
||||
}
|
||||
)
|
||||
if asset_id:
|
||||
logger.info(f"[images.generate] ✅ Asset saved to library: ID={asset_id}, filename={image_path.name}")
|
||||
else:
|
||||
logger.warning(f"[images.generate] Asset tracking returned None (may have failed silently)")
|
||||
except Exception as asset_error:
|
||||
logger.error(f"[images.generate] Failed to save asset to library: {asset_error}", exc_info=True)
|
||||
# Don't fail the request if asset tracking fails
|
||||
else:
|
||||
logger.warning(f"[images.generate] Failed to save image to disk: {save_error}")
|
||||
# Continue without failing the request - base64 is still available
|
||||
except Exception as save_error:
|
||||
logger.error(f"[images.generate] Unexpected error saving image: {save_error}", exc_info=True)
|
||||
# Continue without failing the request
|
||||
|
||||
# TRACK USAGE after successful image generation
|
||||
if result:
|
||||
logger.info(f"[images.generate] ✅ Image generation successful, tracking usage for user {user_id}")
|
||||
@@ -168,6 +248,7 @@ def generate(
|
||||
|
||||
return ImageGenerateResponse(
|
||||
image_base64=image_b64,
|
||||
image_url=image_url,
|
||||
width=result.width,
|
||||
height=result.height,
|
||||
provider=result.provider,
|
||||
@@ -226,6 +307,7 @@ class ImageEditRequest(BaseModel):
|
||||
class ImageEditResponse(BaseModel):
|
||||
success: bool = True
|
||||
image_base64: str
|
||||
image_url: Optional[str] = None # URL to saved edited image file
|
||||
width: int
|
||||
height: int
|
||||
provider: str
|
||||
@@ -358,7 +440,8 @@ def suggest_prompts(
|
||||
@router.post("/edit", response_model=ImageEditResponse)
|
||||
def edit(
|
||||
req: ImageEditRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> ImageEditResponse:
|
||||
"""Edit image with subscription checking."""
|
||||
try:
|
||||
@@ -391,6 +474,78 @@ def edit(
|
||||
)
|
||||
edited_image_b64 = base64.b64encode(result.image_bytes).decode("utf-8")
|
||||
|
||||
# Save edited image to disk and track in asset library
|
||||
image_url = None
|
||||
image_filename = None
|
||||
image_path = None
|
||||
|
||||
try:
|
||||
# Create output directory for image studio edited images
|
||||
base_dir = Path(__file__).parent.parent
|
||||
output_dir = base_dir / "image_studio_images" / "edited"
|
||||
|
||||
# Generate safe filename from prompt
|
||||
clean_prompt = sanitize_filename(req.prompt[:50], max_length=50)
|
||||
image_filename = generate_unique_filename(
|
||||
prefix=f"edited_{clean_prompt}",
|
||||
extension=".png",
|
||||
include_uuid=True
|
||||
)
|
||||
|
||||
# Save file safely
|
||||
image_path, save_error = save_file_safely(
|
||||
content=result.image_bytes,
|
||||
directory=output_dir,
|
||||
filename=image_filename,
|
||||
max_file_size=50 * 1024 * 1024 # 50MB for images
|
||||
)
|
||||
|
||||
if image_path and not save_error:
|
||||
# Generate file URL
|
||||
image_url = f"/api/images/image-studio/images/edited/{image_path.name}"
|
||||
|
||||
logger.info(f"[images.edit] Saved edited image to: {image_path} ({len(result.image_bytes)} bytes)")
|
||||
|
||||
# Save to asset library (non-blocking)
|
||||
try:
|
||||
asset_id = save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="image",
|
||||
source_module="image_studio",
|
||||
filename=image_path.name,
|
||||
file_url=image_url,
|
||||
file_path=str(image_path),
|
||||
file_size=len(result.image_bytes),
|
||||
mime_type="image/png",
|
||||
title=f"Edited: {req.prompt[:100]}" if len(req.prompt) <= 100 else f"Edited: {req.prompt[:97]}...",
|
||||
description=f"Edited image with prompt: {req.prompt[:200]}" if len(req.prompt) > 200 else f"Edited image with prompt: {req.prompt}",
|
||||
prompt=req.prompt,
|
||||
tags=["image_studio", "edited", result.provider] if result.provider else ["image_studio", "edited"],
|
||||
provider=result.provider,
|
||||
model=result.model,
|
||||
asset_metadata={
|
||||
"width": result.width,
|
||||
"height": result.height,
|
||||
"seed": result.seed,
|
||||
"status": "completed",
|
||||
"operation": "edit"
|
||||
}
|
||||
)
|
||||
if asset_id:
|
||||
logger.info(f"[images.edit] ✅ Asset saved to library: ID={asset_id}, filename={image_path.name}")
|
||||
else:
|
||||
logger.warning(f"[images.edit] Asset tracking returned None (may have failed silently)")
|
||||
except Exception as asset_error:
|
||||
logger.error(f"[images.edit] Failed to save asset to library: {asset_error}", exc_info=True)
|
||||
# Don't fail the request if asset tracking fails
|
||||
else:
|
||||
logger.warning(f"[images.edit] Failed to save edited image to disk: {save_error}")
|
||||
# Continue without failing the request - base64 is still available
|
||||
except Exception as save_error:
|
||||
logger.error(f"[images.edit] Unexpected error saving edited image: {save_error}", exc_info=True)
|
||||
# Continue without failing the request
|
||||
|
||||
# TRACK USAGE after successful image editing
|
||||
if result:
|
||||
logger.info(f"[images.edit] ✅ Image editing successful, tracking usage for user {user_id}")
|
||||
@@ -478,6 +633,7 @@ def edit(
|
||||
|
||||
return ImageEditResponse(
|
||||
image_base64=edited_image_b64,
|
||||
image_url=image_url,
|
||||
width=result.width,
|
||||
height=result.height,
|
||||
provider=result.provider,
|
||||
@@ -494,3 +650,55 @@ def edit(
|
||||
detail="Image editing service is temporarily unavailable or the connection was reset. Please try again."
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------
|
||||
# Image Serving Endpoints
|
||||
# ---------------------------
|
||||
|
||||
@router.get("/image-studio/images/{image_filename:path}")
|
||||
async def serve_image_studio_image(
|
||||
image_filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Serve a generated or edited image from Image Studio."""
|
||||
try:
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
# Determine if it's an edited image or regular image
|
||||
base_dir = Path(__file__).parent.parent
|
||||
image_studio_dir = (base_dir / "image_studio_images").resolve()
|
||||
|
||||
if image_filename.startswith("edited/"):
|
||||
# Remove "edited/" prefix and serve from edited directory
|
||||
actual_filename = image_filename.replace("edited/", "", 1)
|
||||
image_path = (image_studio_dir / "edited" / actual_filename).resolve()
|
||||
base_subdir = (image_studio_dir / "edited").resolve()
|
||||
else:
|
||||
image_path = (image_studio_dir / image_filename).resolve()
|
||||
base_subdir = image_studio_dir
|
||||
|
||||
# Security: Prevent directory traversal attacks
|
||||
# Ensure the resolved path is within the intended directory
|
||||
try:
|
||||
image_path.relative_to(base_subdir)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Access denied: Invalid image path"
|
||||
)
|
||||
|
||||
if not image_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
return FileResponse(
|
||||
path=str(image_path),
|
||||
media_type="image/png",
|
||||
filename=image_path.name
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[images] Failed to serve image: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,18 +4,20 @@ Each module focuses on a related set of routes to keep the primary
|
||||
`router.py` concise and easier to maintain.
|
||||
"""
|
||||
|
||||
from . import story_setup
|
||||
from . import story_content
|
||||
from . import story_tasks
|
||||
from . import media_generation
|
||||
from . import video_generation
|
||||
from . import cache_routes
|
||||
from . import media_generation
|
||||
from . import scene_animation
|
||||
from . import story_content
|
||||
from . import story_setup
|
||||
from . import story_tasks
|
||||
from . import video_generation
|
||||
|
||||
__all__ = [
|
||||
"story_setup",
|
||||
"story_content",
|
||||
"story_tasks",
|
||||
"media_generation",
|
||||
"video_generation",
|
||||
"cache_routes",
|
||||
"media_generation",
|
||||
"scene_animation",
|
||||
"story_content",
|
||||
"story_setup",
|
||||
"story_tasks",
|
||||
"video_generation",
|
||||
]
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
|
||||
from models.story_models import (
|
||||
@@ -18,8 +20,10 @@ from models.story_models import (
|
||||
GenerateAIAudioResponse,
|
||||
StoryScene,
|
||||
)
|
||||
from services.database import get_db
|
||||
from services.story_writer.image_generation_service import StoryImageGenerationService
|
||||
from services.story_writer.audio_generation_service import StoryAudioGenerationService
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
|
||||
from ..utils.auth import require_authenticated_user
|
||||
from ..utils.media_utils import resolve_media_file
|
||||
@@ -34,6 +38,7 @@ audio_service = StoryAudioGenerationService()
|
||||
async def generate_scene_images(
|
||||
request: StoryImageGenerationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> StoryImageGenerationResponse:
|
||||
"""Generate images for story scenes."""
|
||||
try:
|
||||
@@ -70,6 +75,37 @@ async def generate_scene_images(
|
||||
for result in image_results
|
||||
]
|
||||
|
||||
# Save assets to library
|
||||
for result in image_results:
|
||||
if not result.get("error") and result.get("image_url"):
|
||||
try:
|
||||
scene_number = result.get("scene_number", 0)
|
||||
# Safely get prompt from scenes_data with bounds checking
|
||||
prompt = None
|
||||
if scene_number > 0 and scene_number <= len(scenes_data):
|
||||
prompt = scenes_data[scene_number - 1].get("image_prompt")
|
||||
|
||||
save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="image",
|
||||
source_module="story_writer",
|
||||
filename=result.get("image_filename", ""),
|
||||
file_url=result.get("image_url", ""),
|
||||
file_path=result.get("image_path"),
|
||||
file_size=result.get("file_size"),
|
||||
mime_type="image/png",
|
||||
title=f"Scene {scene_number}: {result.get('scene_title', 'Untitled')}",
|
||||
description=f"Story scene image for scene {scene_number}",
|
||||
prompt=prompt,
|
||||
tags=["story_writer", "scene", f"scene_{scene_number}"],
|
||||
provider=result.get("provider"),
|
||||
model=result.get("model"),
|
||||
asset_metadata={"scene_number": scene_number, "scene_title": result.get("scene_title"), "status": "completed"}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[StoryWriter] Failed to save image asset to library: {e}")
|
||||
|
||||
return StoryImageGenerationResponse(images=image_models, success=True)
|
||||
|
||||
except HTTPException:
|
||||
@@ -163,6 +199,7 @@ async def serve_scene_image(
|
||||
async def generate_scene_audio(
|
||||
request: StoryAudioGenerationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> StoryAudioGenerationResponse:
|
||||
"""Generate audio narration for story scenes."""
|
||||
try:
|
||||
@@ -185,18 +222,52 @@ async def generate_scene_audio(
|
||||
|
||||
audio_models: List[StoryAudioResult] = []
|
||||
for result in audio_results:
|
||||
audio_url = result.get("audio_url") or ""
|
||||
audio_filename = result.get("audio_filename") or ""
|
||||
|
||||
audio_models.append(
|
||||
StoryAudioResult(
|
||||
scene_number=result.get("scene_number", 0),
|
||||
scene_title=result.get("scene_title", "Untitled"),
|
||||
audio_filename=result.get("audio_filename") or "",
|
||||
audio_url=result.get("audio_url") or "",
|
||||
audio_filename=audio_filename,
|
||||
audio_url=audio_url,
|
||||
provider=result.get("provider", "unknown"),
|
||||
file_size=result.get("file_size", 0),
|
||||
error=result.get("error"),
|
||||
)
|
||||
)
|
||||
|
||||
# Save assets to library
|
||||
if not result.get("error") and audio_url:
|
||||
try:
|
||||
scene_number = result.get("scene_number", 0)
|
||||
# Safely get prompt from scenes_data with bounds checking
|
||||
prompt = None
|
||||
if scene_number > 0 and scene_number <= len(scenes_data):
|
||||
prompt = scenes_data[scene_number - 1].get("text")
|
||||
|
||||
save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="audio",
|
||||
source_module="story_writer",
|
||||
filename=audio_filename,
|
||||
file_url=audio_url,
|
||||
file_path=result.get("audio_path"),
|
||||
file_size=result.get("file_size"),
|
||||
mime_type="audio/mpeg",
|
||||
title=f"Scene {scene_number}: {result.get('scene_title', 'Untitled')}",
|
||||
description=f"Story scene audio narration for scene {scene_number}",
|
||||
prompt=prompt,
|
||||
tags=["story_writer", "audio", "narration", f"scene_{scene_number}"],
|
||||
provider=result.get("provider"),
|
||||
model=result.get("model"),
|
||||
cost=result.get("cost"),
|
||||
asset_metadata={"scene_number": scene_number, "scene_title": result.get("scene_title"), "status": "completed"}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[StoryWriter] Failed to save audio asset to library: {e}")
|
||||
|
||||
return StoryAudioGenerationResponse(audio_files=audio_models, success=True)
|
||||
|
||||
except HTTPException:
|
||||
@@ -287,3 +358,59 @@ async def serve_scene_audio(
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
class PromptOptimizeRequest(BaseModel):
|
||||
text: str = Field(..., description="The prompt text to optimize")
|
||||
mode: Optional[str] = Field(default="image", pattern="^(image|video)$", description="Optimization mode: 'image' or 'video'")
|
||||
style: Optional[str] = Field(
|
||||
default="default",
|
||||
pattern="^(default|artistic|photographic|technical|anime|realistic)$",
|
||||
description="Style: 'default', 'artistic', 'photographic', 'technical', 'anime', or 'realistic'"
|
||||
)
|
||||
image: Optional[str] = Field(None, description="Base64-encoded image for context (optional)")
|
||||
|
||||
|
||||
class PromptOptimizeResponse(BaseModel):
|
||||
optimized_prompt: str
|
||||
success: bool
|
||||
|
||||
|
||||
@router.post("/optimize-prompt", response_model=PromptOptimizeResponse)
|
||||
async def optimize_prompt(
|
||||
request: PromptOptimizeRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> PromptOptimizeResponse:
|
||||
"""Optimize an image prompt using WaveSpeed prompt optimizer."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.text or not request.text.strip():
|
||||
raise HTTPException(status_code=400, detail="Prompt text is required")
|
||||
|
||||
logger.info(f"[StoryWriter] Optimizing prompt for user {user_id} (mode={request.mode}, style={request.style})")
|
||||
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
|
||||
client = WaveSpeedClient()
|
||||
optimized_prompt = client.optimize_prompt(
|
||||
text=request.text.strip(),
|
||||
mode=request.mode or "image",
|
||||
style=request.style or "default",
|
||||
image=request.image, # Optional base64 image
|
||||
enable_sync_mode=True,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
logger.info(f"[StoryWriter] Prompt optimized successfully for user {user_id}")
|
||||
|
||||
return PromptOptimizeResponse(
|
||||
optimized_prompt=optimized_prompt,
|
||||
success=True
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to optimize prompt: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
|
||||
484
backend/api/story_writer/routes/scene_animation.py
Normal file
484
backend/api/story_writer/routes/scene_animation.py
Normal file
@@ -0,0 +1,484 @@
|
||||
"""
|
||||
Scene Animation Routes
|
||||
|
||||
Handles scene animation endpoints using WaveSpeed Kling and InfiniteTalk.
|
||||
"""
|
||||
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from models.story_models import (
|
||||
AnimateSceneRequest,
|
||||
AnimateSceneResponse,
|
||||
AnimateSceneVoiceoverRequest,
|
||||
ResumeSceneAnimationRequest,
|
||||
)
|
||||
from services.database import get_db
|
||||
from services.llm_providers.main_video_generation import track_video_usage
|
||||
from services.story_writer.video_generation_service import StoryVideoGenerationService
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_scene_animation_operation
|
||||
from services.wavespeed.infinitetalk import animate_scene_with_voiceover
|
||||
from services.wavespeed.kling_animation import animate_scene_image, resume_scene_animation
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
from ..task_manager import task_manager
|
||||
from ..utils.auth import require_authenticated_user
|
||||
from ..utils.media_utils import load_story_audio_bytes, load_story_image_bytes
|
||||
|
||||
router = APIRouter()
|
||||
scene_logger = get_service_logger("api.story_writer.scene_animation")
|
||||
AI_VIDEO_SUBDIR = Path("AI_Videos")
|
||||
|
||||
|
||||
def _build_authenticated_media_url(request: Request, path: str) -> str:
|
||||
"""Append the caller's auth token to a media URL so <video>/<img> tags can access it."""
|
||||
if not path:
|
||||
return path
|
||||
|
||||
token: Optional[str] = None
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
token = auth_header.replace("Bearer ", "").strip()
|
||||
elif "token" in request.query_params:
|
||||
token = request.query_params["token"]
|
||||
|
||||
if token:
|
||||
separator = "&" if "?" in path else "?"
|
||||
path = f"{path}{separator}token={quote(token)}"
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def _guess_mime_from_url(url: str, fallback: str) -> str:
|
||||
"""Guess MIME type from URL."""
|
||||
if not url:
|
||||
return fallback
|
||||
mime, _ = mimetypes.guess_type(url)
|
||||
return mime or fallback
|
||||
|
||||
|
||||
@router.post("/animate-scene-preview", response_model=AnimateSceneResponse)
|
||||
async def animate_scene_preview(
|
||||
request_obj: Request,
|
||||
request: AnimateSceneRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> AnimateSceneResponse:
|
||||
"""
|
||||
Animate a single scene image using WaveSpeed Kling v2.5 Turbo Std.
|
||||
"""
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get("id", ""))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
duration = request.duration or 5
|
||||
if duration not in (5, 10):
|
||||
raise HTTPException(status_code=400, detail="Duration must be 5 or 10 seconds.")
|
||||
|
||||
scene_logger.info(
|
||||
"[AnimateScene] User=%s scene=%s duration=%s image_url=%s",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
duration,
|
||||
request.image_url,
|
||||
)
|
||||
|
||||
image_bytes = load_story_image_bytes(request.image_url)
|
||||
if not image_bytes:
|
||||
scene_logger.warning("[AnimateScene] Missing image bytes for user=%s scene=%s", user_id, request.scene_number)
|
||||
raise HTTPException(status_code=404, detail="Scene image not found. Generate images first.")
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
validate_scene_animation_operation(pricing_service=pricing_service, user_id=user_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
animation_result = animate_scene_image(
|
||||
image_bytes=image_bytes,
|
||||
scene_data=request.scene_data,
|
||||
story_context=request.story_context,
|
||||
user_id=user_id,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
base_dir = Path(__file__).parent.parent.parent.parent
|
||||
ai_video_dir = base_dir / "story_videos" / AI_VIDEO_SUBDIR
|
||||
ai_video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_service = StoryVideoGenerationService(output_dir=str(ai_video_dir))
|
||||
|
||||
save_result = video_service.save_scene_video(
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
scene_number=request.scene_number,
|
||||
user_id=user_id,
|
||||
)
|
||||
video_filename = save_result["video_filename"]
|
||||
video_url = _build_authenticated_media_url(
|
||||
request_obj, f"/api/story/videos/ai/{video_filename}"
|
||||
)
|
||||
|
||||
usage_info = track_video_usage(
|
||||
user_id=user_id,
|
||||
provider=animation_result["provider"],
|
||||
model_name=animation_result["model_name"],
|
||||
prompt=animation_result["prompt"],
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
cost_override=animation_result["cost"],
|
||||
)
|
||||
if usage_info:
|
||||
scene_logger.warning(
|
||||
"[AnimateScene] Video usage tracked user=%s: %s → %s / %s (cost +$%.2f, total=$%.2f)",
|
||||
user_id,
|
||||
usage_info.get("previous_calls"),
|
||||
usage_info.get("current_calls"),
|
||||
usage_info.get("video_limit_display"),
|
||||
usage_info.get("cost_per_video", 0.0),
|
||||
usage_info.get("total_video_cost", 0.0),
|
||||
)
|
||||
|
||||
scene_logger.info(
|
||||
"[AnimateScene] ✅ Completed user=%s scene=%s duration=%s cost=$%.2f video=%s",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
animation_result["duration"],
|
||||
animation_result["cost"],
|
||||
video_url,
|
||||
)
|
||||
|
||||
# Save video asset to library
|
||||
db = next(get_db())
|
||||
try:
|
||||
save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="video",
|
||||
source_module="story_writer",
|
||||
filename=video_filename,
|
||||
file_url=video_url,
|
||||
file_path=str(ai_video_dir / video_filename),
|
||||
file_size=len(animation_result["video_bytes"]),
|
||||
mime_type="video/mp4",
|
||||
title=f"Scene {request.scene_number} Animation",
|
||||
description=f"Animated scene {request.scene_number} from story",
|
||||
prompt=animation_result["prompt"],
|
||||
tags=["story_writer", "video", "animation", f"scene_{request.scene_number}"],
|
||||
provider=animation_result["provider"],
|
||||
model=animation_result.get("model_name"),
|
||||
cost=animation_result["cost"],
|
||||
asset_metadata={"scene_number": request.scene_number, "duration": animation_result["duration"], "status": "completed"}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[StoryWriter] Failed to save video asset to library: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return AnimateSceneResponse(
|
||||
success=True,
|
||||
scene_number=request.scene_number,
|
||||
video_filename=video_filename,
|
||||
video_url=video_url,
|
||||
duration=animation_result["duration"],
|
||||
cost=animation_result["cost"],
|
||||
prompt_used=animation_result["prompt"],
|
||||
provider=animation_result["provider"],
|
||||
prediction_id=animation_result.get("prediction_id"),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/animate-scene-resume", response_model=AnimateSceneResponse)
|
||||
async def resume_scene_animation_endpoint(
|
||||
request_obj: Request,
|
||||
request: ResumeSceneAnimationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> AnimateSceneResponse:
|
||||
"""Resume downloading a WaveSpeed animation when the initial call timed out."""
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get("id", ""))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
scene_logger.info(
|
||||
"[AnimateScene] Resume requested user=%s scene=%s prediction=%s",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
request.prediction_id,
|
||||
)
|
||||
|
||||
animation_result = resume_scene_animation(
|
||||
prediction_id=request.prediction_id,
|
||||
duration=request.duration or 5,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
base_dir = Path(__file__).parent.parent.parent.parent
|
||||
ai_video_dir = base_dir / "story_videos" / AI_VIDEO_SUBDIR
|
||||
ai_video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_service = StoryVideoGenerationService(output_dir=str(ai_video_dir))
|
||||
|
||||
save_result = video_service.save_scene_video(
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
scene_number=request.scene_number,
|
||||
user_id=user_id,
|
||||
)
|
||||
video_filename = save_result["video_filename"]
|
||||
video_url = _build_authenticated_media_url(
|
||||
request_obj, f"/api/story/videos/ai/{video_filename}"
|
||||
)
|
||||
|
||||
usage_info = track_video_usage(
|
||||
user_id=user_id,
|
||||
provider=animation_result["provider"],
|
||||
model_name=animation_result["model_name"],
|
||||
prompt=animation_result["prompt"],
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
cost_override=animation_result["cost"],
|
||||
)
|
||||
if usage_info:
|
||||
scene_logger.warning(
|
||||
"[AnimateScene] (Resume) Video usage tracked user=%s: %s → %s / %s (cost +$%.2f, total=$%.2f)",
|
||||
user_id,
|
||||
usage_info.get("previous_calls"),
|
||||
usage_info.get("current_calls"),
|
||||
usage_info.get("video_limit_display"),
|
||||
usage_info.get("cost_per_video", 0.0),
|
||||
usage_info.get("total_video_cost", 0.0),
|
||||
)
|
||||
|
||||
scene_logger.info(
|
||||
"[AnimateScene] ✅ Resume completed user=%s scene=%s prediction=%s video=%s",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
request.prediction_id,
|
||||
video_url,
|
||||
)
|
||||
|
||||
return AnimateSceneResponse(
|
||||
success=True,
|
||||
scene_number=request.scene_number,
|
||||
video_filename=video_filename,
|
||||
video_url=video_url,
|
||||
duration=animation_result["duration"],
|
||||
cost=animation_result["cost"],
|
||||
prompt_used=animation_result["prompt"],
|
||||
provider=animation_result["provider"],
|
||||
prediction_id=animation_result.get("prediction_id"),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/animate-scene-voiceover", response_model=Dict[str, Any])
|
||||
async def animate_scene_voiceover_endpoint(
|
||||
request_obj: Request,
|
||||
request: AnimateSceneVoiceoverRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Animate a scene using WaveSpeed InfiniteTalk (image + audio) asynchronously.
|
||||
Returns task_id for polling since InfiniteTalk can take up to 10 minutes.
|
||||
"""
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get("id", ""))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
scene_logger.info(
|
||||
"[AnimateSceneVoiceover] User=%s scene=%s resolution=%s (async)",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
request.resolution or "720p",
|
||||
)
|
||||
|
||||
image_bytes = load_story_image_bytes(request.image_url)
|
||||
if not image_bytes:
|
||||
raise HTTPException(status_code=404, detail="Scene image not found. Generate images first.")
|
||||
|
||||
audio_bytes = load_story_audio_bytes(request.audio_url)
|
||||
if not audio_bytes:
|
||||
raise HTTPException(status_code=404, detail="Scene audio not found. Generate audio first.")
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
validate_scene_animation_operation(pricing_service=pricing_service, user_id=user_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Extract token for authenticated URL building (if needed)
|
||||
auth_token = None
|
||||
auth_header = request_obj.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
auth_token = auth_header.replace("Bearer ", "").strip()
|
||||
|
||||
# Create async task
|
||||
task_id = task_manager.create_task("scene_voiceover_animation")
|
||||
background_tasks.add_task(
|
||||
_execute_voiceover_animation_task,
|
||||
task_id=task_id,
|
||||
request=request,
|
||||
user_id=user_id,
|
||||
image_bytes=image_bytes,
|
||||
audio_bytes=audio_bytes,
|
||||
auth_token=auth_token,
|
||||
)
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"status": "pending",
|
||||
"message": "InfiniteTalk animation started. This may take up to 10 minutes.",
|
||||
}
|
||||
|
||||
|
||||
def _execute_voiceover_animation_task(
|
||||
task_id: str,
|
||||
request: AnimateSceneVoiceoverRequest,
|
||||
user_id: str,
|
||||
image_bytes: bytes,
|
||||
audio_bytes: bytes,
|
||||
auth_token: Optional[str] = None,
|
||||
):
|
||||
"""Background task to generate InfiniteTalk video with progress updates."""
|
||||
try:
|
||||
task_manager.update_task_status(
|
||||
task_id, "processing", progress=5.0, message="Submitting to WaveSpeed InfiniteTalk..."
|
||||
)
|
||||
|
||||
animation_result = animate_scene_with_voiceover(
|
||||
image_bytes=image_bytes,
|
||||
audio_bytes=audio_bytes,
|
||||
scene_data=request.scene_data,
|
||||
story_context=request.story_context,
|
||||
user_id=user_id,
|
||||
resolution=request.resolution or "720p",
|
||||
prompt_override=request.prompt,
|
||||
image_mime=_guess_mime_from_url(request.image_url, "image/png"),
|
||||
audio_mime=_guess_mime_from_url(request.audio_url, "audio/mpeg"),
|
||||
)
|
||||
|
||||
task_manager.update_task_status(
|
||||
task_id, "processing", progress=80.0, message="Saving video file..."
|
||||
)
|
||||
|
||||
base_dir = Path(__file__).parent.parent.parent.parent
|
||||
ai_video_dir = base_dir / "story_videos" / AI_VIDEO_SUBDIR
|
||||
ai_video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_service = StoryVideoGenerationService(output_dir=str(ai_video_dir))
|
||||
|
||||
save_result = video_service.save_scene_video(
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
scene_number=request.scene_number,
|
||||
user_id=user_id,
|
||||
)
|
||||
video_filename = save_result["video_filename"]
|
||||
# Build authenticated URL if token provided, otherwise return plain URL
|
||||
video_url = f"/api/story/videos/ai/{video_filename}"
|
||||
if auth_token:
|
||||
video_url = f"{video_url}?token={quote(auth_token)}"
|
||||
|
||||
usage_info = track_video_usage(
|
||||
user_id=user_id,
|
||||
provider=animation_result["provider"],
|
||||
model_name=animation_result["model_name"],
|
||||
prompt=animation_result["prompt"],
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
cost_override=animation_result["cost"],
|
||||
)
|
||||
if usage_info:
|
||||
scene_logger.warning(
|
||||
"[AnimateSceneVoiceover] Video usage tracked user=%s: %s → %s / %s (cost +$%.2f, total=$%.2f)",
|
||||
user_id,
|
||||
usage_info.get("previous_calls"),
|
||||
usage_info.get("current_calls"),
|
||||
usage_info.get("video_limit_display"),
|
||||
usage_info.get("cost_per_video", 0.0),
|
||||
usage_info.get("total_video_cost", 0.0),
|
||||
)
|
||||
|
||||
scene_logger.info(
|
||||
"[AnimateSceneVoiceover] ✅ Completed user=%s scene=%s cost=$%.2f video=%s",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
animation_result["cost"],
|
||||
video_url,
|
||||
)
|
||||
|
||||
# Save video asset to library
|
||||
db = next(get_db())
|
||||
try:
|
||||
save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="video",
|
||||
source_module="story_writer",
|
||||
filename=video_filename,
|
||||
file_url=video_url,
|
||||
file_path=str(ai_video_dir / video_filename),
|
||||
file_size=len(animation_result["video_bytes"]),
|
||||
mime_type="video/mp4",
|
||||
title=f"Scene {request.scene_number} Animation (Voiceover)",
|
||||
description=f"Animated scene {request.scene_number} with voiceover from story",
|
||||
prompt=animation_result["prompt"],
|
||||
tags=["story_writer", "video", "animation", "voiceover", f"scene_{request.scene_number}"],
|
||||
provider=animation_result["provider"],
|
||||
model=animation_result.get("model_name"),
|
||||
cost=animation_result["cost"],
|
||||
asset_metadata={"scene_number": request.scene_number, "duration": animation_result["duration"], "status": "completed"}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[StoryWriter] Failed to save video asset to library: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
result = AnimateSceneResponse(
|
||||
success=True,
|
||||
scene_number=request.scene_number,
|
||||
video_filename=video_filename,
|
||||
video_url=video_url,
|
||||
duration=animation_result["duration"],
|
||||
cost=animation_result["cost"],
|
||||
prompt_used=animation_result["prompt"],
|
||||
provider=animation_result["provider"],
|
||||
prediction_id=animation_result.get("prediction_id"),
|
||||
)
|
||||
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"completed",
|
||||
progress=100.0,
|
||||
message="InfiniteTalk animation complete!",
|
||||
result=result.dict(),
|
||||
)
|
||||
except HTTPException as exc:
|
||||
error_msg = str(exc.detail) if isinstance(exc.detail, str) else exc.detail.get("error", "Animation failed") if isinstance(exc.detail, dict) else "Animation failed"
|
||||
scene_logger.error(f"[AnimateSceneVoiceover] Failed: {error_msg}")
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"failed",
|
||||
error=error_msg,
|
||||
message=f"InfiniteTalk animation failed: {error_msg}",
|
||||
)
|
||||
except Exception as exc:
|
||||
error_msg = str(exc)
|
||||
scene_logger.error(f"[AnimateSceneVoiceover] Error: {error_msg}", exc_info=True)
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"failed",
|
||||
error=error_msg,
|
||||
message=f"InfiniteTalk animation error: {error_msg}",
|
||||
)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import Any, Dict, List
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from models.story_models import (
|
||||
@@ -18,6 +20,7 @@ from ..utils.auth import require_authenticated_user
|
||||
|
||||
router = APIRouter()
|
||||
story_service = StoryWriterService()
|
||||
scene_approval_store: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
@router.post("/generate-start", response_model=StoryContentResponse)
|
||||
@@ -193,3 +196,45 @@ async def continue_story(
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
class SceneApprovalRequest(BaseModel):
|
||||
project_id: str
|
||||
scene_id: str
|
||||
approved: bool = True
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/script/approve")
|
||||
async def approve_script_scene(
|
||||
request: SceneApprovalRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
"""Persist scene approval metadata for auditing."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
approvals = scene_approval_store.setdefault(request.project_id, {})
|
||||
approvals[request.scene_id] = {
|
||||
"approved": request.approved,
|
||||
"notes": request.notes,
|
||||
"user_id": user_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
logger.info(
|
||||
"[StoryWriter] Scene approval recorded user=%s project=%s scene=%s approved=%s",
|
||||
user_id,
|
||||
request.project_id,
|
||||
request.scene_id,
|
||||
request.approved,
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"project_id": request.project_id,
|
||||
"scene_id": request.scene_id,
|
||||
"approved": request.approved,
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to approve scene: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
|
||||
@@ -509,3 +509,30 @@ async def serve_story_video(
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.get("/videos/ai/{video_filename}")
|
||||
async def serve_ai_story_video(
|
||||
video_filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Serve a generated AI scene animation video."""
|
||||
try:
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
base_dir = Path(__file__).parent.parent.parent.parent
|
||||
ai_video_dir = (base_dir / "story_videos" / "AI_Videos").resolve()
|
||||
video_service_ai = StoryVideoGenerationService(output_dir=str(ai_video_dir))
|
||||
video_path = resolve_media_file(video_service_ai.output_dir, video_filename)
|
||||
|
||||
return FileResponse(
|
||||
path=str(video_path),
|
||||
media_type="video/mp4",
|
||||
filename=video_filename
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to serve AI video: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user