AI story writer enhancements, text to video and voice generation, subscription management, and more.

This commit is contained in:
ajaysi
2025-11-19 09:55:32 +05:30
parent bf7493c366
commit e96525347b
64 changed files with 10367 additions and 400 deletions

View File

@@ -4,6 +4,7 @@ Provides endpoints for subscription management and usage monitoring.
"""
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from sqlalchemy import desc, func
from typing import Dict, Any, Optional, List
@@ -116,6 +117,7 @@ async def get_subscription_plans(
"stability_calls": plan.stability_calls_limit,
"video_calls": getattr(plan, 'video_calls_limit', 0),
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0),
"audio_calls": getattr(plan, 'audio_calls_limit', 0),
"gemini_tokens": plan.gemini_tokens_limit,
"openai_tokens": plan.openai_tokens_limit,
"anthropic_tokens": plan.anthropic_tokens_limit,
@@ -134,7 +136,7 @@ async def get_subscription_plans(
except (sqlite3.OperationalError, Exception) as e:
error_str = str(e).lower()
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str):
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str or 'audio_calls_limit' in error_str):
logger.warning("Missing column detected in subscription plans query, attempting schema fix...")
try:
import services.subscription.schema_utils as schema_utils
@@ -241,6 +243,7 @@ async def get_user_subscription(
"stability_calls": free_plan.stability_calls_limit,
"video_calls": getattr(free_plan, 'video_calls_limit', 0),
"image_edit_calls": getattr(free_plan, 'image_edit_calls_limit', 0),
"audio_calls": getattr(free_plan, 'audio_calls_limit', 0),
"monthly_cost": free_plan.monthly_cost_limit
}
}
@@ -340,6 +343,7 @@ async def get_subscription_status(
"stability_calls": free_plan.stability_calls_limit,
"video_calls": getattr(free_plan, 'video_calls_limit', 0),
"image_edit_calls": getattr(free_plan, 'image_edit_calls_limit', 0),
"audio_calls": getattr(free_plan, 'audio_calls_limit', 0),
"monthly_cost": free_plan.monthly_cost_limit
}
}
@@ -405,7 +409,7 @@ async def get_subscription_status(
except (sqlite3.OperationalError, Exception) as e:
error_str = str(e).lower()
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str):
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str or 'audio_calls_limit' in error_str):
# Try to fix schema and retry once
logger.warning("Missing column detected in subscription status query, attempting schema fix...")
try:
@@ -499,6 +503,7 @@ async def get_subscription_status(
"stability_calls": plan.stability_calls_limit,
"video_calls": getattr(plan, 'video_calls_limit', 0),
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0),
"audio_calls": getattr(plan, 'audio_calls_limit', 0),
"monthly_cost": plan.monthly_cost_limit
}
}
@@ -988,7 +993,7 @@ async def get_dashboard_data(
except (sqlite3.OperationalError, Exception) as e:
error_str = str(e).lower()
if 'no such column' in error_str and ('exa_calls' in error_str or 'exa_cost' in error_str or 'video_calls' in error_str or 'video_cost' in error_str or 'image_edit_calls' in error_str or 'image_edit_cost' in error_str):
if 'no such column' in error_str and ('exa_calls' in error_str or 'exa_cost' in error_str or 'video_calls' in error_str or 'video_cost' in error_str or 'image_edit_calls' in error_str or 'image_edit_cost' in error_str or 'audio_calls' in error_str or 'audio_cost' in error_str):
logger.warning("Missing column detected in dashboard query, attempting schema fix...")
try:
import services.subscription.schema_utils as schema_utils
@@ -1271,4 +1276,235 @@ async def get_usage_logs(
raise
except Exception as e:
logger.error(f"Error getting usage logs: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to get usage logs: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to get usage logs: {str(e)}")
class PreflightOperationRequest(BaseModel):
"""Request model for pre-flight check operation."""
provider: str
model: Optional[str] = None
tokens_requested: Optional[int] = 0
operation_type: str
actual_provider_name: Optional[str] = None
class PreflightCheckRequest(BaseModel):
"""Request model for pre-flight check."""
operations: List[PreflightOperationRequest]
@router.post("/preflight-check")
async def preflight_check(
request: PreflightCheckRequest,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""
Pre-flight check for operations with cost estimation.
Lightweight endpoint that:
- Validates if operations are allowed based on subscription limits
- Estimates cost for operations
- Returns usage information and remaining quota
Uses caching to minimize DB load (< 100ms with cache hit).
"""
try:
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
# Ensure schema columns exist
try:
ensure_subscription_plan_columns(db)
ensure_usage_summaries_columns(db)
except Exception as schema_err:
logger.warning(f"Schema check failed: {schema_err}")
pricing_service = PricingService(db)
# Convert request operations to internal format
operations_to_validate = []
for op in request.operations:
try:
# Map provider string to APIProvider enum
provider_str = op.provider.lower()
if provider_str == "huggingface":
provider_enum = APIProvider.MISTRAL # Maps to HuggingFace
elif provider_str == "video":
provider_enum = APIProvider.VIDEO
elif provider_str == "image_edit":
provider_enum = APIProvider.IMAGE_EDIT
elif provider_str == "stability":
provider_enum = APIProvider.STABILITY
elif provider_str == "audio":
provider_enum = APIProvider.AUDIO
else:
try:
provider_enum = APIProvider(provider_str)
except ValueError:
logger.warning(f"Unknown provider: {provider_str}, skipping")
continue
operations_to_validate.append({
'provider': provider_enum,
'tokens_requested': op.tokens_requested or 0,
'actual_provider_name': op.actual_provider_name or op.provider,
'operation_type': op.operation_type
})
except Exception as e:
logger.warning(f"Error processing operation {op.operation_type}: {e}")
continue
if not operations_to_validate:
raise HTTPException(status_code=400, detail="No valid operations provided")
# Perform pre-flight validation
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
# Get pricing and cost estimation for each operation
operation_results = []
total_cost = 0.0
for i, op in enumerate(operations_to_validate):
op_result = {
'provider': op['actual_provider_name'],
'operation_type': op['operation_type'],
'cost': 0.0,
'allowed': can_proceed,
'limit_info': None,
'message': None
}
# Get pricing for this operation
model_name = request.operations[i].model
if model_name:
pricing_info = pricing_service.get_pricing_for_provider_model(
op['provider'],
model_name
)
if pricing_info:
# Determine cost based on operation type
if op['provider'] in [APIProvider.VIDEO, APIProvider.IMAGE_EDIT, APIProvider.STABILITY]:
cost = pricing_info.get('cost_per_request', 0.0) or pricing_info.get('cost_per_image', 0.0) or 0.0
elif op['provider'] == APIProvider.AUDIO:
# Audio pricing is per character (every character is 1 token)
cost = (pricing_info.get('cost_per_input_token', 0.0) or 0.0) * (op['tokens_requested'] / 1000.0)
elif op['tokens_requested'] > 0:
# Token-based cost estimation (rough estimate)
cost = (pricing_info.get('cost_per_input_token', 0.0) or 0.0) * (op['tokens_requested'] / 1000)
else:
cost = pricing_info.get('cost_per_request', 0.0) or 0.0
op_result['cost'] = round(cost, 4)
total_cost += cost
else:
# Use default cost if pricing not found
if op['provider'] == APIProvider.VIDEO:
op_result['cost'] = 0.10 # Default video cost
total_cost += 0.10
elif op['provider'] == APIProvider.IMAGE_EDIT:
op_result['cost'] = 0.05 # Default image edit cost
total_cost += 0.05
elif op['provider'] == APIProvider.STABILITY:
op_result['cost'] = 0.04 # Default image generation cost
total_cost += 0.04
elif op['provider'] == APIProvider.AUDIO:
# Default audio cost: $0.05 per 1,000 characters
cost = (op['tokens_requested'] / 1000.0) * 0.05
op_result['cost'] = round(cost, 4)
total_cost += cost
# Get limit information
limit_info = None
if error_details and not can_proceed:
usage_info = error_details.get('usage_info', {})
if usage_info:
op_result['message'] = message
limit_info = {
'current_usage': usage_info.get('current_usage', 0),
'limit': usage_info.get('limit', 0),
'remaining': max(0, usage_info.get('limit', 0) - usage_info.get('current_usage', 0))
}
op_result['limit_info'] = limit_info
else:
# Get current usage for this provider
limits = pricing_service.get_user_limits(user_id)
if limits:
usage_summary = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == pricing_service.get_current_billing_period(user_id)
).first()
if usage_summary:
if op['provider'] == APIProvider.VIDEO:
current = getattr(usage_summary, 'video_calls', 0) or 0
limit = limits['limits'].get('video_calls', 0)
elif op['provider'] == APIProvider.IMAGE_EDIT:
current = getattr(usage_summary, 'image_edit_calls', 0) or 0
limit = limits['limits'].get('image_edit_calls', 0)
elif op['provider'] == APIProvider.STABILITY:
current = getattr(usage_summary, 'stability_calls', 0) or 0
limit = limits['limits'].get('stability_calls', 0)
elif op['provider'] == APIProvider.AUDIO:
current = getattr(usage_summary, 'audio_calls', 0) or 0
limit = limits['limits'].get('audio_calls', 0)
else:
# For LLM providers, use token limits
provider_key = op['provider'].value
current_tokens = getattr(usage_summary, f"{provider_key}_tokens", 0) or 0
limit = limits['limits'].get(f"{provider_key}_tokens", 0)
current = current_tokens
limit_info = {
'current_usage': current,
'limit': limit,
'remaining': max(0, limit - current) if limit > 0 else float('inf')
}
op_result['limit_info'] = limit_info
operation_results.append(op_result)
# Get overall usage summary
limits = pricing_service.get_user_limits(user_id)
usage_summary = None
if limits:
usage_summary = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == pricing_service.get_current_billing_period(user_id)
).first()
response_data = {
'can_proceed': can_proceed,
'estimated_cost': round(total_cost, 4),
'operations': operation_results,
'total_cost': round(total_cost, 4),
'usage_summary': None,
'cached': False # TODO: Track if result was cached
}
if usage_summary and limits:
# For video generation, show video limits
video_current = getattr(usage_summary, 'video_calls', 0) or 0
video_limit = limits['limits'].get('video_calls', 0)
response_data['usage_summary'] = {
'current_calls': video_current,
'limit': video_limit,
'remaining': max(0, video_limit - video_current) if video_limit > 0 else float('inf')
}
return {
"success": True,
"data": response_data
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in pre-flight check: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Pre-flight check failed: {str(e)}")