Files
ALwrity/backend/api/subscription/routes/preflight.py

241 lines
11 KiB
Python

"""
Pre-flight check endpoints for operation validation and cost estimation.
"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import Dict, Any
from loguru import logger
from services.database import get_db
from services.subscription import PricingService
from services.subscription.schema_utils import ensure_subscription_plan_columns, ensure_usage_summaries_columns
from middleware.auth_middleware import get_current_user
from models.subscription_models import APIProvider, UsageSummary
from ..dependencies import get_user_id_from_token
from ..models import PreflightCheckRequest
router = APIRouter()
@router.post("/preflight-check")
async def preflight_check(
request: PreflightCheckRequest,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""
Pre-flight check for operations with cost estimation.
Lightweight endpoint that:
- Validates if operations are allowed based on subscription limits
- Estimates cost for operations
- Returns usage information and remaining quota
Uses caching to minimize DB load (< 100ms with cache hit).
"""
try:
user_id = get_user_id_from_token(current_user)
# Ensure schema columns exist
try:
ensure_subscription_plan_columns(db)
ensure_usage_summaries_columns(db)
except Exception as schema_err:
logger.warning(f"Schema check failed: {schema_err}")
pricing_service = PricingService(db)
# Convert request operations to internal format
operations_to_validate = []
for op in request.operations:
try:
# Map provider string to APIProvider enum
provider_str = op.provider.lower()
if provider_str == "huggingface":
provider_enum = APIProvider.MISTRAL # Maps to HuggingFace
elif provider_str == "video":
provider_enum = APIProvider.VIDEO
elif provider_str == "image_edit":
provider_enum = APIProvider.IMAGE_EDIT
elif provider_str == "stability":
provider_enum = APIProvider.STABILITY
elif provider_str == "audio":
provider_enum = APIProvider.AUDIO
else:
try:
provider_enum = APIProvider(provider_str)
except ValueError:
logger.warning(f"Unknown provider: {provider_str}, skipping")
continue
operations_to_validate.append({
'provider': provider_enum,
'tokens_requested': op.tokens_requested or 0,
'actual_provider_name': op.actual_provider_name or op.provider,
'operation_type': op.operation_type
})
except Exception as e:
logger.warning(f"Error processing operation {op.operation_type}: {e}")
continue
if not operations_to_validate:
raise HTTPException(status_code=400, detail="No valid operations provided")
# Perform pre-flight validation
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
# Get pricing and cost estimation for each operation
operation_results = []
total_cost = 0.0
for i, op in enumerate(operations_to_validate):
op_result = {
'provider': op['actual_provider_name'],
'operation_type': op['operation_type'],
'cost': 0.0,
'allowed': can_proceed,
'limit_info': None,
'message': None
}
# Get pricing for this operation
model_name = request.operations[i].model
if model_name:
pricing_info = pricing_service.get_pricing_for_provider_model(
op['provider'],
model_name
)
if pricing_info:
# Determine cost based on operation type
if op['provider'] in [APIProvider.VIDEO, APIProvider.IMAGE_EDIT, APIProvider.STABILITY]:
cost = pricing_info.get('cost_per_request', 0.0) or pricing_info.get('cost_per_image', 0.0) or 0.0
elif op['provider'] == APIProvider.AUDIO:
model_lower = (model_name or "").lower()
if model_lower == "minimax/voice-clone":
cost = pricing_info.get('cost_per_request', 0.5) or 0.5
elif model_lower == "wavespeed-ai/qwen3-tts/voice-clone":
chars = max(0, int(op.get('tokens_requested') or 0))
cost = max(0.005, 0.005 * (chars / 100.0))
else:
# Audio pricing is per character (every character is 1 token)
cost = (pricing_info.get('cost_per_input_token', 0.0) or 0.0) * (op['tokens_requested'] / 1000.0)
elif op['tokens_requested'] > 0:
# Token-based cost estimation (rough estimate)
cost = (pricing_info.get('cost_per_input_token', 0.0) or 0.0) * (op['tokens_requested'] / 1000)
else:
cost = pricing_info.get('cost_per_request', 0.0) or 0.0
op_result['cost'] = round(cost, 4)
total_cost += cost
else:
# Use default cost if pricing not found
if op['provider'] == APIProvider.VIDEO:
op_result['cost'] = 0.10 # Default video cost
total_cost += 0.10
elif op['provider'] == APIProvider.IMAGE_EDIT:
op_result['cost'] = 0.05 # Default image edit cost
total_cost += 0.05
elif op['provider'] == APIProvider.STABILITY:
op_result['cost'] = 0.04 # Default image generation cost
total_cost += 0.04
elif op['provider'] == APIProvider.AUDIO:
# Default audio cost: $0.05 per 1,000 characters
cost = (op['tokens_requested'] / 1000.0) * 0.05
op_result['cost'] = round(cost, 4)
total_cost += cost
# Get limit information
limit_info = None
if error_details and not can_proceed:
usage_info = error_details.get('usage_info', {})
if usage_info:
op_result['message'] = message
limit_info = {
'current_usage': usage_info.get('current_usage', 0),
'limit': usage_info.get('limit', 0),
'remaining': max(0, usage_info.get('limit', 0) - usage_info.get('current_usage', 0))
}
op_result['limit_info'] = limit_info
else:
# Get current usage for this provider
limits = pricing_service.get_user_limits(user_id)
if limits:
usage_summary = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == pricing_service.get_current_billing_period(user_id)
).first()
if usage_summary:
if op['provider'] == APIProvider.VIDEO:
current = getattr(usage_summary, 'video_calls', 0) or 0
limit = limits['limits'].get('video_calls', 0)
elif op['provider'] == APIProvider.IMAGE_EDIT:
current = getattr(usage_summary, 'image_edit_calls', 0) or 0
limit = limits['limits'].get('image_edit_calls', 0)
elif op['provider'] == APIProvider.STABILITY:
current = getattr(usage_summary, 'stability_calls', 0) or 0
limit = limits['limits'].get('stability_calls', 0)
elif op['provider'] == APIProvider.AUDIO:
current = getattr(usage_summary, 'audio_calls', 0) or 0
limit = limits['limits'].get('audio_calls', 0)
else:
# For LLM providers, use token limits
provider_key = op['provider'].value
current_tokens = getattr(usage_summary, f"{provider_key}_tokens", 0) or 0
limit = limits['limits'].get(f"{provider_key}_tokens", 0)
current = current_tokens
limit_info = {
'current_usage': current,
'limit': limit,
'remaining': max(0, limit - current) if limit > 0 else float('inf')
}
op_result['limit_info'] = limit_info
operation_results.append(op_result)
# Get overall usage summary
limits = pricing_service.get_user_limits(user_id)
usage_summary = None
if limits:
usage_summary = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == pricing_service.get_current_billing_period(user_id)
).first()
response_data = {
'can_proceed': can_proceed,
'estimated_cost': round(total_cost, 4),
'operations': operation_results,
'total_cost': round(total_cost, 4),
'usage_summary': None,
'cached': False # TODO: Track if result was cached
}
if usage_summary and limits:
# For video generation, show video limits
video_current = getattr(usage_summary, 'video_calls', 0) or 0
video_limit = limits['limits'].get('video_calls', 0)
response_data['usage_summary'] = {
'current_calls': video_current,
'limit': video_limit,
'remaining': max(0, video_limit - video_current) if video_limit > 0 else float('inf')
}
return {
"success": True,
"data": response_data
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in pre-flight check: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Pre-flight check failed: {str(e)}")