AI Video Generation Implementation

This commit is contained in:
ajaysi
2025-11-17 17:38:23 +05:30
parent 4901b7eb72
commit bf7493c366
132 changed files with 6200 additions and 19475 deletions

View File

@@ -6,7 +6,7 @@ content generation, and full story creation.
"""
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
from typing import Any, Dict, Union, List
from typing import Any, Dict, Union, List, Optional
from loguru import logger
from middleware.auth_middleware import get_current_user
@@ -37,6 +37,16 @@ from models.story_models import (
from services.story_writer.story_service import StoryWriterService
from .task_manager import task_manager
from .cache_manager import cache_manager
from uuid import uuid4
from pydantic import BaseModel
from pathlib import Path
from .utils.auth import require_authenticated_user
from .utils.media_utils import resolve_media_file
from .utils.hd_video import (
generate_hd_video_payload,
generate_hd_video_scene_payload,
)
router = APIRouter(prefix="/api/story", tags=["Story Writer"])
@@ -503,11 +513,11 @@ async def get_task_status(
raise HTTPException(status_code=500, detail=str(e))
@router.get("/task/{task_id}/result", response_model=StoryFullGenerationResponse)
@router.get("/task/{task_id}/result")
async def get_task_result(
task_id: str,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> StoryFullGenerationResponse:
) -> Dict[str, Any]:
"""Get the result of a completed story generation task."""
try:
if not current_user:
@@ -528,7 +538,19 @@ async def get_task_result(
if not result:
raise HTTPException(status_code=404, detail=f"No result found for task {task_id}")
return StoryFullGenerationResponse(**result, success=True, task_id=task_id)
# Some tasks return a full-story payload compatible with StoryFullGenerationResponse,
# others (e.g., video-only) return a dict like {"video": {...}, "success": True}.
# To avoid model conflicts, return a generic payload and include task_id.
# Frontend callers can branch on keys present (e.g., "video").
if isinstance(result, dict):
# Ensure success flag present without duplicating
payload = {**result}
payload.setdefault("success", True)
payload["task_id"] = task_id
return payload
# Fallback: wrap non-dict results
return {"result": result, "success": True, "task_id": task_id}
except HTTPException:
raise
@@ -536,6 +558,69 @@ async def get_task_result(
logger.error(f"[StoryWriter] Failed to get task result: {e}")
raise HTTPException(status_code=500, detail=str(e))
class HDVideoRequest(BaseModel):
prompt: str
provider: str = "huggingface"
model: str | None = None
num_frames: int | None = None
guidance_scale: float | None = None
num_inference_steps: int | None = None
negative_prompt: str | None = None
seed: int | None = None
@router.post("/hd-video")
async def generate_hd_video(
request: HDVideoRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""
Generate an HD AI animation using provider text-to-video (Hugging Face for now).
Saves the returned bytes as a video file and returns the secured URL.
"""
try:
user_id = require_authenticated_user(current_user)
return generate_hd_video_payload(request, user_id)
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to generate HD video: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
class HDVideoSceneRequest(BaseModel):
scene_number: int
scene_data: Dict[str, Any]
story_context: Dict[str, Any]
all_scenes: List[Dict[str, Any]]
scene_image_url: Optional[str] = None
provider: str = "huggingface"
model: str | None = None
num_frames: int | None = None
guidance_scale: float | None = None
num_inference_steps: int | None = None
negative_prompt: str | None = None
seed: int | None = None
@router.post("/hd-video-scene")
async def generate_hd_video_scene(
request: HDVideoSceneRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""
Generate HD AI video for a single scene with AI-enhanced prompt.
Uses prompt enhancer to create HunyuanVideo-optimized prompt from story context.
"""
try:
user_id = require_authenticated_user(current_user)
return generate_hd_video_scene_payload(request, user_id)
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to generate HD video for scene: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# ---------------------------
# Image Generation Endpoints
@@ -614,31 +699,20 @@ async def serve_scene_image(
):
"""Serve a generated story scene image."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
# Import image generation service to get output directory
require_authenticated_user(current_user)
from services.story_writer.image_generation_service import StoryImageGenerationService
from fastapi.responses import FileResponse
image_service = StoryImageGenerationService()
image_path = image_service.output_dir / image_filename
if not image_path.exists():
raise HTTPException(status_code=404, detail=f"Image not found: {image_filename}")
# Validate that the file is within the output directory (security check)
try:
image_path.resolve().relative_to(image_service.output_dir.resolve())
except ValueError:
raise HTTPException(status_code=403, detail="Access denied")
image_path = resolve_media_file(image_service.output_dir, image_filename)
return FileResponse(
path=str(image_path),
media_type="image/png",
filename=image_filename
)
except HTTPException:
raise
except Exception as e:
@@ -726,31 +800,20 @@ async def serve_scene_audio(
):
"""Serve a generated story scene audio file."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
# Import audio generation service to get output directory
require_authenticated_user(current_user)
from services.story_writer.audio_generation_service import StoryAudioGenerationService
from fastapi.responses import FileResponse
audio_service = StoryAudioGenerationService()
audio_path = audio_service.output_dir / audio_filename
if not audio_path.exists():
raise HTTPException(status_code=404, detail=f"Audio not found: {audio_filename}")
# Validate that the file is within the output directory (security check)
try:
audio_path.resolve().relative_to(audio_service.output_dir.resolve())
except ValueError:
raise HTTPException(status_code=403, detail="Access denied")
audio_path = resolve_media_file(audio_service.output_dir, audio_filename)
return FileResponse(
path=str(audio_path),
media_type="audio/mpeg",
filename=audio_filename
)
except HTTPException:
raise
except Exception as e:
@@ -869,6 +932,99 @@ async def generate_story_video(
logger.error(f"[StoryWriter] Failed to generate video: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/generate-video-async", response_model=Dict[str, Any])
async def generate_story_video_async(
request: StoryVideoGenerationRequest,
background_tasks: BackgroundTasks,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""
Generate a video asynchronously with progress updates via task manager.
Frontend can poll /api/story/task/{task_id}/status to show progress messages.
"""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
if not request.scenes or len(request.scenes) == 0:
raise HTTPException(status_code=400, detail="At least one scene is required")
if len(request.scenes) != len(request.image_urls) or len(request.scenes) != len(request.audio_urls):
raise HTTPException(status_code=400, detail="Number of scenes, image URLs, and audio URLs must match")
task_id = task_manager.create_task("story_video_generation")
background_tasks.add_task(
_execute_video_generation_task,
task_id=task_id,
request=request,
user_id=user_id
)
return {"task_id": task_id, "status": "pending", "message": "Video generation started"}
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to start async video generation: {e}")
raise HTTPException(status_code=500, detail=str(e))
def _execute_video_generation_task(task_id: str, request: StoryVideoGenerationRequest, user_id: str):
"""Background task to generate story video with progress mapped to task manager."""
from services.story_writer.video_generation_service import StoryVideoGenerationService
from services.story_writer.image_generation_service import StoryImageGenerationService
from services.story_writer.audio_generation_service import StoryAudioGenerationService
try:
task_manager.update_task_status(task_id, "processing", progress=2.0, message="Initializing video generation...")
video_service = StoryVideoGenerationService()
image_service = StoryImageGenerationService()
audio_service = StoryAudioGenerationService()
# Prepare assets
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
image_paths, audio_paths, valid_scenes = [], [], []
for idx, (scene, image_url, audio_url) in enumerate(zip(scenes_data, request.image_urls, request.audio_urls)):
image_filename = (image_url.split('/')[-1] if '/' in image_url else image_url).split('?')[0]
audio_filename = (audio_url.split('/')[-1] if '/' in audio_url else audio_url).split('?')[0]
image_path = image_service.output_dir / image_filename
audio_path = audio_service.output_dir / audio_filename
if not image_path.exists():
logger.warning(f"[StoryWriter] Image not found: {image_path} (from URL: {image_url})")
continue
if not audio_path.exists():
logger.warning(f"[StoryWriter] Audio not found: {audio_path} (from URL: {audio_url})")
continue
image_paths.append(str(image_path))
audio_paths.append(str(audio_path))
valid_scenes.append(scene)
if not image_paths or not audio_paths or len(image_paths) != len(audio_paths):
raise RuntimeError("No valid or mismatched image/audio assets for video generation.")
# Map service progress (0-100) to task progress (5-95)
def progress_callback(sub_progress: float, msg: str):
overall = 5.0 + max(0.0, min(100.0, sub_progress)) * 0.9
task_manager.update_task_status(task_id, "processing", progress=overall, message=msg)
result = video_service.generate_story_video(
scenes=valid_scenes,
image_paths=image_paths,
audio_paths=audio_paths,
user_id=user_id,
story_title=request.story_title or "Story",
fps=request.fps or 24,
transition_duration=request.transition_duration or 0.5,
progress_callback=progress_callback
)
task_manager.update_task_status(
task_id,
"completed",
progress=100.0,
message="Video generation complete!",
result={"video": result, "success": True}
)
except Exception as e:
logger.error(f"[StoryWriter] Async video generation failed: {e}", exc_info=True)
task_manager.update_task_status(task_id, "failed", error=str(e), message=f"Video generation failed: {e}")
@router.post("/generate-complete-video", response_model=Dict[str, Any])
async def generate_complete_story_video(
@@ -1111,31 +1267,20 @@ async def serve_story_video(
):
"""Serve a generated story video file."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
# Import video generation service to get output directory
require_authenticated_user(current_user)
from services.story_writer.video_generation_service import StoryVideoGenerationService
from fastapi.responses import FileResponse
video_service = StoryVideoGenerationService()
video_path = video_service.output_dir / video_filename
if not video_path.exists():
raise HTTPException(status_code=404, detail=f"Video not found: {video_filename}")
# Validate that the file is within the output directory (security check)
try:
video_path.resolve().relative_to(video_service.output_dir.resolve())
except ValueError:
raise HTTPException(status_code=403, detail="Access denied")
video_path = resolve_media_file(video_service.output_dir, video_filename)
return FileResponse(
path=str(video_path),
media_type="video/mp4",
filename=video_filename
)
except HTTPException:
raise
except Exception as e:

View File

@@ -0,0 +1,8 @@
"""
Utility helpers for Story Writer API routes.
Grouped here to keep the main router lean while reusing common logic
such as authentication guards, media resolution, and HD video helpers.
"""

View File

@@ -0,0 +1,23 @@
from typing import Any, Dict
from fastapi import HTTPException, status
def require_authenticated_user(current_user: Dict[str, Any] | None) -> str:
"""
Validates the current user dictionary provided by Clerk middleware and
returns the normalized user_id. Raises HTTP 401 if authentication fails.
"""
if not current_user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required")
user_id = str(current_user.get("id", "")).strip()
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid user ID in authentication token",
)
return user_id

View File

@@ -0,0 +1,154 @@
from __future__ import annotations
from typing import Any, Dict, Optional
from fastapi import HTTPException
from loguru import logger
from uuid import uuid4
from .media_utils import load_story_image_bytes
def generate_hd_video_payload(request: Any, user_id: str) -> Dict[str, Any]:
"""Handles synchronous HD video generation."""
from services.llm_providers.main_video_generation import ai_video_generate
from services.story_writer.video_generation_service import StoryVideoGenerationService
video_service = StoryVideoGenerationService()
output_dir = video_service.output_dir
output_dir.mkdir(parents=True, exist_ok=True)
kwargs: Dict[str, Any] = {}
if getattr(request, "model", None):
kwargs["model"] = request.model
if getattr(request, "num_frames", None):
kwargs["num_frames"] = request.num_frames
if getattr(request, "guidance_scale", None) is not None:
kwargs["guidance_scale"] = request.guidance_scale
if getattr(request, "num_inference_steps", None):
kwargs["num_inference_steps"] = request.num_inference_steps
if getattr(request, "negative_prompt", None):
kwargs["negative_prompt"] = request.negative_prompt
if getattr(request, "seed", None) is not None:
kwargs["seed"] = request.seed
logger.info(f"[StoryWriter] Generating HD video via {getattr(request, 'provider', 'huggingface')} for user {user_id}")
raw_bytes = ai_video_generate(
prompt=request.prompt,
provider=getattr(request, "provider", None) or "huggingface",
user_id=user_id,
**kwargs,
)
filename = f"hd_{uuid4().hex}.mp4"
file_path = output_dir / filename
with open(file_path, "wb") as fh:
fh.write(raw_bytes)
logger.info(f"[StoryWriter] HD video saved to {file_path}")
return {
"success": True,
"video_filename": filename,
"video_url": f"/api/story/videos/{filename}",
"provider": getattr(request, "provider", None) or "huggingface",
"model": getattr(request, "model", None) or "tencent/HunyuanVideo",
}
def generate_hd_video_scene_payload(request: Any, user_id: str) -> Dict[str, Any]:
"""
Handles per-scene HD video generation including prompt enhancement,
subscription validation, and optional image conditioning.
"""
from services.database import get_db as get_db_validation
from services.onboarding.api_key_manager import APIKeyManager
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_video_generation_operations
from services.story_writer.prompt_enhancer_service import enhance_scene_prompt_for_video
from services.llm_providers.main_video_generation import ai_video_generate
from services.story_writer.video_generation_service import StoryVideoGenerationService
scene_number = request.scene_number
logger.info(f"[StoryWriter] Generating HD video for scene {scene_number} for user {user_id}")
# Step 1: Validate API key
hf_token = APIKeyManager().get_api_key("hf_token")
if not hf_token:
logger.error("[StoryWriter] Pre-flight: HF token not configured - blocking video generation")
raise HTTPException(
status_code=400,
detail={
"error": "Hugging Face API token is not configured. Please configure your HF token in settings.",
"message": "Hugging Face API token is not configured. Please configure your HF token in settings.",
},
)
# Step 2: Subscription limits
db_validation = next(get_db_validation())
try:
pricing_service = PricingService(db_validation)
logger.info(f"[StoryWriter] Pre-flight: Checking video generation limits for user {user_id}...")
validate_video_generation_operations(pricing_service=pricing_service, user_id=user_id)
logger.info("[StoryWriter] Pre-flight: ✅ Video generation limits validated - proceeding")
finally:
db_validation.close()
# Stage 1: Prompt enhancement
enhanced_prompt = enhance_scene_prompt_for_video(
current_scene=request.scene_data,
story_context=request.story_context,
all_scenes=request.all_scenes,
user_id=user_id,
)
logger.info(f"[StoryWriter] Generated enhanced prompt ({len(enhanced_prompt)} chars) for scene {scene_number}")
# Stage 2: Optional image reference
scene_image_bytes: Optional[bytes] = None
if getattr(request, "scene_image_url", None):
scene_image_bytes = load_story_image_bytes(request.scene_image_url)
if scene_image_bytes:
logger.info(f"[StoryWriter] Using scene image reference for scene {scene_number}")
else:
logger.warning(f"[StoryWriter] Scene image could not be loaded for scene {scene_number}, falling back to text-only video")
kwargs: Dict[str, Any] = {}
if getattr(request, "model", None):
kwargs["model"] = request.model
if getattr(request, "num_frames", None):
kwargs["num_frames"] = request.num_frames
if getattr(request, "guidance_scale", None) is not None:
kwargs["guidance_scale"] = request.guidance_scale
if getattr(request, "num_inference_steps", None):
kwargs["num_inference_steps"] = request.num_inference_steps
if getattr(request, "negative_prompt", None):
kwargs["negative_prompt"] = request.negative_prompt
if getattr(request, "seed", None) is not None:
kwargs["seed"] = request.seed
raw_bytes = ai_video_generate(
prompt=enhanced_prompt,
provider=getattr(request, "provider", None) or "huggingface",
user_id=user_id,
input_image_bytes=scene_image_bytes,
**kwargs,
)
video_service = StoryVideoGenerationService()
save_result = video_service.save_scene_video(
video_bytes=raw_bytes,
scene_number=scene_number,
user_id=user_id,
)
logger.info(f"[StoryWriter] HD video saved for scene {scene_number}: {save_result.get('video_filename')}")
return {
"success": True,
"scene_number": scene_number,
"video_filename": save_result.get("video_filename"),
"video_url": save_result.get("video_url"),
"prompt_used": enhanced_prompt,
"provider": getattr(request, "provider", None) or "huggingface",
"model": getattr(request, "model", None) or "tencent/HunyuanVideo",
}

View File

@@ -0,0 +1,69 @@
from __future__ import annotations
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
from fastapi import HTTPException, status
from loguru import logger
BASE_DIR = Path(__file__).resolve().parents[3] # backend/
STORY_IMAGES_DIR = (BASE_DIR / "story_images").resolve()
STORY_IMAGES_DIR.mkdir(parents=True, exist_ok=True)
def load_story_image_bytes(image_url: str) -> Optional[bytes]:
"""
Resolve an authenticated story image URL (e.g., /api/story/images/<file>) to raw bytes.
Returns None if the file cannot be located.
"""
if not image_url:
return None
try:
parsed = urlparse(image_url)
path = parsed.path if parsed.scheme else image_url
prefix = "/api/story/images/"
if prefix not in path:
logger.warning(f"[StoryWriter] Unsupported image URL for video reference: {image_url}")
return None
filename = path.split(prefix, 1)[1].split("?", 1)[0].strip()
if not filename:
return None
file_path = (STORY_IMAGES_DIR / filename).resolve()
if not str(file_path).startswith(str(STORY_IMAGES_DIR)):
logger.error(f"[StoryWriter] Attempted path traversal when resolving image: {image_url}")
return None
if not file_path.exists():
logger.warning(f"[StoryWriter] Referenced scene image not found on disk: {file_path}")
return None
return file_path.read_bytes()
except Exception as exc:
logger.error(f"[StoryWriter] Failed to load reference image for video gen: {exc}")
return None
def resolve_media_file(base_dir: Path, filename: str) -> Path:
"""
Returns a safe resolved path for a media file stored under base_dir.
Guards against directory traversal and ensures the file exists.
"""
filename = filename.split("?")[0].strip()
resolved = (base_dir / filename).resolve()
try:
resolved.relative_to(base_dir.resolve())
except ValueError:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
if not resolved.exists():
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"File not found: {filename}")
return resolved