Save local changes (GSC/Bing integrations) before merging PR #354
This commit is contained in:
@@ -36,6 +36,172 @@ class VideoProviderNotImplemented(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _track_video_operation_usage(
|
||||
user_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
operation_type: str,
|
||||
result_bytes: bytes,
|
||||
cost: float,
|
||||
prompt: Optional[str] = None,
|
||||
endpoint: str = "/video-generation",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
log_prefix: str = "[Video Generation]",
|
||||
response_time: float = 0.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Reusable usage tracking helper for all video operations.
|
||||
|
||||
Args:
|
||||
user_id: User ID for tracking
|
||||
provider: Provider name
|
||||
model: Model name used
|
||||
operation_type: Type of operation (for logging)
|
||||
result_bytes: Generated video bytes
|
||||
cost: Cost of the operation
|
||||
prompt: Optional prompt text
|
||||
endpoint: API endpoint path
|
||||
metadata: Optional additional metadata
|
||||
log_prefix: Logging prefix
|
||||
response_time: API response time
|
||||
|
||||
Returns:
|
||||
Dictionary with tracking information
|
||||
"""
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
db_track = get_session_for_user(user_id)
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get or create usage summary
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
# Get current values before update
|
||||
current_calls_before = getattr(summary, "video_calls", 0) or 0
|
||||
current_cost_before = getattr(summary, "video_cost", 0.0) or 0.0
|
||||
|
||||
# Update video calls and cost
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + cost
|
||||
|
||||
# Use direct SQL UPDATE for dynamic attributes
|
||||
from sqlalchemy import text as sql_text
|
||||
update_query = sql_text("""
|
||||
UPDATE usage_summaries
|
||||
SET video_calls = :new_calls,
|
||||
video_cost = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
'new_calls': new_calls,
|
||||
'new_cost': new_cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
# Update total cost
|
||||
summary.total_cost = (summary.total_cost or 0.0) + cost
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
# Create usage log
|
||||
request_size = len(prompt.encode("utf-8")) if prompt else 0
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.WAVESPEED, # Default for video
|
||||
endpoint=endpoint,
|
||||
method="POST",
|
||||
model_used=model or "unknown",
|
||||
actual_provider_name=provider,
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=cost,
|
||||
response_time=response_time,
|
||||
status_code=200,
|
||||
request_size=request_size,
|
||||
response_size=len(result_bytes) if result_bytes else 0,
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
|
||||
# Get limits for display
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
video_limit_display = video_limit if (video_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# Get related stats for unified log
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
image_edit_limit_display = image_edit_limit if (image_edit_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"{log_prefix} ✅ Successfully tracked usage: user {user_id} -> {operation_type} -> {new_calls} calls, ${cost:.4f}")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG
|
||||
operation_name = operation_type.replace("-", " ").title()
|
||||
print(f"""
|
||||
[SUBSCRIPTION] {operation_name}
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {provider}
|
||||
├─ Actual Provider: {provider}
|
||||
├─ Model: {model or 'unknown'}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {video_limit_display}
|
||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit_display}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit_display}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
return {
|
||||
"current_calls": new_calls,
|
||||
"cost": cost,
|
||||
"total_cost": new_cost,
|
||||
}
|
||||
|
||||
except Exception as track_error:
|
||||
logger.error(f"{log_prefix} ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||
db_track.rollback()
|
||||
return {}
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"{log_prefix} ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||
return {}
|
||||
|
||||
|
||||
def _get_api_key(provider: str) -> Optional[str]:
|
||||
try:
|
||||
manager = APIKeyManager()
|
||||
@@ -500,156 +666,74 @@ async def ai_video_generate(
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
logger.info(f"[Video Generation] ✅ Pre-flight validation passed - proceeding with {operation_type}")
|
||||
|
||||
# Progress callback: Initial submission
|
||||
if progress_callback:
|
||||
progress_callback(10.0, f"Submitting {operation_type} request to {provider}...")
|
||||
|
||||
# Generate video based on operation type
|
||||
model_name = kwargs.get("model", _get_default_model(operation_type, provider))
|
||||
|
||||
# Track response time for video generation
|
||||
|
||||
# Track response time
|
||||
import time
|
||||
from datetime import datetime
|
||||
start_time = time.time()
|
||||
|
||||
# Execute operation based on type
|
||||
result = {}
|
||||
try:
|
||||
if operation_type == "text-to-video":
|
||||
if provider == "huggingface":
|
||||
video_bytes = _generate_with_huggingface(
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
)
|
||||
# For text-to-video, create metadata dict (HuggingFace doesn't return metadata)
|
||||
result_dict = {
|
||||
video_bytes = _generate_with_huggingface(prompt=prompt, **kwargs)
|
||||
result = {
|
||||
"video_bytes": video_bytes,
|
||||
"prompt": prompt,
|
||||
"duration": kwargs.get("duration", 5.0),
|
||||
"model_name": model_name,
|
||||
"cost": 0.10, # Default cost, will be calculated in track_video_usage
|
||||
"provider": provider,
|
||||
"resolution": kwargs.get("resolution", "720p"),
|
||||
"width": 1280, # Default, actual may vary
|
||||
"height": 720, # Default, actual may vary
|
||||
"metadata": {},
|
||||
"model_name": kwargs.get("model", "tencent/HunyuanVideo"),
|
||||
"provider": "huggingface",
|
||||
"cost": 0.0, # HuggingFace inference is free/low cost
|
||||
}
|
||||
elif provider == "wavespeed":
|
||||
# WaveSpeed text-to-video - use unified service
|
||||
result_dict = await _generate_text_to_video_wavespeed(
|
||||
result = await _generate_text_to_video_wavespeed(
|
||||
prompt=prompt,
|
||||
progress_callback=progress_callback,
|
||||
**kwargs,
|
||||
**kwargs
|
||||
)
|
||||
elif provider == "gemini":
|
||||
video_bytes = _generate_with_gemini(prompt=prompt, **kwargs)
|
||||
result_dict = {
|
||||
"video_bytes": video_bytes,
|
||||
"prompt": prompt,
|
||||
"duration": kwargs.get("duration", 5.0),
|
||||
"model_name": model_name,
|
||||
"cost": 0.10,
|
||||
"provider": provider,
|
||||
"resolution": kwargs.get("resolution", "720p"),
|
||||
"width": 1280,
|
||||
"height": 720,
|
||||
"metadata": {},
|
||||
}
|
||||
result = {"video_bytes": _generate_with_gemini(prompt, **kwargs)}
|
||||
elif provider == "openai":
|
||||
video_bytes = _generate_with_openai(prompt=prompt, **kwargs)
|
||||
result_dict = {
|
||||
"video_bytes": video_bytes,
|
||||
"prompt": prompt,
|
||||
"duration": kwargs.get("duration", 5.0),
|
||||
"model_name": model_name,
|
||||
"cost": 0.10,
|
||||
"provider": provider,
|
||||
"resolution": kwargs.get("resolution", "720p"),
|
||||
"width": 1280,
|
||||
"height": 720,
|
||||
"metadata": {},
|
||||
}
|
||||
result = {"video_bytes": _generate_with_openai(prompt, **kwargs)}
|
||||
else:
|
||||
raise RuntimeError(f"Unknown provider for text-to-video: {provider}")
|
||||
|
||||
raise ValueError(f"Unknown provider for text-to-video: {provider}")
|
||||
|
||||
elif operation_type == "image-to-video":
|
||||
if provider == "wavespeed":
|
||||
# Progress callback: Starting generation
|
||||
if progress_callback:
|
||||
progress_callback(20.0, "Video generation in progress...")
|
||||
|
||||
# Handle async call from sync context
|
||||
# Since ai_video_generate is sync, we need to run async function
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# We're in an async context - use ThreadPoolExecutor to run in new event loop
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
asyncio.run,
|
||||
_generate_image_to_video_wavespeed(
|
||||
image_data=image_data,
|
||||
image_base64=image_base64,
|
||||
prompt=prompt or kwargs.get("prompt", ""),
|
||||
progress_callback=progress_callback,
|
||||
**kwargs
|
||||
)
|
||||
)
|
||||
result_dict = future.result()
|
||||
else:
|
||||
# Event loop exists but not running - use it
|
||||
result_dict = loop.run_until_complete(_generate_image_to_video_wavespeed(
|
||||
image_data=image_data,
|
||||
image_base64=image_base64,
|
||||
prompt=prompt or kwargs.get("prompt", ""),
|
||||
progress_callback=progress_callback,
|
||||
**kwargs
|
||||
))
|
||||
except RuntimeError:
|
||||
# No event loop exists, create a new one
|
||||
result_dict = asyncio.run(_generate_image_to_video_wavespeed(
|
||||
image_data=image_data,
|
||||
image_base64=image_base64,
|
||||
prompt=prompt or kwargs.get("prompt", ""),
|
||||
progress_callback=progress_callback,
|
||||
**kwargs
|
||||
))
|
||||
video_bytes = result_dict["video_bytes"]
|
||||
model_name = result_dict.get("model_name", model_name)
|
||||
|
||||
# Progress callback: Processing result
|
||||
if progress_callback:
|
||||
progress_callback(90.0, "Processing video result...")
|
||||
result = await _generate_image_to_video_wavespeed(
|
||||
image_data=image_data,
|
||||
image_base64=image_base64,
|
||||
prompt=prompt or "",
|
||||
progress_callback=progress_callback,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown provider for image-to-video: {provider}. Only 'wavespeed' is supported.")
|
||||
raise ValueError(f"Unknown provider for image-to-video: {provider}")
|
||||
|
||||
# Track usage (same pattern as text generation)
|
||||
# Use cost from result_dict if available, otherwise calculate
|
||||
response_time = time.time() - start_time
|
||||
cost_override = result_dict.get("cost") if operation_type == "image-to-video" else kwargs.get("cost_override")
|
||||
track_video_usage(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
prompt=result_dict.get("prompt", prompt or ""),
|
||||
video_bytes=video_bytes,
|
||||
cost_override=cost_override,
|
||||
response_time=response_time,
|
||||
)
|
||||
|
||||
# Progress callback: Complete
|
||||
if progress_callback:
|
||||
progress_callback(100.0, "Video generation complete!")
|
||||
|
||||
return result_dict
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions (e.g., from validation or API errors)
|
||||
raise
|
||||
# TRACK USAGE after successful API call
|
||||
video_bytes = result.get("video_bytes")
|
||||
if user_id and video_bytes:
|
||||
_track_video_operation_usage(
|
||||
user_id=user_id,
|
||||
provider=result.get("provider", provider),
|
||||
model=result.get("model_name", kwargs.get("model", "unknown")),
|
||||
operation_type=operation_type,
|
||||
result_bytes=video_bytes,
|
||||
cost=result.get("cost", 0.0),
|
||||
prompt=prompt,
|
||||
endpoint="/video-generation",
|
||||
metadata=result.get("metadata"),
|
||||
log_prefix=f"[{operation_type.replace('-', ' ').title()}]",
|
||||
response_time=response_time
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[video_gen] Error during video generation: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail={"error": str(e)})
|
||||
# Log failure but don't track usage (no cost incurred)
|
||||
logger.error(f"[video_gen] Generation failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def _get_default_model(operation_type: str, provider: str) -> str:
|
||||
|
||||
Reference in New Issue
Block a user