AI story writer enhancements, text to video and voice generation, subscription management, and more.

This commit is contained in:
ajaysi
2025-11-19 09:55:32 +05:30
parent bf7493c366
commit e96525347b
64 changed files with 10367 additions and 400 deletions

View File

@@ -0,0 +1,301 @@
"""
Main Audio Generation Service for ALwrity Backend.
This service provides AI-powered text-to-speech functionality using WaveSpeed Minimax Speech 02 HD.
"""
from __future__ import annotations
import sys
from typing import Optional, Dict, Any
from datetime import datetime
from loguru import logger
from fastapi import HTTPException
from services.wavespeed.client import WaveSpeedClient
from services.onboarding.api_key_manager import APIKeyManager
from utils.logger_utils import get_service_logger
logger = get_service_logger("audio_generation")
class AudioGenerationResult:
"""Result of audio generation."""
def __init__(
self,
audio_bytes: bytes,
provider: str,
model: str,
voice_id: str,
text_length: int,
file_size: int,
):
self.audio_bytes = audio_bytes
self.provider = provider
self.model = model
self.voice_id = voice_id
self.text_length = text_length
self.file_size = file_size
def generate_audio(
text: str,
voice_id: str = "Wise_Woman",
speed: float = 1.0,
volume: float = 1.0,
pitch: float = 0.0,
emotion: str = "happy",
user_id: Optional[str] = None,
**kwargs
) -> AudioGenerationResult:
"""
Generate audio using AI text-to-speech with subscription tracking.
Args:
text: Text to convert to speech (max 10000 characters)
voice_id: Voice ID (default: "Wise_Woman")
speed: Speech speed (0.5-2.0, default: 1.0)
volume: Speech volume (0.1-10.0, default: 1.0)
pitch: Speech pitch (-12 to 12, default: 0.0)
emotion: Emotion (default: "happy")
user_id: User ID for subscription checking (required)
**kwargs: Additional parameters (sample_rate, bitrate, format, etc.)
Returns:
AudioGenerationResult: Generated audio result
Raises:
RuntimeError: If subscription limits are exceeded or user_id is missing.
"""
try:
logger.info("[audio_gen] Starting audio generation")
logger.debug(f"[audio_gen] Text length: {len(text)} characters, voice: {voice_id}")
# SUBSCRIPTION CHECK - Required and strict enforcement
if not user_id:
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
# Calculate cost based on character count (every character is 1 token)
# Pricing: $0.05 per 1,000 characters
character_count = len(text)
cost_per_1000_chars = 0.05
estimated_cost = (character_count / 1000.0) * cost_per_1000_chars
try:
from services.database import get_db
from services.subscription import PricingService
from models.subscription_models import UsageSummary, APIProvider
db = next(get_db())
try:
pricing_service = PricingService(db)
# Check limits using sync method from pricing service (strict enforcement)
# Use AUDIO provider for audio generation
can_proceed, message, usage_info = pricing_service.check_usage_limits(
user_id=user_id,
provider=APIProvider.AUDIO,
tokens_requested=character_count, # Use character count as "tokens" for audio
actual_provider_name="wavespeed" # Actual provider is WaveSpeed
)
if not can_proceed:
logger.warning(f"[audio_gen] Subscription limit exceeded for user {user_id}: {message}")
error_detail = {
'error': message,
'message': message,
'provider': 'wavespeed',
'usage_info': usage_info if usage_info else {}
}
raise HTTPException(status_code=429, detail=error_detail)
# Get current usage for limit checking
current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
usage = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
finally:
db.close()
except HTTPException:
raise
except RuntimeError:
raise
except Exception as sub_error:
logger.error(f"[audio_gen] Subscription check failed for user {user_id}: {sub_error}")
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
# Generate audio using WaveSpeed
try:
client = WaveSpeedClient()
audio_bytes = client.generate_speech(
text=text,
voice_id=voice_id,
speed=speed,
volume=volume,
pitch=pitch,
emotion=emotion,
enable_sync_mode=True,
**kwargs
)
logger.info(f"[audio_gen] ✅ API call successful, generated {len(audio_bytes)} bytes")
except HTTPException:
raise
except Exception as api_error:
logger.error(f"[audio_gen] Audio generation API failed: {api_error}")
raise HTTPException(
status_code=502,
detail={
"error": "Audio generation failed",
"message": str(api_error)
}
)
# TRACK USAGE after successful API call
if audio_bytes:
logger.info(f"[audio_gen] ✅ API call successful, tracking usage for user {user_id}")
try:
db_track = next(get_db())
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get or create usage summary
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush()
# Get current values before update
current_calls_before = getattr(summary, "audio_calls", 0) or 0
current_cost_before = getattr(summary, "audio_cost", 0.0) or 0.0
# Update audio calls and cost
new_calls = current_calls_before + 1
new_cost = current_cost_before + estimated_cost
# Use direct SQL UPDATE for dynamic attributes
from sqlalchemy import text
update_query = text("""
UPDATE usage_summaries
SET audio_calls = :new_calls,
audio_cost = :new_cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_query, {
'new_calls': new_calls,
'new_cost': new_cost,
'user_id': user_id,
'period': current_period
})
# Update total cost
summary.total_cost = (summary.total_cost or 0.0) + estimated_cost
summary.total_calls = (summary.total_calls or 0) + 1
summary.updated_at = datetime.utcnow()
# Create usage log
usage_log = APIUsageLog(
user_id=user_id,
provider=APIProvider.AUDIO,
endpoint="/audio-generation/wavespeed",
method="POST",
model_used="minimax/speech-02-hd",
tokens_input=character_count,
tokens_output=0,
tokens_total=character_count,
cost_input=0.0,
cost_output=0.0,
cost_total=estimated_cost,
response_time=0.0,
status_code=200,
request_size=len(text.encode("utf-8")),
response_size=len(audio_bytes),
billing_period=current_period,
)
db_track.add(usage_log)
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
# Get related stats for unified log
current_image_calls = getattr(summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
db_track.commit()
logger.info(f"[audio_gen] ✅ Successfully tracked usage: user {user_id} -> audio -> {new_calls} calls, ${estimated_cost:.4f}")
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
print(f"""
[SUBSCRIPTION] Audio Generation
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: wavespeed
├─ Actual Provider: wavespeed
├─ Model: minimax/speech-02-hd
├─ Voice: {voice_id}
├─ Calls: {current_calls_before}{new_calls} / {audio_limit_display}
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
├─ Characters: {character_count}
├─ Images: {current_image_calls} / {image_limit if image_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""", flush=True)
sys.stdout.flush()
except Exception as track_error:
logger.error(f"[audio_gen] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
logger.error(f"[audio_gen] ❌ Failed to track usage: {usage_error}", exc_info=True)
return AudioGenerationResult(
audio_bytes=audio_bytes,
provider="wavespeed",
model="minimax/speech-02-hd",
voice_id=voice_id,
text_length=character_count,
file_size=len(audio_bytes),
)
except HTTPException:
raise
except RuntimeError:
raise
except Exception as e:
logger.error(f"[audio_gen] Error generating audio: {e}")
raise HTTPException(
status_code=500,
detail={
"error": "Audio generation failed",
"message": str(e)
}
)

View File

@@ -515,6 +515,12 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
# Get audio stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
# CRITICAL DEBUG: Print diagnostic info BEFORE commit (always visible, flushed immediately)
import sys
debug_msg = f"[DEBUG] BEFORE COMMIT - Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}, Period: {current_period}, New calls will be: {new_calls}, New tokens will be: {new_tokens}"
@@ -571,6 +577,8 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
├─ Tokens: {current_tokens_before}{new_tokens} / {token_limit if token_limit > 0 else ''}
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
├─ Audio: {current_audio_calls} / {audio_limit_display}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
@@ -819,6 +827,12 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
# Get audio stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
db_track.commit() # Commit transaction to make changes visible to other sessions
@@ -838,6 +852,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
├─ Audio: {current_audio_calls} / {audio_limit_display}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:

View File

@@ -10,6 +10,7 @@ from __future__ import annotations
import os
import base64
import io
import sys
from typing import Any, Dict, Optional, Union
from fastapi import HTTPException
@@ -22,11 +23,11 @@ except ImportError:
InferenceClient = None
from ..onboarding.api_key_manager import APIKeyManager
from services.subscription import PricingService
from utils.logger_utils import get_service_logger
logger = get_service_logger("video_generation_service")
class VideoProviderNotImplemented(Exception):
pass
@@ -48,44 +49,80 @@ def _get_api_key(provider: str) -> Optional[str]:
def _coerce_video_bytes(output: Any) -> bytes:
"""
Normalizes the different return shapes that huggingface_hub may emit for video tasks.
Depending on the provider/library version we may get:
- raw bytes
- an object with `.video` or `.bytes` attributes (plus optional `.save`)
- a dict containing a `video` key with bytes/base64 data
According to HF docs, text_to_video() should return bytes directly.
"""
data: Union[bytes, bytearray, memoryview, io.BufferedIOBase, None] = None
logger.debug(f"[video_gen] _coerce_video_bytes received type: {type(output)}")
# Most common case: bytes directly
if isinstance(output, (bytes, bytearray, memoryview)):
logger.debug(f"[video_gen] Output is bytes: {len(output)} bytes")
return bytes(output)
# Handle file-like objects
if hasattr(output, "read"):
logger.debug("[video_gen] Output has read() method, reading...")
data = output.read()
if isinstance(data, (bytes, bytearray, memoryview)):
return bytes(data)
raise TypeError(f"File-like object returned non-bytes: {type(data)}")
# Objects with direct attribute access
if hasattr(output, "video"):
logger.debug("[video_gen] Output has 'video' attribute")
data = getattr(output, "video")
elif hasattr(output, "bytes"):
if isinstance(data, (bytes, bytearray, memoryview)):
return bytes(data)
if hasattr(data, "read"):
return bytes(data.read())
if hasattr(output, "bytes"):
logger.debug("[video_gen] Output has 'bytes' attribute")
data = getattr(output, "bytes")
elif isinstance(output, dict) and "video" in output:
data = output["video"]
else:
data = output
if isinstance(data, (bytes, bytearray, memoryview)):
return bytes(data)
if hasattr(data, "read"):
return bytes(data.read())
# Handle file-like responses
if hasattr(data, "read"):
data = data.read()
# Dict handling - but this shouldn't happen with text_to_video()
if isinstance(output, dict):
logger.warning(f"[video_gen] Received dict output (unexpected): keys={list(output.keys())}")
# Try to get video key safely - use .get() to avoid KeyError
data = output.get("video")
if data is not None:
if isinstance(data, (bytes, bytearray, memoryview)):
return bytes(data)
if hasattr(data, "read"):
return bytes(data.read())
# Try other common keys
for key in ["data", "content", "file", "result", "output"]:
data = output.get(key)
if data is not None:
if isinstance(data, (bytes, bytearray, memoryview)):
return bytes(data)
if hasattr(data, "read"):
return bytes(data.read())
raise TypeError(f"Dict output has no recognized video key. Keys: {list(output.keys())}")
if isinstance(data, (bytes, bytearray, memoryview)):
return bytes(data)
if isinstance(data, str):
# Expecting data URI or raw base64 string
if data.startswith("data:"):
_, encoded = data.split(",", 1)
# String handling (base64)
if isinstance(output, str):
logger.debug("[video_gen] Output is string, attempting base64 decode")
if output.startswith("data:"):
_, encoded = output.split(",", 1)
return base64.b64decode(encoded)
try:
return base64.b64decode(data)
return base64.b64decode(output)
except Exception as exc:
raise TypeError(f"Unable to decode string video payload: {exc}") from exc
raise TypeError(f"Unsupported video payload type: {type(data)}")
# Fallback: try to use output directly
logger.warning(f"[video_gen] Unexpected output type: {type(output)}, attempting direct conversion")
try:
if hasattr(output, "__bytes__"):
return bytes(output)
except Exception:
pass
raise TypeError(f"Unsupported video payload type: {type(output)}. Output: {str(output)[:200]}")
def _generate_with_huggingface(
@@ -96,7 +133,6 @@ def _generate_with_huggingface(
negative_prompt: Optional[str] = None,
seed: Optional[int] = None,
model: str = "tencent/HunyuanVideo",
input_image_bytes: Optional[bytes] = None,
) -> bytes:
"""
Generates video bytes using Hugging Face's InferenceClient.
@@ -109,7 +145,6 @@ def _generate_with_huggingface(
raise RuntimeError("HF token not configured. Set an hf_token in APIKeyManager.")
client = InferenceClient(
model=model,
provider="fal-ai",
token=token,
)
@@ -126,26 +161,25 @@ def _generate_with_huggingface(
params["seed"] = seed
logger.info(
"[video_gen] HuggingFace request model=%s frames=%s steps=%s mode=%s",
"[video_gen] HuggingFace request model=%s frames=%s steps=%s mode=text-to-video",
model,
num_frames,
num_inference_steps,
"image-to-video" if input_image_bytes else "text-to-video",
)
try:
call_kwargs = {**params, "model": model}
if input_image_bytes:
video_output = client.image_to_video(
image=input_image_bytes,
prompt=prompt,
**call_kwargs,
)
else:
video_output = client.text_to_video(
prompt,
**call_kwargs,
)
logger.info("[video_gen] Calling client.text_to_video()...")
video_output = client.text_to_video(
prompt=prompt,
model=model,
**params,
)
logger.info(f"[video_gen] text_to_video() returned type: {type(video_output)}")
if isinstance(video_output, dict):
logger.info(f"[video_gen] Dict keys: {list(video_output.keys())}")
elif hasattr(video_output, "__dict__"):
logger.info(f"[video_gen] Object attributes: {dir(video_output)}")
video_bytes = _coerce_video_bytes(video_output)
@@ -158,6 +192,15 @@ def _generate_with_huggingface(
logger.info(f"[video_gen] Successfully generated video: {len(video_bytes)} bytes")
return video_bytes
except KeyError as e:
error_msg = str(e)
logger.error(f"[video_gen] HF KeyError: {error_msg}", exc_info=True)
logger.error(f"[video_gen] This suggests the API response format is unexpected. Check logs above for response type.")
raise HTTPException(status_code=502, detail={
"error": f"Hugging Face API returned unexpected response format: {error_msg}",
"error_type": "KeyError",
"hint": "The API response may have changed. Check server logs for details."
})
except Exception as e:
error_msg = str(e)
error_type = type(e).__name__
@@ -179,7 +222,6 @@ def ai_video_generate(
prompt: str,
provider: str = "huggingface",
user_id: Optional[str] = None,
input_image_bytes: Optional[bytes] = None,
**kwargs,
) -> bytes:
"""
@@ -187,7 +229,6 @@ def ai_video_generate(
- provider: 'huggingface' (default), 'gemini' (veo3 stub), 'openai' (sora stub)
- kwargs: num_frames, guidance_scale, num_inference_steps, negative_prompt, seed, model
- input_image_bytes: optional bytes for image-to-video flows (uses image as motion anchor)
Returns raw video bytes (mp4/webm depending on provider).
"""
@@ -200,7 +241,6 @@ def ai_video_generate(
# PRE-FLIGHT VALIDATION: Validate video generation before API call
# MUST happen BEFORE any API calls - return immediately if validation fails
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_video_generation_operations
from fastapi import HTTPException
@@ -227,7 +267,6 @@ def ai_video_generate(
if provider == "huggingface":
video_bytes = _generate_with_huggingface(
prompt=prompt,
input_image_bytes=input_image_bytes,
**kwargs,
)
elif provider == "gemini":
@@ -237,112 +276,14 @@ def ai_video_generate(
else:
raise RuntimeError(f"Unknown video provider: {provider}")
# Track usage AFTER successful generation
db_track = next(get_db())
try:
from models.subscription_models import APIProvider, UsageSummary, APIUsageLog
from datetime import datetime
from services.subscription import PricingService
# Create pricing service for tracking (uses same DB session)
pricing_service_track = PricingService(db_track)
# Get current billing period
current_period = pricing_service_track.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get or create usage summary
usage_summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not usage_summary:
usage_summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(usage_summary)
db_track.commit()
# Calculate cost using pricing service
cost_info = pricing_service_track.get_pricing_for_provider_model(
APIProvider.VIDEO,
model_name
)
cost_per_video = cost_info.get('cost_per_request', 0.10) if cost_info else 0.10
# Get "before" state for unified log
current_video_calls_before = getattr(usage_summary, 'video_calls', 0) or 0
current_video_cost = getattr(usage_summary, 'video_cost', 0.0) or 0.0
# Increment video_calls and track cost
new_video_calls = current_video_calls_before + 1
usage_summary.video_calls = new_video_calls
usage_summary.video_cost = current_video_cost + cost_per_video
usage_summary.total_calls = (usage_summary.total_calls or 0) + 1
usage_summary.total_cost = (usage_summary.total_cost or 0.0) + cost_per_video
# Get plan details for unified log (before commit, in case commit fails)
limits = pricing_service_track.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
# Get image and image editing stats for unified log
current_image_calls = getattr(usage_summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
current_image_edit_calls = getattr(usage_summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
# Create usage log entry for audit trail
usage_log = APIUsageLog(
user_id=user_id,
provider=APIProvider.VIDEO,
endpoint=f"/video-generation/{provider}",
method="POST",
model_used=model_name,
tokens_input=0,
tokens_output=0,
tokens_total=0,
cost_input=0.0,
cost_output=0.0,
cost_total=cost_per_video,
response_time=0.0, # Could track actual time if needed
status_code=200,
request_size=len(prompt.encode('utf-8')),
response_size=len(video_bytes),
billing_period=current_period
)
db_track.add(usage_log)
db_track.commit()
logger.info(f"[video_gen] ✅ Successfully tracked usage: user {user_id} -> 1 video call, ${cost_per_video:.4f} cost")
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
# Flush immediately to ensure it's visible in console/logs
import sys
log_message = f"""
[SUBSCRIPTION] Video Generation
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: video
├─ Actual Provider: {provider}
├─ Model: {model_name or 'default'}
├─ Calls: {current_video_calls_before}{new_video_calls} / {video_limit if video_limit > 0 else ''}
├─ Images: {current_image_calls} / {image_limit if image_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
"""
print(log_message, flush=True)
sys.stdout.flush()
except Exception as track_error:
logger.error(f"[video_gen] Error tracking usage: {track_error}", exc_info=True)
db_track.rollback()
# Don't fail video generation if tracking fails - video is already generated
finally:
db_track.close()
track_video_usage(
user_id=user_id,
provider=provider,
model_name=model_name,
prompt=prompt,
video_bytes=video_bytes,
)
return video_bytes
except HTTPException:
@@ -353,3 +294,139 @@ def ai_video_generate(
raise HTTPException(status_code=500, detail={"error": str(e)})
def track_video_usage(
*,
user_id: str,
provider: str,
model_name: str,
prompt: str,
video_bytes: bytes,
cost_override: Optional[float] = None,
) -> Dict[str, Any]:
"""
Track subscription usage for any video generation (text-to-video or image-to-video).
"""
from datetime import datetime
from models.subscription_models import APIProvider, APIUsageLog, UsageSummary
from services.database import get_db
db_track = next(get_db())
try:
logger.info(f"[video_gen] Starting usage tracking for user={user_id}, provider={provider}, model={model_name}")
pricing_service_track = PricingService(db_track)
current_period = pricing_service_track.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
logger.debug(f"[video_gen] Billing period: {current_period}")
usage_summary = (
db_track.query(UsageSummary)
.filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period,
)
.first()
)
if not usage_summary:
logger.debug(f"[video_gen] Creating new UsageSummary for user={user_id}, period={current_period}")
usage_summary = UsageSummary(
user_id=user_id,
billing_period=current_period,
)
db_track.add(usage_summary)
db_track.commit()
db_track.refresh(usage_summary)
else:
logger.debug(f"[video_gen] Found existing UsageSummary: video_calls={getattr(usage_summary, 'video_calls', 0)}")
cost_info = pricing_service_track.get_pricing_for_provider_model(
APIProvider.VIDEO,
model_name,
)
default_cost = 0.10
if cost_info and cost_info.get("cost_per_request") is not None:
default_cost = cost_info["cost_per_request"]
cost_per_video = cost_override if cost_override is not None else default_cost
logger.debug(f"[video_gen] Cost per video: ${cost_per_video} (override={cost_override}, default={default_cost})")
current_video_calls_before = getattr(usage_summary, "video_calls", 0) or 0
current_video_cost = getattr(usage_summary, "video_cost", 0.0) or 0.0
usage_summary.video_calls = current_video_calls_before + 1
usage_summary.video_cost = current_video_cost + cost_per_video
usage_summary.total_calls = (usage_summary.total_calls or 0) + 1
usage_summary.total_cost = (usage_summary.total_cost or 0.0) + cost_per_video
# Ensure the object is in the session
db_track.add(usage_summary)
logger.debug(f"[video_gen] Updated usage_summary: video_calls={current_video_calls_before}{usage_summary.video_calls}")
limits = pricing_service_track.get_user_limits(user_id)
plan_name = limits.get("plan_name", "unknown") if limits else "unknown"
tier = limits.get("tier", "unknown") if limits else "unknown"
video_limit = limits["limits"].get("video_calls", 0) if limits else 0
current_image_calls = getattr(usage_summary, "stability_calls", 0) or 0
image_limit = limits["limits"].get("stability_calls", 0) if limits else 0
current_image_edit_calls = getattr(usage_summary, "image_edit_calls", 0) or 0
image_edit_limit = limits["limits"].get("image_edit_calls", 0) if limits else 0
current_audio_calls = getattr(usage_summary, "audio_calls", 0) or 0
audio_limit = limits["limits"].get("audio_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
usage_log = APIUsageLog(
user_id=user_id,
provider=APIProvider.VIDEO,
endpoint=f"/video-generation/{provider}",
method="POST",
model_used=model_name,
tokens_input=0,
tokens_output=0,
tokens_total=0,
cost_input=0.0,
cost_output=0.0,
cost_total=cost_per_video,
response_time=0.0,
status_code=200,
request_size=len(prompt.encode("utf-8")),
response_size=len(video_bytes),
billing_period=current_period,
)
db_track.add(usage_log)
logger.debug(f"[video_gen] Flushing changes before commit...")
db_track.flush()
logger.debug(f"[video_gen] Committing usage tracking changes...")
db_track.commit()
db_track.refresh(usage_summary)
logger.debug(f"[video_gen] Commit successful. Final video_calls: {usage_summary.video_calls}, video_cost: {usage_summary.video_cost}")
video_limit_display = video_limit if video_limit > 0 else ''
log_message = f"""
[SUBSCRIPTION] Video Generation
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: video
├─ Actual Provider: {provider}
├─ Model: {model_name or 'default'}
├─ Calls: {current_video_calls_before}{usage_summary.video_calls} / {video_limit_display}
├─ Images: {current_image_calls} / {image_limit if image_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Audio: {current_audio_calls} / {audio_limit_display}
└─ Status: ✅ Allowed & Tracked
"""
logger.info(log_message)
return {
"previous_calls": current_video_calls_before,
"current_calls": usage_summary.video_calls,
"video_limit": video_limit,
"video_limit_display": video_limit_display,
"cost_per_video": cost_per_video,
"total_video_cost": usage_summary.video_cost,
}
except Exception as track_error:
logger.error(f"[video_gen] Error tracking usage: {track_error}", exc_info=True)
logger.error(f"[video_gen] Exception type: {type(track_error).__name__}", exc_info=True)
db_track.rollback()
finally:
db_track.close()