242 lines
11 KiB
Python
242 lines
11 KiB
Python
"""
|
|
Pre-flight check endpoints for operation validation and cost estimation.
|
|
"""
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from sqlalchemy.orm import Session
|
|
from typing import Dict, Any
|
|
from loguru import logger
|
|
|
|
from services.database import get_db
|
|
from services.subscription import PricingService
|
|
from services.subscription.schema_utils import ensure_subscription_plan_columns, ensure_usage_summaries_columns
|
|
from middleware.auth_middleware import get_current_user
|
|
from models.subscription_models import APIProvider, UsageSummary
|
|
from ..dependencies import get_user_id_from_token
|
|
from ..models import PreflightCheckRequest
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.post("/preflight-check")
|
|
async def preflight_check(
|
|
request: PreflightCheckRequest,
|
|
db: Session = Depends(get_db),
|
|
current_user: Dict[str, Any] = Depends(get_current_user)
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Pre-flight check for operations with cost estimation.
|
|
|
|
Lightweight endpoint that:
|
|
- Validates if operations are allowed based on subscription limits
|
|
- Estimates cost for operations
|
|
- Returns usage information and remaining quota
|
|
|
|
Uses caching to minimize DB load (< 100ms with cache hit).
|
|
"""
|
|
try:
|
|
user_id = get_user_id_from_token(current_user)
|
|
|
|
# Ensure schema columns exist
|
|
try:
|
|
ensure_subscription_plan_columns(db)
|
|
ensure_usage_summaries_columns(db)
|
|
except Exception as schema_err:
|
|
logger.warning(f"Schema check failed: {schema_err}")
|
|
|
|
pricing_service = PricingService(db)
|
|
|
|
# Convert request operations to internal format
|
|
operations_to_validate = []
|
|
for op in request.operations:
|
|
try:
|
|
# Map provider string to APIProvider enum
|
|
provider_str = op.provider.lower()
|
|
if provider_str == "huggingface":
|
|
provider_enum = APIProvider.MISTRAL # Maps to HuggingFace
|
|
elif provider_str == "video":
|
|
provider_enum = APIProvider.VIDEO
|
|
elif provider_str == "fal-ai" or provider_str == "fal":
|
|
provider_enum = APIProvider.VIDEO # Map fal-ai to VIDEO as it's primarily used for media gen
|
|
elif provider_str == "image_edit":
|
|
provider_enum = APIProvider.IMAGE_EDIT
|
|
elif provider_str == "stability":
|
|
provider_enum = APIProvider.STABILITY
|
|
elif provider_str == "audio":
|
|
provider_enum = APIProvider.AUDIO
|
|
else:
|
|
try:
|
|
provider_enum = APIProvider(provider_str)
|
|
except ValueError:
|
|
logger.warning(f"Unknown provider: {provider_str}, skipping")
|
|
continue
|
|
|
|
operations_to_validate.append({
|
|
'provider': provider_enum,
|
|
'tokens_requested': op.tokens_requested or 0,
|
|
'actual_provider_name': op.actual_provider_name or op.provider,
|
|
'operation_type': op.operation_type,
|
|
'model': op.model
|
|
})
|
|
except Exception as e:
|
|
logger.warning(f"Error processing operation {op.operation_type}: {e}")
|
|
continue
|
|
|
|
if not operations_to_validate:
|
|
raise HTTPException(status_code=400, detail="No valid operations provided")
|
|
|
|
# Perform pre-flight validation
|
|
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
|
user_id=user_id,
|
|
operations=operations_to_validate
|
|
)
|
|
|
|
# Get pricing and cost estimation for each operation
|
|
operation_results = []
|
|
total_cost = 0.0
|
|
|
|
for op in operations_to_validate:
|
|
op_result = {
|
|
'provider': op['actual_provider_name'],
|
|
'operation_type': op['operation_type'],
|
|
'cost': 0.0,
|
|
'allowed': can_proceed,
|
|
'limit_info': None,
|
|
'message': None
|
|
}
|
|
|
|
# Get pricing for this operation
|
|
model_name = op.get('model')
|
|
if model_name:
|
|
pricing_info = pricing_service.get_pricing_for_provider_model(
|
|
op['provider'],
|
|
model_name
|
|
)
|
|
|
|
if pricing_info:
|
|
# Determine cost based on operation type
|
|
if op['provider'] in [APIProvider.VIDEO, APIProvider.IMAGE_EDIT, APIProvider.STABILITY]:
|
|
cost = pricing_info.get('cost_per_request', 0.0) or pricing_info.get('cost_per_image', 0.0) or 0.0
|
|
elif op['provider'] == APIProvider.AUDIO:
|
|
model_lower = (model_name or "").lower()
|
|
if model_lower == "minimax/voice-clone":
|
|
cost = pricing_info.get('cost_per_request', 0.5) or 0.5
|
|
elif model_lower == "wavespeed-ai/qwen3-tts/voice-clone":
|
|
chars = max(0, int(op.get('tokens_requested') or 0))
|
|
cost = max(0.005, 0.005 * (chars / 100.0))
|
|
else:
|
|
cost = (pricing_info.get('cost_per_input_token', 0.0) or 0.0) * op['tokens_requested']
|
|
elif op['tokens_requested'] > 0:
|
|
cost = (pricing_info.get('cost_per_input_token', 0.0) or 0.0) * op['tokens_requested']
|
|
else:
|
|
cost = pricing_info.get('cost_per_request', 0.0) or 0.0
|
|
|
|
op_result['cost'] = round(cost, 4)
|
|
total_cost += cost
|
|
else:
|
|
# Use default cost if pricing not found
|
|
if op['provider'] == APIProvider.VIDEO:
|
|
op_result['cost'] = 0.10 # Default video cost
|
|
total_cost += 0.10
|
|
elif op['provider'] == APIProvider.IMAGE_EDIT:
|
|
op_result['cost'] = 0.05 # Default image edit cost
|
|
total_cost += 0.05
|
|
elif op['provider'] == APIProvider.STABILITY:
|
|
op_result['cost'] = 0.04 # Default image generation cost
|
|
total_cost += 0.04
|
|
elif op['provider'] == APIProvider.AUDIO:
|
|
# Default audio cost: $0.05 per 1,000 characters
|
|
cost = (op['tokens_requested'] / 1000.0) * 0.05
|
|
op_result['cost'] = round(cost, 4)
|
|
total_cost += cost
|
|
|
|
# Get limit information
|
|
limit_info = None
|
|
if error_details and not can_proceed:
|
|
usage_info = error_details.get('usage_info', {})
|
|
if usage_info:
|
|
op_result['message'] = message
|
|
limit_info = {
|
|
'current_usage': usage_info.get('current_usage', 0),
|
|
'limit': usage_info.get('limit', 0),
|
|
'remaining': max(0, usage_info.get('limit', 0) - usage_info.get('current_usage', 0))
|
|
}
|
|
op_result['limit_info'] = limit_info
|
|
else:
|
|
# Get current usage for this provider
|
|
limits = pricing_service.get_user_limits(user_id)
|
|
if limits:
|
|
usage_summary = db.query(UsageSummary).filter(
|
|
UsageSummary.user_id == user_id,
|
|
UsageSummary.billing_period == pricing_service.get_current_billing_period(user_id)
|
|
).first()
|
|
|
|
if usage_summary:
|
|
if op['provider'] == APIProvider.VIDEO:
|
|
current = getattr(usage_summary, 'video_calls', 0) or 0
|
|
limit = limits['limits'].get('video_calls', 0)
|
|
elif op['provider'] == APIProvider.IMAGE_EDIT:
|
|
current = getattr(usage_summary, 'image_edit_calls', 0) or 0
|
|
limit = limits['limits'].get('image_edit_calls', 0)
|
|
elif op['provider'] == APIProvider.STABILITY:
|
|
current = getattr(usage_summary, 'stability_calls', 0) or 0
|
|
limit = limits['limits'].get('stability_calls', 0)
|
|
elif op['provider'] == APIProvider.AUDIO:
|
|
current = getattr(usage_summary, 'audio_calls', 0) or 0
|
|
limit = limits['limits'].get('audio_calls', 0)
|
|
else:
|
|
# For LLM providers, use token limits
|
|
provider_key = op['provider'].value
|
|
current_tokens = getattr(usage_summary, f"{provider_key}_tokens", 0) or 0
|
|
limit = limits['limits'].get(f"{provider_key}_tokens", 0)
|
|
current = current_tokens
|
|
|
|
limit_info = {
|
|
'current_usage': current,
|
|
'limit': limit,
|
|
'remaining': max(0, limit - current) if limit > 0 else float('inf')
|
|
}
|
|
op_result['limit_info'] = limit_info
|
|
|
|
operation_results.append(op_result)
|
|
|
|
# Get overall usage summary
|
|
limits = pricing_service.get_user_limits(user_id)
|
|
usage_summary = None
|
|
if limits:
|
|
usage_summary = db.query(UsageSummary).filter(
|
|
UsageSummary.user_id == user_id,
|
|
UsageSummary.billing_period == pricing_service.get_current_billing_period(user_id)
|
|
).first()
|
|
|
|
response_data = {
|
|
'can_proceed': can_proceed,
|
|
'estimated_cost': round(total_cost, 4),
|
|
'operations': operation_results,
|
|
'total_cost': round(total_cost, 4),
|
|
'usage_summary': None,
|
|
'cached': False # TODO: Track if result was cached
|
|
}
|
|
|
|
if usage_summary and limits:
|
|
# For video generation, show video limits
|
|
video_current = getattr(usage_summary, 'video_calls', 0) or 0
|
|
video_limit = limits['limits'].get('video_calls', 0)
|
|
|
|
response_data['usage_summary'] = {
|
|
'current_calls': video_current,
|
|
'limit': video_limit,
|
|
'remaining': max(0, video_limit - video_current) if video_limit > 0 else float('inf')
|
|
}
|
|
|
|
return {
|
|
"success": True,
|
|
"data": response_data
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error in pre-flight check: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=f"Pre-flight check failed: {str(e)}")
|