AI story writer enhancements, text to video and voice generation, subscription management, and more.
This commit is contained in:
@@ -4,6 +4,7 @@ Provides endpoints for subscription management and usage monitoring.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc, func
|
||||
from typing import Dict, Any, Optional, List
|
||||
@@ -116,6 +117,7 @@ async def get_subscription_plans(
|
||||
"stability_calls": plan.stability_calls_limit,
|
||||
"video_calls": getattr(plan, 'video_calls_limit', 0),
|
||||
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0),
|
||||
"audio_calls": getattr(plan, 'audio_calls_limit', 0),
|
||||
"gemini_tokens": plan.gemini_tokens_limit,
|
||||
"openai_tokens": plan.openai_tokens_limit,
|
||||
"anthropic_tokens": plan.anthropic_tokens_limit,
|
||||
@@ -134,7 +136,7 @@ async def get_subscription_plans(
|
||||
|
||||
except (sqlite3.OperationalError, Exception) as e:
|
||||
error_str = str(e).lower()
|
||||
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str):
|
||||
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str or 'audio_calls_limit' in error_str):
|
||||
logger.warning("Missing column detected in subscription plans query, attempting schema fix...")
|
||||
try:
|
||||
import services.subscription.schema_utils as schema_utils
|
||||
@@ -241,6 +243,7 @@ async def get_user_subscription(
|
||||
"stability_calls": free_plan.stability_calls_limit,
|
||||
"video_calls": getattr(free_plan, 'video_calls_limit', 0),
|
||||
"image_edit_calls": getattr(free_plan, 'image_edit_calls_limit', 0),
|
||||
"audio_calls": getattr(free_plan, 'audio_calls_limit', 0),
|
||||
"monthly_cost": free_plan.monthly_cost_limit
|
||||
}
|
||||
}
|
||||
@@ -340,6 +343,7 @@ async def get_subscription_status(
|
||||
"stability_calls": free_plan.stability_calls_limit,
|
||||
"video_calls": getattr(free_plan, 'video_calls_limit', 0),
|
||||
"image_edit_calls": getattr(free_plan, 'image_edit_calls_limit', 0),
|
||||
"audio_calls": getattr(free_plan, 'audio_calls_limit', 0),
|
||||
"monthly_cost": free_plan.monthly_cost_limit
|
||||
}
|
||||
}
|
||||
@@ -405,7 +409,7 @@ async def get_subscription_status(
|
||||
|
||||
except (sqlite3.OperationalError, Exception) as e:
|
||||
error_str = str(e).lower()
|
||||
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str):
|
||||
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str or 'audio_calls_limit' in error_str):
|
||||
# Try to fix schema and retry once
|
||||
logger.warning("Missing column detected in subscription status query, attempting schema fix...")
|
||||
try:
|
||||
@@ -499,6 +503,7 @@ async def get_subscription_status(
|
||||
"stability_calls": plan.stability_calls_limit,
|
||||
"video_calls": getattr(plan, 'video_calls_limit', 0),
|
||||
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0),
|
||||
"audio_calls": getattr(plan, 'audio_calls_limit', 0),
|
||||
"monthly_cost": plan.monthly_cost_limit
|
||||
}
|
||||
}
|
||||
@@ -988,7 +993,7 @@ async def get_dashboard_data(
|
||||
|
||||
except (sqlite3.OperationalError, Exception) as e:
|
||||
error_str = str(e).lower()
|
||||
if 'no such column' in error_str and ('exa_calls' in error_str or 'exa_cost' in error_str or 'video_calls' in error_str or 'video_cost' in error_str or 'image_edit_calls' in error_str or 'image_edit_cost' in error_str):
|
||||
if 'no such column' in error_str and ('exa_calls' in error_str or 'exa_cost' in error_str or 'video_calls' in error_str or 'video_cost' in error_str or 'image_edit_calls' in error_str or 'image_edit_cost' in error_str or 'audio_calls' in error_str or 'audio_cost' in error_str):
|
||||
logger.warning("Missing column detected in dashboard query, attempting schema fix...")
|
||||
try:
|
||||
import services.subscription.schema_utils as schema_utils
|
||||
@@ -1271,4 +1276,235 @@ async def get_usage_logs(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting usage logs: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get usage logs: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get usage logs: {str(e)}")
|
||||
|
||||
|
||||
class PreflightOperationRequest(BaseModel):
|
||||
"""Request model for pre-flight check operation."""
|
||||
provider: str
|
||||
model: Optional[str] = None
|
||||
tokens_requested: Optional[int] = 0
|
||||
operation_type: str
|
||||
actual_provider_name: Optional[str] = None
|
||||
|
||||
|
||||
class PreflightCheckRequest(BaseModel):
|
||||
"""Request model for pre-flight check."""
|
||||
operations: List[PreflightOperationRequest]
|
||||
|
||||
|
||||
@router.post("/preflight-check")
|
||||
async def preflight_check(
|
||||
request: PreflightCheckRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Pre-flight check for operations with cost estimation.
|
||||
|
||||
Lightweight endpoint that:
|
||||
- Validates if operations are allowed based on subscription limits
|
||||
- Estimates cost for operations
|
||||
- Returns usage information and remaining quota
|
||||
|
||||
Uses caching to minimize DB load (< 100ms with cache hit).
|
||||
"""
|
||||
try:
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
# Ensure schema columns exist
|
||||
try:
|
||||
ensure_subscription_plan_columns(db)
|
||||
ensure_usage_summaries_columns(db)
|
||||
except Exception as schema_err:
|
||||
logger.warning(f"Schema check failed: {schema_err}")
|
||||
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
# Convert request operations to internal format
|
||||
operations_to_validate = []
|
||||
for op in request.operations:
|
||||
try:
|
||||
# Map provider string to APIProvider enum
|
||||
provider_str = op.provider.lower()
|
||||
if provider_str == "huggingface":
|
||||
provider_enum = APIProvider.MISTRAL # Maps to HuggingFace
|
||||
elif provider_str == "video":
|
||||
provider_enum = APIProvider.VIDEO
|
||||
elif provider_str == "image_edit":
|
||||
provider_enum = APIProvider.IMAGE_EDIT
|
||||
elif provider_str == "stability":
|
||||
provider_enum = APIProvider.STABILITY
|
||||
elif provider_str == "audio":
|
||||
provider_enum = APIProvider.AUDIO
|
||||
else:
|
||||
try:
|
||||
provider_enum = APIProvider(provider_str)
|
||||
except ValueError:
|
||||
logger.warning(f"Unknown provider: {provider_str}, skipping")
|
||||
continue
|
||||
|
||||
operations_to_validate.append({
|
||||
'provider': provider_enum,
|
||||
'tokens_requested': op.tokens_requested or 0,
|
||||
'actual_provider_name': op.actual_provider_name or op.provider,
|
||||
'operation_type': op.operation_type
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing operation {op.operation_type}: {e}")
|
||||
continue
|
||||
|
||||
if not operations_to_validate:
|
||||
raise HTTPException(status_code=400, detail="No valid operations provided")
|
||||
|
||||
# Perform pre-flight validation
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
# Get pricing and cost estimation for each operation
|
||||
operation_results = []
|
||||
total_cost = 0.0
|
||||
|
||||
for i, op in enumerate(operations_to_validate):
|
||||
op_result = {
|
||||
'provider': op['actual_provider_name'],
|
||||
'operation_type': op['operation_type'],
|
||||
'cost': 0.0,
|
||||
'allowed': can_proceed,
|
||||
'limit_info': None,
|
||||
'message': None
|
||||
}
|
||||
|
||||
# Get pricing for this operation
|
||||
model_name = request.operations[i].model
|
||||
if model_name:
|
||||
pricing_info = pricing_service.get_pricing_for_provider_model(
|
||||
op['provider'],
|
||||
model_name
|
||||
)
|
||||
|
||||
if pricing_info:
|
||||
# Determine cost based on operation type
|
||||
if op['provider'] in [APIProvider.VIDEO, APIProvider.IMAGE_EDIT, APIProvider.STABILITY]:
|
||||
cost = pricing_info.get('cost_per_request', 0.0) or pricing_info.get('cost_per_image', 0.0) or 0.0
|
||||
elif op['provider'] == APIProvider.AUDIO:
|
||||
# Audio pricing is per character (every character is 1 token)
|
||||
cost = (pricing_info.get('cost_per_input_token', 0.0) or 0.0) * (op['tokens_requested'] / 1000.0)
|
||||
elif op['tokens_requested'] > 0:
|
||||
# Token-based cost estimation (rough estimate)
|
||||
cost = (pricing_info.get('cost_per_input_token', 0.0) or 0.0) * (op['tokens_requested'] / 1000)
|
||||
else:
|
||||
cost = pricing_info.get('cost_per_request', 0.0) or 0.0
|
||||
|
||||
op_result['cost'] = round(cost, 4)
|
||||
total_cost += cost
|
||||
else:
|
||||
# Use default cost if pricing not found
|
||||
if op['provider'] == APIProvider.VIDEO:
|
||||
op_result['cost'] = 0.10 # Default video cost
|
||||
total_cost += 0.10
|
||||
elif op['provider'] == APIProvider.IMAGE_EDIT:
|
||||
op_result['cost'] = 0.05 # Default image edit cost
|
||||
total_cost += 0.05
|
||||
elif op['provider'] == APIProvider.STABILITY:
|
||||
op_result['cost'] = 0.04 # Default image generation cost
|
||||
total_cost += 0.04
|
||||
elif op['provider'] == APIProvider.AUDIO:
|
||||
# Default audio cost: $0.05 per 1,000 characters
|
||||
cost = (op['tokens_requested'] / 1000.0) * 0.05
|
||||
op_result['cost'] = round(cost, 4)
|
||||
total_cost += cost
|
||||
|
||||
# Get limit information
|
||||
limit_info = None
|
||||
if error_details and not can_proceed:
|
||||
usage_info = error_details.get('usage_info', {})
|
||||
if usage_info:
|
||||
op_result['message'] = message
|
||||
limit_info = {
|
||||
'current_usage': usage_info.get('current_usage', 0),
|
||||
'limit': usage_info.get('limit', 0),
|
||||
'remaining': max(0, usage_info.get('limit', 0) - usage_info.get('current_usage', 0))
|
||||
}
|
||||
op_result['limit_info'] = limit_info
|
||||
else:
|
||||
# Get current usage for this provider
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
if limits:
|
||||
usage_summary = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == pricing_service.get_current_billing_period(user_id)
|
||||
).first()
|
||||
|
||||
if usage_summary:
|
||||
if op['provider'] == APIProvider.VIDEO:
|
||||
current = getattr(usage_summary, 'video_calls', 0) or 0
|
||||
limit = limits['limits'].get('video_calls', 0)
|
||||
elif op['provider'] == APIProvider.IMAGE_EDIT:
|
||||
current = getattr(usage_summary, 'image_edit_calls', 0) or 0
|
||||
limit = limits['limits'].get('image_edit_calls', 0)
|
||||
elif op['provider'] == APIProvider.STABILITY:
|
||||
current = getattr(usage_summary, 'stability_calls', 0) or 0
|
||||
limit = limits['limits'].get('stability_calls', 0)
|
||||
elif op['provider'] == APIProvider.AUDIO:
|
||||
current = getattr(usage_summary, 'audio_calls', 0) or 0
|
||||
limit = limits['limits'].get('audio_calls', 0)
|
||||
else:
|
||||
# For LLM providers, use token limits
|
||||
provider_key = op['provider'].value
|
||||
current_tokens = getattr(usage_summary, f"{provider_key}_tokens", 0) or 0
|
||||
limit = limits['limits'].get(f"{provider_key}_tokens", 0)
|
||||
current = current_tokens
|
||||
|
||||
limit_info = {
|
||||
'current_usage': current,
|
||||
'limit': limit,
|
||||
'remaining': max(0, limit - current) if limit > 0 else float('inf')
|
||||
}
|
||||
op_result['limit_info'] = limit_info
|
||||
|
||||
operation_results.append(op_result)
|
||||
|
||||
# Get overall usage summary
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
usage_summary = None
|
||||
if limits:
|
||||
usage_summary = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == pricing_service.get_current_billing_period(user_id)
|
||||
).first()
|
||||
|
||||
response_data = {
|
||||
'can_proceed': can_proceed,
|
||||
'estimated_cost': round(total_cost, 4),
|
||||
'operations': operation_results,
|
||||
'total_cost': round(total_cost, 4),
|
||||
'usage_summary': None,
|
||||
'cached': False # TODO: Track if result was cached
|
||||
}
|
||||
|
||||
if usage_summary and limits:
|
||||
# For video generation, show video limits
|
||||
video_current = getattr(usage_summary, 'video_calls', 0) or 0
|
||||
video_limit = limits['limits'].get('video_calls', 0)
|
||||
|
||||
response_data['usage_summary'] = {
|
||||
'current_calls': video_current,
|
||||
'limit': video_limit,
|
||||
'remaining': max(0, video_limit - video_current) if video_limit > 0 else float('inf')
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": response_data
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in pre-flight check: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Pre-flight check failed: {str(e)}")
|
||||
Reference in New Issue
Block a user