AI Video Generation Implementation

This commit is contained in:
ajaysi
2025-11-17 17:38:23 +05:30
parent 4901b7eb72
commit bf7493c366
132 changed files with 6200 additions and 19475 deletions

View File

@@ -9,6 +9,7 @@ from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel, Field
from services.llm_providers.main_image_generation import generate_image
from services.llm_providers.main_image_editing import edit_image
from services.llm_providers.main_text_generation import llm_text_gen
from utils.logger_utils import get_service_logger
from middleware.auth_middleware import get_current_user
@@ -125,6 +126,14 @@ def generate(
tier = limits.get('tier', 'unknown') if limits else 'unknown'
call_limit = limits['limits'].get("stability_calls", 0) if limits else 0
# Get image editing stats for unified log
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
# Get video stats for unified log
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
db_track.commit()
logger.info(f"[images.generate] ✅ Successfully tracked usage: user {user_id} -> stability -> {new_calls} calls")
@@ -137,6 +146,8 @@ def generate(
├─ Actual Provider: {result.provider}
├─ Model: {result.model or 'default'}
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
@@ -195,6 +206,26 @@ class ImagePromptSuggestResponse(BaseModel):
suggestions: list[PromptSuggestion]
class ImageEditRequest(BaseModel):
image_base64: str
prompt: str
provider: Optional[str] = Field(None, pattern="^(huggingface)$")
model: Optional[str] = None
guidance_scale: Optional[float] = None
steps: Optional[int] = None
seed: Optional[int] = None
class ImageEditResponse(BaseModel):
success: bool = True
image_base64: str
width: int
height: int
provider: str
model: Optional[str] = None
seed: Optional[int] = None
@router.post("/suggest-prompts", response_model=ImagePromptSuggestResponse)
def suggest_prompts(
req: ImagePromptSuggestRequest,
@@ -316,3 +347,136 @@ def suggest_prompts(
logger.error(f"Prompt suggestion failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/edit", response_model=ImageEditResponse)
def edit(
req: ImageEditRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> ImageEditResponse:
"""Edit image with subscription checking."""
try:
# Extract Clerk user ID (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
# Decode base64 image
try:
input_image_bytes = base64.b64decode(req.image_base64)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid image_base64: {str(e)}")
# Validation is now handled inside edit_image function
result = edit_image(
input_image_bytes=input_image_bytes,
prompt=req.prompt,
options={
"provider": req.provider,
"model": req.model,
"guidance_scale": req.guidance_scale,
"steps": req.steps,
"seed": req.seed,
},
user_id=user_id, # Pass user_id for validation inside edit_image
)
edited_image_b64 = base64.b64encode(result.image_bytes).decode("utf-8")
# TRACK USAGE after successful image editing
if result:
logger.info(f"[images.edit] ✅ Image editing successful, tracking usage for user {user_id}")
try:
db_track = next(get_db())
try:
# Get or create usage summary
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
logger.debug(f"[images.edit] Looking for usage summary: user_id={user_id}, period={current_period}")
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
logger.info(f"[images.edit] Creating new usage summary for user {user_id}, period {current_period}")
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush() # Ensure summary is persisted before updating
# Get "before" state for unified log
current_calls_before = getattr(summary, "image_edit_calls", 0) or 0
# Update image editing counters (separate from image generation)
new_calls = current_calls_before + 1
setattr(summary, "image_edit_calls", new_calls)
logger.debug(f"[images.edit] Updated image_edit_calls: {current_calls_before} -> {new_calls}")
# Update totals
old_total_calls = summary.total_calls or 0
summary.total_calls = old_total_calls + 1
logger.debug(f"[images.edit] Updated totals: calls {old_total_calls} -> {summary.total_calls}")
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
call_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
# Get image generation stats for unified log
current_image_gen_calls = getattr(summary, "stability_calls", 0) or 0
image_gen_limit = limits['limits'].get("stability_calls", 0) if limits else 0
# Get video stats for unified log
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
db_track.commit()
logger.info(f"[images.edit] ✅ Successfully tracked usage: user {user_id} -> image_edit -> {new_calls} calls")
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
print(f"""
[SUBSCRIPTION] Image Editing
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: image_edit
├─ Actual Provider: {result.provider}
├─ Model: {result.model or 'default'}
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
├─ Images: {current_image_gen_calls} / {image_gen_limit if image_gen_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
logger.error(f"[images.edit] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
# Non-blocking: log error but don't fail the request
logger.error(f"[images.edit] ❌ Failed to track usage: {usage_error}", exc_info=True)
return ImageEditResponse(
image_base64=edited_image_b64,
width=result.width,
height=result.height,
provider=result.provider,
model=result.model,
seed=result.seed,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Image editing failed: {e}", exc_info=True)
# Provide a clean, actionable message to the client
raise HTTPException(
status_code=500,
detail="Image editing service is temporarily unavailable or the connection was reset. Please try again."
)

View File

@@ -6,7 +6,7 @@ content generation, and full story creation.
"""
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
from typing import Any, Dict, Union, List
from typing import Any, Dict, Union, List, Optional
from loguru import logger
from middleware.auth_middleware import get_current_user
@@ -37,6 +37,16 @@ from models.story_models import (
from services.story_writer.story_service import StoryWriterService
from .task_manager import task_manager
from .cache_manager import cache_manager
from uuid import uuid4
from pydantic import BaseModel
from pathlib import Path
from .utils.auth import require_authenticated_user
from .utils.media_utils import resolve_media_file
from .utils.hd_video import (
generate_hd_video_payload,
generate_hd_video_scene_payload,
)
router = APIRouter(prefix="/api/story", tags=["Story Writer"])
@@ -503,11 +513,11 @@ async def get_task_status(
raise HTTPException(status_code=500, detail=str(e))
@router.get("/task/{task_id}/result", response_model=StoryFullGenerationResponse)
@router.get("/task/{task_id}/result")
async def get_task_result(
task_id: str,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> StoryFullGenerationResponse:
) -> Dict[str, Any]:
"""Get the result of a completed story generation task."""
try:
if not current_user:
@@ -528,7 +538,19 @@ async def get_task_result(
if not result:
raise HTTPException(status_code=404, detail=f"No result found for task {task_id}")
return StoryFullGenerationResponse(**result, success=True, task_id=task_id)
# Some tasks return a full-story payload compatible with StoryFullGenerationResponse,
# others (e.g., video-only) return a dict like {"video": {...}, "success": True}.
# To avoid model conflicts, return a generic payload and include task_id.
# Frontend callers can branch on keys present (e.g., "video").
if isinstance(result, dict):
# Ensure success flag present without duplicating
payload = {**result}
payload.setdefault("success", True)
payload["task_id"] = task_id
return payload
# Fallback: wrap non-dict results
return {"result": result, "success": True, "task_id": task_id}
except HTTPException:
raise
@@ -536,6 +558,69 @@ async def get_task_result(
logger.error(f"[StoryWriter] Failed to get task result: {e}")
raise HTTPException(status_code=500, detail=str(e))
class HDVideoRequest(BaseModel):
prompt: str
provider: str = "huggingface"
model: str | None = None
num_frames: int | None = None
guidance_scale: float | None = None
num_inference_steps: int | None = None
negative_prompt: str | None = None
seed: int | None = None
@router.post("/hd-video")
async def generate_hd_video(
request: HDVideoRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""
Generate an HD AI animation using provider text-to-video (Hugging Face for now).
Saves the returned bytes as a video file and returns the secured URL.
"""
try:
user_id = require_authenticated_user(current_user)
return generate_hd_video_payload(request, user_id)
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to generate HD video: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
class HDVideoSceneRequest(BaseModel):
scene_number: int
scene_data: Dict[str, Any]
story_context: Dict[str, Any]
all_scenes: List[Dict[str, Any]]
scene_image_url: Optional[str] = None
provider: str = "huggingface"
model: str | None = None
num_frames: int | None = None
guidance_scale: float | None = None
num_inference_steps: int | None = None
negative_prompt: str | None = None
seed: int | None = None
@router.post("/hd-video-scene")
async def generate_hd_video_scene(
request: HDVideoSceneRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""
Generate HD AI video for a single scene with AI-enhanced prompt.
Uses prompt enhancer to create HunyuanVideo-optimized prompt from story context.
"""
try:
user_id = require_authenticated_user(current_user)
return generate_hd_video_scene_payload(request, user_id)
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to generate HD video for scene: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# ---------------------------
# Image Generation Endpoints
@@ -614,31 +699,20 @@ async def serve_scene_image(
):
"""Serve a generated story scene image."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
# Import image generation service to get output directory
require_authenticated_user(current_user)
from services.story_writer.image_generation_service import StoryImageGenerationService
from fastapi.responses import FileResponse
image_service = StoryImageGenerationService()
image_path = image_service.output_dir / image_filename
if not image_path.exists():
raise HTTPException(status_code=404, detail=f"Image not found: {image_filename}")
# Validate that the file is within the output directory (security check)
try:
image_path.resolve().relative_to(image_service.output_dir.resolve())
except ValueError:
raise HTTPException(status_code=403, detail="Access denied")
image_path = resolve_media_file(image_service.output_dir, image_filename)
return FileResponse(
path=str(image_path),
media_type="image/png",
filename=image_filename
)
except HTTPException:
raise
except Exception as e:
@@ -726,31 +800,20 @@ async def serve_scene_audio(
):
"""Serve a generated story scene audio file."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
# Import audio generation service to get output directory
require_authenticated_user(current_user)
from services.story_writer.audio_generation_service import StoryAudioGenerationService
from fastapi.responses import FileResponse
audio_service = StoryAudioGenerationService()
audio_path = audio_service.output_dir / audio_filename
if not audio_path.exists():
raise HTTPException(status_code=404, detail=f"Audio not found: {audio_filename}")
# Validate that the file is within the output directory (security check)
try:
audio_path.resolve().relative_to(audio_service.output_dir.resolve())
except ValueError:
raise HTTPException(status_code=403, detail="Access denied")
audio_path = resolve_media_file(audio_service.output_dir, audio_filename)
return FileResponse(
path=str(audio_path),
media_type="audio/mpeg",
filename=audio_filename
)
except HTTPException:
raise
except Exception as e:
@@ -869,6 +932,99 @@ async def generate_story_video(
logger.error(f"[StoryWriter] Failed to generate video: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/generate-video-async", response_model=Dict[str, Any])
async def generate_story_video_async(
request: StoryVideoGenerationRequest,
background_tasks: BackgroundTasks,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""
Generate a video asynchronously with progress updates via task manager.
Frontend can poll /api/story/task/{task_id}/status to show progress messages.
"""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
if not request.scenes or len(request.scenes) == 0:
raise HTTPException(status_code=400, detail="At least one scene is required")
if len(request.scenes) != len(request.image_urls) or len(request.scenes) != len(request.audio_urls):
raise HTTPException(status_code=400, detail="Number of scenes, image URLs, and audio URLs must match")
task_id = task_manager.create_task("story_video_generation")
background_tasks.add_task(
_execute_video_generation_task,
task_id=task_id,
request=request,
user_id=user_id
)
return {"task_id": task_id, "status": "pending", "message": "Video generation started"}
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to start async video generation: {e}")
raise HTTPException(status_code=500, detail=str(e))
def _execute_video_generation_task(task_id: str, request: StoryVideoGenerationRequest, user_id: str):
"""Background task to generate story video with progress mapped to task manager."""
from services.story_writer.video_generation_service import StoryVideoGenerationService
from services.story_writer.image_generation_service import StoryImageGenerationService
from services.story_writer.audio_generation_service import StoryAudioGenerationService
try:
task_manager.update_task_status(task_id, "processing", progress=2.0, message="Initializing video generation...")
video_service = StoryVideoGenerationService()
image_service = StoryImageGenerationService()
audio_service = StoryAudioGenerationService()
# Prepare assets
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
image_paths, audio_paths, valid_scenes = [], [], []
for idx, (scene, image_url, audio_url) in enumerate(zip(scenes_data, request.image_urls, request.audio_urls)):
image_filename = (image_url.split('/')[-1] if '/' in image_url else image_url).split('?')[0]
audio_filename = (audio_url.split('/')[-1] if '/' in audio_url else audio_url).split('?')[0]
image_path = image_service.output_dir / image_filename
audio_path = audio_service.output_dir / audio_filename
if not image_path.exists():
logger.warning(f"[StoryWriter] Image not found: {image_path} (from URL: {image_url})")
continue
if not audio_path.exists():
logger.warning(f"[StoryWriter] Audio not found: {audio_path} (from URL: {audio_url})")
continue
image_paths.append(str(image_path))
audio_paths.append(str(audio_path))
valid_scenes.append(scene)
if not image_paths or not audio_paths or len(image_paths) != len(audio_paths):
raise RuntimeError("No valid or mismatched image/audio assets for video generation.")
# Map service progress (0-100) to task progress (5-95)
def progress_callback(sub_progress: float, msg: str):
overall = 5.0 + max(0.0, min(100.0, sub_progress)) * 0.9
task_manager.update_task_status(task_id, "processing", progress=overall, message=msg)
result = video_service.generate_story_video(
scenes=valid_scenes,
image_paths=image_paths,
audio_paths=audio_paths,
user_id=user_id,
story_title=request.story_title or "Story",
fps=request.fps or 24,
transition_duration=request.transition_duration or 0.5,
progress_callback=progress_callback
)
task_manager.update_task_status(
task_id,
"completed",
progress=100.0,
message="Video generation complete!",
result={"video": result, "success": True}
)
except Exception as e:
logger.error(f"[StoryWriter] Async video generation failed: {e}", exc_info=True)
task_manager.update_task_status(task_id, "failed", error=str(e), message=f"Video generation failed: {e}")
@router.post("/generate-complete-video", response_model=Dict[str, Any])
async def generate_complete_story_video(
@@ -1111,31 +1267,20 @@ async def serve_story_video(
):
"""Serve a generated story video file."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
# Import video generation service to get output directory
require_authenticated_user(current_user)
from services.story_writer.video_generation_service import StoryVideoGenerationService
from fastapi.responses import FileResponse
video_service = StoryVideoGenerationService()
video_path = video_service.output_dir / video_filename
if not video_path.exists():
raise HTTPException(status_code=404, detail=f"Video not found: {video_filename}")
# Validate that the file is within the output directory (security check)
try:
video_path.resolve().relative_to(video_service.output_dir.resolve())
except ValueError:
raise HTTPException(status_code=403, detail="Access denied")
video_path = resolve_media_file(video_service.output_dir, video_filename)
return FileResponse(
path=str(video_path),
media_type="video/mp4",
filename=video_filename
)
except HTTPException:
raise
except Exception as e:

View File

@@ -0,0 +1,8 @@
"""
Utility helpers for Story Writer API routes.
Grouped here to keep the main router lean while reusing common logic
such as authentication guards, media resolution, and HD video helpers.
"""

View File

@@ -0,0 +1,23 @@
from typing import Any, Dict
from fastapi import HTTPException, status
def require_authenticated_user(current_user: Dict[str, Any] | None) -> str:
"""
Validates the current user dictionary provided by Clerk middleware and
returns the normalized user_id. Raises HTTP 401 if authentication fails.
"""
if not current_user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required")
user_id = str(current_user.get("id", "")).strip()
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid user ID in authentication token",
)
return user_id

View File

@@ -0,0 +1,154 @@
from __future__ import annotations
from typing import Any, Dict, Optional
from fastapi import HTTPException
from loguru import logger
from uuid import uuid4
from .media_utils import load_story_image_bytes
def generate_hd_video_payload(request: Any, user_id: str) -> Dict[str, Any]:
"""Handles synchronous HD video generation."""
from services.llm_providers.main_video_generation import ai_video_generate
from services.story_writer.video_generation_service import StoryVideoGenerationService
video_service = StoryVideoGenerationService()
output_dir = video_service.output_dir
output_dir.mkdir(parents=True, exist_ok=True)
kwargs: Dict[str, Any] = {}
if getattr(request, "model", None):
kwargs["model"] = request.model
if getattr(request, "num_frames", None):
kwargs["num_frames"] = request.num_frames
if getattr(request, "guidance_scale", None) is not None:
kwargs["guidance_scale"] = request.guidance_scale
if getattr(request, "num_inference_steps", None):
kwargs["num_inference_steps"] = request.num_inference_steps
if getattr(request, "negative_prompt", None):
kwargs["negative_prompt"] = request.negative_prompt
if getattr(request, "seed", None) is not None:
kwargs["seed"] = request.seed
logger.info(f"[StoryWriter] Generating HD video via {getattr(request, 'provider', 'huggingface')} for user {user_id}")
raw_bytes = ai_video_generate(
prompt=request.prompt,
provider=getattr(request, "provider", None) or "huggingface",
user_id=user_id,
**kwargs,
)
filename = f"hd_{uuid4().hex}.mp4"
file_path = output_dir / filename
with open(file_path, "wb") as fh:
fh.write(raw_bytes)
logger.info(f"[StoryWriter] HD video saved to {file_path}")
return {
"success": True,
"video_filename": filename,
"video_url": f"/api/story/videos/{filename}",
"provider": getattr(request, "provider", None) or "huggingface",
"model": getattr(request, "model", None) or "tencent/HunyuanVideo",
}
def generate_hd_video_scene_payload(request: Any, user_id: str) -> Dict[str, Any]:
"""
Handles per-scene HD video generation including prompt enhancement,
subscription validation, and optional image conditioning.
"""
from services.database import get_db as get_db_validation
from services.onboarding.api_key_manager import APIKeyManager
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_video_generation_operations
from services.story_writer.prompt_enhancer_service import enhance_scene_prompt_for_video
from services.llm_providers.main_video_generation import ai_video_generate
from services.story_writer.video_generation_service import StoryVideoGenerationService
scene_number = request.scene_number
logger.info(f"[StoryWriter] Generating HD video for scene {scene_number} for user {user_id}")
# Step 1: Validate API key
hf_token = APIKeyManager().get_api_key("hf_token")
if not hf_token:
logger.error("[StoryWriter] Pre-flight: HF token not configured - blocking video generation")
raise HTTPException(
status_code=400,
detail={
"error": "Hugging Face API token is not configured. Please configure your HF token in settings.",
"message": "Hugging Face API token is not configured. Please configure your HF token in settings.",
},
)
# Step 2: Subscription limits
db_validation = next(get_db_validation())
try:
pricing_service = PricingService(db_validation)
logger.info(f"[StoryWriter] Pre-flight: Checking video generation limits for user {user_id}...")
validate_video_generation_operations(pricing_service=pricing_service, user_id=user_id)
logger.info("[StoryWriter] Pre-flight: ✅ Video generation limits validated - proceeding")
finally:
db_validation.close()
# Stage 1: Prompt enhancement
enhanced_prompt = enhance_scene_prompt_for_video(
current_scene=request.scene_data,
story_context=request.story_context,
all_scenes=request.all_scenes,
user_id=user_id,
)
logger.info(f"[StoryWriter] Generated enhanced prompt ({len(enhanced_prompt)} chars) for scene {scene_number}")
# Stage 2: Optional image reference
scene_image_bytes: Optional[bytes] = None
if getattr(request, "scene_image_url", None):
scene_image_bytes = load_story_image_bytes(request.scene_image_url)
if scene_image_bytes:
logger.info(f"[StoryWriter] Using scene image reference for scene {scene_number}")
else:
logger.warning(f"[StoryWriter] Scene image could not be loaded for scene {scene_number}, falling back to text-only video")
kwargs: Dict[str, Any] = {}
if getattr(request, "model", None):
kwargs["model"] = request.model
if getattr(request, "num_frames", None):
kwargs["num_frames"] = request.num_frames
if getattr(request, "guidance_scale", None) is not None:
kwargs["guidance_scale"] = request.guidance_scale
if getattr(request, "num_inference_steps", None):
kwargs["num_inference_steps"] = request.num_inference_steps
if getattr(request, "negative_prompt", None):
kwargs["negative_prompt"] = request.negative_prompt
if getattr(request, "seed", None) is not None:
kwargs["seed"] = request.seed
raw_bytes = ai_video_generate(
prompt=enhanced_prompt,
provider=getattr(request, "provider", None) or "huggingface",
user_id=user_id,
input_image_bytes=scene_image_bytes,
**kwargs,
)
video_service = StoryVideoGenerationService()
save_result = video_service.save_scene_video(
video_bytes=raw_bytes,
scene_number=scene_number,
user_id=user_id,
)
logger.info(f"[StoryWriter] HD video saved for scene {scene_number}: {save_result.get('video_filename')}")
return {
"success": True,
"scene_number": scene_number,
"video_filename": save_result.get("video_filename"),
"video_url": save_result.get("video_url"),
"prompt_used": enhanced_prompt,
"provider": getattr(request, "provider", None) or "huggingface",
"model": getattr(request, "model", None) or "tencent/HunyuanVideo",
}

View File

@@ -0,0 +1,69 @@
from __future__ import annotations
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
from fastapi import HTTPException, status
from loguru import logger
BASE_DIR = Path(__file__).resolve().parents[3] # backend/
STORY_IMAGES_DIR = (BASE_DIR / "story_images").resolve()
STORY_IMAGES_DIR.mkdir(parents=True, exist_ok=True)
def load_story_image_bytes(image_url: str) -> Optional[bytes]:
"""
Resolve an authenticated story image URL (e.g., /api/story/images/<file>) to raw bytes.
Returns None if the file cannot be located.
"""
if not image_url:
return None
try:
parsed = urlparse(image_url)
path = parsed.path if parsed.scheme else image_url
prefix = "/api/story/images/"
if prefix not in path:
logger.warning(f"[StoryWriter] Unsupported image URL for video reference: {image_url}")
return None
filename = path.split(prefix, 1)[1].split("?", 1)[0].strip()
if not filename:
return None
file_path = (STORY_IMAGES_DIR / filename).resolve()
if not str(file_path).startswith(str(STORY_IMAGES_DIR)):
logger.error(f"[StoryWriter] Attempted path traversal when resolving image: {image_url}")
return None
if not file_path.exists():
logger.warning(f"[StoryWriter] Referenced scene image not found on disk: {file_path}")
return None
return file_path.read_bytes()
except Exception as exc:
logger.error(f"[StoryWriter] Failed to load reference image for video gen: {exc}")
return None
def resolve_media_file(base_dir: Path, filename: str) -> Path:
"""
Returns a safe resolved path for a media file stored under base_dir.
Guards against directory traversal and ensures the file exists.
"""
filename = filename.split("?")[0].strip()
resolved = (base_dir / filename).resolve()
try:
resolved.relative_to(base_dir.resolve())
except ValueError:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
if not resolved.exists():
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"File not found: {filename}")
return resolved

View File

@@ -14,7 +14,7 @@ from functools import lru_cache
from services.database import get_db
from services.subscription import UsageTrackingService, PricingService
from services.subscription.log_wrapping_service import LogWrappingService
from services.subscription.schema_utils import ensure_subscription_plan_columns
from services.subscription.schema_utils import ensure_subscription_plan_columns, ensure_usage_summaries_columns
import sqlite3
from middleware.auth_middleware import get_current_user
from models.subscription_models import (
@@ -114,6 +114,8 @@ async def get_subscription_plans(
"metaphor_calls": plan.metaphor_calls_limit,
"firecrawl_calls": plan.firecrawl_calls_limit,
"stability_calls": plan.stability_calls_limit,
"video_calls": getattr(plan, 'video_calls_limit', 0),
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0),
"gemini_tokens": plan.gemini_tokens_limit,
"openai_tokens": plan.openai_tokens_limit,
"anthropic_tokens": plan.anthropic_tokens_limit,
@@ -132,7 +134,7 @@ async def get_subscription_plans(
except (sqlite3.OperationalError, Exception) as e:
error_str = str(e).lower()
if 'no such column' in error_str and 'exa_calls_limit' in error_str:
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str):
logger.warning("Missing column detected in subscription plans query, attempting schema fix...")
try:
import services.subscription.schema_utils as schema_utils
@@ -237,6 +239,8 @@ async def get_user_subscription(
"metaphor_calls": free_plan.metaphor_calls_limit,
"firecrawl_calls": free_plan.firecrawl_calls_limit,
"stability_calls": free_plan.stability_calls_limit,
"video_calls": getattr(free_plan, 'video_calls_limit', 0),
"image_edit_calls": getattr(free_plan, 'image_edit_calls_limit', 0),
"monthly_cost": free_plan.monthly_cost_limit
}
}
@@ -334,6 +338,8 @@ async def get_subscription_status(
"metaphor_calls": free_plan.metaphor_calls_limit,
"firecrawl_calls": free_plan.firecrawl_calls_limit,
"stability_calls": free_plan.stability_calls_limit,
"video_calls": getattr(free_plan, 'video_calls_limit', 0),
"image_edit_calls": getattr(free_plan, 'image_edit_calls_limit', 0),
"monthly_cost": free_plan.monthly_cost_limit
}
}
@@ -399,15 +405,16 @@ async def get_subscription_status(
except (sqlite3.OperationalError, Exception) as e:
error_str = str(e).lower()
if 'no such column' in error_str and 'exa_calls_limit' in error_str:
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str):
# Try to fix schema and retry once
logger.warning("Missing column detected in subscription status query, attempting schema fix...")
try:
import services.subscription.schema_utils as schema_utils
schema_utils._checked_subscription_plan_columns = False
ensure_subscription_plan_columns(db)
db.commit() # Ensure schema changes are committed
db.expire_all()
# Retry the query
# Retry the query - query subscription without eager loading plan
subscription = db.query(UserSubscription).filter(
UserSubscription.user_id == user_id,
UserSubscription.is_active == True
@@ -437,11 +444,21 @@ async def get_subscription_status(
"metaphor_calls": free_plan.metaphor_calls_limit,
"firecrawl_calls": free_plan.firecrawl_calls_limit,
"stability_calls": free_plan.stability_calls_limit,
"video_calls": getattr(free_plan, 'video_calls_limit', 0),
"image_edit_calls": getattr(free_plan, 'image_edit_calls_limit', 0),
"monthly_cost": free_plan.monthly_cost_limit
}
}
}
elif subscription:
# Query plan separately after schema fix to avoid lazy loading issues
plan = db.query(SubscriptionPlan).filter(
SubscriptionPlan.id == subscription.plan_id
).first()
if not plan:
raise HTTPException(status_code=404, detail="Plan not found")
now = datetime.utcnow()
if subscription.current_period_end < now:
if getattr(subscription, 'auto_renew', False):
@@ -456,8 +473,8 @@ async def get_subscription_status(
"success": True,
"data": {
"active": False,
"plan": subscription.plan.tier.value,
"tier": subscription.plan.tier.value,
"plan": plan.tier.value,
"tier": plan.tier.value,
"can_use_api": False,
"reason": "Subscription expired"
}
@@ -466,21 +483,23 @@ async def get_subscription_status(
"success": True,
"data": {
"active": True,
"plan": subscription.plan.tier.value,
"tier": subscription.plan.tier.value,
"plan": plan.tier.value,
"tier": plan.tier.value,
"can_use_api": True,
"limits": {
"ai_text_generation_calls": getattr(subscription.plan, 'ai_text_generation_calls_limit', None) or 0,
"gemini_calls": subscription.plan.gemini_calls_limit,
"openai_calls": subscription.plan.openai_calls_limit,
"anthropic_calls": subscription.plan.anthropic_calls_limit,
"mistral_calls": subscription.plan.mistral_calls_limit,
"tavily_calls": subscription.plan.tavily_calls_limit,
"serper_calls": subscription.plan.serper_calls_limit,
"metaphor_calls": subscription.plan.metaphor_calls_limit,
"firecrawl_calls": subscription.plan.firecrawl_calls_limit,
"stability_calls": subscription.plan.stability_calls_limit,
"monthly_cost": subscription.plan.monthly_cost_limit
"ai_text_generation_calls": getattr(plan, 'ai_text_generation_calls_limit', None) or 0,
"gemini_calls": plan.gemini_calls_limit,
"openai_calls": plan.openai_calls_limit,
"anthropic_calls": plan.anthropic_calls_limit,
"mistral_calls": plan.mistral_calls_limit,
"tavily_calls": plan.tavily_calls_limit,
"serper_calls": plan.serper_calls_limit,
"metaphor_calls": plan.metaphor_calls_limit,
"firecrawl_calls": plan.firecrawl_calls_limit,
"stability_calls": plan.stability_calls_limit,
"video_calls": getattr(plan, 'video_calls_limit', 0),
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0),
"monthly_cost": plan.monthly_cost_limit
}
}
}
@@ -893,6 +912,7 @@ async def get_dashboard_data(
try:
ensure_subscription_plan_columns(db)
ensure_usage_summaries_columns(db)
# Serve from short TTL cache to avoid hammering DB on bursts
import time
now = time.time()
@@ -966,7 +986,72 @@ async def get_dashboard_data(
_dashboard_cache_ts[user_id] = now
return response_payload
except Exception as e:
except (sqlite3.OperationalError, Exception) as e:
error_str = str(e).lower()
if 'no such column' in error_str and ('exa_calls' in error_str or 'exa_cost' in error_str or 'video_calls' in error_str or 'video_cost' in error_str or 'image_edit_calls' in error_str or 'image_edit_cost' in error_str):
logger.warning("Missing column detected in dashboard query, attempting schema fix...")
try:
import services.subscription.schema_utils as schema_utils
schema_utils._checked_usage_summaries_columns = False
schema_utils._checked_subscription_plan_columns = False
ensure_usage_summaries_columns(db)
ensure_subscription_plan_columns(db)
db.expire_all()
# Retry the query
usage_service = UsageTrackingService(db)
pricing_service = PricingService(db)
current_usage = usage_service.get_user_usage_stats(user_id)
trends = usage_service.get_usage_trends(user_id, 6)
limits = pricing_service.get_user_limits(user_id)
alerts = db.query(UsageAlert).filter(
UsageAlert.user_id == user_id,
UsageAlert.is_read == False
).order_by(UsageAlert.created_at.desc()).limit(5).all()
alerts_data = [
{
"id": alert.id,
"type": alert.alert_type,
"title": alert.title,
"message": alert.message,
"severity": alert.severity,
"created_at": alert.created_at.isoformat()
}
for alert in alerts
]
current_cost = current_usage.get('total_cost', 0)
days_in_period = 30
current_day = datetime.now().day
projected_cost = (current_cost / current_day) * days_in_period if current_day > 0 else 0
response_payload = {
"success": True,
"data": {
"current_usage": current_usage,
"trends": trends,
"limits": limits,
"alerts": alerts_data,
"projections": {
"projected_monthly_cost": round(projected_cost, 2),
"cost_limit": limits.get('limits', {}).get('monthly_cost', 0) if limits else 0,
"projected_usage_percentage": (projected_cost / max(limits.get('limits', {}).get('monthly_cost', 1), 1)) * 100 if limits else 0
},
"summary": {
"total_api_calls_this_month": current_usage.get('total_calls', 0),
"total_cost_this_month": current_usage.get('total_cost', 0),
"usage_status": current_usage.get('usage_status', 'active'),
"unread_alerts": len(alerts_data)
}
}
}
return response_payload
except Exception as retry_err:
logger.error(f"Schema fix and retry failed: {retry_err}")
raise HTTPException(status_code=500, detail=f"Database schema error: {str(e)}")
logger.error(f"Error getting dashboard data: {e}")
raise HTTPException(status_code=500, detail=str(e))