AI Video Generation Implementation
This commit is contained in:
165
backend/services/llm_providers/main_image_editing.py
Normal file
165
backend/services/llm_providers/main_image_editing.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import io
|
||||
from typing import Optional, Dict, Any
|
||||
from PIL import Image
|
||||
|
||||
from .image_generation import (
|
||||
ImageGenerationOptions,
|
||||
ImageGenerationResult,
|
||||
)
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
try:
|
||||
from huggingface_hub import InferenceClient
|
||||
HF_HUB_AVAILABLE = True
|
||||
except ImportError:
|
||||
HF_HUB_AVAILABLE = False
|
||||
|
||||
|
||||
logger = get_service_logger("image_editing.facade")
|
||||
|
||||
|
||||
DEFAULT_IMAGE_EDIT_MODEL = os.getenv(
|
||||
"HF_IMAGE_EDIT_MODEL",
|
||||
"Qwen/Qwen-Image-Edit",
|
||||
)
|
||||
|
||||
|
||||
def _select_provider(explicit: Optional[str]) -> str:
|
||||
"""Select provider for image editing. Defaults to huggingface with fal-ai."""
|
||||
if explicit:
|
||||
return explicit
|
||||
# Default to huggingface for image editing (best support for image-to-image)
|
||||
return "huggingface"
|
||||
|
||||
|
||||
def _get_provider_client(provider_name: str, api_key: Optional[str] = None):
|
||||
"""Get InferenceClient for the specified provider."""
|
||||
if not HF_HUB_AVAILABLE:
|
||||
raise RuntimeError("huggingface_hub is not installed. Install with: pip install huggingface_hub")
|
||||
|
||||
if provider_name == "huggingface":
|
||||
api_key = api_key or os.getenv("HF_TOKEN")
|
||||
if not api_key:
|
||||
raise RuntimeError("HF_TOKEN is required for Hugging Face image editing")
|
||||
# Use fal-ai provider for fast inference
|
||||
return InferenceClient(provider="fal-ai", api_key=api_key)
|
||||
|
||||
raise ValueError(f"Unknown image editing provider: {provider_name}")
|
||||
|
||||
|
||||
def edit_image(
|
||||
input_image_bytes: bytes,
|
||||
prompt: str,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> ImageGenerationResult:
|
||||
"""Edit image with pre-flight validation.
|
||||
|
||||
Args:
|
||||
input_image_bytes: Input image as bytes (PNG/JPEG)
|
||||
prompt: Natural language prompt describing desired edits (e.g., "Turn the cat into a tiger")
|
||||
options: Image editing options (provider, model, etc.)
|
||||
user_id: User ID for subscription checking (optional, but required for validation)
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with edited image bytes and metadata
|
||||
|
||||
Best Practices for Prompts:
|
||||
- Use clear, specific language describing desired changes
|
||||
- Describe what should change and what should remain
|
||||
- Examples: "Turn the cat into a tiger", "Change background to forest",
|
||||
"Make it look like a watercolor painting"
|
||||
"""
|
||||
# PRE-FLIGHT VALIDATION: Validate image editing before API call
|
||||
# MUST happen BEFORE any API calls - return immediately if validation fails
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_editing_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
validate_image_editing_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id
|
||||
)
|
||||
except HTTPException as http_ex:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"[Image Editing] ❌ Pre-flight validation failed - blocking API call")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
logger.info(f"[Image Editing] ✅ Pre-flight validation passed - proceeding with image editing")
|
||||
|
||||
# Validate input
|
||||
if not input_image_bytes:
|
||||
raise ValueError("input_image_bytes is required")
|
||||
if not prompt or not prompt.strip():
|
||||
raise ValueError("prompt is required for image editing")
|
||||
|
||||
opts = options or {}
|
||||
provider_name = _select_provider(opts.get("provider"))
|
||||
model = opts.get("model") or DEFAULT_IMAGE_EDIT_MODEL
|
||||
|
||||
logger.info(f"[Image Editing] Editing image via provider={provider_name} model={model}")
|
||||
|
||||
# Get provider client
|
||||
client = _get_provider_client(provider_name, opts.get("api_key"))
|
||||
|
||||
# Prepare parameters for image-to-image
|
||||
params: Dict[str, Any] = {}
|
||||
if opts.get("guidance_scale") is not None:
|
||||
params["guidance_scale"] = opts.get("guidance_scale")
|
||||
if opts.get("steps") is not None:
|
||||
params["num_inference_steps"] = opts.get("steps")
|
||||
if opts.get("seed") is not None:
|
||||
params["seed"] = opts.get("seed")
|
||||
|
||||
try:
|
||||
# Convert input image bytes to PIL Image for validation
|
||||
input_image = Image.open(io.BytesIO(input_image_bytes))
|
||||
width = input_image.width
|
||||
height = input_image.height
|
||||
|
||||
# Use image_to_image method from Hugging Face InferenceClient
|
||||
# This follows the pattern from the Hugging Face documentation
|
||||
# Docs: https://huggingface.co/docs/inference-providers/en/guides/image-editor
|
||||
edited_image: Image.Image = client.image_to_image(
|
||||
image=input_image,
|
||||
prompt=prompt.strip(),
|
||||
model=model,
|
||||
**params,
|
||||
)
|
||||
|
||||
# Convert edited image back to bytes
|
||||
with io.BytesIO() as buf:
|
||||
edited_image.save(buf, format="PNG")
|
||||
edited_image_bytes = buf.getvalue()
|
||||
|
||||
logger.info(f"[Image Editing] ✅ Successfully edited image: {len(edited_image_bytes)} bytes")
|
||||
|
||||
return ImageGenerationResult(
|
||||
image_bytes=edited_image_bytes,
|
||||
width=edited_image.width,
|
||||
height=edited_image.height,
|
||||
provider="huggingface",
|
||||
model=model,
|
||||
seed=opts.get("seed"),
|
||||
metadata={
|
||||
"provider": "fal-ai",
|
||||
"operation": "image_editing",
|
||||
"original_width": width,
|
||||
"original_height": height,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Image Editing] ❌ Error editing image: {e}", exc_info=True)
|
||||
raise RuntimeError(f"Image editing failed: {str(e)}")
|
||||
|
||||
@@ -507,6 +507,14 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
current_images_before = getattr(summary, "stability_calls", 0) or 0
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
|
||||
# Get image editing stats for unified log
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
|
||||
# Get video stats for unified log
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
# CRITICAL DEBUG: Print diagnostic info BEFORE commit (always visible, flushed immediately)
|
||||
import sys
|
||||
debug_msg = f"[DEBUG] BEFORE COMMIT - Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}, Period: {current_period}, New calls will be: {new_calls}, New tokens will be: {new_tokens}"
|
||||
@@ -562,6 +570,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'}
|
||||
├─ Tokens: {current_tokens_before} → {new_tokens} / {token_limit if token_limit > 0 else '∞'}
|
||||
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
@@ -802,6 +811,14 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
current_images_before = getattr(summary, "stability_calls", 0) or 0
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
|
||||
# Get image editing stats for unified log
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
|
||||
# Get video stats for unified log
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
|
||||
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
|
||||
db_track.commit() # Commit transaction to make changes visible to other sessions
|
||||
@@ -819,6 +836,8 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'}
|
||||
├─ Tokens: {current_tokens_before} → {new_tokens} / {token_limit if token_limit > 0 else '∞'}
|
||||
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
|
||||
355
backend/services/llm_providers/main_video_generation.py
Normal file
355
backend/services/llm_providers/main_video_generation.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""
|
||||
Main Video Generation Service
|
||||
|
||||
Provides a unified interface for AI video generation providers.
|
||||
Initial support: Hugging Face Inference Providers (text-to-video).
|
||||
Stubs included for Gemini (Veo 3) and OpenAI (Sora) for future use.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import base64
|
||||
import io
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
try:
|
||||
from huggingface_hub import InferenceClient
|
||||
HF_HUB_AVAILABLE = True
|
||||
except ImportError:
|
||||
HF_HUB_AVAILABLE = False
|
||||
InferenceClient = None
|
||||
|
||||
from ..onboarding.api_key_manager import APIKeyManager
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("video_generation_service")
|
||||
|
||||
|
||||
class VideoProviderNotImplemented(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _get_api_key(provider: str) -> Optional[str]:
|
||||
try:
|
||||
manager = APIKeyManager()
|
||||
mapping = {
|
||||
"huggingface": "hf_token",
|
||||
"gemini": "gemini", # placeholder for Veo 3
|
||||
"openai": "openai_api_key", # placeholder for Sora
|
||||
}
|
||||
return manager.get_api_key(mapping.get(provider, provider))
|
||||
except Exception as e:
|
||||
logger.error(f"[video_gen] Failed to read API key for {provider}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _coerce_video_bytes(output: Any) -> bytes:
|
||||
"""
|
||||
Normalizes the different return shapes that huggingface_hub may emit for video tasks.
|
||||
Depending on the provider/library version we may get:
|
||||
- raw bytes
|
||||
- an object with `.video` or `.bytes` attributes (plus optional `.save`)
|
||||
- a dict containing a `video` key with bytes/base64 data
|
||||
"""
|
||||
data: Union[bytes, bytearray, memoryview, io.BufferedIOBase, None] = None
|
||||
|
||||
if isinstance(output, (bytes, bytearray, memoryview)):
|
||||
return bytes(output)
|
||||
|
||||
# Objects with direct attribute access
|
||||
if hasattr(output, "video"):
|
||||
data = getattr(output, "video")
|
||||
elif hasattr(output, "bytes"):
|
||||
data = getattr(output, "bytes")
|
||||
elif isinstance(output, dict) and "video" in output:
|
||||
data = output["video"]
|
||||
else:
|
||||
data = output
|
||||
|
||||
# Handle file-like responses
|
||||
if hasattr(data, "read"):
|
||||
data = data.read()
|
||||
|
||||
if isinstance(data, (bytes, bytearray, memoryview)):
|
||||
return bytes(data)
|
||||
|
||||
if isinstance(data, str):
|
||||
# Expecting data URI or raw base64 string
|
||||
if data.startswith("data:"):
|
||||
_, encoded = data.split(",", 1)
|
||||
return base64.b64decode(encoded)
|
||||
try:
|
||||
return base64.b64decode(data)
|
||||
except Exception as exc:
|
||||
raise TypeError(f"Unable to decode string video payload: {exc}") from exc
|
||||
|
||||
raise TypeError(f"Unsupported video payload type: {type(data)}")
|
||||
|
||||
|
||||
def _generate_with_huggingface(
|
||||
prompt: str,
|
||||
num_frames: int = 24 * 4,
|
||||
guidance_scale: float = 7.5,
|
||||
num_inference_steps: int = 30,
|
||||
negative_prompt: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
model: str = "tencent/HunyuanVideo",
|
||||
input_image_bytes: Optional[bytes] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Generates video bytes using Hugging Face's InferenceClient.
|
||||
"""
|
||||
if not HF_HUB_AVAILABLE:
|
||||
raise RuntimeError("huggingface_hub is not installed. Install with: pip install huggingface_hub")
|
||||
|
||||
token = _get_api_key("huggingface")
|
||||
if not token:
|
||||
raise RuntimeError("HF token not configured. Set an hf_token in APIKeyManager.")
|
||||
|
||||
client = InferenceClient(
|
||||
model=model,
|
||||
provider="fal-ai",
|
||||
token=token,
|
||||
)
|
||||
logger.info("[video_gen] Using HuggingFace provider 'fal-ai'")
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"num_frames": num_frames,
|
||||
"guidance_scale": guidance_scale,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
}
|
||||
if negative_prompt:
|
||||
params["negative_prompt"] = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt]
|
||||
if seed is not None:
|
||||
params["seed"] = seed
|
||||
|
||||
logger.info(
|
||||
"[video_gen] HuggingFace request model=%s frames=%s steps=%s mode=%s",
|
||||
model,
|
||||
num_frames,
|
||||
num_inference_steps,
|
||||
"image-to-video" if input_image_bytes else "text-to-video",
|
||||
)
|
||||
|
||||
try:
|
||||
call_kwargs = {**params, "model": model}
|
||||
if input_image_bytes:
|
||||
video_output = client.image_to_video(
|
||||
image=input_image_bytes,
|
||||
prompt=prompt,
|
||||
**call_kwargs,
|
||||
)
|
||||
else:
|
||||
video_output = client.text_to_video(
|
||||
prompt,
|
||||
**call_kwargs,
|
||||
)
|
||||
|
||||
video_bytes = _coerce_video_bytes(video_output)
|
||||
|
||||
if not isinstance(video_bytes, bytes):
|
||||
raise TypeError(f"Expected bytes from text_to_video, got {type(video_bytes)}")
|
||||
|
||||
if len(video_bytes) == 0:
|
||||
raise ValueError("Received empty video bytes from Hugging Face API")
|
||||
|
||||
logger.info(f"[video_gen] Successfully generated video: {len(video_bytes)} bytes")
|
||||
return video_bytes
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
error_type = type(e).__name__
|
||||
logger.error(f"[video_gen] HF error ({error_type}): {error_msg}", exc_info=True)
|
||||
raise HTTPException(status_code=502, detail={
|
||||
"error": f"Hugging Face video generation failed: {error_msg}",
|
||||
"error_type": error_type
|
||||
})
|
||||
|
||||
|
||||
def _generate_with_gemini(prompt: str, **kwargs) -> bytes:
|
||||
raise VideoProviderNotImplemented("Gemini Veo 3 integration coming soon.")
|
||||
|
||||
def _generate_with_openai(prompt: str, **kwargs) -> bytes:
|
||||
raise VideoProviderNotImplemented("OpenAI Sora integration coming soon.")
|
||||
|
||||
|
||||
def ai_video_generate(
|
||||
prompt: str,
|
||||
provider: str = "huggingface",
|
||||
user_id: Optional[str] = None,
|
||||
input_image_bytes: Optional[bytes] = None,
|
||||
**kwargs,
|
||||
) -> bytes:
|
||||
"""
|
||||
Unified video generation entry point.
|
||||
|
||||
- provider: 'huggingface' (default), 'gemini' (veo3 stub), 'openai' (sora stub)
|
||||
- kwargs: num_frames, guidance_scale, num_inference_steps, negative_prompt, seed, model
|
||||
- input_image_bytes: optional bytes for image-to-video flows (uses image as motion anchor)
|
||||
|
||||
Returns raw video bytes (mp4/webm depending on provider).
|
||||
"""
|
||||
logger.info(f"[video_gen] provider={provider}")
|
||||
|
||||
# Enforce authentication usage like text gen does
|
||||
if not user_id:
|
||||
raise RuntimeError("user_id is required for subscription/usage tracking.")
|
||||
|
||||
# PRE-FLIGHT VALIDATION: Validate video generation before API call
|
||||
# MUST happen BEFORE any API calls - return immediately if validation fails
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_video_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
validate_video_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id
|
||||
)
|
||||
except HTTPException:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"[Video Generation] ❌ Pre-flight validation failed - blocking API call")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
logger.info(f"[Video Generation] ✅ Pre-flight validation passed - proceeding with video generation")
|
||||
|
||||
# Generate video
|
||||
model_name = kwargs.get("model", "tencent/HunyuanVideo")
|
||||
try:
|
||||
if provider == "huggingface":
|
||||
video_bytes = _generate_with_huggingface(
|
||||
prompt=prompt,
|
||||
input_image_bytes=input_image_bytes,
|
||||
**kwargs,
|
||||
)
|
||||
elif provider == "gemini":
|
||||
video_bytes = _generate_with_gemini(prompt=prompt, **kwargs)
|
||||
elif provider == "openai":
|
||||
video_bytes = _generate_with_openai(prompt=prompt, **kwargs)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown video provider: {provider}")
|
||||
|
||||
# Track usage AFTER successful generation
|
||||
db_track = next(get_db())
|
||||
try:
|
||||
from models.subscription_models import APIProvider, UsageSummary, APIUsageLog
|
||||
from datetime import datetime
|
||||
from services.subscription import PricingService
|
||||
|
||||
# Create pricing service for tracking (uses same DB session)
|
||||
pricing_service_track = PricingService(db_track)
|
||||
|
||||
# Get current billing period
|
||||
current_period = pricing_service_track.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get or create usage summary
|
||||
usage_summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not usage_summary:
|
||||
usage_summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(usage_summary)
|
||||
db_track.commit()
|
||||
|
||||
# Calculate cost using pricing service
|
||||
cost_info = pricing_service_track.get_pricing_for_provider_model(
|
||||
APIProvider.VIDEO,
|
||||
model_name
|
||||
)
|
||||
cost_per_video = cost_info.get('cost_per_request', 0.10) if cost_info else 0.10
|
||||
|
||||
# Get "before" state for unified log
|
||||
current_video_calls_before = getattr(usage_summary, 'video_calls', 0) or 0
|
||||
current_video_cost = getattr(usage_summary, 'video_cost', 0.0) or 0.0
|
||||
|
||||
# Increment video_calls and track cost
|
||||
new_video_calls = current_video_calls_before + 1
|
||||
usage_summary.video_calls = new_video_calls
|
||||
usage_summary.video_cost = current_video_cost + cost_per_video
|
||||
usage_summary.total_calls = (usage_summary.total_calls or 0) + 1
|
||||
usage_summary.total_cost = (usage_summary.total_cost or 0.0) + cost_per_video
|
||||
|
||||
# Get plan details for unified log (before commit, in case commit fails)
|
||||
limits = pricing_service_track.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
# Get image and image editing stats for unified log
|
||||
current_image_calls = getattr(usage_summary, "stability_calls", 0) or 0
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
current_image_edit_calls = getattr(usage_summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
|
||||
# Create usage log entry for audit trail
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.VIDEO,
|
||||
endpoint=f"/video-generation/{provider}",
|
||||
method="POST",
|
||||
model_used=model_name,
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=cost_per_video,
|
||||
response_time=0.0, # Could track actual time if needed
|
||||
status_code=200,
|
||||
request_size=len(prompt.encode('utf-8')),
|
||||
response_size=len(video_bytes),
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"[video_gen] ✅ Successfully tracked usage: user {user_id} -> 1 video call, ${cost_per_video:.4f} cost")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
# Flush immediately to ensure it's visible in console/logs
|
||||
import sys
|
||||
log_message = f"""
|
||||
[SUBSCRIPTION] Video Generation
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: video
|
||||
├─ Actual Provider: {provider}
|
||||
├─ Model: {model_name or 'default'}
|
||||
├─ Calls: {current_video_calls_before} → {new_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
├─ Images: {current_image_calls} / {image_limit if image_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
"""
|
||||
print(log_message, flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
except Exception as track_error:
|
||||
logger.error(f"[video_gen] Error tracking usage: {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
# Don't fail video generation if tracking fails - video is already generated
|
||||
finally:
|
||||
db_track.close()
|
||||
|
||||
return video_bytes
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions (e.g., from validation or API errors)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[video_gen] Error during video generation: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail={"error": str(e)})
|
||||
|
||||
|
||||
352
backend/services/story_writer/prompt_enhancer_service.py
Normal file
352
backend/services/story_writer/prompt_enhancer_service.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""
|
||||
Prompt Enhancement Service for HunyuanVideo Generation
|
||||
|
||||
Uses AI to deeply understand story context and generate optimized
|
||||
HunyuanVideo prompts following best practices with 7 components.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from loguru import logger
|
||||
from fastapi import HTTPException
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
|
||||
class PromptEnhancerService:
|
||||
"""Service for generating HunyuanVideo-optimized prompts from story context."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the prompt enhancer service."""
|
||||
logger.info("[PromptEnhancer] Service initialized")
|
||||
|
||||
def enhance_scene_prompt(
|
||||
self,
|
||||
current_scene: Dict[str, Any],
|
||||
story_context: Dict[str, Any],
|
||||
all_scenes: List[Dict[str, Any]],
|
||||
user_id: str
|
||||
) -> str:
|
||||
"""
|
||||
Generate a HunyuanVideo-optimized prompt for a scene using two-stage AI analysis.
|
||||
|
||||
Args:
|
||||
current_scene: Scene data for the scene being processed
|
||||
story_context: Complete story context (setup, premise, outline, story text)
|
||||
all_scenes: List of all scenes for consistency analysis
|
||||
user_id: Clerk user ID for subscription checking
|
||||
|
||||
Returns:
|
||||
str: Optimized HunyuanVideo prompt (300-500 words) with 7 components
|
||||
"""
|
||||
try:
|
||||
logger.info(f"[PromptEnhancer] Enhancing prompt for scene {current_scene.get('scene_number', 'unknown')}")
|
||||
|
||||
# Stage 1: Deep story context analysis
|
||||
story_insights = self._analyze_story_context(
|
||||
current_scene=current_scene,
|
||||
story_context=story_context,
|
||||
all_scenes=all_scenes,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Stage 2: Generate optimized HunyuanVideo prompt
|
||||
optimized_prompt = self._generate_hunyuan_prompt(
|
||||
current_scene=current_scene,
|
||||
story_context=story_context,
|
||||
story_insights=story_insights,
|
||||
all_scenes=all_scenes,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
logger.info(f"[PromptEnhancer] Generated prompt length: {len(optimized_prompt)} characters")
|
||||
return optimized_prompt
|
||||
|
||||
except HTTPException as http_err:
|
||||
# Propagate subscription limit errors (429) to frontend for modal display
|
||||
# Only fallback for other HTTP errors (5xx, etc.)
|
||||
if http_err.status_code == 429:
|
||||
error_msg = self._extract_error_message(http_err)
|
||||
logger.warning(f"[PromptEnhancer] Subscription limit exceeded (HTTP 429): {error_msg}")
|
||||
# Re-raise to propagate to frontend for subscription modal
|
||||
raise
|
||||
else:
|
||||
# For other HTTP errors, log and fallback
|
||||
error_msg = self._extract_error_message(http_err)
|
||||
logger.error(f"[PromptEnhancer] Error enhancing prompt (HTTP {http_err.status_code}): {error_msg}", exc_info=True)
|
||||
return self._generate_fallback_prompt(current_scene, story_context)
|
||||
except Exception as e:
|
||||
logger.error(f"[PromptEnhancer] Error enhancing prompt: {str(e)}", exc_info=True)
|
||||
# Fallback to basic prompt if enhancement fails
|
||||
return self._generate_fallback_prompt(current_scene, story_context)
|
||||
|
||||
def _analyze_story_context(
|
||||
self,
|
||||
current_scene: Dict[str, Any],
|
||||
story_context: Dict[str, Any],
|
||||
all_scenes: List[Dict[str, Any]],
|
||||
user_id: str
|
||||
) -> str:
|
||||
"""
|
||||
Stage 1: Use AI to analyze complete story context and extract insights.
|
||||
|
||||
Returns:
|
||||
str: Story insights as JSON string for use in prompt generation
|
||||
"""
|
||||
# Build comprehensive context for analysis
|
||||
analysis_prompt = f"""You are analyzing a complete story to extract key insights for AI video generation.
|
||||
|
||||
**STORY SETUP:**
|
||||
- Persona: {story_context.get('persona', 'N/A')}
|
||||
- Setting: {story_context.get('story_setting', 'N/A')}
|
||||
- Characters: {story_context.get('characters', 'N/A')}
|
||||
- Plot Elements: {story_context.get('plot_elements', 'N/A')}
|
||||
- Writing Style: {story_context.get('writing_style', 'N/A')}
|
||||
- Tone: {story_context.get('story_tone', 'N/A')}
|
||||
- Narrative POV: {story_context.get('narrative_pov', 'N/A')}
|
||||
- Audience: {story_context.get('audience_age_group', 'N/A')}
|
||||
- Content Rating: {story_context.get('content_rating', 'N/A')}
|
||||
|
||||
**STORY PREMISE:**
|
||||
{story_context.get('premise', 'N/A')}
|
||||
|
||||
**STORY CONTENT:**
|
||||
{story_context.get('story_content', 'N/A')[:2000]}...
|
||||
|
||||
**ALL SCENES OVERVIEW:**
|
||||
"""
|
||||
# Add summary of all scenes
|
||||
for idx, scene in enumerate(all_scenes, 1):
|
||||
scene_num = scene.get('scene_number', idx)
|
||||
analysis_prompt += f"\nScene {scene_num}: {scene.get('title', 'Untitled')}"
|
||||
analysis_prompt += f"\n Description: {scene.get('description', '')[:150]}..."
|
||||
analysis_prompt += f"\n Image Prompt: {scene.get('image_prompt', '')[:150]}..."
|
||||
if scene.get('character_descriptions'):
|
||||
chars = ', '.join(scene.get('character_descriptions', [])[:3])
|
||||
analysis_prompt += f"\n Characters: {chars}"
|
||||
analysis_prompt += "\n"
|
||||
|
||||
analysis_prompt += f"""
|
||||
**CURRENT SCENE FOR VIDEO GENERATION:**
|
||||
Scene {current_scene.get('scene_number', 'N/A')}: {current_scene.get('title', 'Untitled')}
|
||||
Description: {current_scene.get('description', '')}
|
||||
Image Prompt: {current_scene.get('image_prompt', '')}
|
||||
Key Events: {', '.join(current_scene.get('key_events', [])[:5])}
|
||||
Character Descriptions: {', '.join(current_scene.get('character_descriptions', [])[:5])}
|
||||
|
||||
**YOUR TASK:**
|
||||
Analyze this story and extract key insights for video generation. Focus on:
|
||||
1. Narrative arc and position of current scene within it
|
||||
2. Character consistency (how characters appear across scenes)
|
||||
3. Visual style patterns from image prompts
|
||||
4. Tone and atmosphere progression
|
||||
5. Key themes and motifs
|
||||
6. Visual narrative flow
|
||||
7. Camera and composition needs for this specific scene
|
||||
|
||||
Provide your analysis as structured insights that can guide prompt generation.
|
||||
"""
|
||||
|
||||
try:
|
||||
insights = llm_text_gen(
|
||||
prompt=analysis_prompt,
|
||||
system_prompt="You are an expert story analyst specializing in visual narrative and cinematic storytelling. Provide detailed, actionable insights for video generation.",
|
||||
user_id=user_id
|
||||
)
|
||||
logger.debug(f"[PromptEnhancer] Story insights extracted: {insights[:200]}...")
|
||||
return insights
|
||||
except HTTPException as http_err:
|
||||
# Propagate subscription limit errors (429) to frontend
|
||||
if http_err.status_code == 429:
|
||||
error_msg = self._extract_error_message(http_err)
|
||||
logger.warning(f"[PromptEnhancer] Subscription limit exceeded during story analysis (HTTP 429): {error_msg}")
|
||||
# Re-raise to propagate to frontend for subscription modal
|
||||
raise
|
||||
else:
|
||||
# For other HTTP errors, log and fallback
|
||||
error_msg = self._extract_error_message(http_err)
|
||||
logger.warning(f"[PromptEnhancer] Story analysis failed (HTTP {http_err.status_code}): {error_msg}, using basic context")
|
||||
return "Standard narrative flow with consistent character presentation"
|
||||
except Exception as e:
|
||||
logger.warning(f"[PromptEnhancer] Story analysis failed, using basic context: {str(e)}")
|
||||
return "Standard narrative flow with consistent character presentation"
|
||||
|
||||
def _generate_hunyuan_prompt(
|
||||
self,
|
||||
current_scene: Dict[str, Any],
|
||||
story_context: Dict[str, Any],
|
||||
story_insights: str,
|
||||
all_scenes: List[Dict[str, Any]],
|
||||
user_id: str
|
||||
) -> str:
|
||||
"""
|
||||
Stage 2: Generate scene-specific HunyuanVideo prompt with all 7 components.
|
||||
|
||||
Returns:
|
||||
str: Complete HunyuanVideo prompt (300-500 words)
|
||||
"""
|
||||
# Collect character descriptions across all scenes for consistency
|
||||
all_characters = {}
|
||||
for scene in all_scenes:
|
||||
for char_desc in scene.get('character_descriptions', []):
|
||||
if char_desc and char_desc not in all_characters:
|
||||
all_characters[char_desc] = scene.get('scene_number', 0)
|
||||
|
||||
# Collect image prompts for visual style reference
|
||||
image_prompts = [scene.get('image_prompt', '') for scene in all_scenes if scene.get('image_prompt')]
|
||||
|
||||
# Determine scene position in narrative arc
|
||||
current_scene_num = current_scene.get('scene_number', 0)
|
||||
total_scenes = len(all_scenes)
|
||||
scene_position = "beginning" if current_scene_num <= total_scenes // 3 else ("middle" if current_scene_num <= 2 * total_scenes // 3 else "climax")
|
||||
|
||||
prompt_generation_request = f"""Generate a professional HunyuanVideo prompt for this story scene.
|
||||
|
||||
**STORY INSIGHTS (from deep analysis):**
|
||||
{story_insights}
|
||||
|
||||
**STORY SETUP:**
|
||||
- Setting: {story_context.get('story_setting', 'N/A')}
|
||||
- Tone: {story_context.get('story_tone', 'N/A')}
|
||||
- Style: {story_context.get('writing_style', 'N/A')}
|
||||
- Audience: {story_context.get('audience_age_group', 'N/A')}
|
||||
|
||||
**VISUAL STYLE REFERENCE (from generated images):**
|
||||
{chr(10).join([f"- {prompt[:100]}..." for prompt in image_prompts[:3]])}
|
||||
|
||||
**CHARACTER CONSISTENCY (across all scenes):**
|
||||
{chr(10).join([f"- {char}" for char in list(all_characters.keys())[:5]])}
|
||||
|
||||
**CURRENT SCENE DETAILS:**
|
||||
- Scene {current_scene.get('scene_number', 'N/A')} of {total_scenes} (narrative position: {scene_position})
|
||||
- Title: {current_scene.get('title', 'Untitled')}
|
||||
- Description: {current_scene.get('description', '')}
|
||||
- Image Prompt: {current_scene.get('image_prompt', '')}
|
||||
- Key Events: {', '.join(current_scene.get('key_events', [])[:5])}
|
||||
- Characters in scene: {', '.join(current_scene.get('character_descriptions', [])[:5])}
|
||||
- Audio Narration: {current_scene.get('audio_narration', '')[:200]}
|
||||
|
||||
**REQUIREMENTS:**
|
||||
Create a comprehensive HunyuanVideo prompt (300-500 words) following the 7-component structure:
|
||||
|
||||
1. **SUBJECT**: Clearly define the main focus - characters, objects, or action. Include character descriptions that match the visual style from image prompts and maintain consistency across scenes.
|
||||
|
||||
2. **SCENE**: Describe the environment and setting. Ensure it matches the story_setting and aligns with the visual style established in previous scenes.
|
||||
|
||||
3. **MOTION**: Detail the specific actions and movements. Reference key_events and ensure motion fits the narrative flow and story_insights about the scene's position in the arc.
|
||||
|
||||
4. **CAMERA MOVEMENT**: Specify cinematic camera work appropriate for this moment in the story. Consider the narrative position ({scene_position}) - use establishing shots for beginning, dynamic shots for climax.
|
||||
|
||||
5. **ATMOSPHERE**: Set the emotional tone. This should reflect the story_tone but also consider where we are in the narrative arc based on story_insights.
|
||||
|
||||
6. **LIGHTING**: Define lighting that matches the visual style from image prompts and supports the atmosphere. Ensure consistency with the established visual aesthetic.
|
||||
|
||||
7. **SHOT COMPOSITION**: Describe framing and composition that serves the visual narrative. Consider the story's visual style and ensure it flows naturally with the overall story.
|
||||
|
||||
Write the prompt as a flowing, detailed description (not a list) that integrates all 7 components naturally. Make it vivid, cinematic, and consistent with the story's established visual and narrative style. The prompt should be between 300-500 words.
|
||||
"""
|
||||
|
||||
try:
|
||||
optimized_prompt = llm_text_gen(
|
||||
prompt=prompt_generation_request,
|
||||
system_prompt="You are an expert video prompt engineer specializing in HunyuanVideo text-to-video generation. Create detailed, cinematic prompts that follow best practices and ensure high-quality video output.",
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Clean up and validate prompt length
|
||||
optimized_prompt = optimized_prompt.strip()
|
||||
word_count = len(optimized_prompt.split())
|
||||
|
||||
if word_count < 200:
|
||||
logger.warning(f"[PromptEnhancer] Generated prompt is too short ({word_count} words), enhancing...")
|
||||
# Add more detail if too short
|
||||
optimized_prompt += self._add_cinematic_details(current_scene, story_context)
|
||||
elif word_count > 600:
|
||||
logger.warning(f"[PromptEnhancer] Generated prompt is too long ({word_count} words), trimming...")
|
||||
# Trim if too long (keep first ~500 words)
|
||||
words = optimized_prompt.split()
|
||||
optimized_prompt = ' '.join(words[:500])
|
||||
|
||||
logger.info(f"[PromptEnhancer] Generated prompt: {len(optimized_prompt.split())} words")
|
||||
return optimized_prompt
|
||||
|
||||
except HTTPException as http_err:
|
||||
# Propagate subscription limit errors (429) to frontend
|
||||
if http_err.status_code == 429:
|
||||
error_msg = self._extract_error_message(http_err)
|
||||
logger.warning(f"[PromptEnhancer] Subscription limit exceeded during prompt generation (HTTP 429): {error_msg}")
|
||||
# Re-raise to propagate to frontend for subscription modal
|
||||
raise
|
||||
else:
|
||||
# For other HTTP errors, log and fallback
|
||||
error_msg = self._extract_error_message(http_err)
|
||||
logger.error(f"[PromptEnhancer] Prompt generation failed (HTTP {http_err.status_code}): {error_msg}", exc_info=True)
|
||||
return self._generate_fallback_prompt(current_scene, story_context)
|
||||
except Exception as e:
|
||||
logger.error(f"[PromptEnhancer] Prompt generation failed: {str(e)}", exc_info=True)
|
||||
return self._generate_fallback_prompt(current_scene, story_context)
|
||||
|
||||
def _add_cinematic_details(
|
||||
self,
|
||||
current_scene: Dict[str, Any],
|
||||
story_context: Dict[str, Any]
|
||||
) -> str:
|
||||
"""Add cinematic details to enhance a too-short prompt."""
|
||||
return f"""
|
||||
|
||||
The scene unfolds with careful attention to visual storytelling. The {story_context.get('story_setting', 'environment')} serves as more than background - it actively participates in the narrative. Lighting and composition work together to emphasize the emotional weight of this moment, with camera movements that guide the viewer's attention naturally through the space. Every element - from the way light falls to the positioning of characters - contributes to the overall narrative impact.
|
||||
"""
|
||||
|
||||
def _extract_error_message(self, http_err: HTTPException) -> str:
|
||||
"""
|
||||
Extract meaningful error message from HTTPException.
|
||||
|
||||
Handles both dict-based details (from subscription limit errors) and string details.
|
||||
"""
|
||||
if isinstance(http_err.detail, dict):
|
||||
# For subscription limit errors, extract the 'message' or 'error' field
|
||||
return http_err.detail.get('message') or http_err.detail.get('error') or str(http_err.detail)
|
||||
elif isinstance(http_err.detail, str):
|
||||
return http_err.detail
|
||||
else:
|
||||
return str(http_err.detail)
|
||||
|
||||
def _generate_fallback_prompt(
|
||||
self,
|
||||
current_scene: Dict[str, Any],
|
||||
story_context: Dict[str, Any]
|
||||
) -> str:
|
||||
"""Generate a basic fallback prompt if AI enhancement fails."""
|
||||
scene_title = current_scene.get('title', 'Untitled Scene')
|
||||
scene_desc = current_scene.get('description', '')
|
||||
image_prompt = current_scene.get('image_prompt', '')
|
||||
setting = story_context.get('story_setting', 'the scene')
|
||||
tone = story_context.get('story_tone', 'engaging')
|
||||
|
||||
return f"""A cinematic scene titled "{scene_title}" set in {setting}. {scene_desc[:200]}.
|
||||
The scene features {', '.join(current_scene.get('character_descriptions', [])[:2]) if current_scene.get('character_descriptions') else 'the main characters'}.
|
||||
Visual style follows: {image_prompt[:150]}.
|
||||
The {tone} atmosphere is enhanced by natural lighting and dynamic camera movements that follow the action.
|
||||
Shot composition emphasizes the narrative importance of this moment, with careful framing that draws attention to key elements.
|
||||
The scene maintains visual consistency with previous moments while advancing the story's visual narrative."""
|
||||
|
||||
|
||||
def enhance_scene_prompt_for_video(
|
||||
current_scene: Dict[str, Any],
|
||||
story_context: Dict[str, Any],
|
||||
all_scenes: List[Dict[str, Any]],
|
||||
user_id: str
|
||||
) -> str:
|
||||
"""
|
||||
Convenience function to enhance a scene prompt for HunyuanVideo generation.
|
||||
|
||||
Args:
|
||||
current_scene: Scene data for the scene being processed
|
||||
story_context: Complete story context dictionary
|
||||
all_scenes: List of all scenes for consistency
|
||||
user_id: Clerk user ID for subscription checking
|
||||
|
||||
Returns:
|
||||
str: Optimized HunyuanVideo prompt
|
||||
"""
|
||||
service = PromptEnhancerService()
|
||||
return service.enhance_scene_prompt(current_scene, story_context, all_scenes, user_id)
|
||||
|
||||
@@ -41,6 +41,47 @@ class StoryVideoGenerationService:
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
return f"story_{clean_title}_{unique_id}.mp4"
|
||||
|
||||
def save_scene_video(self, video_bytes: bytes, scene_number: int, user_id: str) -> Dict[str, str]:
|
||||
"""
|
||||
Save individual scene video bytes to file.
|
||||
|
||||
Parameters:
|
||||
video_bytes: Raw video file bytes (mp4/webm format)
|
||||
scene_number: Scene number for naming
|
||||
user_id: Clerk user ID for naming
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: Video metadata with video_url and video_filename
|
||||
"""
|
||||
try:
|
||||
# Generate filename with scene number and user ID
|
||||
clean_user_id = "".join(c if c.isalnum() or c in ('-', '_') else '_' for c in user_id[:16])
|
||||
timestamp = str(uuid.uuid4())[:8]
|
||||
filename = f"scene_{scene_number}_{clean_user_id}_{timestamp}.mp4"
|
||||
|
||||
video_path = self.output_dir / filename
|
||||
|
||||
# Write video bytes to file
|
||||
with open(video_path, 'wb') as f:
|
||||
f.write(video_bytes)
|
||||
|
||||
file_size = video_path.stat().st_size
|
||||
logger.info(f"[StoryVideoGeneration] Saved scene {scene_number} video: {filename} ({file_size} bytes)")
|
||||
|
||||
# Generate URL path (relative to /api/story/videos/)
|
||||
video_url = f"/api/story/videos/{filename}"
|
||||
|
||||
return {
|
||||
"video_filename": filename,
|
||||
"video_url": video_url,
|
||||
"video_path": str(video_path),
|
||||
"file_size": file_size
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryVideoGeneration] Error saving scene video: {e}", exc_info=True)
|
||||
raise RuntimeError(f"Failed to save scene video: {str(e)}") from e
|
||||
|
||||
def generate_scene_video(
|
||||
self,
|
||||
scene: Dict[str, Any],
|
||||
@@ -125,12 +166,12 @@ class StoryVideoGenerationService:
|
||||
# Use provided duration or audio duration
|
||||
video_duration = duration if duration is not None else audio_duration
|
||||
|
||||
# Create image clip
|
||||
image_clip = ImageClip(str(image_file)).set_duration(video_duration)
|
||||
image_clip = image_clip.set_fps(fps)
|
||||
# Create image clip (MoviePy v2: use with_* API)
|
||||
image_clip = ImageClip(str(image_file)).with_duration(video_duration)
|
||||
image_clip = image_clip.with_fps(fps)
|
||||
|
||||
# Set audio to image clip
|
||||
video_clip = image_clip.set_audio(audio_clip)
|
||||
video_clip = image_clip.with_audio(audio_clip)
|
||||
|
||||
# Generate video filename
|
||||
video_filename = f"scene_{scene_number}_{scene_title.replace(' ', '_').replace('/', '_')[:50]}_{uuid.uuid4().hex[:8]}.mp4"
|
||||
@@ -274,12 +315,12 @@ class StoryVideoGenerationService:
|
||||
audio_clip = AudioFileClip(str(audio_file))
|
||||
audio_duration = audio_clip.duration
|
||||
|
||||
# Create image clip
|
||||
image_clip = ImageClip(str(image_file)).set_duration(audio_duration)
|
||||
image_clip = image_clip.set_fps(fps)
|
||||
# Create image clip (MoviePy v2: use with_* API)
|
||||
image_clip = ImageClip(str(image_file)).with_duration(audio_duration)
|
||||
image_clip = image_clip.with_fps(fps)
|
||||
|
||||
# Set audio to image clip
|
||||
video_clip = image_clip.set_audio(audio_clip)
|
||||
video_clip = image_clip.with_audio(audio_clip)
|
||||
scene_clips.append(video_clip)
|
||||
|
||||
total_duration += audio_duration
|
||||
|
||||
46
backend/services/story_writer/video_preflight.py
Normal file
46
backend/services/story_writer/video_preflight.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def log_video_stack_diagnostics() -> None:
|
||||
try:
|
||||
import sys
|
||||
import platform
|
||||
import importlib
|
||||
|
||||
mv = importlib.import_module("moviepy")
|
||||
im = importlib.import_module("imageio")
|
||||
try:
|
||||
import imageio_ffmpeg as iff
|
||||
ff = iff.get_ffmpeg_exe()
|
||||
except Exception:
|
||||
ff = "unresolved"
|
||||
logger.info(
|
||||
"[VideoStack] py={} plat={} moviepy={} imageio={} ffmpeg={}",
|
||||
sys.executable,
|
||||
platform.platform(),
|
||||
getattr(mv, "__version__", "NA"),
|
||||
getattr(im, "__version__", "NA"),
|
||||
ff,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("[VideoStack] diagnostics failed: {}", e)
|
||||
|
||||
|
||||
def assert_supported_moviepy() -> None:
|
||||
"""Fail fast if MoviePy isn't version 2.x."""
|
||||
try:
|
||||
import pkg_resources as pr
|
||||
mv = pr.get_distribution("moviepy").version
|
||||
if not mv.startswith("2."):
|
||||
raise RuntimeError(
|
||||
f"Unsupported MoviePy version {mv}. Expected 2.x. "
|
||||
"Please install with: pip install moviepy==2.1.2"
|
||||
)
|
||||
except Exception as e:
|
||||
# Log and re-raise so startup fails clearly
|
||||
logger.error("[VideoStack] version check failed: {}", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -694,6 +694,44 @@ class LimitValidator:
|
||||
|
||||
total_images = projected_images
|
||||
|
||||
# Check video generation limits
|
||||
elif provider == APIProvider.VIDEO:
|
||||
video_limit = limits.get('video_calls', 0) or 0
|
||||
total_video_calls = usage.video_calls or 0
|
||||
projected_video_calls = total_video_calls + 1
|
||||
|
||||
if video_limit > 0 and projected_video_calls > video_limit:
|
||||
error_info = {
|
||||
'current_calls': total_video_calls,
|
||||
'limit': video_limit,
|
||||
'provider': 'video',
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"Video generation limit would be exceeded. Would use {projected_video_calls} of {video_limit} videos this billing period.", {
|
||||
'error_type': 'video_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
# Check image editing limits
|
||||
elif provider == APIProvider.IMAGE_EDIT:
|
||||
image_edit_limit = limits.get('image_edit_calls', 0) or 0
|
||||
total_image_edit_calls = getattr(usage, 'image_edit_calls', 0) or 0
|
||||
projected_image_edit_calls = total_image_edit_calls + 1
|
||||
|
||||
if image_edit_limit > 0 and projected_image_edit_calls > image_edit_limit:
|
||||
error_info = {
|
||||
'current_calls': total_image_edit_calls,
|
||||
'limit': image_edit_limit,
|
||||
'provider': 'image_edit',
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"Image editing limit would be exceeded. Would use {projected_image_edit_calls} of {image_edit_limit} image edits this billing period.", {
|
||||
'error_type': 'image_edit_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
# Check other provider-specific limits
|
||||
else:
|
||||
provider_calls_key = f"{provider_name}_calls"
|
||||
|
||||
@@ -299,3 +299,124 @@ def validate_image_generation_operations(
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def validate_image_editing_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Validate image editing operation before making API calls.
|
||||
|
||||
Args:
|
||||
pricing_service: PricingService instance
|
||||
user_id: User ID for subscription checking
|
||||
|
||||
Returns:
|
||||
None - raises HTTPException with 429 status if validation fails
|
||||
"""
|
||||
try:
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.IMAGE_EDIT,
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'image_edit',
|
||||
'operation_type': 'image_editing'
|
||||
}
|
||||
]
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.error(f"[Pre-flight Validator] Image editing blocked for user {user_id}: {message}")
|
||||
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', 'image_edit') if usage_info else 'image_edit'
|
||||
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ Image editing validated for user {user_id}")
|
||||
# Validation passed - no return needed (function raises HTTPException if validation fails)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating image editing: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
'error': f"Failed to validate image editing: {str(e)}",
|
||||
'message': f"Failed to validate image editing: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def validate_video_generation_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Validate video generation operation before making API calls.
|
||||
|
||||
Args:
|
||||
pricing_service: PricingService instance
|
||||
user_id: User ID for subscription checking
|
||||
|
||||
Returns:
|
||||
None - raises HTTPException with 429 status if validation fails
|
||||
"""
|
||||
try:
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.VIDEO,
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'video',
|
||||
'operation_type': 'video_generation'
|
||||
}
|
||||
]
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.error(f"[Pre-flight Validator] Video generation blocked for user {user_id}: {message}")
|
||||
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', 'video') if usage_info else 'video'
|
||||
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ Video generation validated for user {user_id}")
|
||||
# Validation passed - no return needed (function raises HTTPException if validation fails)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating video generation: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
'error': f"Failed to validate video generation: {str(e)}",
|
||||
'message': f"Failed to validate video generation: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
@@ -295,10 +295,22 @@ class PricingService:
|
||||
"model_name": "exa-search",
|
||||
"cost_per_request": 0.005, # $0.005 per search (1-25 results)
|
||||
"description": "Exa Neural Search API"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.VIDEO,
|
||||
"model_name": "tencent/HunyuanVideo",
|
||||
"cost_per_request": 0.10, # $0.10 per video generation (estimated)
|
||||
"description": "HuggingFace AI Video Generation (HunyuanVideo)"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.VIDEO,
|
||||
"model_name": "default",
|
||||
"cost_per_request": 0.10, # $0.10 per video generation (estimated)
|
||||
"description": "AI Video Generation default pricing"
|
||||
}
|
||||
]
|
||||
|
||||
# Combine all pricing data
|
||||
# Combine all pricing data (include video pricing in search_pricing list)
|
||||
all_pricing = gemini_pricing + openai_pricing + anthropic_pricing + mistral_pricing + search_pricing
|
||||
|
||||
# Insert or update pricing data
|
||||
@@ -344,6 +356,8 @@ class PricingService:
|
||||
"firecrawl_calls_limit": 10,
|
||||
"stability_calls_limit": 5,
|
||||
"exa_calls_limit": 100,
|
||||
"video_calls_limit": 0, # No video generation for free tier
|
||||
"image_edit_calls_limit": 10, # 10 AI image editing calls/month
|
||||
"gemini_tokens_limit": 100000,
|
||||
"monthly_cost_limit": 0.0,
|
||||
"features": ["basic_content_generation", "limited_research"],
|
||||
@@ -365,6 +379,8 @@ class PricingService:
|
||||
"firecrawl_calls_limit": 100,
|
||||
"stability_calls_limit": 5,
|
||||
"exa_calls_limit": 500,
|
||||
"video_calls_limit": 20, # 20 videos/month for basic plan
|
||||
"image_edit_calls_limit": 30, # 30 AI image editing calls/month
|
||||
"gemini_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"openai_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"anthropic_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
@@ -388,6 +404,8 @@ class PricingService:
|
||||
"firecrawl_calls_limit": 500,
|
||||
"stability_calls_limit": 200,
|
||||
"exa_calls_limit": 2000,
|
||||
"video_calls_limit": 50, # 50 videos/month for pro plan
|
||||
"image_edit_calls_limit": 100, # 100 AI image editing calls/month
|
||||
"gemini_tokens_limit": 5000000,
|
||||
"openai_tokens_limit": 2500000,
|
||||
"anthropic_tokens_limit": 1000000,
|
||||
@@ -411,6 +429,8 @@ class PricingService:
|
||||
"firecrawl_calls_limit": 0,
|
||||
"stability_calls_limit": 0,
|
||||
"exa_calls_limit": 0, # Unlimited
|
||||
"video_calls_limit": 0, # Unlimited for enterprise
|
||||
"image_edit_calls_limit": 0, # Unlimited image editing for enterprise
|
||||
"gemini_tokens_limit": 0,
|
||||
"openai_tokens_limit": 0,
|
||||
"anthropic_tokens_limit": 0,
|
||||
@@ -429,6 +449,20 @@ class PricingService:
|
||||
if not existing:
|
||||
plan = SubscriptionPlan(**plan_data)
|
||||
self.db.add(plan)
|
||||
else:
|
||||
# Update existing plan with new limits (e.g., image_edit_calls_limit)
|
||||
# This ensures existing plans get new columns like image_edit_calls_limit
|
||||
for key, value in plan_data.items():
|
||||
if key not in ["name", "tier"]: # Don't overwrite name/tier
|
||||
try:
|
||||
# Try to set the attribute (works even if column was just added)
|
||||
setattr(existing, key, value)
|
||||
except (AttributeError, Exception) as e:
|
||||
# If attribute doesn't exist yet (column not migrated), skip it
|
||||
# Schema migration will add it, then this will update it on next run
|
||||
logger.debug(f"Could not set {key} on plan {existing.name}: {e}")
|
||||
existing.updated_at = datetime.utcnow()
|
||||
logger.debug(f"Updated existing plan: {existing.name}")
|
||||
|
||||
self.db.commit()
|
||||
logger.debug("Default subscription plans initialized")
|
||||
@@ -615,6 +649,8 @@ class PricingService:
|
||||
'metaphor_calls': plan.metaphor_calls_limit,
|
||||
'firecrawl_calls': plan.firecrawl_calls_limit,
|
||||
'stability_calls': plan.stability_calls_limit,
|
||||
'video_calls': getattr(plan, 'video_calls_limit', 0), # Support missing column
|
||||
'image_edit_calls': getattr(plan, 'image_edit_calls_limit', 0), # Support missing column
|
||||
# Token limits
|
||||
'gemini_tokens': plan.gemini_tokens_limit,
|
||||
'openai_tokens': plan.openai_tokens_limit,
|
||||
|
||||
@@ -29,6 +29,8 @@ def ensure_subscription_plan_columns(db: Session) -> None:
|
||||
# Columns we may reference in models but might be missing in older DBs
|
||||
required_columns = {
|
||||
"exa_calls_limit": "INTEGER DEFAULT 0",
|
||||
"video_calls_limit": "INTEGER DEFAULT 0",
|
||||
"image_edit_calls_limit": "INTEGER DEFAULT 0",
|
||||
}
|
||||
|
||||
for col_name, ddl in required_columns.items():
|
||||
@@ -78,6 +80,10 @@ def ensure_usage_summaries_columns(db: Session) -> None:
|
||||
required_columns = {
|
||||
"exa_calls": "INTEGER DEFAULT 0",
|
||||
"exa_cost": "REAL DEFAULT 0.0",
|
||||
"video_calls": "INTEGER DEFAULT 0",
|
||||
"video_cost": "REAL DEFAULT 0.0",
|
||||
"image_edit_calls": "INTEGER DEFAULT 0",
|
||||
"image_edit_cost": "REAL DEFAULT 0.0",
|
||||
}
|
||||
|
||||
for col_name, ddl in required_columns.items():
|
||||
|
||||
@@ -608,6 +608,12 @@ class UsageTrackingService:
|
||||
# Reset image generation counters
|
||||
summary.stability_calls = 0
|
||||
|
||||
# Reset video generation counters
|
||||
summary.video_calls = 0
|
||||
|
||||
# Reset image editing counters
|
||||
summary.image_edit_calls = 0
|
||||
|
||||
# Reset cost counters
|
||||
summary.gemini_cost = 0.0
|
||||
summary.openai_cost = 0.0
|
||||
@@ -618,6 +624,9 @@ class UsageTrackingService:
|
||||
summary.metaphor_cost = 0.0
|
||||
summary.firecrawl_cost = 0.0
|
||||
summary.stability_cost = 0.0
|
||||
summary.exa_cost = 0.0
|
||||
summary.video_cost = 0.0
|
||||
summary.image_edit_cost = 0.0
|
||||
|
||||
# Reset totals
|
||||
summary.total_calls = 0
|
||||
|
||||
Reference in New Issue
Block a user