AI Video Generation Implementation
This commit is contained in:
@@ -9,6 +9,7 @@ from fastapi import APIRouter, HTTPException, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from services.llm_providers.main_image_generation import generate_image
|
||||
from services.llm_providers.main_image_editing import edit_image
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from utils.logger_utils import get_service_logger
|
||||
from middleware.auth_middleware import get_current_user
|
||||
@@ -125,6 +126,14 @@ def generate(
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
call_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
|
||||
# Get image editing stats for unified log
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
|
||||
# Get video stats for unified log
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"[images.generate] ✅ Successfully tracked usage: user {user_id} -> stability -> {new_calls} calls")
|
||||
|
||||
@@ -137,6 +146,8 @@ def generate(
|
||||
├─ Actual Provider: {result.provider}
|
||||
├─ Model: {result.model or 'default'}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
@@ -195,6 +206,26 @@ class ImagePromptSuggestResponse(BaseModel):
|
||||
suggestions: list[PromptSuggestion]
|
||||
|
||||
|
||||
class ImageEditRequest(BaseModel):
|
||||
image_base64: str
|
||||
prompt: str
|
||||
provider: Optional[str] = Field(None, pattern="^(huggingface)$")
|
||||
model: Optional[str] = None
|
||||
guidance_scale: Optional[float] = None
|
||||
steps: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
|
||||
class ImageEditResponse(BaseModel):
|
||||
success: bool = True
|
||||
image_base64: str
|
||||
width: int
|
||||
height: int
|
||||
provider: str
|
||||
model: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
|
||||
@router.post("/suggest-prompts", response_model=ImagePromptSuggestResponse)
|
||||
def suggest_prompts(
|
||||
req: ImagePromptSuggestRequest,
|
||||
@@ -316,3 +347,136 @@ def suggest_prompts(
|
||||
logger.error(f"Prompt suggestion failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/edit", response_model=ImageEditResponse)
|
||||
def edit(
|
||||
req: ImageEditRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> ImageEditResponse:
|
||||
"""Edit image with subscription checking."""
|
||||
try:
|
||||
# Extract Clerk user ID (required)
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
# Decode base64 image
|
||||
try:
|
||||
input_image_bytes = base64.b64decode(req.image_base64)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid image_base64: {str(e)}")
|
||||
|
||||
# Validation is now handled inside edit_image function
|
||||
result = edit_image(
|
||||
input_image_bytes=input_image_bytes,
|
||||
prompt=req.prompt,
|
||||
options={
|
||||
"provider": req.provider,
|
||||
"model": req.model,
|
||||
"guidance_scale": req.guidance_scale,
|
||||
"steps": req.steps,
|
||||
"seed": req.seed,
|
||||
},
|
||||
user_id=user_id, # Pass user_id for validation inside edit_image
|
||||
)
|
||||
edited_image_b64 = base64.b64encode(result.image_bytes).decode("utf-8")
|
||||
|
||||
# TRACK USAGE after successful image editing
|
||||
if result:
|
||||
logger.info(f"[images.edit] ✅ Image editing successful, tracking usage for user {user_id}")
|
||||
try:
|
||||
db_track = next(get_db())
|
||||
try:
|
||||
# Get or create usage summary
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
logger.debug(f"[images.edit] Looking for usage summary: user_id={user_id}, period={current_period}")
|
||||
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
logger.info(f"[images.edit] Creating new usage summary for user {user_id}, period {current_period}")
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush() # Ensure summary is persisted before updating
|
||||
|
||||
# Get "before" state for unified log
|
||||
current_calls_before = getattr(summary, "image_edit_calls", 0) or 0
|
||||
|
||||
# Update image editing counters (separate from image generation)
|
||||
new_calls = current_calls_before + 1
|
||||
setattr(summary, "image_edit_calls", new_calls)
|
||||
logger.debug(f"[images.edit] Updated image_edit_calls: {current_calls_before} -> {new_calls}")
|
||||
|
||||
# Update totals
|
||||
old_total_calls = summary.total_calls or 0
|
||||
summary.total_calls = old_total_calls + 1
|
||||
logger.debug(f"[images.edit] Updated totals: calls {old_total_calls} -> {summary.total_calls}")
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
call_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
|
||||
# Get image generation stats for unified log
|
||||
current_image_gen_calls = getattr(summary, "stability_calls", 0) or 0
|
||||
image_gen_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
|
||||
# Get video stats for unified log
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"[images.edit] ✅ Successfully tracked usage: user {user_id} -> image_edit -> {new_calls} calls")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
print(f"""
|
||||
[SUBSCRIPTION] Image Editing
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: image_edit
|
||||
├─ Actual Provider: {result.provider}
|
||||
├─ Model: {result.model or 'default'}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'}
|
||||
├─ Images: {current_image_gen_calls} / {image_gen_limit if image_gen_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
logger.error(f"[images.edit] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
# Non-blocking: log error but don't fail the request
|
||||
logger.error(f"[images.edit] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
|
||||
return ImageEditResponse(
|
||||
image_base64=edited_image_b64,
|
||||
width=result.width,
|
||||
height=result.height,
|
||||
provider=result.provider,
|
||||
model=result.model,
|
||||
seed=result.seed,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Image editing failed: {e}", exc_info=True)
|
||||
# Provide a clean, actionable message to the client
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Image editing service is temporarily unavailable or the connection was reset. Please try again."
|
||||
)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ content generation, and full story creation.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
|
||||
from typing import Any, Dict, Union, List
|
||||
from typing import Any, Dict, Union, List, Optional
|
||||
from loguru import logger
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
@@ -37,6 +37,16 @@ from models.story_models import (
|
||||
from services.story_writer.story_service import StoryWriterService
|
||||
from .task_manager import task_manager
|
||||
from .cache_manager import cache_manager
|
||||
from uuid import uuid4
|
||||
from pydantic import BaseModel
|
||||
from pathlib import Path
|
||||
|
||||
from .utils.auth import require_authenticated_user
|
||||
from .utils.media_utils import resolve_media_file
|
||||
from .utils.hd_video import (
|
||||
generate_hd_video_payload,
|
||||
generate_hd_video_scene_payload,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/story", tags=["Story Writer"])
|
||||
@@ -503,11 +513,11 @@ async def get_task_status(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/task/{task_id}/result", response_model=StoryFullGenerationResponse)
|
||||
@router.get("/task/{task_id}/result")
|
||||
async def get_task_result(
|
||||
task_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> StoryFullGenerationResponse:
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the result of a completed story generation task."""
|
||||
try:
|
||||
if not current_user:
|
||||
@@ -528,7 +538,19 @@ async def get_task_result(
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail=f"No result found for task {task_id}")
|
||||
|
||||
return StoryFullGenerationResponse(**result, success=True, task_id=task_id)
|
||||
# Some tasks return a full-story payload compatible with StoryFullGenerationResponse,
|
||||
# others (e.g., video-only) return a dict like {"video": {...}, "success": True}.
|
||||
# To avoid model conflicts, return a generic payload and include task_id.
|
||||
# Frontend callers can branch on keys present (e.g., "video").
|
||||
if isinstance(result, dict):
|
||||
# Ensure success flag present without duplicating
|
||||
payload = {**result}
|
||||
payload.setdefault("success", True)
|
||||
payload["task_id"] = task_id
|
||||
return payload
|
||||
|
||||
# Fallback: wrap non-dict results
|
||||
return {"result": result, "success": True, "task_id": task_id}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -536,6 +558,69 @@ async def get_task_result(
|
||||
logger.error(f"[StoryWriter] Failed to get task result: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
class HDVideoRequest(BaseModel):
|
||||
prompt: str
|
||||
provider: str = "huggingface"
|
||||
model: str | None = None
|
||||
num_frames: int | None = None
|
||||
guidance_scale: float | None = None
|
||||
num_inference_steps: int | None = None
|
||||
negative_prompt: str | None = None
|
||||
seed: int | None = None
|
||||
|
||||
@router.post("/hd-video")
|
||||
async def generate_hd_video(
|
||||
request: HDVideoRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate an HD AI animation using provider text-to-video (Hugging Face for now).
|
||||
Saves the returned bytes as a video file and returns the secured URL.
|
||||
"""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
return generate_hd_video_payload(request, user_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryWriter] Failed to generate HD video: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
class HDVideoSceneRequest(BaseModel):
|
||||
scene_number: int
|
||||
scene_data: Dict[str, Any]
|
||||
story_context: Dict[str, Any]
|
||||
all_scenes: List[Dict[str, Any]]
|
||||
scene_image_url: Optional[str] = None
|
||||
provider: str = "huggingface"
|
||||
model: str | None = None
|
||||
num_frames: int | None = None
|
||||
guidance_scale: float | None = None
|
||||
num_inference_steps: int | None = None
|
||||
negative_prompt: str | None = None
|
||||
seed: int | None = None
|
||||
|
||||
|
||||
@router.post("/hd-video-scene")
|
||||
async def generate_hd_video_scene(
|
||||
request: HDVideoSceneRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate HD AI video for a single scene with AI-enhanced prompt.
|
||||
Uses prompt enhancer to create HunyuanVideo-optimized prompt from story context.
|
||||
"""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
return generate_hd_video_scene_payload(request, user_id)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryWriter] Failed to generate HD video for scene: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ---------------------------
|
||||
# Image Generation Endpoints
|
||||
@@ -614,31 +699,20 @@ async def serve_scene_image(
|
||||
):
|
||||
"""Serve a generated story scene image."""
|
||||
try:
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
# Import image generation service to get output directory
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
from services.story_writer.image_generation_service import StoryImageGenerationService
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
|
||||
image_service = StoryImageGenerationService()
|
||||
image_path = image_service.output_dir / image_filename
|
||||
|
||||
if not image_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Image not found: {image_filename}")
|
||||
|
||||
# Validate that the file is within the output directory (security check)
|
||||
try:
|
||||
image_path.resolve().relative_to(image_service.output_dir.resolve())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
image_path = resolve_media_file(image_service.output_dir, image_filename)
|
||||
|
||||
return FileResponse(
|
||||
path=str(image_path),
|
||||
media_type="image/png",
|
||||
filename=image_filename
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -726,31 +800,20 @@ async def serve_scene_audio(
|
||||
):
|
||||
"""Serve a generated story scene audio file."""
|
||||
try:
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
# Import audio generation service to get output directory
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
from services.story_writer.audio_generation_service import StoryAudioGenerationService
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
|
||||
audio_service = StoryAudioGenerationService()
|
||||
audio_path = audio_service.output_dir / audio_filename
|
||||
|
||||
if not audio_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Audio not found: {audio_filename}")
|
||||
|
||||
# Validate that the file is within the output directory (security check)
|
||||
try:
|
||||
audio_path.resolve().relative_to(audio_service.output_dir.resolve())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
audio_path = resolve_media_file(audio_service.output_dir, audio_filename)
|
||||
|
||||
return FileResponse(
|
||||
path=str(audio_path),
|
||||
media_type="audio/mpeg",
|
||||
filename=audio_filename
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -869,6 +932,99 @@ async def generate_story_video(
|
||||
logger.error(f"[StoryWriter] Failed to generate video: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/generate-video-async", response_model=Dict[str, Any])
|
||||
async def generate_story_video_async(
|
||||
request: StoryVideoGenerationRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a video asynchronously with progress updates via task manager.
|
||||
Frontend can poll /api/story/task/{task_id}/status to show progress messages.
|
||||
"""
|
||||
try:
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
if not request.scenes or len(request.scenes) == 0:
|
||||
raise HTTPException(status_code=400, detail="At least one scene is required")
|
||||
if len(request.scenes) != len(request.image_urls) or len(request.scenes) != len(request.audio_urls):
|
||||
raise HTTPException(status_code=400, detail="Number of scenes, image URLs, and audio URLs must match")
|
||||
|
||||
task_id = task_manager.create_task("story_video_generation")
|
||||
background_tasks.add_task(
|
||||
_execute_video_generation_task,
|
||||
task_id=task_id,
|
||||
request=request,
|
||||
user_id=user_id
|
||||
)
|
||||
return {"task_id": task_id, "status": "pending", "message": "Video generation started"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryWriter] Failed to start async video generation: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def _execute_video_generation_task(task_id: str, request: StoryVideoGenerationRequest, user_id: str):
|
||||
"""Background task to generate story video with progress mapped to task manager."""
|
||||
from services.story_writer.video_generation_service import StoryVideoGenerationService
|
||||
from services.story_writer.image_generation_service import StoryImageGenerationService
|
||||
from services.story_writer.audio_generation_service import StoryAudioGenerationService
|
||||
try:
|
||||
task_manager.update_task_status(task_id, "processing", progress=2.0, message="Initializing video generation...")
|
||||
video_service = StoryVideoGenerationService()
|
||||
image_service = StoryImageGenerationService()
|
||||
audio_service = StoryAudioGenerationService()
|
||||
|
||||
# Prepare assets
|
||||
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
|
||||
image_paths, audio_paths, valid_scenes = [], [], []
|
||||
for idx, (scene, image_url, audio_url) in enumerate(zip(scenes_data, request.image_urls, request.audio_urls)):
|
||||
image_filename = (image_url.split('/')[-1] if '/' in image_url else image_url).split('?')[0]
|
||||
audio_filename = (audio_url.split('/')[-1] if '/' in audio_url else audio_url).split('?')[0]
|
||||
image_path = image_service.output_dir / image_filename
|
||||
audio_path = audio_service.output_dir / audio_filename
|
||||
if not image_path.exists():
|
||||
logger.warning(f"[StoryWriter] Image not found: {image_path} (from URL: {image_url})")
|
||||
continue
|
||||
if not audio_path.exists():
|
||||
logger.warning(f"[StoryWriter] Audio not found: {audio_path} (from URL: {audio_url})")
|
||||
continue
|
||||
image_paths.append(str(image_path))
|
||||
audio_paths.append(str(audio_path))
|
||||
valid_scenes.append(scene)
|
||||
|
||||
if not image_paths or not audio_paths or len(image_paths) != len(audio_paths):
|
||||
raise RuntimeError("No valid or mismatched image/audio assets for video generation.")
|
||||
|
||||
# Map service progress (0-100) to task progress (5-95)
|
||||
def progress_callback(sub_progress: float, msg: str):
|
||||
overall = 5.0 + max(0.0, min(100.0, sub_progress)) * 0.9
|
||||
task_manager.update_task_status(task_id, "processing", progress=overall, message=msg)
|
||||
|
||||
result = video_service.generate_story_video(
|
||||
scenes=valid_scenes,
|
||||
image_paths=image_paths,
|
||||
audio_paths=audio_paths,
|
||||
user_id=user_id,
|
||||
story_title=request.story_title or "Story",
|
||||
fps=request.fps or 24,
|
||||
transition_duration=request.transition_duration or 0.5,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"completed",
|
||||
progress=100.0,
|
||||
message="Video generation complete!",
|
||||
result={"video": result, "success": True}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryWriter] Async video generation failed: {e}", exc_info=True)
|
||||
task_manager.update_task_status(task_id, "failed", error=str(e), message=f"Video generation failed: {e}")
|
||||
|
||||
@router.post("/generate-complete-video", response_model=Dict[str, Any])
|
||||
async def generate_complete_story_video(
|
||||
@@ -1111,31 +1267,20 @@ async def serve_story_video(
|
||||
):
|
||||
"""Serve a generated story video file."""
|
||||
try:
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
# Import video generation service to get output directory
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
from services.story_writer.video_generation_service import StoryVideoGenerationService
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
|
||||
video_service = StoryVideoGenerationService()
|
||||
video_path = video_service.output_dir / video_filename
|
||||
|
||||
if not video_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Video not found: {video_filename}")
|
||||
|
||||
# Validate that the file is within the output directory (security check)
|
||||
try:
|
||||
video_path.resolve().relative_to(video_service.output_dir.resolve())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
video_path = resolve_media_file(video_service.output_dir, video_filename)
|
||||
|
||||
return FileResponse(
|
||||
path=str(video_path),
|
||||
media_type="video/mp4",
|
||||
filename=video_filename
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
8
backend/api/story_writer/utils/__init__.py
Normal file
8
backend/api/story_writer/utils/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Utility helpers for Story Writer API routes.
|
||||
|
||||
Grouped here to keep the main router lean while reusing common logic
|
||||
such as authentication guards, media resolution, and HD video helpers.
|
||||
"""
|
||||
|
||||
|
||||
23
backend/api/story_writer/utils/auth.py
Normal file
23
backend/api/story_writer/utils/auth.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
def require_authenticated_user(current_user: Dict[str, Any] | None) -> str:
|
||||
"""
|
||||
Validates the current user dictionary provided by Clerk middleware and
|
||||
returns the normalized user_id. Raises HTTP 401 if authentication fails.
|
||||
"""
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get("id", "")).strip()
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid user ID in authentication token",
|
||||
)
|
||||
|
||||
return user_id
|
||||
|
||||
|
||||
154
backend/api/story_writer/utils/hd_video.py
Normal file
154
backend/api/story_writer/utils/hd_video.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
from uuid import uuid4
|
||||
|
||||
from .media_utils import load_story_image_bytes
|
||||
|
||||
|
||||
def generate_hd_video_payload(request: Any, user_id: str) -> Dict[str, Any]:
|
||||
"""Handles synchronous HD video generation."""
|
||||
from services.llm_providers.main_video_generation import ai_video_generate
|
||||
from services.story_writer.video_generation_service import StoryVideoGenerationService
|
||||
|
||||
video_service = StoryVideoGenerationService()
|
||||
output_dir = video_service.output_dir
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if getattr(request, "model", None):
|
||||
kwargs["model"] = request.model
|
||||
if getattr(request, "num_frames", None):
|
||||
kwargs["num_frames"] = request.num_frames
|
||||
if getattr(request, "guidance_scale", None) is not None:
|
||||
kwargs["guidance_scale"] = request.guidance_scale
|
||||
if getattr(request, "num_inference_steps", None):
|
||||
kwargs["num_inference_steps"] = request.num_inference_steps
|
||||
if getattr(request, "negative_prompt", None):
|
||||
kwargs["negative_prompt"] = request.negative_prompt
|
||||
if getattr(request, "seed", None) is not None:
|
||||
kwargs["seed"] = request.seed
|
||||
|
||||
logger.info(f"[StoryWriter] Generating HD video via {getattr(request, 'provider', 'huggingface')} for user {user_id}")
|
||||
raw_bytes = ai_video_generate(
|
||||
prompt=request.prompt,
|
||||
provider=getattr(request, "provider", None) or "huggingface",
|
||||
user_id=user_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
filename = f"hd_{uuid4().hex}.mp4"
|
||||
file_path = output_dir / filename
|
||||
with open(file_path, "wb") as fh:
|
||||
fh.write(raw_bytes)
|
||||
|
||||
logger.info(f"[StoryWriter] HD video saved to {file_path}")
|
||||
return {
|
||||
"success": True,
|
||||
"video_filename": filename,
|
||||
"video_url": f"/api/story/videos/{filename}",
|
||||
"provider": getattr(request, "provider", None) or "huggingface",
|
||||
"model": getattr(request, "model", None) or "tencent/HunyuanVideo",
|
||||
}
|
||||
|
||||
|
||||
def generate_hd_video_scene_payload(request: Any, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Handles per-scene HD video generation including prompt enhancement,
|
||||
subscription validation, and optional image conditioning.
|
||||
"""
|
||||
from services.database import get_db as get_db_validation
|
||||
from services.onboarding.api_key_manager import APIKeyManager
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_video_generation_operations
|
||||
from services.story_writer.prompt_enhancer_service import enhance_scene_prompt_for_video
|
||||
from services.llm_providers.main_video_generation import ai_video_generate
|
||||
from services.story_writer.video_generation_service import StoryVideoGenerationService
|
||||
|
||||
scene_number = request.scene_number
|
||||
logger.info(f"[StoryWriter] Generating HD video for scene {scene_number} for user {user_id}")
|
||||
|
||||
# Step 1: Validate API key
|
||||
hf_token = APIKeyManager().get_api_key("hf_token")
|
||||
if not hf_token:
|
||||
logger.error("[StoryWriter] Pre-flight: HF token not configured - blocking video generation")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Hugging Face API token is not configured. Please configure your HF token in settings.",
|
||||
"message": "Hugging Face API token is not configured. Please configure your HF token in settings.",
|
||||
},
|
||||
)
|
||||
|
||||
# Step 2: Subscription limits
|
||||
db_validation = next(get_db_validation())
|
||||
try:
|
||||
pricing_service = PricingService(db_validation)
|
||||
logger.info(f"[StoryWriter] Pre-flight: Checking video generation limits for user {user_id}...")
|
||||
validate_video_generation_operations(pricing_service=pricing_service, user_id=user_id)
|
||||
logger.info("[StoryWriter] Pre-flight: ✅ Video generation limits validated - proceeding")
|
||||
finally:
|
||||
db_validation.close()
|
||||
|
||||
# Stage 1: Prompt enhancement
|
||||
enhanced_prompt = enhance_scene_prompt_for_video(
|
||||
current_scene=request.scene_data,
|
||||
story_context=request.story_context,
|
||||
all_scenes=request.all_scenes,
|
||||
user_id=user_id,
|
||||
)
|
||||
logger.info(f"[StoryWriter] Generated enhanced prompt ({len(enhanced_prompt)} chars) for scene {scene_number}")
|
||||
|
||||
# Stage 2: Optional image reference
|
||||
scene_image_bytes: Optional[bytes] = None
|
||||
if getattr(request, "scene_image_url", None):
|
||||
scene_image_bytes = load_story_image_bytes(request.scene_image_url)
|
||||
if scene_image_bytes:
|
||||
logger.info(f"[StoryWriter] Using scene image reference for scene {scene_number}")
|
||||
else:
|
||||
logger.warning(f"[StoryWriter] Scene image could not be loaded for scene {scene_number}, falling back to text-only video")
|
||||
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if getattr(request, "model", None):
|
||||
kwargs["model"] = request.model
|
||||
if getattr(request, "num_frames", None):
|
||||
kwargs["num_frames"] = request.num_frames
|
||||
if getattr(request, "guidance_scale", None) is not None:
|
||||
kwargs["guidance_scale"] = request.guidance_scale
|
||||
if getattr(request, "num_inference_steps", None):
|
||||
kwargs["num_inference_steps"] = request.num_inference_steps
|
||||
if getattr(request, "negative_prompt", None):
|
||||
kwargs["negative_prompt"] = request.negative_prompt
|
||||
if getattr(request, "seed", None) is not None:
|
||||
kwargs["seed"] = request.seed
|
||||
|
||||
raw_bytes = ai_video_generate(
|
||||
prompt=enhanced_prompt,
|
||||
provider=getattr(request, "provider", None) or "huggingface",
|
||||
user_id=user_id,
|
||||
input_image_bytes=scene_image_bytes,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
video_service = StoryVideoGenerationService()
|
||||
save_result = video_service.save_scene_video(
|
||||
video_bytes=raw_bytes,
|
||||
scene_number=scene_number,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
logger.info(f"[StoryWriter] HD video saved for scene {scene_number}: {save_result.get('video_filename')}")
|
||||
return {
|
||||
"success": True,
|
||||
"scene_number": scene_number,
|
||||
"video_filename": save_result.get("video_filename"),
|
||||
"video_url": save_result.get("video_url"),
|
||||
"prompt_used": enhanced_prompt,
|
||||
"provider": getattr(request, "provider", None) or "huggingface",
|
||||
"model": getattr(request, "model", None) or "tencent/HunyuanVideo",
|
||||
}
|
||||
|
||||
|
||||
69
backend/api/story_writer/utils/media_utils.py
Normal file
69
backend/api/story_writer/utils/media_utils.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from loguru import logger
|
||||
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parents[3] # backend/
|
||||
STORY_IMAGES_DIR = (BASE_DIR / "story_images").resolve()
|
||||
STORY_IMAGES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def load_story_image_bytes(image_url: str) -> Optional[bytes]:
|
||||
"""
|
||||
Resolve an authenticated story image URL (e.g., /api/story/images/<file>) to raw bytes.
|
||||
Returns None if the file cannot be located.
|
||||
"""
|
||||
if not image_url:
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed = urlparse(image_url)
|
||||
path = parsed.path if parsed.scheme else image_url
|
||||
prefix = "/api/story/images/"
|
||||
if prefix not in path:
|
||||
logger.warning(f"[StoryWriter] Unsupported image URL for video reference: {image_url}")
|
||||
return None
|
||||
|
||||
filename = path.split(prefix, 1)[1].split("?", 1)[0].strip()
|
||||
if not filename:
|
||||
return None
|
||||
|
||||
file_path = (STORY_IMAGES_DIR / filename).resolve()
|
||||
if not str(file_path).startswith(str(STORY_IMAGES_DIR)):
|
||||
logger.error(f"[StoryWriter] Attempted path traversal when resolving image: {image_url}")
|
||||
return None
|
||||
|
||||
if not file_path.exists():
|
||||
logger.warning(f"[StoryWriter] Referenced scene image not found on disk: {file_path}")
|
||||
return None
|
||||
|
||||
return file_path.read_bytes()
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to load reference image for video gen: {exc}")
|
||||
return None
|
||||
|
||||
|
||||
def resolve_media_file(base_dir: Path, filename: str) -> Path:
|
||||
"""
|
||||
Returns a safe resolved path for a media file stored under base_dir.
|
||||
Guards against directory traversal and ensures the file exists.
|
||||
"""
|
||||
filename = filename.split("?")[0].strip()
|
||||
resolved = (base_dir / filename).resolve()
|
||||
|
||||
try:
|
||||
resolved.relative_to(base_dir.resolve())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
|
||||
|
||||
if not resolved.exists():
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"File not found: {filename}")
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from functools import lru_cache
|
||||
from services.database import get_db
|
||||
from services.subscription import UsageTrackingService, PricingService
|
||||
from services.subscription.log_wrapping_service import LogWrappingService
|
||||
from services.subscription.schema_utils import ensure_subscription_plan_columns
|
||||
from services.subscription.schema_utils import ensure_subscription_plan_columns, ensure_usage_summaries_columns
|
||||
import sqlite3
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from models.subscription_models import (
|
||||
@@ -114,6 +114,8 @@ async def get_subscription_plans(
|
||||
"metaphor_calls": plan.metaphor_calls_limit,
|
||||
"firecrawl_calls": plan.firecrawl_calls_limit,
|
||||
"stability_calls": plan.stability_calls_limit,
|
||||
"video_calls": getattr(plan, 'video_calls_limit', 0),
|
||||
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0),
|
||||
"gemini_tokens": plan.gemini_tokens_limit,
|
||||
"openai_tokens": plan.openai_tokens_limit,
|
||||
"anthropic_tokens": plan.anthropic_tokens_limit,
|
||||
@@ -132,7 +134,7 @@ async def get_subscription_plans(
|
||||
|
||||
except (sqlite3.OperationalError, Exception) as e:
|
||||
error_str = str(e).lower()
|
||||
if 'no such column' in error_str and 'exa_calls_limit' in error_str:
|
||||
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str):
|
||||
logger.warning("Missing column detected in subscription plans query, attempting schema fix...")
|
||||
try:
|
||||
import services.subscription.schema_utils as schema_utils
|
||||
@@ -237,6 +239,8 @@ async def get_user_subscription(
|
||||
"metaphor_calls": free_plan.metaphor_calls_limit,
|
||||
"firecrawl_calls": free_plan.firecrawl_calls_limit,
|
||||
"stability_calls": free_plan.stability_calls_limit,
|
||||
"video_calls": getattr(free_plan, 'video_calls_limit', 0),
|
||||
"image_edit_calls": getattr(free_plan, 'image_edit_calls_limit', 0),
|
||||
"monthly_cost": free_plan.monthly_cost_limit
|
||||
}
|
||||
}
|
||||
@@ -334,6 +338,8 @@ async def get_subscription_status(
|
||||
"metaphor_calls": free_plan.metaphor_calls_limit,
|
||||
"firecrawl_calls": free_plan.firecrawl_calls_limit,
|
||||
"stability_calls": free_plan.stability_calls_limit,
|
||||
"video_calls": getattr(free_plan, 'video_calls_limit', 0),
|
||||
"image_edit_calls": getattr(free_plan, 'image_edit_calls_limit', 0),
|
||||
"monthly_cost": free_plan.monthly_cost_limit
|
||||
}
|
||||
}
|
||||
@@ -399,15 +405,16 @@ async def get_subscription_status(
|
||||
|
||||
except (sqlite3.OperationalError, Exception) as e:
|
||||
error_str = str(e).lower()
|
||||
if 'no such column' in error_str and 'exa_calls_limit' in error_str:
|
||||
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str):
|
||||
# Try to fix schema and retry once
|
||||
logger.warning("Missing column detected in subscription status query, attempting schema fix...")
|
||||
try:
|
||||
import services.subscription.schema_utils as schema_utils
|
||||
schema_utils._checked_subscription_plan_columns = False
|
||||
ensure_subscription_plan_columns(db)
|
||||
db.commit() # Ensure schema changes are committed
|
||||
db.expire_all()
|
||||
# Retry the query
|
||||
# Retry the query - query subscription without eager loading plan
|
||||
subscription = db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
@@ -437,11 +444,21 @@ async def get_subscription_status(
|
||||
"metaphor_calls": free_plan.metaphor_calls_limit,
|
||||
"firecrawl_calls": free_plan.firecrawl_calls_limit,
|
||||
"stability_calls": free_plan.stability_calls_limit,
|
||||
"video_calls": getattr(free_plan, 'video_calls_limit', 0),
|
||||
"image_edit_calls": getattr(free_plan, 'image_edit_calls_limit', 0),
|
||||
"monthly_cost": free_plan.monthly_cost_limit
|
||||
}
|
||||
}
|
||||
}
|
||||
elif subscription:
|
||||
# Query plan separately after schema fix to avoid lazy loading issues
|
||||
plan = db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.id == subscription.plan_id
|
||||
).first()
|
||||
|
||||
if not plan:
|
||||
raise HTTPException(status_code=404, detail="Plan not found")
|
||||
|
||||
now = datetime.utcnow()
|
||||
if subscription.current_period_end < now:
|
||||
if getattr(subscription, 'auto_renew', False):
|
||||
@@ -456,8 +473,8 @@ async def get_subscription_status(
|
||||
"success": True,
|
||||
"data": {
|
||||
"active": False,
|
||||
"plan": subscription.plan.tier.value,
|
||||
"tier": subscription.plan.tier.value,
|
||||
"plan": plan.tier.value,
|
||||
"tier": plan.tier.value,
|
||||
"can_use_api": False,
|
||||
"reason": "Subscription expired"
|
||||
}
|
||||
@@ -466,21 +483,23 @@ async def get_subscription_status(
|
||||
"success": True,
|
||||
"data": {
|
||||
"active": True,
|
||||
"plan": subscription.plan.tier.value,
|
||||
"tier": subscription.plan.tier.value,
|
||||
"plan": plan.tier.value,
|
||||
"tier": plan.tier.value,
|
||||
"can_use_api": True,
|
||||
"limits": {
|
||||
"ai_text_generation_calls": getattr(subscription.plan, 'ai_text_generation_calls_limit', None) or 0,
|
||||
"gemini_calls": subscription.plan.gemini_calls_limit,
|
||||
"openai_calls": subscription.plan.openai_calls_limit,
|
||||
"anthropic_calls": subscription.plan.anthropic_calls_limit,
|
||||
"mistral_calls": subscription.plan.mistral_calls_limit,
|
||||
"tavily_calls": subscription.plan.tavily_calls_limit,
|
||||
"serper_calls": subscription.plan.serper_calls_limit,
|
||||
"metaphor_calls": subscription.plan.metaphor_calls_limit,
|
||||
"firecrawl_calls": subscription.plan.firecrawl_calls_limit,
|
||||
"stability_calls": subscription.plan.stability_calls_limit,
|
||||
"monthly_cost": subscription.plan.monthly_cost_limit
|
||||
"ai_text_generation_calls": getattr(plan, 'ai_text_generation_calls_limit', None) or 0,
|
||||
"gemini_calls": plan.gemini_calls_limit,
|
||||
"openai_calls": plan.openai_calls_limit,
|
||||
"anthropic_calls": plan.anthropic_calls_limit,
|
||||
"mistral_calls": plan.mistral_calls_limit,
|
||||
"tavily_calls": plan.tavily_calls_limit,
|
||||
"serper_calls": plan.serper_calls_limit,
|
||||
"metaphor_calls": plan.metaphor_calls_limit,
|
||||
"firecrawl_calls": plan.firecrawl_calls_limit,
|
||||
"stability_calls": plan.stability_calls_limit,
|
||||
"video_calls": getattr(plan, 'video_calls_limit', 0),
|
||||
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0),
|
||||
"monthly_cost": plan.monthly_cost_limit
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -893,6 +912,7 @@ async def get_dashboard_data(
|
||||
|
||||
try:
|
||||
ensure_subscription_plan_columns(db)
|
||||
ensure_usage_summaries_columns(db)
|
||||
# Serve from short TTL cache to avoid hammering DB on bursts
|
||||
import time
|
||||
now = time.time()
|
||||
@@ -966,7 +986,72 @@ async def get_dashboard_data(
|
||||
_dashboard_cache_ts[user_id] = now
|
||||
return response_payload
|
||||
|
||||
except Exception as e:
|
||||
except (sqlite3.OperationalError, Exception) as e:
|
||||
error_str = str(e).lower()
|
||||
if 'no such column' in error_str and ('exa_calls' in error_str or 'exa_cost' in error_str or 'video_calls' in error_str or 'video_cost' in error_str or 'image_edit_calls' in error_str or 'image_edit_cost' in error_str):
|
||||
logger.warning("Missing column detected in dashboard query, attempting schema fix...")
|
||||
try:
|
||||
import services.subscription.schema_utils as schema_utils
|
||||
schema_utils._checked_usage_summaries_columns = False
|
||||
schema_utils._checked_subscription_plan_columns = False
|
||||
ensure_usage_summaries_columns(db)
|
||||
ensure_subscription_plan_columns(db)
|
||||
db.expire_all()
|
||||
# Retry the query
|
||||
usage_service = UsageTrackingService(db)
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
current_usage = usage_service.get_user_usage_stats(user_id)
|
||||
trends = usage_service.get_usage_trends(user_id, 6)
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
|
||||
alerts = db.query(UsageAlert).filter(
|
||||
UsageAlert.user_id == user_id,
|
||||
UsageAlert.is_read == False
|
||||
).order_by(UsageAlert.created_at.desc()).limit(5).all()
|
||||
|
||||
alerts_data = [
|
||||
{
|
||||
"id": alert.id,
|
||||
"type": alert.alert_type,
|
||||
"title": alert.title,
|
||||
"message": alert.message,
|
||||
"severity": alert.severity,
|
||||
"created_at": alert.created_at.isoformat()
|
||||
}
|
||||
for alert in alerts
|
||||
]
|
||||
|
||||
current_cost = current_usage.get('total_cost', 0)
|
||||
days_in_period = 30
|
||||
current_day = datetime.now().day
|
||||
projected_cost = (current_cost / current_day) * days_in_period if current_day > 0 else 0
|
||||
|
||||
response_payload = {
|
||||
"success": True,
|
||||
"data": {
|
||||
"current_usage": current_usage,
|
||||
"trends": trends,
|
||||
"limits": limits,
|
||||
"alerts": alerts_data,
|
||||
"projections": {
|
||||
"projected_monthly_cost": round(projected_cost, 2),
|
||||
"cost_limit": limits.get('limits', {}).get('monthly_cost', 0) if limits else 0,
|
||||
"projected_usage_percentage": (projected_cost / max(limits.get('limits', {}).get('monthly_cost', 1), 1)) * 100 if limits else 0
|
||||
},
|
||||
"summary": {
|
||||
"total_api_calls_this_month": current_usage.get('total_calls', 0),
|
||||
"total_cost_this_month": current_usage.get('total_cost', 0),
|
||||
"usage_status": current_usage.get('usage_status', 'active'),
|
||||
"unread_alerts": len(alerts_data)
|
||||
}
|
||||
}
|
||||
}
|
||||
return response_payload
|
||||
except Exception as retry_err:
|
||||
logger.error(f"Schema fix and retry failed: {retry_err}")
|
||||
raise HTTPException(status_code=500, detail=f"Database schema error: {str(e)}")
|
||||
|
||||
logger.error(f"Error getting dashboard data: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@@ -93,6 +93,17 @@ def setup_clean_logging():
|
||||
format="{time:HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}\n",
|
||||
filter=warning_only_filter
|
||||
)
|
||||
# Add a focused sink to surface Story Video Generation INFO logs in console
|
||||
def video_generation_filter(record):
|
||||
msg = record.get("message", "")
|
||||
name = record.get("name", "")
|
||||
return "[StoryVideoGeneration]" in msg or "services.story_writer.video_generation_service" in name
|
||||
logger.add(
|
||||
sys.stdout.write,
|
||||
level="INFO",
|
||||
format="{time:HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}\n",
|
||||
filter=video_generation_filter
|
||||
)
|
||||
else:
|
||||
# In verbose mode, show all log levels with detailed formatting
|
||||
logger.add(
|
||||
|
||||
@@ -35,6 +35,8 @@ class APIProvider(enum.Enum):
|
||||
FIRECRAWL = "firecrawl"
|
||||
STABILITY = "stability"
|
||||
EXA = "exa"
|
||||
VIDEO = "video"
|
||||
IMAGE_EDIT = "image_edit"
|
||||
|
||||
class BillingCycle(enum.Enum):
|
||||
MONTHLY = "monthly"
|
||||
@@ -68,6 +70,8 @@ class SubscriptionPlan(Base):
|
||||
firecrawl_calls_limit = Column(Integer, default=0)
|
||||
stability_calls_limit = Column(Integer, default=0) # Image generation
|
||||
exa_calls_limit = Column(Integer, default=0) # Exa neural search
|
||||
video_calls_limit = Column(Integer, default=0) # AI video generation
|
||||
image_edit_calls_limit = Column(Integer, default=0) # AI image editing
|
||||
|
||||
# Token Limits (for LLM providers)
|
||||
gemini_tokens_limit = Column(Integer, default=0)
|
||||
@@ -185,6 +189,8 @@ class UsageSummary(Base):
|
||||
firecrawl_calls = Column(Integer, default=0)
|
||||
stability_calls = Column(Integer, default=0)
|
||||
exa_calls = Column(Integer, default=0)
|
||||
video_calls = Column(Integer, default=0) # AI video generation
|
||||
image_edit_calls = Column(Integer, default=0) # AI image editing
|
||||
|
||||
# Token Usage
|
||||
gemini_tokens = Column(Integer, default=0)
|
||||
@@ -203,6 +209,8 @@ class UsageSummary(Base):
|
||||
firecrawl_cost = Column(Float, default=0.0)
|
||||
stability_cost = Column(Float, default=0.0)
|
||||
exa_cost = Column(Float, default=0.0)
|
||||
video_cost = Column(Float, default=0.0) # AI video generation
|
||||
image_edit_cost = Column(Float, default=0.0) # AI image editing
|
||||
|
||||
# Totals
|
||||
total_calls = Column(Integer, default=0)
|
||||
|
||||
@@ -53,7 +53,7 @@ nltk>=3.8.0
|
||||
|
||||
# Image and audio processing for Stability AI
|
||||
Pillow>=10.0.0
|
||||
huggingface_hub>=0.24.0
|
||||
huggingface_hub>=1.1.4
|
||||
scikit-learn>=1.3.0
|
||||
|
||||
# Text-to-Speech (TTS) dependencies
|
||||
@@ -61,7 +61,7 @@ gtts>=2.4.0
|
||||
pyttsx3>=2.90
|
||||
|
||||
# Video composition dependencies
|
||||
moviepy>=1.0.3
|
||||
moviepy==2.1.2
|
||||
imageio>=2.31.0
|
||||
imageio-ffmpeg>=0.4.9
|
||||
|
||||
|
||||
102
backend/scripts/update_image_edit_limits.py
Normal file
102
backend/scripts/update_image_edit_limits.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
Script to update existing subscription plans with image_edit_calls_limit values.
|
||||
|
||||
This script updates the SubscriptionPlan table to set image_edit_calls_limit
|
||||
for plans that were created before this column was added.
|
||||
|
||||
Limits:
|
||||
- Free: 10 image editing calls/month
|
||||
- Basic: 30 image editing calls/month
|
||||
- Pro: 100 image editing calls/month
|
||||
- Enterprise: 0 (unlimited)
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Add the backend directory to Python path
|
||||
backend_dir = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(backend_dir))
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from loguru import logger
|
||||
|
||||
from models.subscription_models import SubscriptionPlan, SubscriptionTier
|
||||
from services.database import DATABASE_URL
|
||||
|
||||
def update_image_edit_limits():
|
||||
"""Update existing subscription plans with image_edit_calls_limit values."""
|
||||
|
||||
try:
|
||||
# Create engine
|
||||
engine = create_engine(DATABASE_URL, echo=False)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
# Ensure schema columns exist
|
||||
from services.subscription.schema_utils import ensure_subscription_plan_columns
|
||||
ensure_subscription_plan_columns(db)
|
||||
|
||||
# Define limits for each tier
|
||||
limits_by_tier = {
|
||||
SubscriptionTier.FREE: 10,
|
||||
SubscriptionTier.BASIC: 30,
|
||||
SubscriptionTier.PRO: 100,
|
||||
SubscriptionTier.ENTERPRISE: 0, # Unlimited
|
||||
}
|
||||
|
||||
updated_count = 0
|
||||
|
||||
# Update each plan
|
||||
for tier, limit in limits_by_tier.items():
|
||||
plans = db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == tier,
|
||||
SubscriptionPlan.is_active == True
|
||||
).all()
|
||||
|
||||
for plan in plans:
|
||||
current_limit = getattr(plan, 'image_edit_calls_limit', 0) or 0
|
||||
|
||||
# Only update if limit is 0 (not set) or if it's different
|
||||
if current_limit != limit:
|
||||
setattr(plan, 'image_edit_calls_limit', limit)
|
||||
plan.updated_at = datetime.now(timezone.utc)
|
||||
updated_count += 1
|
||||
logger.info(f"Updated {plan.name} plan ({tier.value}): image_edit_calls_limit = {current_limit} -> {limit}")
|
||||
else:
|
||||
logger.debug(f"Plan {plan.name} ({tier.value}) already has image_edit_calls_limit = {limit}")
|
||||
|
||||
# Commit changes
|
||||
db.commit()
|
||||
|
||||
if updated_count > 0:
|
||||
logger.info(f"✅ Successfully updated {updated_count} subscription plan(s) with image_edit_calls_limit")
|
||||
else:
|
||||
logger.info("✅ All subscription plans already have correct image_edit_calls_limit values")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error updating image_edit_limits: {e}")
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error creating database connection: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("🔄 Updating subscription plans with image_edit_calls_limit...")
|
||||
success = update_image_edit_limits()
|
||||
if success:
|
||||
logger.info("🎉 Image edit limits update completed successfully!")
|
||||
else:
|
||||
logger.error("❌ Image edit limits update failed")
|
||||
sys.exit(1)
|
||||
|
||||
165
backend/services/llm_providers/main_image_editing.py
Normal file
165
backend/services/llm_providers/main_image_editing.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import io
|
||||
from typing import Optional, Dict, Any
|
||||
from PIL import Image
|
||||
|
||||
from .image_generation import (
|
||||
ImageGenerationOptions,
|
||||
ImageGenerationResult,
|
||||
)
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
try:
|
||||
from huggingface_hub import InferenceClient
|
||||
HF_HUB_AVAILABLE = True
|
||||
except ImportError:
|
||||
HF_HUB_AVAILABLE = False
|
||||
|
||||
|
||||
logger = get_service_logger("image_editing.facade")
|
||||
|
||||
|
||||
DEFAULT_IMAGE_EDIT_MODEL = os.getenv(
|
||||
"HF_IMAGE_EDIT_MODEL",
|
||||
"Qwen/Qwen-Image-Edit",
|
||||
)
|
||||
|
||||
|
||||
def _select_provider(explicit: Optional[str]) -> str:
|
||||
"""Select provider for image editing. Defaults to huggingface with fal-ai."""
|
||||
if explicit:
|
||||
return explicit
|
||||
# Default to huggingface for image editing (best support for image-to-image)
|
||||
return "huggingface"
|
||||
|
||||
|
||||
def _get_provider_client(provider_name: str, api_key: Optional[str] = None):
|
||||
"""Get InferenceClient for the specified provider."""
|
||||
if not HF_HUB_AVAILABLE:
|
||||
raise RuntimeError("huggingface_hub is not installed. Install with: pip install huggingface_hub")
|
||||
|
||||
if provider_name == "huggingface":
|
||||
api_key = api_key or os.getenv("HF_TOKEN")
|
||||
if not api_key:
|
||||
raise RuntimeError("HF_TOKEN is required for Hugging Face image editing")
|
||||
# Use fal-ai provider for fast inference
|
||||
return InferenceClient(provider="fal-ai", api_key=api_key)
|
||||
|
||||
raise ValueError(f"Unknown image editing provider: {provider_name}")
|
||||
|
||||
|
||||
def edit_image(
|
||||
input_image_bytes: bytes,
|
||||
prompt: str,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> ImageGenerationResult:
|
||||
"""Edit image with pre-flight validation.
|
||||
|
||||
Args:
|
||||
input_image_bytes: Input image as bytes (PNG/JPEG)
|
||||
prompt: Natural language prompt describing desired edits (e.g., "Turn the cat into a tiger")
|
||||
options: Image editing options (provider, model, etc.)
|
||||
user_id: User ID for subscription checking (optional, but required for validation)
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with edited image bytes and metadata
|
||||
|
||||
Best Practices for Prompts:
|
||||
- Use clear, specific language describing desired changes
|
||||
- Describe what should change and what should remain
|
||||
- Examples: "Turn the cat into a tiger", "Change background to forest",
|
||||
"Make it look like a watercolor painting"
|
||||
"""
|
||||
# PRE-FLIGHT VALIDATION: Validate image editing before API call
|
||||
# MUST happen BEFORE any API calls - return immediately if validation fails
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_editing_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
validate_image_editing_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id
|
||||
)
|
||||
except HTTPException as http_ex:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"[Image Editing] ❌ Pre-flight validation failed - blocking API call")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
logger.info(f"[Image Editing] ✅ Pre-flight validation passed - proceeding with image editing")
|
||||
|
||||
# Validate input
|
||||
if not input_image_bytes:
|
||||
raise ValueError("input_image_bytes is required")
|
||||
if not prompt or not prompt.strip():
|
||||
raise ValueError("prompt is required for image editing")
|
||||
|
||||
opts = options or {}
|
||||
provider_name = _select_provider(opts.get("provider"))
|
||||
model = opts.get("model") or DEFAULT_IMAGE_EDIT_MODEL
|
||||
|
||||
logger.info(f"[Image Editing] Editing image via provider={provider_name} model={model}")
|
||||
|
||||
# Get provider client
|
||||
client = _get_provider_client(provider_name, opts.get("api_key"))
|
||||
|
||||
# Prepare parameters for image-to-image
|
||||
params: Dict[str, Any] = {}
|
||||
if opts.get("guidance_scale") is not None:
|
||||
params["guidance_scale"] = opts.get("guidance_scale")
|
||||
if opts.get("steps") is not None:
|
||||
params["num_inference_steps"] = opts.get("steps")
|
||||
if opts.get("seed") is not None:
|
||||
params["seed"] = opts.get("seed")
|
||||
|
||||
try:
|
||||
# Convert input image bytes to PIL Image for validation
|
||||
input_image = Image.open(io.BytesIO(input_image_bytes))
|
||||
width = input_image.width
|
||||
height = input_image.height
|
||||
|
||||
# Use image_to_image method from Hugging Face InferenceClient
|
||||
# This follows the pattern from the Hugging Face documentation
|
||||
# Docs: https://huggingface.co/docs/inference-providers/en/guides/image-editor
|
||||
edited_image: Image.Image = client.image_to_image(
|
||||
image=input_image,
|
||||
prompt=prompt.strip(),
|
||||
model=model,
|
||||
**params,
|
||||
)
|
||||
|
||||
# Convert edited image back to bytes
|
||||
with io.BytesIO() as buf:
|
||||
edited_image.save(buf, format="PNG")
|
||||
edited_image_bytes = buf.getvalue()
|
||||
|
||||
logger.info(f"[Image Editing] ✅ Successfully edited image: {len(edited_image_bytes)} bytes")
|
||||
|
||||
return ImageGenerationResult(
|
||||
image_bytes=edited_image_bytes,
|
||||
width=edited_image.width,
|
||||
height=edited_image.height,
|
||||
provider="huggingface",
|
||||
model=model,
|
||||
seed=opts.get("seed"),
|
||||
metadata={
|
||||
"provider": "fal-ai",
|
||||
"operation": "image_editing",
|
||||
"original_width": width,
|
||||
"original_height": height,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Image Editing] ❌ Error editing image: {e}", exc_info=True)
|
||||
raise RuntimeError(f"Image editing failed: {str(e)}")
|
||||
|
||||
@@ -507,6 +507,14 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
current_images_before = getattr(summary, "stability_calls", 0) or 0
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
|
||||
# Get image editing stats for unified log
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
|
||||
# Get video stats for unified log
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
# CRITICAL DEBUG: Print diagnostic info BEFORE commit (always visible, flushed immediately)
|
||||
import sys
|
||||
debug_msg = f"[DEBUG] BEFORE COMMIT - Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}, Period: {current_period}, New calls will be: {new_calls}, New tokens will be: {new_tokens}"
|
||||
@@ -562,6 +570,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'}
|
||||
├─ Tokens: {current_tokens_before} → {new_tokens} / {token_limit if token_limit > 0 else '∞'}
|
||||
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
@@ -802,6 +811,14 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
current_images_before = getattr(summary, "stability_calls", 0) or 0
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
|
||||
# Get image editing stats for unified log
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
|
||||
# Get video stats for unified log
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
|
||||
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
|
||||
db_track.commit() # Commit transaction to make changes visible to other sessions
|
||||
@@ -819,6 +836,8 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'}
|
||||
├─ Tokens: {current_tokens_before} → {new_tokens} / {token_limit if token_limit > 0 else '∞'}
|
||||
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
|
||||
355
backend/services/llm_providers/main_video_generation.py
Normal file
355
backend/services/llm_providers/main_video_generation.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""
|
||||
Main Video Generation Service
|
||||
|
||||
Provides a unified interface for AI video generation providers.
|
||||
Initial support: Hugging Face Inference Providers (text-to-video).
|
||||
Stubs included for Gemini (Veo 3) and OpenAI (Sora) for future use.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import base64
|
||||
import io
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
try:
|
||||
from huggingface_hub import InferenceClient
|
||||
HF_HUB_AVAILABLE = True
|
||||
except ImportError:
|
||||
HF_HUB_AVAILABLE = False
|
||||
InferenceClient = None
|
||||
|
||||
from ..onboarding.api_key_manager import APIKeyManager
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("video_generation_service")
|
||||
|
||||
|
||||
class VideoProviderNotImplemented(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _get_api_key(provider: str) -> Optional[str]:
|
||||
try:
|
||||
manager = APIKeyManager()
|
||||
mapping = {
|
||||
"huggingface": "hf_token",
|
||||
"gemini": "gemini", # placeholder for Veo 3
|
||||
"openai": "openai_api_key", # placeholder for Sora
|
||||
}
|
||||
return manager.get_api_key(mapping.get(provider, provider))
|
||||
except Exception as e:
|
||||
logger.error(f"[video_gen] Failed to read API key for {provider}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _coerce_video_bytes(output: Any) -> bytes:
|
||||
"""
|
||||
Normalizes the different return shapes that huggingface_hub may emit for video tasks.
|
||||
Depending on the provider/library version we may get:
|
||||
- raw bytes
|
||||
- an object with `.video` or `.bytes` attributes (plus optional `.save`)
|
||||
- a dict containing a `video` key with bytes/base64 data
|
||||
"""
|
||||
data: Union[bytes, bytearray, memoryview, io.BufferedIOBase, None] = None
|
||||
|
||||
if isinstance(output, (bytes, bytearray, memoryview)):
|
||||
return bytes(output)
|
||||
|
||||
# Objects with direct attribute access
|
||||
if hasattr(output, "video"):
|
||||
data = getattr(output, "video")
|
||||
elif hasattr(output, "bytes"):
|
||||
data = getattr(output, "bytes")
|
||||
elif isinstance(output, dict) and "video" in output:
|
||||
data = output["video"]
|
||||
else:
|
||||
data = output
|
||||
|
||||
# Handle file-like responses
|
||||
if hasattr(data, "read"):
|
||||
data = data.read()
|
||||
|
||||
if isinstance(data, (bytes, bytearray, memoryview)):
|
||||
return bytes(data)
|
||||
|
||||
if isinstance(data, str):
|
||||
# Expecting data URI or raw base64 string
|
||||
if data.startswith("data:"):
|
||||
_, encoded = data.split(",", 1)
|
||||
return base64.b64decode(encoded)
|
||||
try:
|
||||
return base64.b64decode(data)
|
||||
except Exception as exc:
|
||||
raise TypeError(f"Unable to decode string video payload: {exc}") from exc
|
||||
|
||||
raise TypeError(f"Unsupported video payload type: {type(data)}")
|
||||
|
||||
|
||||
def _generate_with_huggingface(
|
||||
prompt: str,
|
||||
num_frames: int = 24 * 4,
|
||||
guidance_scale: float = 7.5,
|
||||
num_inference_steps: int = 30,
|
||||
negative_prompt: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
model: str = "tencent/HunyuanVideo",
|
||||
input_image_bytes: Optional[bytes] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Generates video bytes using Hugging Face's InferenceClient.
|
||||
"""
|
||||
if not HF_HUB_AVAILABLE:
|
||||
raise RuntimeError("huggingface_hub is not installed. Install with: pip install huggingface_hub")
|
||||
|
||||
token = _get_api_key("huggingface")
|
||||
if not token:
|
||||
raise RuntimeError("HF token not configured. Set an hf_token in APIKeyManager.")
|
||||
|
||||
client = InferenceClient(
|
||||
model=model,
|
||||
provider="fal-ai",
|
||||
token=token,
|
||||
)
|
||||
logger.info("[video_gen] Using HuggingFace provider 'fal-ai'")
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"num_frames": num_frames,
|
||||
"guidance_scale": guidance_scale,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
}
|
||||
if negative_prompt:
|
||||
params["negative_prompt"] = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt]
|
||||
if seed is not None:
|
||||
params["seed"] = seed
|
||||
|
||||
logger.info(
|
||||
"[video_gen] HuggingFace request model=%s frames=%s steps=%s mode=%s",
|
||||
model,
|
||||
num_frames,
|
||||
num_inference_steps,
|
||||
"image-to-video" if input_image_bytes else "text-to-video",
|
||||
)
|
||||
|
||||
try:
|
||||
call_kwargs = {**params, "model": model}
|
||||
if input_image_bytes:
|
||||
video_output = client.image_to_video(
|
||||
image=input_image_bytes,
|
||||
prompt=prompt,
|
||||
**call_kwargs,
|
||||
)
|
||||
else:
|
||||
video_output = client.text_to_video(
|
||||
prompt,
|
||||
**call_kwargs,
|
||||
)
|
||||
|
||||
video_bytes = _coerce_video_bytes(video_output)
|
||||
|
||||
if not isinstance(video_bytes, bytes):
|
||||
raise TypeError(f"Expected bytes from text_to_video, got {type(video_bytes)}")
|
||||
|
||||
if len(video_bytes) == 0:
|
||||
raise ValueError("Received empty video bytes from Hugging Face API")
|
||||
|
||||
logger.info(f"[video_gen] Successfully generated video: {len(video_bytes)} bytes")
|
||||
return video_bytes
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
error_type = type(e).__name__
|
||||
logger.error(f"[video_gen] HF error ({error_type}): {error_msg}", exc_info=True)
|
||||
raise HTTPException(status_code=502, detail={
|
||||
"error": f"Hugging Face video generation failed: {error_msg}",
|
||||
"error_type": error_type
|
||||
})
|
||||
|
||||
|
||||
def _generate_with_gemini(prompt: str, **kwargs) -> bytes:
|
||||
raise VideoProviderNotImplemented("Gemini Veo 3 integration coming soon.")
|
||||
|
||||
def _generate_with_openai(prompt: str, **kwargs) -> bytes:
|
||||
raise VideoProviderNotImplemented("OpenAI Sora integration coming soon.")
|
||||
|
||||
|
||||
def ai_video_generate(
|
||||
prompt: str,
|
||||
provider: str = "huggingface",
|
||||
user_id: Optional[str] = None,
|
||||
input_image_bytes: Optional[bytes] = None,
|
||||
**kwargs,
|
||||
) -> bytes:
|
||||
"""
|
||||
Unified video generation entry point.
|
||||
|
||||
- provider: 'huggingface' (default), 'gemini' (veo3 stub), 'openai' (sora stub)
|
||||
- kwargs: num_frames, guidance_scale, num_inference_steps, negative_prompt, seed, model
|
||||
- input_image_bytes: optional bytes for image-to-video flows (uses image as motion anchor)
|
||||
|
||||
Returns raw video bytes (mp4/webm depending on provider).
|
||||
"""
|
||||
logger.info(f"[video_gen] provider={provider}")
|
||||
|
||||
# Enforce authentication usage like text gen does
|
||||
if not user_id:
|
||||
raise RuntimeError("user_id is required for subscription/usage tracking.")
|
||||
|
||||
# PRE-FLIGHT VALIDATION: Validate video generation before API call
|
||||
# MUST happen BEFORE any API calls - return immediately if validation fails
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_video_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
validate_video_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id
|
||||
)
|
||||
except HTTPException:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"[Video Generation] ❌ Pre-flight validation failed - blocking API call")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
logger.info(f"[Video Generation] ✅ Pre-flight validation passed - proceeding with video generation")
|
||||
|
||||
# Generate video
|
||||
model_name = kwargs.get("model", "tencent/HunyuanVideo")
|
||||
try:
|
||||
if provider == "huggingface":
|
||||
video_bytes = _generate_with_huggingface(
|
||||
prompt=prompt,
|
||||
input_image_bytes=input_image_bytes,
|
||||
**kwargs,
|
||||
)
|
||||
elif provider == "gemini":
|
||||
video_bytes = _generate_with_gemini(prompt=prompt, **kwargs)
|
||||
elif provider == "openai":
|
||||
video_bytes = _generate_with_openai(prompt=prompt, **kwargs)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown video provider: {provider}")
|
||||
|
||||
# Track usage AFTER successful generation
|
||||
db_track = next(get_db())
|
||||
try:
|
||||
from models.subscription_models import APIProvider, UsageSummary, APIUsageLog
|
||||
from datetime import datetime
|
||||
from services.subscription import PricingService
|
||||
|
||||
# Create pricing service for tracking (uses same DB session)
|
||||
pricing_service_track = PricingService(db_track)
|
||||
|
||||
# Get current billing period
|
||||
current_period = pricing_service_track.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get or create usage summary
|
||||
usage_summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not usage_summary:
|
||||
usage_summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(usage_summary)
|
||||
db_track.commit()
|
||||
|
||||
# Calculate cost using pricing service
|
||||
cost_info = pricing_service_track.get_pricing_for_provider_model(
|
||||
APIProvider.VIDEO,
|
||||
model_name
|
||||
)
|
||||
cost_per_video = cost_info.get('cost_per_request', 0.10) if cost_info else 0.10
|
||||
|
||||
# Get "before" state for unified log
|
||||
current_video_calls_before = getattr(usage_summary, 'video_calls', 0) or 0
|
||||
current_video_cost = getattr(usage_summary, 'video_cost', 0.0) or 0.0
|
||||
|
||||
# Increment video_calls and track cost
|
||||
new_video_calls = current_video_calls_before + 1
|
||||
usage_summary.video_calls = new_video_calls
|
||||
usage_summary.video_cost = current_video_cost + cost_per_video
|
||||
usage_summary.total_calls = (usage_summary.total_calls or 0) + 1
|
||||
usage_summary.total_cost = (usage_summary.total_cost or 0.0) + cost_per_video
|
||||
|
||||
# Get plan details for unified log (before commit, in case commit fails)
|
||||
limits = pricing_service_track.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
# Get image and image editing stats for unified log
|
||||
current_image_calls = getattr(usage_summary, "stability_calls", 0) or 0
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
current_image_edit_calls = getattr(usage_summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
|
||||
# Create usage log entry for audit trail
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.VIDEO,
|
||||
endpoint=f"/video-generation/{provider}",
|
||||
method="POST",
|
||||
model_used=model_name,
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=cost_per_video,
|
||||
response_time=0.0, # Could track actual time if needed
|
||||
status_code=200,
|
||||
request_size=len(prompt.encode('utf-8')),
|
||||
response_size=len(video_bytes),
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"[video_gen] ✅ Successfully tracked usage: user {user_id} -> 1 video call, ${cost_per_video:.4f} cost")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
# Flush immediately to ensure it's visible in console/logs
|
||||
import sys
|
||||
log_message = f"""
|
||||
[SUBSCRIPTION] Video Generation
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: video
|
||||
├─ Actual Provider: {provider}
|
||||
├─ Model: {model_name or 'default'}
|
||||
├─ Calls: {current_video_calls_before} → {new_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
├─ Images: {current_image_calls} / {image_limit if image_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
"""
|
||||
print(log_message, flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
except Exception as track_error:
|
||||
logger.error(f"[video_gen] Error tracking usage: {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
# Don't fail video generation if tracking fails - video is already generated
|
||||
finally:
|
||||
db_track.close()
|
||||
|
||||
return video_bytes
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions (e.g., from validation or API errors)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[video_gen] Error during video generation: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail={"error": str(e)})
|
||||
|
||||
|
||||
352
backend/services/story_writer/prompt_enhancer_service.py
Normal file
352
backend/services/story_writer/prompt_enhancer_service.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""
|
||||
Prompt Enhancement Service for HunyuanVideo Generation
|
||||
|
||||
Uses AI to deeply understand story context and generate optimized
|
||||
HunyuanVideo prompts following best practices with 7 components.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from loguru import logger
|
||||
from fastapi import HTTPException
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
|
||||
class PromptEnhancerService:
|
||||
"""Service for generating HunyuanVideo-optimized prompts from story context."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the prompt enhancer service."""
|
||||
logger.info("[PromptEnhancer] Service initialized")
|
||||
|
||||
def enhance_scene_prompt(
|
||||
self,
|
||||
current_scene: Dict[str, Any],
|
||||
story_context: Dict[str, Any],
|
||||
all_scenes: List[Dict[str, Any]],
|
||||
user_id: str
|
||||
) -> str:
|
||||
"""
|
||||
Generate a HunyuanVideo-optimized prompt for a scene using two-stage AI analysis.
|
||||
|
||||
Args:
|
||||
current_scene: Scene data for the scene being processed
|
||||
story_context: Complete story context (setup, premise, outline, story text)
|
||||
all_scenes: List of all scenes for consistency analysis
|
||||
user_id: Clerk user ID for subscription checking
|
||||
|
||||
Returns:
|
||||
str: Optimized HunyuanVideo prompt (300-500 words) with 7 components
|
||||
"""
|
||||
try:
|
||||
logger.info(f"[PromptEnhancer] Enhancing prompt for scene {current_scene.get('scene_number', 'unknown')}")
|
||||
|
||||
# Stage 1: Deep story context analysis
|
||||
story_insights = self._analyze_story_context(
|
||||
current_scene=current_scene,
|
||||
story_context=story_context,
|
||||
all_scenes=all_scenes,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Stage 2: Generate optimized HunyuanVideo prompt
|
||||
optimized_prompt = self._generate_hunyuan_prompt(
|
||||
current_scene=current_scene,
|
||||
story_context=story_context,
|
||||
story_insights=story_insights,
|
||||
all_scenes=all_scenes,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
logger.info(f"[PromptEnhancer] Generated prompt length: {len(optimized_prompt)} characters")
|
||||
return optimized_prompt
|
||||
|
||||
except HTTPException as http_err:
|
||||
# Propagate subscription limit errors (429) to frontend for modal display
|
||||
# Only fallback for other HTTP errors (5xx, etc.)
|
||||
if http_err.status_code == 429:
|
||||
error_msg = self._extract_error_message(http_err)
|
||||
logger.warning(f"[PromptEnhancer] Subscription limit exceeded (HTTP 429): {error_msg}")
|
||||
# Re-raise to propagate to frontend for subscription modal
|
||||
raise
|
||||
else:
|
||||
# For other HTTP errors, log and fallback
|
||||
error_msg = self._extract_error_message(http_err)
|
||||
logger.error(f"[PromptEnhancer] Error enhancing prompt (HTTP {http_err.status_code}): {error_msg}", exc_info=True)
|
||||
return self._generate_fallback_prompt(current_scene, story_context)
|
||||
except Exception as e:
|
||||
logger.error(f"[PromptEnhancer] Error enhancing prompt: {str(e)}", exc_info=True)
|
||||
# Fallback to basic prompt if enhancement fails
|
||||
return self._generate_fallback_prompt(current_scene, story_context)
|
||||
|
||||
def _analyze_story_context(
|
||||
self,
|
||||
current_scene: Dict[str, Any],
|
||||
story_context: Dict[str, Any],
|
||||
all_scenes: List[Dict[str, Any]],
|
||||
user_id: str
|
||||
) -> str:
|
||||
"""
|
||||
Stage 1: Use AI to analyze complete story context and extract insights.
|
||||
|
||||
Returns:
|
||||
str: Story insights as JSON string for use in prompt generation
|
||||
"""
|
||||
# Build comprehensive context for analysis
|
||||
analysis_prompt = f"""You are analyzing a complete story to extract key insights for AI video generation.
|
||||
|
||||
**STORY SETUP:**
|
||||
- Persona: {story_context.get('persona', 'N/A')}
|
||||
- Setting: {story_context.get('story_setting', 'N/A')}
|
||||
- Characters: {story_context.get('characters', 'N/A')}
|
||||
- Plot Elements: {story_context.get('plot_elements', 'N/A')}
|
||||
- Writing Style: {story_context.get('writing_style', 'N/A')}
|
||||
- Tone: {story_context.get('story_tone', 'N/A')}
|
||||
- Narrative POV: {story_context.get('narrative_pov', 'N/A')}
|
||||
- Audience: {story_context.get('audience_age_group', 'N/A')}
|
||||
- Content Rating: {story_context.get('content_rating', 'N/A')}
|
||||
|
||||
**STORY PREMISE:**
|
||||
{story_context.get('premise', 'N/A')}
|
||||
|
||||
**STORY CONTENT:**
|
||||
{story_context.get('story_content', 'N/A')[:2000]}...
|
||||
|
||||
**ALL SCENES OVERVIEW:**
|
||||
"""
|
||||
# Add summary of all scenes
|
||||
for idx, scene in enumerate(all_scenes, 1):
|
||||
scene_num = scene.get('scene_number', idx)
|
||||
analysis_prompt += f"\nScene {scene_num}: {scene.get('title', 'Untitled')}"
|
||||
analysis_prompt += f"\n Description: {scene.get('description', '')[:150]}..."
|
||||
analysis_prompt += f"\n Image Prompt: {scene.get('image_prompt', '')[:150]}..."
|
||||
if scene.get('character_descriptions'):
|
||||
chars = ', '.join(scene.get('character_descriptions', [])[:3])
|
||||
analysis_prompt += f"\n Characters: {chars}"
|
||||
analysis_prompt += "\n"
|
||||
|
||||
analysis_prompt += f"""
|
||||
**CURRENT SCENE FOR VIDEO GENERATION:**
|
||||
Scene {current_scene.get('scene_number', 'N/A')}: {current_scene.get('title', 'Untitled')}
|
||||
Description: {current_scene.get('description', '')}
|
||||
Image Prompt: {current_scene.get('image_prompt', '')}
|
||||
Key Events: {', '.join(current_scene.get('key_events', [])[:5])}
|
||||
Character Descriptions: {', '.join(current_scene.get('character_descriptions', [])[:5])}
|
||||
|
||||
**YOUR TASK:**
|
||||
Analyze this story and extract key insights for video generation. Focus on:
|
||||
1. Narrative arc and position of current scene within it
|
||||
2. Character consistency (how characters appear across scenes)
|
||||
3. Visual style patterns from image prompts
|
||||
4. Tone and atmosphere progression
|
||||
5. Key themes and motifs
|
||||
6. Visual narrative flow
|
||||
7. Camera and composition needs for this specific scene
|
||||
|
||||
Provide your analysis as structured insights that can guide prompt generation.
|
||||
"""
|
||||
|
||||
try:
|
||||
insights = llm_text_gen(
|
||||
prompt=analysis_prompt,
|
||||
system_prompt="You are an expert story analyst specializing in visual narrative and cinematic storytelling. Provide detailed, actionable insights for video generation.",
|
||||
user_id=user_id
|
||||
)
|
||||
logger.debug(f"[PromptEnhancer] Story insights extracted: {insights[:200]}...")
|
||||
return insights
|
||||
except HTTPException as http_err:
|
||||
# Propagate subscription limit errors (429) to frontend
|
||||
if http_err.status_code == 429:
|
||||
error_msg = self._extract_error_message(http_err)
|
||||
logger.warning(f"[PromptEnhancer] Subscription limit exceeded during story analysis (HTTP 429): {error_msg}")
|
||||
# Re-raise to propagate to frontend for subscription modal
|
||||
raise
|
||||
else:
|
||||
# For other HTTP errors, log and fallback
|
||||
error_msg = self._extract_error_message(http_err)
|
||||
logger.warning(f"[PromptEnhancer] Story analysis failed (HTTP {http_err.status_code}): {error_msg}, using basic context")
|
||||
return "Standard narrative flow with consistent character presentation"
|
||||
except Exception as e:
|
||||
logger.warning(f"[PromptEnhancer] Story analysis failed, using basic context: {str(e)}")
|
||||
return "Standard narrative flow with consistent character presentation"
|
||||
|
||||
def _generate_hunyuan_prompt(
|
||||
self,
|
||||
current_scene: Dict[str, Any],
|
||||
story_context: Dict[str, Any],
|
||||
story_insights: str,
|
||||
all_scenes: List[Dict[str, Any]],
|
||||
user_id: str
|
||||
) -> str:
|
||||
"""
|
||||
Stage 2: Generate scene-specific HunyuanVideo prompt with all 7 components.
|
||||
|
||||
Returns:
|
||||
str: Complete HunyuanVideo prompt (300-500 words)
|
||||
"""
|
||||
# Collect character descriptions across all scenes for consistency
|
||||
all_characters = {}
|
||||
for scene in all_scenes:
|
||||
for char_desc in scene.get('character_descriptions', []):
|
||||
if char_desc and char_desc not in all_characters:
|
||||
all_characters[char_desc] = scene.get('scene_number', 0)
|
||||
|
||||
# Collect image prompts for visual style reference
|
||||
image_prompts = [scene.get('image_prompt', '') for scene in all_scenes if scene.get('image_prompt')]
|
||||
|
||||
# Determine scene position in narrative arc
|
||||
current_scene_num = current_scene.get('scene_number', 0)
|
||||
total_scenes = len(all_scenes)
|
||||
scene_position = "beginning" if current_scene_num <= total_scenes // 3 else ("middle" if current_scene_num <= 2 * total_scenes // 3 else "climax")
|
||||
|
||||
prompt_generation_request = f"""Generate a professional HunyuanVideo prompt for this story scene.
|
||||
|
||||
**STORY INSIGHTS (from deep analysis):**
|
||||
{story_insights}
|
||||
|
||||
**STORY SETUP:**
|
||||
- Setting: {story_context.get('story_setting', 'N/A')}
|
||||
- Tone: {story_context.get('story_tone', 'N/A')}
|
||||
- Style: {story_context.get('writing_style', 'N/A')}
|
||||
- Audience: {story_context.get('audience_age_group', 'N/A')}
|
||||
|
||||
**VISUAL STYLE REFERENCE (from generated images):**
|
||||
{chr(10).join([f"- {prompt[:100]}..." for prompt in image_prompts[:3]])}
|
||||
|
||||
**CHARACTER CONSISTENCY (across all scenes):**
|
||||
{chr(10).join([f"- {char}" for char in list(all_characters.keys())[:5]])}
|
||||
|
||||
**CURRENT SCENE DETAILS:**
|
||||
- Scene {current_scene.get('scene_number', 'N/A')} of {total_scenes} (narrative position: {scene_position})
|
||||
- Title: {current_scene.get('title', 'Untitled')}
|
||||
- Description: {current_scene.get('description', '')}
|
||||
- Image Prompt: {current_scene.get('image_prompt', '')}
|
||||
- Key Events: {', '.join(current_scene.get('key_events', [])[:5])}
|
||||
- Characters in scene: {', '.join(current_scene.get('character_descriptions', [])[:5])}
|
||||
- Audio Narration: {current_scene.get('audio_narration', '')[:200]}
|
||||
|
||||
**REQUIREMENTS:**
|
||||
Create a comprehensive HunyuanVideo prompt (300-500 words) following the 7-component structure:
|
||||
|
||||
1. **SUBJECT**: Clearly define the main focus - characters, objects, or action. Include character descriptions that match the visual style from image prompts and maintain consistency across scenes.
|
||||
|
||||
2. **SCENE**: Describe the environment and setting. Ensure it matches the story_setting and aligns with the visual style established in previous scenes.
|
||||
|
||||
3. **MOTION**: Detail the specific actions and movements. Reference key_events and ensure motion fits the narrative flow and story_insights about the scene's position in the arc.
|
||||
|
||||
4. **CAMERA MOVEMENT**: Specify cinematic camera work appropriate for this moment in the story. Consider the narrative position ({scene_position}) - use establishing shots for beginning, dynamic shots for climax.
|
||||
|
||||
5. **ATMOSPHERE**: Set the emotional tone. This should reflect the story_tone but also consider where we are in the narrative arc based on story_insights.
|
||||
|
||||
6. **LIGHTING**: Define lighting that matches the visual style from image prompts and supports the atmosphere. Ensure consistency with the established visual aesthetic.
|
||||
|
||||
7. **SHOT COMPOSITION**: Describe framing and composition that serves the visual narrative. Consider the story's visual style and ensure it flows naturally with the overall story.
|
||||
|
||||
Write the prompt as a flowing, detailed description (not a list) that integrates all 7 components naturally. Make it vivid, cinematic, and consistent with the story's established visual and narrative style. The prompt should be between 300-500 words.
|
||||
"""
|
||||
|
||||
try:
|
||||
optimized_prompt = llm_text_gen(
|
||||
prompt=prompt_generation_request,
|
||||
system_prompt="You are an expert video prompt engineer specializing in HunyuanVideo text-to-video generation. Create detailed, cinematic prompts that follow best practices and ensure high-quality video output.",
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Clean up and validate prompt length
|
||||
optimized_prompt = optimized_prompt.strip()
|
||||
word_count = len(optimized_prompt.split())
|
||||
|
||||
if word_count < 200:
|
||||
logger.warning(f"[PromptEnhancer] Generated prompt is too short ({word_count} words), enhancing...")
|
||||
# Add more detail if too short
|
||||
optimized_prompt += self._add_cinematic_details(current_scene, story_context)
|
||||
elif word_count > 600:
|
||||
logger.warning(f"[PromptEnhancer] Generated prompt is too long ({word_count} words), trimming...")
|
||||
# Trim if too long (keep first ~500 words)
|
||||
words = optimized_prompt.split()
|
||||
optimized_prompt = ' '.join(words[:500])
|
||||
|
||||
logger.info(f"[PromptEnhancer] Generated prompt: {len(optimized_prompt.split())} words")
|
||||
return optimized_prompt
|
||||
|
||||
except HTTPException as http_err:
|
||||
# Propagate subscription limit errors (429) to frontend
|
||||
if http_err.status_code == 429:
|
||||
error_msg = self._extract_error_message(http_err)
|
||||
logger.warning(f"[PromptEnhancer] Subscription limit exceeded during prompt generation (HTTP 429): {error_msg}")
|
||||
# Re-raise to propagate to frontend for subscription modal
|
||||
raise
|
||||
else:
|
||||
# For other HTTP errors, log and fallback
|
||||
error_msg = self._extract_error_message(http_err)
|
||||
logger.error(f"[PromptEnhancer] Prompt generation failed (HTTP {http_err.status_code}): {error_msg}", exc_info=True)
|
||||
return self._generate_fallback_prompt(current_scene, story_context)
|
||||
except Exception as e:
|
||||
logger.error(f"[PromptEnhancer] Prompt generation failed: {str(e)}", exc_info=True)
|
||||
return self._generate_fallback_prompt(current_scene, story_context)
|
||||
|
||||
def _add_cinematic_details(
|
||||
self,
|
||||
current_scene: Dict[str, Any],
|
||||
story_context: Dict[str, Any]
|
||||
) -> str:
|
||||
"""Add cinematic details to enhance a too-short prompt."""
|
||||
return f"""
|
||||
|
||||
The scene unfolds with careful attention to visual storytelling. The {story_context.get('story_setting', 'environment')} serves as more than background - it actively participates in the narrative. Lighting and composition work together to emphasize the emotional weight of this moment, with camera movements that guide the viewer's attention naturally through the space. Every element - from the way light falls to the positioning of characters - contributes to the overall narrative impact.
|
||||
"""
|
||||
|
||||
def _extract_error_message(self, http_err: HTTPException) -> str:
|
||||
"""
|
||||
Extract meaningful error message from HTTPException.
|
||||
|
||||
Handles both dict-based details (from subscription limit errors) and string details.
|
||||
"""
|
||||
if isinstance(http_err.detail, dict):
|
||||
# For subscription limit errors, extract the 'message' or 'error' field
|
||||
return http_err.detail.get('message') or http_err.detail.get('error') or str(http_err.detail)
|
||||
elif isinstance(http_err.detail, str):
|
||||
return http_err.detail
|
||||
else:
|
||||
return str(http_err.detail)
|
||||
|
||||
def _generate_fallback_prompt(
|
||||
self,
|
||||
current_scene: Dict[str, Any],
|
||||
story_context: Dict[str, Any]
|
||||
) -> str:
|
||||
"""Generate a basic fallback prompt if AI enhancement fails."""
|
||||
scene_title = current_scene.get('title', 'Untitled Scene')
|
||||
scene_desc = current_scene.get('description', '')
|
||||
image_prompt = current_scene.get('image_prompt', '')
|
||||
setting = story_context.get('story_setting', 'the scene')
|
||||
tone = story_context.get('story_tone', 'engaging')
|
||||
|
||||
return f"""A cinematic scene titled "{scene_title}" set in {setting}. {scene_desc[:200]}.
|
||||
The scene features {', '.join(current_scene.get('character_descriptions', [])[:2]) if current_scene.get('character_descriptions') else 'the main characters'}.
|
||||
Visual style follows: {image_prompt[:150]}.
|
||||
The {tone} atmosphere is enhanced by natural lighting and dynamic camera movements that follow the action.
|
||||
Shot composition emphasizes the narrative importance of this moment, with careful framing that draws attention to key elements.
|
||||
The scene maintains visual consistency with previous moments while advancing the story's visual narrative."""
|
||||
|
||||
|
||||
def enhance_scene_prompt_for_video(
|
||||
current_scene: Dict[str, Any],
|
||||
story_context: Dict[str, Any],
|
||||
all_scenes: List[Dict[str, Any]],
|
||||
user_id: str
|
||||
) -> str:
|
||||
"""
|
||||
Convenience function to enhance a scene prompt for HunyuanVideo generation.
|
||||
|
||||
Args:
|
||||
current_scene: Scene data for the scene being processed
|
||||
story_context: Complete story context dictionary
|
||||
all_scenes: List of all scenes for consistency
|
||||
user_id: Clerk user ID for subscription checking
|
||||
|
||||
Returns:
|
||||
str: Optimized HunyuanVideo prompt
|
||||
"""
|
||||
service = PromptEnhancerService()
|
||||
return service.enhance_scene_prompt(current_scene, story_context, all_scenes, user_id)
|
||||
|
||||
@@ -41,6 +41,47 @@ class StoryVideoGenerationService:
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
return f"story_{clean_title}_{unique_id}.mp4"
|
||||
|
||||
def save_scene_video(self, video_bytes: bytes, scene_number: int, user_id: str) -> Dict[str, str]:
|
||||
"""
|
||||
Save individual scene video bytes to file.
|
||||
|
||||
Parameters:
|
||||
video_bytes: Raw video file bytes (mp4/webm format)
|
||||
scene_number: Scene number for naming
|
||||
user_id: Clerk user ID for naming
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: Video metadata with video_url and video_filename
|
||||
"""
|
||||
try:
|
||||
# Generate filename with scene number and user ID
|
||||
clean_user_id = "".join(c if c.isalnum() or c in ('-', '_') else '_' for c in user_id[:16])
|
||||
timestamp = str(uuid.uuid4())[:8]
|
||||
filename = f"scene_{scene_number}_{clean_user_id}_{timestamp}.mp4"
|
||||
|
||||
video_path = self.output_dir / filename
|
||||
|
||||
# Write video bytes to file
|
||||
with open(video_path, 'wb') as f:
|
||||
f.write(video_bytes)
|
||||
|
||||
file_size = video_path.stat().st_size
|
||||
logger.info(f"[StoryVideoGeneration] Saved scene {scene_number} video: {filename} ({file_size} bytes)")
|
||||
|
||||
# Generate URL path (relative to /api/story/videos/)
|
||||
video_url = f"/api/story/videos/{filename}"
|
||||
|
||||
return {
|
||||
"video_filename": filename,
|
||||
"video_url": video_url,
|
||||
"video_path": str(video_path),
|
||||
"file_size": file_size
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryVideoGeneration] Error saving scene video: {e}", exc_info=True)
|
||||
raise RuntimeError(f"Failed to save scene video: {str(e)}") from e
|
||||
|
||||
def generate_scene_video(
|
||||
self,
|
||||
scene: Dict[str, Any],
|
||||
@@ -125,12 +166,12 @@ class StoryVideoGenerationService:
|
||||
# Use provided duration or audio duration
|
||||
video_duration = duration if duration is not None else audio_duration
|
||||
|
||||
# Create image clip
|
||||
image_clip = ImageClip(str(image_file)).set_duration(video_duration)
|
||||
image_clip = image_clip.set_fps(fps)
|
||||
# Create image clip (MoviePy v2: use with_* API)
|
||||
image_clip = ImageClip(str(image_file)).with_duration(video_duration)
|
||||
image_clip = image_clip.with_fps(fps)
|
||||
|
||||
# Set audio to image clip
|
||||
video_clip = image_clip.set_audio(audio_clip)
|
||||
video_clip = image_clip.with_audio(audio_clip)
|
||||
|
||||
# Generate video filename
|
||||
video_filename = f"scene_{scene_number}_{scene_title.replace(' ', '_').replace('/', '_')[:50]}_{uuid.uuid4().hex[:8]}.mp4"
|
||||
@@ -274,12 +315,12 @@ class StoryVideoGenerationService:
|
||||
audio_clip = AudioFileClip(str(audio_file))
|
||||
audio_duration = audio_clip.duration
|
||||
|
||||
# Create image clip
|
||||
image_clip = ImageClip(str(image_file)).set_duration(audio_duration)
|
||||
image_clip = image_clip.set_fps(fps)
|
||||
# Create image clip (MoviePy v2: use with_* API)
|
||||
image_clip = ImageClip(str(image_file)).with_duration(audio_duration)
|
||||
image_clip = image_clip.with_fps(fps)
|
||||
|
||||
# Set audio to image clip
|
||||
video_clip = image_clip.set_audio(audio_clip)
|
||||
video_clip = image_clip.with_audio(audio_clip)
|
||||
scene_clips.append(video_clip)
|
||||
|
||||
total_duration += audio_duration
|
||||
|
||||
46
backend/services/story_writer/video_preflight.py
Normal file
46
backend/services/story_writer/video_preflight.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def log_video_stack_diagnostics() -> None:
|
||||
try:
|
||||
import sys
|
||||
import platform
|
||||
import importlib
|
||||
|
||||
mv = importlib.import_module("moviepy")
|
||||
im = importlib.import_module("imageio")
|
||||
try:
|
||||
import imageio_ffmpeg as iff
|
||||
ff = iff.get_ffmpeg_exe()
|
||||
except Exception:
|
||||
ff = "unresolved"
|
||||
logger.info(
|
||||
"[VideoStack] py={} plat={} moviepy={} imageio={} ffmpeg={}",
|
||||
sys.executable,
|
||||
platform.platform(),
|
||||
getattr(mv, "__version__", "NA"),
|
||||
getattr(im, "__version__", "NA"),
|
||||
ff,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("[VideoStack] diagnostics failed: {}", e)
|
||||
|
||||
|
||||
def assert_supported_moviepy() -> None:
|
||||
"""Fail fast if MoviePy isn't version 2.x."""
|
||||
try:
|
||||
import pkg_resources as pr
|
||||
mv = pr.get_distribution("moviepy").version
|
||||
if not mv.startswith("2."):
|
||||
raise RuntimeError(
|
||||
f"Unsupported MoviePy version {mv}. Expected 2.x. "
|
||||
"Please install with: pip install moviepy==2.1.2"
|
||||
)
|
||||
except Exception as e:
|
||||
# Log and re-raise so startup fails clearly
|
||||
logger.error("[VideoStack] version check failed: {}", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -694,6 +694,44 @@ class LimitValidator:
|
||||
|
||||
total_images = projected_images
|
||||
|
||||
# Check video generation limits
|
||||
elif provider == APIProvider.VIDEO:
|
||||
video_limit = limits.get('video_calls', 0) or 0
|
||||
total_video_calls = usage.video_calls or 0
|
||||
projected_video_calls = total_video_calls + 1
|
||||
|
||||
if video_limit > 0 and projected_video_calls > video_limit:
|
||||
error_info = {
|
||||
'current_calls': total_video_calls,
|
||||
'limit': video_limit,
|
||||
'provider': 'video',
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"Video generation limit would be exceeded. Would use {projected_video_calls} of {video_limit} videos this billing period.", {
|
||||
'error_type': 'video_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
# Check image editing limits
|
||||
elif provider == APIProvider.IMAGE_EDIT:
|
||||
image_edit_limit = limits.get('image_edit_calls', 0) or 0
|
||||
total_image_edit_calls = getattr(usage, 'image_edit_calls', 0) or 0
|
||||
projected_image_edit_calls = total_image_edit_calls + 1
|
||||
|
||||
if image_edit_limit > 0 and projected_image_edit_calls > image_edit_limit:
|
||||
error_info = {
|
||||
'current_calls': total_image_edit_calls,
|
||||
'limit': image_edit_limit,
|
||||
'provider': 'image_edit',
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"Image editing limit would be exceeded. Would use {projected_image_edit_calls} of {image_edit_limit} image edits this billing period.", {
|
||||
'error_type': 'image_edit_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
# Check other provider-specific limits
|
||||
else:
|
||||
provider_calls_key = f"{provider_name}_calls"
|
||||
|
||||
@@ -299,3 +299,124 @@ def validate_image_generation_operations(
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def validate_image_editing_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Validate image editing operation before making API calls.
|
||||
|
||||
Args:
|
||||
pricing_service: PricingService instance
|
||||
user_id: User ID for subscription checking
|
||||
|
||||
Returns:
|
||||
None - raises HTTPException with 429 status if validation fails
|
||||
"""
|
||||
try:
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.IMAGE_EDIT,
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'image_edit',
|
||||
'operation_type': 'image_editing'
|
||||
}
|
||||
]
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.error(f"[Pre-flight Validator] Image editing blocked for user {user_id}: {message}")
|
||||
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', 'image_edit') if usage_info else 'image_edit'
|
||||
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ Image editing validated for user {user_id}")
|
||||
# Validation passed - no return needed (function raises HTTPException if validation fails)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating image editing: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
'error': f"Failed to validate image editing: {str(e)}",
|
||||
'message': f"Failed to validate image editing: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def validate_video_generation_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Validate video generation operation before making API calls.
|
||||
|
||||
Args:
|
||||
pricing_service: PricingService instance
|
||||
user_id: User ID for subscription checking
|
||||
|
||||
Returns:
|
||||
None - raises HTTPException with 429 status if validation fails
|
||||
"""
|
||||
try:
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.VIDEO,
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'video',
|
||||
'operation_type': 'video_generation'
|
||||
}
|
||||
]
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.error(f"[Pre-flight Validator] Video generation blocked for user {user_id}: {message}")
|
||||
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', 'video') if usage_info else 'video'
|
||||
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ Video generation validated for user {user_id}")
|
||||
# Validation passed - no return needed (function raises HTTPException if validation fails)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating video generation: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
'error': f"Failed to validate video generation: {str(e)}",
|
||||
'message': f"Failed to validate video generation: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
@@ -295,10 +295,22 @@ class PricingService:
|
||||
"model_name": "exa-search",
|
||||
"cost_per_request": 0.005, # $0.005 per search (1-25 results)
|
||||
"description": "Exa Neural Search API"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.VIDEO,
|
||||
"model_name": "tencent/HunyuanVideo",
|
||||
"cost_per_request": 0.10, # $0.10 per video generation (estimated)
|
||||
"description": "HuggingFace AI Video Generation (HunyuanVideo)"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.VIDEO,
|
||||
"model_name": "default",
|
||||
"cost_per_request": 0.10, # $0.10 per video generation (estimated)
|
||||
"description": "AI Video Generation default pricing"
|
||||
}
|
||||
]
|
||||
|
||||
# Combine all pricing data
|
||||
# Combine all pricing data (include video pricing in search_pricing list)
|
||||
all_pricing = gemini_pricing + openai_pricing + anthropic_pricing + mistral_pricing + search_pricing
|
||||
|
||||
# Insert or update pricing data
|
||||
@@ -344,6 +356,8 @@ class PricingService:
|
||||
"firecrawl_calls_limit": 10,
|
||||
"stability_calls_limit": 5,
|
||||
"exa_calls_limit": 100,
|
||||
"video_calls_limit": 0, # No video generation for free tier
|
||||
"image_edit_calls_limit": 10, # 10 AI image editing calls/month
|
||||
"gemini_tokens_limit": 100000,
|
||||
"monthly_cost_limit": 0.0,
|
||||
"features": ["basic_content_generation", "limited_research"],
|
||||
@@ -365,6 +379,8 @@ class PricingService:
|
||||
"firecrawl_calls_limit": 100,
|
||||
"stability_calls_limit": 5,
|
||||
"exa_calls_limit": 500,
|
||||
"video_calls_limit": 20, # 20 videos/month for basic plan
|
||||
"image_edit_calls_limit": 30, # 30 AI image editing calls/month
|
||||
"gemini_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"openai_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"anthropic_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
@@ -388,6 +404,8 @@ class PricingService:
|
||||
"firecrawl_calls_limit": 500,
|
||||
"stability_calls_limit": 200,
|
||||
"exa_calls_limit": 2000,
|
||||
"video_calls_limit": 50, # 50 videos/month for pro plan
|
||||
"image_edit_calls_limit": 100, # 100 AI image editing calls/month
|
||||
"gemini_tokens_limit": 5000000,
|
||||
"openai_tokens_limit": 2500000,
|
||||
"anthropic_tokens_limit": 1000000,
|
||||
@@ -411,6 +429,8 @@ class PricingService:
|
||||
"firecrawl_calls_limit": 0,
|
||||
"stability_calls_limit": 0,
|
||||
"exa_calls_limit": 0, # Unlimited
|
||||
"video_calls_limit": 0, # Unlimited for enterprise
|
||||
"image_edit_calls_limit": 0, # Unlimited image editing for enterprise
|
||||
"gemini_tokens_limit": 0,
|
||||
"openai_tokens_limit": 0,
|
||||
"anthropic_tokens_limit": 0,
|
||||
@@ -429,6 +449,20 @@ class PricingService:
|
||||
if not existing:
|
||||
plan = SubscriptionPlan(**plan_data)
|
||||
self.db.add(plan)
|
||||
else:
|
||||
# Update existing plan with new limits (e.g., image_edit_calls_limit)
|
||||
# This ensures existing plans get new columns like image_edit_calls_limit
|
||||
for key, value in plan_data.items():
|
||||
if key not in ["name", "tier"]: # Don't overwrite name/tier
|
||||
try:
|
||||
# Try to set the attribute (works even if column was just added)
|
||||
setattr(existing, key, value)
|
||||
except (AttributeError, Exception) as e:
|
||||
# If attribute doesn't exist yet (column not migrated), skip it
|
||||
# Schema migration will add it, then this will update it on next run
|
||||
logger.debug(f"Could not set {key} on plan {existing.name}: {e}")
|
||||
existing.updated_at = datetime.utcnow()
|
||||
logger.debug(f"Updated existing plan: {existing.name}")
|
||||
|
||||
self.db.commit()
|
||||
logger.debug("Default subscription plans initialized")
|
||||
@@ -615,6 +649,8 @@ class PricingService:
|
||||
'metaphor_calls': plan.metaphor_calls_limit,
|
||||
'firecrawl_calls': plan.firecrawl_calls_limit,
|
||||
'stability_calls': plan.stability_calls_limit,
|
||||
'video_calls': getattr(plan, 'video_calls_limit', 0), # Support missing column
|
||||
'image_edit_calls': getattr(plan, 'image_edit_calls_limit', 0), # Support missing column
|
||||
# Token limits
|
||||
'gemini_tokens': plan.gemini_tokens_limit,
|
||||
'openai_tokens': plan.openai_tokens_limit,
|
||||
|
||||
@@ -29,6 +29,8 @@ def ensure_subscription_plan_columns(db: Session) -> None:
|
||||
# Columns we may reference in models but might be missing in older DBs
|
||||
required_columns = {
|
||||
"exa_calls_limit": "INTEGER DEFAULT 0",
|
||||
"video_calls_limit": "INTEGER DEFAULT 0",
|
||||
"image_edit_calls_limit": "INTEGER DEFAULT 0",
|
||||
}
|
||||
|
||||
for col_name, ddl in required_columns.items():
|
||||
@@ -78,6 +80,10 @@ def ensure_usage_summaries_columns(db: Session) -> None:
|
||||
required_columns = {
|
||||
"exa_calls": "INTEGER DEFAULT 0",
|
||||
"exa_cost": "REAL DEFAULT 0.0",
|
||||
"video_calls": "INTEGER DEFAULT 0",
|
||||
"video_cost": "REAL DEFAULT 0.0",
|
||||
"image_edit_calls": "INTEGER DEFAULT 0",
|
||||
"image_edit_cost": "REAL DEFAULT 0.0",
|
||||
}
|
||||
|
||||
for col_name, ddl in required_columns.items():
|
||||
|
||||
@@ -608,6 +608,12 @@ class UsageTrackingService:
|
||||
# Reset image generation counters
|
||||
summary.stability_calls = 0
|
||||
|
||||
# Reset video generation counters
|
||||
summary.video_calls = 0
|
||||
|
||||
# Reset image editing counters
|
||||
summary.image_edit_calls = 0
|
||||
|
||||
# Reset cost counters
|
||||
summary.gemini_cost = 0.0
|
||||
summary.openai_cost = 0.0
|
||||
@@ -618,6 +624,9 @@ class UsageTrackingService:
|
||||
summary.metaphor_cost = 0.0
|
||||
summary.firecrawl_cost = 0.0
|
||||
summary.stability_cost = 0.0
|
||||
summary.exa_cost = 0.0
|
||||
summary.video_cost = 0.0
|
||||
summary.image_edit_cost = 0.0
|
||||
|
||||
# Reset totals
|
||||
summary.total_calls = 0
|
||||
|
||||
@@ -161,9 +161,29 @@ def start_backend(enable_reload=False, production_mode=False):
|
||||
|
||||
# Set up clean logging for end users
|
||||
from logging_config import setup_clean_logging, get_uvicorn_log_level
|
||||
# Video stack preflight (diagnostics + version assert)
|
||||
try:
|
||||
from services.story_writer.video_preflight import (
|
||||
log_video_stack_diagnostics,
|
||||
assert_supported_moviepy,
|
||||
)
|
||||
except Exception:
|
||||
# Preflight is optional; continue if module missing
|
||||
log_video_stack_diagnostics = None
|
||||
assert_supported_moviepy = None
|
||||
|
||||
verbose_mode = setup_clean_logging()
|
||||
uvicorn_log_level = get_uvicorn_log_level()
|
||||
|
||||
# Log diagnostics and assert versions (fail fast if misconfigured)
|
||||
try:
|
||||
if log_video_stack_diagnostics:
|
||||
log_video_stack_diagnostics()
|
||||
if assert_supported_moviepy:
|
||||
assert_supported_moviepy()
|
||||
except Exception as _video_stack_err:
|
||||
print(f"[ERROR] Video stack preflight failed: {_video_stack_err}")
|
||||
return False
|
||||
|
||||
uvicorn.run(
|
||||
"app:app",
|
||||
|
||||
Reference in New Issue
Block a user