AI Video Generation Implementation

This commit is contained in:
ajaysi
2025-11-17 17:38:23 +05:30
parent 4901b7eb72
commit bf7493c366
132 changed files with 6200 additions and 19475 deletions

View File

@@ -9,6 +9,7 @@ from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel, Field
from services.llm_providers.main_image_generation import generate_image
from services.llm_providers.main_image_editing import edit_image
from services.llm_providers.main_text_generation import llm_text_gen
from utils.logger_utils import get_service_logger
from middleware.auth_middleware import get_current_user
@@ -125,6 +126,14 @@ def generate(
tier = limits.get('tier', 'unknown') if limits else 'unknown'
call_limit = limits['limits'].get("stability_calls", 0) if limits else 0
# Get image editing stats for unified log
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
# Get video stats for unified log
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
db_track.commit()
logger.info(f"[images.generate] ✅ Successfully tracked usage: user {user_id} -> stability -> {new_calls} calls")
@@ -137,6 +146,8 @@ def generate(
├─ Actual Provider: {result.provider}
├─ Model: {result.model or 'default'}
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
@@ -195,6 +206,26 @@ class ImagePromptSuggestResponse(BaseModel):
suggestions: list[PromptSuggestion]
class ImageEditRequest(BaseModel):
image_base64: str
prompt: str
provider: Optional[str] = Field(None, pattern="^(huggingface)$")
model: Optional[str] = None
guidance_scale: Optional[float] = None
steps: Optional[int] = None
seed: Optional[int] = None
class ImageEditResponse(BaseModel):
success: bool = True
image_base64: str
width: int
height: int
provider: str
model: Optional[str] = None
seed: Optional[int] = None
@router.post("/suggest-prompts", response_model=ImagePromptSuggestResponse)
def suggest_prompts(
req: ImagePromptSuggestRequest,
@@ -316,3 +347,136 @@ def suggest_prompts(
logger.error(f"Prompt suggestion failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/edit", response_model=ImageEditResponse)
def edit(
req: ImageEditRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> ImageEditResponse:
"""Edit image with subscription checking."""
try:
# Extract Clerk user ID (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
# Decode base64 image
try:
input_image_bytes = base64.b64decode(req.image_base64)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid image_base64: {str(e)}")
# Validation is now handled inside edit_image function
result = edit_image(
input_image_bytes=input_image_bytes,
prompt=req.prompt,
options={
"provider": req.provider,
"model": req.model,
"guidance_scale": req.guidance_scale,
"steps": req.steps,
"seed": req.seed,
},
user_id=user_id, # Pass user_id for validation inside edit_image
)
edited_image_b64 = base64.b64encode(result.image_bytes).decode("utf-8")
# TRACK USAGE after successful image editing
if result:
logger.info(f"[images.edit] ✅ Image editing successful, tracking usage for user {user_id}")
try:
db_track = next(get_db())
try:
# Get or create usage summary
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
logger.debug(f"[images.edit] Looking for usage summary: user_id={user_id}, period={current_period}")
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
logger.info(f"[images.edit] Creating new usage summary for user {user_id}, period {current_period}")
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush() # Ensure summary is persisted before updating
# Get "before" state for unified log
current_calls_before = getattr(summary, "image_edit_calls", 0) or 0
# Update image editing counters (separate from image generation)
new_calls = current_calls_before + 1
setattr(summary, "image_edit_calls", new_calls)
logger.debug(f"[images.edit] Updated image_edit_calls: {current_calls_before} -> {new_calls}")
# Update totals
old_total_calls = summary.total_calls or 0
summary.total_calls = old_total_calls + 1
logger.debug(f"[images.edit] Updated totals: calls {old_total_calls} -> {summary.total_calls}")
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
call_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
# Get image generation stats for unified log
current_image_gen_calls = getattr(summary, "stability_calls", 0) or 0
image_gen_limit = limits['limits'].get("stability_calls", 0) if limits else 0
# Get video stats for unified log
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
db_track.commit()
logger.info(f"[images.edit] ✅ Successfully tracked usage: user {user_id} -> image_edit -> {new_calls} calls")
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
print(f"""
[SUBSCRIPTION] Image Editing
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: image_edit
├─ Actual Provider: {result.provider}
├─ Model: {result.model or 'default'}
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
├─ Images: {current_image_gen_calls} / {image_gen_limit if image_gen_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
logger.error(f"[images.edit] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
# Non-blocking: log error but don't fail the request
logger.error(f"[images.edit] ❌ Failed to track usage: {usage_error}", exc_info=True)
return ImageEditResponse(
image_base64=edited_image_b64,
width=result.width,
height=result.height,
provider=result.provider,
model=result.model,
seed=result.seed,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Image editing failed: {e}", exc_info=True)
# Provide a clean, actionable message to the client
raise HTTPException(
status_code=500,
detail="Image editing service is temporarily unavailable or the connection was reset. Please try again."
)

View File

@@ -6,7 +6,7 @@ content generation, and full story creation.
"""
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
from typing import Any, Dict, Union, List
from typing import Any, Dict, Union, List, Optional
from loguru import logger
from middleware.auth_middleware import get_current_user
@@ -37,6 +37,16 @@ from models.story_models import (
from services.story_writer.story_service import StoryWriterService
from .task_manager import task_manager
from .cache_manager import cache_manager
from uuid import uuid4
from pydantic import BaseModel
from pathlib import Path
from .utils.auth import require_authenticated_user
from .utils.media_utils import resolve_media_file
from .utils.hd_video import (
generate_hd_video_payload,
generate_hd_video_scene_payload,
)
router = APIRouter(prefix="/api/story", tags=["Story Writer"])
@@ -503,11 +513,11 @@ async def get_task_status(
raise HTTPException(status_code=500, detail=str(e))
@router.get("/task/{task_id}/result", response_model=StoryFullGenerationResponse)
@router.get("/task/{task_id}/result")
async def get_task_result(
task_id: str,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> StoryFullGenerationResponse:
) -> Dict[str, Any]:
"""Get the result of a completed story generation task."""
try:
if not current_user:
@@ -528,7 +538,19 @@ async def get_task_result(
if not result:
raise HTTPException(status_code=404, detail=f"No result found for task {task_id}")
return StoryFullGenerationResponse(**result, success=True, task_id=task_id)
# Some tasks return a full-story payload compatible with StoryFullGenerationResponse,
# others (e.g., video-only) return a dict like {"video": {...}, "success": True}.
# To avoid model conflicts, return a generic payload and include task_id.
# Frontend callers can branch on keys present (e.g., "video").
if isinstance(result, dict):
# Ensure success flag present without duplicating
payload = {**result}
payload.setdefault("success", True)
payload["task_id"] = task_id
return payload
# Fallback: wrap non-dict results
return {"result": result, "success": True, "task_id": task_id}
except HTTPException:
raise
@@ -536,6 +558,69 @@ async def get_task_result(
logger.error(f"[StoryWriter] Failed to get task result: {e}")
raise HTTPException(status_code=500, detail=str(e))
class HDVideoRequest(BaseModel):
prompt: str
provider: str = "huggingface"
model: str | None = None
num_frames: int | None = None
guidance_scale: float | None = None
num_inference_steps: int | None = None
negative_prompt: str | None = None
seed: int | None = None
@router.post("/hd-video")
async def generate_hd_video(
request: HDVideoRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""
Generate an HD AI animation using provider text-to-video (Hugging Face for now).
Saves the returned bytes as a video file and returns the secured URL.
"""
try:
user_id = require_authenticated_user(current_user)
return generate_hd_video_payload(request, user_id)
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to generate HD video: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
class HDVideoSceneRequest(BaseModel):
scene_number: int
scene_data: Dict[str, Any]
story_context: Dict[str, Any]
all_scenes: List[Dict[str, Any]]
scene_image_url: Optional[str] = None
provider: str = "huggingface"
model: str | None = None
num_frames: int | None = None
guidance_scale: float | None = None
num_inference_steps: int | None = None
negative_prompt: str | None = None
seed: int | None = None
@router.post("/hd-video-scene")
async def generate_hd_video_scene(
request: HDVideoSceneRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""
Generate HD AI video for a single scene with AI-enhanced prompt.
Uses prompt enhancer to create HunyuanVideo-optimized prompt from story context.
"""
try:
user_id = require_authenticated_user(current_user)
return generate_hd_video_scene_payload(request, user_id)
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to generate HD video for scene: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# ---------------------------
# Image Generation Endpoints
@@ -614,31 +699,20 @@ async def serve_scene_image(
):
"""Serve a generated story scene image."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
# Import image generation service to get output directory
require_authenticated_user(current_user)
from services.story_writer.image_generation_service import StoryImageGenerationService
from fastapi.responses import FileResponse
image_service = StoryImageGenerationService()
image_path = image_service.output_dir / image_filename
if not image_path.exists():
raise HTTPException(status_code=404, detail=f"Image not found: {image_filename}")
# Validate that the file is within the output directory (security check)
try:
image_path.resolve().relative_to(image_service.output_dir.resolve())
except ValueError:
raise HTTPException(status_code=403, detail="Access denied")
image_path = resolve_media_file(image_service.output_dir, image_filename)
return FileResponse(
path=str(image_path),
media_type="image/png",
filename=image_filename
)
except HTTPException:
raise
except Exception as e:
@@ -726,31 +800,20 @@ async def serve_scene_audio(
):
"""Serve a generated story scene audio file."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
# Import audio generation service to get output directory
require_authenticated_user(current_user)
from services.story_writer.audio_generation_service import StoryAudioGenerationService
from fastapi.responses import FileResponse
audio_service = StoryAudioGenerationService()
audio_path = audio_service.output_dir / audio_filename
if not audio_path.exists():
raise HTTPException(status_code=404, detail=f"Audio not found: {audio_filename}")
# Validate that the file is within the output directory (security check)
try:
audio_path.resolve().relative_to(audio_service.output_dir.resolve())
except ValueError:
raise HTTPException(status_code=403, detail="Access denied")
audio_path = resolve_media_file(audio_service.output_dir, audio_filename)
return FileResponse(
path=str(audio_path),
media_type="audio/mpeg",
filename=audio_filename
)
except HTTPException:
raise
except Exception as e:
@@ -869,6 +932,99 @@ async def generate_story_video(
logger.error(f"[StoryWriter] Failed to generate video: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/generate-video-async", response_model=Dict[str, Any])
async def generate_story_video_async(
request: StoryVideoGenerationRequest,
background_tasks: BackgroundTasks,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""
Generate a video asynchronously with progress updates via task manager.
Frontend can poll /api/story/task/{task_id}/status to show progress messages.
"""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
if not request.scenes or len(request.scenes) == 0:
raise HTTPException(status_code=400, detail="At least one scene is required")
if len(request.scenes) != len(request.image_urls) or len(request.scenes) != len(request.audio_urls):
raise HTTPException(status_code=400, detail="Number of scenes, image URLs, and audio URLs must match")
task_id = task_manager.create_task("story_video_generation")
background_tasks.add_task(
_execute_video_generation_task,
task_id=task_id,
request=request,
user_id=user_id
)
return {"task_id": task_id, "status": "pending", "message": "Video generation started"}
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to start async video generation: {e}")
raise HTTPException(status_code=500, detail=str(e))
def _execute_video_generation_task(task_id: str, request: StoryVideoGenerationRequest, user_id: str):
"""Background task to generate story video with progress mapped to task manager."""
from services.story_writer.video_generation_service import StoryVideoGenerationService
from services.story_writer.image_generation_service import StoryImageGenerationService
from services.story_writer.audio_generation_service import StoryAudioGenerationService
try:
task_manager.update_task_status(task_id, "processing", progress=2.0, message="Initializing video generation...")
video_service = StoryVideoGenerationService()
image_service = StoryImageGenerationService()
audio_service = StoryAudioGenerationService()
# Prepare assets
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
image_paths, audio_paths, valid_scenes = [], [], []
for idx, (scene, image_url, audio_url) in enumerate(zip(scenes_data, request.image_urls, request.audio_urls)):
image_filename = (image_url.split('/')[-1] if '/' in image_url else image_url).split('?')[0]
audio_filename = (audio_url.split('/')[-1] if '/' in audio_url else audio_url).split('?')[0]
image_path = image_service.output_dir / image_filename
audio_path = audio_service.output_dir / audio_filename
if not image_path.exists():
logger.warning(f"[StoryWriter] Image not found: {image_path} (from URL: {image_url})")
continue
if not audio_path.exists():
logger.warning(f"[StoryWriter] Audio not found: {audio_path} (from URL: {audio_url})")
continue
image_paths.append(str(image_path))
audio_paths.append(str(audio_path))
valid_scenes.append(scene)
if not image_paths or not audio_paths or len(image_paths) != len(audio_paths):
raise RuntimeError("No valid or mismatched image/audio assets for video generation.")
# Map service progress (0-100) to task progress (5-95)
def progress_callback(sub_progress: float, msg: str):
overall = 5.0 + max(0.0, min(100.0, sub_progress)) * 0.9
task_manager.update_task_status(task_id, "processing", progress=overall, message=msg)
result = video_service.generate_story_video(
scenes=valid_scenes,
image_paths=image_paths,
audio_paths=audio_paths,
user_id=user_id,
story_title=request.story_title or "Story",
fps=request.fps or 24,
transition_duration=request.transition_duration or 0.5,
progress_callback=progress_callback
)
task_manager.update_task_status(
task_id,
"completed",
progress=100.0,
message="Video generation complete!",
result={"video": result, "success": True}
)
except Exception as e:
logger.error(f"[StoryWriter] Async video generation failed: {e}", exc_info=True)
task_manager.update_task_status(task_id, "failed", error=str(e), message=f"Video generation failed: {e}")
@router.post("/generate-complete-video", response_model=Dict[str, Any])
async def generate_complete_story_video(
@@ -1111,31 +1267,20 @@ async def serve_story_video(
):
"""Serve a generated story video file."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
# Import video generation service to get output directory
require_authenticated_user(current_user)
from services.story_writer.video_generation_service import StoryVideoGenerationService
from fastapi.responses import FileResponse
video_service = StoryVideoGenerationService()
video_path = video_service.output_dir / video_filename
if not video_path.exists():
raise HTTPException(status_code=404, detail=f"Video not found: {video_filename}")
# Validate that the file is within the output directory (security check)
try:
video_path.resolve().relative_to(video_service.output_dir.resolve())
except ValueError:
raise HTTPException(status_code=403, detail="Access denied")
video_path = resolve_media_file(video_service.output_dir, video_filename)
return FileResponse(
path=str(video_path),
media_type="video/mp4",
filename=video_filename
)
except HTTPException:
raise
except Exception as e:

View File

@@ -0,0 +1,8 @@
"""
Utility helpers for Story Writer API routes.
Grouped here to keep the main router lean while reusing common logic
such as authentication guards, media resolution, and HD video helpers.
"""

View File

@@ -0,0 +1,23 @@
from typing import Any, Dict
from fastapi import HTTPException, status
def require_authenticated_user(current_user: Dict[str, Any] | None) -> str:
"""
Validates the current user dictionary provided by Clerk middleware and
returns the normalized user_id. Raises HTTP 401 if authentication fails.
"""
if not current_user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required")
user_id = str(current_user.get("id", "")).strip()
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid user ID in authentication token",
)
return user_id

View File

@@ -0,0 +1,154 @@
from __future__ import annotations
from typing import Any, Dict, Optional
from fastapi import HTTPException
from loguru import logger
from uuid import uuid4
from .media_utils import load_story_image_bytes
def generate_hd_video_payload(request: Any, user_id: str) -> Dict[str, Any]:
"""Handles synchronous HD video generation."""
from services.llm_providers.main_video_generation import ai_video_generate
from services.story_writer.video_generation_service import StoryVideoGenerationService
video_service = StoryVideoGenerationService()
output_dir = video_service.output_dir
output_dir.mkdir(parents=True, exist_ok=True)
kwargs: Dict[str, Any] = {}
if getattr(request, "model", None):
kwargs["model"] = request.model
if getattr(request, "num_frames", None):
kwargs["num_frames"] = request.num_frames
if getattr(request, "guidance_scale", None) is not None:
kwargs["guidance_scale"] = request.guidance_scale
if getattr(request, "num_inference_steps", None):
kwargs["num_inference_steps"] = request.num_inference_steps
if getattr(request, "negative_prompt", None):
kwargs["negative_prompt"] = request.negative_prompt
if getattr(request, "seed", None) is not None:
kwargs["seed"] = request.seed
logger.info(f"[StoryWriter] Generating HD video via {getattr(request, 'provider', 'huggingface')} for user {user_id}")
raw_bytes = ai_video_generate(
prompt=request.prompt,
provider=getattr(request, "provider", None) or "huggingface",
user_id=user_id,
**kwargs,
)
filename = f"hd_{uuid4().hex}.mp4"
file_path = output_dir / filename
with open(file_path, "wb") as fh:
fh.write(raw_bytes)
logger.info(f"[StoryWriter] HD video saved to {file_path}")
return {
"success": True,
"video_filename": filename,
"video_url": f"/api/story/videos/{filename}",
"provider": getattr(request, "provider", None) or "huggingface",
"model": getattr(request, "model", None) or "tencent/HunyuanVideo",
}
def generate_hd_video_scene_payload(request: Any, user_id: str) -> Dict[str, Any]:
"""
Handles per-scene HD video generation including prompt enhancement,
subscription validation, and optional image conditioning.
"""
from services.database import get_db as get_db_validation
from services.onboarding.api_key_manager import APIKeyManager
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_video_generation_operations
from services.story_writer.prompt_enhancer_service import enhance_scene_prompt_for_video
from services.llm_providers.main_video_generation import ai_video_generate
from services.story_writer.video_generation_service import StoryVideoGenerationService
scene_number = request.scene_number
logger.info(f"[StoryWriter] Generating HD video for scene {scene_number} for user {user_id}")
# Step 1: Validate API key
hf_token = APIKeyManager().get_api_key("hf_token")
if not hf_token:
logger.error("[StoryWriter] Pre-flight: HF token not configured - blocking video generation")
raise HTTPException(
status_code=400,
detail={
"error": "Hugging Face API token is not configured. Please configure your HF token in settings.",
"message": "Hugging Face API token is not configured. Please configure your HF token in settings.",
},
)
# Step 2: Subscription limits
db_validation = next(get_db_validation())
try:
pricing_service = PricingService(db_validation)
logger.info(f"[StoryWriter] Pre-flight: Checking video generation limits for user {user_id}...")
validate_video_generation_operations(pricing_service=pricing_service, user_id=user_id)
logger.info("[StoryWriter] Pre-flight: ✅ Video generation limits validated - proceeding")
finally:
db_validation.close()
# Stage 1: Prompt enhancement
enhanced_prompt = enhance_scene_prompt_for_video(
current_scene=request.scene_data,
story_context=request.story_context,
all_scenes=request.all_scenes,
user_id=user_id,
)
logger.info(f"[StoryWriter] Generated enhanced prompt ({len(enhanced_prompt)} chars) for scene {scene_number}")
# Stage 2: Optional image reference
scene_image_bytes: Optional[bytes] = None
if getattr(request, "scene_image_url", None):
scene_image_bytes = load_story_image_bytes(request.scene_image_url)
if scene_image_bytes:
logger.info(f"[StoryWriter] Using scene image reference for scene {scene_number}")
else:
logger.warning(f"[StoryWriter] Scene image could not be loaded for scene {scene_number}, falling back to text-only video")
kwargs: Dict[str, Any] = {}
if getattr(request, "model", None):
kwargs["model"] = request.model
if getattr(request, "num_frames", None):
kwargs["num_frames"] = request.num_frames
if getattr(request, "guidance_scale", None) is not None:
kwargs["guidance_scale"] = request.guidance_scale
if getattr(request, "num_inference_steps", None):
kwargs["num_inference_steps"] = request.num_inference_steps
if getattr(request, "negative_prompt", None):
kwargs["negative_prompt"] = request.negative_prompt
if getattr(request, "seed", None) is not None:
kwargs["seed"] = request.seed
raw_bytes = ai_video_generate(
prompt=enhanced_prompt,
provider=getattr(request, "provider", None) or "huggingface",
user_id=user_id,
input_image_bytes=scene_image_bytes,
**kwargs,
)
video_service = StoryVideoGenerationService()
save_result = video_service.save_scene_video(
video_bytes=raw_bytes,
scene_number=scene_number,
user_id=user_id,
)
logger.info(f"[StoryWriter] HD video saved for scene {scene_number}: {save_result.get('video_filename')}")
return {
"success": True,
"scene_number": scene_number,
"video_filename": save_result.get("video_filename"),
"video_url": save_result.get("video_url"),
"prompt_used": enhanced_prompt,
"provider": getattr(request, "provider", None) or "huggingface",
"model": getattr(request, "model", None) or "tencent/HunyuanVideo",
}

View File

@@ -0,0 +1,69 @@
from __future__ import annotations
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
from fastapi import HTTPException, status
from loguru import logger
BASE_DIR = Path(__file__).resolve().parents[3] # backend/
STORY_IMAGES_DIR = (BASE_DIR / "story_images").resolve()
STORY_IMAGES_DIR.mkdir(parents=True, exist_ok=True)
def load_story_image_bytes(image_url: str) -> Optional[bytes]:
"""
Resolve an authenticated story image URL (e.g., /api/story/images/<file>) to raw bytes.
Returns None if the file cannot be located.
"""
if not image_url:
return None
try:
parsed = urlparse(image_url)
path = parsed.path if parsed.scheme else image_url
prefix = "/api/story/images/"
if prefix not in path:
logger.warning(f"[StoryWriter] Unsupported image URL for video reference: {image_url}")
return None
filename = path.split(prefix, 1)[1].split("?", 1)[0].strip()
if not filename:
return None
file_path = (STORY_IMAGES_DIR / filename).resolve()
if not str(file_path).startswith(str(STORY_IMAGES_DIR)):
logger.error(f"[StoryWriter] Attempted path traversal when resolving image: {image_url}")
return None
if not file_path.exists():
logger.warning(f"[StoryWriter] Referenced scene image not found on disk: {file_path}")
return None
return file_path.read_bytes()
except Exception as exc:
logger.error(f"[StoryWriter] Failed to load reference image for video gen: {exc}")
return None
def resolve_media_file(base_dir: Path, filename: str) -> Path:
"""
Returns a safe resolved path for a media file stored under base_dir.
Guards against directory traversal and ensures the file exists.
"""
filename = filename.split("?")[0].strip()
resolved = (base_dir / filename).resolve()
try:
resolved.relative_to(base_dir.resolve())
except ValueError:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
if not resolved.exists():
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"File not found: {filename}")
return resolved

View File

@@ -14,7 +14,7 @@ from functools import lru_cache
from services.database import get_db
from services.subscription import UsageTrackingService, PricingService
from services.subscription.log_wrapping_service import LogWrappingService
from services.subscription.schema_utils import ensure_subscription_plan_columns
from services.subscription.schema_utils import ensure_subscription_plan_columns, ensure_usage_summaries_columns
import sqlite3
from middleware.auth_middleware import get_current_user
from models.subscription_models import (
@@ -114,6 +114,8 @@ async def get_subscription_plans(
"metaphor_calls": plan.metaphor_calls_limit,
"firecrawl_calls": plan.firecrawl_calls_limit,
"stability_calls": plan.stability_calls_limit,
"video_calls": getattr(plan, 'video_calls_limit', 0),
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0),
"gemini_tokens": plan.gemini_tokens_limit,
"openai_tokens": plan.openai_tokens_limit,
"anthropic_tokens": plan.anthropic_tokens_limit,
@@ -132,7 +134,7 @@ async def get_subscription_plans(
except (sqlite3.OperationalError, Exception) as e:
error_str = str(e).lower()
if 'no such column' in error_str and 'exa_calls_limit' in error_str:
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str):
logger.warning("Missing column detected in subscription plans query, attempting schema fix...")
try:
import services.subscription.schema_utils as schema_utils
@@ -237,6 +239,8 @@ async def get_user_subscription(
"metaphor_calls": free_plan.metaphor_calls_limit,
"firecrawl_calls": free_plan.firecrawl_calls_limit,
"stability_calls": free_plan.stability_calls_limit,
"video_calls": getattr(free_plan, 'video_calls_limit', 0),
"image_edit_calls": getattr(free_plan, 'image_edit_calls_limit', 0),
"monthly_cost": free_plan.monthly_cost_limit
}
}
@@ -334,6 +338,8 @@ async def get_subscription_status(
"metaphor_calls": free_plan.metaphor_calls_limit,
"firecrawl_calls": free_plan.firecrawl_calls_limit,
"stability_calls": free_plan.stability_calls_limit,
"video_calls": getattr(free_plan, 'video_calls_limit', 0),
"image_edit_calls": getattr(free_plan, 'image_edit_calls_limit', 0),
"monthly_cost": free_plan.monthly_cost_limit
}
}
@@ -399,15 +405,16 @@ async def get_subscription_status(
except (sqlite3.OperationalError, Exception) as e:
error_str = str(e).lower()
if 'no such column' in error_str and 'exa_calls_limit' in error_str:
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str):
# Try to fix schema and retry once
logger.warning("Missing column detected in subscription status query, attempting schema fix...")
try:
import services.subscription.schema_utils as schema_utils
schema_utils._checked_subscription_plan_columns = False
ensure_subscription_plan_columns(db)
db.commit() # Ensure schema changes are committed
db.expire_all()
# Retry the query
# Retry the query - query subscription without eager loading plan
subscription = db.query(UserSubscription).filter(
UserSubscription.user_id == user_id,
UserSubscription.is_active == True
@@ -437,11 +444,21 @@ async def get_subscription_status(
"metaphor_calls": free_plan.metaphor_calls_limit,
"firecrawl_calls": free_plan.firecrawl_calls_limit,
"stability_calls": free_plan.stability_calls_limit,
"video_calls": getattr(free_plan, 'video_calls_limit', 0),
"image_edit_calls": getattr(free_plan, 'image_edit_calls_limit', 0),
"monthly_cost": free_plan.monthly_cost_limit
}
}
}
elif subscription:
# Query plan separately after schema fix to avoid lazy loading issues
plan = db.query(SubscriptionPlan).filter(
SubscriptionPlan.id == subscription.plan_id
).first()
if not plan:
raise HTTPException(status_code=404, detail="Plan not found")
now = datetime.utcnow()
if subscription.current_period_end < now:
if getattr(subscription, 'auto_renew', False):
@@ -456,8 +473,8 @@ async def get_subscription_status(
"success": True,
"data": {
"active": False,
"plan": subscription.plan.tier.value,
"tier": subscription.plan.tier.value,
"plan": plan.tier.value,
"tier": plan.tier.value,
"can_use_api": False,
"reason": "Subscription expired"
}
@@ -466,21 +483,23 @@ async def get_subscription_status(
"success": True,
"data": {
"active": True,
"plan": subscription.plan.tier.value,
"tier": subscription.plan.tier.value,
"plan": plan.tier.value,
"tier": plan.tier.value,
"can_use_api": True,
"limits": {
"ai_text_generation_calls": getattr(subscription.plan, 'ai_text_generation_calls_limit', None) or 0,
"gemini_calls": subscription.plan.gemini_calls_limit,
"openai_calls": subscription.plan.openai_calls_limit,
"anthropic_calls": subscription.plan.anthropic_calls_limit,
"mistral_calls": subscription.plan.mistral_calls_limit,
"tavily_calls": subscription.plan.tavily_calls_limit,
"serper_calls": subscription.plan.serper_calls_limit,
"metaphor_calls": subscription.plan.metaphor_calls_limit,
"firecrawl_calls": subscription.plan.firecrawl_calls_limit,
"stability_calls": subscription.plan.stability_calls_limit,
"monthly_cost": subscription.plan.monthly_cost_limit
"ai_text_generation_calls": getattr(plan, 'ai_text_generation_calls_limit', None) or 0,
"gemini_calls": plan.gemini_calls_limit,
"openai_calls": plan.openai_calls_limit,
"anthropic_calls": plan.anthropic_calls_limit,
"mistral_calls": plan.mistral_calls_limit,
"tavily_calls": plan.tavily_calls_limit,
"serper_calls": plan.serper_calls_limit,
"metaphor_calls": plan.metaphor_calls_limit,
"firecrawl_calls": plan.firecrawl_calls_limit,
"stability_calls": plan.stability_calls_limit,
"video_calls": getattr(plan, 'video_calls_limit', 0),
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0),
"monthly_cost": plan.monthly_cost_limit
}
}
}
@@ -893,6 +912,7 @@ async def get_dashboard_data(
try:
ensure_subscription_plan_columns(db)
ensure_usage_summaries_columns(db)
# Serve from short TTL cache to avoid hammering DB on bursts
import time
now = time.time()
@@ -966,7 +986,72 @@ async def get_dashboard_data(
_dashboard_cache_ts[user_id] = now
return response_payload
except Exception as e:
except (sqlite3.OperationalError, Exception) as e:
error_str = str(e).lower()
if 'no such column' in error_str and ('exa_calls' in error_str or 'exa_cost' in error_str or 'video_calls' in error_str or 'video_cost' in error_str or 'image_edit_calls' in error_str or 'image_edit_cost' in error_str):
logger.warning("Missing column detected in dashboard query, attempting schema fix...")
try:
import services.subscription.schema_utils as schema_utils
schema_utils._checked_usage_summaries_columns = False
schema_utils._checked_subscription_plan_columns = False
ensure_usage_summaries_columns(db)
ensure_subscription_plan_columns(db)
db.expire_all()
# Retry the query
usage_service = UsageTrackingService(db)
pricing_service = PricingService(db)
current_usage = usage_service.get_user_usage_stats(user_id)
trends = usage_service.get_usage_trends(user_id, 6)
limits = pricing_service.get_user_limits(user_id)
alerts = db.query(UsageAlert).filter(
UsageAlert.user_id == user_id,
UsageAlert.is_read == False
).order_by(UsageAlert.created_at.desc()).limit(5).all()
alerts_data = [
{
"id": alert.id,
"type": alert.alert_type,
"title": alert.title,
"message": alert.message,
"severity": alert.severity,
"created_at": alert.created_at.isoformat()
}
for alert in alerts
]
current_cost = current_usage.get('total_cost', 0)
days_in_period = 30
current_day = datetime.now().day
projected_cost = (current_cost / current_day) * days_in_period if current_day > 0 else 0
response_payload = {
"success": True,
"data": {
"current_usage": current_usage,
"trends": trends,
"limits": limits,
"alerts": alerts_data,
"projections": {
"projected_monthly_cost": round(projected_cost, 2),
"cost_limit": limits.get('limits', {}).get('monthly_cost', 0) if limits else 0,
"projected_usage_percentage": (projected_cost / max(limits.get('limits', {}).get('monthly_cost', 1), 1)) * 100 if limits else 0
},
"summary": {
"total_api_calls_this_month": current_usage.get('total_calls', 0),
"total_cost_this_month": current_usage.get('total_cost', 0),
"usage_status": current_usage.get('usage_status', 'active'),
"unread_alerts": len(alerts_data)
}
}
}
return response_payload
except Exception as retry_err:
logger.error(f"Schema fix and retry failed: {retry_err}")
raise HTTPException(status_code=500, detail=f"Database schema error: {str(e)}")
logger.error(f"Error getting dashboard data: {e}")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -93,6 +93,17 @@ def setup_clean_logging():
format="{time:HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}\n",
filter=warning_only_filter
)
# Add a focused sink to surface Story Video Generation INFO logs in console
def video_generation_filter(record):
msg = record.get("message", "")
name = record.get("name", "")
return "[StoryVideoGeneration]" in msg or "services.story_writer.video_generation_service" in name
logger.add(
sys.stdout.write,
level="INFO",
format="{time:HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}\n",
filter=video_generation_filter
)
else:
# In verbose mode, show all log levels with detailed formatting
logger.add(

View File

@@ -35,6 +35,8 @@ class APIProvider(enum.Enum):
FIRECRAWL = "firecrawl"
STABILITY = "stability"
EXA = "exa"
VIDEO = "video"
IMAGE_EDIT = "image_edit"
class BillingCycle(enum.Enum):
MONTHLY = "monthly"
@@ -68,6 +70,8 @@ class SubscriptionPlan(Base):
firecrawl_calls_limit = Column(Integer, default=0)
stability_calls_limit = Column(Integer, default=0) # Image generation
exa_calls_limit = Column(Integer, default=0) # Exa neural search
video_calls_limit = Column(Integer, default=0) # AI video generation
image_edit_calls_limit = Column(Integer, default=0) # AI image editing
# Token Limits (for LLM providers)
gemini_tokens_limit = Column(Integer, default=0)
@@ -185,6 +189,8 @@ class UsageSummary(Base):
firecrawl_calls = Column(Integer, default=0)
stability_calls = Column(Integer, default=0)
exa_calls = Column(Integer, default=0)
video_calls = Column(Integer, default=0) # AI video generation
image_edit_calls = Column(Integer, default=0) # AI image editing
# Token Usage
gemini_tokens = Column(Integer, default=0)
@@ -203,6 +209,8 @@ class UsageSummary(Base):
firecrawl_cost = Column(Float, default=0.0)
stability_cost = Column(Float, default=0.0)
exa_cost = Column(Float, default=0.0)
video_cost = Column(Float, default=0.0) # AI video generation
image_edit_cost = Column(Float, default=0.0) # AI image editing
# Totals
total_calls = Column(Integer, default=0)

View File

@@ -53,7 +53,7 @@ nltk>=3.8.0
# Image and audio processing for Stability AI
Pillow>=10.0.0
huggingface_hub>=0.24.0
huggingface_hub>=1.1.4
scikit-learn>=1.3.0
# Text-to-Speech (TTS) dependencies
@@ -61,7 +61,7 @@ gtts>=2.4.0
pyttsx3>=2.90
# Video composition dependencies
moviepy>=1.0.3
moviepy==2.1.2
imageio>=2.31.0
imageio-ffmpeg>=0.4.9

View File

@@ -0,0 +1,102 @@
"""
Script to update existing subscription plans with image_edit_calls_limit values.
This script updates the SubscriptionPlan table to set image_edit_calls_limit
for plans that were created before this column was added.
Limits:
- Free: 10 image editing calls/month
- Basic: 30 image editing calls/month
- Pro: 100 image editing calls/month
- Enterprise: 0 (unlimited)
"""
import sys
import os
from pathlib import Path
from datetime import datetime, timezone
# Add the backend directory to Python path
backend_dir = Path(__file__).parent.parent
sys.path.insert(0, str(backend_dir))
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from loguru import logger
from models.subscription_models import SubscriptionPlan, SubscriptionTier
from services.database import DATABASE_URL
def update_image_edit_limits():
"""Update existing subscription plans with image_edit_calls_limit values."""
try:
# Create engine
engine = create_engine(DATABASE_URL, echo=False)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = SessionLocal()
try:
# Ensure schema columns exist
from services.subscription.schema_utils import ensure_subscription_plan_columns
ensure_subscription_plan_columns(db)
# Define limits for each tier
limits_by_tier = {
SubscriptionTier.FREE: 10,
SubscriptionTier.BASIC: 30,
SubscriptionTier.PRO: 100,
SubscriptionTier.ENTERPRISE: 0, # Unlimited
}
updated_count = 0
# Update each plan
for tier, limit in limits_by_tier.items():
plans = db.query(SubscriptionPlan).filter(
SubscriptionPlan.tier == tier,
SubscriptionPlan.is_active == True
).all()
for plan in plans:
current_limit = getattr(plan, 'image_edit_calls_limit', 0) or 0
# Only update if limit is 0 (not set) or if it's different
if current_limit != limit:
setattr(plan, 'image_edit_calls_limit', limit)
plan.updated_at = datetime.now(timezone.utc)
updated_count += 1
logger.info(f"Updated {plan.name} plan ({tier.value}): image_edit_calls_limit = {current_limit} -> {limit}")
else:
logger.debug(f"Plan {plan.name} ({tier.value}) already has image_edit_calls_limit = {limit}")
# Commit changes
db.commit()
if updated_count > 0:
logger.info(f"✅ Successfully updated {updated_count} subscription plan(s) with image_edit_calls_limit")
else:
logger.info("✅ All subscription plans already have correct image_edit_calls_limit values")
return True
except Exception as e:
logger.error(f"❌ Error updating image_edit_limits: {e}")
db.rollback()
raise
finally:
db.close()
except Exception as e:
logger.error(f"❌ Error creating database connection: {e}")
raise
if __name__ == "__main__":
logger.info("🔄 Updating subscription plans with image_edit_calls_limit...")
success = update_image_edit_limits()
if success:
logger.info("🎉 Image edit limits update completed successfully!")
else:
logger.error("❌ Image edit limits update failed")
sys.exit(1)

View File

@@ -0,0 +1,165 @@
from __future__ import annotations
import os
import io
from typing import Optional, Dict, Any
from PIL import Image
from .image_generation import (
ImageGenerationOptions,
ImageGenerationResult,
)
from utils.logger_utils import get_service_logger
try:
from huggingface_hub import InferenceClient
HF_HUB_AVAILABLE = True
except ImportError:
HF_HUB_AVAILABLE = False
logger = get_service_logger("image_editing.facade")
DEFAULT_IMAGE_EDIT_MODEL = os.getenv(
"HF_IMAGE_EDIT_MODEL",
"Qwen/Qwen-Image-Edit",
)
def _select_provider(explicit: Optional[str]) -> str:
"""Select provider for image editing. Defaults to huggingface with fal-ai."""
if explicit:
return explicit
# Default to huggingface for image editing (best support for image-to-image)
return "huggingface"
def _get_provider_client(provider_name: str, api_key: Optional[str] = None):
"""Get InferenceClient for the specified provider."""
if not HF_HUB_AVAILABLE:
raise RuntimeError("huggingface_hub is not installed. Install with: pip install huggingface_hub")
if provider_name == "huggingface":
api_key = api_key or os.getenv("HF_TOKEN")
if not api_key:
raise RuntimeError("HF_TOKEN is required for Hugging Face image editing")
# Use fal-ai provider for fast inference
return InferenceClient(provider="fal-ai", api_key=api_key)
raise ValueError(f"Unknown image editing provider: {provider_name}")
def edit_image(
input_image_bytes: bytes,
prompt: str,
options: Optional[Dict[str, Any]] = None,
user_id: Optional[str] = None
) -> ImageGenerationResult:
"""Edit image with pre-flight validation.
Args:
input_image_bytes: Input image as bytes (PNG/JPEG)
prompt: Natural language prompt describing desired edits (e.g., "Turn the cat into a tiger")
options: Image editing options (provider, model, etc.)
user_id: User ID for subscription checking (optional, but required for validation)
Returns:
ImageGenerationResult with edited image bytes and metadata
Best Practices for Prompts:
- Use clear, specific language describing desired changes
- Describe what should change and what should remain
- Examples: "Turn the cat into a tiger", "Change background to forest",
"Make it look like a watercolor painting"
"""
# PRE-FLIGHT VALIDATION: Validate image editing before API call
# MUST happen BEFORE any API calls - return immediately if validation fails
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_editing_operations
from fastapi import HTTPException
db = next(get_db())
try:
pricing_service = PricingService(db)
# Raises HTTPException immediately if validation fails - frontend gets immediate response
validate_image_editing_operations(
pricing_service=pricing_service,
user_id=user_id
)
except HTTPException as http_ex:
# Re-raise immediately - don't proceed with API call
logger.error(f"[Image Editing] ❌ Pre-flight validation failed - blocking API call")
raise
finally:
db.close()
logger.info(f"[Image Editing] ✅ Pre-flight validation passed - proceeding with image editing")
# Validate input
if not input_image_bytes:
raise ValueError("input_image_bytes is required")
if not prompt or not prompt.strip():
raise ValueError("prompt is required for image editing")
opts = options or {}
provider_name = _select_provider(opts.get("provider"))
model = opts.get("model") or DEFAULT_IMAGE_EDIT_MODEL
logger.info(f"[Image Editing] Editing image via provider={provider_name} model={model}")
# Get provider client
client = _get_provider_client(provider_name, opts.get("api_key"))
# Prepare parameters for image-to-image
params: Dict[str, Any] = {}
if opts.get("guidance_scale") is not None:
params["guidance_scale"] = opts.get("guidance_scale")
if opts.get("steps") is not None:
params["num_inference_steps"] = opts.get("steps")
if opts.get("seed") is not None:
params["seed"] = opts.get("seed")
try:
# Convert input image bytes to PIL Image for validation
input_image = Image.open(io.BytesIO(input_image_bytes))
width = input_image.width
height = input_image.height
# Use image_to_image method from Hugging Face InferenceClient
# This follows the pattern from the Hugging Face documentation
# Docs: https://huggingface.co/docs/inference-providers/en/guides/image-editor
edited_image: Image.Image = client.image_to_image(
image=input_image,
prompt=prompt.strip(),
model=model,
**params,
)
# Convert edited image back to bytes
with io.BytesIO() as buf:
edited_image.save(buf, format="PNG")
edited_image_bytes = buf.getvalue()
logger.info(f"[Image Editing] ✅ Successfully edited image: {len(edited_image_bytes)} bytes")
return ImageGenerationResult(
image_bytes=edited_image_bytes,
width=edited_image.width,
height=edited_image.height,
provider="huggingface",
model=model,
seed=opts.get("seed"),
metadata={
"provider": "fal-ai",
"operation": "image_editing",
"original_width": width,
"original_height": height,
},
)
except Exception as e:
logger.error(f"[Image Editing] ❌ Error editing image: {e}", exc_info=True)
raise RuntimeError(f"Image editing failed: {str(e)}")

View File

@@ -507,6 +507,14 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
current_images_before = getattr(summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
# Get image editing stats for unified log
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
# Get video stats for unified log
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
# CRITICAL DEBUG: Print diagnostic info BEFORE commit (always visible, flushed immediately)
import sys
debug_msg = f"[DEBUG] BEFORE COMMIT - Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}, Period: {current_period}, New calls will be: {new_calls}, New tokens will be: {new_tokens}"
@@ -562,6 +570,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
├─ Tokens: {current_tokens_before}{new_tokens} / {token_limit if token_limit > 0 else ''}
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
@@ -802,6 +811,14 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
current_images_before = getattr(summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
# Get image editing stats for unified log
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
# Get video stats for unified log
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
db_track.commit() # Commit transaction to make changes visible to other sessions
@@ -819,6 +836,8 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
├─ Tokens: {current_tokens_before}{new_tokens} / {token_limit if token_limit > 0 else ''}
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:

View File

@@ -0,0 +1,355 @@
"""
Main Video Generation Service
Provides a unified interface for AI video generation providers.
Initial support: Hugging Face Inference Providers (text-to-video).
Stubs included for Gemini (Veo 3) and OpenAI (Sora) for future use.
"""
from __future__ import annotations
import os
import base64
import io
from typing import Any, Dict, Optional, Union
from fastapi import HTTPException
try:
from huggingface_hub import InferenceClient
HF_HUB_AVAILABLE = True
except ImportError:
HF_HUB_AVAILABLE = False
InferenceClient = None
from ..onboarding.api_key_manager import APIKeyManager
from utils.logger_utils import get_service_logger
logger = get_service_logger("video_generation_service")
class VideoProviderNotImplemented(Exception):
pass
def _get_api_key(provider: str) -> Optional[str]:
try:
manager = APIKeyManager()
mapping = {
"huggingface": "hf_token",
"gemini": "gemini", # placeholder for Veo 3
"openai": "openai_api_key", # placeholder for Sora
}
return manager.get_api_key(mapping.get(provider, provider))
except Exception as e:
logger.error(f"[video_gen] Failed to read API key for {provider}: {e}")
return None
def _coerce_video_bytes(output: Any) -> bytes:
"""
Normalizes the different return shapes that huggingface_hub may emit for video tasks.
Depending on the provider/library version we may get:
- raw bytes
- an object with `.video` or `.bytes` attributes (plus optional `.save`)
- a dict containing a `video` key with bytes/base64 data
"""
data: Union[bytes, bytearray, memoryview, io.BufferedIOBase, None] = None
if isinstance(output, (bytes, bytearray, memoryview)):
return bytes(output)
# Objects with direct attribute access
if hasattr(output, "video"):
data = getattr(output, "video")
elif hasattr(output, "bytes"):
data = getattr(output, "bytes")
elif isinstance(output, dict) and "video" in output:
data = output["video"]
else:
data = output
# Handle file-like responses
if hasattr(data, "read"):
data = data.read()
if isinstance(data, (bytes, bytearray, memoryview)):
return bytes(data)
if isinstance(data, str):
# Expecting data URI or raw base64 string
if data.startswith("data:"):
_, encoded = data.split(",", 1)
return base64.b64decode(encoded)
try:
return base64.b64decode(data)
except Exception as exc:
raise TypeError(f"Unable to decode string video payload: {exc}") from exc
raise TypeError(f"Unsupported video payload type: {type(data)}")
def _generate_with_huggingface(
prompt: str,
num_frames: int = 24 * 4,
guidance_scale: float = 7.5,
num_inference_steps: int = 30,
negative_prompt: Optional[str] = None,
seed: Optional[int] = None,
model: str = "tencent/HunyuanVideo",
input_image_bytes: Optional[bytes] = None,
) -> bytes:
"""
Generates video bytes using Hugging Face's InferenceClient.
"""
if not HF_HUB_AVAILABLE:
raise RuntimeError("huggingface_hub is not installed. Install with: pip install huggingface_hub")
token = _get_api_key("huggingface")
if not token:
raise RuntimeError("HF token not configured. Set an hf_token in APIKeyManager.")
client = InferenceClient(
model=model,
provider="fal-ai",
token=token,
)
logger.info("[video_gen] Using HuggingFace provider 'fal-ai'")
params: Dict[str, Any] = {
"num_frames": num_frames,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
}
if negative_prompt:
params["negative_prompt"] = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt]
if seed is not None:
params["seed"] = seed
logger.info(
"[video_gen] HuggingFace request model=%s frames=%s steps=%s mode=%s",
model,
num_frames,
num_inference_steps,
"image-to-video" if input_image_bytes else "text-to-video",
)
try:
call_kwargs = {**params, "model": model}
if input_image_bytes:
video_output = client.image_to_video(
image=input_image_bytes,
prompt=prompt,
**call_kwargs,
)
else:
video_output = client.text_to_video(
prompt,
**call_kwargs,
)
video_bytes = _coerce_video_bytes(video_output)
if not isinstance(video_bytes, bytes):
raise TypeError(f"Expected bytes from text_to_video, got {type(video_bytes)}")
if len(video_bytes) == 0:
raise ValueError("Received empty video bytes from Hugging Face API")
logger.info(f"[video_gen] Successfully generated video: {len(video_bytes)} bytes")
return video_bytes
except Exception as e:
error_msg = str(e)
error_type = type(e).__name__
logger.error(f"[video_gen] HF error ({error_type}): {error_msg}", exc_info=True)
raise HTTPException(status_code=502, detail={
"error": f"Hugging Face video generation failed: {error_msg}",
"error_type": error_type
})
def _generate_with_gemini(prompt: str, **kwargs) -> bytes:
raise VideoProviderNotImplemented("Gemini Veo 3 integration coming soon.")
def _generate_with_openai(prompt: str, **kwargs) -> bytes:
raise VideoProviderNotImplemented("OpenAI Sora integration coming soon.")
def ai_video_generate(
prompt: str,
provider: str = "huggingface",
user_id: Optional[str] = None,
input_image_bytes: Optional[bytes] = None,
**kwargs,
) -> bytes:
"""
Unified video generation entry point.
- provider: 'huggingface' (default), 'gemini' (veo3 stub), 'openai' (sora stub)
- kwargs: num_frames, guidance_scale, num_inference_steps, negative_prompt, seed, model
- input_image_bytes: optional bytes for image-to-video flows (uses image as motion anchor)
Returns raw video bytes (mp4/webm depending on provider).
"""
logger.info(f"[video_gen] provider={provider}")
# Enforce authentication usage like text gen does
if not user_id:
raise RuntimeError("user_id is required for subscription/usage tracking.")
# PRE-FLIGHT VALIDATION: Validate video generation before API call
# MUST happen BEFORE any API calls - return immediately if validation fails
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_video_generation_operations
from fastapi import HTTPException
db = next(get_db())
try:
pricing_service = PricingService(db)
# Raises HTTPException immediately if validation fails - frontend gets immediate response
validate_video_generation_operations(
pricing_service=pricing_service,
user_id=user_id
)
except HTTPException:
# Re-raise immediately - don't proceed with API call
logger.error(f"[Video Generation] ❌ Pre-flight validation failed - blocking API call")
raise
finally:
db.close()
logger.info(f"[Video Generation] ✅ Pre-flight validation passed - proceeding with video generation")
# Generate video
model_name = kwargs.get("model", "tencent/HunyuanVideo")
try:
if provider == "huggingface":
video_bytes = _generate_with_huggingface(
prompt=prompt,
input_image_bytes=input_image_bytes,
**kwargs,
)
elif provider == "gemini":
video_bytes = _generate_with_gemini(prompt=prompt, **kwargs)
elif provider == "openai":
video_bytes = _generate_with_openai(prompt=prompt, **kwargs)
else:
raise RuntimeError(f"Unknown video provider: {provider}")
# Track usage AFTER successful generation
db_track = next(get_db())
try:
from models.subscription_models import APIProvider, UsageSummary, APIUsageLog
from datetime import datetime
from services.subscription import PricingService
# Create pricing service for tracking (uses same DB session)
pricing_service_track = PricingService(db_track)
# Get current billing period
current_period = pricing_service_track.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get or create usage summary
usage_summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not usage_summary:
usage_summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(usage_summary)
db_track.commit()
# Calculate cost using pricing service
cost_info = pricing_service_track.get_pricing_for_provider_model(
APIProvider.VIDEO,
model_name
)
cost_per_video = cost_info.get('cost_per_request', 0.10) if cost_info else 0.10
# Get "before" state for unified log
current_video_calls_before = getattr(usage_summary, 'video_calls', 0) or 0
current_video_cost = getattr(usage_summary, 'video_cost', 0.0) or 0.0
# Increment video_calls and track cost
new_video_calls = current_video_calls_before + 1
usage_summary.video_calls = new_video_calls
usage_summary.video_cost = current_video_cost + cost_per_video
usage_summary.total_calls = (usage_summary.total_calls or 0) + 1
usage_summary.total_cost = (usage_summary.total_cost or 0.0) + cost_per_video
# Get plan details for unified log (before commit, in case commit fails)
limits = pricing_service_track.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
# Get image and image editing stats for unified log
current_image_calls = getattr(usage_summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
current_image_edit_calls = getattr(usage_summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
# Create usage log entry for audit trail
usage_log = APIUsageLog(
user_id=user_id,
provider=APIProvider.VIDEO,
endpoint=f"/video-generation/{provider}",
method="POST",
model_used=model_name,
tokens_input=0,
tokens_output=0,
tokens_total=0,
cost_input=0.0,
cost_output=0.0,
cost_total=cost_per_video,
response_time=0.0, # Could track actual time if needed
status_code=200,
request_size=len(prompt.encode('utf-8')),
response_size=len(video_bytes),
billing_period=current_period
)
db_track.add(usage_log)
db_track.commit()
logger.info(f"[video_gen] ✅ Successfully tracked usage: user {user_id} -> 1 video call, ${cost_per_video:.4f} cost")
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
# Flush immediately to ensure it's visible in console/logs
import sys
log_message = f"""
[SUBSCRIPTION] Video Generation
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: video
├─ Actual Provider: {provider}
├─ Model: {model_name or 'default'}
├─ Calls: {current_video_calls_before}{new_video_calls} / {video_limit if video_limit > 0 else ''}
├─ Images: {current_image_calls} / {image_limit if image_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
"""
print(log_message, flush=True)
sys.stdout.flush()
except Exception as track_error:
logger.error(f"[video_gen] Error tracking usage: {track_error}", exc_info=True)
db_track.rollback()
# Don't fail video generation if tracking fails - video is already generated
finally:
db_track.close()
return video_bytes
except HTTPException:
# Re-raise HTTPExceptions (e.g., from validation or API errors)
raise
except Exception as e:
logger.error(f"[video_gen] Error during video generation: {e}", exc_info=True)
raise HTTPException(status_code=500, detail={"error": str(e)})

View File

@@ -0,0 +1,352 @@
"""
Prompt Enhancement Service for HunyuanVideo Generation
Uses AI to deeply understand story context and generate optimized
HunyuanVideo prompts following best practices with 7 components.
"""
from typing import Dict, Any, List, Optional
from loguru import logger
from fastapi import HTTPException
from services.llm_providers.main_text_generation import llm_text_gen
class PromptEnhancerService:
"""Service for generating HunyuanVideo-optimized prompts from story context."""
def __init__(self):
"""Initialize the prompt enhancer service."""
logger.info("[PromptEnhancer] Service initialized")
def enhance_scene_prompt(
self,
current_scene: Dict[str, Any],
story_context: Dict[str, Any],
all_scenes: List[Dict[str, Any]],
user_id: str
) -> str:
"""
Generate a HunyuanVideo-optimized prompt for a scene using two-stage AI analysis.
Args:
current_scene: Scene data for the scene being processed
story_context: Complete story context (setup, premise, outline, story text)
all_scenes: List of all scenes for consistency analysis
user_id: Clerk user ID for subscription checking
Returns:
str: Optimized HunyuanVideo prompt (300-500 words) with 7 components
"""
try:
logger.info(f"[PromptEnhancer] Enhancing prompt for scene {current_scene.get('scene_number', 'unknown')}")
# Stage 1: Deep story context analysis
story_insights = self._analyze_story_context(
current_scene=current_scene,
story_context=story_context,
all_scenes=all_scenes,
user_id=user_id
)
# Stage 2: Generate optimized HunyuanVideo prompt
optimized_prompt = self._generate_hunyuan_prompt(
current_scene=current_scene,
story_context=story_context,
story_insights=story_insights,
all_scenes=all_scenes,
user_id=user_id
)
logger.info(f"[PromptEnhancer] Generated prompt length: {len(optimized_prompt)} characters")
return optimized_prompt
except HTTPException as http_err:
# Propagate subscription limit errors (429) to frontend for modal display
# Only fallback for other HTTP errors (5xx, etc.)
if http_err.status_code == 429:
error_msg = self._extract_error_message(http_err)
logger.warning(f"[PromptEnhancer] Subscription limit exceeded (HTTP 429): {error_msg}")
# Re-raise to propagate to frontend for subscription modal
raise
else:
# For other HTTP errors, log and fallback
error_msg = self._extract_error_message(http_err)
logger.error(f"[PromptEnhancer] Error enhancing prompt (HTTP {http_err.status_code}): {error_msg}", exc_info=True)
return self._generate_fallback_prompt(current_scene, story_context)
except Exception as e:
logger.error(f"[PromptEnhancer] Error enhancing prompt: {str(e)}", exc_info=True)
# Fallback to basic prompt if enhancement fails
return self._generate_fallback_prompt(current_scene, story_context)
def _analyze_story_context(
self,
current_scene: Dict[str, Any],
story_context: Dict[str, Any],
all_scenes: List[Dict[str, Any]],
user_id: str
) -> str:
"""
Stage 1: Use AI to analyze complete story context and extract insights.
Returns:
str: Story insights as JSON string for use in prompt generation
"""
# Build comprehensive context for analysis
analysis_prompt = f"""You are analyzing a complete story to extract key insights for AI video generation.
**STORY SETUP:**
- Persona: {story_context.get('persona', 'N/A')}
- Setting: {story_context.get('story_setting', 'N/A')}
- Characters: {story_context.get('characters', 'N/A')}
- Plot Elements: {story_context.get('plot_elements', 'N/A')}
- Writing Style: {story_context.get('writing_style', 'N/A')}
- Tone: {story_context.get('story_tone', 'N/A')}
- Narrative POV: {story_context.get('narrative_pov', 'N/A')}
- Audience: {story_context.get('audience_age_group', 'N/A')}
- Content Rating: {story_context.get('content_rating', 'N/A')}
**STORY PREMISE:**
{story_context.get('premise', 'N/A')}
**STORY CONTENT:**
{story_context.get('story_content', 'N/A')[:2000]}...
**ALL SCENES OVERVIEW:**
"""
# Add summary of all scenes
for idx, scene in enumerate(all_scenes, 1):
scene_num = scene.get('scene_number', idx)
analysis_prompt += f"\nScene {scene_num}: {scene.get('title', 'Untitled')}"
analysis_prompt += f"\n Description: {scene.get('description', '')[:150]}..."
analysis_prompt += f"\n Image Prompt: {scene.get('image_prompt', '')[:150]}..."
if scene.get('character_descriptions'):
chars = ', '.join(scene.get('character_descriptions', [])[:3])
analysis_prompt += f"\n Characters: {chars}"
analysis_prompt += "\n"
analysis_prompt += f"""
**CURRENT SCENE FOR VIDEO GENERATION:**
Scene {current_scene.get('scene_number', 'N/A')}: {current_scene.get('title', 'Untitled')}
Description: {current_scene.get('description', '')}
Image Prompt: {current_scene.get('image_prompt', '')}
Key Events: {', '.join(current_scene.get('key_events', [])[:5])}
Character Descriptions: {', '.join(current_scene.get('character_descriptions', [])[:5])}
**YOUR TASK:**
Analyze this story and extract key insights for video generation. Focus on:
1. Narrative arc and position of current scene within it
2. Character consistency (how characters appear across scenes)
3. Visual style patterns from image prompts
4. Tone and atmosphere progression
5. Key themes and motifs
6. Visual narrative flow
7. Camera and composition needs for this specific scene
Provide your analysis as structured insights that can guide prompt generation.
"""
try:
insights = llm_text_gen(
prompt=analysis_prompt,
system_prompt="You are an expert story analyst specializing in visual narrative and cinematic storytelling. Provide detailed, actionable insights for video generation.",
user_id=user_id
)
logger.debug(f"[PromptEnhancer] Story insights extracted: {insights[:200]}...")
return insights
except HTTPException as http_err:
# Propagate subscription limit errors (429) to frontend
if http_err.status_code == 429:
error_msg = self._extract_error_message(http_err)
logger.warning(f"[PromptEnhancer] Subscription limit exceeded during story analysis (HTTP 429): {error_msg}")
# Re-raise to propagate to frontend for subscription modal
raise
else:
# For other HTTP errors, log and fallback
error_msg = self._extract_error_message(http_err)
logger.warning(f"[PromptEnhancer] Story analysis failed (HTTP {http_err.status_code}): {error_msg}, using basic context")
return "Standard narrative flow with consistent character presentation"
except Exception as e:
logger.warning(f"[PromptEnhancer] Story analysis failed, using basic context: {str(e)}")
return "Standard narrative flow with consistent character presentation"
def _generate_hunyuan_prompt(
self,
current_scene: Dict[str, Any],
story_context: Dict[str, Any],
story_insights: str,
all_scenes: List[Dict[str, Any]],
user_id: str
) -> str:
"""
Stage 2: Generate scene-specific HunyuanVideo prompt with all 7 components.
Returns:
str: Complete HunyuanVideo prompt (300-500 words)
"""
# Collect character descriptions across all scenes for consistency
all_characters = {}
for scene in all_scenes:
for char_desc in scene.get('character_descriptions', []):
if char_desc and char_desc not in all_characters:
all_characters[char_desc] = scene.get('scene_number', 0)
# Collect image prompts for visual style reference
image_prompts = [scene.get('image_prompt', '') for scene in all_scenes if scene.get('image_prompt')]
# Determine scene position in narrative arc
current_scene_num = current_scene.get('scene_number', 0)
total_scenes = len(all_scenes)
scene_position = "beginning" if current_scene_num <= total_scenes // 3 else ("middle" if current_scene_num <= 2 * total_scenes // 3 else "climax")
prompt_generation_request = f"""Generate a professional HunyuanVideo prompt for this story scene.
**STORY INSIGHTS (from deep analysis):**
{story_insights}
**STORY SETUP:**
- Setting: {story_context.get('story_setting', 'N/A')}
- Tone: {story_context.get('story_tone', 'N/A')}
- Style: {story_context.get('writing_style', 'N/A')}
- Audience: {story_context.get('audience_age_group', 'N/A')}
**VISUAL STYLE REFERENCE (from generated images):**
{chr(10).join([f"- {prompt[:100]}..." for prompt in image_prompts[:3]])}
**CHARACTER CONSISTENCY (across all scenes):**
{chr(10).join([f"- {char}" for char in list(all_characters.keys())[:5]])}
**CURRENT SCENE DETAILS:**
- Scene {current_scene.get('scene_number', 'N/A')} of {total_scenes} (narrative position: {scene_position})
- Title: {current_scene.get('title', 'Untitled')}
- Description: {current_scene.get('description', '')}
- Image Prompt: {current_scene.get('image_prompt', '')}
- Key Events: {', '.join(current_scene.get('key_events', [])[:5])}
- Characters in scene: {', '.join(current_scene.get('character_descriptions', [])[:5])}
- Audio Narration: {current_scene.get('audio_narration', '')[:200]}
**REQUIREMENTS:**
Create a comprehensive HunyuanVideo prompt (300-500 words) following the 7-component structure:
1. **SUBJECT**: Clearly define the main focus - characters, objects, or action. Include character descriptions that match the visual style from image prompts and maintain consistency across scenes.
2. **SCENE**: Describe the environment and setting. Ensure it matches the story_setting and aligns with the visual style established in previous scenes.
3. **MOTION**: Detail the specific actions and movements. Reference key_events and ensure motion fits the narrative flow and story_insights about the scene's position in the arc.
4. **CAMERA MOVEMENT**: Specify cinematic camera work appropriate for this moment in the story. Consider the narrative position ({scene_position}) - use establishing shots for beginning, dynamic shots for climax.
5. **ATMOSPHERE**: Set the emotional tone. This should reflect the story_tone but also consider where we are in the narrative arc based on story_insights.
6. **LIGHTING**: Define lighting that matches the visual style from image prompts and supports the atmosphere. Ensure consistency with the established visual aesthetic.
7. **SHOT COMPOSITION**: Describe framing and composition that serves the visual narrative. Consider the story's visual style and ensure it flows naturally with the overall story.
Write the prompt as a flowing, detailed description (not a list) that integrates all 7 components naturally. Make it vivid, cinematic, and consistent with the story's established visual and narrative style. The prompt should be between 300-500 words.
"""
try:
optimized_prompt = llm_text_gen(
prompt=prompt_generation_request,
system_prompt="You are an expert video prompt engineer specializing in HunyuanVideo text-to-video generation. Create detailed, cinematic prompts that follow best practices and ensure high-quality video output.",
user_id=user_id
)
# Clean up and validate prompt length
optimized_prompt = optimized_prompt.strip()
word_count = len(optimized_prompt.split())
if word_count < 200:
logger.warning(f"[PromptEnhancer] Generated prompt is too short ({word_count} words), enhancing...")
# Add more detail if too short
optimized_prompt += self._add_cinematic_details(current_scene, story_context)
elif word_count > 600:
logger.warning(f"[PromptEnhancer] Generated prompt is too long ({word_count} words), trimming...")
# Trim if too long (keep first ~500 words)
words = optimized_prompt.split()
optimized_prompt = ' '.join(words[:500])
logger.info(f"[PromptEnhancer] Generated prompt: {len(optimized_prompt.split())} words")
return optimized_prompt
except HTTPException as http_err:
# Propagate subscription limit errors (429) to frontend
if http_err.status_code == 429:
error_msg = self._extract_error_message(http_err)
logger.warning(f"[PromptEnhancer] Subscription limit exceeded during prompt generation (HTTP 429): {error_msg}")
# Re-raise to propagate to frontend for subscription modal
raise
else:
# For other HTTP errors, log and fallback
error_msg = self._extract_error_message(http_err)
logger.error(f"[PromptEnhancer] Prompt generation failed (HTTP {http_err.status_code}): {error_msg}", exc_info=True)
return self._generate_fallback_prompt(current_scene, story_context)
except Exception as e:
logger.error(f"[PromptEnhancer] Prompt generation failed: {str(e)}", exc_info=True)
return self._generate_fallback_prompt(current_scene, story_context)
def _add_cinematic_details(
self,
current_scene: Dict[str, Any],
story_context: Dict[str, Any]
) -> str:
"""Add cinematic details to enhance a too-short prompt."""
return f"""
The scene unfolds with careful attention to visual storytelling. The {story_context.get('story_setting', 'environment')} serves as more than background - it actively participates in the narrative. Lighting and composition work together to emphasize the emotional weight of this moment, with camera movements that guide the viewer's attention naturally through the space. Every element - from the way light falls to the positioning of characters - contributes to the overall narrative impact.
"""
def _extract_error_message(self, http_err: HTTPException) -> str:
"""
Extract meaningful error message from HTTPException.
Handles both dict-based details (from subscription limit errors) and string details.
"""
if isinstance(http_err.detail, dict):
# For subscription limit errors, extract the 'message' or 'error' field
return http_err.detail.get('message') or http_err.detail.get('error') or str(http_err.detail)
elif isinstance(http_err.detail, str):
return http_err.detail
else:
return str(http_err.detail)
def _generate_fallback_prompt(
self,
current_scene: Dict[str, Any],
story_context: Dict[str, Any]
) -> str:
"""Generate a basic fallback prompt if AI enhancement fails."""
scene_title = current_scene.get('title', 'Untitled Scene')
scene_desc = current_scene.get('description', '')
image_prompt = current_scene.get('image_prompt', '')
setting = story_context.get('story_setting', 'the scene')
tone = story_context.get('story_tone', 'engaging')
return f"""A cinematic scene titled "{scene_title}" set in {setting}. {scene_desc[:200]}.
The scene features {', '.join(current_scene.get('character_descriptions', [])[:2]) if current_scene.get('character_descriptions') else 'the main characters'}.
Visual style follows: {image_prompt[:150]}.
The {tone} atmosphere is enhanced by natural lighting and dynamic camera movements that follow the action.
Shot composition emphasizes the narrative importance of this moment, with careful framing that draws attention to key elements.
The scene maintains visual consistency with previous moments while advancing the story's visual narrative."""
def enhance_scene_prompt_for_video(
current_scene: Dict[str, Any],
story_context: Dict[str, Any],
all_scenes: List[Dict[str, Any]],
user_id: str
) -> str:
"""
Convenience function to enhance a scene prompt for HunyuanVideo generation.
Args:
current_scene: Scene data for the scene being processed
story_context: Complete story context dictionary
all_scenes: List of all scenes for consistency
user_id: Clerk user ID for subscription checking
Returns:
str: Optimized HunyuanVideo prompt
"""
service = PromptEnhancerService()
return service.enhance_scene_prompt(current_scene, story_context, all_scenes, user_id)

View File

@@ -41,6 +41,47 @@ class StoryVideoGenerationService:
unique_id = str(uuid.uuid4())[:8]
return f"story_{clean_title}_{unique_id}.mp4"
def save_scene_video(self, video_bytes: bytes, scene_number: int, user_id: str) -> Dict[str, str]:
"""
Save individual scene video bytes to file.
Parameters:
video_bytes: Raw video file bytes (mp4/webm format)
scene_number: Scene number for naming
user_id: Clerk user ID for naming
Returns:
Dict[str, str]: Video metadata with video_url and video_filename
"""
try:
# Generate filename with scene number and user ID
clean_user_id = "".join(c if c.isalnum() or c in ('-', '_') else '_' for c in user_id[:16])
timestamp = str(uuid.uuid4())[:8]
filename = f"scene_{scene_number}_{clean_user_id}_{timestamp}.mp4"
video_path = self.output_dir / filename
# Write video bytes to file
with open(video_path, 'wb') as f:
f.write(video_bytes)
file_size = video_path.stat().st_size
logger.info(f"[StoryVideoGeneration] Saved scene {scene_number} video: {filename} ({file_size} bytes)")
# Generate URL path (relative to /api/story/videos/)
video_url = f"/api/story/videos/{filename}"
return {
"video_filename": filename,
"video_url": video_url,
"video_path": str(video_path),
"file_size": file_size
}
except Exception as e:
logger.error(f"[StoryVideoGeneration] Error saving scene video: {e}", exc_info=True)
raise RuntimeError(f"Failed to save scene video: {str(e)}") from e
def generate_scene_video(
self,
scene: Dict[str, Any],
@@ -125,12 +166,12 @@ class StoryVideoGenerationService:
# Use provided duration or audio duration
video_duration = duration if duration is not None else audio_duration
# Create image clip
image_clip = ImageClip(str(image_file)).set_duration(video_duration)
image_clip = image_clip.set_fps(fps)
# Create image clip (MoviePy v2: use with_* API)
image_clip = ImageClip(str(image_file)).with_duration(video_duration)
image_clip = image_clip.with_fps(fps)
# Set audio to image clip
video_clip = image_clip.set_audio(audio_clip)
video_clip = image_clip.with_audio(audio_clip)
# Generate video filename
video_filename = f"scene_{scene_number}_{scene_title.replace(' ', '_').replace('/', '_')[:50]}_{uuid.uuid4().hex[:8]}.mp4"
@@ -274,12 +315,12 @@ class StoryVideoGenerationService:
audio_clip = AudioFileClip(str(audio_file))
audio_duration = audio_clip.duration
# Create image clip
image_clip = ImageClip(str(image_file)).set_duration(audio_duration)
image_clip = image_clip.set_fps(fps)
# Create image clip (MoviePy v2: use with_* API)
image_clip = ImageClip(str(image_file)).with_duration(audio_duration)
image_clip = image_clip.with_fps(fps)
# Set audio to image clip
video_clip = image_clip.set_audio(audio_clip)
video_clip = image_clip.with_audio(audio_clip)
scene_clips.append(video_clip)
total_duration += audio_duration

View File

@@ -0,0 +1,46 @@
from __future__ import annotations
from loguru import logger
def log_video_stack_diagnostics() -> None:
try:
import sys
import platform
import importlib
mv = importlib.import_module("moviepy")
im = importlib.import_module("imageio")
try:
import imageio_ffmpeg as iff
ff = iff.get_ffmpeg_exe()
except Exception:
ff = "unresolved"
logger.info(
"[VideoStack] py={} plat={} moviepy={} imageio={} ffmpeg={}",
sys.executable,
platform.platform(),
getattr(mv, "__version__", "NA"),
getattr(im, "__version__", "NA"),
ff,
)
except Exception as e:
logger.error("[VideoStack] diagnostics failed: {}", e)
def assert_supported_moviepy() -> None:
"""Fail fast if MoviePy isn't version 2.x."""
try:
import pkg_resources as pr
mv = pr.get_distribution("moviepy").version
if not mv.startswith("2."):
raise RuntimeError(
f"Unsupported MoviePy version {mv}. Expected 2.x. "
"Please install with: pip install moviepy==2.1.2"
)
except Exception as e:
# Log and re-raise so startup fails clearly
logger.error("[VideoStack] version check failed: {}", e)
raise

View File

@@ -694,6 +694,44 @@ class LimitValidator:
total_images = projected_images
# Check video generation limits
elif provider == APIProvider.VIDEO:
video_limit = limits.get('video_calls', 0) or 0
total_video_calls = usage.video_calls or 0
projected_video_calls = total_video_calls + 1
if video_limit > 0 and projected_video_calls > video_limit:
error_info = {
'current_calls': total_video_calls,
'limit': video_limit,
'provider': 'video',
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"Video generation limit would be exceeded. Would use {projected_video_calls} of {video_limit} videos this billing period.", {
'error_type': 'video_limit',
'usage_info': error_info
}
# Check image editing limits
elif provider == APIProvider.IMAGE_EDIT:
image_edit_limit = limits.get('image_edit_calls', 0) or 0
total_image_edit_calls = getattr(usage, 'image_edit_calls', 0) or 0
projected_image_edit_calls = total_image_edit_calls + 1
if image_edit_limit > 0 and projected_image_edit_calls > image_edit_limit:
error_info = {
'current_calls': total_image_edit_calls,
'limit': image_edit_limit,
'provider': 'image_edit',
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"Image editing limit would be exceeded. Would use {projected_image_edit_calls} of {image_edit_limit} image edits this billing period.", {
'error_type': 'image_edit_limit',
'usage_info': error_info
}
# Check other provider-specific limits
else:
provider_calls_key = f"{provider_name}_calls"

View File

@@ -299,3 +299,124 @@ def validate_image_generation_operations(
}
)
def validate_image_editing_operations(
pricing_service: PricingService,
user_id: str
) -> None:
"""
Validate image editing operation before making API calls.
Args:
pricing_service: PricingService instance
user_id: User ID for subscription checking
Returns:
None - raises HTTPException with 429 status if validation fails
"""
try:
operations_to_validate = [
{
'provider': APIProvider.IMAGE_EDIT,
'tokens_requested': 0,
'actual_provider_name': 'image_edit',
'operation_type': 'image_editing'
}
]
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
if not can_proceed:
logger.error(f"[Pre-flight Validator] Image editing blocked for user {user_id}: {message}")
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', 'image_edit') if usage_info else 'image_edit'
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details
}
)
logger.info(f"[Pre-flight Validator] ✅ Image editing validated for user {user_id}")
# Validation passed - no return needed (function raises HTTPException if validation fails)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Pre-flight Validator] Error validating image editing: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
'error': f"Failed to validate image editing: {str(e)}",
'message': f"Failed to validate image editing: {str(e)}"
}
)
def validate_video_generation_operations(
pricing_service: PricingService,
user_id: str
) -> None:
"""
Validate video generation operation before making API calls.
Args:
pricing_service: PricingService instance
user_id: User ID for subscription checking
Returns:
None - raises HTTPException with 429 status if validation fails
"""
try:
operations_to_validate = [
{
'provider': APIProvider.VIDEO,
'tokens_requested': 0,
'actual_provider_name': 'video',
'operation_type': 'video_generation'
}
]
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
if not can_proceed:
logger.error(f"[Pre-flight Validator] Video generation blocked for user {user_id}: {message}")
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', 'video') if usage_info else 'video'
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details
}
)
logger.info(f"[Pre-flight Validator] ✅ Video generation validated for user {user_id}")
# Validation passed - no return needed (function raises HTTPException if validation fails)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Pre-flight Validator] Error validating video generation: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
'error': f"Failed to validate video generation: {str(e)}",
'message': f"Failed to validate video generation: {str(e)}"
}
)

View File

@@ -295,10 +295,22 @@ class PricingService:
"model_name": "exa-search",
"cost_per_request": 0.005, # $0.005 per search (1-25 results)
"description": "Exa Neural Search API"
},
{
"provider": APIProvider.VIDEO,
"model_name": "tencent/HunyuanVideo",
"cost_per_request": 0.10, # $0.10 per video generation (estimated)
"description": "HuggingFace AI Video Generation (HunyuanVideo)"
},
{
"provider": APIProvider.VIDEO,
"model_name": "default",
"cost_per_request": 0.10, # $0.10 per video generation (estimated)
"description": "AI Video Generation default pricing"
}
]
# Combine all pricing data
# Combine all pricing data (include video pricing in search_pricing list)
all_pricing = gemini_pricing + openai_pricing + anthropic_pricing + mistral_pricing + search_pricing
# Insert or update pricing data
@@ -344,6 +356,8 @@ class PricingService:
"firecrawl_calls_limit": 10,
"stability_calls_limit": 5,
"exa_calls_limit": 100,
"video_calls_limit": 0, # No video generation for free tier
"image_edit_calls_limit": 10, # 10 AI image editing calls/month
"gemini_tokens_limit": 100000,
"monthly_cost_limit": 0.0,
"features": ["basic_content_generation", "limited_research"],
@@ -365,6 +379,8 @@ class PricingService:
"firecrawl_calls_limit": 100,
"stability_calls_limit": 5,
"exa_calls_limit": 500,
"video_calls_limit": 20, # 20 videos/month for basic plan
"image_edit_calls_limit": 30, # 30 AI image editing calls/month
"gemini_tokens_limit": 20000, # Increased from 5000 for better stability
"openai_tokens_limit": 20000, # Increased from 5000 for better stability
"anthropic_tokens_limit": 20000, # Increased from 5000 for better stability
@@ -388,6 +404,8 @@ class PricingService:
"firecrawl_calls_limit": 500,
"stability_calls_limit": 200,
"exa_calls_limit": 2000,
"video_calls_limit": 50, # 50 videos/month for pro plan
"image_edit_calls_limit": 100, # 100 AI image editing calls/month
"gemini_tokens_limit": 5000000,
"openai_tokens_limit": 2500000,
"anthropic_tokens_limit": 1000000,
@@ -411,6 +429,8 @@ class PricingService:
"firecrawl_calls_limit": 0,
"stability_calls_limit": 0,
"exa_calls_limit": 0, # Unlimited
"video_calls_limit": 0, # Unlimited for enterprise
"image_edit_calls_limit": 0, # Unlimited image editing for enterprise
"gemini_tokens_limit": 0,
"openai_tokens_limit": 0,
"anthropic_tokens_limit": 0,
@@ -429,6 +449,20 @@ class PricingService:
if not existing:
plan = SubscriptionPlan(**plan_data)
self.db.add(plan)
else:
# Update existing plan with new limits (e.g., image_edit_calls_limit)
# This ensures existing plans get new columns like image_edit_calls_limit
for key, value in plan_data.items():
if key not in ["name", "tier"]: # Don't overwrite name/tier
try:
# Try to set the attribute (works even if column was just added)
setattr(existing, key, value)
except (AttributeError, Exception) as e:
# If attribute doesn't exist yet (column not migrated), skip it
# Schema migration will add it, then this will update it on next run
logger.debug(f"Could not set {key} on plan {existing.name}: {e}")
existing.updated_at = datetime.utcnow()
logger.debug(f"Updated existing plan: {existing.name}")
self.db.commit()
logger.debug("Default subscription plans initialized")
@@ -615,6 +649,8 @@ class PricingService:
'metaphor_calls': plan.metaphor_calls_limit,
'firecrawl_calls': plan.firecrawl_calls_limit,
'stability_calls': plan.stability_calls_limit,
'video_calls': getattr(plan, 'video_calls_limit', 0), # Support missing column
'image_edit_calls': getattr(plan, 'image_edit_calls_limit', 0), # Support missing column
# Token limits
'gemini_tokens': plan.gemini_tokens_limit,
'openai_tokens': plan.openai_tokens_limit,

View File

@@ -29,6 +29,8 @@ def ensure_subscription_plan_columns(db: Session) -> None:
# Columns we may reference in models but might be missing in older DBs
required_columns = {
"exa_calls_limit": "INTEGER DEFAULT 0",
"video_calls_limit": "INTEGER DEFAULT 0",
"image_edit_calls_limit": "INTEGER DEFAULT 0",
}
for col_name, ddl in required_columns.items():
@@ -78,6 +80,10 @@ def ensure_usage_summaries_columns(db: Session) -> None:
required_columns = {
"exa_calls": "INTEGER DEFAULT 0",
"exa_cost": "REAL DEFAULT 0.0",
"video_calls": "INTEGER DEFAULT 0",
"video_cost": "REAL DEFAULT 0.0",
"image_edit_calls": "INTEGER DEFAULT 0",
"image_edit_cost": "REAL DEFAULT 0.0",
}
for col_name, ddl in required_columns.items():

View File

@@ -608,6 +608,12 @@ class UsageTrackingService:
# Reset image generation counters
summary.stability_calls = 0
# Reset video generation counters
summary.video_calls = 0
# Reset image editing counters
summary.image_edit_calls = 0
# Reset cost counters
summary.gemini_cost = 0.0
summary.openai_cost = 0.0
@@ -618,6 +624,9 @@ class UsageTrackingService:
summary.metaphor_cost = 0.0
summary.firecrawl_cost = 0.0
summary.stability_cost = 0.0
summary.exa_cost = 0.0
summary.video_cost = 0.0
summary.image_edit_cost = 0.0
# Reset totals
summary.total_calls = 0

View File

@@ -161,9 +161,29 @@ def start_backend(enable_reload=False, production_mode=False):
# Set up clean logging for end users
from logging_config import setup_clean_logging, get_uvicorn_log_level
# Video stack preflight (diagnostics + version assert)
try:
from services.story_writer.video_preflight import (
log_video_stack_diagnostics,
assert_supported_moviepy,
)
except Exception:
# Preflight is optional; continue if module missing
log_video_stack_diagnostics = None
assert_supported_moviepy = None
verbose_mode = setup_clean_logging()
uvicorn_log_level = get_uvicorn_log_level()
# Log diagnostics and assert versions (fail fast if misconfigured)
try:
if log_video_stack_diagnostics:
log_video_stack_diagnostics()
if assert_supported_moviepy:
assert_supported_moviepy()
except Exception as _video_stack_err:
print(f"[ERROR] Video stack preflight failed: {_video_stack_err}")
return False
uvicorn.run(
"app:app",