AI Video Generation Implementation
This commit is contained in:
@@ -6,7 +6,7 @@ content generation, and full story creation.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
|
||||
from typing import Any, Dict, Union, List
|
||||
from typing import Any, Dict, Union, List, Optional
|
||||
from loguru import logger
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
@@ -37,6 +37,16 @@ from models.story_models import (
|
||||
from services.story_writer.story_service import StoryWriterService
|
||||
from .task_manager import task_manager
|
||||
from .cache_manager import cache_manager
|
||||
from uuid import uuid4
|
||||
from pydantic import BaseModel
|
||||
from pathlib import Path
|
||||
|
||||
from .utils.auth import require_authenticated_user
|
||||
from .utils.media_utils import resolve_media_file
|
||||
from .utils.hd_video import (
|
||||
generate_hd_video_payload,
|
||||
generate_hd_video_scene_payload,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/story", tags=["Story Writer"])
|
||||
@@ -503,11 +513,11 @@ async def get_task_status(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/task/{task_id}/result", response_model=StoryFullGenerationResponse)
|
||||
@router.get("/task/{task_id}/result")
|
||||
async def get_task_result(
|
||||
task_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> StoryFullGenerationResponse:
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the result of a completed story generation task."""
|
||||
try:
|
||||
if not current_user:
|
||||
@@ -528,7 +538,19 @@ async def get_task_result(
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail=f"No result found for task {task_id}")
|
||||
|
||||
return StoryFullGenerationResponse(**result, success=True, task_id=task_id)
|
||||
# Some tasks return a full-story payload compatible with StoryFullGenerationResponse,
|
||||
# others (e.g., video-only) return a dict like {"video": {...}, "success": True}.
|
||||
# To avoid model conflicts, return a generic payload and include task_id.
|
||||
# Frontend callers can branch on keys present (e.g., "video").
|
||||
if isinstance(result, dict):
|
||||
# Ensure success flag present without duplicating
|
||||
payload = {**result}
|
||||
payload.setdefault("success", True)
|
||||
payload["task_id"] = task_id
|
||||
return payload
|
||||
|
||||
# Fallback: wrap non-dict results
|
||||
return {"result": result, "success": True, "task_id": task_id}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -536,6 +558,69 @@ async def get_task_result(
|
||||
logger.error(f"[StoryWriter] Failed to get task result: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
class HDVideoRequest(BaseModel):
|
||||
prompt: str
|
||||
provider: str = "huggingface"
|
||||
model: str | None = None
|
||||
num_frames: int | None = None
|
||||
guidance_scale: float | None = None
|
||||
num_inference_steps: int | None = None
|
||||
negative_prompt: str | None = None
|
||||
seed: int | None = None
|
||||
|
||||
@router.post("/hd-video")
|
||||
async def generate_hd_video(
|
||||
request: HDVideoRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate an HD AI animation using provider text-to-video (Hugging Face for now).
|
||||
Saves the returned bytes as a video file and returns the secured URL.
|
||||
"""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
return generate_hd_video_payload(request, user_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryWriter] Failed to generate HD video: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
class HDVideoSceneRequest(BaseModel):
|
||||
scene_number: int
|
||||
scene_data: Dict[str, Any]
|
||||
story_context: Dict[str, Any]
|
||||
all_scenes: List[Dict[str, Any]]
|
||||
scene_image_url: Optional[str] = None
|
||||
provider: str = "huggingface"
|
||||
model: str | None = None
|
||||
num_frames: int | None = None
|
||||
guidance_scale: float | None = None
|
||||
num_inference_steps: int | None = None
|
||||
negative_prompt: str | None = None
|
||||
seed: int | None = None
|
||||
|
||||
|
||||
@router.post("/hd-video-scene")
|
||||
async def generate_hd_video_scene(
|
||||
request: HDVideoSceneRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate HD AI video for a single scene with AI-enhanced prompt.
|
||||
Uses prompt enhancer to create HunyuanVideo-optimized prompt from story context.
|
||||
"""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
return generate_hd_video_scene_payload(request, user_id)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryWriter] Failed to generate HD video for scene: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ---------------------------
|
||||
# Image Generation Endpoints
|
||||
@@ -614,31 +699,20 @@ async def serve_scene_image(
|
||||
):
|
||||
"""Serve a generated story scene image."""
|
||||
try:
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
# Import image generation service to get output directory
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
from services.story_writer.image_generation_service import StoryImageGenerationService
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
|
||||
image_service = StoryImageGenerationService()
|
||||
image_path = image_service.output_dir / image_filename
|
||||
|
||||
if not image_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Image not found: {image_filename}")
|
||||
|
||||
# Validate that the file is within the output directory (security check)
|
||||
try:
|
||||
image_path.resolve().relative_to(image_service.output_dir.resolve())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
image_path = resolve_media_file(image_service.output_dir, image_filename)
|
||||
|
||||
return FileResponse(
|
||||
path=str(image_path),
|
||||
media_type="image/png",
|
||||
filename=image_filename
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -726,31 +800,20 @@ async def serve_scene_audio(
|
||||
):
|
||||
"""Serve a generated story scene audio file."""
|
||||
try:
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
# Import audio generation service to get output directory
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
from services.story_writer.audio_generation_service import StoryAudioGenerationService
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
|
||||
audio_service = StoryAudioGenerationService()
|
||||
audio_path = audio_service.output_dir / audio_filename
|
||||
|
||||
if not audio_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Audio not found: {audio_filename}")
|
||||
|
||||
# Validate that the file is within the output directory (security check)
|
||||
try:
|
||||
audio_path.resolve().relative_to(audio_service.output_dir.resolve())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
audio_path = resolve_media_file(audio_service.output_dir, audio_filename)
|
||||
|
||||
return FileResponse(
|
||||
path=str(audio_path),
|
||||
media_type="audio/mpeg",
|
||||
filename=audio_filename
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -869,6 +932,99 @@ async def generate_story_video(
|
||||
logger.error(f"[StoryWriter] Failed to generate video: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/generate-video-async", response_model=Dict[str, Any])
|
||||
async def generate_story_video_async(
|
||||
request: StoryVideoGenerationRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a video asynchronously with progress updates via task manager.
|
||||
Frontend can poll /api/story/task/{task_id}/status to show progress messages.
|
||||
"""
|
||||
try:
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
if not request.scenes or len(request.scenes) == 0:
|
||||
raise HTTPException(status_code=400, detail="At least one scene is required")
|
||||
if len(request.scenes) != len(request.image_urls) or len(request.scenes) != len(request.audio_urls):
|
||||
raise HTTPException(status_code=400, detail="Number of scenes, image URLs, and audio URLs must match")
|
||||
|
||||
task_id = task_manager.create_task("story_video_generation")
|
||||
background_tasks.add_task(
|
||||
_execute_video_generation_task,
|
||||
task_id=task_id,
|
||||
request=request,
|
||||
user_id=user_id
|
||||
)
|
||||
return {"task_id": task_id, "status": "pending", "message": "Video generation started"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryWriter] Failed to start async video generation: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def _execute_video_generation_task(task_id: str, request: StoryVideoGenerationRequest, user_id: str):
|
||||
"""Background task to generate story video with progress mapped to task manager."""
|
||||
from services.story_writer.video_generation_service import StoryVideoGenerationService
|
||||
from services.story_writer.image_generation_service import StoryImageGenerationService
|
||||
from services.story_writer.audio_generation_service import StoryAudioGenerationService
|
||||
try:
|
||||
task_manager.update_task_status(task_id, "processing", progress=2.0, message="Initializing video generation...")
|
||||
video_service = StoryVideoGenerationService()
|
||||
image_service = StoryImageGenerationService()
|
||||
audio_service = StoryAudioGenerationService()
|
||||
|
||||
# Prepare assets
|
||||
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
|
||||
image_paths, audio_paths, valid_scenes = [], [], []
|
||||
for idx, (scene, image_url, audio_url) in enumerate(zip(scenes_data, request.image_urls, request.audio_urls)):
|
||||
image_filename = (image_url.split('/')[-1] if '/' in image_url else image_url).split('?')[0]
|
||||
audio_filename = (audio_url.split('/')[-1] if '/' in audio_url else audio_url).split('?')[0]
|
||||
image_path = image_service.output_dir / image_filename
|
||||
audio_path = audio_service.output_dir / audio_filename
|
||||
if not image_path.exists():
|
||||
logger.warning(f"[StoryWriter] Image not found: {image_path} (from URL: {image_url})")
|
||||
continue
|
||||
if not audio_path.exists():
|
||||
logger.warning(f"[StoryWriter] Audio not found: {audio_path} (from URL: {audio_url})")
|
||||
continue
|
||||
image_paths.append(str(image_path))
|
||||
audio_paths.append(str(audio_path))
|
||||
valid_scenes.append(scene)
|
||||
|
||||
if not image_paths or not audio_paths or len(image_paths) != len(audio_paths):
|
||||
raise RuntimeError("No valid or mismatched image/audio assets for video generation.")
|
||||
|
||||
# Map service progress (0-100) to task progress (5-95)
|
||||
def progress_callback(sub_progress: float, msg: str):
|
||||
overall = 5.0 + max(0.0, min(100.0, sub_progress)) * 0.9
|
||||
task_manager.update_task_status(task_id, "processing", progress=overall, message=msg)
|
||||
|
||||
result = video_service.generate_story_video(
|
||||
scenes=valid_scenes,
|
||||
image_paths=image_paths,
|
||||
audio_paths=audio_paths,
|
||||
user_id=user_id,
|
||||
story_title=request.story_title or "Story",
|
||||
fps=request.fps or 24,
|
||||
transition_duration=request.transition_duration or 0.5,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"completed",
|
||||
progress=100.0,
|
||||
message="Video generation complete!",
|
||||
result={"video": result, "success": True}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryWriter] Async video generation failed: {e}", exc_info=True)
|
||||
task_manager.update_task_status(task_id, "failed", error=str(e), message=f"Video generation failed: {e}")
|
||||
|
||||
@router.post("/generate-complete-video", response_model=Dict[str, Any])
|
||||
async def generate_complete_story_video(
|
||||
@@ -1111,31 +1267,20 @@ async def serve_story_video(
|
||||
):
|
||||
"""Serve a generated story video file."""
|
||||
try:
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
# Import video generation service to get output directory
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
from services.story_writer.video_generation_service import StoryVideoGenerationService
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
|
||||
video_service = StoryVideoGenerationService()
|
||||
video_path = video_service.output_dir / video_filename
|
||||
|
||||
if not video_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Video not found: {video_filename}")
|
||||
|
||||
# Validate that the file is within the output directory (security check)
|
||||
try:
|
||||
video_path.resolve().relative_to(video_service.output_dir.resolve())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
video_path = resolve_media_file(video_service.output_dir, video_filename)
|
||||
|
||||
return FileResponse(
|
||||
path=str(video_path),
|
||||
media_type="video/mp4",
|
||||
filename=video_filename
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
8
backend/api/story_writer/utils/__init__.py
Normal file
8
backend/api/story_writer/utils/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Utility helpers for Story Writer API routes.
|
||||
|
||||
Grouped here to keep the main router lean while reusing common logic
|
||||
such as authentication guards, media resolution, and HD video helpers.
|
||||
"""
|
||||
|
||||
|
||||
23
backend/api/story_writer/utils/auth.py
Normal file
23
backend/api/story_writer/utils/auth.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
def require_authenticated_user(current_user: Dict[str, Any] | None) -> str:
|
||||
"""
|
||||
Validates the current user dictionary provided by Clerk middleware and
|
||||
returns the normalized user_id. Raises HTTP 401 if authentication fails.
|
||||
"""
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get("id", "")).strip()
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid user ID in authentication token",
|
||||
)
|
||||
|
||||
return user_id
|
||||
|
||||
|
||||
154
backend/api/story_writer/utils/hd_video.py
Normal file
154
backend/api/story_writer/utils/hd_video.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
from uuid import uuid4
|
||||
|
||||
from .media_utils import load_story_image_bytes
|
||||
|
||||
|
||||
def generate_hd_video_payload(request: Any, user_id: str) -> Dict[str, Any]:
|
||||
"""Handles synchronous HD video generation."""
|
||||
from services.llm_providers.main_video_generation import ai_video_generate
|
||||
from services.story_writer.video_generation_service import StoryVideoGenerationService
|
||||
|
||||
video_service = StoryVideoGenerationService()
|
||||
output_dir = video_service.output_dir
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if getattr(request, "model", None):
|
||||
kwargs["model"] = request.model
|
||||
if getattr(request, "num_frames", None):
|
||||
kwargs["num_frames"] = request.num_frames
|
||||
if getattr(request, "guidance_scale", None) is not None:
|
||||
kwargs["guidance_scale"] = request.guidance_scale
|
||||
if getattr(request, "num_inference_steps", None):
|
||||
kwargs["num_inference_steps"] = request.num_inference_steps
|
||||
if getattr(request, "negative_prompt", None):
|
||||
kwargs["negative_prompt"] = request.negative_prompt
|
||||
if getattr(request, "seed", None) is not None:
|
||||
kwargs["seed"] = request.seed
|
||||
|
||||
logger.info(f"[StoryWriter] Generating HD video via {getattr(request, 'provider', 'huggingface')} for user {user_id}")
|
||||
raw_bytes = ai_video_generate(
|
||||
prompt=request.prompt,
|
||||
provider=getattr(request, "provider", None) or "huggingface",
|
||||
user_id=user_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
filename = f"hd_{uuid4().hex}.mp4"
|
||||
file_path = output_dir / filename
|
||||
with open(file_path, "wb") as fh:
|
||||
fh.write(raw_bytes)
|
||||
|
||||
logger.info(f"[StoryWriter] HD video saved to {file_path}")
|
||||
return {
|
||||
"success": True,
|
||||
"video_filename": filename,
|
||||
"video_url": f"/api/story/videos/{filename}",
|
||||
"provider": getattr(request, "provider", None) or "huggingface",
|
||||
"model": getattr(request, "model", None) or "tencent/HunyuanVideo",
|
||||
}
|
||||
|
||||
|
||||
def generate_hd_video_scene_payload(request: Any, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Handles per-scene HD video generation including prompt enhancement,
|
||||
subscription validation, and optional image conditioning.
|
||||
"""
|
||||
from services.database import get_db as get_db_validation
|
||||
from services.onboarding.api_key_manager import APIKeyManager
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_video_generation_operations
|
||||
from services.story_writer.prompt_enhancer_service import enhance_scene_prompt_for_video
|
||||
from services.llm_providers.main_video_generation import ai_video_generate
|
||||
from services.story_writer.video_generation_service import StoryVideoGenerationService
|
||||
|
||||
scene_number = request.scene_number
|
||||
logger.info(f"[StoryWriter] Generating HD video for scene {scene_number} for user {user_id}")
|
||||
|
||||
# Step 1: Validate API key
|
||||
hf_token = APIKeyManager().get_api_key("hf_token")
|
||||
if not hf_token:
|
||||
logger.error("[StoryWriter] Pre-flight: HF token not configured - blocking video generation")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Hugging Face API token is not configured. Please configure your HF token in settings.",
|
||||
"message": "Hugging Face API token is not configured. Please configure your HF token in settings.",
|
||||
},
|
||||
)
|
||||
|
||||
# Step 2: Subscription limits
|
||||
db_validation = next(get_db_validation())
|
||||
try:
|
||||
pricing_service = PricingService(db_validation)
|
||||
logger.info(f"[StoryWriter] Pre-flight: Checking video generation limits for user {user_id}...")
|
||||
validate_video_generation_operations(pricing_service=pricing_service, user_id=user_id)
|
||||
logger.info("[StoryWriter] Pre-flight: ✅ Video generation limits validated - proceeding")
|
||||
finally:
|
||||
db_validation.close()
|
||||
|
||||
# Stage 1: Prompt enhancement
|
||||
enhanced_prompt = enhance_scene_prompt_for_video(
|
||||
current_scene=request.scene_data,
|
||||
story_context=request.story_context,
|
||||
all_scenes=request.all_scenes,
|
||||
user_id=user_id,
|
||||
)
|
||||
logger.info(f"[StoryWriter] Generated enhanced prompt ({len(enhanced_prompt)} chars) for scene {scene_number}")
|
||||
|
||||
# Stage 2: Optional image reference
|
||||
scene_image_bytes: Optional[bytes] = None
|
||||
if getattr(request, "scene_image_url", None):
|
||||
scene_image_bytes = load_story_image_bytes(request.scene_image_url)
|
||||
if scene_image_bytes:
|
||||
logger.info(f"[StoryWriter] Using scene image reference for scene {scene_number}")
|
||||
else:
|
||||
logger.warning(f"[StoryWriter] Scene image could not be loaded for scene {scene_number}, falling back to text-only video")
|
||||
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if getattr(request, "model", None):
|
||||
kwargs["model"] = request.model
|
||||
if getattr(request, "num_frames", None):
|
||||
kwargs["num_frames"] = request.num_frames
|
||||
if getattr(request, "guidance_scale", None) is not None:
|
||||
kwargs["guidance_scale"] = request.guidance_scale
|
||||
if getattr(request, "num_inference_steps", None):
|
||||
kwargs["num_inference_steps"] = request.num_inference_steps
|
||||
if getattr(request, "negative_prompt", None):
|
||||
kwargs["negative_prompt"] = request.negative_prompt
|
||||
if getattr(request, "seed", None) is not None:
|
||||
kwargs["seed"] = request.seed
|
||||
|
||||
raw_bytes = ai_video_generate(
|
||||
prompt=enhanced_prompt,
|
||||
provider=getattr(request, "provider", None) or "huggingface",
|
||||
user_id=user_id,
|
||||
input_image_bytes=scene_image_bytes,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
video_service = StoryVideoGenerationService()
|
||||
save_result = video_service.save_scene_video(
|
||||
video_bytes=raw_bytes,
|
||||
scene_number=scene_number,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
logger.info(f"[StoryWriter] HD video saved for scene {scene_number}: {save_result.get('video_filename')}")
|
||||
return {
|
||||
"success": True,
|
||||
"scene_number": scene_number,
|
||||
"video_filename": save_result.get("video_filename"),
|
||||
"video_url": save_result.get("video_url"),
|
||||
"prompt_used": enhanced_prompt,
|
||||
"provider": getattr(request, "provider", None) or "huggingface",
|
||||
"model": getattr(request, "model", None) or "tencent/HunyuanVideo",
|
||||
}
|
||||
|
||||
|
||||
69
backend/api/story_writer/utils/media_utils.py
Normal file
69
backend/api/story_writer/utils/media_utils.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from loguru import logger
|
||||
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parents[3] # backend/
|
||||
STORY_IMAGES_DIR = (BASE_DIR / "story_images").resolve()
|
||||
STORY_IMAGES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def load_story_image_bytes(image_url: str) -> Optional[bytes]:
|
||||
"""
|
||||
Resolve an authenticated story image URL (e.g., /api/story/images/<file>) to raw bytes.
|
||||
Returns None if the file cannot be located.
|
||||
"""
|
||||
if not image_url:
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed = urlparse(image_url)
|
||||
path = parsed.path if parsed.scheme else image_url
|
||||
prefix = "/api/story/images/"
|
||||
if prefix not in path:
|
||||
logger.warning(f"[StoryWriter] Unsupported image URL for video reference: {image_url}")
|
||||
return None
|
||||
|
||||
filename = path.split(prefix, 1)[1].split("?", 1)[0].strip()
|
||||
if not filename:
|
||||
return None
|
||||
|
||||
file_path = (STORY_IMAGES_DIR / filename).resolve()
|
||||
if not str(file_path).startswith(str(STORY_IMAGES_DIR)):
|
||||
logger.error(f"[StoryWriter] Attempted path traversal when resolving image: {image_url}")
|
||||
return None
|
||||
|
||||
if not file_path.exists():
|
||||
logger.warning(f"[StoryWriter] Referenced scene image not found on disk: {file_path}")
|
||||
return None
|
||||
|
||||
return file_path.read_bytes()
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to load reference image for video gen: {exc}")
|
||||
return None
|
||||
|
||||
|
||||
def resolve_media_file(base_dir: Path, filename: str) -> Path:
|
||||
"""
|
||||
Returns a safe resolved path for a media file stored under base_dir.
|
||||
Guards against directory traversal and ensures the file exists.
|
||||
"""
|
||||
filename = filename.split("?")[0].strip()
|
||||
resolved = (base_dir / filename).resolve()
|
||||
|
||||
try:
|
||||
resolved.relative_to(base_dir.resolve())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
|
||||
|
||||
if not resolved.exists():
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"File not found: {filename}")
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
Reference in New Issue
Block a user