1510 lines
68 KiB
Python
1510 lines
68 KiB
Python
"""
|
|
Subscription and Usage API Routes
|
|
Provides endpoints for subscription management and usage monitoring.
|
|
"""
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
from pydantic import BaseModel
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import desc, func
|
|
from typing import Dict, Any, Optional, List
|
|
from datetime import datetime, timedelta
|
|
from loguru import logger
|
|
from functools import lru_cache
|
|
|
|
from services.database import get_db
|
|
from services.subscription import UsageTrackingService, PricingService
|
|
from services.subscription.log_wrapping_service import LogWrappingService
|
|
from services.subscription.schema_utils import ensure_subscription_plan_columns, ensure_usage_summaries_columns
|
|
import sqlite3
|
|
from middleware.auth_middleware import get_current_user
|
|
from models.subscription_models import (
|
|
APIProvider, SubscriptionPlan, UserSubscription, UsageSummary,
|
|
APIProviderPricing, UsageAlert, SubscriptionTier, BillingCycle, UsageStatus,
|
|
APIUsageLog, SubscriptionRenewalHistory
|
|
)
|
|
|
|
router = APIRouter(prefix="/api/subscription", tags=["subscription"])
|
|
|
|
# Simple in-process cache for dashboard responses to smooth bursts
|
|
# Cache key: (user_id). TTL-like behavior implemented via timestamp check
|
|
_dashboard_cache: Dict[str, Dict[str, Any]] = {}
|
|
_dashboard_cache_ts: Dict[str, float] = {}
|
|
_DASHBOARD_CACHE_TTL_SEC = 600.0
|
|
|
|
@router.get("/usage/{user_id}")
|
|
async def get_user_usage(
|
|
user_id: str,
|
|
billing_period: Optional[str] = Query(None, description="Billing period (YYYY-MM)"),
|
|
db: Session = Depends(get_db),
|
|
current_user: Dict[str, Any] = Depends(get_current_user)
|
|
) -> Dict[str, Any]:
|
|
"""Get comprehensive usage statistics for a user."""
|
|
|
|
# Verify user can only access their own data
|
|
if current_user.get('id') != user_id:
|
|
raise HTTPException(status_code=403, detail="Access denied")
|
|
|
|
try:
|
|
usage_service = UsageTrackingService(db)
|
|
stats = usage_service.get_user_usage_stats(user_id, billing_period)
|
|
|
|
return {
|
|
"success": True,
|
|
"data": stats
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error getting user usage: {e}")
|
|
raise HTTPException(status_code=500, detail="Failed to get user usage")
|
|
|
|
@router.get("/usage/{user_id}/trends")
|
|
async def get_usage_trends(
|
|
user_id: str,
|
|
months: int = Query(6, ge=1, le=24, description="Number of months to include"),
|
|
db: Session = Depends(get_db)
|
|
) -> Dict[str, Any]:
|
|
"""Get usage trends over time."""
|
|
|
|
try:
|
|
usage_service = UsageTrackingService(db)
|
|
trends = usage_service.get_usage_trends(user_id, months)
|
|
|
|
return {
|
|
"success": True,
|
|
"data": trends
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting usage trends: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@router.get("/plans")
|
|
async def get_subscription_plans(
|
|
db: Session = Depends(get_db)
|
|
) -> Dict[str, Any]:
|
|
"""Get all available subscription plans."""
|
|
|
|
try:
|
|
ensure_subscription_plan_columns(db)
|
|
except Exception as schema_err:
|
|
logger.warning(f"Schema check failed, will retry on query: {schema_err}")
|
|
|
|
try:
|
|
plans = db.query(SubscriptionPlan).filter(
|
|
SubscriptionPlan.is_active == True
|
|
).order_by(SubscriptionPlan.price_monthly).all()
|
|
|
|
plans_data = []
|
|
for plan in plans:
|
|
plans_data.append({
|
|
"id": plan.id,
|
|
"name": plan.name,
|
|
"tier": plan.tier.value,
|
|
"price_monthly": plan.price_monthly,
|
|
"price_yearly": plan.price_yearly,
|
|
"description": plan.description,
|
|
"features": plan.features or [],
|
|
"limits": {
|
|
"ai_text_generation_calls": getattr(plan, 'ai_text_generation_calls_limit', None) or 0,
|
|
"gemini_calls": plan.gemini_calls_limit,
|
|
"openai_calls": plan.openai_calls_limit,
|
|
"anthropic_calls": plan.anthropic_calls_limit,
|
|
"mistral_calls": plan.mistral_calls_limit,
|
|
"tavily_calls": plan.tavily_calls_limit,
|
|
"serper_calls": plan.serper_calls_limit,
|
|
"metaphor_calls": plan.metaphor_calls_limit,
|
|
"firecrawl_calls": plan.firecrawl_calls_limit,
|
|
"stability_calls": plan.stability_calls_limit,
|
|
"video_calls": getattr(plan, 'video_calls_limit', 0),
|
|
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0),
|
|
"audio_calls": getattr(plan, 'audio_calls_limit', 0),
|
|
"gemini_tokens": plan.gemini_tokens_limit,
|
|
"openai_tokens": plan.openai_tokens_limit,
|
|
"anthropic_tokens": plan.anthropic_tokens_limit,
|
|
"mistral_tokens": plan.mistral_tokens_limit,
|
|
"monthly_cost": plan.monthly_cost_limit
|
|
}
|
|
})
|
|
|
|
return {
|
|
"success": True,
|
|
"data": {
|
|
"plans": plans_data,
|
|
"total": len(plans_data)
|
|
}
|
|
}
|
|
|
|
except (sqlite3.OperationalError, Exception) as e:
|
|
error_str = str(e).lower()
|
|
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str or 'audio_calls_limit' in error_str):
|
|
logger.warning("Missing column detected in subscription plans query, attempting schema fix...")
|
|
try:
|
|
import services.subscription.schema_utils as schema_utils
|
|
schema_utils._checked_subscription_plan_columns = False
|
|
ensure_subscription_plan_columns(db)
|
|
db.expire_all()
|
|
# Retry the query
|
|
plans = db.query(SubscriptionPlan).filter(
|
|
SubscriptionPlan.is_active == True
|
|
).order_by(SubscriptionPlan.price_monthly).all()
|
|
|
|
plans_data = []
|
|
for plan in plans:
|
|
plans_data.append({
|
|
"id": plan.id,
|
|
"name": plan.name,
|
|
"tier": plan.tier.value,
|
|
"price_monthly": plan.price_monthly,
|
|
"price_yearly": plan.price_yearly,
|
|
"description": plan.description,
|
|
"features": plan.features or [],
|
|
"limits": {
|
|
"ai_text_generation_calls": getattr(plan, 'ai_text_generation_calls_limit', None) or 0,
|
|
"gemini_calls": plan.gemini_calls_limit,
|
|
"openai_calls": plan.openai_calls_limit,
|
|
"anthropic_calls": plan.anthropic_calls_limit,
|
|
"mistral_calls": plan.mistral_calls_limit,
|
|
"tavily_calls": plan.tavily_calls_limit,
|
|
"serper_calls": plan.serper_calls_limit,
|
|
"metaphor_calls": plan.metaphor_calls_limit,
|
|
"firecrawl_calls": plan.firecrawl_calls_limit,
|
|
"stability_calls": plan.stability_calls_limit,
|
|
"gemini_tokens": plan.gemini_tokens_limit,
|
|
"openai_tokens": plan.openai_tokens_limit,
|
|
"anthropic_tokens": plan.anthropic_tokens_limit,
|
|
"mistral_tokens": plan.mistral_tokens_limit,
|
|
"monthly_cost": plan.monthly_cost_limit
|
|
}
|
|
})
|
|
|
|
return {
|
|
"success": True,
|
|
"data": {
|
|
"plans": plans_data,
|
|
"total": len(plans_data)
|
|
}
|
|
}
|
|
except Exception as retry_err:
|
|
logger.error(f"Schema fix and retry failed: {retry_err}")
|
|
raise HTTPException(status_code=500, detail=f"Database schema error: {str(e)}")
|
|
|
|
logger.error(f"Error getting subscription plans: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@router.get("/user/{user_id}/subscription")
|
|
async def get_user_subscription(
|
|
user_id: str,
|
|
db: Session = Depends(get_db),
|
|
current_user: Dict[str, Any] = Depends(get_current_user)
|
|
) -> Dict[str, Any]:
|
|
"""Get user's current subscription information."""
|
|
|
|
# Verify user can only access their own data
|
|
if current_user.get('id') != user_id:
|
|
raise HTTPException(status_code=403, detail="Access denied")
|
|
|
|
try:
|
|
ensure_subscription_plan_columns(db)
|
|
subscription = db.query(UserSubscription).filter(
|
|
UserSubscription.user_id == user_id,
|
|
UserSubscription.is_active == True
|
|
).first()
|
|
|
|
if not subscription:
|
|
# Return free tier information
|
|
free_plan = db.query(SubscriptionPlan).filter(
|
|
SubscriptionPlan.tier == SubscriptionTier.FREE
|
|
).first()
|
|
|
|
if free_plan:
|
|
return {
|
|
"success": True,
|
|
"data": {
|
|
"subscription": None,
|
|
"plan": {
|
|
"id": free_plan.id,
|
|
"name": free_plan.name,
|
|
"tier": free_plan.tier.value,
|
|
"price_monthly": free_plan.price_monthly,
|
|
"description": free_plan.description,
|
|
"is_free": True
|
|
},
|
|
"status": "free",
|
|
"limits": {
|
|
"ai_text_generation_calls": getattr(free_plan, 'ai_text_generation_calls_limit', None) or 0,
|
|
"gemini_calls": free_plan.gemini_calls_limit,
|
|
"openai_calls": free_plan.openai_calls_limit,
|
|
"anthropic_calls": free_plan.anthropic_calls_limit,
|
|
"mistral_calls": free_plan.mistral_calls_limit,
|
|
"tavily_calls": free_plan.tavily_calls_limit,
|
|
"serper_calls": free_plan.serper_calls_limit,
|
|
"metaphor_calls": free_plan.metaphor_calls_limit,
|
|
"firecrawl_calls": free_plan.firecrawl_calls_limit,
|
|
"stability_calls": free_plan.stability_calls_limit,
|
|
"video_calls": getattr(free_plan, 'video_calls_limit', 0),
|
|
"image_edit_calls": getattr(free_plan, 'image_edit_calls_limit', 0),
|
|
"audio_calls": getattr(free_plan, 'audio_calls_limit', 0),
|
|
"monthly_cost": free_plan.monthly_cost_limit
|
|
}
|
|
}
|
|
}
|
|
else:
|
|
raise HTTPException(status_code=404, detail="No subscription plan found")
|
|
|
|
return {
|
|
"success": True,
|
|
"data": {
|
|
"subscription": {
|
|
"id": subscription.id,
|
|
"billing_cycle": subscription.billing_cycle.value,
|
|
"current_period_start": subscription.current_period_start.isoformat(),
|
|
"current_period_end": subscription.current_period_end.isoformat(),
|
|
"status": subscription.status.value,
|
|
"auto_renew": subscription.auto_renew,
|
|
"created_at": subscription.created_at.isoformat()
|
|
},
|
|
"plan": {
|
|
"id": subscription.plan.id,
|
|
"name": subscription.plan.name,
|
|
"tier": subscription.plan.tier.value,
|
|
"price_monthly": subscription.plan.price_monthly,
|
|
"price_yearly": subscription.plan.price_yearly,
|
|
"description": subscription.plan.description,
|
|
"is_free": False
|
|
},
|
|
"limits": {
|
|
"ai_text_generation_calls": getattr(subscription.plan, 'ai_text_generation_calls_limit', None) or 0,
|
|
"gemini_calls": subscription.plan.gemini_calls_limit,
|
|
"openai_calls": subscription.plan.openai_calls_limit,
|
|
"anthropic_calls": subscription.plan.anthropic_calls_limit,
|
|
"mistral_calls": subscription.plan.mistral_calls_limit,
|
|
"tavily_calls": subscription.plan.tavily_calls_limit,
|
|
"serper_calls": subscription.plan.serper_calls_limit,
|
|
"metaphor_calls": subscription.plan.metaphor_calls_limit,
|
|
"firecrawl_calls": subscription.plan.firecrawl_calls_limit,
|
|
"stability_calls": subscription.plan.stability_calls_limit,
|
|
"monthly_cost": subscription.plan.monthly_cost_limit
|
|
}
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting user subscription: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@router.get("/status/{user_id}")
|
|
async def get_subscription_status(
|
|
user_id: str,
|
|
db: Session = Depends(get_db),
|
|
current_user: Dict[str, Any] = Depends(get_current_user)
|
|
) -> Dict[str, Any]:
|
|
"""Get simple subscription status for enforcement checks."""
|
|
|
|
# Verify user can only access their own data
|
|
if current_user.get('id') != user_id:
|
|
raise HTTPException(status_code=403, detail="Access denied")
|
|
|
|
try:
|
|
ensure_subscription_plan_columns(db)
|
|
except Exception as schema_err:
|
|
logger.warning(f"Schema check failed, will retry on query: {schema_err}")
|
|
|
|
try:
|
|
subscription = db.query(UserSubscription).filter(
|
|
UserSubscription.user_id == user_id,
|
|
UserSubscription.is_active == True
|
|
).first()
|
|
|
|
if not subscription:
|
|
# Check if free tier exists
|
|
free_plan = db.query(SubscriptionPlan).filter(
|
|
SubscriptionPlan.tier == SubscriptionTier.FREE,
|
|
SubscriptionPlan.is_active == True
|
|
).first()
|
|
|
|
if free_plan:
|
|
return {
|
|
"success": True,
|
|
"data": {
|
|
"active": True,
|
|
"plan": "free",
|
|
"tier": "free",
|
|
"can_use_api": True,
|
|
"limits": {
|
|
"ai_text_generation_calls": getattr(free_plan, 'ai_text_generation_calls_limit', None) or 0,
|
|
"gemini_calls": free_plan.gemini_calls_limit,
|
|
"openai_calls": free_plan.openai_calls_limit,
|
|
"anthropic_calls": free_plan.anthropic_calls_limit,
|
|
"mistral_calls": free_plan.mistral_calls_limit,
|
|
"tavily_calls": free_plan.tavily_calls_limit,
|
|
"serper_calls": free_plan.serper_calls_limit,
|
|
"metaphor_calls": free_plan.metaphor_calls_limit,
|
|
"firecrawl_calls": free_plan.firecrawl_calls_limit,
|
|
"stability_calls": free_plan.stability_calls_limit,
|
|
"video_calls": getattr(free_plan, 'video_calls_limit', 0),
|
|
"image_edit_calls": getattr(free_plan, 'image_edit_calls_limit', 0),
|
|
"audio_calls": getattr(free_plan, 'audio_calls_limit', 0),
|
|
"monthly_cost": free_plan.monthly_cost_limit
|
|
}
|
|
}
|
|
}
|
|
else:
|
|
return {
|
|
"success": True,
|
|
"data": {
|
|
"active": False,
|
|
"plan": "none",
|
|
"tier": "none",
|
|
"can_use_api": False,
|
|
"reason": "No active subscription or free tier found"
|
|
}
|
|
}
|
|
|
|
# Check if subscription is within valid period; auto-advance if expired and auto_renew
|
|
now = datetime.utcnow()
|
|
if subscription.current_period_end < now:
|
|
if getattr(subscription, 'auto_renew', False):
|
|
# advance period
|
|
try:
|
|
from services.pricing_service import PricingService
|
|
pricing = PricingService(db)
|
|
# reuse helper to ensure current
|
|
pricing._ensure_subscription_current(subscription)
|
|
except Exception as e:
|
|
logger.error(f"Failed to auto-advance subscription: {e}")
|
|
else:
|
|
return {
|
|
"success": True,
|
|
"data": {
|
|
"active": False,
|
|
"plan": subscription.plan.tier.value,
|
|
"tier": subscription.plan.tier.value,
|
|
"can_use_api": False,
|
|
"reason": "Subscription expired"
|
|
}
|
|
}
|
|
|
|
return {
|
|
"success": True,
|
|
"data": {
|
|
"active": True,
|
|
"plan": subscription.plan.tier.value,
|
|
"tier": subscription.plan.tier.value,
|
|
"can_use_api": True,
|
|
"limits": {
|
|
"ai_text_generation_calls": getattr(subscription.plan, 'ai_text_generation_calls_limit', None) or 0,
|
|
"gemini_calls": subscription.plan.gemini_calls_limit,
|
|
"openai_calls": subscription.plan.openai_calls_limit,
|
|
"anthropic_calls": subscription.plan.anthropic_calls_limit,
|
|
"mistral_calls": subscription.plan.mistral_calls_limit,
|
|
"tavily_calls": subscription.plan.tavily_calls_limit,
|
|
"serper_calls": subscription.plan.serper_calls_limit,
|
|
"metaphor_calls": subscription.plan.metaphor_calls_limit,
|
|
"firecrawl_calls": subscription.plan.firecrawl_calls_limit,
|
|
"stability_calls": subscription.plan.stability_calls_limit,
|
|
"monthly_cost": subscription.plan.monthly_cost_limit
|
|
}
|
|
}
|
|
}
|
|
|
|
except (sqlite3.OperationalError, Exception) as e:
|
|
error_str = str(e).lower()
|
|
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str or 'audio_calls_limit' in error_str):
|
|
# Try to fix schema and retry once
|
|
logger.warning("Missing column detected in subscription status query, attempting schema fix...")
|
|
try:
|
|
import services.subscription.schema_utils as schema_utils
|
|
schema_utils._checked_subscription_plan_columns = False
|
|
ensure_subscription_plan_columns(db)
|
|
db.commit() # Ensure schema changes are committed
|
|
db.expire_all()
|
|
# Retry the query - query subscription without eager loading plan
|
|
subscription = db.query(UserSubscription).filter(
|
|
UserSubscription.user_id == user_id,
|
|
UserSubscription.is_active == True
|
|
).first()
|
|
|
|
if not subscription:
|
|
free_plan = db.query(SubscriptionPlan).filter(
|
|
SubscriptionPlan.tier == SubscriptionTier.FREE,
|
|
SubscriptionPlan.is_active == True
|
|
).first()
|
|
if free_plan:
|
|
return {
|
|
"success": True,
|
|
"data": {
|
|
"active": True,
|
|
"plan": "free",
|
|
"tier": "free",
|
|
"can_use_api": True,
|
|
"limits": {
|
|
"ai_text_generation_calls": getattr(free_plan, 'ai_text_generation_calls_limit', None) or 0,
|
|
"gemini_calls": free_plan.gemini_calls_limit,
|
|
"openai_calls": free_plan.openai_calls_limit,
|
|
"anthropic_calls": free_plan.anthropic_calls_limit,
|
|
"mistral_calls": free_plan.mistral_calls_limit,
|
|
"tavily_calls": free_plan.tavily_calls_limit,
|
|
"serper_calls": free_plan.serper_calls_limit,
|
|
"metaphor_calls": free_plan.metaphor_calls_limit,
|
|
"firecrawl_calls": free_plan.firecrawl_calls_limit,
|
|
"stability_calls": free_plan.stability_calls_limit,
|
|
"video_calls": getattr(free_plan, 'video_calls_limit', 0),
|
|
"image_edit_calls": getattr(free_plan, 'image_edit_calls_limit', 0),
|
|
"monthly_cost": free_plan.monthly_cost_limit
|
|
}
|
|
}
|
|
}
|
|
elif subscription:
|
|
# Query plan separately after schema fix to avoid lazy loading issues
|
|
plan = db.query(SubscriptionPlan).filter(
|
|
SubscriptionPlan.id == subscription.plan_id
|
|
).first()
|
|
|
|
if not plan:
|
|
raise HTTPException(status_code=404, detail="Plan not found")
|
|
|
|
now = datetime.utcnow()
|
|
if subscription.current_period_end < now:
|
|
if getattr(subscription, 'auto_renew', False):
|
|
try:
|
|
from services.pricing_service import PricingService
|
|
pricing = PricingService(db)
|
|
pricing._ensure_subscription_current(subscription)
|
|
except Exception as e2:
|
|
logger.error(f"Failed to auto-advance subscription: {e2}")
|
|
else:
|
|
return {
|
|
"success": True,
|
|
"data": {
|
|
"active": False,
|
|
"plan": plan.tier.value,
|
|
"tier": plan.tier.value,
|
|
"can_use_api": False,
|
|
"reason": "Subscription expired"
|
|
}
|
|
}
|
|
return {
|
|
"success": True,
|
|
"data": {
|
|
"active": True,
|
|
"plan": plan.tier.value,
|
|
"tier": plan.tier.value,
|
|
"can_use_api": True,
|
|
"limits": {
|
|
"ai_text_generation_calls": getattr(plan, 'ai_text_generation_calls_limit', None) or 0,
|
|
"gemini_calls": plan.gemini_calls_limit,
|
|
"openai_calls": plan.openai_calls_limit,
|
|
"anthropic_calls": plan.anthropic_calls_limit,
|
|
"mistral_calls": plan.mistral_calls_limit,
|
|
"tavily_calls": plan.tavily_calls_limit,
|
|
"serper_calls": plan.serper_calls_limit,
|
|
"metaphor_calls": plan.metaphor_calls_limit,
|
|
"firecrawl_calls": plan.firecrawl_calls_limit,
|
|
"stability_calls": plan.stability_calls_limit,
|
|
"video_calls": getattr(plan, 'video_calls_limit', 0),
|
|
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0),
|
|
"audio_calls": getattr(plan, 'audio_calls_limit', 0),
|
|
"monthly_cost": plan.monthly_cost_limit
|
|
}
|
|
}
|
|
}
|
|
except Exception as retry_err:
|
|
logger.error(f"Schema fix and retry failed: {retry_err}")
|
|
raise HTTPException(status_code=500, detail=f"Database schema error: {str(e)}")
|
|
|
|
logger.error(f"Error getting subscription status: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@router.post("/subscribe/{user_id}")
|
|
async def subscribe_to_plan(
|
|
user_id: str,
|
|
subscription_data: dict,
|
|
db: Session = Depends(get_db),
|
|
current_user: Dict[str, Any] = Depends(get_current_user)
|
|
) -> Dict[str, Any]:
|
|
"""Create or update a user's subscription (renewal)."""
|
|
|
|
# Verify user can only subscribe/renew their own subscription
|
|
if current_user.get('id') != user_id:
|
|
raise HTTPException(status_code=403, detail="Access denied")
|
|
|
|
try:
|
|
ensure_subscription_plan_columns(db)
|
|
plan_id = subscription_data.get('plan_id')
|
|
billing_cycle = subscription_data.get('billing_cycle', 'monthly')
|
|
|
|
if not plan_id:
|
|
raise HTTPException(status_code=400, detail="plan_id is required")
|
|
|
|
# Get the plan
|
|
plan = db.query(SubscriptionPlan).filter(
|
|
SubscriptionPlan.id == plan_id,
|
|
SubscriptionPlan.is_active == True
|
|
).first()
|
|
|
|
if not plan:
|
|
raise HTTPException(status_code=404, detail="Plan not found")
|
|
|
|
# Check if user already has an active subscription
|
|
existing_subscription = db.query(UserSubscription).filter(
|
|
UserSubscription.user_id == user_id,
|
|
UserSubscription.is_active == True
|
|
).first()
|
|
|
|
now = datetime.utcnow()
|
|
|
|
# Track renewal history - capture BEFORE updating subscription
|
|
previous_period_start = None
|
|
previous_period_end = None
|
|
previous_plan_name = None
|
|
previous_plan_tier = None
|
|
renewal_type = "new"
|
|
renewal_count = 0
|
|
|
|
# Get usage snapshot BEFORE renewal (capture current state)
|
|
usage_before_snapshot = None
|
|
current_period = datetime.utcnow().strftime("%Y-%m")
|
|
usage_before = db.query(UsageSummary).filter(
|
|
UsageSummary.user_id == user_id,
|
|
UsageSummary.billing_period == current_period
|
|
).first()
|
|
|
|
if usage_before:
|
|
usage_before_snapshot = {
|
|
"total_calls": usage_before.total_calls or 0,
|
|
"total_tokens": usage_before.total_tokens or 0,
|
|
"total_cost": float(usage_before.total_cost) if usage_before.total_cost else 0.0,
|
|
"gemini_calls": usage_before.gemini_calls or 0,
|
|
"mistral_calls": usage_before.mistral_calls or 0,
|
|
"usage_status": usage_before.usage_status.value if hasattr(usage_before.usage_status, 'value') else str(usage_before.usage_status)
|
|
}
|
|
|
|
if existing_subscription:
|
|
# This is a renewal/update - capture previous subscription state BEFORE updating
|
|
previous_period_start = existing_subscription.current_period_start
|
|
previous_period_end = existing_subscription.current_period_end
|
|
previous_plan = existing_subscription.plan
|
|
previous_plan_name = previous_plan.name if previous_plan else None
|
|
previous_plan_tier = previous_plan.tier.value if previous_plan else None
|
|
|
|
# Determine renewal type
|
|
if previous_plan and previous_plan.id == plan_id:
|
|
# Same plan - this is a renewal
|
|
renewal_type = "renewal"
|
|
elif previous_plan:
|
|
# Different plan - check if upgrade or downgrade
|
|
tier_order = {"free": 0, "basic": 1, "pro": 2, "enterprise": 3}
|
|
previous_tier_order = tier_order.get(previous_plan_tier or "free", 0)
|
|
new_tier_order = tier_order.get(plan.tier.value, 0)
|
|
if new_tier_order > previous_tier_order:
|
|
renewal_type = "upgrade"
|
|
elif new_tier_order < previous_tier_order:
|
|
renewal_type = "downgrade"
|
|
else:
|
|
renewal_type = "renewal" # Same tier, different plan name
|
|
|
|
# Get renewal count (how many times this user has renewed)
|
|
last_renewal = db.query(SubscriptionRenewalHistory).filter(
|
|
SubscriptionRenewalHistory.user_id == user_id
|
|
).order_by(SubscriptionRenewalHistory.created_at.desc()).first()
|
|
|
|
if last_renewal:
|
|
renewal_count = last_renewal.renewal_count + 1
|
|
else:
|
|
renewal_count = 1 # First renewal
|
|
|
|
# Update existing subscription
|
|
existing_subscription.plan_id = plan_id
|
|
existing_subscription.billing_cycle = BillingCycle(billing_cycle)
|
|
existing_subscription.current_period_start = now
|
|
existing_subscription.current_period_end = now + timedelta(
|
|
days=365 if billing_cycle == 'yearly' else 30
|
|
)
|
|
existing_subscription.updated_at = now
|
|
|
|
subscription = existing_subscription
|
|
else:
|
|
# Create new subscription
|
|
subscription = UserSubscription(
|
|
user_id=user_id,
|
|
plan_id=plan_id,
|
|
billing_cycle=BillingCycle(billing_cycle),
|
|
current_period_start=now,
|
|
current_period_end=now + timedelta(
|
|
days=365 if billing_cycle == 'yearly' else 30
|
|
),
|
|
status=UsageStatus.ACTIVE,
|
|
is_active=True,
|
|
auto_renew=True
|
|
)
|
|
db.add(subscription)
|
|
|
|
db.commit()
|
|
|
|
# Create renewal history record AFTER subscription update (so we have the new period_end)
|
|
renewal_history = SubscriptionRenewalHistory(
|
|
user_id=user_id,
|
|
plan_id=plan_id,
|
|
plan_name=plan.name,
|
|
plan_tier=plan.tier.value,
|
|
previous_period_start=previous_period_start,
|
|
previous_period_end=previous_period_end,
|
|
new_period_start=now,
|
|
new_period_end=subscription.current_period_end,
|
|
billing_cycle=BillingCycle(billing_cycle),
|
|
renewal_type=renewal_type,
|
|
renewal_count=renewal_count,
|
|
previous_plan_name=previous_plan_name,
|
|
previous_plan_tier=previous_plan_tier,
|
|
usage_before_renewal=usage_before_snapshot, # Usage snapshot captured BEFORE renewal
|
|
payment_amount=plan.price_yearly if billing_cycle == 'yearly' else plan.price_monthly,
|
|
payment_status="paid", # Assume paid for now (can be updated if payment processing is added)
|
|
payment_date=now
|
|
)
|
|
db.add(renewal_history)
|
|
db.commit()
|
|
|
|
# Get current usage BEFORE reset for logging
|
|
current_period = datetime.utcnow().strftime("%Y-%m")
|
|
usage_before = db.query(UsageSummary).filter(
|
|
UsageSummary.user_id == user_id,
|
|
UsageSummary.billing_period == current_period
|
|
).first()
|
|
|
|
# Log renewal request details
|
|
logger.info("=" * 80)
|
|
logger.info(f"[SUBSCRIPTION RENEWAL] 🔄 Processing renewal request")
|
|
logger.info(f" ├─ User: {user_id}")
|
|
logger.info(f" ├─ Plan: {plan.name} (ID: {plan_id}, Tier: {plan.tier.value})")
|
|
logger.info(f" ├─ Billing Cycle: {billing_cycle}")
|
|
logger.info(f" ├─ Period Start: {now.strftime('%Y-%m-%d %H:%M:%S')}")
|
|
logger.info(f" └─ Period End: {subscription.current_period_end.strftime('%Y-%m-%d %H:%M:%S')}")
|
|
|
|
if usage_before:
|
|
logger.info(f" 📊 Current Usage BEFORE Reset (Period: {current_period}):")
|
|
logger.info(f" ├─ Gemini: {usage_before.gemini_tokens or 0} tokens / {usage_before.gemini_calls or 0} calls")
|
|
logger.info(f" ├─ Mistral/HF: {usage_before.mistral_tokens or 0} tokens / {usage_before.mistral_calls or 0} calls")
|
|
logger.info(f" ├─ OpenAI: {usage_before.openai_tokens or 0} tokens / {usage_before.openai_calls or 0} calls")
|
|
logger.info(f" ├─ Stability (Images): {usage_before.stability_calls or 0} calls")
|
|
logger.info(f" ├─ Total Tokens: {usage_before.total_tokens or 0}")
|
|
logger.info(f" ├─ Total Calls: {usage_before.total_calls or 0}")
|
|
logger.info(f" └─ Usage Status: {usage_before.usage_status.value}")
|
|
else:
|
|
logger.info(f" 📊 No usage summary found for period {current_period} (will be created on reset)")
|
|
|
|
# Clear subscription limits cache to force refresh on next check
|
|
# IMPORTANT: Do this BEFORE resetting usage to ensure cache is cleared first
|
|
try:
|
|
from services.subscription import PricingService
|
|
# Clear cache for this specific user (class-level cache shared across all instances)
|
|
cleared_count = PricingService.clear_user_cache(user_id)
|
|
logger.info(f" 🗑️ Cleared {cleared_count} subscription cache entries for user {user_id}")
|
|
|
|
# Also expire all SQLAlchemy objects to force fresh reads
|
|
db.expire_all()
|
|
logger.info(f" 🔄 Expired all SQLAlchemy objects to force fresh reads")
|
|
except Exception as cache_err:
|
|
logger.error(f" ❌ Failed to clear cache after subscribe: {cache_err}")
|
|
|
|
# Reset usage status for current billing period so new plan takes effect immediately
|
|
reset_result = None
|
|
try:
|
|
usage_service = UsageTrackingService(db)
|
|
reset_result = await usage_service.reset_current_billing_period(user_id)
|
|
|
|
# Force commit to ensure reset is persisted
|
|
db.commit()
|
|
|
|
# Expire all SQLAlchemy objects to force fresh reads
|
|
db.expire_all()
|
|
|
|
# Re-query usage summary from DB after reset to get fresh data (fresh query)
|
|
usage_after = db.query(UsageSummary).filter(
|
|
UsageSummary.user_id == user_id,
|
|
UsageSummary.billing_period == current_period
|
|
).first()
|
|
|
|
# Refresh the usage object if found to ensure we have latest data
|
|
if usage_after:
|
|
db.refresh(usage_after)
|
|
|
|
if reset_result.get('reset'):
|
|
logger.info(f" ✅ Usage counters RESET successfully")
|
|
if usage_after:
|
|
logger.info(f" 📊 New Usage AFTER Reset:")
|
|
logger.info(f" ├─ Gemini: {usage_after.gemini_tokens or 0} tokens / {usage_after.gemini_calls or 0} calls")
|
|
logger.info(f" ├─ Mistral/HF: {usage_after.mistral_tokens or 0} tokens / {usage_after.mistral_calls or 0} calls")
|
|
logger.info(f" ├─ OpenAI: {usage_after.openai_tokens or 0} tokens / {usage_after.openai_calls or 0} calls")
|
|
logger.info(f" ├─ Stability (Images): {usage_after.stability_calls or 0} calls")
|
|
logger.info(f" ├─ Total Tokens: {usage_after.total_tokens or 0}")
|
|
logger.info(f" ├─ Total Calls: {usage_after.total_calls or 0}")
|
|
logger.info(f" └─ Usage Status: {usage_after.usage_status.value}")
|
|
else:
|
|
logger.warning(f" ⚠️ Usage summary not found after reset - may need to be created on next API call")
|
|
else:
|
|
logger.warning(f" ⚠️ Reset returned: {reset_result.get('reason', 'unknown')}")
|
|
except Exception as reset_err:
|
|
logger.error(f" ❌ Failed to reset usage after subscribe: {reset_err}", exc_info=True)
|
|
|
|
logger.info(f" ✅ Renewal completed: User {user_id} → {plan.name} ({billing_cycle})")
|
|
logger.info("=" * 80)
|
|
|
|
return {
|
|
"success": True,
|
|
"message": f"Successfully subscribed to {plan.name}",
|
|
"data": {
|
|
"subscription_id": subscription.id,
|
|
"plan_name": plan.name,
|
|
"billing_cycle": billing_cycle,
|
|
"current_period_start": subscription.current_period_start.isoformat(),
|
|
"current_period_end": subscription.current_period_end.isoformat(),
|
|
"status": subscription.status.value,
|
|
"limits": {
|
|
"ai_text_generation_calls": getattr(plan, 'ai_text_generation_calls_limit', None) or 0,
|
|
"gemini_calls": plan.gemini_calls_limit,
|
|
"openai_calls": plan.openai_calls_limit,
|
|
"anthropic_calls": plan.anthropic_calls_limit,
|
|
"mistral_calls": plan.mistral_calls_limit,
|
|
"tavily_calls": plan.tavily_calls_limit,
|
|
"serper_calls": plan.serper_calls_limit,
|
|
"metaphor_calls": plan.metaphor_calls_limit,
|
|
"firecrawl_calls": plan.firecrawl_calls_limit,
|
|
"stability_calls": plan.stability_calls_limit,
|
|
"monthly_cost": plan.monthly_cost_limit
|
|
}
|
|
}
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error subscribing to plan: {e}")
|
|
db.rollback()
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@router.get("/pricing")
|
|
async def get_api_pricing(
|
|
provider: Optional[str] = Query(None, description="API provider"),
|
|
db: Session = Depends(get_db)
|
|
) -> Dict[str, Any]:
|
|
"""Get API pricing information."""
|
|
|
|
try:
|
|
query = db.query(APIProviderPricing).filter(
|
|
APIProviderPricing.is_active == True
|
|
)
|
|
|
|
if provider:
|
|
try:
|
|
api_provider = APIProvider(provider.lower())
|
|
query = query.filter(APIProviderPricing.provider == api_provider)
|
|
except ValueError:
|
|
raise HTTPException(status_code=400, detail=f"Invalid provider: {provider}")
|
|
|
|
pricing_data = query.all()
|
|
|
|
pricing_list = []
|
|
for pricing in pricing_data:
|
|
pricing_list.append({
|
|
"provider": pricing.provider.value,
|
|
"model_name": pricing.model_name,
|
|
"cost_per_input_token": pricing.cost_per_input_token,
|
|
"cost_per_output_token": pricing.cost_per_output_token,
|
|
"cost_per_request": pricing.cost_per_request,
|
|
"cost_per_search": pricing.cost_per_search,
|
|
"cost_per_image": pricing.cost_per_image,
|
|
"cost_per_page": pricing.cost_per_page,
|
|
"description": pricing.description,
|
|
"effective_date": pricing.effective_date.isoformat()
|
|
})
|
|
|
|
return {
|
|
"success": True,
|
|
"data": {
|
|
"pricing": pricing_list,
|
|
"total": len(pricing_list)
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting API pricing: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@router.get("/alerts/{user_id}")
|
|
async def get_usage_alerts(
|
|
user_id: str,
|
|
unread_only: bool = Query(False, description="Only return unread alerts"),
|
|
limit: int = Query(50, ge=1, le=100, description="Maximum number of alerts"),
|
|
db: Session = Depends(get_db)
|
|
) -> Dict[str, Any]:
|
|
"""Get usage alerts for a user."""
|
|
|
|
try:
|
|
query = db.query(UsageAlert).filter(
|
|
UsageAlert.user_id == user_id
|
|
)
|
|
|
|
if unread_only:
|
|
query = query.filter(UsageAlert.is_read == False)
|
|
|
|
alerts = query.order_by(
|
|
UsageAlert.created_at.desc()
|
|
).limit(limit).all()
|
|
|
|
alerts_data = []
|
|
for alert in alerts:
|
|
alerts_data.append({
|
|
"id": alert.id,
|
|
"type": alert.alert_type,
|
|
"threshold_percentage": alert.threshold_percentage,
|
|
"provider": alert.provider.value if alert.provider else None,
|
|
"title": alert.title,
|
|
"message": alert.message,
|
|
"severity": alert.severity,
|
|
"is_sent": alert.is_sent,
|
|
"sent_at": alert.sent_at.isoformat() if alert.sent_at else None,
|
|
"is_read": alert.is_read,
|
|
"read_at": alert.read_at.isoformat() if alert.read_at else None,
|
|
"billing_period": alert.billing_period,
|
|
"created_at": alert.created_at.isoformat()
|
|
})
|
|
|
|
return {
|
|
"success": True,
|
|
"data": {
|
|
"alerts": alerts_data,
|
|
"total": len(alerts_data),
|
|
"unread_count": len([a for a in alerts_data if not a["is_read"]])
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting usage alerts: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@router.post("/alerts/{alert_id}/mark-read")
|
|
async def mark_alert_read(
|
|
alert_id: int,
|
|
db: Session = Depends(get_db)
|
|
) -> Dict[str, Any]:
|
|
"""Mark an alert as read."""
|
|
|
|
try:
|
|
alert = db.query(UsageAlert).filter(UsageAlert.id == alert_id).first()
|
|
|
|
if not alert:
|
|
raise HTTPException(status_code=404, detail="Alert not found")
|
|
|
|
alert.is_read = True
|
|
alert.read_at = datetime.utcnow()
|
|
db.commit()
|
|
|
|
return {
|
|
"success": True,
|
|
"message": "Alert marked as read"
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error marking alert as read: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@router.get("/dashboard/{user_id}")
|
|
async def get_dashboard_data(
|
|
user_id: str,
|
|
db: Session = Depends(get_db)
|
|
) -> Dict[str, Any]:
|
|
"""Get comprehensive dashboard data for usage monitoring."""
|
|
|
|
try:
|
|
ensure_subscription_plan_columns(db)
|
|
ensure_usage_summaries_columns(db)
|
|
# Serve from short TTL cache to avoid hammering DB on bursts
|
|
import time
|
|
now = time.time()
|
|
import os
|
|
nocache = False
|
|
try:
|
|
# Not having direct access to request here; provide env flag override as simple control
|
|
nocache = os.getenv('SUBSCRIPTION_DASHBOARD_NOCACHE', 'false').lower() in {'1','true','yes','on'}
|
|
except Exception:
|
|
nocache = False
|
|
if not nocache and user_id in _dashboard_cache and (now - _dashboard_cache_ts.get(user_id, 0)) < _DASHBOARD_CACHE_TTL_SEC:
|
|
return _dashboard_cache[user_id]
|
|
|
|
usage_service = UsageTrackingService(db)
|
|
pricing_service = PricingService(db)
|
|
|
|
# Get current usage stats
|
|
current_usage = usage_service.get_user_usage_stats(user_id)
|
|
|
|
# Get usage trends (last 6 months)
|
|
trends = usage_service.get_usage_trends(user_id, 6)
|
|
|
|
# Get user limits
|
|
limits = pricing_service.get_user_limits(user_id)
|
|
|
|
# Get unread alerts
|
|
alerts = db.query(UsageAlert).filter(
|
|
UsageAlert.user_id == user_id,
|
|
UsageAlert.is_read == False
|
|
).order_by(UsageAlert.created_at.desc()).limit(5).all()
|
|
|
|
alerts_data = [
|
|
{
|
|
"id": alert.id,
|
|
"type": alert.alert_type,
|
|
"title": alert.title,
|
|
"message": alert.message,
|
|
"severity": alert.severity,
|
|
"created_at": alert.created_at.isoformat()
|
|
}
|
|
for alert in alerts
|
|
]
|
|
|
|
# Calculate cost projections
|
|
current_cost = current_usage.get('total_cost', 0)
|
|
days_in_period = 30
|
|
current_day = datetime.now().day
|
|
projected_cost = (current_cost / current_day) * days_in_period if current_day > 0 else 0
|
|
|
|
response_payload = {
|
|
"success": True,
|
|
"data": {
|
|
"current_usage": current_usage,
|
|
"trends": trends,
|
|
"limits": limits,
|
|
"alerts": alerts_data,
|
|
"projections": {
|
|
"projected_monthly_cost": round(projected_cost, 2),
|
|
"cost_limit": limits.get('limits', {}).get('monthly_cost', 0) if limits else 0,
|
|
"projected_usage_percentage": (projected_cost / max(limits.get('limits', {}).get('monthly_cost', 1), 1)) * 100 if limits else 0
|
|
},
|
|
"summary": {
|
|
"total_api_calls_this_month": current_usage.get('total_calls', 0),
|
|
"total_cost_this_month": current_usage.get('total_cost', 0),
|
|
"usage_status": current_usage.get('usage_status', 'active'),
|
|
"unread_alerts": len(alerts_data)
|
|
}
|
|
}
|
|
}
|
|
_dashboard_cache[user_id] = response_payload
|
|
_dashboard_cache_ts[user_id] = now
|
|
return response_payload
|
|
|
|
except (sqlite3.OperationalError, Exception) as e:
|
|
error_str = str(e).lower()
|
|
if 'no such column' in error_str and ('exa_calls' in error_str or 'exa_cost' in error_str or 'video_calls' in error_str or 'video_cost' in error_str or 'image_edit_calls' in error_str or 'image_edit_cost' in error_str or 'audio_calls' in error_str or 'audio_cost' in error_str):
|
|
logger.warning("Missing column detected in dashboard query, attempting schema fix...")
|
|
try:
|
|
import services.subscription.schema_utils as schema_utils
|
|
schema_utils._checked_usage_summaries_columns = False
|
|
schema_utils._checked_subscription_plan_columns = False
|
|
ensure_usage_summaries_columns(db)
|
|
ensure_subscription_plan_columns(db)
|
|
db.expire_all()
|
|
# Retry the query
|
|
usage_service = UsageTrackingService(db)
|
|
pricing_service = PricingService(db)
|
|
|
|
current_usage = usage_service.get_user_usage_stats(user_id)
|
|
trends = usage_service.get_usage_trends(user_id, 6)
|
|
limits = pricing_service.get_user_limits(user_id)
|
|
|
|
alerts = db.query(UsageAlert).filter(
|
|
UsageAlert.user_id == user_id,
|
|
UsageAlert.is_read == False
|
|
).order_by(UsageAlert.created_at.desc()).limit(5).all()
|
|
|
|
alerts_data = [
|
|
{
|
|
"id": alert.id,
|
|
"type": alert.alert_type,
|
|
"title": alert.title,
|
|
"message": alert.message,
|
|
"severity": alert.severity,
|
|
"created_at": alert.created_at.isoformat()
|
|
}
|
|
for alert in alerts
|
|
]
|
|
|
|
current_cost = current_usage.get('total_cost', 0)
|
|
days_in_period = 30
|
|
current_day = datetime.now().day
|
|
projected_cost = (current_cost / current_day) * days_in_period if current_day > 0 else 0
|
|
|
|
response_payload = {
|
|
"success": True,
|
|
"data": {
|
|
"current_usage": current_usage,
|
|
"trends": trends,
|
|
"limits": limits,
|
|
"alerts": alerts_data,
|
|
"projections": {
|
|
"projected_monthly_cost": round(projected_cost, 2),
|
|
"cost_limit": limits.get('limits', {}).get('monthly_cost', 0) if limits else 0,
|
|
"projected_usage_percentage": (projected_cost / max(limits.get('limits', {}).get('monthly_cost', 1), 1)) * 100 if limits else 0
|
|
},
|
|
"summary": {
|
|
"total_api_calls_this_month": current_usage.get('total_calls', 0),
|
|
"total_cost_this_month": current_usage.get('total_cost', 0),
|
|
"usage_status": current_usage.get('usage_status', 'active'),
|
|
"unread_alerts": len(alerts_data)
|
|
}
|
|
}
|
|
}
|
|
return response_payload
|
|
except Exception as retry_err:
|
|
logger.error(f"Schema fix and retry failed: {retry_err}")
|
|
raise HTTPException(status_code=500, detail=f"Database schema error: {str(e)}")
|
|
|
|
logger.error(f"Error getting dashboard data: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@router.get("/renewal-history/{user_id}")
|
|
async def get_renewal_history(
|
|
user_id: str,
|
|
limit: int = Query(50, ge=1, le=100, description="Number of records to return"),
|
|
offset: int = Query(0, ge=0, description="Pagination offset"),
|
|
current_user: Dict[str, Any] = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Get subscription renewal history for a user.
|
|
|
|
Returns:
|
|
- List of renewal history records
|
|
- Total count for pagination
|
|
"""
|
|
try:
|
|
# Verify user can only access their own data
|
|
if current_user.get('id') != user_id:
|
|
raise HTTPException(status_code=403, detail="Access denied")
|
|
|
|
# Get total count
|
|
total_count = db.query(SubscriptionRenewalHistory).filter(
|
|
SubscriptionRenewalHistory.user_id == user_id
|
|
).count()
|
|
|
|
# Get paginated results, ordered by created_at descending (most recent first)
|
|
renewals = db.query(SubscriptionRenewalHistory).filter(
|
|
SubscriptionRenewalHistory.user_id == user_id
|
|
).order_by(SubscriptionRenewalHistory.created_at.desc()).offset(offset).limit(limit).all()
|
|
|
|
# Format renewal history for response
|
|
renewal_history = []
|
|
for renewal in renewals:
|
|
renewal_history.append({
|
|
'id': renewal.id,
|
|
'plan_name': renewal.plan_name,
|
|
'plan_tier': renewal.plan_tier,
|
|
'previous_period_start': renewal.previous_period_start.isoformat() if renewal.previous_period_start else None,
|
|
'previous_period_end': renewal.previous_period_end.isoformat() if renewal.previous_period_end else None,
|
|
'new_period_start': renewal.new_period_start.isoformat() if renewal.new_period_start else None,
|
|
'new_period_end': renewal.new_period_end.isoformat() if renewal.new_period_end else None,
|
|
'billing_cycle': renewal.billing_cycle.value if renewal.billing_cycle else None,
|
|
'renewal_type': renewal.renewal_type,
|
|
'renewal_count': renewal.renewal_count,
|
|
'previous_plan_name': renewal.previous_plan_name,
|
|
'previous_plan_tier': renewal.previous_plan_tier,
|
|
'usage_before_renewal': renewal.usage_before_renewal,
|
|
'payment_amount': float(renewal.payment_amount) if renewal.payment_amount else 0.0,
|
|
'payment_status': renewal.payment_status,
|
|
'payment_date': renewal.payment_date.isoformat() if renewal.payment_date else None,
|
|
'created_at': renewal.created_at.isoformat() if renewal.created_at else None
|
|
})
|
|
|
|
return {
|
|
"success": True,
|
|
"data": {
|
|
"renewals": renewal_history,
|
|
"total_count": total_count,
|
|
"limit": limit,
|
|
"offset": offset,
|
|
"has_more": (offset + limit) < total_count
|
|
}
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error getting renewal history: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@router.get("/usage-logs")
|
|
async def get_usage_logs(
|
|
limit: int = Query(50, ge=1, le=5000, description="Number of logs to return"),
|
|
offset: int = Query(0, ge=0, description="Pagination offset"),
|
|
provider: Optional[str] = Query(None, description="Filter by provider"),
|
|
status_code: Optional[int] = Query(None, description="Filter by HTTP status code"),
|
|
billing_period: Optional[str] = Query(None, description="Filter by billing period (YYYY-MM)"),
|
|
current_user: Dict[str, Any] = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Get API usage logs for the current user.
|
|
|
|
Query Params:
|
|
- limit: Number of logs to return (1-500, default: 50)
|
|
- offset: Pagination offset (default: 0)
|
|
- provider: Filter by provider (e.g., "gemini", "openai", "huggingface")
|
|
- status_code: Filter by HTTP status code (e.g., 200 for success, 400+ for errors)
|
|
- billing_period: Filter by billing period (YYYY-MM format)
|
|
|
|
Returns:
|
|
- List of usage logs with API call details
|
|
- Total count for pagination
|
|
"""
|
|
try:
|
|
# Get user_id from current_user
|
|
user_id = str(current_user.get('id', '')) if current_user else None
|
|
|
|
if not user_id:
|
|
raise HTTPException(status_code=401, detail="User not authenticated")
|
|
|
|
# Build query
|
|
query = db.query(APIUsageLog).filter(
|
|
APIUsageLog.user_id == user_id
|
|
)
|
|
|
|
# Apply filters
|
|
if provider:
|
|
provider_lower = provider.lower()
|
|
# Handle special case: huggingface maps to MISTRAL enum in database
|
|
if provider_lower == "huggingface":
|
|
provider_enum = APIProvider.MISTRAL
|
|
else:
|
|
try:
|
|
provider_enum = APIProvider(provider_lower)
|
|
except ValueError:
|
|
# Invalid provider, return empty results
|
|
return {
|
|
"logs": [],
|
|
"total_count": 0,
|
|
"limit": limit,
|
|
"offset": offset,
|
|
"has_more": False
|
|
}
|
|
query = query.filter(APIUsageLog.provider == provider_enum)
|
|
|
|
if status_code is not None:
|
|
query = query.filter(APIUsageLog.status_code == status_code)
|
|
|
|
if billing_period:
|
|
query = query.filter(APIUsageLog.billing_period == billing_period)
|
|
|
|
# Check and wrap logs if necessary (before getting count)
|
|
wrapping_service = LogWrappingService(db)
|
|
wrap_result = wrapping_service.check_and_wrap_logs(user_id)
|
|
if wrap_result.get('wrapped'):
|
|
logger.info(f"[UsageLogs] Log wrapping completed for user {user_id}: {wrap_result.get('message')}")
|
|
# Rebuild query after wrapping (in case filters changed)
|
|
query = db.query(APIUsageLog).filter(
|
|
APIUsageLog.user_id == user_id
|
|
)
|
|
# Reapply filters
|
|
if provider:
|
|
provider_lower = provider.lower()
|
|
if provider_lower == "huggingface":
|
|
provider_enum = APIProvider.MISTRAL
|
|
else:
|
|
try:
|
|
provider_enum = APIProvider(provider_lower)
|
|
except ValueError:
|
|
return {
|
|
"logs": [],
|
|
"total_count": 0,
|
|
"limit": limit,
|
|
"offset": offset,
|
|
"has_more": False
|
|
}
|
|
query = query.filter(APIUsageLog.provider == provider_enum)
|
|
if status_code is not None:
|
|
query = query.filter(APIUsageLog.status_code == status_code)
|
|
if billing_period:
|
|
query = query.filter(APIUsageLog.billing_period == billing_period)
|
|
|
|
# Get total count
|
|
total_count = query.count()
|
|
|
|
# Get paginated results, ordered by timestamp descending (most recent first)
|
|
logs = query.order_by(desc(APIUsageLog.timestamp)).offset(offset).limit(limit).all()
|
|
|
|
# Format logs for response
|
|
formatted_logs = []
|
|
for log in logs:
|
|
# Determine status based on status_code
|
|
status = 'success' if 200 <= log.status_code < 300 else 'failed'
|
|
|
|
# Handle provider display name - ALL MISTRAL enum logs are actually HuggingFace
|
|
# (HuggingFace always maps to MISTRAL enum in the database)
|
|
provider_display = log.provider.value if log.provider else None
|
|
if provider_display == "mistral":
|
|
# All MISTRAL provider logs are HuggingFace calls
|
|
provider_display = "huggingface"
|
|
|
|
formatted_logs.append({
|
|
'id': log.id,
|
|
'timestamp': log.timestamp.isoformat() if log.timestamp else None,
|
|
'provider': provider_display,
|
|
'model_used': log.model_used,
|
|
'endpoint': log.endpoint,
|
|
'method': log.method,
|
|
'tokens_input': log.tokens_input or 0,
|
|
'tokens_output': log.tokens_output or 0,
|
|
'tokens_total': log.tokens_total or 0,
|
|
'cost_input': float(log.cost_input) if log.cost_input else 0.0,
|
|
'cost_output': float(log.cost_output) if log.cost_output else 0.0,
|
|
'cost_total': float(log.cost_total) if log.cost_total else 0.0,
|
|
'response_time': float(log.response_time) if log.response_time else 0.0,
|
|
'status_code': log.status_code,
|
|
'status': status,
|
|
'error_message': log.error_message,
|
|
'billing_period': log.billing_period,
|
|
'retry_count': log.retry_count or 0,
|
|
'is_aggregated': log.endpoint == "[AGGREGATED]" # Flag to indicate aggregated log
|
|
})
|
|
|
|
return {
|
|
"logs": formatted_logs,
|
|
"total_count": total_count,
|
|
"limit": limit,
|
|
"offset": offset,
|
|
"has_more": (offset + limit) < total_count
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error getting usage logs: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=f"Failed to get usage logs: {str(e)}")
|
|
|
|
|
|
class PreflightOperationRequest(BaseModel):
|
|
"""Request model for pre-flight check operation."""
|
|
provider: str
|
|
model: Optional[str] = None
|
|
tokens_requested: Optional[int] = 0
|
|
operation_type: str
|
|
actual_provider_name: Optional[str] = None
|
|
|
|
|
|
class PreflightCheckRequest(BaseModel):
|
|
"""Request model for pre-flight check."""
|
|
operations: List[PreflightOperationRequest]
|
|
|
|
|
|
@router.post("/preflight-check")
|
|
async def preflight_check(
|
|
request: PreflightCheckRequest,
|
|
db: Session = Depends(get_db),
|
|
current_user: Dict[str, Any] = Depends(get_current_user)
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Pre-flight check for operations with cost estimation.
|
|
|
|
Lightweight endpoint that:
|
|
- Validates if operations are allowed based on subscription limits
|
|
- Estimates cost for operations
|
|
- Returns usage information and remaining quota
|
|
|
|
Uses caching to minimize DB load (< 100ms with cache hit).
|
|
"""
|
|
try:
|
|
user_id = str(current_user.get('id', ''))
|
|
if not user_id:
|
|
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
|
|
|
# Ensure schema columns exist
|
|
try:
|
|
ensure_subscription_plan_columns(db)
|
|
ensure_usage_summaries_columns(db)
|
|
except Exception as schema_err:
|
|
logger.warning(f"Schema check failed: {schema_err}")
|
|
|
|
pricing_service = PricingService(db)
|
|
|
|
# Convert request operations to internal format
|
|
operations_to_validate = []
|
|
for op in request.operations:
|
|
try:
|
|
# Map provider string to APIProvider enum
|
|
provider_str = op.provider.lower()
|
|
if provider_str == "huggingface":
|
|
provider_enum = APIProvider.MISTRAL # Maps to HuggingFace
|
|
elif provider_str == "video":
|
|
provider_enum = APIProvider.VIDEO
|
|
elif provider_str == "image_edit":
|
|
provider_enum = APIProvider.IMAGE_EDIT
|
|
elif provider_str == "stability":
|
|
provider_enum = APIProvider.STABILITY
|
|
elif provider_str == "audio":
|
|
provider_enum = APIProvider.AUDIO
|
|
else:
|
|
try:
|
|
provider_enum = APIProvider(provider_str)
|
|
except ValueError:
|
|
logger.warning(f"Unknown provider: {provider_str}, skipping")
|
|
continue
|
|
|
|
operations_to_validate.append({
|
|
'provider': provider_enum,
|
|
'tokens_requested': op.tokens_requested or 0,
|
|
'actual_provider_name': op.actual_provider_name or op.provider,
|
|
'operation_type': op.operation_type
|
|
})
|
|
except Exception as e:
|
|
logger.warning(f"Error processing operation {op.operation_type}: {e}")
|
|
continue
|
|
|
|
if not operations_to_validate:
|
|
raise HTTPException(status_code=400, detail="No valid operations provided")
|
|
|
|
# Perform pre-flight validation
|
|
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
|
user_id=user_id,
|
|
operations=operations_to_validate
|
|
)
|
|
|
|
# Get pricing and cost estimation for each operation
|
|
operation_results = []
|
|
total_cost = 0.0
|
|
|
|
for i, op in enumerate(operations_to_validate):
|
|
op_result = {
|
|
'provider': op['actual_provider_name'],
|
|
'operation_type': op['operation_type'],
|
|
'cost': 0.0,
|
|
'allowed': can_proceed,
|
|
'limit_info': None,
|
|
'message': None
|
|
}
|
|
|
|
# Get pricing for this operation
|
|
model_name = request.operations[i].model
|
|
if model_name:
|
|
pricing_info = pricing_service.get_pricing_for_provider_model(
|
|
op['provider'],
|
|
model_name
|
|
)
|
|
|
|
if pricing_info:
|
|
# Determine cost based on operation type
|
|
if op['provider'] in [APIProvider.VIDEO, APIProvider.IMAGE_EDIT, APIProvider.STABILITY]:
|
|
cost = pricing_info.get('cost_per_request', 0.0) or pricing_info.get('cost_per_image', 0.0) or 0.0
|
|
elif op['provider'] == APIProvider.AUDIO:
|
|
# Audio pricing is per character (every character is 1 token)
|
|
cost = (pricing_info.get('cost_per_input_token', 0.0) or 0.0) * (op['tokens_requested'] / 1000.0)
|
|
elif op['tokens_requested'] > 0:
|
|
# Token-based cost estimation (rough estimate)
|
|
cost = (pricing_info.get('cost_per_input_token', 0.0) or 0.0) * (op['tokens_requested'] / 1000)
|
|
else:
|
|
cost = pricing_info.get('cost_per_request', 0.0) or 0.0
|
|
|
|
op_result['cost'] = round(cost, 4)
|
|
total_cost += cost
|
|
else:
|
|
# Use default cost if pricing not found
|
|
if op['provider'] == APIProvider.VIDEO:
|
|
op_result['cost'] = 0.10 # Default video cost
|
|
total_cost += 0.10
|
|
elif op['provider'] == APIProvider.IMAGE_EDIT:
|
|
op_result['cost'] = 0.05 # Default image edit cost
|
|
total_cost += 0.05
|
|
elif op['provider'] == APIProvider.STABILITY:
|
|
op_result['cost'] = 0.04 # Default image generation cost
|
|
total_cost += 0.04
|
|
elif op['provider'] == APIProvider.AUDIO:
|
|
# Default audio cost: $0.05 per 1,000 characters
|
|
cost = (op['tokens_requested'] / 1000.0) * 0.05
|
|
op_result['cost'] = round(cost, 4)
|
|
total_cost += cost
|
|
|
|
# Get limit information
|
|
limit_info = None
|
|
if error_details and not can_proceed:
|
|
usage_info = error_details.get('usage_info', {})
|
|
if usage_info:
|
|
op_result['message'] = message
|
|
limit_info = {
|
|
'current_usage': usage_info.get('current_usage', 0),
|
|
'limit': usage_info.get('limit', 0),
|
|
'remaining': max(0, usage_info.get('limit', 0) - usage_info.get('current_usage', 0))
|
|
}
|
|
op_result['limit_info'] = limit_info
|
|
else:
|
|
# Get current usage for this provider
|
|
limits = pricing_service.get_user_limits(user_id)
|
|
if limits:
|
|
usage_summary = db.query(UsageSummary).filter(
|
|
UsageSummary.user_id == user_id,
|
|
UsageSummary.billing_period == pricing_service.get_current_billing_period(user_id)
|
|
).first()
|
|
|
|
if usage_summary:
|
|
if op['provider'] == APIProvider.VIDEO:
|
|
current = getattr(usage_summary, 'video_calls', 0) or 0
|
|
limit = limits['limits'].get('video_calls', 0)
|
|
elif op['provider'] == APIProvider.IMAGE_EDIT:
|
|
current = getattr(usage_summary, 'image_edit_calls', 0) or 0
|
|
limit = limits['limits'].get('image_edit_calls', 0)
|
|
elif op['provider'] == APIProvider.STABILITY:
|
|
current = getattr(usage_summary, 'stability_calls', 0) or 0
|
|
limit = limits['limits'].get('stability_calls', 0)
|
|
elif op['provider'] == APIProvider.AUDIO:
|
|
current = getattr(usage_summary, 'audio_calls', 0) or 0
|
|
limit = limits['limits'].get('audio_calls', 0)
|
|
else:
|
|
# For LLM providers, use token limits
|
|
provider_key = op['provider'].value
|
|
current_tokens = getattr(usage_summary, f"{provider_key}_tokens", 0) or 0
|
|
limit = limits['limits'].get(f"{provider_key}_tokens", 0)
|
|
current = current_tokens
|
|
|
|
limit_info = {
|
|
'current_usage': current,
|
|
'limit': limit,
|
|
'remaining': max(0, limit - current) if limit > 0 else float('inf')
|
|
}
|
|
op_result['limit_info'] = limit_info
|
|
|
|
operation_results.append(op_result)
|
|
|
|
# Get overall usage summary
|
|
limits = pricing_service.get_user_limits(user_id)
|
|
usage_summary = None
|
|
if limits:
|
|
usage_summary = db.query(UsageSummary).filter(
|
|
UsageSummary.user_id == user_id,
|
|
UsageSummary.billing_period == pricing_service.get_current_billing_period(user_id)
|
|
).first()
|
|
|
|
response_data = {
|
|
'can_proceed': can_proceed,
|
|
'estimated_cost': round(total_cost, 4),
|
|
'operations': operation_results,
|
|
'total_cost': round(total_cost, 4),
|
|
'usage_summary': None,
|
|
'cached': False # TODO: Track if result was cached
|
|
}
|
|
|
|
if usage_summary and limits:
|
|
# For video generation, show video limits
|
|
video_current = getattr(usage_summary, 'video_calls', 0) or 0
|
|
video_limit = limits['limits'].get('video_calls', 0)
|
|
|
|
response_data['usage_summary'] = {
|
|
'current_calls': video_current,
|
|
'limit': video_limit,
|
|
'remaining': max(0, video_limit - video_current) if video_limit > 0 else float('inf')
|
|
}
|
|
|
|
return {
|
|
"success": True,
|
|
"data": response_data
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error in pre-flight check: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=f"Pre-flight check failed: {str(e)}") |