Base code

This commit is contained in:
Kunthawat Greethong
2026-01-08 22:39:53 +07:00
parent 697115c61a
commit c35fa52117
2169 changed files with 626670 additions and 0 deletions

View File

@@ -0,0 +1,185 @@
# Subscription Services Package
## Overview
This package consolidates all subscription, billing, and usage tracking related services and middleware into a single, well-organized module. This follows the same architectural pattern as the onboarding package for consistency and maintainability.
## Package Structure
```
backend/services/subscription/
├── __init__.py # Package exports
├── pricing_service.py # API pricing and cost calculations
├── usage_tracking_service.py # Usage tracking and limits
├── exception_handler.py # Exception handling
├── monitoring_middleware.py # API monitoring with usage tracking
└── README.md # This documentation
```
## Services
### PricingService
- **File**: `pricing_service.py`
- **Purpose**: Manages API pricing, cost calculation, and subscription limits
- **Key Features**:
- Dynamic pricing based on API provider and model
- Cost calculation for input/output tokens
- Subscription limit enforcement
- Billing period management
### UsageTrackingService
- **File**: `usage_tracking_service.py`
- **Purpose**: Comprehensive tracking of API usage, costs, and subscription limits
- **Key Features**:
- Real-time usage tracking
- Cost calculation and billing
- Usage limit enforcement with TTL caching
- Usage alerts and notifications
### SubscriptionExceptionHandler
- **File**: `exception_handler.py`
- **Purpose**: Centralized exception handling for subscription-related errors
- **Key Features**:
- Custom exception types
- Error handling decorators
- Consistent error responses
### Monitoring Middleware
- **File**: `monitoring_middleware.py`
- **Purpose**: FastAPI middleware for API monitoring and usage tracking
- **Key Features**:
- Request/response monitoring
- Usage tracking integration
- Performance metrics
- Database API monitoring
## Usage
### Import Pattern
Always use the consolidated package for subscription-related imports:
```python
# ✅ Correct - Use consolidated package
from services.subscription import PricingService, UsageTrackingService
from services.subscription import SubscriptionExceptionHandler
from services.subscription import check_usage_limits_middleware
# ❌ Incorrect - Old scattered imports
from services.pricing_service import PricingService
from services.usage_tracking_service import UsageTrackingService
from middleware.monitoring_middleware import check_usage_limits_middleware
```
### Service Initialization
```python
from services.subscription import PricingService, UsageTrackingService
from services.database import get_db
# Get database session
db = next(get_db())
# Initialize services
pricing_service = PricingService(db)
usage_service = UsageTrackingService(db)
```
### Middleware Registration
```python
from services.subscription import monitoring_middleware
# Register middleware in FastAPI app
app.middleware("http")(monitoring_middleware)
```
## Database Models
The subscription services use the following database models (defined in `backend/models/subscription_models.py`):
- `APIProvider` - API provider enumeration
- `SubscriptionPlan` - Subscription plan definitions
- `UserSubscription` - User subscription records
- `UsageSummary` - Usage summary by billing period
- `APIUsageLog` - Individual API usage logs
- `APIProviderPricing` - Pricing configuration
- `UsageAlert` - Usage limit alerts
- `SubscriptionTier` - Subscription tier definitions
- `BillingCycle` - Billing cycle enumeration
- `UsageStatus` - Usage status enumeration
## Key Features
### 1. Database-Only Persistence
- All data stored in database tables
- No file-based storage
- User-isolated data access
### 2. TTL Caching
- In-memory caching for performance
- 30-second TTL for usage limit checks
- 10-minute TTL for dashboard data
### 3. Real-time Monitoring
- Live API usage tracking
- Performance metrics collection
- Error rate monitoring
### 4. Flexible Pricing
- Per-provider pricing configuration
- Model-specific pricing
- Dynamic cost calculation
## Error Handling
The package provides comprehensive error handling:
```python
from services.subscription import (
SubscriptionException,
UsageLimitExceededException,
PricingException,
TrackingException
)
try:
# Subscription operation
pass
except UsageLimitExceededException as e:
# Handle usage limit exceeded
pass
except PricingException as e:
# Handle pricing error
pass
```
## Configuration
The services use environment variables for configuration:
- `SUBSCRIPTION_DASHBOARD_NOCACHE` - Bypass dashboard cache
- `ENABLE_ALPHA` - Enable alpha features (default: false)
## Migration from Old Structure
This package consolidates the following previously scattered files:
- `backend/services/pricing_service.py``subscription/pricing_service.py`
- `backend/services/usage_tracking_service.py``subscription/usage_tracking_service.py`
- `backend/services/subscription_exception_handler.py``subscription/exception_handler.py`
- `backend/middleware/monitoring_middleware.py``subscription/monitoring_middleware.py`
## Benefits
1. **Single Package**: All subscription logic in one location
2. **Clear Ownership**: Easy to find subscription-related code
3. **Better Organization**: Follows same pattern as onboarding
4. **Easier Maintenance**: Single source of truth for billing logic
5. **Consistent Architecture**: Matches onboarding consolidation
## Related Packages
- `services.onboarding` - Onboarding and user setup
- `models.subscription_models` - Database models
- `api.subscription_api` - API endpoints

View File

@@ -0,0 +1,40 @@
# Subscription Services Package
# Consolidated subscription-related services and middleware
from .pricing_service import PricingService
from .usage_tracking_service import UsageTrackingService
from .exception_handler import (
SubscriptionException,
SubscriptionExceptionHandler,
UsageLimitExceededException,
PricingException,
TrackingException,
handle_usage_limit_error,
handle_pricing_error,
handle_tracking_error,
)
from .monitoring_middleware import (
DatabaseAPIMonitor,
check_usage_limits_middleware,
monitoring_middleware,
get_monitoring_stats,
get_lightweight_stats,
)
__all__ = [
"PricingService",
"UsageTrackingService",
"SubscriptionException",
"SubscriptionExceptionHandler",
"UsageLimitExceededException",
"PricingException",
"TrackingException",
"handle_usage_limit_error",
"handle_pricing_error",
"handle_tracking_error",
"DatabaseAPIMonitor",
"check_usage_limits_middleware",
"monitoring_middleware",
"get_monitoring_stats",
"get_lightweight_stats",
]

View File

@@ -0,0 +1,412 @@
"""
Comprehensive Exception Handling and Logging for Subscription System
Provides robust error handling, logging, and monitoring for the usage-based subscription system.
"""
import traceback
import json
from datetime import datetime
from typing import Dict, Any, Optional, Union, List
from enum import Enum
from loguru import logger
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError
from models.subscription_models import APIProvider, UsageAlert
class SubscriptionErrorType(Enum):
USAGE_LIMIT_EXCEEDED = "usage_limit_exceeded"
PRICING_ERROR = "pricing_error"
TRACKING_ERROR = "tracking_error"
DATABASE_ERROR = "database_error"
API_PROVIDER_ERROR = "api_provider_error"
AUTHENTICATION_ERROR = "authentication_error"
BILLING_ERROR = "billing_error"
CONFIGURATION_ERROR = "configuration_error"
class SubscriptionErrorSeverity(Enum):
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class SubscriptionException(Exception):
"""Base exception for subscription system errors."""
def __init__(
self,
message: str,
error_type: SubscriptionErrorType,
severity: SubscriptionErrorSeverity = SubscriptionErrorSeverity.MEDIUM,
user_id: str = None,
provider: APIProvider = None,
context: Dict[str, Any] = None,
original_error: Exception = None
):
self.message = message
self.error_type = error_type
self.severity = severity
self.user_id = user_id
self.provider = provider
self.context = context or {}
self.original_error = original_error
self.timestamp = datetime.utcnow()
super().__init__(message)
def to_dict(self) -> Dict[str, Any]:
"""Convert exception to dictionary for logging/storage."""
return {
"message": self.message,
"error_type": self.error_type.value,
"severity": self.severity.value,
"user_id": self.user_id,
"provider": self.provider.value if self.provider else None,
"context": self.context,
"timestamp": self.timestamp.isoformat(),
"original_error": str(self.original_error) if self.original_error else None,
"traceback": traceback.format_exc() if self.original_error else None
}
class UsageLimitExceededException(SubscriptionException):
"""Exception raised when usage limits are exceeded."""
def __init__(
self,
message: str,
user_id: str,
provider: APIProvider,
limit_type: str,
current_usage: Union[int, float],
limit_value: Union[int, float],
context: Dict[str, Any] = None
):
context = context or {}
context.update({
"limit_type": limit_type,
"current_usage": current_usage,
"limit_value": limit_value,
"usage_percentage": (current_usage / max(limit_value, 1)) * 100
})
super().__init__(
message=message,
error_type=SubscriptionErrorType.USAGE_LIMIT_EXCEEDED,
severity=SubscriptionErrorSeverity.HIGH,
user_id=user_id,
provider=provider,
context=context
)
class PricingException(SubscriptionException):
"""Exception raised for pricing calculation errors."""
def __init__(
self,
message: str,
provider: APIProvider = None,
model_name: str = None,
context: Dict[str, Any] = None,
original_error: Exception = None
):
context = context or {}
if model_name:
context["model_name"] = model_name
super().__init__(
message=message,
error_type=SubscriptionErrorType.PRICING_ERROR,
severity=SubscriptionErrorSeverity.MEDIUM,
provider=provider,
context=context,
original_error=original_error
)
class TrackingException(SubscriptionException):
"""Exception raised for usage tracking errors."""
def __init__(
self,
message: str,
user_id: str = None,
provider: APIProvider = None,
context: Dict[str, Any] = None,
original_error: Exception = None
):
super().__init__(
message=message,
error_type=SubscriptionErrorType.TRACKING_ERROR,
severity=SubscriptionErrorSeverity.MEDIUM,
user_id=user_id,
provider=provider,
context=context,
original_error=original_error
)
class SubscriptionExceptionHandler:
"""Comprehensive exception handler for the subscription system."""
def __init__(self, db: Session = None):
self.db = db
self._setup_logging()
def _setup_logging(self):
"""Setup structured logging for subscription errors."""
from utils.logger_utils import get_service_logger
return get_service_logger("subscription_exception_handler")
def handle_exception(
self,
error: Union[Exception, SubscriptionException],
context: Dict[str, Any] = None,
log_level: str = "error"
) -> Dict[str, Any]:
"""Handle and log subscription system exceptions."""
context = context or {}
# Convert regular exceptions to SubscriptionException
if not isinstance(error, SubscriptionException):
error = SubscriptionException(
message=str(error),
error_type=self._classify_error(error),
severity=self._determine_severity(error),
context=context,
original_error=error
)
# Log the error
error_data = error.to_dict()
error_data.update(context)
log_message = f"Subscription Error: {error.message}"
if log_level == "critical":
logger.critical(log_message, extra={"error_data": error_data})
elif log_level == "error":
logger.error(log_message, extra={"error_data": error_data})
elif log_level == "warning":
logger.warning(log_message, extra={"error_data": error_data})
else:
logger.info(log_message, extra={"error_data": error_data})
# Store critical errors in database for alerting
if error.severity in [SubscriptionErrorSeverity.HIGH, SubscriptionErrorSeverity.CRITICAL]:
self._store_error_alert(error)
# Return formatted error response
return self._format_error_response(error)
def _classify_error(self, error: Exception) -> SubscriptionErrorType:
"""Classify an exception into a subscription error type."""
error_str = str(error).lower()
error_type_name = type(error).__name__.lower()
if "limit" in error_str or "exceeded" in error_str:
return SubscriptionErrorType.USAGE_LIMIT_EXCEEDED
elif "pricing" in error_str or "cost" in error_str:
return SubscriptionErrorType.PRICING_ERROR
elif "tracking" in error_str or "usage" in error_str:
return SubscriptionErrorType.TRACKING_ERROR
elif "database" in error_str or "sql" in error_type_name:
return SubscriptionErrorType.DATABASE_ERROR
elif "api" in error_str or "provider" in error_str:
return SubscriptionErrorType.API_PROVIDER_ERROR
elif "auth" in error_str or "permission" in error_str:
return SubscriptionErrorType.AUTHENTICATION_ERROR
elif "billing" in error_str or "payment" in error_str:
return SubscriptionErrorType.BILLING_ERROR
else:
return SubscriptionErrorType.CONFIGURATION_ERROR
def _determine_severity(self, error: Exception) -> SubscriptionErrorSeverity:
"""Determine the severity of an error."""
error_str = str(error).lower()
error_type = type(error)
# Critical errors
if isinstance(error, (SQLAlchemyError, ConnectionError)):
return SubscriptionErrorSeverity.CRITICAL
# High severity errors
if "limit exceeded" in error_str or "unauthorized" in error_str:
return SubscriptionErrorSeverity.HIGH
# Medium severity errors
if "pricing" in error_str or "tracking" in error_str:
return SubscriptionErrorSeverity.MEDIUM
# Default to low
return SubscriptionErrorSeverity.LOW
def _store_error_alert(self, error: SubscriptionException):
"""Store critical errors as alerts in the database."""
if not self.db or not error.user_id:
return
try:
alert = UsageAlert(
user_id=error.user_id,
alert_type="system_error",
threshold_percentage=0,
provider=error.provider,
title=f"System Error: {error.error_type.value}",
message=error.message,
severity=error.severity.value,
billing_period=datetime.now().strftime("%Y-%m")
)
self.db.add(alert)
self.db.commit()
except Exception as e:
logger.error(f"Failed to store error alert: {e}")
def _format_error_response(self, error: SubscriptionException) -> Dict[str, Any]:
"""Format error for API response."""
response = {
"success": False,
"error": {
"type": error.error_type.value,
"message": error.message,
"severity": error.severity.value,
"timestamp": error.timestamp.isoformat()
}
}
# Add context for debugging (non-sensitive info only)
if error.context:
safe_context = {
k: v for k, v in error.context.items()
if k not in ["password", "token", "key", "secret"]
}
response["error"]["context"] = safe_context
# Add user-friendly message based on error type
user_messages = {
SubscriptionErrorType.USAGE_LIMIT_EXCEEDED:
"You have reached your usage limit. Please upgrade your plan or wait for the next billing cycle.",
SubscriptionErrorType.PRICING_ERROR:
"There was an issue calculating the cost for this request. Please try again.",
SubscriptionErrorType.TRACKING_ERROR:
"Unable to track usage for this request. Please contact support if this persists.",
SubscriptionErrorType.DATABASE_ERROR:
"A database error occurred. Please try again later.",
SubscriptionErrorType.API_PROVIDER_ERROR:
"There was an issue with the API provider. Please try again.",
SubscriptionErrorType.AUTHENTICATION_ERROR:
"Authentication failed. Please check your credentials.",
SubscriptionErrorType.BILLING_ERROR:
"There was a billing-related error. Please contact support.",
SubscriptionErrorType.CONFIGURATION_ERROR:
"System configuration error. Please contact support."
}
response["error"]["user_message"] = user_messages.get(
error.error_type,
"An unexpected error occurred. Please try again or contact support."
)
return response
# Utility functions for common error scenarios
def handle_usage_limit_error(
user_id: str,
provider: APIProvider,
limit_type: str,
current_usage: Union[int, float],
limit_value: Union[int, float],
db: Session = None
) -> Dict[str, Any]:
"""Handle usage limit exceeded errors."""
handler = SubscriptionExceptionHandler(db)
error = UsageLimitExceededException(
message=f"Usage limit exceeded for {limit_type}",
user_id=user_id,
provider=provider,
limit_type=limit_type,
current_usage=current_usage,
limit_value=limit_value
)
return handler.handle_exception(error, log_level="warning")
def handle_pricing_error(
message: str,
provider: APIProvider = None,
model_name: str = None,
original_error: Exception = None,
db: Session = None
) -> Dict[str, Any]:
"""Handle pricing calculation errors."""
handler = SubscriptionExceptionHandler(db)
error = PricingException(
message=message,
provider=provider,
model_name=model_name,
original_error=original_error
)
return handler.handle_exception(error)
def handle_tracking_error(
message: str,
user_id: str = None,
provider: APIProvider = None,
original_error: Exception = None,
db: Session = None
) -> Dict[str, Any]:
"""Handle usage tracking errors."""
handler = SubscriptionExceptionHandler(db)
error = TrackingException(
message=message,
user_id=user_id,
provider=provider,
original_error=original_error
)
return handler.handle_exception(error)
def log_usage_event(
user_id: str,
provider: APIProvider,
action: str,
details: Dict[str, Any] = None
):
"""Log usage events for monitoring and debugging."""
details = details or {}
log_data = {
"user_id": user_id,
"provider": provider.value,
"action": action,
"timestamp": datetime.utcnow().isoformat(),
**details
}
logger.info(f"Usage Tracking: {action}", extra={"usage_data": log_data})
# Decorator for automatic exception handling
def handle_subscription_errors(db: Session = None):
"""Decorator to automatically handle subscription-related exceptions."""
def decorator(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except SubscriptionException as e:
handler = SubscriptionExceptionHandler(db)
return handler.handle_exception(e)
except Exception as e:
handler = SubscriptionExceptionHandler(db)
return handler.handle_exception(e)
return wrapper
return decorator

View File

@@ -0,0 +1,799 @@
"""
Limit Validation Module
Handles subscription limit checking and validation logic.
Extracted from pricing_service.py for better modularity.
"""
from typing import Dict, Any, Optional, List, Tuple, TYPE_CHECKING
from datetime import datetime, timedelta
from sqlalchemy import text
from loguru import logger
from models.subscription_models import (
UserSubscription, UsageSummary, SubscriptionPlan,
APIProvider, SubscriptionTier
)
if TYPE_CHECKING:
from .pricing_service import PricingService
class LimitValidator:
"""Validates subscription limits for API usage."""
def __init__(self, pricing_service: 'PricingService'):
"""
Initialize limit validator with reference to PricingService.
Args:
pricing_service: Instance of PricingService to access helper methods and cache
"""
self.pricing_service = pricing_service
self.db = pricing_service.db
def check_usage_limits(self, user_id: str, provider: APIProvider,
tokens_requested: int = 0, actual_provider_name: Optional[str] = None) -> Tuple[bool, str, Dict[str, Any]]:
"""Check if user can make an API call within their limits.
Args:
user_id: User ID
provider: APIProvider enum (may be MISTRAL for HuggingFace)
tokens_requested: Estimated tokens for the request
actual_provider_name: Optional actual provider name (e.g., "huggingface" when provider is MISTRAL)
Returns:
(can_proceed, error_message, usage_info)
"""
try:
# Use actual_provider_name if provided, otherwise use enum value
# This fixes cases where HuggingFace maps to MISTRAL enum but should show as "huggingface" in errors
display_provider_name = actual_provider_name or provider.value
logger.debug(f"[Subscription Check] Starting limit check for user {user_id}, provider {display_provider_name}, tokens {tokens_requested}")
# Short TTL cache to reduce DB reads under sustained traffic
cache_key = f"{user_id}:{provider.value}"
now = datetime.utcnow()
cached = self.pricing_service._limits_cache.get(cache_key)
if cached and cached.get('expires_at') and cached['expires_at'] > now:
logger.debug(f"[Subscription Check] Using cached result for {user_id}:{provider.value}")
return tuple(cached['result']) # type: ignore
# Get user subscription first to check expiration
subscription = self.db.query(UserSubscription).filter(
UserSubscription.user_id == user_id,
UserSubscription.is_active == True
).first()
if subscription:
logger.debug(f"[Subscription Check] Found subscription for user {user_id}: plan_id={subscription.plan_id}, period_end={subscription.current_period_end}")
else:
logger.debug(f"[Subscription Check] No active subscription found for user {user_id}")
# Check subscription expiration (STRICT: deny if expired)
if subscription:
if subscription.current_period_end < now:
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}: period_end={subscription.current_period_end}, now={now}")
# Subscription expired - check if auto_renew is enabled
if not getattr(subscription, 'auto_renew', False):
# Expired and no auto-renew - deny access
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}, auto_renew=False, denying access")
result = (False, "Subscription expired. Please renew your subscription to continue using the service.", {
'expired': True,
'period_end': subscription.current_period_end.isoformat()
})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
else:
# Try to auto-renew
if not self.pricing_service._ensure_subscription_current(subscription):
# Auto-renew failed - deny access
result = (False, "Subscription expired and auto-renewal failed. Please renew manually.", {
'expired': True,
'auto_renew_failed': True
})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
# Get user limits with error handling (STRICT: fail on errors)
# CRITICAL: Expire SQLAlchemy objects to ensure we get fresh plan data after renewal
try:
# Force expire subscription and plan objects to avoid stale cache
if subscription and subscription.plan_id:
plan_obj = self.db.query(SubscriptionPlan).filter(SubscriptionPlan.id == subscription.plan_id).first()
if plan_obj:
self.db.expire(plan_obj)
logger.debug(f"[Subscription Check] Expired plan object to ensure fresh limits after renewal")
limits = self.pricing_service.get_user_limits(user_id)
if limits:
logger.debug(f"[Subscription Check] Retrieved limits for user {user_id}: plan={limits.get('plan_name')}, tier={limits.get('tier')}")
# Log token limits for debugging
token_limits = limits.get('limits', {})
logger.debug(f"[Subscription Check] Token limits: gemini={token_limits.get('gemini_tokens')}, mistral={token_limits.get('mistral_tokens')}, openai={token_limits.get('openai_tokens')}, anthropic={token_limits.get('anthropic_tokens')}")
else:
logger.debug(f"[Subscription Check] No limits found for user {user_id}, checking free tier")
except Exception as e:
logger.error(f"[Subscription Check] Error getting user limits for {user_id}: {e}", exc_info=True)
# STRICT: Fail closed - deny request if we can't check limits
return False, f"Failed to retrieve subscription limits: {str(e)}", {}
if not limits:
# No subscription found - check for free tier
free_plan = self.db.query(SubscriptionPlan).filter(
SubscriptionPlan.tier == SubscriptionTier.FREE,
SubscriptionPlan.is_active == True
).first()
if free_plan:
logger.info(f"[Subscription Check] Assigning free tier to user {user_id}")
limits = self.pricing_service._plan_to_limits_dict(free_plan)
else:
# No subscription and no free tier - deny access
logger.warning(f"[Subscription Check] No subscription or free tier found for user {user_id}, denying access")
return False, "No subscription plan found. Please subscribe to a plan.", {}
# Get current usage for this billing period with error handling
# CRITICAL: Use fresh queries to avoid SQLAlchemy cache after renewal
try:
current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Expire all objects to force fresh read from DB (critical after renewal)
self.db.expire_all()
# Use raw SQL query first to bypass ORM cache, fallback to ORM if SQL fails
usage = None
try:
from sqlalchemy import text
sql_query = text("SELECT * FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
result = self.db.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
if result:
# Map result to UsageSummary object
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if usage:
self.db.refresh(usage) # Ensure fresh data
except Exception as sql_error:
logger.debug(f"[Subscription Check] Raw SQL query failed, using ORM: {sql_error}")
# Fallback to ORM query
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if usage:
self.db.refresh(usage) # Ensure fresh data
if not usage:
# First usage this period, create summary
try:
usage = UsageSummary(
user_id=user_id,
billing_period=current_period
)
self.db.add(usage)
self.db.commit()
except Exception as create_error:
logger.error(f"Error creating usage summary: {create_error}")
self.db.rollback()
# STRICT: Fail closed on DB error
return False, f"Failed to create usage summary: {str(create_error)}", {}
except Exception as e:
logger.error(f"Error getting usage summary for {user_id}: {e}")
self.db.rollback()
# STRICT: Fail closed on DB error
return False, f"Failed to retrieve usage summary: {str(e)}", {}
# Check call limits with error handling
# NOTE: call_limit = 0 means UNLIMITED (Enterprise plans)
try:
# Use display_provider_name for error messages, but provider.value for DB queries
provider_name = provider.value # For DB field names (e.g., "mistral_calls", "mistral_tokens")
# For LLM text generation providers, check against unified total_calls limit
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
is_llm_provider = provider_name in llm_providers
if is_llm_provider:
# Use unified AI text generation limit (total_calls across all LLM providers)
ai_text_gen_limit = limits['limits'].get('ai_text_generation_calls', 0) or 0
# If unified limit not set, fall back to provider-specific limit for backwards compatibility
if ai_text_gen_limit == 0:
ai_text_gen_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
# Calculate total LLM provider calls (sum of gemini + openai + anthropic + mistral)
current_total_llm_calls = (
(usage.gemini_calls or 0) +
(usage.openai_calls or 0) +
(usage.anthropic_calls or 0) +
(usage.mistral_calls or 0)
)
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if ai_text_gen_limit > 0 and current_total_llm_calls >= ai_text_gen_limit:
logger.error(f"[Subscription Check] AI text generation call limit exceeded for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit} (provider: {display_provider_name})")
result = (False, f"AI text generation call limit reached. Used {current_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls this billing period.", {
'current_calls': current_total_llm_calls,
'limit': ai_text_gen_limit,
'usage_percentage': (current_total_llm_calls / ai_text_gen_limit) * 100 if ai_text_gen_limit > 0 else 0,
'provider': display_provider_name, # Use display name for consistency
'usage_info': {
'provider': display_provider_name, # Use display name for user-facing info
'current_calls': current_total_llm_calls,
'limit': ai_text_gen_limit,
'type': 'ai_text_generation',
'breakdown': {
'gemini': usage.gemini_calls or 0,
'openai': usage.openai_calls or 0,
'anthropic': usage.anthropic_calls or 0,
'mistral': usage.mistral_calls or 0 # DB field name (not display name)
}
}
})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
else:
logger.debug(f"[Subscription Check] AI text generation limit check passed for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit if ai_text_gen_limit > 0 else 'unlimited'} (provider: {display_provider_name})")
else:
# For non-LLM providers, check provider-specific limit
current_calls = getattr(usage, f"{provider_name}_calls", 0) or 0
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if call_limit > 0 and current_calls >= call_limit:
logger.error(f"[Subscription Check] Call limit exceeded for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit}")
result = (False, f"API call limit reached for {display_provider_name}. Used {current_calls} of {call_limit} calls this billing period.", {
'current_calls': current_calls,
'limit': call_limit,
'usage_percentage': 100.0,
'provider': display_provider_name # Use display name for consistency
})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
else:
logger.debug(f"[Subscription Check] Call limit check passed for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit if call_limit > 0 else 'unlimited'}")
except Exception as e:
logger.error(f"Error checking call limits: {e}")
# Continue to next check
# Check token limits for LLM providers with error handling
# NOTE: token_limit = 0 means UNLIMITED (Enterprise plans)
try:
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
current_tokens = getattr(usage, f"{provider_name}_tokens", 0) or 0
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if token_limit > 0 and (current_tokens + tokens_requested) > token_limit:
result = (False, f"Token limit would be exceeded for {display_provider_name}. Current: {current_tokens}, Requested: {tokens_requested}, Limit: {token_limit}", {
'current_tokens': current_tokens,
'requested_tokens': tokens_requested,
'limit': token_limit,
'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100,
'provider': display_provider_name, # Use display name in error details
'usage_info': {
'provider': display_provider_name,
'current_tokens': current_tokens,
'requested_tokens': tokens_requested,
'limit': token_limit,
'type': 'tokens'
}
})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
except Exception as e:
logger.error(f"Error checking token limits: {e}")
# Continue to next check
# Check cost limits with error handling
# NOTE: cost_limit = 0 means UNLIMITED (Enterprise plans)
try:
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if cost_limit > 0 and usage.total_cost >= cost_limit:
result = (False, f"Monthly cost limit reached. Current cost: ${usage.total_cost:.2f}, Limit: ${cost_limit:.2f}", {
'current_cost': usage.total_cost,
'limit': cost_limit,
'usage_percentage': 100.0
})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
except Exception as e:
logger.error(f"Error checking cost limits: {e}")
# Continue to success case
# Calculate usage percentages for warnings
try:
# Determine which call variables to use based on provider type
if is_llm_provider:
# Use unified LLM call tracking
current_call_count = current_total_llm_calls
call_limit_value = ai_text_gen_limit
else:
# Use provider-specific call tracking
current_call_count = current_calls
call_limit_value = call_limit
call_usage_pct = (current_call_count / max(call_limit_value, 1)) * 100 if call_limit_value > 0 else 0
cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0
result = (True, "Within limits", {
'current_calls': current_call_count,
'call_limit': call_limit_value,
'call_usage_percentage': call_usage_pct,
'current_cost': usage.total_cost,
'cost_limit': cost_limit,
'cost_usage_percentage': cost_usage_pct
})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
except Exception as e:
logger.error(f"Error calculating usage percentages: {e}")
# Return basic success
return True, "Within limits", {}
except Exception as e:
logger.error(f"Unexpected error in check_usage_limits for {user_id}: {e}")
# STRICT: Fail closed - deny requests if subscription system fails
return False, f"Subscription check error: {str(e)}", {}
def check_comprehensive_limits(
self,
user_id: str,
operations: List[Dict[str, Any]]
) -> Tuple[bool, Optional[str], Optional[Dict[str, Any]]]:
"""
Comprehensive pre-flight validation that checks ALL limits before making ANY API calls.
This prevents wasteful API calls by validating that ALL subsequent operations will succeed
before making the first external API call.
Args:
user_id: User ID
operations: List of operations to validate, each with:
- 'provider': APIProvider enum
- 'tokens_requested': int (estimated tokens for LLM calls, 0 for non-LLM)
- 'actual_provider_name': Optional[str] (e.g., "huggingface" when provider is MISTRAL)
- 'operation_type': str (e.g., "google_grounding", "llm_call", "image_generation")
Returns:
(can_proceed, error_message, error_details)
If can_proceed is False, error_message explains which limit would be exceeded
"""
try:
logger.info(f"[Pre-flight Check] 🔍 Starting comprehensive validation for user {user_id}")
logger.info(f"[Pre-flight Check] 📋 Validating {len(operations)} operation(s) before making any API calls")
# Get current usage and limits once
current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
logger.info(f"[Pre-flight Check] 📅 Billing Period: {current_period} (for user {user_id})")
# Ensure schema columns exist before querying
try:
from services.subscription.schema_utils import ensure_usage_summaries_columns
ensure_usage_summaries_columns(self.db)
except Exception as schema_err:
logger.warning(f"Schema check failed, will retry on query error: {schema_err}")
# Explicitly expire any cached objects and refresh from DB to ensure fresh data
self.db.expire_all()
try:
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
# CRITICAL: Explicitly refresh from database to get latest values (clears SQLAlchemy cache)
if usage:
self.db.refresh(usage)
except Exception as query_err:
error_str = str(query_err).lower()
if 'no such column' in error_str and 'exa_calls' in error_str:
logger.warning("Missing column detected in usage query, fixing schema and retrying...")
import sqlite3
import services.subscription.schema_utils as schema_utils
schema_utils._checked_usage_summaries_columns = False
from services.subscription.schema_utils import ensure_usage_summaries_columns
ensure_usage_summaries_columns(self.db)
self.db.expire_all()
# Retry the query
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if usage:
self.db.refresh(usage)
else:
raise
# Log what we actually read from database
if usage:
logger.info(f"[Pre-flight Check] 📊 Usage Summary from DB (Period: {current_period}):")
logger.info(f" ├─ Gemini: {usage.gemini_tokens or 0} tokens / {usage.gemini_calls or 0} calls")
logger.info(f" ├─ Mistral/HF: {usage.mistral_tokens or 0} tokens / {usage.mistral_calls or 0} calls")
logger.info(f" ├─ Total Tokens: {usage.total_tokens or 0}")
logger.info(f" └─ Usage Status: {usage.usage_status.value if usage.usage_status else 'N/A'}")
else:
logger.info(f"[Pre-flight Check] 📊 No usage summary found for period {current_period} (will create new)")
if not usage:
# First usage this period, create summary
try:
usage = UsageSummary(
user_id=user_id,
billing_period=current_period
)
self.db.add(usage)
self.db.commit()
except Exception as create_error:
logger.error(f"Error creating usage summary: {create_error}")
self.db.rollback()
return False, f"Failed to create usage summary: {str(create_error)}", {}
# Get user limits
limits_dict = self.pricing_service.get_user_limits(user_id)
if not limits_dict:
# No subscription found - check for free tier
free_plan = self.db.query(SubscriptionPlan).filter(
SubscriptionPlan.tier == SubscriptionTier.FREE,
SubscriptionPlan.is_active == True
).first()
if free_plan:
limits_dict = self.pricing_service._plan_to_limits_dict(free_plan)
else:
return False, "No subscription plan found. Please subscribe to a plan.", {}
limits = limits_dict.get('limits', {})
# Track cumulative usage across all operations
total_llm_calls = (
(usage.gemini_calls or 0) +
(usage.openai_calls or 0) +
(usage.anthropic_calls or 0) +
(usage.mistral_calls or 0)
)
total_llm_tokens = {}
total_images = usage.stability_calls or 0
# Log current usage summary
logger.info(f"[Pre-flight Check] 📊 Current Usage Summary:")
logger.info(f" └─ Total LLM Calls: {total_llm_calls}")
logger.info(f" └─ Gemini Tokens: {usage.gemini_tokens or 0}, Mistral/HF Tokens: {usage.mistral_tokens or 0}")
logger.info(f" └─ Image Calls: {total_images}")
# Validate each operation
for op_idx, operation in enumerate(operations):
provider = operation.get('provider')
provider_name = provider.value if hasattr(provider, 'value') else str(provider)
tokens_requested = operation.get('tokens_requested', 0)
actual_provider_name = operation.get('actual_provider_name')
operation_type = operation.get('operation_type', 'unknown')
display_provider_name = actual_provider_name or provider_name
# Log operation details at debug level (only when needed)
logger.debug(f"[Pre-flight] Operation {op_idx + 1}/{len(operations)}: {operation_type} ({display_provider_name}, {tokens_requested} tokens)")
# Check if this is an LLM provider
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
is_llm_provider = provider_name in llm_providers
# Check unified AI text generation limit for LLM providers
if is_llm_provider:
ai_text_gen_limit = limits.get('ai_text_generation_calls', 0) or 0
if ai_text_gen_limit == 0:
# Fallback to provider-specific limit
ai_text_gen_limit = limits.get(f"{provider_name}_calls", 0) or 0
# Count this operation as an LLM call
projected_total_llm_calls = total_llm_calls + 1
if ai_text_gen_limit > 0 and projected_total_llm_calls > ai_text_gen_limit:
error_info = {
'current_calls': total_llm_calls,
'limit': ai_text_gen_limit,
'provider': display_provider_name,
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"AI text generation call limit would be exceeded. Would use {projected_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls.", {
'error_type': 'call_limit',
'usage_info': error_info
}
# Check token limits for this provider
# CRITICAL: Always query fresh from DB for each operation to avoid SQLAlchemy cache issues
# This ensures we get the latest values after subscription renewal, even for cumulative tracking
provider_tokens_key = f"{provider_name}_tokens"
# Try to get fresh value from DB with comprehensive error handling
base_current_tokens = 0
query_succeeded = False
try:
# Validate column name is safe (only allow known provider token columns)
valid_token_columns = ['gemini_tokens', 'openai_tokens', 'anthropic_tokens', 'mistral_tokens']
if provider_tokens_key not in valid_token_columns:
logger.error(f" └─ Invalid provider tokens key: {provider_tokens_key}")
query_succeeded = True # Treat as success with 0 value
else:
# Method 1: Try raw SQL query to completely bypass ORM cache
try:
logger.debug(f" └─ Attempting raw SQL query for {provider_tokens_key}")
sql_query = text(f"""
SELECT {provider_tokens_key}
FROM usage_summaries
WHERE user_id = :user_id
AND billing_period = :period
LIMIT 1
""")
logger.debug(f" └─ SQL: SELECT {provider_tokens_key} FROM usage_summaries WHERE user_id={user_id} AND billing_period={current_period}")
result = self.db.execute(sql_query, {
'user_id': user_id,
'period': current_period
}).first()
if result:
base_current_tokens = result[0] if result[0] is not None else 0
else:
base_current_tokens = 0
query_succeeded = True
logger.debug(f"[Pre-flight] Raw SQL query for {provider_tokens_key}: {base_current_tokens}")
except Exception as sql_error:
logger.error(f" └─ Raw SQL query failed for {provider_tokens_key}: {type(sql_error).__name__}: {sql_error}", exc_info=True)
query_succeeded = False # Will try ORM fallback
# Method 2: Fallback to fresh ORM query if raw SQL fails
if not query_succeeded:
try:
# Expire all cached objects and do fresh query
self.db.expire_all()
fresh_usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if fresh_usage:
# Explicitly refresh to get latest from DB
self.db.refresh(fresh_usage)
base_current_tokens = getattr(fresh_usage, provider_tokens_key, 0) or 0
else:
base_current_tokens = 0
query_succeeded = True
logger.info(f"[Pre-flight Check] ✅ ORM fallback query succeeded for {provider_tokens_key}: {base_current_tokens}")
except Exception as orm_error:
logger.error(f" └─ ORM query also failed: {orm_error}", exc_info=True)
query_succeeded = False
except Exception as e:
logger.error(f" └─ Unexpected error getting tokens from DB for {provider_tokens_key}: {e}", exc_info=True)
base_current_tokens = 0 # Fail safe - assume 0 if we can't query
if not query_succeeded:
logger.warning(f" └─ Both query methods failed, using 0 as fallback")
# Log DB query result at debug level (only when needed for troubleshooting)
logger.debug(f"[Pre-flight] DB query for {display_provider_name} ({provider_tokens_key}): {base_current_tokens} (period: {current_period})")
# Add any projected tokens from previous operations in this validation run
# Note: total_llm_tokens tracks ONLY projected tokens from this run, not base DB value
projected_from_previous = total_llm_tokens.get(provider_tokens_key, 0)
# Current tokens = base from DB + projected from previous operations in this run
current_provider_tokens = base_current_tokens + projected_from_previous
# Log token calculation at debug level
logger.debug(f"[Pre-flight] Token calc for {display_provider_name}: base={base_current_tokens}, projected={projected_from_previous}, total={current_provider_tokens}")
token_limit = limits.get(provider_tokens_key, 0) or 0
if token_limit > 0 and tokens_requested > 0:
projected_tokens = current_provider_tokens + tokens_requested
logger.info(f" └─ Token Check: {current_provider_tokens} (current) + {tokens_requested} (requested) = {projected_tokens} (total) / {token_limit} (limit)")
if projected_tokens > token_limit:
usage_percentage = (projected_tokens / token_limit) * 100 if token_limit > 0 else 0
error_info = {
'current_tokens': current_provider_tokens,
'base_tokens_from_db': base_current_tokens,
'projected_from_previous_ops': projected_from_previous,
'requested_tokens': tokens_requested,
'limit': token_limit,
'provider': display_provider_name,
'operation_type': operation_type,
'operation_index': op_idx
}
# Make error message clearer: show actual DB usage vs projected
if projected_from_previous > 0:
error_msg = (
f"Token limit exceeded for {display_provider_name} "
f"({operation_type}). "
f"Base usage: {base_current_tokens}/{token_limit}, "
f"After previous operations in this workflow: {current_provider_tokens}/{token_limit}, "
f"This operation would add: {tokens_requested}, "
f"Total would be: {projected_tokens} (exceeds by {projected_tokens - token_limit} tokens)"
)
else:
error_msg = (
f"Token limit exceeded for {display_provider_name} "
f"({operation_type}). "
f"Current: {current_provider_tokens}/{token_limit}, "
f"Requested: {tokens_requested}, "
f"Would exceed by: {projected_tokens - token_limit} tokens "
f"({usage_percentage:.1f}% of limit)"
)
logger.error(f"[Pre-flight Check] ❌ BLOCKED: {error_msg}")
return False, error_msg, {
'error_type': 'token_limit',
'usage_info': error_info
}
else:
logger.info(f" └─ ✅ Token limit check passed: {projected_tokens} <= {token_limit}")
# Update cumulative counts for next operation
total_llm_calls = projected_total_llm_calls
# Update cumulative projected tokens from this validation run
# This represents only projected tokens from previous operations in this run
# Base DB value is always queried fresh, so we only track the projection delta
old_projected = total_llm_tokens.get(provider_tokens_key, 0)
if tokens_requested > 0:
# Add this operation's tokens to cumulative projected tokens
total_llm_tokens[provider_tokens_key] = projected_from_previous + tokens_requested
logger.debug(f"[Pre-flight] Updated projected tokens for {display_provider_name}: {projected_from_previous} + {tokens_requested} = {total_llm_tokens[provider_tokens_key]}")
else:
# No tokens requested, keep existing projected tokens (or 0 if first operation)
total_llm_tokens[provider_tokens_key] = projected_from_previous
# Check image generation limits
elif provider == APIProvider.STABILITY:
image_limit = limits.get('stability_calls', 0) or 0
projected_images = total_images + 1
if image_limit > 0 and projected_images > image_limit:
error_info = {
'current_images': total_images,
'limit': image_limit,
'provider': 'stability',
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"Image generation limit would be exceeded. Would use {projected_images} of {image_limit} images this billing period.", {
'error_type': 'image_limit',
'usage_info': error_info
}
total_images = projected_images
# Check video generation limits
elif provider == APIProvider.VIDEO:
video_limit = limits.get('video_calls', 0) or 0
total_video_calls = usage.video_calls or 0
projected_video_calls = total_video_calls + 1
if video_limit > 0 and projected_video_calls > video_limit:
error_info = {
'current_calls': total_video_calls,
'limit': video_limit,
'provider': 'video',
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"Video generation limit would be exceeded. Would use {projected_video_calls} of {video_limit} videos this billing period.", {
'error_type': 'video_limit',
'usage_info': error_info
}
# Check image editing limits
elif provider == APIProvider.IMAGE_EDIT:
image_edit_limit = limits.get('image_edit_calls', 0) or 0
total_image_edit_calls = getattr(usage, 'image_edit_calls', 0) or 0
projected_image_edit_calls = total_image_edit_calls + 1
if image_edit_limit > 0 and projected_image_edit_calls > image_edit_limit:
error_info = {
'current_calls': total_image_edit_calls,
'limit': image_edit_limit,
'provider': 'image_edit',
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"Image editing limit would be exceeded. Would use {projected_image_edit_calls} of {image_edit_limit} image edits this billing period.", {
'error_type': 'image_edit_limit',
'usage_info': error_info
}
# Check other provider-specific limits
else:
provider_calls_key = f"{provider_name}_calls"
current_provider_calls = getattr(usage, provider_calls_key, 0) or 0
call_limit = limits.get(provider_calls_key, 0) or 0
if call_limit > 0:
projected_calls = current_provider_calls + 1
if projected_calls > call_limit:
error_info = {
'current_calls': current_provider_calls,
'limit': call_limit,
'provider': display_provider_name,
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"API call limit would be exceeded for {display_provider_name}. Would use {projected_calls} of {call_limit} calls this billing period.", {
'error_type': 'call_limit',
'usage_info': error_info
}
# All checks passed
logger.info(f"[Pre-flight Check] ✅ All {len(operations)} operation(s) validated successfully")
logger.info(f"[Pre-flight Check] ✅ User {user_id} is cleared to proceed with API calls")
return True, None, None
except Exception as e:
error_type = type(e).__name__
error_message = str(e).lower()
# Handle missing column errors with schema fix and retry
if 'operationalerror' in error_type.lower() or 'operationalerror' in error_message:
if 'no such column' in error_message and 'exa_calls' in error_message:
logger.warning("Missing column detected in limit check, attempting schema fix...")
try:
import sqlite3
import services.subscription.schema_utils as schema_utils
schema_utils._checked_usage_summaries_columns = False
from services.subscription.schema_utils import ensure_usage_summaries_columns
ensure_usage_summaries_columns(self.db)
self.db.expire_all()
# Retry the query
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if usage:
self.db.refresh(usage)
# Continue with the rest of the validation using the retried usage
# (The rest of the function logic continues from here)
# For now, we'll let it fall through to return the error since we'd need to duplicate the entire validation logic
# Instead, we'll just log and return, but the next call should succeed
logger.info(f"[Pre-flight Check] Schema fixed, but need to retry validation on next call")
return False, f"Schema updated, please retry: Database schema was updated. Please try again.", {'error_type': 'schema_update', 'retry': True}
except Exception as retry_err:
logger.error(f"Schema fix and retry failed: {retry_err}")
return False, f"Failed to validate limits: {error_type}: {str(e)}", {}
logger.error(f"[Pre-flight Check] ❌ Error during comprehensive limit check: {error_type}: {str(e)}", exc_info=True)
logger.error(f"[Pre-flight Check] ❌ User: {user_id}, Operations count: {len(operations) if operations else 0}")
return False, f"Failed to validate limits: {error_type}: {str(e)}", {}

View File

@@ -0,0 +1,231 @@
"""
Log Wrapping Service
Intelligently wraps API usage logs when they exceed 5000 records.
Aggregates old logs into cumulative records while preserving historical data.
"""
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from sqlalchemy import func, desc
from loguru import logger
from models.subscription_models import APIUsageLog, APIProvider
class LogWrappingService:
"""Service for wrapping and aggregating API usage logs."""
MAX_LOGS_PER_USER = 5000
AGGREGATION_THRESHOLD_DAYS = 30 # Aggregate logs older than 30 days
def __init__(self, db: Session):
self.db = db
def check_and_wrap_logs(self, user_id: str) -> Dict[str, Any]:
"""
Check if user has exceeded log limit and wrap if necessary.
Returns:
Dict with wrapping status and statistics
"""
try:
# Count total logs for user
total_count = self.db.query(func.count(APIUsageLog.id)).filter(
APIUsageLog.user_id == user_id
).scalar() or 0
if total_count <= self.MAX_LOGS_PER_USER:
return {
'wrapped': False,
'total_logs': total_count,
'max_logs': self.MAX_LOGS_PER_USER,
'message': f'Log count ({total_count}) is within limit ({self.MAX_LOGS_PER_USER})'
}
# Need to wrap logs - aggregate old logs
logger.info(f"[LogWrapping] User {user_id} has {total_count} logs, exceeding limit of {self.MAX_LOGS_PER_USER}. Starting wrap...")
wrap_result = self._wrap_old_logs(user_id, total_count)
return {
'wrapped': True,
'total_logs_before': total_count,
'total_logs_after': wrap_result['logs_remaining'],
'aggregated_logs': wrap_result['aggregated_count'],
'aggregated_periods': wrap_result['periods'],
'message': f'Wrapped {wrap_result["aggregated_count"]} logs into {len(wrap_result["periods"])} aggregated records'
}
except Exception as e:
logger.error(f"[LogWrapping] Error checking/wrapping logs for user {user_id}: {e}", exc_info=True)
return {
'wrapped': False,
'error': str(e),
'message': f'Error wrapping logs: {str(e)}'
}
def _wrap_old_logs(self, user_id: str, total_count: int) -> Dict[str, Any]:
"""
Aggregate old logs into cumulative records.
Strategy:
1. Keep most recent 4000 logs (detailed)
2. Aggregate logs older than 30 days or oldest logs beyond 4000
3. Create aggregated records grouped by provider and billing period
4. Delete individual logs that were aggregated
"""
try:
# Calculate how many logs to keep (4000 detailed, rest aggregated)
logs_to_keep = 4000
logs_to_aggregate = total_count - logs_to_keep
# Get cutoff date (30 days ago)
cutoff_date = datetime.utcnow() - timedelta(days=self.AGGREGATION_THRESHOLD_DAYS)
# Get logs to aggregate: oldest logs beyond the keep limit
# Order by timestamp ascending to get oldest first
# We'll keep the most recent logs_to_keep logs, aggregate the rest
logs_to_process = self.db.query(APIUsageLog).filter(
APIUsageLog.user_id == user_id
).order_by(APIUsageLog.timestamp.asc()).limit(logs_to_aggregate).all()
if not logs_to_process:
return {
'aggregated_count': 0,
'logs_remaining': total_count,
'periods': []
}
# Group logs by provider and billing period for aggregation
aggregated_data: Dict[str, Dict[str, Any]] = {}
for log in logs_to_process:
# Use provider value as key (e.g., "mistral" for huggingface)
provider_key = log.provider.value
# Special handling: if provider is MISTRAL but we want to show as huggingface
if provider_key == "mistral":
# Check if this is actually huggingface by looking at model or endpoint
# For now, we'll use "mistral" as the key but store actual provider name
provider_display = "huggingface" if "huggingface" in (log.model_used or "").lower() else "mistral"
else:
provider_display = provider_key
period_key = f"{provider_display}_{log.billing_period}"
if period_key not in aggregated_data:
aggregated_data[period_key] = {
'provider': log.provider,
'billing_period': log.billing_period,
'count': 0,
'total_tokens_input': 0,
'total_tokens_output': 0,
'total_tokens': 0,
'total_cost_input': 0.0,
'total_cost_output': 0.0,
'total_cost': 0.0,
'total_response_time': 0.0,
'success_count': 0,
'failed_count': 0,
'oldest_timestamp': log.timestamp,
'newest_timestamp': log.timestamp,
'log_ids': []
}
agg = aggregated_data[period_key]
agg['count'] += 1
agg['total_tokens_input'] += log.tokens_input or 0
agg['total_tokens_output'] += log.tokens_output or 0
agg['total_tokens'] += log.tokens_total or 0
agg['total_cost_input'] += float(log.cost_input or 0.0)
agg['total_cost_output'] += float(log.cost_output or 0.0)
agg['total_cost'] += float(log.cost_total or 0.0)
agg['total_response_time'] += float(log.response_time or 0.0)
if 200 <= log.status_code < 300:
agg['success_count'] += 1
else:
agg['failed_count'] += 1
if log.timestamp:
if log.timestamp < agg['oldest_timestamp']:
agg['oldest_timestamp'] = log.timestamp
if log.timestamp > agg['newest_timestamp']:
agg['newest_timestamp'] = log.timestamp
agg['log_ids'].append(log.id)
# Create aggregated log entries
aggregated_count = 0
periods_created = []
for period_key, agg_data in aggregated_data.items():
# Calculate averages
count = agg_data['count']
avg_response_time = agg_data['total_response_time'] / count if count > 0 else 0.0
# Create aggregated log entry
aggregated_log = APIUsageLog(
user_id=user_id,
provider=agg_data['provider'],
endpoint='[AGGREGATED]',
method='AGGREGATED',
model_used=f"[{count} calls aggregated]",
tokens_input=agg_data['total_tokens_input'],
tokens_output=agg_data['total_tokens_output'],
tokens_total=agg_data['total_tokens'],
cost_input=agg_data['total_cost_input'],
cost_output=agg_data['total_cost_output'],
cost_total=agg_data['total_cost'],
response_time=avg_response_time,
status_code=200 if agg_data['success_count'] > agg_data['failed_count'] else 500,
error_message=f"Aggregated {count} calls: {agg_data['success_count']} success, {agg_data['failed_count']} failed",
retry_count=0,
timestamp=agg_data['oldest_timestamp'], # Use oldest timestamp
billing_period=agg_data['billing_period']
)
self.db.add(aggregated_log)
periods_created.append({
'provider': agg_data['provider'].value,
'billing_period': agg_data['billing_period'],
'count': count,
'period_start': agg_data['oldest_timestamp'].isoformat() if agg_data['oldest_timestamp'] else None,
'period_end': agg_data['newest_timestamp'].isoformat() if agg_data['newest_timestamp'] else None
})
aggregated_count += count
# Delete individual logs that were aggregated
log_ids_to_delete = []
for agg_data in aggregated_data.values():
log_ids_to_delete.extend(agg_data['log_ids'])
if log_ids_to_delete:
self.db.query(APIUsageLog).filter(
APIUsageLog.id.in_(log_ids_to_delete)
).delete(synchronize_session=False)
self.db.commit()
# Get remaining log count
remaining_count = self.db.query(func.count(APIUsageLog.id)).filter(
APIUsageLog.user_id == user_id
).scalar() or 0
logger.info(
f"[LogWrapping] Wrapped {aggregated_count} logs into {len(periods_created)} aggregated records. "
f"Remaining logs: {remaining_count}"
)
return {
'aggregated_count': aggregated_count,
'logs_remaining': remaining_count,
'periods': periods_created
}
except Exception as e:
self.db.rollback()
logger.error(f"[LogWrapping] Error wrapping logs: {e}", exc_info=True)
raise

View File

@@ -0,0 +1,388 @@
"""
Enhanced FastAPI Monitoring Middleware
Database-backed monitoring for API calls, errors, performance metrics, and usage tracking.
Includes comprehensive subscription-based usage monitoring and cost tracking.
"""
from fastapi import Request, Response
from fastapi.responses import JSONResponse
import time
import json
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional
from collections import defaultdict, deque
import asyncio
from loguru import logger
from sqlalchemy.orm import Session
from sqlalchemy import and_, func
import re
from models.api_monitoring import APIRequest, APIEndpointStats, SystemHealth, CachePerformance
from models.subscription_models import APIProvider
from .usage_tracking_service import UsageTrackingService
from .pricing_service import PricingService
def _get_db_session():
"""
Get a database session with lazy import to survive hot reloads.
Uvicorn's reloader can sometimes clear module-level imports.
"""
from services.database import get_db
return next(get_db())
class DatabaseAPIMonitor:
"""Database-backed API monitoring with usage tracking and subscription management."""
def __init__(self):
self.cache_stats = {
'hits': 0,
'misses': 0,
'hit_rate': 0.0
}
# API provider detection patterns - Updated to match actual endpoints
self.provider_patterns = {
APIProvider.GEMINI: [
r'gemini', r'google.*ai'
],
APIProvider.OPENAI: [r'openai', r'gpt', r'chatgpt'],
APIProvider.ANTHROPIC: [r'anthropic', r'claude'],
APIProvider.MISTRAL: [r'mistral'],
APIProvider.TAVILY: [r'tavily'],
APIProvider.SERPER: [r'serper'],
APIProvider.METAPHOR: [r'metaphor', r'/exa'],
APIProvider.FIRECRAWL: [r'firecrawl']
}
def detect_api_provider(self, path: str, user_agent: str = None) -> Optional[APIProvider]:
"""Detect which API provider is being used based on request details."""
path_lower = path.lower()
user_agent_lower = (user_agent or '').lower()
# Permanently ignore internal route families that must not accrue or check provider usage
if path_lower.startswith('/api/onboarding/') or path_lower.startswith('/api/subscription/'):
return None
for provider, patterns in self.provider_patterns.items():
for pattern in patterns:
if re.search(pattern, path_lower) or re.search(pattern, user_agent_lower):
return provider
return None
def extract_usage_metrics(self, request_body: str = None, response_body: str = None) -> Dict[str, Any]:
"""Extract usage metrics from request/response bodies."""
metrics = {
'tokens_input': 0,
'tokens_output': 0,
'model_used': None,
'search_count': 0,
'image_count': 0,
'page_count': 0
}
try:
# Try to parse request body for input tokens/content
if request_body:
request_data = json.loads(request_body) if isinstance(request_body, str) else request_body
# Extract model information
if 'model' in request_data:
metrics['model_used'] = request_data['model']
# Estimate input tokens from prompt/content
if 'prompt' in request_data:
metrics['tokens_input'] = self._estimate_tokens(request_data['prompt'])
elif 'messages' in request_data:
total_content = ' '.join([msg.get('content', '') for msg in request_data['messages']])
metrics['tokens_input'] = self._estimate_tokens(total_content)
elif 'input' in request_data:
metrics['tokens_input'] = self._estimate_tokens(str(request_data['input']))
# Count specific request types
if 'query' in request_data or 'search' in request_data:
metrics['search_count'] = 1
if 'image' in request_data or 'generate_image' in request_data:
metrics['image_count'] = 1
if 'url' in request_data or 'crawl' in request_data:
metrics['page_count'] = 1
# Try to parse response body for output tokens
if response_body:
response_data = json.loads(response_body) if isinstance(response_body, str) else response_body
# Extract output content and estimate tokens
if 'text' in response_data:
metrics['tokens_output'] = self._estimate_tokens(response_data['text'])
elif 'content' in response_data:
metrics['tokens_output'] = self._estimate_tokens(str(response_data['content']))
elif 'choices' in response_data and response_data['choices']:
choice = response_data['choices'][0]
if 'message' in choice and 'content' in choice['message']:
metrics['tokens_output'] = self._estimate_tokens(choice['message']['content'])
# Extract actual token usage if provided by API
if 'usage' in response_data:
usage = response_data['usage']
if 'prompt_tokens' in usage:
metrics['tokens_input'] = usage['prompt_tokens']
if 'completion_tokens' in usage:
metrics['tokens_output'] = usage['completion_tokens']
except (json.JSONDecodeError, KeyError, TypeError) as e:
logger.debug(f"Could not extract usage metrics: {e}")
return metrics
def _estimate_tokens(self, text: str) -> int:
"""Estimate token count for text (rough approximation)."""
if not text:
return 0
# Rough estimation: 1.3 tokens per word on average
word_count = len(str(text).split())
return int(word_count * 1.3)
async def check_usage_limits_middleware(request: Request, user_id: str, request_body: str = None) -> Optional[JSONResponse]:
"""Check usage limits before processing request."""
if not user_id:
return None
# No special whitelist; onboarding/subscription are ignored by provider detection
try:
path = request.url.path
except Exception:
pass
db = None
try:
db = _get_db_session()
api_monitor = DatabaseAPIMonitor()
# Detect if this is an API call that should be rate limited
api_provider = api_monitor.detect_api_provider(request.url.path, request.headers.get('user-agent'))
if not api_provider:
return None
# Use provided request body or read it if not provided
if request_body is None:
try:
if hasattr(request, '_body'):
request_body = request._body
else:
# Try to read body (this might not work in all cases)
body = await request.body()
request_body = body.decode('utf-8') if body else None
except:
pass
# Estimate tokens needed
tokens_requested = 0
if request_body:
usage_metrics = api_monitor.extract_usage_metrics(request_body)
tokens_requested = usage_metrics.get('tokens_input', 0)
# Check limits
usage_service = UsageTrackingService(db)
can_proceed, message, usage_info = await usage_service.enforce_usage_limits(
user_id=user_id,
provider=api_provider,
tokens_requested=tokens_requested
)
if not can_proceed:
logger.warning(f"Usage limit exceeded for {user_id}: {message}")
return JSONResponse(
status_code=429,
content={
"error": "Usage limit exceeded",
"message": message,
"usage_info": usage_info,
"provider": api_provider.value
}
)
# Warn if approaching limits
if usage_info.get('call_usage_percentage', 0) >= 80 or usage_info.get('cost_usage_percentage', 0) >= 80:
logger.warning(f"User {user_id} approaching usage limits: {usage_info}")
return None
except Exception as e:
logger.error(f"Error checking usage limits: {e}")
# Don't block requests if usage checking fails
return None
finally:
if db is not None:
db.close()
async def monitoring_middleware(request: Request, call_next):
"""Enhanced FastAPI middleware for monitoring API calls with usage tracking."""
start_time = time.time()
# Get database session
db = _get_db_session()
# Extract request details - Enhanced user identification
user_id = None
try:
# PRIORITY 1: Check request.state.user_id (set by API key injection middleware)
if hasattr(request.state, 'user_id') and request.state.user_id:
user_id = request.state.user_id
logger.debug(f"Monitoring: Using user_id from request.state: {user_id}")
# PRIORITY 2: Check query parameters
elif hasattr(request, 'query_params') and 'user_id' in request.query_params:
user_id = request.query_params['user_id']
elif hasattr(request, 'path_params') and 'user_id' in request.path_params:
user_id = request.path_params['user_id']
# PRIORITY 3: Check headers for user identification
elif 'x-user-id' in request.headers:
user_id = request.headers['x-user-id']
elif 'x-user-email' in request.headers:
user_id = request.headers['x-user-email'] # Use email as user identifier
elif 'x-session-id' in request.headers:
user_id = request.headers['x-session-id'] # Use session as fallback
# Check for authorization header with user info
elif 'authorization' in request.headers:
# Auth middleware should have set request.state.user_id
# If not, this indicates an authentication failure (likely expired token)
# Log at debug level to reduce noise - expired tokens are expected
user_id = None
logger.debug("Monitoring: Auth header present but no user_id in state - token likely expired")
# Final fallback: None (skip usage limits for truly anonymous/unauthenticated)
else:
user_id = None
except Exception as e:
logger.debug(f"Error extracting user ID: {e}")
user_id = None # On error, skip usage limits
# Capture request body for usage tracking (read once, safely)
request_body = None
try:
# Only read body for POST/PUT/PATCH requests to avoid issues
if request.method in ['POST', 'PUT', 'PATCH']:
if hasattr(request, '_body') and request._body:
request_body = request._body.decode('utf-8')
else:
# Read body only if it hasn't been read yet
try:
body = await request.body()
request_body = body.decode('utf-8') if body else None
except Exception as body_error:
logger.debug(f"Could not read request body: {body_error}")
request_body = None
except Exception as e:
logger.debug(f"Error capturing request body: {e}")
request_body = None
# Check usage limits before processing
limit_response = await check_usage_limits_middleware(request, user_id, request_body)
if limit_response:
return limit_response
try:
response = await call_next(request)
status_code = response.status_code
duration = time.time() - start_time
# Capture response body for usage tracking
response_body = None
try:
if hasattr(response, 'body'):
response_body = response.body.decode('utf-8') if response.body else None
elif hasattr(response, '_content'):
response_body = response._content.decode('utf-8') if response._content else None
except:
pass
# Track API usage if this is an API call to external providers
api_monitor = DatabaseAPIMonitor()
api_provider = api_monitor.detect_api_provider(request.url.path, request.headers.get('user-agent'))
if api_provider and user_id:
logger.info(f"Detected API call: {request.url.path} -> {api_provider.value} for user: {user_id}")
try:
# Extract usage metrics
usage_metrics = api_monitor.extract_usage_metrics(request_body, response_body)
# Track usage with the usage tracking service
usage_service = UsageTrackingService(db)
await usage_service.track_api_usage(
user_id=user_id,
provider=api_provider,
endpoint=request.url.path,
method=request.method,
model_used=usage_metrics.get('model_used'),
tokens_input=usage_metrics.get('tokens_input', 0),
tokens_output=usage_metrics.get('tokens_output', 0),
response_time=duration,
status_code=status_code,
request_size=len(request_body) if request_body else None,
response_size=len(response_body) if response_body else None,
user_agent=request.headers.get('user-agent'),
ip_address=request.client.host if request.client else None,
search_count=usage_metrics.get('search_count', 0),
image_count=usage_metrics.get('image_count', 0),
page_count=usage_metrics.get('page_count', 0)
)
except Exception as usage_error:
logger.error(f"Error tracking API usage: {usage_error}")
# Don't fail the main request if usage tracking fails
return response
except Exception as e:
duration = time.time() - start_time
status_code = 500
# Store minimal error info
logger.error(f"API Error: {request.method} {request.url.path} - {str(e)}")
return JSONResponse(
status_code=500,
content={"error": "Internal server error"}
)
finally:
db.close()
async def get_monitoring_stats(minutes: int = 5) -> Dict[str, Any]:
"""Get current monitoring statistics."""
db = None
try:
db = _get_db_session()
# Placeholder to match old API; heavy stats handled elsewhere
return {
'timestamp': datetime.utcnow().isoformat(),
'overview': {
'recent_requests': 0,
'recent_errors': 0,
},
'cache_performance': {'hits': 0, 'misses': 0, 'hit_rate': 0.0},
'recent_errors': [],
'system_health': {'status': 'healthy', 'error_rate': 0.0}
}
finally:
if db is not None:
db.close()
async def get_lightweight_stats() -> Dict[str, Any]:
"""Get lightweight stats for dashboard header."""
db = None
try:
db = _get_db_session()
# Minimal viable placeholder values
now = datetime.utcnow()
return {
'status': 'healthy',
'icon': '🟢',
'recent_requests': 0,
'recent_errors': 0,
'error_rate': 0.0,
'timestamp': now.isoformat()
}
finally:
if db is not None:
db.close()

View File

@@ -0,0 +1,853 @@
"""
Pre-flight Validation Utility for Multi-Operation Workflows
Provides transparent validation for operations that involve multiple API calls.
Services can use this to validate entire workflows before making any external API calls.
"""
from typing import Dict, Any, List, Optional, Tuple
from fastapi import HTTPException
from loguru import logger
from services.subscription.pricing_service import PricingService
from models.subscription_models import APIProvider
def validate_research_operations(
pricing_service: PricingService,
user_id: str,
gpt_provider: str = "google"
) -> None:
"""
Validate all operations for a research workflow before making ANY API calls.
This prevents wasteful external API calls (like Google Grounding) if subsequent
LLM calls would fail due to token or call limits.
Args:
pricing_service: PricingService instance
user_id: User ID for subscription checking
gpt_provider: GPT provider from env var (defaults to "google")
Returns:
(can_proceed, error_message, error_details)
If can_proceed is False, raises HTTPException with 429 status
"""
try:
# Determine actual provider for LLM calls based on GPT_PROVIDER env var
gpt_provider_lower = gpt_provider.lower()
if gpt_provider_lower == "huggingface":
llm_provider_enum = APIProvider.MISTRAL # Maps to HuggingFace
llm_provider_name = "huggingface"
else:
llm_provider_enum = APIProvider.GEMINI
llm_provider_name = "gemini"
# Estimate tokens for each operation in research workflow
# Google Grounding call: ~1200 tokens (input: ~500 tokens, output: ~700 tokens for research results)
# Keyword analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON)
# Competitor analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON)
# Content angle generator: ~1000 tokens (input: 3000 chars research, output: list of angles)
# Note: These are conservative estimates. Actual usage may be lower, but we use these for pre-flight validation
# to prevent wasteful API calls if the workflow would exceed limits.
operations_to_validate = [
{
'provider': APIProvider.GEMINI, # Google Grounding uses Gemini
'tokens_requested': 1200, # Reduced from 2000 to more realistic estimate
'actual_provider_name': 'gemini',
'operation_type': 'google_grounding'
},
{
'provider': llm_provider_enum,
'tokens_requested': 1000,
'actual_provider_name': llm_provider_name,
'operation_type': 'keyword_analysis'
},
{
'provider': llm_provider_enum,
'tokens_requested': 1000,
'actual_provider_name': llm_provider_name,
'operation_type': 'competitor_analysis'
},
{
'provider': llm_provider_enum,
'tokens_requested': 1000,
'actual_provider_name': llm_provider_name,
'operation_type': 'content_angle_generation'
}
]
logger.info(f"[Pre-flight Validator] 🚀 Starting Research Workflow Validation")
logger.info(f" ├─ User: {user_id}")
logger.info(f" ├─ LLM Provider: {llm_provider_name} (GPT_PROVIDER={gpt_provider})")
logger.info(f" └─ Operations to validate: {len(operations_to_validate)}")
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
if not can_proceed:
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', llm_provider_name) if usage_info else llm_provider_name
operation_type = usage_info.get('operation_type', 'unknown')
logger.warning(f"[Pre-flight] Research blocked for user {user_id}: {operation_type} ({provider}) - {message}")
# Raise HTTPException immediately - frontend gets immediate response, no API calls made
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details
}
)
logger.info(f"[Pre-flight Validator] ✅ RESEARCH WORKFLOW APPROVED")
logger.info(f" ├─ User: {user_id}")
logger.info(f" └─ All {len(operations_to_validate)} operations validated - proceeding with API calls")
# Validation passed - no return needed (function raises HTTPException if validation fails)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Pre-flight Validator] Error validating research operations: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
'error': f"Failed to validate operations: {str(e)}",
'message': f"Failed to validate operations: {str(e)}"
}
)
def validate_exa_research_operations(
pricing_service: PricingService,
user_id: str,
gpt_provider: str = "google"
) -> None:
"""
Validate all operations for an Exa research workflow before making ANY API calls.
This prevents wasteful external API calls (like Exa search) if subsequent
LLM calls would fail due to token or call limits.
Args:
pricing_service: PricingService instance
user_id: User ID for subscription checking
gpt_provider: GPT provider from env var (defaults to "google")
Returns:
None
If validation fails, raises HTTPException with 429 status
"""
try:
# Determine actual provider for LLM calls based on GPT_PROVIDER env var
gpt_provider_lower = gpt_provider.lower()
if gpt_provider_lower == "huggingface":
llm_provider_enum = APIProvider.MISTRAL # Maps to HuggingFace
llm_provider_name = "huggingface"
else:
llm_provider_enum = APIProvider.GEMINI
llm_provider_name = "gemini"
# Estimate tokens for each operation in Exa research workflow
# Exa Search call: 1 Exa API call (not token-based)
# Keyword analyzer: ~1000 tokens (input: research results, output: structured JSON)
# Competitor analyzer: ~1000 tokens (input: research results, output: structured JSON)
# Content angle generator: ~1000 tokens (input: research results, output: list of angles)
# Note: These are conservative estimates for pre-flight validation
operations_to_validate = [
{
'provider': APIProvider.EXA, # Exa API call
'tokens_requested': 0,
'actual_provider_name': 'exa',
'operation_type': 'exa_neural_search'
},
{
'provider': llm_provider_enum,
'tokens_requested': 1000,
'actual_provider_name': llm_provider_name,
'operation_type': 'keyword_analysis'
},
{
'provider': llm_provider_enum,
'tokens_requested': 1000,
'actual_provider_name': llm_provider_name,
'operation_type': 'competitor_analysis'
},
{
'provider': llm_provider_enum,
'tokens_requested': 1000,
'actual_provider_name': llm_provider_name,
'operation_type': 'content_angle_generation'
}
]
logger.info(f"[Pre-flight Validator] 🚀 Starting Exa Research Workflow Validation")
logger.info(f" ├─ User: {user_id}")
logger.info(f" ├─ LLM Provider: {llm_provider_name} (GPT_PROVIDER={gpt_provider})")
logger.info(f" └─ Operations to validate: {len(operations_to_validate)}")
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
if not can_proceed:
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', llm_provider_name) if usage_info else llm_provider_name
operation_type = usage_info.get('operation_type', 'unknown')
logger.error(f"[Pre-flight Validator] ❌ EXA RESEARCH WORKFLOW BLOCKED")
logger.error(f" ├─ User: {user_id}")
logger.error(f" ├─ Blocked at: {operation_type}")
logger.error(f" ├─ Provider: {provider}")
logger.error(f" └─ Reason: {message}")
# Raise HTTPException immediately - frontend gets immediate response, no API calls made
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details
}
)
logger.info(f"[Pre-flight Validator] ✅ EXA RESEARCH WORKFLOW APPROVED")
logger.info(f" ├─ User: {user_id}")
logger.info(f" └─ All {len(operations_to_validate)} operations validated - proceeding with API calls")
# Validation passed - no return needed (function raises HTTPException if validation fails)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Pre-flight Validator] Error validating Exa research operations: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
'error': f"Failed to validate operations: {str(e)}",
'message': f"Failed to validate operations: {str(e)}"
}
)
def validate_image_generation_operations(
pricing_service: PricingService,
user_id: str,
num_images: int = 1
) -> None:
"""
Validate image generation operation(s) before making API calls.
Args:
pricing_service: PricingService instance
user_id: User ID for subscription checking
num_images: Number of images to generate (for multiple variations)
Returns:
None
If validation fails, raises HTTPException with 429 status
"""
try:
# Create validation operations for each image
operations_to_validate = [
{
'provider': APIProvider.STABILITY,
'tokens_requested': 0,
'actual_provider_name': 'stability',
'operation_type': 'image_generation'
}
for _ in range(num_images)
]
logger.info(f"[Pre-flight Validator] 🚀 Validating {num_images} image generation(s) for user {user_id}")
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
if not can_proceed:
logger.error(f"[Pre-flight Validator] Image generation blocked for user {user_id}: {message}")
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', 'stability') if usage_info else 'stability'
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details
}
)
logger.info(f"[Pre-flight Validator] ✅ Image generation validated for user {user_id}")
# Validation passed - no return needed (function raises HTTPException if validation fails)
except HTTPException:
raise
def validate_image_upscale_operations(
pricing_service: PricingService,
user_id: str,
num_images: int = 1
) -> None:
"""
Validate image upscaling before making API calls.
"""
try:
operations_to_validate = [
{
'provider': APIProvider.STABILITY,
'tokens_requested': 0,
'actual_provider_name': 'stability',
'operation_type': 'image_upscale'
}
for _ in range(num_images)
]
logger.info(f"[Pre-flight Validator] 🚀 Validating {num_images} image upscale request(s) for user {user_id}")
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
if not can_proceed:
logger.error(f"[Pre-flight Validator] Image upscale blocked for user {user_id}: {message}")
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', 'stability') if usage_info else 'stability'
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details
}
)
logger.info(f"[Pre-flight Validator] ✅ Image upscale validated for user {user_id}")
except HTTPException:
raise
except Exception as e:
logger.error(f"[Pre-flight Validator] Error validating image generation: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
'error': f"Failed to validate image generation: {str(e)}",
'message': f"Failed to validate image generation: {str(e)}"
}
)
def validate_image_editing_operations(
pricing_service: PricingService,
user_id: str
) -> None:
"""
Validate image editing operation before making API calls.
Args:
pricing_service: PricingService instance
user_id: User ID for subscription checking
Returns:
None - raises HTTPException with 429 status if validation fails
"""
try:
operations_to_validate = [
{
'provider': APIProvider.IMAGE_EDIT,
'tokens_requested': 0,
'actual_provider_name': 'image_edit',
'operation_type': 'image_editing'
}
]
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
if not can_proceed:
logger.error(f"[Pre-flight Validator] Image editing blocked for user {user_id}: {message}")
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', 'image_edit') if usage_info else 'image_edit'
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details
}
)
logger.info(f"[Pre-flight Validator] ✅ Image editing validated for user {user_id}")
# Validation passed - no return needed (function raises HTTPException if validation fails)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Pre-flight Validator] Error validating image editing: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
'error': f"Failed to validate image editing: {str(e)}",
'message': f"Failed to validate image editing: {str(e)}"
}
)
def validate_image_control_operations(
pricing_service: PricingService,
user_id: str,
num_images: int = 1
) -> None:
"""
Validate image control operations (sketch-to-image, structure control, style transfer) before making API calls.
Control operations use Stability AI for image generation with control inputs, so they use
the same validation as image generation operations.
Args:
pricing_service: PricingService instance
user_id: User ID for subscription checking
num_images: Number of images to generate (for multiple variations)
Returns:
None - raises HTTPException with 429 status if validation fails
"""
try:
# Control operations use Stability AI, same as image generation
operations_to_validate = [
{
'provider': APIProvider.STABILITY,
'tokens_requested': 0,
'actual_provider_name': 'stability',
'operation_type': 'image_generation' # Control ops use image generation limits
}
for _ in range(num_images)
]
logger.info(f"[Pre-flight Validator] 🚀 Validating {num_images} image control operation(s) for user {user_id}")
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
if not can_proceed:
logger.error(f"[Pre-flight Validator] Image control blocked for user {user_id}: {message}")
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', 'stability') if usage_info else 'stability'
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details
}
)
logger.info(f"[Pre-flight Validator] ✅ Image control validated for user {user_id}")
except HTTPException:
raise
except Exception as e:
logger.error(f"[Pre-flight Validator] Error validating image control: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
'error': f"Failed to validate image control: {str(e)}",
'message': f"Failed to validate image control: {str(e)}"
}
)
def validate_video_generation_operations(
pricing_service: PricingService,
user_id: str
) -> None:
"""
Validate video generation operation before making API calls.
Args:
pricing_service: PricingService instance
user_id: User ID for subscription checking
Returns:
None - raises HTTPException with 429 status if validation fails
"""
try:
operations_to_validate = [
{
'provider': APIProvider.VIDEO,
'tokens_requested': 0,
'actual_provider_name': 'video',
'operation_type': 'video_generation'
}
]
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
if not can_proceed:
logger.error(f"[Pre-flight Validator] Video generation blocked for user {user_id}: {message}")
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', 'video') if usage_info else 'video'
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details
}
)
logger.info(f"[Pre-flight Validator] ✅ Video generation validated for user {user_id}")
# Validation passed - no return needed (function raises HTTPException if validation fails)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Pre-flight Validator] Error validating video generation: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
'error': f"Failed to validate video generation: {str(e)}",
'message': f"Failed to validate video generation: {str(e)}"
}
)
def validate_scene_animation_operation(
pricing_service: PricingService,
user_id: str,
) -> None:
"""
Validate the per-scene animation workflow before API calls.
"""
try:
operations_to_validate = [
{
'provider': APIProvider.VIDEO,
'tokens_requested': 0,
'actual_provider_name': 'wavespeed',
'operation_type': 'scene_animation',
}
]
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate,
)
if not can_proceed:
logger.error(f"[Pre-flight Validator] Scene animation blocked for user {user_id}: {message}")
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', 'video') if usage_info else 'video'
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details,
}
)
logger.info(f"[Pre-flight Validator] ✅ Scene animation validated for user {user_id}")
# Validation passed - no return needed (function raises HTTPException if validation fails)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Pre-flight Validator] Error validating scene animation: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
'error': f"Failed to validate scene animation: {str(e)}",
'message': f"Failed to validate scene animation: {str(e)}"
}
)
def validate_image_control_operations(
pricing_service: PricingService,
user_id: str,
num_images: int = 1
) -> None:
"""
Validate image control operations (sketch-to-image, structure control, style transfer) before making API calls.
Control operations use Stability AI for image generation with control inputs, so they use
the same validation as image generation operations.
Args:
pricing_service: PricingService instance
user_id: User ID for subscription checking
num_images: Number of images to generate (for multiple variations)
Returns:
None - raises HTTPException with 429 status if validation fails
"""
try:
# Control operations use Stability AI, same as image generation
operations_to_validate = [
{
'provider': APIProvider.STABILITY,
'tokens_requested': 0,
'actual_provider_name': 'stability',
'operation_type': 'image_generation' # Control ops use image generation limits
}
for _ in range(num_images)
]
logger.info(f"[Pre-flight Validator] 🚀 Validating {num_images} image control operation(s) for user {user_id}")
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
if not can_proceed:
logger.error(f"[Pre-flight Validator] Image control blocked for user {user_id}: {message}")
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', 'stability') if usage_info else 'stability'
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details
}
)
logger.info(f"[Pre-flight Validator] ✅ Image control validated for user {user_id}")
except HTTPException:
raise
except Exception as e:
logger.error(f"[Pre-flight Validator] Error validating image control: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
'error': f"Failed to validate image control: {str(e)}",
'message': f"Failed to validate image control: {str(e)}"
}
)
def validate_video_generation_operations(
pricing_service: PricingService,
user_id: str
) -> None:
"""
Validate video generation operation before making API calls.
Args:
pricing_service: PricingService instance
user_id: User ID for subscription checking
Returns:
None - raises HTTPException with 429 status if validation fails
"""
try:
operations_to_validate = [
{
'provider': APIProvider.VIDEO,
'tokens_requested': 0,
'actual_provider_name': 'video',
'operation_type': 'video_generation'
}
]
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
if not can_proceed:
logger.error(f"[Pre-flight Validator] Video generation blocked for user {user_id}: {message}")
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', 'video') if usage_info else 'video'
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details
}
)
logger.info(f"[Pre-flight Validator] ✅ Video generation validated for user {user_id}")
# Validation passed - no return needed (function raises HTTPException if validation fails)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Pre-flight Validator] Error validating video generation: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
'error': f"Failed to validate video generation: {str(e)}",
'message': f"Failed to validate video generation: {str(e)}"
}
)
def validate_scene_animation_operation(
pricing_service: PricingService,
user_id: str,
) -> None:
"""
Validate the per-scene animation workflow before API calls.
"""
try:
operations_to_validate = [
{
'provider': APIProvider.VIDEO,
'tokens_requested': 0,
'actual_provider_name': 'wavespeed',
'operation_type': 'scene_animation',
}
]
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate,
)
if not can_proceed:
logger.error(f"[Pre-flight Validator] Scene animation blocked for user {user_id}: {message}")
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', 'video') if usage_info else 'video'
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details,
}
)
logger.info(f"[Pre-flight Validator] ✅ Scene animation validated for user {user_id}")
except HTTPException:
raise
except Exception as e:
logger.error(f"[Pre-flight Validator] Error validating scene animation: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
'error': f"Failed to validate scene animation: {str(e)}",
'message': f"Failed to validate scene animation: {str(e)}",
},
)
def validate_calendar_generation_operations(
pricing_service: PricingService,
user_id: str,
gpt_provider: str = "google"
) -> None:
"""
Validate calendar generation operations before making API calls.
Args:
pricing_service: PricingService instance
user_id: User ID for subscription checking
gpt_provider: GPT provider from env var (defaults to "google")
Returns:
None - raises HTTPException with 429 status if validation fails
"""
try:
# Determine actual provider for LLM calls based on GPT_PROVIDER env var
gpt_provider_lower = gpt_provider.lower()
if gpt_provider_lower == "huggingface":
llm_provider_enum = APIProvider.MISTRAL
llm_provider_name = "huggingface"
else:
llm_provider_enum = APIProvider.GEMINI
llm_provider_name = "gemini"
# Estimate tokens for 12-step process
# This is a heavy operation involving multiple steps and analysis
operations_to_validate = [
{
'provider': llm_provider_enum,
'tokens_requested': 20000, # Conservative estimate for full calendar generation
'actual_provider_name': llm_provider_name,
'operation_type': 'calendar_generation'
}
]
logger.info(f"[Pre-flight Validator] 🚀 Validating Calendar Generation for user {user_id}")
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
if not can_proceed:
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', llm_provider_name) if usage_info else llm_provider_name
logger.warning(f"[Pre-flight Validator] Calendar generation blocked for user {user_id}: {message}")
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details
}
)
logger.info(f"[Pre-flight Validator] ✅ Calendar Generation validated for user {user_id}")
except HTTPException:
raise
except Exception as e:
logger.error(f"[Pre-flight Validator] Error validating calendar generation: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
'error': f"Failed to validate calendar generation: {str(e)}",
'message': f"Failed to validate calendar generation: {str(e)}"
}
)

View File

@@ -0,0 +1,815 @@
"""
Pricing Service for API Usage Tracking
Manages API pricing, cost calculation, and subscription limits.
"""
from typing import Dict, Any, Optional, List, Tuple, Union
from decimal import Decimal, ROUND_HALF_UP
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from sqlalchemy import text
from loguru import logger
import os
from models.subscription_models import (
APIProviderPricing, SubscriptionPlan, UserSubscription,
UsageSummary, APIUsageLog, APIProvider, SubscriptionTier
)
class PricingService:
"""Service for managing API pricing and cost calculations."""
# Class-level cache shared across all instances (critical for cache invalidation on subscription renewal)
# key: f"{user_id}:{provider}", value: { 'result': (bool, str, dict), 'expires_at': datetime }
_limits_cache: Dict[str, Dict[str, Any]] = {}
def __init__(self, db: Session):
self.db = db
self._pricing_cache = {}
self._plans_cache = {}
# Cache for schema feature detection (ai_text_generation_calls_limit column)
self._ai_text_gen_col_checked: bool = False
self._ai_text_gen_col_available: bool = False
# ------------------- Billing period helpers -------------------
def _compute_next_period_end(self, start: datetime, cycle: str) -> datetime:
"""Compute the next period end given a start and billing cycle."""
try:
cycle_value = cycle.value if hasattr(cycle, 'value') else str(cycle)
except Exception:
cycle_value = str(cycle)
if cycle_value == 'yearly':
return start + timedelta(days=365)
return start + timedelta(days=30)
def _ensure_subscription_current(self, subscription) -> bool:
"""Auto-advance subscription period if expired and auto_renew is enabled."""
if not subscription:
return False
now = datetime.utcnow()
try:
if subscription.current_period_end and subscription.current_period_end < now:
if getattr(subscription, 'auto_renew', False):
subscription.current_period_start = now
subscription.current_period_end = self._compute_next_period_end(now, subscription.billing_cycle)
# Keep status active if model enum else string
try:
subscription.status = subscription.status.ACTIVE # type: ignore[attr-defined]
except Exception:
setattr(subscription, 'status', 'active')
self.db.commit()
else:
return False
except Exception:
self.db.rollback()
return True
def get_current_billing_period(self, user_id: str) -> Optional[str]:
"""Return current billing period key (YYYY-MM) after ensuring subscription is current."""
subscription = self.db.query(UserSubscription).filter(
UserSubscription.user_id == user_id,
UserSubscription.is_active == True
).first()
# Ensure subscription is current (advance if auto_renew)
self._ensure_subscription_current(subscription)
# Continue to use YYYY-MM for summaries
return datetime.now().strftime("%Y-%m")
@classmethod
def clear_user_cache(cls, user_id: str) -> int:
"""Clear all cached limit checks for a specific user. Returns number of entries cleared."""
keys_to_remove = [key for key in cls._limits_cache.keys() if key.startswith(f"{user_id}:")]
for key in keys_to_remove:
del cls._limits_cache[key]
logger.info(f"Cleared {len(keys_to_remove)} cache entries for user {user_id}")
return len(keys_to_remove)
def initialize_default_pricing(self):
"""Initialize default pricing for all API providers."""
# Gemini API Pricing (Updated as of September 2025 - Official Google AI Pricing)
# Source: https://ai.google.dev/gemini-api/docs/pricing
gemini_pricing = [
# Gemini 2.5 Pro - Standard Tier
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-2.5-pro",
"cost_per_input_token": 0.00000125, # $1.25 per 1M input tokens (prompts <= 200k tokens)
"cost_per_output_token": 0.00001, # $10.00 per 1M output tokens (prompts <= 200k tokens)
"description": "Gemini 2.5 Pro - State-of-the-art multipurpose model for coding and complex reasoning"
},
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-2.5-pro-large",
"cost_per_input_token": 0.0000025, # $2.50 per 1M input tokens (prompts > 200k tokens)
"cost_per_output_token": 0.000015, # $15.00 per 1M output tokens (prompts > 200k tokens)
"description": "Gemini 2.5 Pro - Large context model for prompts > 200k tokens"
},
# Gemini 2.5 Flash - Standard Tier
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-2.5-flash",
"cost_per_input_token": 0.0000003, # $0.30 per 1M input tokens (text/image/video)
"cost_per_output_token": 0.0000025, # $2.50 per 1M output tokens
"description": "Gemini 2.5 Flash - Hybrid reasoning model with 1M token context window"
},
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-2.5-flash-audio",
"cost_per_input_token": 0.000001, # $1.00 per 1M input tokens (audio)
"cost_per_output_token": 0.0000025, # $2.50 per 1M output tokens
"description": "Gemini 2.5 Flash - Audio input model"
},
# Gemini 2.5 Flash-Lite - Standard Tier
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-2.5-flash-lite",
"cost_per_input_token": 0.0000001, # $0.10 per 1M input tokens (text/image/video)
"cost_per_output_token": 0.0000004, # $0.40 per 1M output tokens
"description": "Gemini 2.5 Flash-Lite - Smallest and most cost-effective model for at-scale usage"
},
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-2.5-flash-lite-audio",
"cost_per_input_token": 0.0000003, # $0.30 per 1M input tokens (audio)
"cost_per_output_token": 0.0000004, # $0.40 per 1M output tokens
"description": "Gemini 2.5 Flash-Lite - Audio input model"
},
# Gemini 1.5 Flash - Standard Tier
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-1.5-flash",
"cost_per_input_token": 0.000000075, # $0.075 per 1M input tokens (prompts <= 128k tokens)
"cost_per_output_token": 0.0000003, # $0.30 per 1M output tokens (prompts <= 128k tokens)
"description": "Gemini 1.5 Flash - Fast multimodal model with 1M token context window"
},
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-1.5-flash-large",
"cost_per_input_token": 0.00000015, # $0.15 per 1M input tokens (prompts > 128k tokens)
"cost_per_output_token": 0.0000006, # $0.60 per 1M output tokens (prompts > 128k tokens)
"description": "Gemini 1.5 Flash - Large context model for prompts > 128k tokens"
},
# Gemini 1.5 Flash-8B - Standard Tier
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-1.5-flash-8b",
"cost_per_input_token": 0.0000000375, # $0.0375 per 1M input tokens (prompts <= 128k tokens)
"cost_per_output_token": 0.00000015, # $0.15 per 1M output tokens (prompts <= 128k tokens)
"description": "Gemini 1.5 Flash-8B - Smallest model for lower intelligence use cases"
},
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-1.5-flash-8b-large",
"cost_per_input_token": 0.000000075, # $0.075 per 1M input tokens (prompts > 128k tokens)
"cost_per_output_token": 0.0000003, # $0.30 per 1M output tokens (prompts > 128k tokens)
"description": "Gemini 1.5 Flash-8B - Large context model for prompts > 128k tokens"
},
# Gemini 1.5 Pro - Standard Tier
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-1.5-pro",
"cost_per_input_token": 0.00000125, # $1.25 per 1M input tokens (prompts <= 128k tokens)
"cost_per_output_token": 0.000005, # $5.00 per 1M output tokens (prompts <= 128k tokens)
"description": "Gemini 1.5 Pro - Highest intelligence model with 2M token context window"
},
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-1.5-pro-large",
"cost_per_input_token": 0.0000025, # $2.50 per 1M input tokens (prompts > 128k tokens)
"cost_per_output_token": 0.00001, # $10.00 per 1M output tokens (prompts > 128k tokens)
"description": "Gemini 1.5 Pro - Large context model for prompts > 128k tokens"
},
# Gemini Embedding - Standard Tier
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-embedding",
"cost_per_input_token": 0.00000015, # $0.15 per 1M input tokens
"cost_per_output_token": 0.0, # No output tokens for embeddings
"description": "Gemini Embedding - Newest embeddings model with higher rate limits"
},
# Grounding with Google Search - Standard Tier
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-grounding-search",
"cost_per_request": 0.035, # $35 per 1,000 requests (after free tier)
"cost_per_input_token": 0.0, # No additional token cost for grounding
"cost_per_output_token": 0.0, # No additional token cost for grounding
"description": "Grounding with Google Search - 1,500 RPD free, then $35/1K requests"
}
]
# OpenAI Pricing (estimated, will be updated)
openai_pricing = [
{
"provider": APIProvider.OPENAI,
"model_name": "gpt-4o",
"cost_per_input_token": 0.0000025, # $2.50 per 1M input tokens
"cost_per_output_token": 0.00001, # $10.00 per 1M output tokens
"description": "GPT-4o - Latest OpenAI model"
},
{
"provider": APIProvider.OPENAI,
"model_name": "gpt-4o-mini",
"cost_per_input_token": 0.00000015, # $0.15 per 1M input tokens
"cost_per_output_token": 0.0000006, # $0.60 per 1M output tokens
"description": "GPT-4o Mini - Cost-effective model"
}
]
# Anthropic Pricing (estimated, will be updated)
anthropic_pricing = [
{
"provider": APIProvider.ANTHROPIC,
"model_name": "claude-3.5-sonnet",
"cost_per_input_token": 0.000003, # $3.00 per 1M input tokens
"cost_per_output_token": 0.000015, # $15.00 per 1M output tokens
"description": "Claude 3.5 Sonnet - Anthropic's flagship model"
}
]
# HuggingFace/Mistral Pricing (for GPT-OSS-120B via Groq)
# Default pricing from environment variables or fallback to estimated values
# Based on Groq pricing: ~$1 per 1M input tokens, ~$3 per 1M output tokens
hf_input_cost = float(os.getenv('HUGGINGFACE_INPUT_TOKEN_COST', '0.000001')) # $1 per 1M tokens default
hf_output_cost = float(os.getenv('HUGGINGFACE_OUTPUT_TOKEN_COST', '0.000003')) # $3 per 1M tokens default
mistral_pricing = [
{
"provider": APIProvider.MISTRAL,
"model_name": "openai/gpt-oss-120b:groq",
"cost_per_input_token": hf_input_cost,
"cost_per_output_token": hf_output_cost,
"description": f"GPT-OSS-120B via HuggingFace/Groq (configurable via HUGGINGFACE_INPUT_TOKEN_COST and HUGGINGFACE_OUTPUT_TOKEN_COST env vars)"
},
{
"provider": APIProvider.MISTRAL,
"model_name": "gpt-oss-120b",
"cost_per_input_token": hf_input_cost,
"cost_per_output_token": hf_output_cost,
"description": f"GPT-OSS-120B via HuggingFace/Groq (configurable via HUGGINGFACE_INPUT_TOKEN_COST and HUGGINGFACE_OUTPUT_TOKEN_COST env vars)"
},
{
"provider": APIProvider.MISTRAL,
"model_name": "default",
"cost_per_input_token": hf_input_cost,
"cost_per_output_token": hf_output_cost,
"description": f"HuggingFace default model pricing (configurable via HUGGINGFACE_INPUT_TOKEN_COST and HUGGINGFACE_OUTPUT_TOKEN_COST env vars)"
}
]
# Search API Pricing (estimated)
search_pricing = [
{
"provider": APIProvider.TAVILY,
"model_name": "tavily-search",
"cost_per_request": 0.001, # $0.001 per search
"description": "Tavily AI Search API"
},
{
"provider": APIProvider.SERPER,
"model_name": "serper-search",
"cost_per_request": 0.001, # $0.001 per search
"description": "Serper Google Search API"
},
{
"provider": APIProvider.METAPHOR,
"model_name": "metaphor-search",
"cost_per_request": 0.003, # $0.003 per search
"description": "Metaphor/Exa AI Search API"
},
{
"provider": APIProvider.FIRECRAWL,
"model_name": "firecrawl-extract",
"cost_per_page": 0.002, # $0.002 per page crawled
"description": "Firecrawl Web Extraction API"
},
{
"provider": APIProvider.STABILITY,
"model_name": "stable-diffusion",
"cost_per_image": 0.04, # $0.04 per image
"description": "Stability AI Image Generation"
},
{
"provider": APIProvider.EXA,
"model_name": "exa-search",
"cost_per_request": 0.005, # $0.005 per search (1-25 results)
"description": "Exa Neural Search API"
},
{
"provider": APIProvider.VIDEO,
"model_name": "tencent/HunyuanVideo",
"cost_per_request": 0.10, # $0.10 per video generation (estimated)
"description": "HuggingFace AI Video Generation (HunyuanVideo)"
},
{
"provider": APIProvider.VIDEO,
"model_name": "default",
"cost_per_request": 0.10, # $0.10 per video generation (estimated)
"description": "AI Video Generation default pricing"
},
{
"provider": APIProvider.VIDEO,
"model_name": "kling-v2.5-turbo-std-5s",
"cost_per_request": 0.21,
"description": "WaveSpeed Kling v2.5 Turbo Std Image-to-Video (5 seconds)"
},
{
"provider": APIProvider.VIDEO,
"model_name": "kling-v2.5-turbo-std-10s",
"cost_per_request": 0.42,
"description": "WaveSpeed Kling v2.5 Turbo Std Image-to-Video (10 seconds)"
},
{
"provider": APIProvider.VIDEO,
"model_name": "wavespeed-ai/infinitetalk",
"cost_per_request": 0.30,
"description": "WaveSpeed InfiniteTalk (image + audio to talking avatar video)"
},
# Audio Generation Pricing (Minimax Speech 02 HD via WaveSpeed)
{
"provider": APIProvider.AUDIO,
"model_name": "minimax/speech-02-hd",
"cost_per_input_token": 0.00005, # $0.05 per 1,000 characters (every character is 1 token)
"cost_per_output_token": 0.0, # No output tokens for audio
"cost_per_request": 0.0, # Pricing is per character, not per request
"description": "AI Audio Generation (Text-to-Speech) - Minimax Speech 02 HD via WaveSpeed"
},
{
"provider": APIProvider.AUDIO,
"model_name": "default",
"cost_per_input_token": 0.00005, # $0.05 per 1,000 characters default
"cost_per_output_token": 0.0,
"cost_per_request": 0.0,
"description": "AI Audio Generation default pricing"
}
]
# Combine all pricing data (include video pricing in search_pricing list)
all_pricing = gemini_pricing + openai_pricing + anthropic_pricing + mistral_pricing + search_pricing
# Insert or update pricing data
for pricing_data in all_pricing:
existing = self.db.query(APIProviderPricing).filter(
APIProviderPricing.provider == pricing_data["provider"],
APIProviderPricing.model_name == pricing_data["model_name"]
).first()
if existing:
# Update existing pricing (especially for HuggingFace if env vars changed)
if pricing_data["provider"] == APIProvider.MISTRAL:
# Update HuggingFace pricing from env vars
existing.cost_per_input_token = pricing_data["cost_per_input_token"]
existing.cost_per_output_token = pricing_data["cost_per_output_token"]
existing.description = pricing_data["description"]
existing.updated_at = datetime.utcnow()
logger.debug(f"Updated pricing for {pricing_data['provider'].value}:{pricing_data['model_name']}")
else:
pricing = APIProviderPricing(**pricing_data)
self.db.add(pricing)
logger.debug(f"Added new pricing for {pricing_data['provider'].value}:{pricing_data['model_name']}")
self.db.commit()
logger.info("Default API pricing initialized/updated. HuggingFace pricing loaded from env vars if available.")
def initialize_default_plans(self):
"""Initialize default subscription plans."""
plans = [
{
"name": "Free",
"tier": SubscriptionTier.FREE,
"price_monthly": 0.0,
"price_yearly": 0.0,
"gemini_calls_limit": 100,
"openai_calls_limit": 0,
"anthropic_calls_limit": 0,
"mistral_calls_limit": 50,
"tavily_calls_limit": 20,
"serper_calls_limit": 20,
"metaphor_calls_limit": 10,
"firecrawl_calls_limit": 10,
"stability_calls_limit": 5,
"exa_calls_limit": 100,
"video_calls_limit": 0, # No video generation for free tier
"image_edit_calls_limit": 10, # 10 AI image editing calls/month
"audio_calls_limit": 20, # 20 AI audio generation calls/month
"gemini_tokens_limit": 100000,
"monthly_cost_limit": 0.0,
"features": ["basic_content_generation", "limited_research"],
"description": "Perfect for trying out ALwrity"
},
{
"name": "Basic",
"tier": SubscriptionTier.BASIC,
"price_monthly": 29.0,
"price_yearly": 290.0,
"ai_text_generation_calls_limit": 10, # Unified limit for all LLM providers
"gemini_calls_limit": 1000, # Legacy, kept for backwards compatibility (not used for enforcement)
"openai_calls_limit": 500,
"anthropic_calls_limit": 200,
"mistral_calls_limit": 500,
"tavily_calls_limit": 200,
"serper_calls_limit": 200,
"metaphor_calls_limit": 100,
"firecrawl_calls_limit": 100,
"stability_calls_limit": 5,
"exa_calls_limit": 500,
"video_calls_limit": 20, # 20 videos/month for basic plan
"image_edit_calls_limit": 30, # 30 AI image editing calls/month
"audio_calls_limit": 50, # 50 AI audio generation calls/month
"gemini_tokens_limit": 20000, # Increased from 5000 for better stability
"openai_tokens_limit": 20000, # Increased from 5000 for better stability
"anthropic_tokens_limit": 20000, # Increased from 5000 for better stability
"mistral_tokens_limit": 20000, # Increased from 5000 for better stability
"monthly_cost_limit": 50.0,
"features": ["full_content_generation", "advanced_research", "basic_analytics"],
"description": "Great for individuals and small teams"
},
{
"name": "Pro",
"tier": SubscriptionTier.PRO,
"price_monthly": 79.0,
"price_yearly": 790.0,
"gemini_calls_limit": 5000,
"openai_calls_limit": 2500,
"anthropic_calls_limit": 1000,
"mistral_calls_limit": 2500,
"tavily_calls_limit": 1000,
"serper_calls_limit": 1000,
"metaphor_calls_limit": 500,
"firecrawl_calls_limit": 500,
"stability_calls_limit": 200,
"exa_calls_limit": 2000,
"video_calls_limit": 50, # 50 videos/month for pro plan
"image_edit_calls_limit": 100, # 100 AI image editing calls/month
"audio_calls_limit": 200, # 200 AI audio generation calls/month
"gemini_tokens_limit": 5000000,
"openai_tokens_limit": 2500000,
"anthropic_tokens_limit": 1000000,
"mistral_tokens_limit": 2500000,
"monthly_cost_limit": 150.0,
"features": ["unlimited_content_generation", "premium_research", "advanced_analytics", "priority_support"],
"description": "Perfect for growing businesses"
},
{
"name": "Enterprise",
"tier": SubscriptionTier.ENTERPRISE,
"price_monthly": 199.0,
"price_yearly": 1990.0,
"gemini_calls_limit": 0, # Unlimited
"openai_calls_limit": 0,
"anthropic_calls_limit": 0,
"mistral_calls_limit": 0,
"tavily_calls_limit": 0,
"serper_calls_limit": 0,
"metaphor_calls_limit": 0,
"firecrawl_calls_limit": 0,
"stability_calls_limit": 0,
"exa_calls_limit": 0, # Unlimited
"video_calls_limit": 0, # Unlimited for enterprise
"image_edit_calls_limit": 0, # Unlimited image editing for enterprise
"audio_calls_limit": 0, # Unlimited audio generation for enterprise
"gemini_tokens_limit": 0,
"openai_tokens_limit": 0,
"anthropic_tokens_limit": 0,
"mistral_tokens_limit": 0,
"monthly_cost_limit": 500.0,
"features": ["unlimited_everything", "white_label", "dedicated_support", "custom_integrations"],
"description": "For large organizations with high-volume needs"
}
]
for plan_data in plans:
existing = self.db.query(SubscriptionPlan).filter(
SubscriptionPlan.name == plan_data["name"]
).first()
if not existing:
plan = SubscriptionPlan(**plan_data)
self.db.add(plan)
else:
# Update existing plan with new limits (e.g., image_edit_calls_limit)
# This ensures existing plans get new columns like image_edit_calls_limit
for key, value in plan_data.items():
if key not in ["name", "tier"]: # Don't overwrite name/tier
try:
# Try to set the attribute (works even if column was just added)
setattr(existing, key, value)
except (AttributeError, Exception) as e:
# If attribute doesn't exist yet (column not migrated), skip it
# Schema migration will add it, then this will update it on next run
logger.debug(f"Could not set {key} on plan {existing.name}: {e}")
existing.updated_at = datetime.utcnow()
logger.debug(f"Updated existing plan: {existing.name}")
self.db.commit()
logger.debug("Default subscription plans initialized")
def calculate_api_cost(self, provider: APIProvider, model_name: str,
tokens_input: int = 0, tokens_output: int = 0,
request_count: int = 1, **kwargs) -> Dict[str, float]:
"""Calculate cost for an API call.
Args:
provider: APIProvider enum (e.g., APIProvider.MISTRAL for HuggingFace)
model_name: Model name (e.g., "openai/gpt-oss-120b:groq")
tokens_input: Number of input tokens
tokens_output: Number of output tokens
request_count: Number of requests (default: 1)
**kwargs: Additional parameters (search_count, image_count, page_count, etc.)
Returns:
Dict with cost_input, cost_output, and cost_total
"""
# Get pricing for the provider and model
# Try exact match first
pricing = self.db.query(APIProviderPricing).filter(
APIProviderPricing.provider == provider,
APIProviderPricing.model_name == model_name,
APIProviderPricing.is_active == True
).first()
# If not found, try "default" model name for the provider
if not pricing:
pricing = self.db.query(APIProviderPricing).filter(
APIProviderPricing.provider == provider,
APIProviderPricing.model_name == "default",
APIProviderPricing.is_active == True
).first()
# If still not found, check for HuggingFace models (provider is MISTRAL)
# Try alternative model name variations
if not pricing and provider == APIProvider.MISTRAL:
# Try with "gpt-oss-120b" (without full path) if model contains it
if "gpt-oss-120b" in model_name.lower():
pricing = self.db.query(APIProviderPricing).filter(
APIProviderPricing.provider == provider,
APIProviderPricing.model_name == "gpt-oss-120b",
APIProviderPricing.is_active == True
).first()
# Also try with full model path
if not pricing:
pricing = self.db.query(APIProviderPricing).filter(
APIProviderPricing.provider == provider,
APIProviderPricing.model_name == "openai/gpt-oss-120b:groq",
APIProviderPricing.is_active == True
).first()
if not pricing:
# Check if we should use env vars for HuggingFace/Mistral
if provider == APIProvider.MISTRAL:
# Use environment variables for HuggingFace pricing if available
hf_input_cost = float(os.getenv('HUGGINGFACE_INPUT_TOKEN_COST', '0.000001'))
hf_output_cost = float(os.getenv('HUGGINGFACE_OUTPUT_TOKEN_COST', '0.000003'))
logger.info(f"Using HuggingFace pricing from env vars: input={hf_input_cost}, output={hf_output_cost} for model {model_name}")
cost_input = tokens_input * hf_input_cost
cost_output = tokens_output * hf_output_cost
cost_total = cost_input + cost_output
else:
logger.warning(f"No pricing found for {provider.value}:{model_name}, using default estimates")
# Use default estimates
cost_input = tokens_input * 0.000001 # $1 per 1M tokens default
cost_output = tokens_output * 0.000001
cost_total = cost_input + cost_output
else:
# Calculate based on actual pricing from database
logger.debug(f"Using pricing from DB for {provider.value}:{model_name} - input: {pricing.cost_per_input_token}, output: {pricing.cost_per_output_token}")
cost_input = tokens_input * (pricing.cost_per_input_token or 0.0)
cost_output = tokens_output * (pricing.cost_per_output_token or 0.0)
cost_request = request_count * (pricing.cost_per_request or 0.0)
# Handle special cases for non-LLM APIs
cost_search = kwargs.get('search_count', 0) * (pricing.cost_per_search or 0.0)
cost_image = kwargs.get('image_count', 0) * (pricing.cost_per_image or 0.0)
cost_page = kwargs.get('page_count', 0) * (pricing.cost_per_page or 0.0)
cost_total = cost_input + cost_output + cost_request + cost_search + cost_image + cost_page
# Round to 6 decimal places for precision
return {
'cost_input': round(cost_input, 6),
'cost_output': round(cost_output, 6),
'cost_total': round(cost_total, 6)
}
def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Get usage limits for a user based on their subscription."""
# CRITICAL: Expire all objects first to ensure fresh data after renewal
self.db.expire_all()
subscription = self.db.query(UserSubscription).filter(
UserSubscription.user_id == user_id,
UserSubscription.is_active == True
).first()
if not subscription:
# Return free tier limits
free_plan = self.db.query(SubscriptionPlan).filter(
SubscriptionPlan.tier == SubscriptionTier.FREE
).first()
if free_plan:
return self._plan_to_limits_dict(free_plan)
return None
# Ensure current period before returning limits
self._ensure_subscription_current(subscription)
# CRITICAL: Refresh subscription to get latest plan_id, then refresh plan relationship
self.db.refresh(subscription)
# Re-query plan directly to ensure fresh data (bypass relationship cache)
plan = self.db.query(SubscriptionPlan).filter(
SubscriptionPlan.id == subscription.plan_id
).first()
if not plan:
logger.error(f"Plan not found for subscription plan_id={subscription.plan_id}")
return None
# Refresh plan to ensure fresh limits
self.db.refresh(plan)
return self._plan_to_limits_dict(plan)
def _ensure_ai_text_gen_column_detection(self) -> None:
"""Detect at runtime whether ai_text_generation_calls_limit column exists and cache the result."""
if self._ai_text_gen_col_checked:
return
try:
# Try to query the column - if it exists, this will work
self.db.execute(text('SELECT ai_text_generation_calls_limit FROM subscription_plans LIMIT 0'))
self._ai_text_gen_col_available = True
except Exception:
self._ai_text_gen_col_available = False
finally:
self._ai_text_gen_col_checked = True
def _plan_to_limits_dict(self, plan: SubscriptionPlan) -> Dict[str, Any]:
"""Convert subscription plan to limits dictionary."""
# Detect if unified AI text generation limit column exists
self._ensure_ai_text_gen_column_detection()
# Use unified AI text generation limit if column exists and is set
ai_text_gen_limit = None
if self._ai_text_gen_col_available:
try:
ai_text_gen_limit = getattr(plan, 'ai_text_generation_calls_limit', None)
# If 0, treat as not set (unlimited for Enterprise or use fallback)
if ai_text_gen_limit == 0:
ai_text_gen_limit = None
except (AttributeError, Exception):
# Column exists but access failed - use fallback
ai_text_gen_limit = None
return {
'plan_name': plan.name,
'tier': plan.tier.value,
'limits': {
# Unified AI text generation limit (applies to all LLM providers)
# If not set, fall back to first non-zero legacy limit for backwards compatibility
'ai_text_generation_calls': ai_text_gen_limit if ai_text_gen_limit is not None else (
plan.gemini_calls_limit if plan.gemini_calls_limit > 0 else
plan.openai_calls_limit if plan.openai_calls_limit > 0 else
plan.anthropic_calls_limit if plan.anthropic_calls_limit > 0 else
plan.mistral_calls_limit if plan.mistral_calls_limit > 0 else 0
),
# Legacy per-provider limits (for backwards compatibility and analytics)
'gemini_calls': plan.gemini_calls_limit,
'openai_calls': plan.openai_calls_limit,
'anthropic_calls': plan.anthropic_calls_limit,
'mistral_calls': plan.mistral_calls_limit,
# Other API limits
'tavily_calls': plan.tavily_calls_limit,
'serper_calls': plan.serper_calls_limit,
'metaphor_calls': plan.metaphor_calls_limit,
'firecrawl_calls': plan.firecrawl_calls_limit,
'stability_calls': plan.stability_calls_limit,
'video_calls': getattr(plan, 'video_calls_limit', 0), # Support missing column
'image_edit_calls': getattr(plan, 'image_edit_calls_limit', 0), # Support missing column
'audio_calls': getattr(plan, 'audio_calls_limit', 0), # Support missing column
# Token limits
'gemini_tokens': plan.gemini_tokens_limit,
'openai_tokens': plan.openai_tokens_limit,
'anthropic_tokens': plan.anthropic_tokens_limit,
'mistral_tokens': plan.mistral_tokens_limit,
'monthly_cost': plan.monthly_cost_limit
},
'features': plan.features or []
}
def check_usage_limits(self, user_id: str, provider: APIProvider,
tokens_requested: int = 0, actual_provider_name: Optional[str] = None) -> Tuple[bool, str, Dict[str, Any]]:
"""Check if user can make an API call within their limits.
Delegates to LimitValidator for actual validation logic.
Args:
user_id: User ID
provider: APIProvider enum (may be MISTRAL for HuggingFace)
tokens_requested: Estimated tokens for the request
actual_provider_name: Optional actual provider name (e.g., "huggingface" when provider is MISTRAL)
Returns:
(can_proceed, error_message, usage_info)
"""
from .limit_validation import LimitValidator
validator = LimitValidator(self)
return validator.check_usage_limits(user_id, provider, tokens_requested, actual_provider_name)
def estimate_tokens(self, text: str, provider: APIProvider) -> int:
"""Estimate token count for text based on provider."""
# Get pricing info for token estimation
pricing = self.db.query(APIProviderPricing).filter(
APIProviderPricing.provider == provider,
APIProviderPricing.is_active == True
).first()
if pricing and pricing.tokens_per_word:
# Use provider-specific conversion
word_count = len(text.split())
return int(word_count * pricing.tokens_per_word)
else:
# Use default estimation (roughly 1.3 tokens per word for most models)
word_count = len(text.split())
return int(word_count * 1.3)
def get_pricing_info(self, provider: APIProvider, model_name: str = None) -> Optional[Dict[str, Any]]:
"""Get pricing information for a provider/model."""
query = self.db.query(APIProviderPricing).filter(
APIProviderPricing.provider == provider,
APIProviderPricing.is_active == True
)
if model_name:
query = query.filter(APIProviderPricing.model_name == model_name)
pricing = query.first()
if not pricing:
return None
# Return pricing info as dict
return {
'provider': pricing.provider.value,
'model_name': pricing.model_name,
'cost_per_input_token': pricing.cost_per_input_token,
'cost_per_output_token': pricing.cost_per_output_token,
'cost_per_request': pricing.cost_per_request,
'description': pricing.description
}
def check_comprehensive_limits(
self,
user_id: str,
operations: List[Dict[str, Any]]
) -> Tuple[bool, Optional[str], Optional[Dict[str, Any]]]:
"""
Comprehensive pre-flight validation that checks ALL limits before making ANY API calls.
Delegates to LimitValidator for actual validation logic.
This prevents wasteful API calls by validating that ALL subsequent operations will succeed
before making the first external API call.
Args:
user_id: User ID
operations: List of operations to validate, each with:
- 'provider': APIProvider enum
- 'tokens_requested': int (estimated tokens for LLM calls, 0 for non-LLM)
- 'actual_provider_name': Optional[str] (e.g., "huggingface" when provider is MISTRAL)
- 'operation_type': str (e.g., "google_grounding", "llm_call", "image_generation")
Returns:
(can_proceed, error_message, error_details)
If can_proceed is False, error_message explains which limit would be exceeded
"""
from .limit_validation import LimitValidator
validator = LimitValidator(self)
return validator.check_comprehensive_limits(user_id, operations)
def get_pricing_for_provider_model(self, provider: APIProvider, model_name: str) -> Optional[Dict[str, Any]]:
"""Get pricing configuration for a specific provider and model."""
pricing = self.db.query(APIProviderPricing).filter(
APIProviderPricing.provider == provider,
APIProviderPricing.model_name == model_name
).first()
if not pricing:
return None
return {
'provider': pricing.provider.value,
'model_name': pricing.model_name,
'cost_per_input_token': pricing.cost_per_input_token,
'cost_per_output_token': pricing.cost_per_output_token,
'cost_per_request': pricing.cost_per_request,
'cost_per_search': pricing.cost_per_search,
'cost_per_image': pricing.cost_per_image,
'cost_per_page': pricing.cost_per_page,
'description': pricing.description
}

View File

@@ -0,0 +1,122 @@
from typing import Set
from sqlalchemy.orm import Session
from sqlalchemy import text
from loguru import logger
_checked_subscription_plan_columns: bool = False
_checked_usage_summaries_columns: bool = False
def ensure_subscription_plan_columns(db: Session) -> None:
"""Ensure required columns exist on subscription_plans for runtime safety.
This is a defensive guard for environments where migrations have not yet
been applied. If columns are missing (e.g., exa_calls_limit), we add them
with a safe default so ORM queries do not fail.
"""
global _checked_subscription_plan_columns
if _checked_subscription_plan_columns:
return
try:
# Discover existing columns using PRAGMA
result = db.execute(text("PRAGMA table_info(subscription_plans)"))
cols: Set[str] = {row[1] for row in result}
logger.debug(f"Schema check: Found {len(cols)} columns in subscription_plans table")
# Columns we may reference in models but might be missing in older DBs
required_columns = {
"exa_calls_limit": "INTEGER DEFAULT 0",
"video_calls_limit": "INTEGER DEFAULT 0",
"image_edit_calls_limit": "INTEGER DEFAULT 0",
"audio_calls_limit": "INTEGER DEFAULT 0",
}
for col_name, ddl in required_columns.items():
if col_name not in cols:
logger.info(f"Adding missing column {col_name} to subscription_plans table")
try:
db.execute(text(f"ALTER TABLE subscription_plans ADD COLUMN {col_name} {ddl}"))
db.commit()
logger.info(f"Successfully added column {col_name}")
except Exception as alter_err:
logger.error(f"Failed to add column {col_name}: {alter_err}")
db.rollback()
# Don't set flag on error - allow retry
raise
else:
logger.debug(f"Column {col_name} already exists")
# Only set flag if we successfully completed the check
_checked_subscription_plan_columns = True
except Exception as e:
logger.error(f"Error ensuring subscription_plan columns: {e}", exc_info=True)
db.rollback()
# Don't set the flag if there was an error, so we retry next time
_checked_subscription_plan_columns = False
raise
def ensure_usage_summaries_columns(db: Session) -> None:
"""Ensure required columns exist on usage_summaries for runtime safety.
This is a defensive guard for environments where migrations have not yet
been applied. If columns are missing (e.g., exa_calls, exa_cost), we add them
with a safe default so ORM queries do not fail.
"""
global _checked_usage_summaries_columns
if _checked_usage_summaries_columns:
return
try:
# Discover existing columns using PRAGMA
result = db.execute(text("PRAGMA table_info(usage_summaries)"))
cols: Set[str] = {row[1] for row in result}
logger.debug(f"Schema check: Found {len(cols)} columns in usage_summaries table")
# Columns we may reference in models but might be missing in older DBs
required_columns = {
"exa_calls": "INTEGER DEFAULT 0",
"exa_cost": "REAL DEFAULT 0.0",
"video_calls": "INTEGER DEFAULT 0",
"video_cost": "REAL DEFAULT 0.0",
"image_edit_calls": "INTEGER DEFAULT 0",
"image_edit_cost": "REAL DEFAULT 0.0",
"audio_calls": "INTEGER DEFAULT 0",
"audio_cost": "REAL DEFAULT 0.0",
}
for col_name, ddl in required_columns.items():
if col_name not in cols:
logger.info(f"Adding missing column {col_name} to usage_summaries table")
try:
db.execute(text(f"ALTER TABLE usage_summaries ADD COLUMN {col_name} {ddl}"))
db.commit()
logger.info(f"Successfully added column {col_name}")
except Exception as alter_err:
logger.error(f"Failed to add column {col_name}: {alter_err}")
db.rollback()
# Don't set flag on error - allow retry
raise
else:
logger.debug(f"Column {col_name} already exists")
# Only set flag if we successfully completed the check
_checked_usage_summaries_columns = True
except Exception as e:
logger.error(f"Error ensuring usage_summaries columns: {e}", exc_info=True)
db.rollback()
# Don't set the flag if there was an error, so we retry next time
_checked_usage_summaries_columns = False
raise
def ensure_all_schema_columns(db: Session) -> None:
"""Ensure all required columns exist in subscription-related tables."""
ensure_subscription_plan_columns(db)
ensure_usage_summaries_columns(db)

View File

@@ -0,0 +1,644 @@
"""
Usage Tracking Service
Comprehensive tracking of API usage, costs, and subscription limits.
"""
import asyncio
from typing import Dict, Any, Optional, List, Tuple
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from loguru import logger
import json
from models.subscription_models import (
APIUsageLog, UsageSummary, APIProvider, UsageAlert,
UserSubscription, UsageStatus
)
from .pricing_service import PricingService
class UsageTrackingService:
"""Service for tracking API usage and managing subscription limits."""
def __init__(self, db: Session):
self.db = db
self.pricing_service = PricingService(db)
# TTL cache (30s) for enforcement results to cut DB chatter
# key: f"{user_id}:{provider}", value: { 'result': (bool,str,dict), 'expires_at': datetime }
self._enforce_cache: Dict[str, Dict[str, Any]] = {}
async def track_api_usage(self, user_id: str, provider: APIProvider,
endpoint: str, method: str, model_used: str = None,
tokens_input: int = 0, tokens_output: int = 0,
response_time: float = 0.0, status_code: int = 200,
request_size: int = None, response_size: int = None,
user_agent: str = None, ip_address: str = None,
error_message: str = None, retry_count: int = 0,
**kwargs) -> Dict[str, Any]:
"""Track an API usage event and update billing information."""
try:
# Calculate costs
# Use specific model names instead of generic defaults
default_models = {
"gemini": "gemini-2.5-flash", # Use Flash as default (cost-effective)
"openai": "gpt-4o-mini", # Use Mini as default (cost-effective)
"anthropic": "claude-3.5-sonnet", # Use Sonnet as default
"mistral": "openai/gpt-oss-120b:groq" # HuggingFace default model
}
# For HuggingFace (stored as MISTRAL), use the actual model name or default
if provider == APIProvider.MISTRAL:
# HuggingFace models - try to match the actual model name from model_used
if model_used:
model_name = model_used
else:
model_name = default_models.get("mistral", "openai/gpt-oss-120b:groq")
else:
model_name = model_used or default_models.get(provider.value, f"{provider.value}-default")
cost_data = self.pricing_service.calculate_api_cost(
provider=provider,
model_name=model_name,
tokens_input=tokens_input,
tokens_output=tokens_output,
request_count=1,
**kwargs
)
# Create usage log entry
billing_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
usage_log = APIUsageLog(
user_id=user_id,
provider=provider,
endpoint=endpoint,
method=method,
model_used=model_used,
tokens_input=tokens_input,
tokens_output=tokens_output,
tokens_total=(tokens_input or 0) + (tokens_output or 0),
cost_input=cost_data['cost_input'],
cost_output=cost_data['cost_output'],
cost_total=cost_data['cost_total'],
response_time=response_time,
status_code=status_code,
request_size=request_size,
response_size=response_size,
user_agent=user_agent,
ip_address=ip_address,
error_message=error_message,
retry_count=retry_count,
billing_period=billing_period
)
self.db.add(usage_log)
# Update usage summary
await self._update_usage_summary(
user_id=user_id,
provider=provider,
tokens_used=(tokens_input or 0) + (tokens_output or 0),
cost=cost_data['cost_total'],
billing_period=billing_period,
response_time=response_time,
is_error=status_code >= 400
)
# Check for usage alerts
await self._check_usage_alerts(user_id, provider, billing_period)
self.db.commit()
logger.info(f"Tracked API usage: {user_id} -> {provider.value} -> ${cost_data['cost_total']:.6f}")
return {
'usage_logged': True,
'cost': cost_data['cost_total'],
'tokens_used': (tokens_input or 0) + (tokens_output or 0),
'billing_period': billing_period
}
except Exception as e:
logger.error(f"Error tracking API usage: {str(e)}")
self.db.rollback()
return {
'usage_logged': False,
'error': str(e)
}
async def _update_usage_summary(self, user_id: str, provider: APIProvider,
tokens_used: int, cost: float, billing_period: str,
response_time: float, is_error: bool):
"""Update the usage summary for a user."""
# Get or create usage summary
summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == billing_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=billing_period
)
self.db.add(summary)
# Update provider-specific counters
provider_name = provider.value
current_calls = getattr(summary, f"{provider_name}_calls", 0)
setattr(summary, f"{provider_name}_calls", current_calls + 1)
# Update token usage for LLM providers
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
current_tokens = getattr(summary, f"{provider_name}_tokens", 0)
setattr(summary, f"{provider_name}_tokens", current_tokens + tokens_used)
# Update cost
current_cost = getattr(summary, f"{provider_name}_cost", 0.0)
setattr(summary, f"{provider_name}_cost", current_cost + cost)
# Update totals
summary.total_calls += 1
summary.total_tokens += tokens_used
summary.total_cost += cost
# Update performance metrics
if summary.total_calls > 0:
# Update average response time
total_response_time = summary.avg_response_time * (summary.total_calls - 1) + response_time
summary.avg_response_time = total_response_time / summary.total_calls
# Update error rate
if is_error:
error_count = int(summary.error_rate * (summary.total_calls - 1) / 100) + 1
summary.error_rate = (error_count / summary.total_calls) * 100
else:
error_count = int(summary.error_rate * (summary.total_calls - 1) / 100)
summary.error_rate = (error_count / summary.total_calls) * 100
# Update usage status based on limits
await self._update_usage_status(summary)
summary.updated_at = datetime.utcnow()
async def _update_usage_status(self, summary: UsageSummary):
"""Update usage status based on subscription limits."""
limits = self.pricing_service.get_user_limits(summary.user_id)
if not limits:
return
# Check various limits and determine status
max_usage_percentage = 0.0
# Check cost limit
cost_limit = limits['limits'].get('monthly_cost', 0)
if cost_limit > 0:
cost_usage_pct = (summary.total_cost / cost_limit) * 100
max_usage_percentage = max(max_usage_percentage, cost_usage_pct)
# Check call limits for each provider
for provider in APIProvider:
provider_name = provider.value
current_calls = getattr(summary, f"{provider_name}_calls", 0)
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
if call_limit > 0:
call_usage_pct = (current_calls / call_limit) * 100
max_usage_percentage = max(max_usage_percentage, call_usage_pct)
# Update status based on highest usage percentage
if max_usage_percentage >= 100:
summary.usage_status = UsageStatus.LIMIT_REACHED
elif max_usage_percentage >= 80:
summary.usage_status = UsageStatus.WARNING
else:
summary.usage_status = UsageStatus.ACTIVE
async def _check_usage_alerts(self, user_id: str, provider: APIProvider, billing_period: str):
"""Check if usage alerts should be sent."""
# Get current usage
summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == billing_period
).first()
if not summary:
return
# Get user limits
limits = self.pricing_service.get_user_limits(user_id)
if not limits:
return
# Check for alert thresholds (80%, 90%, 100%)
thresholds = [80, 90, 100]
for threshold in thresholds:
# Check if alert already sent for this threshold
existing_alert = self.db.query(UsageAlert).filter(
UsageAlert.user_id == user_id,
UsageAlert.billing_period == billing_period,
UsageAlert.threshold_percentage == threshold,
UsageAlert.provider == provider,
UsageAlert.is_sent == True
).first()
if existing_alert:
continue
# Check if threshold is reached
provider_name = provider.value
current_calls = getattr(summary, f"{provider_name}_calls", 0)
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
if call_limit > 0:
usage_percentage = (current_calls / call_limit) * 100
if usage_percentage >= threshold:
await self._create_usage_alert(
user_id=user_id,
provider=provider,
threshold=threshold,
current_usage=current_calls,
limit=call_limit,
billing_period=billing_period
)
async def _create_usage_alert(self, user_id: str, provider: APIProvider,
threshold: int, current_usage: int, limit: int,
billing_period: str):
"""Create a usage alert."""
# Determine alert type and severity
if threshold >= 100:
alert_type = "limit_reached"
severity = "error"
title = f"API Limit Reached - {provider.value.title()}"
message = f"You have reached your {provider.value} API limit of {limit:,} calls for this billing period."
elif threshold >= 90:
alert_type = "usage_warning"
severity = "warning"
title = f"API Usage Warning - {provider.value.title()}"
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
else:
alert_type = "usage_warning"
severity = "info"
title = f"API Usage Notice - {provider.value.title()}"
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
alert = UsageAlert(
user_id=user_id,
alert_type=alert_type,
threshold_percentage=threshold,
provider=provider,
title=title,
message=message,
severity=severity,
billing_period=billing_period
)
self.db.add(alert)
logger.info(f"Created usage alert for {user_id}: {title}")
def get_user_usage_stats(self, user_id: str, billing_period: str = None) -> Dict[str, Any]:
"""Get comprehensive usage statistics for a user."""
if not billing_period:
billing_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get usage summary
summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == billing_period
).first()
# Get user limits
limits = self.pricing_service.get_user_limits(user_id)
# Get recent alerts
alerts = self.db.query(UsageAlert).filter(
UsageAlert.user_id == user_id,
UsageAlert.billing_period == billing_period,
UsageAlert.is_read == False
).order_by(UsageAlert.created_at.desc()).limit(10).all()
if not summary:
# No usage this period - return complete structure with zeros
provider_breakdown = {}
usage_percentages = {}
# Initialize provider breakdown with zeros
for provider in APIProvider:
provider_name = provider.value
provider_breakdown[provider_name] = {
'calls': 0,
'tokens': 0,
'cost': 0.0
}
usage_percentages[f"{provider_name}_calls"] = 0
usage_percentages['cost'] = 0
return {
'billing_period': billing_period,
'usage_status': 'active',
'total_calls': 0,
'total_tokens': 0,
'total_cost': 0.0,
'avg_response_time': 0.0,
'error_rate': 0.0,
'last_updated': datetime.now().isoformat(),
'limits': limits,
'provider_breakdown': provider_breakdown,
'alerts': [],
'usage_percentages': {}
}
# Provider breakdown - calculate costs first, then use for percentages
# Only include Gemini and HuggingFace (HuggingFace is stored under MISTRAL enum)
provider_breakdown = {}
# Gemini
gemini_calls = getattr(summary, "gemini_calls", 0) or 0
gemini_tokens = getattr(summary, "gemini_tokens", 0) or 0
gemini_cost = getattr(summary, "gemini_cost", 0.0) or 0.0
# If gemini cost is 0 but there are calls, calculate from usage logs
if gemini_calls > 0 and gemini_cost == 0.0:
gemini_logs = self.db.query(APIUsageLog).filter(
APIUsageLog.user_id == user_id,
APIUsageLog.provider == APIProvider.GEMINI,
APIUsageLog.billing_period == billing_period
).all()
if gemini_logs:
gemini_cost = sum(float(log.cost_total or 0.0) for log in gemini_logs)
logger.info(f"[UsageStats] Calculated gemini cost from {len(gemini_logs)} logs: ${gemini_cost:.6f}")
provider_breakdown['gemini'] = {
'calls': gemini_calls,
'tokens': gemini_tokens,
'cost': gemini_cost
}
# HuggingFace (stored as MISTRAL in database)
mistral_calls = getattr(summary, "mistral_calls", 0) or 0
mistral_tokens = getattr(summary, "mistral_tokens", 0) or 0
mistral_cost = getattr(summary, "mistral_cost", 0.0) or 0.0
# If mistral (HuggingFace) cost is 0 but there are calls, calculate from usage logs
if mistral_calls > 0 and mistral_cost == 0.0:
mistral_logs = self.db.query(APIUsageLog).filter(
APIUsageLog.user_id == user_id,
APIUsageLog.provider == APIProvider.MISTRAL,
APIUsageLog.billing_period == billing_period
).all()
if mistral_logs:
mistral_cost = sum(float(log.cost_total or 0.0) for log in mistral_logs)
logger.info(f"[UsageStats] Calculated mistral (HuggingFace) cost from {len(mistral_logs)} logs: ${mistral_cost:.6f}")
provider_breakdown['huggingface'] = {
'calls': mistral_calls,
'tokens': mistral_tokens,
'cost': mistral_cost
}
# Calculate total cost from provider breakdown if summary total_cost is 0
calculated_total_cost = gemini_cost + mistral_cost
summary_total_cost = summary.total_cost or 0.0
# Use calculated cost if summary cost is 0, otherwise use summary cost (it's more accurate)
final_total_cost = summary_total_cost if summary_total_cost > 0 else calculated_total_cost
# If we calculated costs from logs, update the summary for future requests
if calculated_total_cost > 0 and summary_total_cost == 0.0:
logger.info(f"[UsageStats] Updating summary costs: total_cost={final_total_cost:.6f}, gemini_cost={gemini_cost:.6f}, mistral_cost={mistral_cost:.6f}")
summary.total_cost = final_total_cost
summary.gemini_cost = gemini_cost
summary.mistral_cost = mistral_cost
try:
self.db.commit()
except Exception as e:
logger.error(f"[UsageStats] Error updating summary costs: {e}")
self.db.rollback()
# Calculate usage percentages - only for Gemini and HuggingFace
# Use the calculated costs for accurate percentages
usage_percentages = {}
if limits:
# Gemini
gemini_call_limit = limits['limits'].get("gemini_calls", 0) or 0
if gemini_call_limit > 0:
usage_percentages['gemini_calls'] = (gemini_calls / gemini_call_limit) * 100
else:
usage_percentages['gemini_calls'] = 0
# HuggingFace (stored as mistral in database)
mistral_call_limit = limits['limits'].get("mistral_calls", 0) or 0
if mistral_call_limit > 0:
usage_percentages['mistral_calls'] = (mistral_calls / mistral_call_limit) * 100
else:
usage_percentages['mistral_calls'] = 0
# Cost usage percentage - use final_total_cost (calculated from logs if needed)
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
if cost_limit > 0:
usage_percentages['cost'] = (final_total_cost / cost_limit) * 100
else:
usage_percentages['cost'] = 0
return {
'billing_period': billing_period,
'usage_status': summary.usage_status.value if hasattr(summary.usage_status, 'value') else str(summary.usage_status),
'total_calls': summary.total_calls or 0,
'total_tokens': summary.total_tokens or 0,
'total_cost': final_total_cost,
'avg_response_time': summary.avg_response_time or 0.0,
'error_rate': summary.error_rate or 0.0,
'limits': limits,
'provider_breakdown': provider_breakdown,
'alerts': [
{
'id': alert.id,
'type': alert.alert_type,
'title': alert.title,
'message': alert.message,
'severity': alert.severity,
'created_at': alert.created_at.isoformat()
}
for alert in alerts
],
'usage_percentages': usage_percentages,
'last_updated': summary.updated_at.isoformat()
}
def get_usage_trends(self, user_id: str, months: int = 6) -> Dict[str, Any]:
"""Get usage trends over time."""
# Calculate billing periods
end_date = datetime.now()
periods = []
for i in range(months):
period_date = end_date - timedelta(days=30 * i)
periods.append(period_date.strftime("%Y-%m"))
periods.reverse() # Oldest first
# Get usage summaries for these periods
summaries = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period.in_(periods)
).order_by(UsageSummary.billing_period).all()
# Create trends data
trends = {
'periods': periods,
'total_calls': [],
'total_cost': [],
'total_tokens': [],
'provider_trends': {}
}
summary_dict = {s.billing_period: s for s in summaries}
for period in periods:
summary = summary_dict.get(period)
if summary:
trends['total_calls'].append(summary.total_calls or 0)
trends['total_cost'].append(summary.total_cost or 0.0)
trends['total_tokens'].append(summary.total_tokens or 0)
# Provider-specific trends
for provider in APIProvider:
provider_name = provider.value
if provider_name not in trends['provider_trends']:
trends['provider_trends'][provider_name] = {
'calls': [],
'cost': [],
'tokens': []
}
trends['provider_trends'][provider_name]['calls'].append(
getattr(summary, f"{provider_name}_calls", 0) or 0
)
trends['provider_trends'][provider_name]['cost'].append(
getattr(summary, f"{provider_name}_cost", 0.0) or 0.0
)
trends['provider_trends'][provider_name]['tokens'].append(
getattr(summary, f"{provider_name}_tokens", 0) or 0
)
else:
# No data for this period
trends['total_calls'].append(0)
trends['total_cost'].append(0.0)
trends['total_tokens'].append(0)
for provider in APIProvider:
provider_name = provider.value
if provider_name not in trends['provider_trends']:
trends['provider_trends'][provider_name] = {
'calls': [],
'cost': [],
'tokens': []
}
trends['provider_trends'][provider_name]['calls'].append(0)
trends['provider_trends'][provider_name]['cost'].append(0.0)
trends['provider_trends'][provider_name]['tokens'].append(0)
return trends
async def enforce_usage_limits(self, user_id: str, provider: APIProvider,
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
"""Enforce usage limits before making an API call."""
# Check short-lived cache first (30s)
cache_key = f"{user_id}:{provider.value}"
now = datetime.utcnow()
cached = self._enforce_cache.get(cache_key)
if cached and cached.get('expires_at') and cached['expires_at'] > now:
return tuple(cached['result']) # type: ignore
result = self.pricing_service.check_usage_limits(
user_id=user_id,
provider=provider,
tokens_requested=tokens_requested
)
self._enforce_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
async def reset_current_billing_period(self, user_id: str) -> Dict[str, Any]:
"""Reset usage status and counters for the current billing period (after plan renewal/change)."""
try:
billing_period = datetime.now().strftime("%Y-%m")
summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == billing_period
).first()
if not summary:
# Nothing to reset
return {"reset": False, "reason": "no_summary"}
# CRITICAL: Reset ALL usage counters to 0 so user gets fresh limits with new/renewed plan
# Clear LIMIT_REACHED status
summary.usage_status = UsageStatus.ACTIVE
# Reset all LLM provider call counters
summary.gemini_calls = 0
summary.openai_calls = 0
summary.anthropic_calls = 0
summary.mistral_calls = 0
# Reset all LLM provider token counters
summary.gemini_tokens = 0
summary.openai_tokens = 0
summary.anthropic_tokens = 0
summary.mistral_tokens = 0
# Reset search/research provider counters
summary.tavily_calls = 0
summary.serper_calls = 0
summary.metaphor_calls = 0
summary.firecrawl_calls = 0
# Reset image generation counters
summary.stability_calls = 0
# Reset video generation counters
summary.video_calls = 0
# Reset image editing counters
summary.image_edit_calls = 0
# Reset cost counters
summary.gemini_cost = 0.0
summary.openai_cost = 0.0
summary.anthropic_cost = 0.0
summary.mistral_cost = 0.0
summary.tavily_cost = 0.0
summary.serper_cost = 0.0
summary.metaphor_cost = 0.0
summary.firecrawl_cost = 0.0
summary.stability_cost = 0.0
summary.exa_cost = 0.0
summary.video_cost = 0.0
summary.image_edit_cost = 0.0
# Reset totals
summary.total_calls = 0
summary.total_tokens = 0
summary.total_cost = 0.0
summary.updated_at = datetime.utcnow()
self.db.commit()
logger.info(f"Reset usage counters for user {user_id} in billing period {billing_period} after renewal")
return {"reset": True, "counters_reset": True}
except Exception as e:
self.db.rollback()
logger.error(f"Error resetting usage status: {e}")
return {"reset": False, "error": str(e)}