AI Analysis and Content Strategy fixes. Enhanced Strategy Routes refactoring.
This commit is contained in:
30
backend/api/subscription/__init__.py
Normal file
30
backend/api/subscription/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""
|
||||
Subscription API Module
|
||||
Main router that includes all subscription-related endpoints.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .routes import (
|
||||
usage,
|
||||
plans,
|
||||
subscriptions,
|
||||
alerts,
|
||||
dashboard,
|
||||
logs,
|
||||
preflight
|
||||
)
|
||||
|
||||
# Create main router
|
||||
router = APIRouter(prefix="/api/subscription", tags=["subscription"])
|
||||
|
||||
# Include all sub-routers
|
||||
router.include_router(usage.router, tags=["subscription"])
|
||||
router.include_router(plans.router, tags=["subscription"])
|
||||
router.include_router(subscriptions.router, tags=["subscription"])
|
||||
router.include_router(alerts.router, tags=["subscription"])
|
||||
router.include_router(dashboard.router, tags=["subscription"])
|
||||
router.include_router(logs.router, tags=["subscription"])
|
||||
router.include_router(preflight.router, tags=["subscription"])
|
||||
|
||||
__all__ = ["router"]
|
||||
68
backend/api/subscription/cache.py
Normal file
68
backend/api/subscription/cache.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Cache management for subscription API endpoints.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
import time
|
||||
import os
|
||||
|
||||
|
||||
# Simple in-process cache for dashboard responses to smooth bursts
|
||||
# Cache key: (user_id). TTL-like behavior implemented via timestamp check
|
||||
_dashboard_cache: Dict[str, Dict[str, Any]] = {}
|
||||
_dashboard_cache_ts: Dict[str, float] = {}
|
||||
_DASHBOARD_CACHE_TTL_SEC = 600.0
|
||||
|
||||
|
||||
def get_cached_dashboard(user_id: str) -> Dict[str, Any] | None:
|
||||
"""
|
||||
Get cached dashboard data if available and not expired.
|
||||
|
||||
Args:
|
||||
user_id: User ID to get cached data for
|
||||
|
||||
Returns:
|
||||
Cached dashboard data or None if not cached/expired
|
||||
"""
|
||||
# Check if caching is disabled via environment variable
|
||||
nocache = False
|
||||
try:
|
||||
nocache = os.getenv('SUBSCRIPTION_DASHBOARD_NOCACHE', 'false').lower() in {'1', 'true', 'yes', 'on'}
|
||||
except Exception:
|
||||
nocache = False
|
||||
|
||||
if nocache:
|
||||
return None
|
||||
|
||||
now = time.time()
|
||||
if user_id in _dashboard_cache and (now - _dashboard_cache_ts.get(user_id, 0)) < _DASHBOARD_CACHE_TTL_SEC:
|
||||
return _dashboard_cache[user_id]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def set_cached_dashboard(user_id: str, data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Cache dashboard data for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID to cache data for
|
||||
data: Dashboard data to cache
|
||||
"""
|
||||
_dashboard_cache[user_id] = data
|
||||
_dashboard_cache_ts[user_id] = time.time()
|
||||
|
||||
|
||||
def clear_dashboard_cache(user_id: str | None = None) -> None:
|
||||
"""
|
||||
Clear dashboard cache for a specific user or all users.
|
||||
|
||||
Args:
|
||||
user_id: User ID to clear cache for, or None to clear all
|
||||
"""
|
||||
if user_id:
|
||||
_dashboard_cache.pop(user_id, None)
|
||||
_dashboard_cache_ts.pop(user_id, None)
|
||||
else:
|
||||
_dashboard_cache.clear()
|
||||
_dashboard_cache_ts.clear()
|
||||
84
backend/api/subscription/dependencies.py
Normal file
84
backend/api/subscription/dependencies.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
Shared dependencies for subscription API routes.
|
||||
"""
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Dict, Any
|
||||
|
||||
from services.database import get_db
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from services.subscription.schema_utils import (
|
||||
ensure_subscription_plan_columns,
|
||||
ensure_usage_summaries_columns,
|
||||
ensure_api_usage_logs_columns
|
||||
)
|
||||
|
||||
|
||||
def verify_user_access(
|
||||
user_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> str:
|
||||
"""
|
||||
Verify that the current user can only access their own data.
|
||||
|
||||
Args:
|
||||
user_id: The user ID from the route parameter
|
||||
current_user: The authenticated user from the token
|
||||
|
||||
Returns:
|
||||
The verified user_id
|
||||
|
||||
Raises:
|
||||
HTTPException: If user tries to access another user's data
|
||||
"""
|
||||
if current_user.get('id') != user_id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
return user_id
|
||||
|
||||
|
||||
def get_user_id_from_token(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> str:
|
||||
"""
|
||||
Extract user ID from authentication token.
|
||||
|
||||
Args:
|
||||
current_user: The authenticated user from the token
|
||||
|
||||
Returns:
|
||||
The user ID as a string
|
||||
|
||||
Raises:
|
||||
HTTPException: If user is not authenticated
|
||||
"""
|
||||
user_id = str(current_user.get('id', '')) if current_user else None
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User not authenticated")
|
||||
return user_id
|
||||
|
||||
|
||||
def ensure_schema_columns(
|
||||
db: Session = Depends(get_db),
|
||||
include_usage_logs: bool = False
|
||||
) -> Session:
|
||||
"""
|
||||
Ensure required schema columns exist before queries.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
include_usage_logs: Whether to check api_usage_logs columns
|
||||
|
||||
Returns:
|
||||
Database session
|
||||
"""
|
||||
try:
|
||||
ensure_subscription_plan_columns(db)
|
||||
ensure_usage_summaries_columns(db)
|
||||
if include_usage_logs:
|
||||
ensure_api_usage_logs_columns(db)
|
||||
except Exception as schema_err:
|
||||
# Log warning but don't fail - will be caught by error handlers
|
||||
from loguru import logger
|
||||
logger.warning(f"Schema check failed, will retry on query: {schema_err}")
|
||||
return db
|
||||
20
backend/api/subscription/models.py
Normal file
20
backend/api/subscription/models.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
Pydantic models for subscription API requests/responses.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
class PreflightOperationRequest(BaseModel):
|
||||
"""Request model for pre-flight check operation."""
|
||||
provider: str
|
||||
model: Optional[str] = None
|
||||
tokens_requested: Optional[int] = 0
|
||||
operation_type: str
|
||||
actual_provider_name: Optional[str] = None
|
||||
|
||||
|
||||
class PreflightCheckRequest(BaseModel):
|
||||
"""Request model for pre-flight check."""
|
||||
operations: List[PreflightOperationRequest]
|
||||
8
backend/api/subscription/routes/__init__.py
Normal file
8
backend/api/subscription/routes/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Subscription API Routes
|
||||
All route modules are imported here for easy access.
|
||||
"""
|
||||
|
||||
from . import usage, plans, subscriptions, alerts, dashboard, logs, preflight
|
||||
|
||||
__all__ = ["usage", "plans", "subscriptions", "alerts", "dashboard", "logs", "preflight"]
|
||||
94
backend/api/subscription/routes/alerts.py
Normal file
94
backend/api/subscription/routes/alerts.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
Usage alerts endpoints.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Dict, Any
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
|
||||
from services.database import get_db
|
||||
from models.subscription_models import UsageAlert
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/alerts/{user_id}")
|
||||
async def get_usage_alerts(
|
||||
user_id: str,
|
||||
unread_only: bool = Query(False, description="Only return unread alerts"),
|
||||
limit: int = Query(50, ge=1, le=100, description="Maximum number of alerts"),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get usage alerts for a user."""
|
||||
|
||||
try:
|
||||
query = db.query(UsageAlert).filter(
|
||||
UsageAlert.user_id == user_id
|
||||
)
|
||||
|
||||
if unread_only:
|
||||
query = query.filter(UsageAlert.is_read == False)
|
||||
|
||||
alerts = query.order_by(
|
||||
UsageAlert.created_at.desc()
|
||||
).limit(limit).all()
|
||||
|
||||
alerts_data = []
|
||||
for alert in alerts:
|
||||
alerts_data.append({
|
||||
"id": alert.id,
|
||||
"type": alert.alert_type,
|
||||
"threshold_percentage": alert.threshold_percentage,
|
||||
"provider": alert.provider.value if alert.provider else None,
|
||||
"title": alert.title,
|
||||
"message": alert.message,
|
||||
"severity": alert.severity,
|
||||
"is_sent": alert.is_sent,
|
||||
"sent_at": alert.sent_at.isoformat() if alert.sent_at else None,
|
||||
"is_read": alert.is_read,
|
||||
"read_at": alert.read_at.isoformat() if alert.read_at else None,
|
||||
"billing_period": alert.billing_period,
|
||||
"created_at": alert.created_at.isoformat()
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"alerts": alerts_data,
|
||||
"total": len(alerts_data),
|
||||
"unread_count": len([a for a in alerts_data if not a["is_read"]])
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting usage alerts: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/alerts/{alert_id}/mark-read")
|
||||
async def mark_alert_read(
|
||||
alert_id: int,
|
||||
db: Session = Depends(get_db)
|
||||
) -> Dict[str, Any]:
|
||||
"""Mark an alert as read."""
|
||||
|
||||
try:
|
||||
alert = db.query(UsageAlert).filter(UsageAlert.id == alert_id).first()
|
||||
|
||||
if not alert:
|
||||
raise HTTPException(status_code=404, detail="Alert not found")
|
||||
|
||||
alert.is_read = True
|
||||
alert.read_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Alert marked as read"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error marking alert as read: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
170
backend/api/subscription/routes/dashboard.py
Normal file
170
backend/api/subscription/routes/dashboard.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Dashboard endpoints for comprehensive usage monitoring.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Dict, Any
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
import sqlite3
|
||||
|
||||
from services.database import get_db
|
||||
from services.subscription import UsageTrackingService, PricingService
|
||||
from services.subscription.schema_utils import ensure_subscription_plan_columns, ensure_usage_summaries_columns
|
||||
from models.subscription_models import UsageAlert
|
||||
from ..cache import get_cached_dashboard, set_cached_dashboard
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/dashboard/{user_id}")
|
||||
async def get_dashboard_data(
|
||||
user_id: str,
|
||||
db: Session = Depends(get_db)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get comprehensive dashboard data for usage monitoring."""
|
||||
|
||||
try:
|
||||
ensure_subscription_plan_columns(db)
|
||||
ensure_usage_summaries_columns(db)
|
||||
|
||||
# Check cache first
|
||||
cached_data = get_cached_dashboard(user_id)
|
||||
if cached_data:
|
||||
return cached_data
|
||||
|
||||
usage_service = UsageTrackingService(db)
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
# Get current usage stats
|
||||
current_usage = usage_service.get_user_usage_stats(user_id)
|
||||
|
||||
# Get usage trends (last 6 months)
|
||||
trends = usage_service.get_usage_trends(user_id, 6)
|
||||
|
||||
# Get user limits
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
|
||||
# Get unread alerts
|
||||
alerts = db.query(UsageAlert).filter(
|
||||
UsageAlert.user_id == user_id,
|
||||
UsageAlert.is_read == False
|
||||
).order_by(UsageAlert.created_at.desc()).limit(5).all()
|
||||
|
||||
alerts_data = [
|
||||
{
|
||||
"id": alert.id,
|
||||
"type": alert.alert_type,
|
||||
"title": alert.title,
|
||||
"message": alert.message,
|
||||
"severity": alert.severity,
|
||||
"created_at": alert.created_at.isoformat()
|
||||
}
|
||||
for alert in alerts
|
||||
]
|
||||
|
||||
# Calculate cost projections
|
||||
current_cost = current_usage.get('total_cost', 0)
|
||||
days_in_period = 30
|
||||
current_day = datetime.now().day
|
||||
projected_cost = (current_cost / current_day) * days_in_period if current_day > 0 else 0
|
||||
|
||||
response_payload = {
|
||||
"success": True,
|
||||
"data": {
|
||||
"current_usage": current_usage,
|
||||
"trends": trends,
|
||||
"limits": limits,
|
||||
"alerts": alerts_data,
|
||||
"projections": {
|
||||
"projected_monthly_cost": round(projected_cost, 2),
|
||||
"cost_limit": limits.get('limits', {}).get('monthly_cost', 0) if limits else 0,
|
||||
"projected_usage_percentage": (projected_cost / max(limits.get('limits', {}).get('monthly_cost', 1), 1)) * 100 if limits else 0
|
||||
},
|
||||
"summary": {
|
||||
"total_api_calls_this_month": current_usage.get('total_calls', 0),
|
||||
"total_cost_this_month": current_usage.get('total_cost', 0),
|
||||
"usage_status": current_usage.get('usage_status', 'active'),
|
||||
"unread_alerts": len(alerts_data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Cache the response
|
||||
set_cached_dashboard(user_id, response_payload)
|
||||
return response_payload
|
||||
|
||||
except (sqlite3.OperationalError, Exception) as e:
|
||||
error_str = str(e).lower()
|
||||
if 'no such column' in error_str and ('exa_calls' in error_str or 'exa_cost' in error_str or 'video_calls' in error_str or 'video_cost' in error_str or 'image_edit_calls' in error_str or 'image_edit_cost' in error_str or 'audio_calls' in error_str or 'audio_cost' in error_str):
|
||||
logger.warning("Missing column detected in dashboard query, attempting schema fix...")
|
||||
try:
|
||||
import services.subscription.schema_utils as schema_utils
|
||||
schema_utils._checked_usage_summaries_columns = False
|
||||
schema_utils._checked_subscription_plan_columns = False
|
||||
# Use the already imported functions from top of file
|
||||
ensure_usage_summaries_columns(db)
|
||||
ensure_subscription_plan_columns(db)
|
||||
db.expire_all()
|
||||
|
||||
# Retry the query
|
||||
usage_service = UsageTrackingService(db)
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
current_usage = usage_service.get_user_usage_stats(user_id)
|
||||
trends = usage_service.get_usage_trends(user_id, 6)
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
|
||||
alerts = db.query(UsageAlert).filter(
|
||||
UsageAlert.user_id == user_id,
|
||||
UsageAlert.is_read == False
|
||||
).order_by(UsageAlert.created_at.desc()).limit(5).all()
|
||||
|
||||
alerts_data = [
|
||||
{
|
||||
"id": alert.id,
|
||||
"type": alert.alert_type,
|
||||
"title": alert.title,
|
||||
"message": alert.message,
|
||||
"severity": alert.severity,
|
||||
"created_at": alert.created_at.isoformat()
|
||||
}
|
||||
for alert in alerts
|
||||
]
|
||||
|
||||
current_cost = current_usage.get('total_cost', 0)
|
||||
days_in_period = 30
|
||||
current_day = datetime.now().day
|
||||
projected_cost = (current_cost / current_day) * days_in_period if current_day > 0 else 0
|
||||
|
||||
response_payload = {
|
||||
"success": True,
|
||||
"data": {
|
||||
"current_usage": current_usage,
|
||||
"trends": trends,
|
||||
"limits": limits,
|
||||
"alerts": alerts_data,
|
||||
"projections": {
|
||||
"projected_monthly_cost": round(projected_cost, 2),
|
||||
"cost_limit": limits.get('limits', {}).get('monthly_cost', 0) if limits else 0,
|
||||
"projected_usage_percentage": (projected_cost / max(limits.get('limits', {}).get('monthly_cost', 1), 1)) * 100 if limits else 0
|
||||
},
|
||||
"summary": {
|
||||
"total_api_calls_this_month": current_usage.get('total_calls', 0),
|
||||
"total_cost_this_month": current_usage.get('total_cost', 0),
|
||||
"usage_status": current_usage.get('usage_status', 'active'),
|
||||
"unread_alerts": len(alerts_data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Cache the response after successful retry
|
||||
set_cached_dashboard(user_id, response_payload)
|
||||
return response_payload
|
||||
except Exception as retry_err:
|
||||
logger.error(f"Schema fix and retry failed: {retry_err}")
|
||||
raise HTTPException(status_code=500, detail=f"Database schema error: {str(e)}")
|
||||
|
||||
logger.error(f"Error getting dashboard data: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
198
backend/api/subscription/routes/logs.py
Normal file
198
backend/api/subscription/routes/logs.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
API usage logs endpoints.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
from typing import Dict, Any, Optional
|
||||
from loguru import logger
|
||||
import sqlite3
|
||||
|
||||
from services.database import get_db
|
||||
from services.subscription.log_wrapping_service import LogWrappingService
|
||||
from services.subscription.schema_utils import ensure_api_usage_logs_columns
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from models.subscription_models import APIProvider, APIUsageLog
|
||||
from ..dependencies import get_user_id_from_token
|
||||
from ..utils import handle_schema_error
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/usage-logs")
|
||||
async def get_usage_logs(
|
||||
limit: int = Query(50, ge=1, le=5000, description="Number of logs to return"),
|
||||
offset: int = Query(0, ge=0, description="Pagination offset"),
|
||||
provider: Optional[str] = Query(None, description="Filter by provider"),
|
||||
status_code: Optional[int] = Query(None, description="Filter by HTTP status code"),
|
||||
billing_period: Optional[str] = Query(None, description="Filter by billing period (YYYY-MM)"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get API usage logs for the current user.
|
||||
|
||||
Query Params:
|
||||
- limit: Number of logs to return (1-5000, default: 50)
|
||||
- offset: Pagination offset (default: 0)
|
||||
- provider: Filter by provider (e.g., "gemini", "openai", "huggingface")
|
||||
- status_code: Filter by HTTP status code (e.g., 200 for success, 400+ for errors)
|
||||
- billing_period: Filter by billing period (YYYY-MM format)
|
||||
|
||||
Returns:
|
||||
- List of usage logs with API call details
|
||||
- Total count for pagination
|
||||
"""
|
||||
try:
|
||||
# Get user_id from current_user
|
||||
user_id = get_user_id_from_token(current_user)
|
||||
|
||||
# Ensure schema columns exist (especially actual_provider_name)
|
||||
ensure_api_usage_logs_columns(db)
|
||||
|
||||
# Build query
|
||||
query = db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if provider:
|
||||
provider_lower = provider.lower()
|
||||
# Handle special case: huggingface maps to MISTRAL enum in database
|
||||
if provider_lower == "huggingface":
|
||||
provider_enum = APIProvider.MISTRAL
|
||||
else:
|
||||
try:
|
||||
provider_enum = APIProvider(provider_lower)
|
||||
except ValueError:
|
||||
# Invalid provider, return empty results
|
||||
return {
|
||||
"logs": [],
|
||||
"total_count": 0,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"has_more": False
|
||||
}
|
||||
query = query.filter(APIUsageLog.provider == provider_enum)
|
||||
|
||||
if status_code is not None:
|
||||
query = query.filter(APIUsageLog.status_code == status_code)
|
||||
|
||||
if billing_period:
|
||||
query = query.filter(APIUsageLog.billing_period == billing_period)
|
||||
|
||||
# Check and wrap logs if necessary (before getting count)
|
||||
wrapping_service = LogWrappingService(db)
|
||||
wrap_result = wrapping_service.check_and_wrap_logs(user_id)
|
||||
if wrap_result.get('wrapped'):
|
||||
logger.info(f"[UsageLogs] Log wrapping completed for user {user_id}: {wrap_result.get('message')}")
|
||||
# Rebuild query after wrapping (in case filters changed)
|
||||
query = db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id
|
||||
)
|
||||
# Reapply filters
|
||||
if provider:
|
||||
provider_lower = provider.lower()
|
||||
if provider_lower == "huggingface":
|
||||
provider_enum = APIProvider.MISTRAL
|
||||
else:
|
||||
try:
|
||||
provider_enum = APIProvider(provider_lower)
|
||||
except ValueError:
|
||||
return {
|
||||
"logs": [],
|
||||
"total_count": 0,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"has_more": False
|
||||
}
|
||||
query = query.filter(APIUsageLog.provider == provider_enum)
|
||||
if status_code is not None:
|
||||
query = query.filter(APIUsageLog.status_code == status_code)
|
||||
if billing_period:
|
||||
query = query.filter(APIUsageLog.billing_period == billing_period)
|
||||
|
||||
# Get total count
|
||||
total_count = query.count()
|
||||
|
||||
# Get paginated results, ordered by timestamp descending (most recent first)
|
||||
logs = query.order_by(desc(APIUsageLog.timestamp)).offset(offset).limit(limit).all()
|
||||
|
||||
# Format logs for response
|
||||
formatted_logs = []
|
||||
for log in logs:
|
||||
# Determine status based on status_code
|
||||
status = 'success' if 200 <= log.status_code < 300 else 'failed'
|
||||
|
||||
# Handle provider display name - use actual_provider_name if available, otherwise detect from model/endpoint
|
||||
# This correctly identifies WaveSpeed, Google, HuggingFace, etc. instead of generic VIDEO/AUDIO/STABILITY
|
||||
provider_display = None
|
||||
actual_provider_name = None
|
||||
|
||||
# Safely get actual_provider_name (column may not exist yet)
|
||||
try:
|
||||
actual_provider_name = getattr(log, 'actual_provider_name', None)
|
||||
except (AttributeError, KeyError):
|
||||
actual_provider_name = None
|
||||
|
||||
if actual_provider_name:
|
||||
# Use the actual provider name (WaveSpeed, Google, HuggingFace, etc.)
|
||||
provider_display = actual_provider_name
|
||||
else:
|
||||
# For old logs without actual_provider_name, detect from model name and endpoint
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
provider_display = detect_actual_provider(
|
||||
provider_enum=log.provider,
|
||||
model_name=log.model_used,
|
||||
endpoint=log.endpoint
|
||||
)
|
||||
# Special handling for MISTRAL (HuggingFace)
|
||||
if provider_display == "mistral":
|
||||
provider_display = "huggingface"
|
||||
|
||||
formatted_logs.append({
|
||||
'id': log.id,
|
||||
'timestamp': log.timestamp.isoformat() if log.timestamp else None,
|
||||
'provider': provider_display,
|
||||
'actual_provider_name': actual_provider_name, # Include for frontend use
|
||||
'model_used': log.model_used,
|
||||
'endpoint': log.endpoint,
|
||||
'method': log.method,
|
||||
'tokens_input': log.tokens_input or 0,
|
||||
'tokens_output': log.tokens_output or 0,
|
||||
'tokens_total': log.tokens_total or 0,
|
||||
'cost_input': float(log.cost_input) if log.cost_input else 0.0,
|
||||
'cost_output': float(log.cost_output) if log.cost_output else 0.0,
|
||||
'cost_total': float(log.cost_total) if log.cost_total else 0.0,
|
||||
'response_time': float(log.response_time) if log.response_time else 0.0,
|
||||
'status_code': log.status_code,
|
||||
'status': status,
|
||||
'error_message': log.error_message,
|
||||
'billing_period': log.billing_period,
|
||||
'retry_count': log.retry_count or 0,
|
||||
'is_aggregated': log.endpoint == "[AGGREGATED]" # Flag to indicate aggregated log
|
||||
})
|
||||
|
||||
return {
|
||||
"logs": formatted_logs,
|
||||
"total_count": total_count,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"has_more": (offset + limit) < total_count
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except (sqlite3.OperationalError, Exception) as e:
|
||||
error_str = str(e).lower()
|
||||
if 'no such column' in error_str and 'actual_provider_name' in error_str:
|
||||
return handle_schema_error(
|
||||
e,
|
||||
db,
|
||||
error_str,
|
||||
lambda: get_usage_logs(limit, offset, provider, status_code, billing_period, current_user, db)
|
||||
)
|
||||
|
||||
logger.error(f"Error getting usage logs: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get usage logs: {str(e)}")
|
||||
120
backend/api/subscription/routes/plans.py
Normal file
120
backend/api/subscription/routes/plans.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
Subscription plans endpoints.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Dict, Any
|
||||
from loguru import logger
|
||||
import sqlite3
|
||||
|
||||
from services.database import get_db
|
||||
from models.subscription_models import SubscriptionPlan
|
||||
from services.subscription.schema_utils import ensure_subscription_plan_columns
|
||||
from ..utils import format_plan_limits, handle_schema_error
|
||||
from fastapi import Query
|
||||
from typing import Optional
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/plans")
|
||||
async def get_subscription_plans(
|
||||
db: Session = Depends(get_db)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get all available subscription plans."""
|
||||
|
||||
try:
|
||||
ensure_subscription_plan_columns(db)
|
||||
except Exception as schema_err:
|
||||
logger.warning(f"Schema check failed, will retry on query: {schema_err}")
|
||||
|
||||
try:
|
||||
plans = db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.is_active == True
|
||||
).order_by(SubscriptionPlan.price_monthly).all()
|
||||
|
||||
plans_data = []
|
||||
for plan in plans:
|
||||
plans_data.append({
|
||||
"id": plan.id,
|
||||
"name": plan.name,
|
||||
"tier": plan.tier.value,
|
||||
"price_monthly": plan.price_monthly,
|
||||
"price_yearly": plan.price_yearly,
|
||||
"description": plan.description,
|
||||
"features": plan.features or [],
|
||||
"limits": format_plan_limits(plan)
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"plans": plans_data,
|
||||
"total": len(plans_data)
|
||||
}
|
||||
}
|
||||
|
||||
except (sqlite3.OperationalError, Exception) as e:
|
||||
error_str = str(e).lower()
|
||||
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str or 'audio_calls_limit' in error_str):
|
||||
return handle_schema_error(
|
||||
e,
|
||||
db,
|
||||
error_str,
|
||||
lambda: get_subscription_plans(db)
|
||||
)
|
||||
|
||||
logger.error(f"Error getting subscription plans: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/pricing")
|
||||
async def get_api_pricing(
|
||||
provider: Optional[str] = Query(None, description="API provider"),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get API pricing information."""
|
||||
|
||||
try:
|
||||
from models.subscription_models import APIProvider, APIProviderPricing
|
||||
|
||||
query = db.query(APIProviderPricing).filter(
|
||||
APIProviderPricing.is_active == True
|
||||
)
|
||||
|
||||
if provider:
|
||||
try:
|
||||
api_provider = APIProvider(provider.lower())
|
||||
query = query.filter(APIProviderPricing.provider == api_provider)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid provider: {provider}")
|
||||
|
||||
pricing_data = query.all()
|
||||
|
||||
pricing_list = []
|
||||
for pricing in pricing_data:
|
||||
pricing_list.append({
|
||||
"provider": pricing.provider.value,
|
||||
"model_name": pricing.model_name,
|
||||
"cost_per_input_token": pricing.cost_per_input_token,
|
||||
"cost_per_output_token": pricing.cost_per_output_token,
|
||||
"cost_per_request": pricing.cost_per_request,
|
||||
"cost_per_search": pricing.cost_per_search,
|
||||
"cost_per_image": pricing.cost_per_image,
|
||||
"cost_per_page": pricing.cost_per_page,
|
||||
"description": pricing.description,
|
||||
"effective_date": pricing.effective_date.isoformat()
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"pricing": pricing_list,
|
||||
"total": len(pricing_list)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting API pricing: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
233
backend/api/subscription/routes/preflight.py
Normal file
233
backend/api/subscription/routes/preflight.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
Pre-flight check endpoints for operation validation and cost estimation.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Dict, Any
|
||||
from loguru import logger
|
||||
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.schema_utils import ensure_subscription_plan_columns, ensure_usage_summaries_columns
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from models.subscription_models import APIProvider, UsageSummary
|
||||
from ..dependencies import get_user_id_from_token
|
||||
from ..models import PreflightCheckRequest
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/preflight-check")
|
||||
async def preflight_check(
|
||||
request: PreflightCheckRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Pre-flight check for operations with cost estimation.
|
||||
|
||||
Lightweight endpoint that:
|
||||
- Validates if operations are allowed based on subscription limits
|
||||
- Estimates cost for operations
|
||||
- Returns usage information and remaining quota
|
||||
|
||||
Uses caching to minimize DB load (< 100ms with cache hit).
|
||||
"""
|
||||
try:
|
||||
user_id = get_user_id_from_token(current_user)
|
||||
|
||||
# Ensure schema columns exist
|
||||
try:
|
||||
ensure_subscription_plan_columns(db)
|
||||
ensure_usage_summaries_columns(db)
|
||||
except Exception as schema_err:
|
||||
logger.warning(f"Schema check failed: {schema_err}")
|
||||
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
# Convert request operations to internal format
|
||||
operations_to_validate = []
|
||||
for op in request.operations:
|
||||
try:
|
||||
# Map provider string to APIProvider enum
|
||||
provider_str = op.provider.lower()
|
||||
if provider_str == "huggingface":
|
||||
provider_enum = APIProvider.MISTRAL # Maps to HuggingFace
|
||||
elif provider_str == "video":
|
||||
provider_enum = APIProvider.VIDEO
|
||||
elif provider_str == "image_edit":
|
||||
provider_enum = APIProvider.IMAGE_EDIT
|
||||
elif provider_str == "stability":
|
||||
provider_enum = APIProvider.STABILITY
|
||||
elif provider_str == "audio":
|
||||
provider_enum = APIProvider.AUDIO
|
||||
else:
|
||||
try:
|
||||
provider_enum = APIProvider(provider_str)
|
||||
except ValueError:
|
||||
logger.warning(f"Unknown provider: {provider_str}, skipping")
|
||||
continue
|
||||
|
||||
operations_to_validate.append({
|
||||
'provider': provider_enum,
|
||||
'tokens_requested': op.tokens_requested or 0,
|
||||
'actual_provider_name': op.actual_provider_name or op.provider,
|
||||
'operation_type': op.operation_type
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing operation {op.operation_type}: {e}")
|
||||
continue
|
||||
|
||||
if not operations_to_validate:
|
||||
raise HTTPException(status_code=400, detail="No valid operations provided")
|
||||
|
||||
# Perform pre-flight validation
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
# Get pricing and cost estimation for each operation
|
||||
operation_results = []
|
||||
total_cost = 0.0
|
||||
|
||||
for i, op in enumerate(operations_to_validate):
|
||||
op_result = {
|
||||
'provider': op['actual_provider_name'],
|
||||
'operation_type': op['operation_type'],
|
||||
'cost': 0.0,
|
||||
'allowed': can_proceed,
|
||||
'limit_info': None,
|
||||
'message': None
|
||||
}
|
||||
|
||||
# Get pricing for this operation
|
||||
model_name = request.operations[i].model
|
||||
if model_name:
|
||||
pricing_info = pricing_service.get_pricing_for_provider_model(
|
||||
op['provider'],
|
||||
model_name
|
||||
)
|
||||
|
||||
if pricing_info:
|
||||
# Determine cost based on operation type
|
||||
if op['provider'] in [APIProvider.VIDEO, APIProvider.IMAGE_EDIT, APIProvider.STABILITY]:
|
||||
cost = pricing_info.get('cost_per_request', 0.0) or pricing_info.get('cost_per_image', 0.0) or 0.0
|
||||
elif op['provider'] == APIProvider.AUDIO:
|
||||
# Audio pricing is per character (every character is 1 token)
|
||||
cost = (pricing_info.get('cost_per_input_token', 0.0) or 0.0) * (op['tokens_requested'] / 1000.0)
|
||||
elif op['tokens_requested'] > 0:
|
||||
# Token-based cost estimation (rough estimate)
|
||||
cost = (pricing_info.get('cost_per_input_token', 0.0) or 0.0) * (op['tokens_requested'] / 1000)
|
||||
else:
|
||||
cost = pricing_info.get('cost_per_request', 0.0) or 0.0
|
||||
|
||||
op_result['cost'] = round(cost, 4)
|
||||
total_cost += cost
|
||||
else:
|
||||
# Use default cost if pricing not found
|
||||
if op['provider'] == APIProvider.VIDEO:
|
||||
op_result['cost'] = 0.10 # Default video cost
|
||||
total_cost += 0.10
|
||||
elif op['provider'] == APIProvider.IMAGE_EDIT:
|
||||
op_result['cost'] = 0.05 # Default image edit cost
|
||||
total_cost += 0.05
|
||||
elif op['provider'] == APIProvider.STABILITY:
|
||||
op_result['cost'] = 0.04 # Default image generation cost
|
||||
total_cost += 0.04
|
||||
elif op['provider'] == APIProvider.AUDIO:
|
||||
# Default audio cost: $0.05 per 1,000 characters
|
||||
cost = (op['tokens_requested'] / 1000.0) * 0.05
|
||||
op_result['cost'] = round(cost, 4)
|
||||
total_cost += cost
|
||||
|
||||
# Get limit information
|
||||
limit_info = None
|
||||
if error_details and not can_proceed:
|
||||
usage_info = error_details.get('usage_info', {})
|
||||
if usage_info:
|
||||
op_result['message'] = message
|
||||
limit_info = {
|
||||
'current_usage': usage_info.get('current_usage', 0),
|
||||
'limit': usage_info.get('limit', 0),
|
||||
'remaining': max(0, usage_info.get('limit', 0) - usage_info.get('current_usage', 0))
|
||||
}
|
||||
op_result['limit_info'] = limit_info
|
||||
else:
|
||||
# Get current usage for this provider
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
if limits:
|
||||
usage_summary = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == pricing_service.get_current_billing_period(user_id)
|
||||
).first()
|
||||
|
||||
if usage_summary:
|
||||
if op['provider'] == APIProvider.VIDEO:
|
||||
current = getattr(usage_summary, 'video_calls', 0) or 0
|
||||
limit = limits['limits'].get('video_calls', 0)
|
||||
elif op['provider'] == APIProvider.IMAGE_EDIT:
|
||||
current = getattr(usage_summary, 'image_edit_calls', 0) or 0
|
||||
limit = limits['limits'].get('image_edit_calls', 0)
|
||||
elif op['provider'] == APIProvider.STABILITY:
|
||||
current = getattr(usage_summary, 'stability_calls', 0) or 0
|
||||
limit = limits['limits'].get('stability_calls', 0)
|
||||
elif op['provider'] == APIProvider.AUDIO:
|
||||
current = getattr(usage_summary, 'audio_calls', 0) or 0
|
||||
limit = limits['limits'].get('audio_calls', 0)
|
||||
else:
|
||||
# For LLM providers, use token limits
|
||||
provider_key = op['provider'].value
|
||||
current_tokens = getattr(usage_summary, f"{provider_key}_tokens", 0) or 0
|
||||
limit = limits['limits'].get(f"{provider_key}_tokens", 0)
|
||||
current = current_tokens
|
||||
|
||||
limit_info = {
|
||||
'current_usage': current,
|
||||
'limit': limit,
|
||||
'remaining': max(0, limit - current) if limit > 0 else float('inf')
|
||||
}
|
||||
op_result['limit_info'] = limit_info
|
||||
|
||||
operation_results.append(op_result)
|
||||
|
||||
# Get overall usage summary
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
usage_summary = None
|
||||
if limits:
|
||||
usage_summary = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == pricing_service.get_current_billing_period(user_id)
|
||||
).first()
|
||||
|
||||
response_data = {
|
||||
'can_proceed': can_proceed,
|
||||
'estimated_cost': round(total_cost, 4),
|
||||
'operations': operation_results,
|
||||
'total_cost': round(total_cost, 4),
|
||||
'usage_summary': None,
|
||||
'cached': False # TODO: Track if result was cached
|
||||
}
|
||||
|
||||
if usage_summary and limits:
|
||||
# For video generation, show video limits
|
||||
video_current = getattr(usage_summary, 'video_calls', 0) or 0
|
||||
video_limit = limits['limits'].get('video_calls', 0)
|
||||
|
||||
response_data['usage_summary'] = {
|
||||
'current_calls': video_current,
|
||||
'limit': video_limit,
|
||||
'remaining': max(0, video_limit - video_current) if video_limit > 0 else float('inf')
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": response_data
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in pre-flight check: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Pre-flight check failed: {str(e)}")
|
||||
631
backend/api/subscription/routes/subscriptions.py
Normal file
631
backend/api/subscription/routes/subscriptions.py
Normal file
@@ -0,0 +1,631 @@
|
||||
"""
|
||||
User subscription management endpoints.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from loguru import logger
|
||||
import sqlite3
|
||||
|
||||
from services.database import get_db
|
||||
from services.subscription import UsageTrackingService, PricingService
|
||||
from services.subscription.schema_utils import ensure_subscription_plan_columns
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from models.subscription_models import (
|
||||
SubscriptionPlan, UserSubscription, UsageSummary,
|
||||
SubscriptionTier, BillingCycle, UsageStatus, SubscriptionRenewalHistory
|
||||
)
|
||||
from ..dependencies import verify_user_access
|
||||
from ..utils import format_plan_limits, handle_schema_error
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/user/{user_id}/subscription")
|
||||
async def get_user_subscription(
|
||||
user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get user's current subscription information."""
|
||||
|
||||
verify_user_access(user_id, current_user)
|
||||
|
||||
try:
|
||||
ensure_subscription_plan_columns(db)
|
||||
subscription = db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
|
||||
if not subscription:
|
||||
# Return free tier information
|
||||
free_plan = db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.FREE
|
||||
).first()
|
||||
|
||||
if free_plan:
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"subscription": None,
|
||||
"plan": {
|
||||
"id": free_plan.id,
|
||||
"name": free_plan.name,
|
||||
"tier": free_plan.tier.value,
|
||||
"price_monthly": free_plan.price_monthly,
|
||||
"description": free_plan.description,
|
||||
"is_free": True
|
||||
},
|
||||
"status": "free",
|
||||
"limits": format_plan_limits(free_plan)
|
||||
}
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="No subscription plan found")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"subscription": {
|
||||
"id": subscription.id,
|
||||
"billing_cycle": subscription.billing_cycle.value,
|
||||
"current_period_start": subscription.current_period_start.isoformat(),
|
||||
"current_period_end": subscription.current_period_end.isoformat(),
|
||||
"status": subscription.status.value,
|
||||
"auto_renew": subscription.auto_renew,
|
||||
"created_at": subscription.created_at.isoformat()
|
||||
},
|
||||
"plan": {
|
||||
"id": subscription.plan.id,
|
||||
"name": subscription.plan.name,
|
||||
"tier": subscription.plan.tier.value,
|
||||
"price_monthly": subscription.plan.price_monthly,
|
||||
"price_yearly": subscription.plan.price_yearly,
|
||||
"description": subscription.plan.description,
|
||||
"is_free": False
|
||||
},
|
||||
"limits": format_plan_limits(subscription.plan)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user subscription: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/status/{user_id}")
|
||||
async def get_subscription_status(
|
||||
user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get simple subscription status for enforcement checks."""
|
||||
|
||||
verify_user_access(user_id, current_user)
|
||||
|
||||
try:
|
||||
ensure_subscription_plan_columns(db)
|
||||
except Exception as schema_err:
|
||||
logger.warning(f"Schema check failed, will retry on query: {schema_err}")
|
||||
|
||||
try:
|
||||
subscription = db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
|
||||
if not subscription:
|
||||
# Check if free tier exists
|
||||
free_plan = db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.FREE,
|
||||
SubscriptionPlan.is_active == True
|
||||
).first()
|
||||
|
||||
if free_plan:
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"active": True,
|
||||
"plan": "free",
|
||||
"tier": "free",
|
||||
"can_use_api": True,
|
||||
"limits": format_plan_limits(free_plan)
|
||||
}
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"active": False,
|
||||
"plan": "none",
|
||||
"tier": "none",
|
||||
"can_use_api": False,
|
||||
"reason": "No active subscription or free tier found"
|
||||
}
|
||||
}
|
||||
|
||||
# Check if subscription is within valid period; auto-advance if expired and auto_renew
|
||||
now = datetime.utcnow()
|
||||
if subscription.current_period_end < now:
|
||||
if getattr(subscription, 'auto_renew', False):
|
||||
# advance period
|
||||
try:
|
||||
from services.pricing_service import PricingService
|
||||
pricing = PricingService(db)
|
||||
# reuse helper to ensure current
|
||||
pricing._ensure_subscription_current(subscription)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to auto-advance subscription: {e}")
|
||||
else:
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"active": False,
|
||||
"plan": subscription.plan.tier.value,
|
||||
"tier": subscription.plan.tier.value,
|
||||
"can_use_api": False,
|
||||
"reason": "Subscription expired"
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"active": True,
|
||||
"plan": subscription.plan.tier.value,
|
||||
"tier": subscription.plan.tier.value,
|
||||
"can_use_api": True,
|
||||
"limits": format_plan_limits(subscription.plan)
|
||||
}
|
||||
}
|
||||
|
||||
except (sqlite3.OperationalError, Exception) as e:
|
||||
error_str = str(e).lower()
|
||||
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str or 'audio_calls_limit' in error_str):
|
||||
# Try to fix schema and retry once
|
||||
logger.warning("Missing column detected in subscription status query, attempting schema fix...")
|
||||
try:
|
||||
import services.subscription.schema_utils as schema_utils
|
||||
schema_utils._checked_subscription_plan_columns = False
|
||||
ensure_subscription_plan_columns(db)
|
||||
db.commit() # Ensure schema changes are committed
|
||||
db.expire_all()
|
||||
# Retry the query - query subscription without eager loading plan
|
||||
subscription = db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
|
||||
if not subscription:
|
||||
free_plan = db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.FREE,
|
||||
SubscriptionPlan.is_active == True
|
||||
).first()
|
||||
if free_plan:
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"active": True,
|
||||
"plan": "free",
|
||||
"tier": "free",
|
||||
"can_use_api": True,
|
||||
"limits": format_plan_limits(free_plan)
|
||||
}
|
||||
}
|
||||
elif subscription:
|
||||
# Query plan separately after schema fix to avoid lazy loading issues
|
||||
plan = db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.id == subscription.plan_id
|
||||
).first()
|
||||
|
||||
if not plan:
|
||||
raise HTTPException(status_code=404, detail="Plan not found")
|
||||
|
||||
now = datetime.utcnow()
|
||||
if subscription.current_period_end < now:
|
||||
if getattr(subscription, 'auto_renew', False):
|
||||
try:
|
||||
from services.pricing_service import PricingService
|
||||
pricing = PricingService(db)
|
||||
pricing._ensure_subscription_current(subscription)
|
||||
except Exception as e2:
|
||||
logger.error(f"Failed to auto-advance subscription: {e2}")
|
||||
else:
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"active": False,
|
||||
"plan": plan.tier.value,
|
||||
"tier": plan.tier.value,
|
||||
"can_use_api": False,
|
||||
"reason": "Subscription expired"
|
||||
}
|
||||
}
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"active": True,
|
||||
"plan": plan.tier.value,
|
||||
"tier": plan.tier.value,
|
||||
"can_use_api": True,
|
||||
"limits": format_plan_limits(plan)
|
||||
}
|
||||
}
|
||||
except Exception as retry_err:
|
||||
logger.error(f"Schema fix and retry failed: {retry_err}")
|
||||
raise HTTPException(status_code=500, detail=f"Database schema error: {str(e)}")
|
||||
|
||||
logger.error(f"Error getting subscription status: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/subscribe/{user_id}")
|
||||
async def subscribe_to_plan(
|
||||
user_id: str,
|
||||
subscription_data: dict,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Create or update a user's subscription (renewal)."""
|
||||
|
||||
verify_user_access(user_id, current_user)
|
||||
|
||||
try:
|
||||
ensure_subscription_plan_columns(db)
|
||||
plan_id = subscription_data.get('plan_id')
|
||||
billing_cycle = subscription_data.get('billing_cycle', 'monthly')
|
||||
|
||||
if not plan_id:
|
||||
raise HTTPException(status_code=400, detail="plan_id is required")
|
||||
|
||||
# Get the plan
|
||||
plan = db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.id == plan_id,
|
||||
SubscriptionPlan.is_active == True
|
||||
).first()
|
||||
|
||||
if not plan:
|
||||
raise HTTPException(status_code=404, detail="Plan not found")
|
||||
|
||||
# Check if user already has an active subscription
|
||||
existing_subscription = db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Track renewal history - capture BEFORE updating subscription
|
||||
previous_period_start = None
|
||||
previous_period_end = None
|
||||
previous_plan_name = None
|
||||
previous_plan_tier = None
|
||||
renewal_type = "new"
|
||||
renewal_count = 0
|
||||
|
||||
# Get usage snapshot BEFORE renewal (capture current state)
|
||||
usage_before_snapshot = None
|
||||
current_period = datetime.utcnow().strftime("%Y-%m")
|
||||
usage_before = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if usage_before:
|
||||
usage_before_snapshot = {
|
||||
"total_calls": usage_before.total_calls or 0,
|
||||
"total_tokens": usage_before.total_tokens or 0,
|
||||
"total_cost": float(usage_before.total_cost) if usage_before.total_cost else 0.0,
|
||||
"gemini_calls": usage_before.gemini_calls or 0,
|
||||
"mistral_calls": usage_before.mistral_calls or 0,
|
||||
"usage_status": usage_before.usage_status.value if hasattr(usage_before.usage_status, 'value') else str(usage_before.usage_status)
|
||||
}
|
||||
|
||||
if existing_subscription:
|
||||
# This is a renewal/update - capture previous subscription state BEFORE updating
|
||||
previous_period_start = existing_subscription.current_period_start
|
||||
previous_period_end = existing_subscription.current_period_end
|
||||
previous_plan = existing_subscription.plan
|
||||
previous_plan_name = previous_plan.name if previous_plan else None
|
||||
previous_plan_tier = previous_plan.tier.value if previous_plan else None
|
||||
|
||||
# Determine renewal type
|
||||
if previous_plan and previous_plan.id == plan_id:
|
||||
# Same plan - this is a renewal
|
||||
renewal_type = "renewal"
|
||||
elif previous_plan:
|
||||
# Different plan - check if upgrade or downgrade
|
||||
tier_order = {"free": 0, "basic": 1, "pro": 2, "enterprise": 3}
|
||||
previous_tier_order = tier_order.get(previous_plan_tier or "free", 0)
|
||||
new_tier_order = tier_order.get(plan.tier.value, 0)
|
||||
if new_tier_order > previous_tier_order:
|
||||
renewal_type = "upgrade"
|
||||
elif new_tier_order < previous_tier_order:
|
||||
renewal_type = "downgrade"
|
||||
else:
|
||||
renewal_type = "renewal" # Same tier, different plan name
|
||||
|
||||
# Get renewal count (how many times this user has renewed)
|
||||
last_renewal = db.query(SubscriptionRenewalHistory).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id
|
||||
).order_by(SubscriptionRenewalHistory.created_at.desc()).first()
|
||||
|
||||
if last_renewal:
|
||||
renewal_count = last_renewal.renewal_count + 1
|
||||
else:
|
||||
renewal_count = 1 # First renewal
|
||||
|
||||
# Update existing subscription
|
||||
existing_subscription.plan_id = plan_id
|
||||
existing_subscription.billing_cycle = BillingCycle(billing_cycle)
|
||||
existing_subscription.current_period_start = now
|
||||
existing_subscription.current_period_end = now + timedelta(
|
||||
days=365 if billing_cycle == 'yearly' else 30
|
||||
)
|
||||
existing_subscription.updated_at = now
|
||||
|
||||
subscription = existing_subscription
|
||||
else:
|
||||
# Create new subscription
|
||||
subscription = UserSubscription(
|
||||
user_id=user_id,
|
||||
plan_id=plan_id,
|
||||
billing_cycle=BillingCycle(billing_cycle),
|
||||
current_period_start=now,
|
||||
current_period_end=now + timedelta(
|
||||
days=365 if billing_cycle == 'yearly' else 30
|
||||
),
|
||||
status=UsageStatus.ACTIVE,
|
||||
is_active=True,
|
||||
auto_renew=True
|
||||
)
|
||||
db.add(subscription)
|
||||
|
||||
db.commit()
|
||||
|
||||
# Create renewal history record AFTER subscription update (so we have the new period_end)
|
||||
renewal_history = SubscriptionRenewalHistory(
|
||||
user_id=user_id,
|
||||
plan_id=plan_id,
|
||||
plan_name=plan.name,
|
||||
plan_tier=plan.tier.value,
|
||||
previous_period_start=previous_period_start,
|
||||
previous_period_end=previous_period_end,
|
||||
new_period_start=now,
|
||||
new_period_end=subscription.current_period_end,
|
||||
billing_cycle=BillingCycle(billing_cycle),
|
||||
renewal_type=renewal_type,
|
||||
renewal_count=renewal_count,
|
||||
previous_plan_name=previous_plan_name,
|
||||
previous_plan_tier=previous_plan_tier,
|
||||
usage_before_renewal=usage_before_snapshot, # Usage snapshot captured BEFORE renewal
|
||||
payment_amount=plan.price_yearly if billing_cycle == 'yearly' else plan.price_monthly,
|
||||
payment_status="paid", # Assume paid for now (can be updated if payment processing is added)
|
||||
payment_date=now
|
||||
)
|
||||
db.add(renewal_history)
|
||||
db.commit()
|
||||
|
||||
# Get current usage BEFORE reset for logging
|
||||
current_period = datetime.utcnow().strftime("%Y-%m")
|
||||
usage_before = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
# Log renewal request details
|
||||
logger.info("=" * 80)
|
||||
logger.info(f"[SUBSCRIPTION RENEWAL] 🔄 Processing renewal request")
|
||||
logger.info(f" ├─ User: {user_id}")
|
||||
logger.info(f" ├─ Plan: {plan.name} (ID: {plan_id}, Tier: {plan.tier.value})")
|
||||
logger.info(f" ├─ Billing Cycle: {billing_cycle}")
|
||||
logger.info(f" ├─ Period Start: {now.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
logger.info(f" └─ Period End: {subscription.current_period_end.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
if usage_before:
|
||||
logger.info(f" 📊 Current Usage BEFORE Reset (Period: {current_period}):")
|
||||
logger.info(f" ├─ Gemini: {usage_before.gemini_tokens or 0} tokens / {usage_before.gemini_calls or 0} calls")
|
||||
logger.info(f" ├─ Mistral/HF: {usage_before.mistral_tokens or 0} tokens / {usage_before.mistral_calls or 0} calls")
|
||||
logger.info(f" ├─ OpenAI: {usage_before.openai_tokens or 0} tokens / {usage_before.openai_calls or 0} calls")
|
||||
logger.info(f" ├─ Stability (Images): {usage_before.stability_calls or 0} calls")
|
||||
logger.info(f" ├─ Total Tokens: {usage_before.total_tokens or 0}")
|
||||
logger.info(f" ├─ Total Calls: {usage_before.total_calls or 0}")
|
||||
logger.info(f" └─ Usage Status: {usage_before.usage_status.value}")
|
||||
else:
|
||||
logger.info(f" 📊 No usage summary found for period {current_period} (will be created on reset)")
|
||||
|
||||
# Clear subscription limits cache to force refresh on next check
|
||||
# IMPORTANT: Do this BEFORE resetting usage to ensure cache is cleared first
|
||||
try:
|
||||
from services.subscription import PricingService
|
||||
# Clear cache for this specific user (class-level cache shared across all instances)
|
||||
cleared_count = PricingService.clear_user_cache(user_id)
|
||||
logger.info(f" 🗑️ Cleared {cleared_count} subscription cache entries for user {user_id}")
|
||||
|
||||
# Also expire all SQLAlchemy objects to force fresh reads
|
||||
db.expire_all()
|
||||
logger.info(f" 🔄 Expired all SQLAlchemy objects to force fresh reads")
|
||||
except Exception as cache_err:
|
||||
logger.error(f" ❌ Failed to clear cache after subscribe: {cache_err}")
|
||||
|
||||
# Reset usage status for current billing period so new plan takes effect immediately
|
||||
reset_result = None
|
||||
try:
|
||||
usage_service = UsageTrackingService(db)
|
||||
reset_result = await usage_service.reset_current_billing_period(user_id)
|
||||
|
||||
# Force commit to ensure reset is persisted
|
||||
db.commit()
|
||||
|
||||
# Expire all SQLAlchemy objects to force fresh reads
|
||||
db.expire_all()
|
||||
|
||||
# Re-query usage summary from DB after reset to get fresh data (fresh query)
|
||||
usage_after = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
# Refresh the usage object if found to ensure we have latest data
|
||||
if usage_after:
|
||||
db.refresh(usage_after)
|
||||
|
||||
if reset_result.get('reset'):
|
||||
logger.info(f" ✅ Usage counters RESET successfully")
|
||||
if usage_after:
|
||||
logger.info(f" 📊 New Usage AFTER Reset:")
|
||||
logger.info(f" ├─ Gemini: {usage_after.gemini_tokens or 0} tokens / {usage_after.gemini_calls or 0} calls")
|
||||
logger.info(f" ├─ Mistral/HF: {usage_after.mistral_tokens or 0} tokens / {usage_after.mistral_calls or 0} calls")
|
||||
logger.info(f" ├─ OpenAI: {usage_after.openai_tokens or 0} tokens / {usage_after.openai_calls or 0} calls")
|
||||
logger.info(f" ├─ Stability (Images): {usage_after.stability_calls or 0} calls")
|
||||
logger.info(f" ├─ Total Tokens: {usage_after.total_tokens or 0}")
|
||||
logger.info(f" ├─ Total Calls: {usage_after.total_calls or 0}")
|
||||
logger.info(f" └─ Usage Status: {usage_after.usage_status.value}")
|
||||
else:
|
||||
logger.warning(f" ⚠️ Usage summary not found after reset - may need to be created on next API call")
|
||||
else:
|
||||
logger.warning(f" ⚠️ Reset returned: {reset_result.get('reason', 'unknown')}")
|
||||
except Exception as reset_err:
|
||||
logger.error(f" ❌ Failed to reset usage after subscribe: {reset_err}", exc_info=True)
|
||||
|
||||
logger.info(f" ✅ Renewal completed: User {user_id} → {plan.name} ({billing_cycle})")
|
||||
logger.info("=" * 80)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Successfully subscribed to {plan.name}",
|
||||
"data": {
|
||||
"subscription_id": subscription.id,
|
||||
"plan_name": plan.name,
|
||||
"billing_cycle": billing_cycle,
|
||||
"current_period_start": subscription.current_period_start.isoformat(),
|
||||
"current_period_end": subscription.current_period_end.isoformat(),
|
||||
"status": subscription.status.value,
|
||||
"limits": format_plan_limits(plan)
|
||||
}
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error subscribing to plan: {e}")
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/renewal-history/{user_id}")
|
||||
async def get_renewal_history(
|
||||
user_id: str,
|
||||
limit: int = Query(50, ge=1, le=100, description="Number of records to return"),
|
||||
offset: int = Query(0, ge=0, description="Pagination offset"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get subscription renewal history for a user.
|
||||
|
||||
Automatically applies retention policies:
|
||||
- Compresses usage snapshots for records 12-24 months old
|
||||
- Removes usage snapshots for records 24-84 months old
|
||||
- Preserves payment data indefinitely
|
||||
|
||||
Returns:
|
||||
- List of renewal history records
|
||||
- Total count for pagination
|
||||
"""
|
||||
try:
|
||||
verify_user_access(user_id, current_user)
|
||||
|
||||
# Apply retention policies before fetching
|
||||
from services.subscription.renewal_history_retention import RenewalHistoryRetentionService
|
||||
retention_service = RenewalHistoryRetentionService(db)
|
||||
retention_result = retention_service.check_and_apply_retention(user_id)
|
||||
if retention_result.get('retention_applied'):
|
||||
logger.info(f"[RenewalHistory] Retention applied for user {user_id}: {retention_result.get('message')}")
|
||||
|
||||
# Get total count
|
||||
total_count = db.query(SubscriptionRenewalHistory).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id
|
||||
).count()
|
||||
|
||||
# Get paginated results, ordered by created_at descending (most recent first)
|
||||
renewals = db.query(SubscriptionRenewalHistory).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id
|
||||
).order_by(SubscriptionRenewalHistory.created_at.desc()).offset(offset).limit(limit).all()
|
||||
|
||||
# Format renewal history for response
|
||||
renewal_history = []
|
||||
for renewal in renewals:
|
||||
renewal_history.append({
|
||||
'id': renewal.id,
|
||||
'plan_name': renewal.plan_name,
|
||||
'plan_tier': renewal.plan_tier,
|
||||
'previous_period_start': renewal.previous_period_start.isoformat() if renewal.previous_period_start else None,
|
||||
'previous_period_end': renewal.previous_period_end.isoformat() if renewal.previous_period_end else None,
|
||||
'new_period_start': renewal.new_period_start.isoformat() if renewal.new_period_start else None,
|
||||
'new_period_end': renewal.new_period_end.isoformat() if renewal.new_period_end else None,
|
||||
'billing_cycle': renewal.billing_cycle.value if renewal.billing_cycle else None,
|
||||
'renewal_type': renewal.renewal_type,
|
||||
'renewal_count': renewal.renewal_count,
|
||||
'previous_plan_name': renewal.previous_plan_name,
|
||||
'previous_plan_tier': renewal.previous_plan_tier,
|
||||
'usage_before_renewal': renewal.usage_before_renewal,
|
||||
'payment_amount': float(renewal.payment_amount) if renewal.payment_amount else 0.0,
|
||||
'payment_status': renewal.payment_status,
|
||||
'payment_date': renewal.payment_date.isoformat() if renewal.payment_date else None,
|
||||
'created_at': renewal.created_at.isoformat() if renewal.created_at else None
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"renewals": renewal_history,
|
||||
"total_count": total_count,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"has_more": (offset + limit) < total_count
|
||||
}
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting renewal history: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/renewal-history/{user_id}/retention-stats")
|
||||
async def get_renewal_retention_stats(
|
||||
user_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get retention statistics for a user's renewal history.
|
||||
|
||||
Returns breakdown by retention tier:
|
||||
- Recent records (0-12 months): Full records with usage snapshots
|
||||
- To compress (12-24 months): Records that need snapshot compression
|
||||
- To summarize (24-84 months): Records that need snapshot removal
|
||||
- To archive (84+ months): Records ready for archive
|
||||
"""
|
||||
try:
|
||||
verify_user_access(user_id, current_user)
|
||||
|
||||
from services.subscription.renewal_history_retention import RenewalHistoryRetentionService
|
||||
retention_service = RenewalHistoryRetentionService(db)
|
||||
stats = retention_service.get_retention_stats(user_id)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": stats
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting renewal retention stats: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
62
backend/api/subscription/routes/usage.py
Normal file
62
backend/api/subscription/routes/usage.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Usage statistics endpoints.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Dict, Any, Optional
|
||||
from loguru import logger
|
||||
|
||||
from services.database import get_db
|
||||
from services.subscription import UsageTrackingService
|
||||
from ..dependencies import verify_user_access
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/usage/{user_id}")
|
||||
async def get_user_usage(
|
||||
user_id: str,
|
||||
billing_period: Optional[str] = Query(None, description="Billing period (YYYY-MM)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get comprehensive usage statistics for a user."""
|
||||
|
||||
# Verify user can only access their own data
|
||||
verify_user_access(user_id, current_user)
|
||||
|
||||
try:
|
||||
usage_service = UsageTrackingService(db)
|
||||
stats = usage_service.get_user_usage_stats(user_id, billing_period)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": stats
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user usage: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to get user usage")
|
||||
|
||||
|
||||
@router.get("/usage/{user_id}/trends")
|
||||
async def get_usage_trends(
|
||||
user_id: str,
|
||||
months: int = Query(6, ge=1, le=24, description="Number of months to include"),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get usage trends over time."""
|
||||
|
||||
try:
|
||||
usage_service = UsageTrackingService(db)
|
||||
trends = usage_service.get_usage_trends(user_id, months)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": trends
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting usage trends: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
98
backend/api/subscription/utils.py
Normal file
98
backend/api/subscription/utils.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Shared utility functions for subscription API routes.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
import sqlite3
|
||||
|
||||
from models.subscription_models import SubscriptionPlan
|
||||
|
||||
|
||||
def format_plan_limits(plan: SubscriptionPlan) -> Dict[str, Any]:
|
||||
"""
|
||||
Format subscription plan limits for API response.
|
||||
|
||||
Args:
|
||||
plan: SubscriptionPlan model instance
|
||||
|
||||
Returns:
|
||||
Dictionary with formatted limits
|
||||
"""
|
||||
return {
|
||||
"ai_text_generation_calls": getattr(plan, 'ai_text_generation_calls_limit', None) or 0,
|
||||
"gemini_calls": plan.gemini_calls_limit,
|
||||
"openai_calls": plan.openai_calls_limit,
|
||||
"anthropic_calls": plan.anthropic_calls_limit,
|
||||
"mistral_calls": plan.mistral_calls_limit,
|
||||
"tavily_calls": plan.tavily_calls_limit,
|
||||
"serper_calls": plan.serper_calls_limit,
|
||||
"metaphor_calls": plan.metaphor_calls_limit,
|
||||
"firecrawl_calls": plan.firecrawl_calls_limit,
|
||||
"stability_calls": plan.stability_calls_limit,
|
||||
"video_calls": getattr(plan, 'video_calls_limit', 0) or 0,
|
||||
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0) or 0,
|
||||
"audio_calls": getattr(plan, 'audio_calls_limit', 0) or 0,
|
||||
"exa_calls": getattr(plan, 'exa_calls_limit', 0) or 0,
|
||||
"gemini_tokens": plan.gemini_tokens_limit,
|
||||
"openai_tokens": plan.openai_tokens_limit,
|
||||
"anthropic_tokens": plan.anthropic_tokens_limit,
|
||||
"mistral_tokens": plan.mistral_tokens_limit,
|
||||
"monthly_cost": plan.monthly_cost_limit
|
||||
}
|
||||
|
||||
|
||||
def handle_schema_error(
|
||||
error: Exception,
|
||||
db: Session,
|
||||
error_str: str,
|
||||
retry_func: callable
|
||||
) -> Any:
|
||||
"""
|
||||
Handle database schema errors by fixing schema and retrying.
|
||||
|
||||
Args:
|
||||
error: The original exception
|
||||
error_str: Lowercase string representation of error
|
||||
db: Database session
|
||||
retry_func: Function to retry after schema fix
|
||||
|
||||
Returns:
|
||||
Result from retry_func
|
||||
|
||||
Raises:
|
||||
HTTPException: If schema fix fails
|
||||
"""
|
||||
if 'no such column' in error_str:
|
||||
logger.warning("Missing column detected, attempting schema fix...")
|
||||
try:
|
||||
import services.subscription.schema_utils as schema_utils
|
||||
|
||||
# Reset schema check flags based on error type
|
||||
if 'exa_calls_limit' in error_str or 'video_calls_limit' in error_str or \
|
||||
'image_edit_calls_limit' in error_str or 'audio_calls_limit' in error_str:
|
||||
schema_utils._checked_subscription_plan_columns = False
|
||||
from services.subscription.schema_utils import ensure_subscription_plan_columns
|
||||
ensure_subscription_plan_columns(db)
|
||||
elif 'exa_calls' in error_str or 'exa_cost' in error_str or \
|
||||
'video_calls' in error_str or 'video_cost' in error_str or \
|
||||
'image_edit_calls' in error_str or 'image_edit_cost' in error_str or \
|
||||
'audio_calls' in error_str or 'audio_cost' in error_str:
|
||||
schema_utils._checked_usage_summaries_columns = False
|
||||
schema_utils._checked_subscription_plan_columns = False
|
||||
from services.subscription.schema_utils import ensure_usage_summaries_columns, ensure_subscription_plan_columns
|
||||
ensure_usage_summaries_columns(db)
|
||||
ensure_subscription_plan_columns(db)
|
||||
elif 'actual_provider_name' in error_str:
|
||||
schema_utils._checked_api_usage_logs_columns = False
|
||||
from services.subscription.schema_utils import ensure_api_usage_logs_columns
|
||||
ensure_api_usage_logs_columns(db)
|
||||
|
||||
db.expire_all()
|
||||
return retry_func()
|
||||
except Exception as retry_err:
|
||||
logger.error(f"Schema fix and retry failed: {retry_err}")
|
||||
raise HTTPException(status_code=500, detail=f"Database schema error: {str(error)}")
|
||||
|
||||
raise error
|
||||
Reference in New Issue
Block a user