AI Video Generation Implementation
This commit is contained in:
@@ -9,6 +9,7 @@ from fastapi import APIRouter, HTTPException, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from services.llm_providers.main_image_generation import generate_image
|
||||
from services.llm_providers.main_image_editing import edit_image
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from utils.logger_utils import get_service_logger
|
||||
from middleware.auth_middleware import get_current_user
|
||||
@@ -125,6 +126,14 @@ def generate(
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
call_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
|
||||
# Get image editing stats for unified log
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
|
||||
# Get video stats for unified log
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"[images.generate] ✅ Successfully tracked usage: user {user_id} -> stability -> {new_calls} calls")
|
||||
|
||||
@@ -137,6 +146,8 @@ def generate(
|
||||
├─ Actual Provider: {result.provider}
|
||||
├─ Model: {result.model or 'default'}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
@@ -195,6 +206,26 @@ class ImagePromptSuggestResponse(BaseModel):
|
||||
suggestions: list[PromptSuggestion]
|
||||
|
||||
|
||||
class ImageEditRequest(BaseModel):
|
||||
image_base64: str
|
||||
prompt: str
|
||||
provider: Optional[str] = Field(None, pattern="^(huggingface)$")
|
||||
model: Optional[str] = None
|
||||
guidance_scale: Optional[float] = None
|
||||
steps: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
|
||||
class ImageEditResponse(BaseModel):
|
||||
success: bool = True
|
||||
image_base64: str
|
||||
width: int
|
||||
height: int
|
||||
provider: str
|
||||
model: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
|
||||
@router.post("/suggest-prompts", response_model=ImagePromptSuggestResponse)
|
||||
def suggest_prompts(
|
||||
req: ImagePromptSuggestRequest,
|
||||
@@ -316,3 +347,136 @@ def suggest_prompts(
|
||||
logger.error(f"Prompt suggestion failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/edit", response_model=ImageEditResponse)
|
||||
def edit(
|
||||
req: ImageEditRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> ImageEditResponse:
|
||||
"""Edit image with subscription checking."""
|
||||
try:
|
||||
# Extract Clerk user ID (required)
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
# Decode base64 image
|
||||
try:
|
||||
input_image_bytes = base64.b64decode(req.image_base64)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid image_base64: {str(e)}")
|
||||
|
||||
# Validation is now handled inside edit_image function
|
||||
result = edit_image(
|
||||
input_image_bytes=input_image_bytes,
|
||||
prompt=req.prompt,
|
||||
options={
|
||||
"provider": req.provider,
|
||||
"model": req.model,
|
||||
"guidance_scale": req.guidance_scale,
|
||||
"steps": req.steps,
|
||||
"seed": req.seed,
|
||||
},
|
||||
user_id=user_id, # Pass user_id for validation inside edit_image
|
||||
)
|
||||
edited_image_b64 = base64.b64encode(result.image_bytes).decode("utf-8")
|
||||
|
||||
# TRACK USAGE after successful image editing
|
||||
if result:
|
||||
logger.info(f"[images.edit] ✅ Image editing successful, tracking usage for user {user_id}")
|
||||
try:
|
||||
db_track = next(get_db())
|
||||
try:
|
||||
# Get or create usage summary
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
logger.debug(f"[images.edit] Looking for usage summary: user_id={user_id}, period={current_period}")
|
||||
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
logger.info(f"[images.edit] Creating new usage summary for user {user_id}, period {current_period}")
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush() # Ensure summary is persisted before updating
|
||||
|
||||
# Get "before" state for unified log
|
||||
current_calls_before = getattr(summary, "image_edit_calls", 0) or 0
|
||||
|
||||
# Update image editing counters (separate from image generation)
|
||||
new_calls = current_calls_before + 1
|
||||
setattr(summary, "image_edit_calls", new_calls)
|
||||
logger.debug(f"[images.edit] Updated image_edit_calls: {current_calls_before} -> {new_calls}")
|
||||
|
||||
# Update totals
|
||||
old_total_calls = summary.total_calls or 0
|
||||
summary.total_calls = old_total_calls + 1
|
||||
logger.debug(f"[images.edit] Updated totals: calls {old_total_calls} -> {summary.total_calls}")
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
call_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
|
||||
# Get image generation stats for unified log
|
||||
current_image_gen_calls = getattr(summary, "stability_calls", 0) or 0
|
||||
image_gen_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
|
||||
# Get video stats for unified log
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"[images.edit] ✅ Successfully tracked usage: user {user_id} -> image_edit -> {new_calls} calls")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
print(f"""
|
||||
[SUBSCRIPTION] Image Editing
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: image_edit
|
||||
├─ Actual Provider: {result.provider}
|
||||
├─ Model: {result.model or 'default'}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'}
|
||||
├─ Images: {current_image_gen_calls} / {image_gen_limit if image_gen_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
logger.error(f"[images.edit] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
# Non-blocking: log error but don't fail the request
|
||||
logger.error(f"[images.edit] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
|
||||
return ImageEditResponse(
|
||||
image_base64=edited_image_b64,
|
||||
width=result.width,
|
||||
height=result.height,
|
||||
provider=result.provider,
|
||||
model=result.model,
|
||||
seed=result.seed,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Image editing failed: {e}", exc_info=True)
|
||||
# Provide a clean, actionable message to the client
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Image editing service is temporarily unavailable or the connection was reset. Please try again."
|
||||
)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ content generation, and full story creation.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
|
||||
from typing import Any, Dict, Union, List
|
||||
from typing import Any, Dict, Union, List, Optional
|
||||
from loguru import logger
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
@@ -37,6 +37,16 @@ from models.story_models import (
|
||||
from services.story_writer.story_service import StoryWriterService
|
||||
from .task_manager import task_manager
|
||||
from .cache_manager import cache_manager
|
||||
from uuid import uuid4
|
||||
from pydantic import BaseModel
|
||||
from pathlib import Path
|
||||
|
||||
from .utils.auth import require_authenticated_user
|
||||
from .utils.media_utils import resolve_media_file
|
||||
from .utils.hd_video import (
|
||||
generate_hd_video_payload,
|
||||
generate_hd_video_scene_payload,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/story", tags=["Story Writer"])
|
||||
@@ -503,11 +513,11 @@ async def get_task_status(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/task/{task_id}/result", response_model=StoryFullGenerationResponse)
|
||||
@router.get("/task/{task_id}/result")
|
||||
async def get_task_result(
|
||||
task_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> StoryFullGenerationResponse:
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the result of a completed story generation task."""
|
||||
try:
|
||||
if not current_user:
|
||||
@@ -528,7 +538,19 @@ async def get_task_result(
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail=f"No result found for task {task_id}")
|
||||
|
||||
return StoryFullGenerationResponse(**result, success=True, task_id=task_id)
|
||||
# Some tasks return a full-story payload compatible with StoryFullGenerationResponse,
|
||||
# others (e.g., video-only) return a dict like {"video": {...}, "success": True}.
|
||||
# To avoid model conflicts, return a generic payload and include task_id.
|
||||
# Frontend callers can branch on keys present (e.g., "video").
|
||||
if isinstance(result, dict):
|
||||
# Ensure success flag present without duplicating
|
||||
payload = {**result}
|
||||
payload.setdefault("success", True)
|
||||
payload["task_id"] = task_id
|
||||
return payload
|
||||
|
||||
# Fallback: wrap non-dict results
|
||||
return {"result": result, "success": True, "task_id": task_id}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -536,6 +558,69 @@ async def get_task_result(
|
||||
logger.error(f"[StoryWriter] Failed to get task result: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
class HDVideoRequest(BaseModel):
|
||||
prompt: str
|
||||
provider: str = "huggingface"
|
||||
model: str | None = None
|
||||
num_frames: int | None = None
|
||||
guidance_scale: float | None = None
|
||||
num_inference_steps: int | None = None
|
||||
negative_prompt: str | None = None
|
||||
seed: int | None = None
|
||||
|
||||
@router.post("/hd-video")
|
||||
async def generate_hd_video(
|
||||
request: HDVideoRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate an HD AI animation using provider text-to-video (Hugging Face for now).
|
||||
Saves the returned bytes as a video file and returns the secured URL.
|
||||
"""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
return generate_hd_video_payload(request, user_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryWriter] Failed to generate HD video: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
class HDVideoSceneRequest(BaseModel):
|
||||
scene_number: int
|
||||
scene_data: Dict[str, Any]
|
||||
story_context: Dict[str, Any]
|
||||
all_scenes: List[Dict[str, Any]]
|
||||
scene_image_url: Optional[str] = None
|
||||
provider: str = "huggingface"
|
||||
model: str | None = None
|
||||
num_frames: int | None = None
|
||||
guidance_scale: float | None = None
|
||||
num_inference_steps: int | None = None
|
||||
negative_prompt: str | None = None
|
||||
seed: int | None = None
|
||||
|
||||
|
||||
@router.post("/hd-video-scene")
|
||||
async def generate_hd_video_scene(
|
||||
request: HDVideoSceneRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate HD AI video for a single scene with AI-enhanced prompt.
|
||||
Uses prompt enhancer to create HunyuanVideo-optimized prompt from story context.
|
||||
"""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
return generate_hd_video_scene_payload(request, user_id)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryWriter] Failed to generate HD video for scene: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ---------------------------
|
||||
# Image Generation Endpoints
|
||||
@@ -614,31 +699,20 @@ async def serve_scene_image(
|
||||
):
|
||||
"""Serve a generated story scene image."""
|
||||
try:
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
# Import image generation service to get output directory
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
from services.story_writer.image_generation_service import StoryImageGenerationService
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
|
||||
image_service = StoryImageGenerationService()
|
||||
image_path = image_service.output_dir / image_filename
|
||||
|
||||
if not image_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Image not found: {image_filename}")
|
||||
|
||||
# Validate that the file is within the output directory (security check)
|
||||
try:
|
||||
image_path.resolve().relative_to(image_service.output_dir.resolve())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
image_path = resolve_media_file(image_service.output_dir, image_filename)
|
||||
|
||||
return FileResponse(
|
||||
path=str(image_path),
|
||||
media_type="image/png",
|
||||
filename=image_filename
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -726,31 +800,20 @@ async def serve_scene_audio(
|
||||
):
|
||||
"""Serve a generated story scene audio file."""
|
||||
try:
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
# Import audio generation service to get output directory
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
from services.story_writer.audio_generation_service import StoryAudioGenerationService
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
|
||||
audio_service = StoryAudioGenerationService()
|
||||
audio_path = audio_service.output_dir / audio_filename
|
||||
|
||||
if not audio_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Audio not found: {audio_filename}")
|
||||
|
||||
# Validate that the file is within the output directory (security check)
|
||||
try:
|
||||
audio_path.resolve().relative_to(audio_service.output_dir.resolve())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
audio_path = resolve_media_file(audio_service.output_dir, audio_filename)
|
||||
|
||||
return FileResponse(
|
||||
path=str(audio_path),
|
||||
media_type="audio/mpeg",
|
||||
filename=audio_filename
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -869,6 +932,99 @@ async def generate_story_video(
|
||||
logger.error(f"[StoryWriter] Failed to generate video: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/generate-video-async", response_model=Dict[str, Any])
|
||||
async def generate_story_video_async(
|
||||
request: StoryVideoGenerationRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a video asynchronously with progress updates via task manager.
|
||||
Frontend can poll /api/story/task/{task_id}/status to show progress messages.
|
||||
"""
|
||||
try:
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
if not request.scenes or len(request.scenes) == 0:
|
||||
raise HTTPException(status_code=400, detail="At least one scene is required")
|
||||
if len(request.scenes) != len(request.image_urls) or len(request.scenes) != len(request.audio_urls):
|
||||
raise HTTPException(status_code=400, detail="Number of scenes, image URLs, and audio URLs must match")
|
||||
|
||||
task_id = task_manager.create_task("story_video_generation")
|
||||
background_tasks.add_task(
|
||||
_execute_video_generation_task,
|
||||
task_id=task_id,
|
||||
request=request,
|
||||
user_id=user_id
|
||||
)
|
||||
return {"task_id": task_id, "status": "pending", "message": "Video generation started"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryWriter] Failed to start async video generation: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def _execute_video_generation_task(task_id: str, request: StoryVideoGenerationRequest, user_id: str):
|
||||
"""Background task to generate story video with progress mapped to task manager."""
|
||||
from services.story_writer.video_generation_service import StoryVideoGenerationService
|
||||
from services.story_writer.image_generation_service import StoryImageGenerationService
|
||||
from services.story_writer.audio_generation_service import StoryAudioGenerationService
|
||||
try:
|
||||
task_manager.update_task_status(task_id, "processing", progress=2.0, message="Initializing video generation...")
|
||||
video_service = StoryVideoGenerationService()
|
||||
image_service = StoryImageGenerationService()
|
||||
audio_service = StoryAudioGenerationService()
|
||||
|
||||
# Prepare assets
|
||||
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
|
||||
image_paths, audio_paths, valid_scenes = [], [], []
|
||||
for idx, (scene, image_url, audio_url) in enumerate(zip(scenes_data, request.image_urls, request.audio_urls)):
|
||||
image_filename = (image_url.split('/')[-1] if '/' in image_url else image_url).split('?')[0]
|
||||
audio_filename = (audio_url.split('/')[-1] if '/' in audio_url else audio_url).split('?')[0]
|
||||
image_path = image_service.output_dir / image_filename
|
||||
audio_path = audio_service.output_dir / audio_filename
|
||||
if not image_path.exists():
|
||||
logger.warning(f"[StoryWriter] Image not found: {image_path} (from URL: {image_url})")
|
||||
continue
|
||||
if not audio_path.exists():
|
||||
logger.warning(f"[StoryWriter] Audio not found: {audio_path} (from URL: {audio_url})")
|
||||
continue
|
||||
image_paths.append(str(image_path))
|
||||
audio_paths.append(str(audio_path))
|
||||
valid_scenes.append(scene)
|
||||
|
||||
if not image_paths or not audio_paths or len(image_paths) != len(audio_paths):
|
||||
raise RuntimeError("No valid or mismatched image/audio assets for video generation.")
|
||||
|
||||
# Map service progress (0-100) to task progress (5-95)
|
||||
def progress_callback(sub_progress: float, msg: str):
|
||||
overall = 5.0 + max(0.0, min(100.0, sub_progress)) * 0.9
|
||||
task_manager.update_task_status(task_id, "processing", progress=overall, message=msg)
|
||||
|
||||
result = video_service.generate_story_video(
|
||||
scenes=valid_scenes,
|
||||
image_paths=image_paths,
|
||||
audio_paths=audio_paths,
|
||||
user_id=user_id,
|
||||
story_title=request.story_title or "Story",
|
||||
fps=request.fps or 24,
|
||||
transition_duration=request.transition_duration or 0.5,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"completed",
|
||||
progress=100.0,
|
||||
message="Video generation complete!",
|
||||
result={"video": result, "success": True}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryWriter] Async video generation failed: {e}", exc_info=True)
|
||||
task_manager.update_task_status(task_id, "failed", error=str(e), message=f"Video generation failed: {e}")
|
||||
|
||||
@router.post("/generate-complete-video", response_model=Dict[str, Any])
|
||||
async def generate_complete_story_video(
|
||||
@@ -1111,31 +1267,20 @@ async def serve_story_video(
|
||||
):
|
||||
"""Serve a generated story video file."""
|
||||
try:
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
# Import video generation service to get output directory
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
from services.story_writer.video_generation_service import StoryVideoGenerationService
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
|
||||
video_service = StoryVideoGenerationService()
|
||||
video_path = video_service.output_dir / video_filename
|
||||
|
||||
if not video_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Video not found: {video_filename}")
|
||||
|
||||
# Validate that the file is within the output directory (security check)
|
||||
try:
|
||||
video_path.resolve().relative_to(video_service.output_dir.resolve())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
video_path = resolve_media_file(video_service.output_dir, video_filename)
|
||||
|
||||
return FileResponse(
|
||||
path=str(video_path),
|
||||
media_type="video/mp4",
|
||||
filename=video_filename
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
8
backend/api/story_writer/utils/__init__.py
Normal file
8
backend/api/story_writer/utils/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Utility helpers for Story Writer API routes.
|
||||
|
||||
Grouped here to keep the main router lean while reusing common logic
|
||||
such as authentication guards, media resolution, and HD video helpers.
|
||||
"""
|
||||
|
||||
|
||||
23
backend/api/story_writer/utils/auth.py
Normal file
23
backend/api/story_writer/utils/auth.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
def require_authenticated_user(current_user: Dict[str, Any] | None) -> str:
|
||||
"""
|
||||
Validates the current user dictionary provided by Clerk middleware and
|
||||
returns the normalized user_id. Raises HTTP 401 if authentication fails.
|
||||
"""
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get("id", "")).strip()
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid user ID in authentication token",
|
||||
)
|
||||
|
||||
return user_id
|
||||
|
||||
|
||||
154
backend/api/story_writer/utils/hd_video.py
Normal file
154
backend/api/story_writer/utils/hd_video.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
from uuid import uuid4
|
||||
|
||||
from .media_utils import load_story_image_bytes
|
||||
|
||||
|
||||
def generate_hd_video_payload(request: Any, user_id: str) -> Dict[str, Any]:
|
||||
"""Handles synchronous HD video generation."""
|
||||
from services.llm_providers.main_video_generation import ai_video_generate
|
||||
from services.story_writer.video_generation_service import StoryVideoGenerationService
|
||||
|
||||
video_service = StoryVideoGenerationService()
|
||||
output_dir = video_service.output_dir
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if getattr(request, "model", None):
|
||||
kwargs["model"] = request.model
|
||||
if getattr(request, "num_frames", None):
|
||||
kwargs["num_frames"] = request.num_frames
|
||||
if getattr(request, "guidance_scale", None) is not None:
|
||||
kwargs["guidance_scale"] = request.guidance_scale
|
||||
if getattr(request, "num_inference_steps", None):
|
||||
kwargs["num_inference_steps"] = request.num_inference_steps
|
||||
if getattr(request, "negative_prompt", None):
|
||||
kwargs["negative_prompt"] = request.negative_prompt
|
||||
if getattr(request, "seed", None) is not None:
|
||||
kwargs["seed"] = request.seed
|
||||
|
||||
logger.info(f"[StoryWriter] Generating HD video via {getattr(request, 'provider', 'huggingface')} for user {user_id}")
|
||||
raw_bytes = ai_video_generate(
|
||||
prompt=request.prompt,
|
||||
provider=getattr(request, "provider", None) or "huggingface",
|
||||
user_id=user_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
filename = f"hd_{uuid4().hex}.mp4"
|
||||
file_path = output_dir / filename
|
||||
with open(file_path, "wb") as fh:
|
||||
fh.write(raw_bytes)
|
||||
|
||||
logger.info(f"[StoryWriter] HD video saved to {file_path}")
|
||||
return {
|
||||
"success": True,
|
||||
"video_filename": filename,
|
||||
"video_url": f"/api/story/videos/{filename}",
|
||||
"provider": getattr(request, "provider", None) or "huggingface",
|
||||
"model": getattr(request, "model", None) or "tencent/HunyuanVideo",
|
||||
}
|
||||
|
||||
|
||||
def generate_hd_video_scene_payload(request: Any, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Handles per-scene HD video generation including prompt enhancement,
|
||||
subscription validation, and optional image conditioning.
|
||||
"""
|
||||
from services.database import get_db as get_db_validation
|
||||
from services.onboarding.api_key_manager import APIKeyManager
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_video_generation_operations
|
||||
from services.story_writer.prompt_enhancer_service import enhance_scene_prompt_for_video
|
||||
from services.llm_providers.main_video_generation import ai_video_generate
|
||||
from services.story_writer.video_generation_service import StoryVideoGenerationService
|
||||
|
||||
scene_number = request.scene_number
|
||||
logger.info(f"[StoryWriter] Generating HD video for scene {scene_number} for user {user_id}")
|
||||
|
||||
# Step 1: Validate API key
|
||||
hf_token = APIKeyManager().get_api_key("hf_token")
|
||||
if not hf_token:
|
||||
logger.error("[StoryWriter] Pre-flight: HF token not configured - blocking video generation")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Hugging Face API token is not configured. Please configure your HF token in settings.",
|
||||
"message": "Hugging Face API token is not configured. Please configure your HF token in settings.",
|
||||
},
|
||||
)
|
||||
|
||||
# Step 2: Subscription limits
|
||||
db_validation = next(get_db_validation())
|
||||
try:
|
||||
pricing_service = PricingService(db_validation)
|
||||
logger.info(f"[StoryWriter] Pre-flight: Checking video generation limits for user {user_id}...")
|
||||
validate_video_generation_operations(pricing_service=pricing_service, user_id=user_id)
|
||||
logger.info("[StoryWriter] Pre-flight: ✅ Video generation limits validated - proceeding")
|
||||
finally:
|
||||
db_validation.close()
|
||||
|
||||
# Stage 1: Prompt enhancement
|
||||
enhanced_prompt = enhance_scene_prompt_for_video(
|
||||
current_scene=request.scene_data,
|
||||
story_context=request.story_context,
|
||||
all_scenes=request.all_scenes,
|
||||
user_id=user_id,
|
||||
)
|
||||
logger.info(f"[StoryWriter] Generated enhanced prompt ({len(enhanced_prompt)} chars) for scene {scene_number}")
|
||||
|
||||
# Stage 2: Optional image reference
|
||||
scene_image_bytes: Optional[bytes] = None
|
||||
if getattr(request, "scene_image_url", None):
|
||||
scene_image_bytes = load_story_image_bytes(request.scene_image_url)
|
||||
if scene_image_bytes:
|
||||
logger.info(f"[StoryWriter] Using scene image reference for scene {scene_number}")
|
||||
else:
|
||||
logger.warning(f"[StoryWriter] Scene image could not be loaded for scene {scene_number}, falling back to text-only video")
|
||||
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if getattr(request, "model", None):
|
||||
kwargs["model"] = request.model
|
||||
if getattr(request, "num_frames", None):
|
||||
kwargs["num_frames"] = request.num_frames
|
||||
if getattr(request, "guidance_scale", None) is not None:
|
||||
kwargs["guidance_scale"] = request.guidance_scale
|
||||
if getattr(request, "num_inference_steps", None):
|
||||
kwargs["num_inference_steps"] = request.num_inference_steps
|
||||
if getattr(request, "negative_prompt", None):
|
||||
kwargs["negative_prompt"] = request.negative_prompt
|
||||
if getattr(request, "seed", None) is not None:
|
||||
kwargs["seed"] = request.seed
|
||||
|
||||
raw_bytes = ai_video_generate(
|
||||
prompt=enhanced_prompt,
|
||||
provider=getattr(request, "provider", None) or "huggingface",
|
||||
user_id=user_id,
|
||||
input_image_bytes=scene_image_bytes,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
video_service = StoryVideoGenerationService()
|
||||
save_result = video_service.save_scene_video(
|
||||
video_bytes=raw_bytes,
|
||||
scene_number=scene_number,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
logger.info(f"[StoryWriter] HD video saved for scene {scene_number}: {save_result.get('video_filename')}")
|
||||
return {
|
||||
"success": True,
|
||||
"scene_number": scene_number,
|
||||
"video_filename": save_result.get("video_filename"),
|
||||
"video_url": save_result.get("video_url"),
|
||||
"prompt_used": enhanced_prompt,
|
||||
"provider": getattr(request, "provider", None) or "huggingface",
|
||||
"model": getattr(request, "model", None) or "tencent/HunyuanVideo",
|
||||
}
|
||||
|
||||
|
||||
69
backend/api/story_writer/utils/media_utils.py
Normal file
69
backend/api/story_writer/utils/media_utils.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from loguru import logger
|
||||
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parents[3] # backend/
|
||||
STORY_IMAGES_DIR = (BASE_DIR / "story_images").resolve()
|
||||
STORY_IMAGES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def load_story_image_bytes(image_url: str) -> Optional[bytes]:
|
||||
"""
|
||||
Resolve an authenticated story image URL (e.g., /api/story/images/<file>) to raw bytes.
|
||||
Returns None if the file cannot be located.
|
||||
"""
|
||||
if not image_url:
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed = urlparse(image_url)
|
||||
path = parsed.path if parsed.scheme else image_url
|
||||
prefix = "/api/story/images/"
|
||||
if prefix not in path:
|
||||
logger.warning(f"[StoryWriter] Unsupported image URL for video reference: {image_url}")
|
||||
return None
|
||||
|
||||
filename = path.split(prefix, 1)[1].split("?", 1)[0].strip()
|
||||
if not filename:
|
||||
return None
|
||||
|
||||
file_path = (STORY_IMAGES_DIR / filename).resolve()
|
||||
if not str(file_path).startswith(str(STORY_IMAGES_DIR)):
|
||||
logger.error(f"[StoryWriter] Attempted path traversal when resolving image: {image_url}")
|
||||
return None
|
||||
|
||||
if not file_path.exists():
|
||||
logger.warning(f"[StoryWriter] Referenced scene image not found on disk: {file_path}")
|
||||
return None
|
||||
|
||||
return file_path.read_bytes()
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to load reference image for video gen: {exc}")
|
||||
return None
|
||||
|
||||
|
||||
def resolve_media_file(base_dir: Path, filename: str) -> Path:
|
||||
"""
|
||||
Returns a safe resolved path for a media file stored under base_dir.
|
||||
Guards against directory traversal and ensures the file exists.
|
||||
"""
|
||||
filename = filename.split("?")[0].strip()
|
||||
resolved = (base_dir / filename).resolve()
|
||||
|
||||
try:
|
||||
resolved.relative_to(base_dir.resolve())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
|
||||
|
||||
if not resolved.exists():
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"File not found: {filename}")
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from functools import lru_cache
|
||||
from services.database import get_db
|
||||
from services.subscription import UsageTrackingService, PricingService
|
||||
from services.subscription.log_wrapping_service import LogWrappingService
|
||||
from services.subscription.schema_utils import ensure_subscription_plan_columns
|
||||
from services.subscription.schema_utils import ensure_subscription_plan_columns, ensure_usage_summaries_columns
|
||||
import sqlite3
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from models.subscription_models import (
|
||||
@@ -114,6 +114,8 @@ async def get_subscription_plans(
|
||||
"metaphor_calls": plan.metaphor_calls_limit,
|
||||
"firecrawl_calls": plan.firecrawl_calls_limit,
|
||||
"stability_calls": plan.stability_calls_limit,
|
||||
"video_calls": getattr(plan, 'video_calls_limit', 0),
|
||||
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0),
|
||||
"gemini_tokens": plan.gemini_tokens_limit,
|
||||
"openai_tokens": plan.openai_tokens_limit,
|
||||
"anthropic_tokens": plan.anthropic_tokens_limit,
|
||||
@@ -132,7 +134,7 @@ async def get_subscription_plans(
|
||||
|
||||
except (sqlite3.OperationalError, Exception) as e:
|
||||
error_str = str(e).lower()
|
||||
if 'no such column' in error_str and 'exa_calls_limit' in error_str:
|
||||
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str):
|
||||
logger.warning("Missing column detected in subscription plans query, attempting schema fix...")
|
||||
try:
|
||||
import services.subscription.schema_utils as schema_utils
|
||||
@@ -237,6 +239,8 @@ async def get_user_subscription(
|
||||
"metaphor_calls": free_plan.metaphor_calls_limit,
|
||||
"firecrawl_calls": free_plan.firecrawl_calls_limit,
|
||||
"stability_calls": free_plan.stability_calls_limit,
|
||||
"video_calls": getattr(free_plan, 'video_calls_limit', 0),
|
||||
"image_edit_calls": getattr(free_plan, 'image_edit_calls_limit', 0),
|
||||
"monthly_cost": free_plan.monthly_cost_limit
|
||||
}
|
||||
}
|
||||
@@ -334,6 +338,8 @@ async def get_subscription_status(
|
||||
"metaphor_calls": free_plan.metaphor_calls_limit,
|
||||
"firecrawl_calls": free_plan.firecrawl_calls_limit,
|
||||
"stability_calls": free_plan.stability_calls_limit,
|
||||
"video_calls": getattr(free_plan, 'video_calls_limit', 0),
|
||||
"image_edit_calls": getattr(free_plan, 'image_edit_calls_limit', 0),
|
||||
"monthly_cost": free_plan.monthly_cost_limit
|
||||
}
|
||||
}
|
||||
@@ -399,15 +405,16 @@ async def get_subscription_status(
|
||||
|
||||
except (sqlite3.OperationalError, Exception) as e:
|
||||
error_str = str(e).lower()
|
||||
if 'no such column' in error_str and 'exa_calls_limit' in error_str:
|
||||
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str):
|
||||
# Try to fix schema and retry once
|
||||
logger.warning("Missing column detected in subscription status query, attempting schema fix...")
|
||||
try:
|
||||
import services.subscription.schema_utils as schema_utils
|
||||
schema_utils._checked_subscription_plan_columns = False
|
||||
ensure_subscription_plan_columns(db)
|
||||
db.commit() # Ensure schema changes are committed
|
||||
db.expire_all()
|
||||
# Retry the query
|
||||
# Retry the query - query subscription without eager loading plan
|
||||
subscription = db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
@@ -437,11 +444,21 @@ async def get_subscription_status(
|
||||
"metaphor_calls": free_plan.metaphor_calls_limit,
|
||||
"firecrawl_calls": free_plan.firecrawl_calls_limit,
|
||||
"stability_calls": free_plan.stability_calls_limit,
|
||||
"video_calls": getattr(free_plan, 'video_calls_limit', 0),
|
||||
"image_edit_calls": getattr(free_plan, 'image_edit_calls_limit', 0),
|
||||
"monthly_cost": free_plan.monthly_cost_limit
|
||||
}
|
||||
}
|
||||
}
|
||||
elif subscription:
|
||||
# Query plan separately after schema fix to avoid lazy loading issues
|
||||
plan = db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.id == subscription.plan_id
|
||||
).first()
|
||||
|
||||
if not plan:
|
||||
raise HTTPException(status_code=404, detail="Plan not found")
|
||||
|
||||
now = datetime.utcnow()
|
||||
if subscription.current_period_end < now:
|
||||
if getattr(subscription, 'auto_renew', False):
|
||||
@@ -456,8 +473,8 @@ async def get_subscription_status(
|
||||
"success": True,
|
||||
"data": {
|
||||
"active": False,
|
||||
"plan": subscription.plan.tier.value,
|
||||
"tier": subscription.plan.tier.value,
|
||||
"plan": plan.tier.value,
|
||||
"tier": plan.tier.value,
|
||||
"can_use_api": False,
|
||||
"reason": "Subscription expired"
|
||||
}
|
||||
@@ -466,21 +483,23 @@ async def get_subscription_status(
|
||||
"success": True,
|
||||
"data": {
|
||||
"active": True,
|
||||
"plan": subscription.plan.tier.value,
|
||||
"tier": subscription.plan.tier.value,
|
||||
"plan": plan.tier.value,
|
||||
"tier": plan.tier.value,
|
||||
"can_use_api": True,
|
||||
"limits": {
|
||||
"ai_text_generation_calls": getattr(subscription.plan, 'ai_text_generation_calls_limit', None) or 0,
|
||||
"gemini_calls": subscription.plan.gemini_calls_limit,
|
||||
"openai_calls": subscription.plan.openai_calls_limit,
|
||||
"anthropic_calls": subscription.plan.anthropic_calls_limit,
|
||||
"mistral_calls": subscription.plan.mistral_calls_limit,
|
||||
"tavily_calls": subscription.plan.tavily_calls_limit,
|
||||
"serper_calls": subscription.plan.serper_calls_limit,
|
||||
"metaphor_calls": subscription.plan.metaphor_calls_limit,
|
||||
"firecrawl_calls": subscription.plan.firecrawl_calls_limit,
|
||||
"stability_calls": subscription.plan.stability_calls_limit,
|
||||
"monthly_cost": subscription.plan.monthly_cost_limit
|
||||
"ai_text_generation_calls": getattr(plan, 'ai_text_generation_calls_limit', None) or 0,
|
||||
"gemini_calls": plan.gemini_calls_limit,
|
||||
"openai_calls": plan.openai_calls_limit,
|
||||
"anthropic_calls": plan.anthropic_calls_limit,
|
||||
"mistral_calls": plan.mistral_calls_limit,
|
||||
"tavily_calls": plan.tavily_calls_limit,
|
||||
"serper_calls": plan.serper_calls_limit,
|
||||
"metaphor_calls": plan.metaphor_calls_limit,
|
||||
"firecrawl_calls": plan.firecrawl_calls_limit,
|
||||
"stability_calls": plan.stability_calls_limit,
|
||||
"video_calls": getattr(plan, 'video_calls_limit', 0),
|
||||
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0),
|
||||
"monthly_cost": plan.monthly_cost_limit
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -893,6 +912,7 @@ async def get_dashboard_data(
|
||||
|
||||
try:
|
||||
ensure_subscription_plan_columns(db)
|
||||
ensure_usage_summaries_columns(db)
|
||||
# Serve from short TTL cache to avoid hammering DB on bursts
|
||||
import time
|
||||
now = time.time()
|
||||
@@ -966,7 +986,72 @@ async def get_dashboard_data(
|
||||
_dashboard_cache_ts[user_id] = now
|
||||
return response_payload
|
||||
|
||||
except Exception as e:
|
||||
except (sqlite3.OperationalError, Exception) as e:
|
||||
error_str = str(e).lower()
|
||||
if 'no such column' in error_str and ('exa_calls' in error_str or 'exa_cost' in error_str or 'video_calls' in error_str or 'video_cost' in error_str or 'image_edit_calls' in error_str or 'image_edit_cost' in error_str):
|
||||
logger.warning("Missing column detected in dashboard query, attempting schema fix...")
|
||||
try:
|
||||
import services.subscription.schema_utils as schema_utils
|
||||
schema_utils._checked_usage_summaries_columns = False
|
||||
schema_utils._checked_subscription_plan_columns = False
|
||||
ensure_usage_summaries_columns(db)
|
||||
ensure_subscription_plan_columns(db)
|
||||
db.expire_all()
|
||||
# Retry the query
|
||||
usage_service = UsageTrackingService(db)
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
current_usage = usage_service.get_user_usage_stats(user_id)
|
||||
trends = usage_service.get_usage_trends(user_id, 6)
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
|
||||
alerts = db.query(UsageAlert).filter(
|
||||
UsageAlert.user_id == user_id,
|
||||
UsageAlert.is_read == False
|
||||
).order_by(UsageAlert.created_at.desc()).limit(5).all()
|
||||
|
||||
alerts_data = [
|
||||
{
|
||||
"id": alert.id,
|
||||
"type": alert.alert_type,
|
||||
"title": alert.title,
|
||||
"message": alert.message,
|
||||
"severity": alert.severity,
|
||||
"created_at": alert.created_at.isoformat()
|
||||
}
|
||||
for alert in alerts
|
||||
]
|
||||
|
||||
current_cost = current_usage.get('total_cost', 0)
|
||||
days_in_period = 30
|
||||
current_day = datetime.now().day
|
||||
projected_cost = (current_cost / current_day) * days_in_period if current_day > 0 else 0
|
||||
|
||||
response_payload = {
|
||||
"success": True,
|
||||
"data": {
|
||||
"current_usage": current_usage,
|
||||
"trends": trends,
|
||||
"limits": limits,
|
||||
"alerts": alerts_data,
|
||||
"projections": {
|
||||
"projected_monthly_cost": round(projected_cost, 2),
|
||||
"cost_limit": limits.get('limits', {}).get('monthly_cost', 0) if limits else 0,
|
||||
"projected_usage_percentage": (projected_cost / max(limits.get('limits', {}).get('monthly_cost', 1), 1)) * 100 if limits else 0
|
||||
},
|
||||
"summary": {
|
||||
"total_api_calls_this_month": current_usage.get('total_calls', 0),
|
||||
"total_cost_this_month": current_usage.get('total_cost', 0),
|
||||
"usage_status": current_usage.get('usage_status', 'active'),
|
||||
"unread_alerts": len(alerts_data)
|
||||
}
|
||||
}
|
||||
}
|
||||
return response_payload
|
||||
except Exception as retry_err:
|
||||
logger.error(f"Schema fix and retry failed: {retry_err}")
|
||||
raise HTTPException(status_code=500, detail=f"Database schema error: {str(e)}")
|
||||
|
||||
logger.error(f"Error getting dashboard data: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user