AI Video Generation Implementation
This commit is contained in:
@@ -6,7 +6,7 @@ content generation, and full story creation.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
|
||||
from typing import Any, Dict, Union, List
|
||||
from typing import Any, Dict, Union, List, Optional
|
||||
from loguru import logger
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
@@ -37,6 +37,16 @@ from models.story_models import (
|
||||
from services.story_writer.story_service import StoryWriterService
|
||||
from .task_manager import task_manager
|
||||
from .cache_manager import cache_manager
|
||||
from uuid import uuid4
|
||||
from pydantic import BaseModel
|
||||
from pathlib import Path
|
||||
|
||||
from .utils.auth import require_authenticated_user
|
||||
from .utils.media_utils import resolve_media_file
|
||||
from .utils.hd_video import (
|
||||
generate_hd_video_payload,
|
||||
generate_hd_video_scene_payload,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/story", tags=["Story Writer"])
|
||||
@@ -503,11 +513,11 @@ async def get_task_status(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/task/{task_id}/result", response_model=StoryFullGenerationResponse)
|
||||
@router.get("/task/{task_id}/result")
|
||||
async def get_task_result(
|
||||
task_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> StoryFullGenerationResponse:
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the result of a completed story generation task."""
|
||||
try:
|
||||
if not current_user:
|
||||
@@ -528,7 +538,19 @@ async def get_task_result(
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail=f"No result found for task {task_id}")
|
||||
|
||||
return StoryFullGenerationResponse(**result, success=True, task_id=task_id)
|
||||
# Some tasks return a full-story payload compatible with StoryFullGenerationResponse,
|
||||
# others (e.g., video-only) return a dict like {"video": {...}, "success": True}.
|
||||
# To avoid model conflicts, return a generic payload and include task_id.
|
||||
# Frontend callers can branch on keys present (e.g., "video").
|
||||
if isinstance(result, dict):
|
||||
# Ensure success flag present without duplicating
|
||||
payload = {**result}
|
||||
payload.setdefault("success", True)
|
||||
payload["task_id"] = task_id
|
||||
return payload
|
||||
|
||||
# Fallback: wrap non-dict results
|
||||
return {"result": result, "success": True, "task_id": task_id}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -536,6 +558,69 @@ async def get_task_result(
|
||||
logger.error(f"[StoryWriter] Failed to get task result: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
class HDVideoRequest(BaseModel):
|
||||
prompt: str
|
||||
provider: str = "huggingface"
|
||||
model: str | None = None
|
||||
num_frames: int | None = None
|
||||
guidance_scale: float | None = None
|
||||
num_inference_steps: int | None = None
|
||||
negative_prompt: str | None = None
|
||||
seed: int | None = None
|
||||
|
||||
@router.post("/hd-video")
|
||||
async def generate_hd_video(
|
||||
request: HDVideoRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate an HD AI animation using provider text-to-video (Hugging Face for now).
|
||||
Saves the returned bytes as a video file and returns the secured URL.
|
||||
"""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
return generate_hd_video_payload(request, user_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryWriter] Failed to generate HD video: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
class HDVideoSceneRequest(BaseModel):
|
||||
scene_number: int
|
||||
scene_data: Dict[str, Any]
|
||||
story_context: Dict[str, Any]
|
||||
all_scenes: List[Dict[str, Any]]
|
||||
scene_image_url: Optional[str] = None
|
||||
provider: str = "huggingface"
|
||||
model: str | None = None
|
||||
num_frames: int | None = None
|
||||
guidance_scale: float | None = None
|
||||
num_inference_steps: int | None = None
|
||||
negative_prompt: str | None = None
|
||||
seed: int | None = None
|
||||
|
||||
|
||||
@router.post("/hd-video-scene")
|
||||
async def generate_hd_video_scene(
|
||||
request: HDVideoSceneRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate HD AI video for a single scene with AI-enhanced prompt.
|
||||
Uses prompt enhancer to create HunyuanVideo-optimized prompt from story context.
|
||||
"""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
return generate_hd_video_scene_payload(request, user_id)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryWriter] Failed to generate HD video for scene: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ---------------------------
|
||||
# Image Generation Endpoints
|
||||
@@ -614,31 +699,20 @@ async def serve_scene_image(
|
||||
):
|
||||
"""Serve a generated story scene image."""
|
||||
try:
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
# Import image generation service to get output directory
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
from services.story_writer.image_generation_service import StoryImageGenerationService
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
|
||||
image_service = StoryImageGenerationService()
|
||||
image_path = image_service.output_dir / image_filename
|
||||
|
||||
if not image_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Image not found: {image_filename}")
|
||||
|
||||
# Validate that the file is within the output directory (security check)
|
||||
try:
|
||||
image_path.resolve().relative_to(image_service.output_dir.resolve())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
image_path = resolve_media_file(image_service.output_dir, image_filename)
|
||||
|
||||
return FileResponse(
|
||||
path=str(image_path),
|
||||
media_type="image/png",
|
||||
filename=image_filename
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -726,31 +800,20 @@ async def serve_scene_audio(
|
||||
):
|
||||
"""Serve a generated story scene audio file."""
|
||||
try:
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
# Import audio generation service to get output directory
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
from services.story_writer.audio_generation_service import StoryAudioGenerationService
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
|
||||
audio_service = StoryAudioGenerationService()
|
||||
audio_path = audio_service.output_dir / audio_filename
|
||||
|
||||
if not audio_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Audio not found: {audio_filename}")
|
||||
|
||||
# Validate that the file is within the output directory (security check)
|
||||
try:
|
||||
audio_path.resolve().relative_to(audio_service.output_dir.resolve())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
audio_path = resolve_media_file(audio_service.output_dir, audio_filename)
|
||||
|
||||
return FileResponse(
|
||||
path=str(audio_path),
|
||||
media_type="audio/mpeg",
|
||||
filename=audio_filename
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -869,6 +932,99 @@ async def generate_story_video(
|
||||
logger.error(f"[StoryWriter] Failed to generate video: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/generate-video-async", response_model=Dict[str, Any])
|
||||
async def generate_story_video_async(
|
||||
request: StoryVideoGenerationRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a video asynchronously with progress updates via task manager.
|
||||
Frontend can poll /api/story/task/{task_id}/status to show progress messages.
|
||||
"""
|
||||
try:
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
if not request.scenes or len(request.scenes) == 0:
|
||||
raise HTTPException(status_code=400, detail="At least one scene is required")
|
||||
if len(request.scenes) != len(request.image_urls) or len(request.scenes) != len(request.audio_urls):
|
||||
raise HTTPException(status_code=400, detail="Number of scenes, image URLs, and audio URLs must match")
|
||||
|
||||
task_id = task_manager.create_task("story_video_generation")
|
||||
background_tasks.add_task(
|
||||
_execute_video_generation_task,
|
||||
task_id=task_id,
|
||||
request=request,
|
||||
user_id=user_id
|
||||
)
|
||||
return {"task_id": task_id, "status": "pending", "message": "Video generation started"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryWriter] Failed to start async video generation: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def _execute_video_generation_task(task_id: str, request: StoryVideoGenerationRequest, user_id: str):
|
||||
"""Background task to generate story video with progress mapped to task manager."""
|
||||
from services.story_writer.video_generation_service import StoryVideoGenerationService
|
||||
from services.story_writer.image_generation_service import StoryImageGenerationService
|
||||
from services.story_writer.audio_generation_service import StoryAudioGenerationService
|
||||
try:
|
||||
task_manager.update_task_status(task_id, "processing", progress=2.0, message="Initializing video generation...")
|
||||
video_service = StoryVideoGenerationService()
|
||||
image_service = StoryImageGenerationService()
|
||||
audio_service = StoryAudioGenerationService()
|
||||
|
||||
# Prepare assets
|
||||
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
|
||||
image_paths, audio_paths, valid_scenes = [], [], []
|
||||
for idx, (scene, image_url, audio_url) in enumerate(zip(scenes_data, request.image_urls, request.audio_urls)):
|
||||
image_filename = (image_url.split('/')[-1] if '/' in image_url else image_url).split('?')[0]
|
||||
audio_filename = (audio_url.split('/')[-1] if '/' in audio_url else audio_url).split('?')[0]
|
||||
image_path = image_service.output_dir / image_filename
|
||||
audio_path = audio_service.output_dir / audio_filename
|
||||
if not image_path.exists():
|
||||
logger.warning(f"[StoryWriter] Image not found: {image_path} (from URL: {image_url})")
|
||||
continue
|
||||
if not audio_path.exists():
|
||||
logger.warning(f"[StoryWriter] Audio not found: {audio_path} (from URL: {audio_url})")
|
||||
continue
|
||||
image_paths.append(str(image_path))
|
||||
audio_paths.append(str(audio_path))
|
||||
valid_scenes.append(scene)
|
||||
|
||||
if not image_paths or not audio_paths or len(image_paths) != len(audio_paths):
|
||||
raise RuntimeError("No valid or mismatched image/audio assets for video generation.")
|
||||
|
||||
# Map service progress (0-100) to task progress (5-95)
|
||||
def progress_callback(sub_progress: float, msg: str):
|
||||
overall = 5.0 + max(0.0, min(100.0, sub_progress)) * 0.9
|
||||
task_manager.update_task_status(task_id, "processing", progress=overall, message=msg)
|
||||
|
||||
result = video_service.generate_story_video(
|
||||
scenes=valid_scenes,
|
||||
image_paths=image_paths,
|
||||
audio_paths=audio_paths,
|
||||
user_id=user_id,
|
||||
story_title=request.story_title or "Story",
|
||||
fps=request.fps or 24,
|
||||
transition_duration=request.transition_duration or 0.5,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"completed",
|
||||
progress=100.0,
|
||||
message="Video generation complete!",
|
||||
result={"video": result, "success": True}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryWriter] Async video generation failed: {e}", exc_info=True)
|
||||
task_manager.update_task_status(task_id, "failed", error=str(e), message=f"Video generation failed: {e}")
|
||||
|
||||
@router.post("/generate-complete-video", response_model=Dict[str, Any])
|
||||
async def generate_complete_story_video(
|
||||
@@ -1111,31 +1267,20 @@ async def serve_story_video(
|
||||
):
|
||||
"""Serve a generated story video file."""
|
||||
try:
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
# Import video generation service to get output directory
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
from services.story_writer.video_generation_service import StoryVideoGenerationService
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
|
||||
video_service = StoryVideoGenerationService()
|
||||
video_path = video_service.output_dir / video_filename
|
||||
|
||||
if not video_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Video not found: {video_filename}")
|
||||
|
||||
# Validate that the file is within the output directory (security check)
|
||||
try:
|
||||
video_path.resolve().relative_to(video_service.output_dir.resolve())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
video_path = resolve_media_file(video_service.output_dir, video_filename)
|
||||
|
||||
return FileResponse(
|
||||
path=str(video_path),
|
||||
media_type="video/mp4",
|
||||
filename=video_filename
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user