Base code
This commit is contained in:
185
backend/services/subscription/README.md
Normal file
185
backend/services/subscription/README.md
Normal file
@@ -0,0 +1,185 @@
|
||||
# Subscription Services Package
|
||||
|
||||
## Overview
|
||||
|
||||
This package consolidates all subscription, billing, and usage tracking related services and middleware into a single, well-organized module. This follows the same architectural pattern as the onboarding package for consistency and maintainability.
|
||||
|
||||
## Package Structure
|
||||
|
||||
```
|
||||
backend/services/subscription/
|
||||
├── __init__.py # Package exports
|
||||
├── pricing_service.py # API pricing and cost calculations
|
||||
├── usage_tracking_service.py # Usage tracking and limits
|
||||
├── exception_handler.py # Exception handling
|
||||
├── monitoring_middleware.py # API monitoring with usage tracking
|
||||
└── README.md # This documentation
|
||||
```
|
||||
|
||||
## Services
|
||||
|
||||
### PricingService
|
||||
- **File**: `pricing_service.py`
|
||||
- **Purpose**: Manages API pricing, cost calculation, and subscription limits
|
||||
- **Key Features**:
|
||||
- Dynamic pricing based on API provider and model
|
||||
- Cost calculation for input/output tokens
|
||||
- Subscription limit enforcement
|
||||
- Billing period management
|
||||
|
||||
### UsageTrackingService
|
||||
- **File**: `usage_tracking_service.py`
|
||||
- **Purpose**: Comprehensive tracking of API usage, costs, and subscription limits
|
||||
- **Key Features**:
|
||||
- Real-time usage tracking
|
||||
- Cost calculation and billing
|
||||
- Usage limit enforcement with TTL caching
|
||||
- Usage alerts and notifications
|
||||
|
||||
### SubscriptionExceptionHandler
|
||||
- **File**: `exception_handler.py`
|
||||
- **Purpose**: Centralized exception handling for subscription-related errors
|
||||
- **Key Features**:
|
||||
- Custom exception types
|
||||
- Error handling decorators
|
||||
- Consistent error responses
|
||||
|
||||
### Monitoring Middleware
|
||||
- **File**: `monitoring_middleware.py`
|
||||
- **Purpose**: FastAPI middleware for API monitoring and usage tracking
|
||||
- **Key Features**:
|
||||
- Request/response monitoring
|
||||
- Usage tracking integration
|
||||
- Performance metrics
|
||||
- Database API monitoring
|
||||
|
||||
## Usage
|
||||
|
||||
### Import Pattern
|
||||
|
||||
Always use the consolidated package for subscription-related imports:
|
||||
|
||||
```python
|
||||
# ✅ Correct - Use consolidated package
|
||||
from services.subscription import PricingService, UsageTrackingService
|
||||
from services.subscription import SubscriptionExceptionHandler
|
||||
from services.subscription import check_usage_limits_middleware
|
||||
|
||||
# ❌ Incorrect - Old scattered imports
|
||||
from services.pricing_service import PricingService
|
||||
from services.usage_tracking_service import UsageTrackingService
|
||||
from middleware.monitoring_middleware import check_usage_limits_middleware
|
||||
```
|
||||
|
||||
### Service Initialization
|
||||
|
||||
```python
|
||||
from services.subscription import PricingService, UsageTrackingService
|
||||
from services.database import get_db
|
||||
|
||||
# Get database session
|
||||
db = next(get_db())
|
||||
|
||||
# Initialize services
|
||||
pricing_service = PricingService(db)
|
||||
usage_service = UsageTrackingService(db)
|
||||
```
|
||||
|
||||
### Middleware Registration
|
||||
|
||||
```python
|
||||
from services.subscription import monitoring_middleware
|
||||
|
||||
# Register middleware in FastAPI app
|
||||
app.middleware("http")(monitoring_middleware)
|
||||
```
|
||||
|
||||
## Database Models
|
||||
|
||||
The subscription services use the following database models (defined in `backend/models/subscription_models.py`):
|
||||
|
||||
- `APIProvider` - API provider enumeration
|
||||
- `SubscriptionPlan` - Subscription plan definitions
|
||||
- `UserSubscription` - User subscription records
|
||||
- `UsageSummary` - Usage summary by billing period
|
||||
- `APIUsageLog` - Individual API usage logs
|
||||
- `APIProviderPricing` - Pricing configuration
|
||||
- `UsageAlert` - Usage limit alerts
|
||||
- `SubscriptionTier` - Subscription tier definitions
|
||||
- `BillingCycle` - Billing cycle enumeration
|
||||
- `UsageStatus` - Usage status enumeration
|
||||
|
||||
## Key Features
|
||||
|
||||
### 1. Database-Only Persistence
|
||||
- All data stored in database tables
|
||||
- No file-based storage
|
||||
- User-isolated data access
|
||||
|
||||
### 2. TTL Caching
|
||||
- In-memory caching for performance
|
||||
- 30-second TTL for usage limit checks
|
||||
- 10-minute TTL for dashboard data
|
||||
|
||||
### 3. Real-time Monitoring
|
||||
- Live API usage tracking
|
||||
- Performance metrics collection
|
||||
- Error rate monitoring
|
||||
|
||||
### 4. Flexible Pricing
|
||||
- Per-provider pricing configuration
|
||||
- Model-specific pricing
|
||||
- Dynamic cost calculation
|
||||
|
||||
## Error Handling
|
||||
|
||||
The package provides comprehensive error handling:
|
||||
|
||||
```python
|
||||
from services.subscription import (
|
||||
SubscriptionException,
|
||||
UsageLimitExceededException,
|
||||
PricingException,
|
||||
TrackingException
|
||||
)
|
||||
|
||||
try:
|
||||
# Subscription operation
|
||||
pass
|
||||
except UsageLimitExceededException as e:
|
||||
# Handle usage limit exceeded
|
||||
pass
|
||||
except PricingException as e:
|
||||
# Handle pricing error
|
||||
pass
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
The services use environment variables for configuration:
|
||||
|
||||
- `SUBSCRIPTION_DASHBOARD_NOCACHE` - Bypass dashboard cache
|
||||
- `ENABLE_ALPHA` - Enable alpha features (default: false)
|
||||
|
||||
## Migration from Old Structure
|
||||
|
||||
This package consolidates the following previously scattered files:
|
||||
|
||||
- `backend/services/pricing_service.py` → `subscription/pricing_service.py`
|
||||
- `backend/services/usage_tracking_service.py` → `subscription/usage_tracking_service.py`
|
||||
- `backend/services/subscription_exception_handler.py` → `subscription/exception_handler.py`
|
||||
- `backend/middleware/monitoring_middleware.py` → `subscription/monitoring_middleware.py`
|
||||
|
||||
## Benefits
|
||||
|
||||
1. **Single Package**: All subscription logic in one location
|
||||
2. **Clear Ownership**: Easy to find subscription-related code
|
||||
3. **Better Organization**: Follows same pattern as onboarding
|
||||
4. **Easier Maintenance**: Single source of truth for billing logic
|
||||
5. **Consistent Architecture**: Matches onboarding consolidation
|
||||
|
||||
## Related Packages
|
||||
|
||||
- `services.onboarding` - Onboarding and user setup
|
||||
- `models.subscription_models` - Database models
|
||||
- `api.subscription_api` - API endpoints
|
||||
40
backend/services/subscription/__init__.py
Normal file
40
backend/services/subscription/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Subscription Services Package
|
||||
# Consolidated subscription-related services and middleware
|
||||
|
||||
from .pricing_service import PricingService
|
||||
from .usage_tracking_service import UsageTrackingService
|
||||
from .exception_handler import (
|
||||
SubscriptionException,
|
||||
SubscriptionExceptionHandler,
|
||||
UsageLimitExceededException,
|
||||
PricingException,
|
||||
TrackingException,
|
||||
handle_usage_limit_error,
|
||||
handle_pricing_error,
|
||||
handle_tracking_error,
|
||||
)
|
||||
from .monitoring_middleware import (
|
||||
DatabaseAPIMonitor,
|
||||
check_usage_limits_middleware,
|
||||
monitoring_middleware,
|
||||
get_monitoring_stats,
|
||||
get_lightweight_stats,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PricingService",
|
||||
"UsageTrackingService",
|
||||
"SubscriptionException",
|
||||
"SubscriptionExceptionHandler",
|
||||
"UsageLimitExceededException",
|
||||
"PricingException",
|
||||
"TrackingException",
|
||||
"handle_usage_limit_error",
|
||||
"handle_pricing_error",
|
||||
"handle_tracking_error",
|
||||
"DatabaseAPIMonitor",
|
||||
"check_usage_limits_middleware",
|
||||
"monitoring_middleware",
|
||||
"get_monitoring_stats",
|
||||
"get_lightweight_stats",
|
||||
]
|
||||
412
backend/services/subscription/exception_handler.py
Normal file
412
backend/services/subscription/exception_handler.py
Normal file
@@ -0,0 +1,412 @@
|
||||
"""
|
||||
Comprehensive Exception Handling and Logging for Subscription System
|
||||
Provides robust error handling, logging, and monitoring for the usage-based subscription system.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional, Union, List
|
||||
from enum import Enum
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from models.subscription_models import APIProvider, UsageAlert
|
||||
|
||||
class SubscriptionErrorType(Enum):
|
||||
USAGE_LIMIT_EXCEEDED = "usage_limit_exceeded"
|
||||
PRICING_ERROR = "pricing_error"
|
||||
TRACKING_ERROR = "tracking_error"
|
||||
DATABASE_ERROR = "database_error"
|
||||
API_PROVIDER_ERROR = "api_provider_error"
|
||||
AUTHENTICATION_ERROR = "authentication_error"
|
||||
BILLING_ERROR = "billing_error"
|
||||
CONFIGURATION_ERROR = "configuration_error"
|
||||
|
||||
class SubscriptionErrorSeverity(Enum):
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
class SubscriptionException(Exception):
|
||||
"""Base exception for subscription system errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
error_type: SubscriptionErrorType,
|
||||
severity: SubscriptionErrorSeverity = SubscriptionErrorSeverity.MEDIUM,
|
||||
user_id: str = None,
|
||||
provider: APIProvider = None,
|
||||
context: Dict[str, Any] = None,
|
||||
original_error: Exception = None
|
||||
):
|
||||
self.message = message
|
||||
self.error_type = error_type
|
||||
self.severity = severity
|
||||
self.user_id = user_id
|
||||
self.provider = provider
|
||||
self.context = context or {}
|
||||
self.original_error = original_error
|
||||
self.timestamp = datetime.utcnow()
|
||||
|
||||
super().__init__(message)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert exception to dictionary for logging/storage."""
|
||||
return {
|
||||
"message": self.message,
|
||||
"error_type": self.error_type.value,
|
||||
"severity": self.severity.value,
|
||||
"user_id": self.user_id,
|
||||
"provider": self.provider.value if self.provider else None,
|
||||
"context": self.context,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"original_error": str(self.original_error) if self.original_error else None,
|
||||
"traceback": traceback.format_exc() if self.original_error else None
|
||||
}
|
||||
|
||||
class UsageLimitExceededException(SubscriptionException):
|
||||
"""Exception raised when usage limits are exceeded."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
user_id: str,
|
||||
provider: APIProvider,
|
||||
limit_type: str,
|
||||
current_usage: Union[int, float],
|
||||
limit_value: Union[int, float],
|
||||
context: Dict[str, Any] = None
|
||||
):
|
||||
context = context or {}
|
||||
context.update({
|
||||
"limit_type": limit_type,
|
||||
"current_usage": current_usage,
|
||||
"limit_value": limit_value,
|
||||
"usage_percentage": (current_usage / max(limit_value, 1)) * 100
|
||||
})
|
||||
|
||||
super().__init__(
|
||||
message=message,
|
||||
error_type=SubscriptionErrorType.USAGE_LIMIT_EXCEEDED,
|
||||
severity=SubscriptionErrorSeverity.HIGH,
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
context=context
|
||||
)
|
||||
|
||||
class PricingException(SubscriptionException):
|
||||
"""Exception raised for pricing calculation errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
provider: APIProvider = None,
|
||||
model_name: str = None,
|
||||
context: Dict[str, Any] = None,
|
||||
original_error: Exception = None
|
||||
):
|
||||
context = context or {}
|
||||
if model_name:
|
||||
context["model_name"] = model_name
|
||||
|
||||
super().__init__(
|
||||
message=message,
|
||||
error_type=SubscriptionErrorType.PRICING_ERROR,
|
||||
severity=SubscriptionErrorSeverity.MEDIUM,
|
||||
provider=provider,
|
||||
context=context,
|
||||
original_error=original_error
|
||||
)
|
||||
|
||||
class TrackingException(SubscriptionException):
|
||||
"""Exception raised for usage tracking errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
user_id: str = None,
|
||||
provider: APIProvider = None,
|
||||
context: Dict[str, Any] = None,
|
||||
original_error: Exception = None
|
||||
):
|
||||
super().__init__(
|
||||
message=message,
|
||||
error_type=SubscriptionErrorType.TRACKING_ERROR,
|
||||
severity=SubscriptionErrorSeverity.MEDIUM,
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
context=context,
|
||||
original_error=original_error
|
||||
)
|
||||
|
||||
class SubscriptionExceptionHandler:
|
||||
"""Comprehensive exception handler for the subscription system."""
|
||||
|
||||
def __init__(self, db: Session = None):
|
||||
self.db = db
|
||||
self._setup_logging()
|
||||
|
||||
def _setup_logging(self):
|
||||
"""Setup structured logging for subscription errors."""
|
||||
from utils.logger_utils import get_service_logger
|
||||
return get_service_logger("subscription_exception_handler")
|
||||
|
||||
def handle_exception(
|
||||
self,
|
||||
error: Union[Exception, SubscriptionException],
|
||||
context: Dict[str, Any] = None,
|
||||
log_level: str = "error"
|
||||
) -> Dict[str, Any]:
|
||||
"""Handle and log subscription system exceptions."""
|
||||
|
||||
context = context or {}
|
||||
|
||||
# Convert regular exceptions to SubscriptionException
|
||||
if not isinstance(error, SubscriptionException):
|
||||
error = SubscriptionException(
|
||||
message=str(error),
|
||||
error_type=self._classify_error(error),
|
||||
severity=self._determine_severity(error),
|
||||
context=context,
|
||||
original_error=error
|
||||
)
|
||||
|
||||
# Log the error
|
||||
error_data = error.to_dict()
|
||||
error_data.update(context)
|
||||
|
||||
log_message = f"Subscription Error: {error.message}"
|
||||
|
||||
if log_level == "critical":
|
||||
logger.critical(log_message, extra={"error_data": error_data})
|
||||
elif log_level == "error":
|
||||
logger.error(log_message, extra={"error_data": error_data})
|
||||
elif log_level == "warning":
|
||||
logger.warning(log_message, extra={"error_data": error_data})
|
||||
else:
|
||||
logger.info(log_message, extra={"error_data": error_data})
|
||||
|
||||
# Store critical errors in database for alerting
|
||||
if error.severity in [SubscriptionErrorSeverity.HIGH, SubscriptionErrorSeverity.CRITICAL]:
|
||||
self._store_error_alert(error)
|
||||
|
||||
# Return formatted error response
|
||||
return self._format_error_response(error)
|
||||
|
||||
def _classify_error(self, error: Exception) -> SubscriptionErrorType:
|
||||
"""Classify an exception into a subscription error type."""
|
||||
|
||||
error_str = str(error).lower()
|
||||
error_type_name = type(error).__name__.lower()
|
||||
|
||||
if "limit" in error_str or "exceeded" in error_str:
|
||||
return SubscriptionErrorType.USAGE_LIMIT_EXCEEDED
|
||||
elif "pricing" in error_str or "cost" in error_str:
|
||||
return SubscriptionErrorType.PRICING_ERROR
|
||||
elif "tracking" in error_str or "usage" in error_str:
|
||||
return SubscriptionErrorType.TRACKING_ERROR
|
||||
elif "database" in error_str or "sql" in error_type_name:
|
||||
return SubscriptionErrorType.DATABASE_ERROR
|
||||
elif "api" in error_str or "provider" in error_str:
|
||||
return SubscriptionErrorType.API_PROVIDER_ERROR
|
||||
elif "auth" in error_str or "permission" in error_str:
|
||||
return SubscriptionErrorType.AUTHENTICATION_ERROR
|
||||
elif "billing" in error_str or "payment" in error_str:
|
||||
return SubscriptionErrorType.BILLING_ERROR
|
||||
else:
|
||||
return SubscriptionErrorType.CONFIGURATION_ERROR
|
||||
|
||||
def _determine_severity(self, error: Exception) -> SubscriptionErrorSeverity:
|
||||
"""Determine the severity of an error."""
|
||||
|
||||
error_str = str(error).lower()
|
||||
error_type = type(error)
|
||||
|
||||
# Critical errors
|
||||
if isinstance(error, (SQLAlchemyError, ConnectionError)):
|
||||
return SubscriptionErrorSeverity.CRITICAL
|
||||
|
||||
# High severity errors
|
||||
if "limit exceeded" in error_str or "unauthorized" in error_str:
|
||||
return SubscriptionErrorSeverity.HIGH
|
||||
|
||||
# Medium severity errors
|
||||
if "pricing" in error_str or "tracking" in error_str:
|
||||
return SubscriptionErrorSeverity.MEDIUM
|
||||
|
||||
# Default to low
|
||||
return SubscriptionErrorSeverity.LOW
|
||||
|
||||
def _store_error_alert(self, error: SubscriptionException):
|
||||
"""Store critical errors as alerts in the database."""
|
||||
|
||||
if not self.db or not error.user_id:
|
||||
return
|
||||
|
||||
try:
|
||||
alert = UsageAlert(
|
||||
user_id=error.user_id,
|
||||
alert_type="system_error",
|
||||
threshold_percentage=0,
|
||||
provider=error.provider,
|
||||
title=f"System Error: {error.error_type.value}",
|
||||
message=error.message,
|
||||
severity=error.severity.value,
|
||||
billing_period=datetime.now().strftime("%Y-%m")
|
||||
)
|
||||
|
||||
self.db.add(alert)
|
||||
self.db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store error alert: {e}")
|
||||
|
||||
def _format_error_response(self, error: SubscriptionException) -> Dict[str, Any]:
|
||||
"""Format error for API response."""
|
||||
|
||||
response = {
|
||||
"success": False,
|
||||
"error": {
|
||||
"type": error.error_type.value,
|
||||
"message": error.message,
|
||||
"severity": error.severity.value,
|
||||
"timestamp": error.timestamp.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
# Add context for debugging (non-sensitive info only)
|
||||
if error.context:
|
||||
safe_context = {
|
||||
k: v for k, v in error.context.items()
|
||||
if k not in ["password", "token", "key", "secret"]
|
||||
}
|
||||
response["error"]["context"] = safe_context
|
||||
|
||||
# Add user-friendly message based on error type
|
||||
user_messages = {
|
||||
SubscriptionErrorType.USAGE_LIMIT_EXCEEDED:
|
||||
"You have reached your usage limit. Please upgrade your plan or wait for the next billing cycle.",
|
||||
SubscriptionErrorType.PRICING_ERROR:
|
||||
"There was an issue calculating the cost for this request. Please try again.",
|
||||
SubscriptionErrorType.TRACKING_ERROR:
|
||||
"Unable to track usage for this request. Please contact support if this persists.",
|
||||
SubscriptionErrorType.DATABASE_ERROR:
|
||||
"A database error occurred. Please try again later.",
|
||||
SubscriptionErrorType.API_PROVIDER_ERROR:
|
||||
"There was an issue with the API provider. Please try again.",
|
||||
SubscriptionErrorType.AUTHENTICATION_ERROR:
|
||||
"Authentication failed. Please check your credentials.",
|
||||
SubscriptionErrorType.BILLING_ERROR:
|
||||
"There was a billing-related error. Please contact support.",
|
||||
SubscriptionErrorType.CONFIGURATION_ERROR:
|
||||
"System configuration error. Please contact support."
|
||||
}
|
||||
|
||||
response["error"]["user_message"] = user_messages.get(
|
||||
error.error_type,
|
||||
"An unexpected error occurred. Please try again or contact support."
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
# Utility functions for common error scenarios
|
||||
def handle_usage_limit_error(
|
||||
user_id: str,
|
||||
provider: APIProvider,
|
||||
limit_type: str,
|
||||
current_usage: Union[int, float],
|
||||
limit_value: Union[int, float],
|
||||
db: Session = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Handle usage limit exceeded errors."""
|
||||
|
||||
handler = SubscriptionExceptionHandler(db)
|
||||
error = UsageLimitExceededException(
|
||||
message=f"Usage limit exceeded for {limit_type}",
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
limit_type=limit_type,
|
||||
current_usage=current_usage,
|
||||
limit_value=limit_value
|
||||
)
|
||||
|
||||
return handler.handle_exception(error, log_level="warning")
|
||||
|
||||
def handle_pricing_error(
|
||||
message: str,
|
||||
provider: APIProvider = None,
|
||||
model_name: str = None,
|
||||
original_error: Exception = None,
|
||||
db: Session = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Handle pricing calculation errors."""
|
||||
|
||||
handler = SubscriptionExceptionHandler(db)
|
||||
error = PricingException(
|
||||
message=message,
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
original_error=original_error
|
||||
)
|
||||
|
||||
return handler.handle_exception(error)
|
||||
|
||||
def handle_tracking_error(
|
||||
message: str,
|
||||
user_id: str = None,
|
||||
provider: APIProvider = None,
|
||||
original_error: Exception = None,
|
||||
db: Session = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Handle usage tracking errors."""
|
||||
|
||||
handler = SubscriptionExceptionHandler(db)
|
||||
error = TrackingException(
|
||||
message=message,
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
original_error=original_error
|
||||
)
|
||||
|
||||
return handler.handle_exception(error)
|
||||
|
||||
def log_usage_event(
|
||||
user_id: str,
|
||||
provider: APIProvider,
|
||||
action: str,
|
||||
details: Dict[str, Any] = None
|
||||
):
|
||||
"""Log usage events for monitoring and debugging."""
|
||||
|
||||
details = details or {}
|
||||
log_data = {
|
||||
"user_id": user_id,
|
||||
"provider": provider.value,
|
||||
"action": action,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
**details
|
||||
}
|
||||
|
||||
logger.info(f"Usage Tracking: {action}", extra={"usage_data": log_data})
|
||||
|
||||
# Decorator for automatic exception handling
|
||||
def handle_subscription_errors(db: Session = None):
|
||||
"""Decorator to automatically handle subscription-related exceptions."""
|
||||
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except SubscriptionException as e:
|
||||
handler = SubscriptionExceptionHandler(db)
|
||||
return handler.handle_exception(e)
|
||||
except Exception as e:
|
||||
handler = SubscriptionExceptionHandler(db)
|
||||
return handler.handle_exception(e)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
799
backend/services/subscription/limit_validation.py
Normal file
799
backend/services/subscription/limit_validation.py
Normal file
@@ -0,0 +1,799 @@
|
||||
"""
|
||||
Limit Validation Module
|
||||
Handles subscription limit checking and validation logic.
|
||||
Extracted from pricing_service.py for better modularity.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List, Tuple, TYPE_CHECKING
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy import text
|
||||
from loguru import logger
|
||||
|
||||
from models.subscription_models import (
|
||||
UserSubscription, UsageSummary, SubscriptionPlan,
|
||||
APIProvider, SubscriptionTier
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .pricing_service import PricingService
|
||||
|
||||
|
||||
class LimitValidator:
|
||||
"""Validates subscription limits for API usage."""
|
||||
|
||||
def __init__(self, pricing_service: 'PricingService'):
|
||||
"""
|
||||
Initialize limit validator with reference to PricingService.
|
||||
|
||||
Args:
|
||||
pricing_service: Instance of PricingService to access helper methods and cache
|
||||
"""
|
||||
self.pricing_service = pricing_service
|
||||
self.db = pricing_service.db
|
||||
|
||||
def check_usage_limits(self, user_id: str, provider: APIProvider,
|
||||
tokens_requested: int = 0, actual_provider_name: Optional[str] = None) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""Check if user can make an API call within their limits.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
provider: APIProvider enum (may be MISTRAL for HuggingFace)
|
||||
tokens_requested: Estimated tokens for the request
|
||||
actual_provider_name: Optional actual provider name (e.g., "huggingface" when provider is MISTRAL)
|
||||
|
||||
Returns:
|
||||
(can_proceed, error_message, usage_info)
|
||||
"""
|
||||
try:
|
||||
# Use actual_provider_name if provided, otherwise use enum value
|
||||
# This fixes cases where HuggingFace maps to MISTRAL enum but should show as "huggingface" in errors
|
||||
display_provider_name = actual_provider_name or provider.value
|
||||
|
||||
logger.debug(f"[Subscription Check] Starting limit check for user {user_id}, provider {display_provider_name}, tokens {tokens_requested}")
|
||||
|
||||
# Short TTL cache to reduce DB reads under sustained traffic
|
||||
cache_key = f"{user_id}:{provider.value}"
|
||||
now = datetime.utcnow()
|
||||
cached = self.pricing_service._limits_cache.get(cache_key)
|
||||
if cached and cached.get('expires_at') and cached['expires_at'] > now:
|
||||
logger.debug(f"[Subscription Check] Using cached result for {user_id}:{provider.value}")
|
||||
return tuple(cached['result']) # type: ignore
|
||||
|
||||
# Get user subscription first to check expiration
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
|
||||
if subscription:
|
||||
logger.debug(f"[Subscription Check] Found subscription for user {user_id}: plan_id={subscription.plan_id}, period_end={subscription.current_period_end}")
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] No active subscription found for user {user_id}")
|
||||
|
||||
# Check subscription expiration (STRICT: deny if expired)
|
||||
if subscription:
|
||||
if subscription.current_period_end < now:
|
||||
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}: period_end={subscription.current_period_end}, now={now}")
|
||||
# Subscription expired - check if auto_renew is enabled
|
||||
if not getattr(subscription, 'auto_renew', False):
|
||||
# Expired and no auto-renew - deny access
|
||||
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}, auto_renew=False, denying access")
|
||||
result = (False, "Subscription expired. Please renew your subscription to continue using the service.", {
|
||||
'expired': True,
|
||||
'period_end': subscription.current_period_end.isoformat()
|
||||
})
|
||||
self.pricing_service._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
else:
|
||||
# Try to auto-renew
|
||||
if not self.pricing_service._ensure_subscription_current(subscription):
|
||||
# Auto-renew failed - deny access
|
||||
result = (False, "Subscription expired and auto-renewal failed. Please renew manually.", {
|
||||
'expired': True,
|
||||
'auto_renew_failed': True
|
||||
})
|
||||
self.pricing_service._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
|
||||
# Get user limits with error handling (STRICT: fail on errors)
|
||||
# CRITICAL: Expire SQLAlchemy objects to ensure we get fresh plan data after renewal
|
||||
try:
|
||||
# Force expire subscription and plan objects to avoid stale cache
|
||||
if subscription and subscription.plan_id:
|
||||
plan_obj = self.db.query(SubscriptionPlan).filter(SubscriptionPlan.id == subscription.plan_id).first()
|
||||
if plan_obj:
|
||||
self.db.expire(plan_obj)
|
||||
logger.debug(f"[Subscription Check] Expired plan object to ensure fresh limits after renewal")
|
||||
|
||||
limits = self.pricing_service.get_user_limits(user_id)
|
||||
if limits:
|
||||
logger.debug(f"[Subscription Check] Retrieved limits for user {user_id}: plan={limits.get('plan_name')}, tier={limits.get('tier')}")
|
||||
# Log token limits for debugging
|
||||
token_limits = limits.get('limits', {})
|
||||
logger.debug(f"[Subscription Check] Token limits: gemini={token_limits.get('gemini_tokens')}, mistral={token_limits.get('mistral_tokens')}, openai={token_limits.get('openai_tokens')}, anthropic={token_limits.get('anthropic_tokens')}")
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] No limits found for user {user_id}, checking free tier")
|
||||
except Exception as e:
|
||||
logger.error(f"[Subscription Check] Error getting user limits for {user_id}: {e}", exc_info=True)
|
||||
# STRICT: Fail closed - deny request if we can't check limits
|
||||
return False, f"Failed to retrieve subscription limits: {str(e)}", {}
|
||||
|
||||
if not limits:
|
||||
# No subscription found - check for free tier
|
||||
free_plan = self.db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.FREE,
|
||||
SubscriptionPlan.is_active == True
|
||||
).first()
|
||||
if free_plan:
|
||||
logger.info(f"[Subscription Check] Assigning free tier to user {user_id}")
|
||||
limits = self.pricing_service._plan_to_limits_dict(free_plan)
|
||||
else:
|
||||
# No subscription and no free tier - deny access
|
||||
logger.warning(f"[Subscription Check] No subscription or free tier found for user {user_id}, denying access")
|
||||
return False, "No subscription plan found. Please subscribe to a plan.", {}
|
||||
|
||||
# Get current usage for this billing period with error handling
|
||||
# CRITICAL: Use fresh queries to avoid SQLAlchemy cache after renewal
|
||||
try:
|
||||
current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Expire all objects to force fresh read from DB (critical after renewal)
|
||||
self.db.expire_all()
|
||||
|
||||
# Use raw SQL query first to bypass ORM cache, fallback to ORM if SQL fails
|
||||
usage = None
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
sql_query = text("SELECT * FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
|
||||
result = self.db.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if result:
|
||||
# Map result to UsageSummary object
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if usage:
|
||||
self.db.refresh(usage) # Ensure fresh data
|
||||
except Exception as sql_error:
|
||||
logger.debug(f"[Subscription Check] Raw SQL query failed, using ORM: {sql_error}")
|
||||
# Fallback to ORM query
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if usage:
|
||||
self.db.refresh(usage) # Ensure fresh data
|
||||
|
||||
if not usage:
|
||||
# First usage this period, create summary
|
||||
try:
|
||||
usage = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
self.db.add(usage)
|
||||
self.db.commit()
|
||||
except Exception as create_error:
|
||||
logger.error(f"Error creating usage summary: {create_error}")
|
||||
self.db.rollback()
|
||||
# STRICT: Fail closed on DB error
|
||||
return False, f"Failed to create usage summary: {str(create_error)}", {}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting usage summary for {user_id}: {e}")
|
||||
self.db.rollback()
|
||||
# STRICT: Fail closed on DB error
|
||||
return False, f"Failed to retrieve usage summary: {str(e)}", {}
|
||||
|
||||
# Check call limits with error handling
|
||||
# NOTE: call_limit = 0 means UNLIMITED (Enterprise plans)
|
||||
try:
|
||||
# Use display_provider_name for error messages, but provider.value for DB queries
|
||||
provider_name = provider.value # For DB field names (e.g., "mistral_calls", "mistral_tokens")
|
||||
|
||||
# For LLM text generation providers, check against unified total_calls limit
|
||||
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
|
||||
is_llm_provider = provider_name in llm_providers
|
||||
|
||||
if is_llm_provider:
|
||||
# Use unified AI text generation limit (total_calls across all LLM providers)
|
||||
ai_text_gen_limit = limits['limits'].get('ai_text_generation_calls', 0) or 0
|
||||
|
||||
# If unified limit not set, fall back to provider-specific limit for backwards compatibility
|
||||
if ai_text_gen_limit == 0:
|
||||
ai_text_gen_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Calculate total LLM provider calls (sum of gemini + openai + anthropic + mistral)
|
||||
current_total_llm_calls = (
|
||||
(usage.gemini_calls or 0) +
|
||||
(usage.openai_calls or 0) +
|
||||
(usage.anthropic_calls or 0) +
|
||||
(usage.mistral_calls or 0)
|
||||
)
|
||||
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if ai_text_gen_limit > 0 and current_total_llm_calls >= ai_text_gen_limit:
|
||||
logger.error(f"[Subscription Check] AI text generation call limit exceeded for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit} (provider: {display_provider_name})")
|
||||
result = (False, f"AI text generation call limit reached. Used {current_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls this billing period.", {
|
||||
'current_calls': current_total_llm_calls,
|
||||
'limit': ai_text_gen_limit,
|
||||
'usage_percentage': (current_total_llm_calls / ai_text_gen_limit) * 100 if ai_text_gen_limit > 0 else 0,
|
||||
'provider': display_provider_name, # Use display name for consistency
|
||||
'usage_info': {
|
||||
'provider': display_provider_name, # Use display name for user-facing info
|
||||
'current_calls': current_total_llm_calls,
|
||||
'limit': ai_text_gen_limit,
|
||||
'type': 'ai_text_generation',
|
||||
'breakdown': {
|
||||
'gemini': usage.gemini_calls or 0,
|
||||
'openai': usage.openai_calls or 0,
|
||||
'anthropic': usage.anthropic_calls or 0,
|
||||
'mistral': usage.mistral_calls or 0 # DB field name (not display name)
|
||||
}
|
||||
}
|
||||
})
|
||||
self.pricing_service._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] AI text generation limit check passed for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit if ai_text_gen_limit > 0 else 'unlimited'} (provider: {display_provider_name})")
|
||||
else:
|
||||
# For non-LLM providers, check provider-specific limit
|
||||
current_calls = getattr(usage, f"{provider_name}_calls", 0) or 0
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if call_limit > 0 and current_calls >= call_limit:
|
||||
logger.error(f"[Subscription Check] Call limit exceeded for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit}")
|
||||
result = (False, f"API call limit reached for {display_provider_name}. Used {current_calls} of {call_limit} calls this billing period.", {
|
||||
'current_calls': current_calls,
|
||||
'limit': call_limit,
|
||||
'usage_percentage': 100.0,
|
||||
'provider': display_provider_name # Use display name for consistency
|
||||
})
|
||||
self.pricing_service._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] Call limit check passed for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit if call_limit > 0 else 'unlimited'}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking call limits: {e}")
|
||||
# Continue to next check
|
||||
|
||||
# Check token limits for LLM providers with error handling
|
||||
# NOTE: token_limit = 0 means UNLIMITED (Enterprise plans)
|
||||
try:
|
||||
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
current_tokens = getattr(usage, f"{provider_name}_tokens", 0) or 0
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
|
||||
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if token_limit > 0 and (current_tokens + tokens_requested) > token_limit:
|
||||
result = (False, f"Token limit would be exceeded for {display_provider_name}. Current: {current_tokens}, Requested: {tokens_requested}, Limit: {token_limit}", {
|
||||
'current_tokens': current_tokens,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100,
|
||||
'provider': display_provider_name, # Use display name in error details
|
||||
'usage_info': {
|
||||
'provider': display_provider_name,
|
||||
'current_tokens': current_tokens,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'type': 'tokens'
|
||||
}
|
||||
})
|
||||
self.pricing_service._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking token limits: {e}")
|
||||
# Continue to next check
|
||||
|
||||
# Check cost limits with error handling
|
||||
# NOTE: cost_limit = 0 means UNLIMITED (Enterprise plans)
|
||||
try:
|
||||
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if cost_limit > 0 and usage.total_cost >= cost_limit:
|
||||
result = (False, f"Monthly cost limit reached. Current cost: ${usage.total_cost:.2f}, Limit: ${cost_limit:.2f}", {
|
||||
'current_cost': usage.total_cost,
|
||||
'limit': cost_limit,
|
||||
'usage_percentage': 100.0
|
||||
})
|
||||
self.pricing_service._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking cost limits: {e}")
|
||||
# Continue to success case
|
||||
|
||||
# Calculate usage percentages for warnings
|
||||
try:
|
||||
# Determine which call variables to use based on provider type
|
||||
if is_llm_provider:
|
||||
# Use unified LLM call tracking
|
||||
current_call_count = current_total_llm_calls
|
||||
call_limit_value = ai_text_gen_limit
|
||||
else:
|
||||
# Use provider-specific call tracking
|
||||
current_call_count = current_calls
|
||||
call_limit_value = call_limit
|
||||
|
||||
call_usage_pct = (current_call_count / max(call_limit_value, 1)) * 100 if call_limit_value > 0 else 0
|
||||
cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0
|
||||
result = (True, "Within limits", {
|
||||
'current_calls': current_call_count,
|
||||
'call_limit': call_limit_value,
|
||||
'call_usage_percentage': call_usage_pct,
|
||||
'current_cost': usage.total_cost,
|
||||
'cost_limit': cost_limit,
|
||||
'cost_usage_percentage': cost_usage_pct
|
||||
})
|
||||
self.pricing_service._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating usage percentages: {e}")
|
||||
# Return basic success
|
||||
return True, "Within limits", {}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in check_usage_limits for {user_id}: {e}")
|
||||
# STRICT: Fail closed - deny requests if subscription system fails
|
||||
return False, f"Subscription check error: {str(e)}", {}
|
||||
|
||||
def check_comprehensive_limits(
|
||||
self,
|
||||
user_id: str,
|
||||
operations: List[Dict[str, Any]]
|
||||
) -> Tuple[bool, Optional[str], Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
Comprehensive pre-flight validation that checks ALL limits before making ANY API calls.
|
||||
|
||||
This prevents wasteful API calls by validating that ALL subsequent operations will succeed
|
||||
before making the first external API call.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
operations: List of operations to validate, each with:
|
||||
- 'provider': APIProvider enum
|
||||
- 'tokens_requested': int (estimated tokens for LLM calls, 0 for non-LLM)
|
||||
- 'actual_provider_name': Optional[str] (e.g., "huggingface" when provider is MISTRAL)
|
||||
- 'operation_type': str (e.g., "google_grounding", "llm_call", "image_generation")
|
||||
|
||||
Returns:
|
||||
(can_proceed, error_message, error_details)
|
||||
If can_proceed is False, error_message explains which limit would be exceeded
|
||||
"""
|
||||
try:
|
||||
logger.info(f"[Pre-flight Check] 🔍 Starting comprehensive validation for user {user_id}")
|
||||
logger.info(f"[Pre-flight Check] 📋 Validating {len(operations)} operation(s) before making any API calls")
|
||||
|
||||
# Get current usage and limits once
|
||||
current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
logger.info(f"[Pre-flight Check] 📅 Billing Period: {current_period} (for user {user_id})")
|
||||
|
||||
# Ensure schema columns exist before querying
|
||||
try:
|
||||
from services.subscription.schema_utils import ensure_usage_summaries_columns
|
||||
ensure_usage_summaries_columns(self.db)
|
||||
except Exception as schema_err:
|
||||
logger.warning(f"Schema check failed, will retry on query error: {schema_err}")
|
||||
|
||||
# Explicitly expire any cached objects and refresh from DB to ensure fresh data
|
||||
self.db.expire_all()
|
||||
|
||||
try:
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
# CRITICAL: Explicitly refresh from database to get latest values (clears SQLAlchemy cache)
|
||||
if usage:
|
||||
self.db.refresh(usage)
|
||||
except Exception as query_err:
|
||||
error_str = str(query_err).lower()
|
||||
if 'no such column' in error_str and 'exa_calls' in error_str:
|
||||
logger.warning("Missing column detected in usage query, fixing schema and retrying...")
|
||||
import sqlite3
|
||||
import services.subscription.schema_utils as schema_utils
|
||||
schema_utils._checked_usage_summaries_columns = False
|
||||
from services.subscription.schema_utils import ensure_usage_summaries_columns
|
||||
ensure_usage_summaries_columns(self.db)
|
||||
self.db.expire_all()
|
||||
# Retry the query
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if usage:
|
||||
self.db.refresh(usage)
|
||||
else:
|
||||
raise
|
||||
|
||||
# Log what we actually read from database
|
||||
if usage:
|
||||
logger.info(f"[Pre-flight Check] 📊 Usage Summary from DB (Period: {current_period}):")
|
||||
logger.info(f" ├─ Gemini: {usage.gemini_tokens or 0} tokens / {usage.gemini_calls or 0} calls")
|
||||
logger.info(f" ├─ Mistral/HF: {usage.mistral_tokens or 0} tokens / {usage.mistral_calls or 0} calls")
|
||||
logger.info(f" ├─ Total Tokens: {usage.total_tokens or 0}")
|
||||
logger.info(f" └─ Usage Status: {usage.usage_status.value if usage.usage_status else 'N/A'}")
|
||||
else:
|
||||
logger.info(f"[Pre-flight Check] 📊 No usage summary found for period {current_period} (will create new)")
|
||||
|
||||
if not usage:
|
||||
# First usage this period, create summary
|
||||
try:
|
||||
usage = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
self.db.add(usage)
|
||||
self.db.commit()
|
||||
except Exception as create_error:
|
||||
logger.error(f"Error creating usage summary: {create_error}")
|
||||
self.db.rollback()
|
||||
return False, f"Failed to create usage summary: {str(create_error)}", {}
|
||||
|
||||
# Get user limits
|
||||
limits_dict = self.pricing_service.get_user_limits(user_id)
|
||||
if not limits_dict:
|
||||
# No subscription found - check for free tier
|
||||
free_plan = self.db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.FREE,
|
||||
SubscriptionPlan.is_active == True
|
||||
).first()
|
||||
if free_plan:
|
||||
limits_dict = self.pricing_service._plan_to_limits_dict(free_plan)
|
||||
else:
|
||||
return False, "No subscription plan found. Please subscribe to a plan.", {}
|
||||
|
||||
limits = limits_dict.get('limits', {})
|
||||
|
||||
# Track cumulative usage across all operations
|
||||
total_llm_calls = (
|
||||
(usage.gemini_calls or 0) +
|
||||
(usage.openai_calls or 0) +
|
||||
(usage.anthropic_calls or 0) +
|
||||
(usage.mistral_calls or 0)
|
||||
)
|
||||
total_llm_tokens = {}
|
||||
total_images = usage.stability_calls or 0
|
||||
|
||||
# Log current usage summary
|
||||
logger.info(f"[Pre-flight Check] 📊 Current Usage Summary:")
|
||||
logger.info(f" └─ Total LLM Calls: {total_llm_calls}")
|
||||
logger.info(f" └─ Gemini Tokens: {usage.gemini_tokens or 0}, Mistral/HF Tokens: {usage.mistral_tokens or 0}")
|
||||
logger.info(f" └─ Image Calls: {total_images}")
|
||||
|
||||
# Validate each operation
|
||||
for op_idx, operation in enumerate(operations):
|
||||
provider = operation.get('provider')
|
||||
provider_name = provider.value if hasattr(provider, 'value') else str(provider)
|
||||
tokens_requested = operation.get('tokens_requested', 0)
|
||||
actual_provider_name = operation.get('actual_provider_name')
|
||||
operation_type = operation.get('operation_type', 'unknown')
|
||||
|
||||
display_provider_name = actual_provider_name or provider_name
|
||||
|
||||
# Log operation details at debug level (only when needed)
|
||||
logger.debug(f"[Pre-flight] Operation {op_idx + 1}/{len(operations)}: {operation_type} ({display_provider_name}, {tokens_requested} tokens)")
|
||||
|
||||
# Check if this is an LLM provider
|
||||
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
|
||||
is_llm_provider = provider_name in llm_providers
|
||||
|
||||
# Check unified AI text generation limit for LLM providers
|
||||
if is_llm_provider:
|
||||
ai_text_gen_limit = limits.get('ai_text_generation_calls', 0) or 0
|
||||
if ai_text_gen_limit == 0:
|
||||
# Fallback to provider-specific limit
|
||||
ai_text_gen_limit = limits.get(f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Count this operation as an LLM call
|
||||
projected_total_llm_calls = total_llm_calls + 1
|
||||
|
||||
if ai_text_gen_limit > 0 and projected_total_llm_calls > ai_text_gen_limit:
|
||||
error_info = {
|
||||
'current_calls': total_llm_calls,
|
||||
'limit': ai_text_gen_limit,
|
||||
'provider': display_provider_name,
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"AI text generation call limit would be exceeded. Would use {projected_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls.", {
|
||||
'error_type': 'call_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
# Check token limits for this provider
|
||||
# CRITICAL: Always query fresh from DB for each operation to avoid SQLAlchemy cache issues
|
||||
# This ensures we get the latest values after subscription renewal, even for cumulative tracking
|
||||
provider_tokens_key = f"{provider_name}_tokens"
|
||||
|
||||
# Try to get fresh value from DB with comprehensive error handling
|
||||
base_current_tokens = 0
|
||||
query_succeeded = False
|
||||
|
||||
try:
|
||||
# Validate column name is safe (only allow known provider token columns)
|
||||
valid_token_columns = ['gemini_tokens', 'openai_tokens', 'anthropic_tokens', 'mistral_tokens']
|
||||
|
||||
if provider_tokens_key not in valid_token_columns:
|
||||
logger.error(f" └─ Invalid provider tokens key: {provider_tokens_key}")
|
||||
query_succeeded = True # Treat as success with 0 value
|
||||
else:
|
||||
# Method 1: Try raw SQL query to completely bypass ORM cache
|
||||
try:
|
||||
logger.debug(f" └─ Attempting raw SQL query for {provider_tokens_key}")
|
||||
sql_query = text(f"""
|
||||
SELECT {provider_tokens_key}
|
||||
FROM usage_summaries
|
||||
WHERE user_id = :user_id
|
||||
AND billing_period = :period
|
||||
LIMIT 1
|
||||
""")
|
||||
|
||||
logger.debug(f" └─ SQL: SELECT {provider_tokens_key} FROM usage_summaries WHERE user_id={user_id} AND billing_period={current_period}")
|
||||
|
||||
result = self.db.execute(sql_query, {
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
}).first()
|
||||
|
||||
if result:
|
||||
base_current_tokens = result[0] if result[0] is not None else 0
|
||||
else:
|
||||
base_current_tokens = 0
|
||||
|
||||
query_succeeded = True
|
||||
logger.debug(f"[Pre-flight] Raw SQL query for {provider_tokens_key}: {base_current_tokens}")
|
||||
|
||||
except Exception as sql_error:
|
||||
logger.error(f" └─ Raw SQL query failed for {provider_tokens_key}: {type(sql_error).__name__}: {sql_error}", exc_info=True)
|
||||
query_succeeded = False # Will try ORM fallback
|
||||
|
||||
# Method 2: Fallback to fresh ORM query if raw SQL fails
|
||||
if not query_succeeded:
|
||||
try:
|
||||
# Expire all cached objects and do fresh query
|
||||
self.db.expire_all()
|
||||
fresh_usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if fresh_usage:
|
||||
# Explicitly refresh to get latest from DB
|
||||
self.db.refresh(fresh_usage)
|
||||
base_current_tokens = getattr(fresh_usage, provider_tokens_key, 0) or 0
|
||||
else:
|
||||
base_current_tokens = 0
|
||||
|
||||
query_succeeded = True
|
||||
logger.info(f"[Pre-flight Check] ✅ ORM fallback query succeeded for {provider_tokens_key}: {base_current_tokens}")
|
||||
|
||||
except Exception as orm_error:
|
||||
logger.error(f" └─ ORM query also failed: {orm_error}", exc_info=True)
|
||||
query_succeeded = False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" └─ Unexpected error getting tokens from DB for {provider_tokens_key}: {e}", exc_info=True)
|
||||
base_current_tokens = 0 # Fail safe - assume 0 if we can't query
|
||||
|
||||
if not query_succeeded:
|
||||
logger.warning(f" └─ Both query methods failed, using 0 as fallback")
|
||||
|
||||
# Log DB query result at debug level (only when needed for troubleshooting)
|
||||
logger.debug(f"[Pre-flight] DB query for {display_provider_name} ({provider_tokens_key}): {base_current_tokens} (period: {current_period})")
|
||||
|
||||
# Add any projected tokens from previous operations in this validation run
|
||||
# Note: total_llm_tokens tracks ONLY projected tokens from this run, not base DB value
|
||||
projected_from_previous = total_llm_tokens.get(provider_tokens_key, 0)
|
||||
|
||||
# Current tokens = base from DB + projected from previous operations in this run
|
||||
current_provider_tokens = base_current_tokens + projected_from_previous
|
||||
|
||||
# Log token calculation at debug level
|
||||
logger.debug(f"[Pre-flight] Token calc for {display_provider_name}: base={base_current_tokens}, projected={projected_from_previous}, total={current_provider_tokens}")
|
||||
|
||||
token_limit = limits.get(provider_tokens_key, 0) or 0
|
||||
|
||||
if token_limit > 0 and tokens_requested > 0:
|
||||
projected_tokens = current_provider_tokens + tokens_requested
|
||||
logger.info(f" └─ Token Check: {current_provider_tokens} (current) + {tokens_requested} (requested) = {projected_tokens} (total) / {token_limit} (limit)")
|
||||
|
||||
if projected_tokens > token_limit:
|
||||
usage_percentage = (projected_tokens / token_limit) * 100 if token_limit > 0 else 0
|
||||
error_info = {
|
||||
'current_tokens': current_provider_tokens,
|
||||
'base_tokens_from_db': base_current_tokens,
|
||||
'projected_from_previous_ops': projected_from_previous,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'provider': display_provider_name,
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
# Make error message clearer: show actual DB usage vs projected
|
||||
if projected_from_previous > 0:
|
||||
error_msg = (
|
||||
f"Token limit exceeded for {display_provider_name} "
|
||||
f"({operation_type}). "
|
||||
f"Base usage: {base_current_tokens}/{token_limit}, "
|
||||
f"After previous operations in this workflow: {current_provider_tokens}/{token_limit}, "
|
||||
f"This operation would add: {tokens_requested}, "
|
||||
f"Total would be: {projected_tokens} (exceeds by {projected_tokens - token_limit} tokens)"
|
||||
)
|
||||
else:
|
||||
error_msg = (
|
||||
f"Token limit exceeded for {display_provider_name} "
|
||||
f"({operation_type}). "
|
||||
f"Current: {current_provider_tokens}/{token_limit}, "
|
||||
f"Requested: {tokens_requested}, "
|
||||
f"Would exceed by: {projected_tokens - token_limit} tokens "
|
||||
f"({usage_percentage:.1f}% of limit)"
|
||||
)
|
||||
logger.error(f"[Pre-flight Check] ❌ BLOCKED: {error_msg}")
|
||||
return False, error_msg, {
|
||||
'error_type': 'token_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
else:
|
||||
logger.info(f" └─ ✅ Token limit check passed: {projected_tokens} <= {token_limit}")
|
||||
|
||||
# Update cumulative counts for next operation
|
||||
total_llm_calls = projected_total_llm_calls
|
||||
# Update cumulative projected tokens from this validation run
|
||||
# This represents only projected tokens from previous operations in this run
|
||||
# Base DB value is always queried fresh, so we only track the projection delta
|
||||
old_projected = total_llm_tokens.get(provider_tokens_key, 0)
|
||||
if tokens_requested > 0:
|
||||
# Add this operation's tokens to cumulative projected tokens
|
||||
total_llm_tokens[provider_tokens_key] = projected_from_previous + tokens_requested
|
||||
logger.debug(f"[Pre-flight] Updated projected tokens for {display_provider_name}: {projected_from_previous} + {tokens_requested} = {total_llm_tokens[provider_tokens_key]}")
|
||||
else:
|
||||
# No tokens requested, keep existing projected tokens (or 0 if first operation)
|
||||
total_llm_tokens[provider_tokens_key] = projected_from_previous
|
||||
|
||||
# Check image generation limits
|
||||
elif provider == APIProvider.STABILITY:
|
||||
image_limit = limits.get('stability_calls', 0) or 0
|
||||
projected_images = total_images + 1
|
||||
|
||||
if image_limit > 0 and projected_images > image_limit:
|
||||
error_info = {
|
||||
'current_images': total_images,
|
||||
'limit': image_limit,
|
||||
'provider': 'stability',
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"Image generation limit would be exceeded. Would use {projected_images} of {image_limit} images this billing period.", {
|
||||
'error_type': 'image_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
total_images = projected_images
|
||||
|
||||
# Check video generation limits
|
||||
elif provider == APIProvider.VIDEO:
|
||||
video_limit = limits.get('video_calls', 0) or 0
|
||||
total_video_calls = usage.video_calls or 0
|
||||
projected_video_calls = total_video_calls + 1
|
||||
|
||||
if video_limit > 0 and projected_video_calls > video_limit:
|
||||
error_info = {
|
||||
'current_calls': total_video_calls,
|
||||
'limit': video_limit,
|
||||
'provider': 'video',
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"Video generation limit would be exceeded. Would use {projected_video_calls} of {video_limit} videos this billing period.", {
|
||||
'error_type': 'video_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
# Check image editing limits
|
||||
elif provider == APIProvider.IMAGE_EDIT:
|
||||
image_edit_limit = limits.get('image_edit_calls', 0) or 0
|
||||
total_image_edit_calls = getattr(usage, 'image_edit_calls', 0) or 0
|
||||
projected_image_edit_calls = total_image_edit_calls + 1
|
||||
|
||||
if image_edit_limit > 0 and projected_image_edit_calls > image_edit_limit:
|
||||
error_info = {
|
||||
'current_calls': total_image_edit_calls,
|
||||
'limit': image_edit_limit,
|
||||
'provider': 'image_edit',
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"Image editing limit would be exceeded. Would use {projected_image_edit_calls} of {image_edit_limit} image edits this billing period.", {
|
||||
'error_type': 'image_edit_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
# Check other provider-specific limits
|
||||
else:
|
||||
provider_calls_key = f"{provider_name}_calls"
|
||||
current_provider_calls = getattr(usage, provider_calls_key, 0) or 0
|
||||
call_limit = limits.get(provider_calls_key, 0) or 0
|
||||
|
||||
if call_limit > 0:
|
||||
projected_calls = current_provider_calls + 1
|
||||
if projected_calls > call_limit:
|
||||
error_info = {
|
||||
'current_calls': current_provider_calls,
|
||||
'limit': call_limit,
|
||||
'provider': display_provider_name,
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"API call limit would be exceeded for {display_provider_name}. Would use {projected_calls} of {call_limit} calls this billing period.", {
|
||||
'error_type': 'call_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
# All checks passed
|
||||
logger.info(f"[Pre-flight Check] ✅ All {len(operations)} operation(s) validated successfully")
|
||||
logger.info(f"[Pre-flight Check] ✅ User {user_id} is cleared to proceed with API calls")
|
||||
return True, None, None
|
||||
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
error_message = str(e).lower()
|
||||
|
||||
# Handle missing column errors with schema fix and retry
|
||||
if 'operationalerror' in error_type.lower() or 'operationalerror' in error_message:
|
||||
if 'no such column' in error_message and 'exa_calls' in error_message:
|
||||
logger.warning("Missing column detected in limit check, attempting schema fix...")
|
||||
try:
|
||||
import sqlite3
|
||||
import services.subscription.schema_utils as schema_utils
|
||||
schema_utils._checked_usage_summaries_columns = False
|
||||
from services.subscription.schema_utils import ensure_usage_summaries_columns
|
||||
ensure_usage_summaries_columns(self.db)
|
||||
self.db.expire_all()
|
||||
|
||||
# Retry the query
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if usage:
|
||||
self.db.refresh(usage)
|
||||
|
||||
# Continue with the rest of the validation using the retried usage
|
||||
# (The rest of the function logic continues from here)
|
||||
# For now, we'll let it fall through to return the error since we'd need to duplicate the entire validation logic
|
||||
# Instead, we'll just log and return, but the next call should succeed
|
||||
logger.info(f"[Pre-flight Check] Schema fixed, but need to retry validation on next call")
|
||||
return False, f"Schema updated, please retry: Database schema was updated. Please try again.", {'error_type': 'schema_update', 'retry': True}
|
||||
except Exception as retry_err:
|
||||
logger.error(f"Schema fix and retry failed: {retry_err}")
|
||||
return False, f"Failed to validate limits: {error_type}: {str(e)}", {}
|
||||
|
||||
logger.error(f"[Pre-flight Check] ❌ Error during comprehensive limit check: {error_type}: {str(e)}", exc_info=True)
|
||||
logger.error(f"[Pre-flight Check] ❌ User: {user_id}, Operations count: {len(operations) if operations else 0}")
|
||||
return False, f"Failed to validate limits: {error_type}: {str(e)}", {}
|
||||
|
||||
231
backend/services/subscription/log_wrapping_service.py
Normal file
231
backend/services/subscription/log_wrapping_service.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
Log Wrapping Service
|
||||
Intelligently wraps API usage logs when they exceed 5000 records.
|
||||
Aggregates old logs into cumulative records while preserving historical data.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func, desc
|
||||
from loguru import logger
|
||||
|
||||
from models.subscription_models import APIUsageLog, APIProvider
|
||||
|
||||
|
||||
class LogWrappingService:
|
||||
"""Service for wrapping and aggregating API usage logs."""
|
||||
|
||||
MAX_LOGS_PER_USER = 5000
|
||||
AGGREGATION_THRESHOLD_DAYS = 30 # Aggregate logs older than 30 days
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def check_and_wrap_logs(self, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Check if user has exceeded log limit and wrap if necessary.
|
||||
|
||||
Returns:
|
||||
Dict with wrapping status and statistics
|
||||
"""
|
||||
try:
|
||||
# Count total logs for user
|
||||
total_count = self.db.query(func.count(APIUsageLog.id)).filter(
|
||||
APIUsageLog.user_id == user_id
|
||||
).scalar() or 0
|
||||
|
||||
if total_count <= self.MAX_LOGS_PER_USER:
|
||||
return {
|
||||
'wrapped': False,
|
||||
'total_logs': total_count,
|
||||
'max_logs': self.MAX_LOGS_PER_USER,
|
||||
'message': f'Log count ({total_count}) is within limit ({self.MAX_LOGS_PER_USER})'
|
||||
}
|
||||
|
||||
# Need to wrap logs - aggregate old logs
|
||||
logger.info(f"[LogWrapping] User {user_id} has {total_count} logs, exceeding limit of {self.MAX_LOGS_PER_USER}. Starting wrap...")
|
||||
|
||||
wrap_result = self._wrap_old_logs(user_id, total_count)
|
||||
|
||||
return {
|
||||
'wrapped': True,
|
||||
'total_logs_before': total_count,
|
||||
'total_logs_after': wrap_result['logs_remaining'],
|
||||
'aggregated_logs': wrap_result['aggregated_count'],
|
||||
'aggregated_periods': wrap_result['periods'],
|
||||
'message': f'Wrapped {wrap_result["aggregated_count"]} logs into {len(wrap_result["periods"])} aggregated records'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[LogWrapping] Error checking/wrapping logs for user {user_id}: {e}", exc_info=True)
|
||||
return {
|
||||
'wrapped': False,
|
||||
'error': str(e),
|
||||
'message': f'Error wrapping logs: {str(e)}'
|
||||
}
|
||||
|
||||
def _wrap_old_logs(self, user_id: str, total_count: int) -> Dict[str, Any]:
|
||||
"""
|
||||
Aggregate old logs into cumulative records.
|
||||
|
||||
Strategy:
|
||||
1. Keep most recent 4000 logs (detailed)
|
||||
2. Aggregate logs older than 30 days or oldest logs beyond 4000
|
||||
3. Create aggregated records grouped by provider and billing period
|
||||
4. Delete individual logs that were aggregated
|
||||
"""
|
||||
try:
|
||||
# Calculate how many logs to keep (4000 detailed, rest aggregated)
|
||||
logs_to_keep = 4000
|
||||
logs_to_aggregate = total_count - logs_to_keep
|
||||
|
||||
# Get cutoff date (30 days ago)
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=self.AGGREGATION_THRESHOLD_DAYS)
|
||||
|
||||
# Get logs to aggregate: oldest logs beyond the keep limit
|
||||
# Order by timestamp ascending to get oldest first
|
||||
# We'll keep the most recent logs_to_keep logs, aggregate the rest
|
||||
logs_to_process = self.db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id
|
||||
).order_by(APIUsageLog.timestamp.asc()).limit(logs_to_aggregate).all()
|
||||
|
||||
if not logs_to_process:
|
||||
return {
|
||||
'aggregated_count': 0,
|
||||
'logs_remaining': total_count,
|
||||
'periods': []
|
||||
}
|
||||
|
||||
# Group logs by provider and billing period for aggregation
|
||||
aggregated_data: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
for log in logs_to_process:
|
||||
# Use provider value as key (e.g., "mistral" for huggingface)
|
||||
provider_key = log.provider.value
|
||||
# Special handling: if provider is MISTRAL but we want to show as huggingface
|
||||
if provider_key == "mistral":
|
||||
# Check if this is actually huggingface by looking at model or endpoint
|
||||
# For now, we'll use "mistral" as the key but store actual provider name
|
||||
provider_display = "huggingface" if "huggingface" in (log.model_used or "").lower() else "mistral"
|
||||
else:
|
||||
provider_display = provider_key
|
||||
|
||||
period_key = f"{provider_display}_{log.billing_period}"
|
||||
|
||||
if period_key not in aggregated_data:
|
||||
aggregated_data[period_key] = {
|
||||
'provider': log.provider,
|
||||
'billing_period': log.billing_period,
|
||||
'count': 0,
|
||||
'total_tokens_input': 0,
|
||||
'total_tokens_output': 0,
|
||||
'total_tokens': 0,
|
||||
'total_cost_input': 0.0,
|
||||
'total_cost_output': 0.0,
|
||||
'total_cost': 0.0,
|
||||
'total_response_time': 0.0,
|
||||
'success_count': 0,
|
||||
'failed_count': 0,
|
||||
'oldest_timestamp': log.timestamp,
|
||||
'newest_timestamp': log.timestamp,
|
||||
'log_ids': []
|
||||
}
|
||||
|
||||
agg = aggregated_data[period_key]
|
||||
agg['count'] += 1
|
||||
agg['total_tokens_input'] += log.tokens_input or 0
|
||||
agg['total_tokens_output'] += log.tokens_output or 0
|
||||
agg['total_tokens'] += log.tokens_total or 0
|
||||
agg['total_cost_input'] += float(log.cost_input or 0.0)
|
||||
agg['total_cost_output'] += float(log.cost_output or 0.0)
|
||||
agg['total_cost'] += float(log.cost_total or 0.0)
|
||||
agg['total_response_time'] += float(log.response_time or 0.0)
|
||||
|
||||
if 200 <= log.status_code < 300:
|
||||
agg['success_count'] += 1
|
||||
else:
|
||||
agg['failed_count'] += 1
|
||||
|
||||
if log.timestamp:
|
||||
if log.timestamp < agg['oldest_timestamp']:
|
||||
agg['oldest_timestamp'] = log.timestamp
|
||||
if log.timestamp > agg['newest_timestamp']:
|
||||
agg['newest_timestamp'] = log.timestamp
|
||||
|
||||
agg['log_ids'].append(log.id)
|
||||
|
||||
# Create aggregated log entries
|
||||
aggregated_count = 0
|
||||
periods_created = []
|
||||
|
||||
for period_key, agg_data in aggregated_data.items():
|
||||
# Calculate averages
|
||||
count = agg_data['count']
|
||||
avg_response_time = agg_data['total_response_time'] / count if count > 0 else 0.0
|
||||
|
||||
# Create aggregated log entry
|
||||
aggregated_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=agg_data['provider'],
|
||||
endpoint='[AGGREGATED]',
|
||||
method='AGGREGATED',
|
||||
model_used=f"[{count} calls aggregated]",
|
||||
tokens_input=agg_data['total_tokens_input'],
|
||||
tokens_output=agg_data['total_tokens_output'],
|
||||
tokens_total=agg_data['total_tokens'],
|
||||
cost_input=agg_data['total_cost_input'],
|
||||
cost_output=agg_data['total_cost_output'],
|
||||
cost_total=agg_data['total_cost'],
|
||||
response_time=avg_response_time,
|
||||
status_code=200 if agg_data['success_count'] > agg_data['failed_count'] else 500,
|
||||
error_message=f"Aggregated {count} calls: {agg_data['success_count']} success, {agg_data['failed_count']} failed",
|
||||
retry_count=0,
|
||||
timestamp=agg_data['oldest_timestamp'], # Use oldest timestamp
|
||||
billing_period=agg_data['billing_period']
|
||||
)
|
||||
|
||||
self.db.add(aggregated_log)
|
||||
periods_created.append({
|
||||
'provider': agg_data['provider'].value,
|
||||
'billing_period': agg_data['billing_period'],
|
||||
'count': count,
|
||||
'period_start': agg_data['oldest_timestamp'].isoformat() if agg_data['oldest_timestamp'] else None,
|
||||
'period_end': agg_data['newest_timestamp'].isoformat() if agg_data['newest_timestamp'] else None
|
||||
})
|
||||
|
||||
aggregated_count += count
|
||||
|
||||
# Delete individual logs that were aggregated
|
||||
log_ids_to_delete = []
|
||||
for agg_data in aggregated_data.values():
|
||||
log_ids_to_delete.extend(agg_data['log_ids'])
|
||||
|
||||
if log_ids_to_delete:
|
||||
self.db.query(APIUsageLog).filter(
|
||||
APIUsageLog.id.in_(log_ids_to_delete)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
# Get remaining log count
|
||||
remaining_count = self.db.query(func.count(APIUsageLog.id)).filter(
|
||||
APIUsageLog.user_id == user_id
|
||||
).scalar() or 0
|
||||
|
||||
logger.info(
|
||||
f"[LogWrapping] Wrapped {aggregated_count} logs into {len(periods_created)} aggregated records. "
|
||||
f"Remaining logs: {remaining_count}"
|
||||
)
|
||||
|
||||
return {
|
||||
'aggregated_count': aggregated_count,
|
||||
'logs_remaining': remaining_count,
|
||||
'periods': periods_created
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"[LogWrapping] Error wrapping logs: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
388
backend/services/subscription/monitoring_middleware.py
Normal file
388
backend/services/subscription/monitoring_middleware.py
Normal file
@@ -0,0 +1,388 @@
|
||||
"""
|
||||
Enhanced FastAPI Monitoring Middleware
|
||||
Database-backed monitoring for API calls, errors, performance metrics, and usage tracking.
|
||||
Includes comprehensive subscription-based usage monitoring and cost tracking.
|
||||
"""
|
||||
|
||||
from fastapi import Request, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
import time
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any, Optional
|
||||
from collections import defaultdict, deque
|
||||
import asyncio
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, func
|
||||
import re
|
||||
|
||||
from models.api_monitoring import APIRequest, APIEndpointStats, SystemHealth, CachePerformance
|
||||
from models.subscription_models import APIProvider
|
||||
from .usage_tracking_service import UsageTrackingService
|
||||
from .pricing_service import PricingService
|
||||
|
||||
|
||||
def _get_db_session():
|
||||
"""
|
||||
Get a database session with lazy import to survive hot reloads.
|
||||
Uvicorn's reloader can sometimes clear module-level imports.
|
||||
"""
|
||||
from services.database import get_db
|
||||
return next(get_db())
|
||||
|
||||
class DatabaseAPIMonitor:
|
||||
"""Database-backed API monitoring with usage tracking and subscription management."""
|
||||
|
||||
def __init__(self):
|
||||
self.cache_stats = {
|
||||
'hits': 0,
|
||||
'misses': 0,
|
||||
'hit_rate': 0.0
|
||||
}
|
||||
# API provider detection patterns - Updated to match actual endpoints
|
||||
self.provider_patterns = {
|
||||
APIProvider.GEMINI: [
|
||||
r'gemini', r'google.*ai'
|
||||
],
|
||||
APIProvider.OPENAI: [r'openai', r'gpt', r'chatgpt'],
|
||||
APIProvider.ANTHROPIC: [r'anthropic', r'claude'],
|
||||
APIProvider.MISTRAL: [r'mistral'],
|
||||
APIProvider.TAVILY: [r'tavily'],
|
||||
APIProvider.SERPER: [r'serper'],
|
||||
APIProvider.METAPHOR: [r'metaphor', r'/exa'],
|
||||
APIProvider.FIRECRAWL: [r'firecrawl']
|
||||
}
|
||||
|
||||
def detect_api_provider(self, path: str, user_agent: str = None) -> Optional[APIProvider]:
|
||||
"""Detect which API provider is being used based on request details."""
|
||||
path_lower = path.lower()
|
||||
user_agent_lower = (user_agent or '').lower()
|
||||
|
||||
# Permanently ignore internal route families that must not accrue or check provider usage
|
||||
if path_lower.startswith('/api/onboarding/') or path_lower.startswith('/api/subscription/'):
|
||||
return None
|
||||
|
||||
for provider, patterns in self.provider_patterns.items():
|
||||
for pattern in patterns:
|
||||
if re.search(pattern, path_lower) or re.search(pattern, user_agent_lower):
|
||||
return provider
|
||||
|
||||
return None
|
||||
|
||||
def extract_usage_metrics(self, request_body: str = None, response_body: str = None) -> Dict[str, Any]:
|
||||
"""Extract usage metrics from request/response bodies."""
|
||||
metrics = {
|
||||
'tokens_input': 0,
|
||||
'tokens_output': 0,
|
||||
'model_used': None,
|
||||
'search_count': 0,
|
||||
'image_count': 0,
|
||||
'page_count': 0
|
||||
}
|
||||
|
||||
try:
|
||||
# Try to parse request body for input tokens/content
|
||||
if request_body:
|
||||
request_data = json.loads(request_body) if isinstance(request_body, str) else request_body
|
||||
|
||||
# Extract model information
|
||||
if 'model' in request_data:
|
||||
metrics['model_used'] = request_data['model']
|
||||
|
||||
# Estimate input tokens from prompt/content
|
||||
if 'prompt' in request_data:
|
||||
metrics['tokens_input'] = self._estimate_tokens(request_data['prompt'])
|
||||
elif 'messages' in request_data:
|
||||
total_content = ' '.join([msg.get('content', '') for msg in request_data['messages']])
|
||||
metrics['tokens_input'] = self._estimate_tokens(total_content)
|
||||
elif 'input' in request_data:
|
||||
metrics['tokens_input'] = self._estimate_tokens(str(request_data['input']))
|
||||
|
||||
# Count specific request types
|
||||
if 'query' in request_data or 'search' in request_data:
|
||||
metrics['search_count'] = 1
|
||||
if 'image' in request_data or 'generate_image' in request_data:
|
||||
metrics['image_count'] = 1
|
||||
if 'url' in request_data or 'crawl' in request_data:
|
||||
metrics['page_count'] = 1
|
||||
|
||||
# Try to parse response body for output tokens
|
||||
if response_body:
|
||||
response_data = json.loads(response_body) if isinstance(response_body, str) else response_body
|
||||
|
||||
# Extract output content and estimate tokens
|
||||
if 'text' in response_data:
|
||||
metrics['tokens_output'] = self._estimate_tokens(response_data['text'])
|
||||
elif 'content' in response_data:
|
||||
metrics['tokens_output'] = self._estimate_tokens(str(response_data['content']))
|
||||
elif 'choices' in response_data and response_data['choices']:
|
||||
choice = response_data['choices'][0]
|
||||
if 'message' in choice and 'content' in choice['message']:
|
||||
metrics['tokens_output'] = self._estimate_tokens(choice['message']['content'])
|
||||
|
||||
# Extract actual token usage if provided by API
|
||||
if 'usage' in response_data:
|
||||
usage = response_data['usage']
|
||||
if 'prompt_tokens' in usage:
|
||||
metrics['tokens_input'] = usage['prompt_tokens']
|
||||
if 'completion_tokens' in usage:
|
||||
metrics['tokens_output'] = usage['completion_tokens']
|
||||
|
||||
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
||||
logger.debug(f"Could not extract usage metrics: {e}")
|
||||
|
||||
return metrics
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
"""Estimate token count for text (rough approximation)."""
|
||||
if not text:
|
||||
return 0
|
||||
# Rough estimation: 1.3 tokens per word on average
|
||||
word_count = len(str(text).split())
|
||||
return int(word_count * 1.3)
|
||||
|
||||
async def check_usage_limits_middleware(request: Request, user_id: str, request_body: str = None) -> Optional[JSONResponse]:
|
||||
"""Check usage limits before processing request."""
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
# No special whitelist; onboarding/subscription are ignored by provider detection
|
||||
try:
|
||||
path = request.url.path
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
db = None
|
||||
try:
|
||||
db = _get_db_session()
|
||||
api_monitor = DatabaseAPIMonitor()
|
||||
|
||||
# Detect if this is an API call that should be rate limited
|
||||
api_provider = api_monitor.detect_api_provider(request.url.path, request.headers.get('user-agent'))
|
||||
if not api_provider:
|
||||
return None
|
||||
|
||||
# Use provided request body or read it if not provided
|
||||
if request_body is None:
|
||||
try:
|
||||
if hasattr(request, '_body'):
|
||||
request_body = request._body
|
||||
else:
|
||||
# Try to read body (this might not work in all cases)
|
||||
body = await request.body()
|
||||
request_body = body.decode('utf-8') if body else None
|
||||
except:
|
||||
pass
|
||||
|
||||
# Estimate tokens needed
|
||||
tokens_requested = 0
|
||||
if request_body:
|
||||
usage_metrics = api_monitor.extract_usage_metrics(request_body)
|
||||
tokens_requested = usage_metrics.get('tokens_input', 0)
|
||||
|
||||
# Check limits
|
||||
usage_service = UsageTrackingService(db)
|
||||
can_proceed, message, usage_info = await usage_service.enforce_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=api_provider,
|
||||
tokens_requested=tokens_requested
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.warning(f"Usage limit exceeded for {user_id}: {message}")
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "Usage limit exceeded",
|
||||
"message": message,
|
||||
"usage_info": usage_info,
|
||||
"provider": api_provider.value
|
||||
}
|
||||
)
|
||||
|
||||
# Warn if approaching limits
|
||||
if usage_info.get('call_usage_percentage', 0) >= 80 or usage_info.get('cost_usage_percentage', 0) >= 80:
|
||||
logger.warning(f"User {user_id} approaching usage limits: {usage_info}")
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking usage limits: {e}")
|
||||
# Don't block requests if usage checking fails
|
||||
return None
|
||||
finally:
|
||||
if db is not None:
|
||||
db.close()
|
||||
|
||||
async def monitoring_middleware(request: Request, call_next):
|
||||
"""Enhanced FastAPI middleware for monitoring API calls with usage tracking."""
|
||||
start_time = time.time()
|
||||
|
||||
# Get database session
|
||||
db = _get_db_session()
|
||||
|
||||
# Extract request details - Enhanced user identification
|
||||
user_id = None
|
||||
try:
|
||||
# PRIORITY 1: Check request.state.user_id (set by API key injection middleware)
|
||||
if hasattr(request.state, 'user_id') and request.state.user_id:
|
||||
user_id = request.state.user_id
|
||||
logger.debug(f"Monitoring: Using user_id from request.state: {user_id}")
|
||||
|
||||
# PRIORITY 2: Check query parameters
|
||||
elif hasattr(request, 'query_params') and 'user_id' in request.query_params:
|
||||
user_id = request.query_params['user_id']
|
||||
elif hasattr(request, 'path_params') and 'user_id' in request.path_params:
|
||||
user_id = request.path_params['user_id']
|
||||
|
||||
# PRIORITY 3: Check headers for user identification
|
||||
elif 'x-user-id' in request.headers:
|
||||
user_id = request.headers['x-user-id']
|
||||
elif 'x-user-email' in request.headers:
|
||||
user_id = request.headers['x-user-email'] # Use email as user identifier
|
||||
elif 'x-session-id' in request.headers:
|
||||
user_id = request.headers['x-session-id'] # Use session as fallback
|
||||
|
||||
# Check for authorization header with user info
|
||||
elif 'authorization' in request.headers:
|
||||
# Auth middleware should have set request.state.user_id
|
||||
# If not, this indicates an authentication failure (likely expired token)
|
||||
# Log at debug level to reduce noise - expired tokens are expected
|
||||
user_id = None
|
||||
logger.debug("Monitoring: Auth header present but no user_id in state - token likely expired")
|
||||
|
||||
# Final fallback: None (skip usage limits for truly anonymous/unauthenticated)
|
||||
else:
|
||||
user_id = None
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error extracting user ID: {e}")
|
||||
user_id = None # On error, skip usage limits
|
||||
|
||||
# Capture request body for usage tracking (read once, safely)
|
||||
request_body = None
|
||||
try:
|
||||
# Only read body for POST/PUT/PATCH requests to avoid issues
|
||||
if request.method in ['POST', 'PUT', 'PATCH']:
|
||||
if hasattr(request, '_body') and request._body:
|
||||
request_body = request._body.decode('utf-8')
|
||||
else:
|
||||
# Read body only if it hasn't been read yet
|
||||
try:
|
||||
body = await request.body()
|
||||
request_body = body.decode('utf-8') if body else None
|
||||
except Exception as body_error:
|
||||
logger.debug(f"Could not read request body: {body_error}")
|
||||
request_body = None
|
||||
except Exception as e:
|
||||
logger.debug(f"Error capturing request body: {e}")
|
||||
request_body = None
|
||||
|
||||
# Check usage limits before processing
|
||||
limit_response = await check_usage_limits_middleware(request, user_id, request_body)
|
||||
if limit_response:
|
||||
return limit_response
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
status_code = response.status_code
|
||||
duration = time.time() - start_time
|
||||
|
||||
# Capture response body for usage tracking
|
||||
response_body = None
|
||||
try:
|
||||
if hasattr(response, 'body'):
|
||||
response_body = response.body.decode('utf-8') if response.body else None
|
||||
elif hasattr(response, '_content'):
|
||||
response_body = response._content.decode('utf-8') if response._content else None
|
||||
except:
|
||||
pass
|
||||
|
||||
# Track API usage if this is an API call to external providers
|
||||
api_monitor = DatabaseAPIMonitor()
|
||||
api_provider = api_monitor.detect_api_provider(request.url.path, request.headers.get('user-agent'))
|
||||
if api_provider and user_id:
|
||||
logger.info(f"Detected API call: {request.url.path} -> {api_provider.value} for user: {user_id}")
|
||||
try:
|
||||
# Extract usage metrics
|
||||
usage_metrics = api_monitor.extract_usage_metrics(request_body, response_body)
|
||||
|
||||
# Track usage with the usage tracking service
|
||||
usage_service = UsageTrackingService(db)
|
||||
await usage_service.track_api_usage(
|
||||
user_id=user_id,
|
||||
provider=api_provider,
|
||||
endpoint=request.url.path,
|
||||
method=request.method,
|
||||
model_used=usage_metrics.get('model_used'),
|
||||
tokens_input=usage_metrics.get('tokens_input', 0),
|
||||
tokens_output=usage_metrics.get('tokens_output', 0),
|
||||
response_time=duration,
|
||||
status_code=status_code,
|
||||
request_size=len(request_body) if request_body else None,
|
||||
response_size=len(response_body) if response_body else None,
|
||||
user_agent=request.headers.get('user-agent'),
|
||||
ip_address=request.client.host if request.client else None,
|
||||
search_count=usage_metrics.get('search_count', 0),
|
||||
image_count=usage_metrics.get('image_count', 0),
|
||||
page_count=usage_metrics.get('page_count', 0)
|
||||
)
|
||||
except Exception as usage_error:
|
||||
logger.error(f"Error tracking API usage: {usage_error}")
|
||||
# Don't fail the main request if usage tracking fails
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
duration = time.time() - start_time
|
||||
status_code = 500
|
||||
|
||||
# Store minimal error info
|
||||
logger.error(f"API Error: {request.method} {request.url.path} - {str(e)}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": "Internal server error"}
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def get_monitoring_stats(minutes: int = 5) -> Dict[str, Any]:
|
||||
"""Get current monitoring statistics."""
|
||||
db = None
|
||||
try:
|
||||
db = _get_db_session()
|
||||
# Placeholder to match old API; heavy stats handled elsewhere
|
||||
return {
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
'overview': {
|
||||
'recent_requests': 0,
|
||||
'recent_errors': 0,
|
||||
},
|
||||
'cache_performance': {'hits': 0, 'misses': 0, 'hit_rate': 0.0},
|
||||
'recent_errors': [],
|
||||
'system_health': {'status': 'healthy', 'error_rate': 0.0}
|
||||
}
|
||||
finally:
|
||||
if db is not None:
|
||||
db.close()
|
||||
|
||||
async def get_lightweight_stats() -> Dict[str, Any]:
|
||||
"""Get lightweight stats for dashboard header."""
|
||||
db = None
|
||||
try:
|
||||
db = _get_db_session()
|
||||
# Minimal viable placeholder values
|
||||
now = datetime.utcnow()
|
||||
return {
|
||||
'status': 'healthy',
|
||||
'icon': '🟢',
|
||||
'recent_requests': 0,
|
||||
'recent_errors': 0,
|
||||
'error_rate': 0.0,
|
||||
'timestamp': now.isoformat()
|
||||
}
|
||||
finally:
|
||||
if db is not None:
|
||||
db.close()
|
||||
853
backend/services/subscription/preflight_validator.py
Normal file
853
backend/services/subscription/preflight_validator.py
Normal file
@@ -0,0 +1,853 @@
|
||||
"""
|
||||
Pre-flight Validation Utility for Multi-Operation Workflows
|
||||
|
||||
Provides transparent validation for operations that involve multiple API calls.
|
||||
Services can use this to validate entire workflows before making any external API calls.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from services.subscription.pricing_service import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
|
||||
|
||||
def validate_research_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str,
|
||||
gpt_provider: str = "google"
|
||||
) -> None:
|
||||
"""
|
||||
Validate all operations for a research workflow before making ANY API calls.
|
||||
|
||||
This prevents wasteful external API calls (like Google Grounding) if subsequent
|
||||
LLM calls would fail due to token or call limits.
|
||||
|
||||
Args:
|
||||
pricing_service: PricingService instance
|
||||
user_id: User ID for subscription checking
|
||||
gpt_provider: GPT provider from env var (defaults to "google")
|
||||
|
||||
Returns:
|
||||
(can_proceed, error_message, error_details)
|
||||
If can_proceed is False, raises HTTPException with 429 status
|
||||
"""
|
||||
try:
|
||||
# Determine actual provider for LLM calls based on GPT_PROVIDER env var
|
||||
gpt_provider_lower = gpt_provider.lower()
|
||||
if gpt_provider_lower == "huggingface":
|
||||
llm_provider_enum = APIProvider.MISTRAL # Maps to HuggingFace
|
||||
llm_provider_name = "huggingface"
|
||||
else:
|
||||
llm_provider_enum = APIProvider.GEMINI
|
||||
llm_provider_name = "gemini"
|
||||
|
||||
# Estimate tokens for each operation in research workflow
|
||||
# Google Grounding call: ~1200 tokens (input: ~500 tokens, output: ~700 tokens for research results)
|
||||
# Keyword analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON)
|
||||
# Competitor analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON)
|
||||
# Content angle generator: ~1000 tokens (input: 3000 chars research, output: list of angles)
|
||||
# Note: These are conservative estimates. Actual usage may be lower, but we use these for pre-flight validation
|
||||
# to prevent wasteful API calls if the workflow would exceed limits.
|
||||
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.GEMINI, # Google Grounding uses Gemini
|
||||
'tokens_requested': 1200, # Reduced from 2000 to more realistic estimate
|
||||
'actual_provider_name': 'gemini',
|
||||
'operation_type': 'google_grounding'
|
||||
},
|
||||
{
|
||||
'provider': llm_provider_enum,
|
||||
'tokens_requested': 1000,
|
||||
'actual_provider_name': llm_provider_name,
|
||||
'operation_type': 'keyword_analysis'
|
||||
},
|
||||
{
|
||||
'provider': llm_provider_enum,
|
||||
'tokens_requested': 1000,
|
||||
'actual_provider_name': llm_provider_name,
|
||||
'operation_type': 'competitor_analysis'
|
||||
},
|
||||
{
|
||||
'provider': llm_provider_enum,
|
||||
'tokens_requested': 1000,
|
||||
'actual_provider_name': llm_provider_name,
|
||||
'operation_type': 'content_angle_generation'
|
||||
}
|
||||
]
|
||||
|
||||
logger.info(f"[Pre-flight Validator] 🚀 Starting Research Workflow Validation")
|
||||
logger.info(f" ├─ User: {user_id}")
|
||||
logger.info(f" ├─ LLM Provider: {llm_provider_name} (GPT_PROVIDER={gpt_provider})")
|
||||
logger.info(f" └─ Operations to validate: {len(operations_to_validate)}")
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', llm_provider_name) if usage_info else llm_provider_name
|
||||
operation_type = usage_info.get('operation_type', 'unknown')
|
||||
|
||||
logger.warning(f"[Pre-flight] Research blocked for user {user_id}: {operation_type} ({provider}) - {message}")
|
||||
|
||||
# Raise HTTPException immediately - frontend gets immediate response, no API calls made
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ RESEARCH WORKFLOW APPROVED")
|
||||
logger.info(f" ├─ User: {user_id}")
|
||||
logger.info(f" └─ All {len(operations_to_validate)} operations validated - proceeding with API calls")
|
||||
# Validation passed - no return needed (function raises HTTPException if validation fails)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating research operations: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
'error': f"Failed to validate operations: {str(e)}",
|
||||
'message': f"Failed to validate operations: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def validate_exa_research_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str,
|
||||
gpt_provider: str = "google"
|
||||
) -> None:
|
||||
"""
|
||||
Validate all operations for an Exa research workflow before making ANY API calls.
|
||||
|
||||
This prevents wasteful external API calls (like Exa search) if subsequent
|
||||
LLM calls would fail due to token or call limits.
|
||||
|
||||
Args:
|
||||
pricing_service: PricingService instance
|
||||
user_id: User ID for subscription checking
|
||||
gpt_provider: GPT provider from env var (defaults to "google")
|
||||
|
||||
Returns:
|
||||
None
|
||||
If validation fails, raises HTTPException with 429 status
|
||||
"""
|
||||
try:
|
||||
# Determine actual provider for LLM calls based on GPT_PROVIDER env var
|
||||
gpt_provider_lower = gpt_provider.lower()
|
||||
if gpt_provider_lower == "huggingface":
|
||||
llm_provider_enum = APIProvider.MISTRAL # Maps to HuggingFace
|
||||
llm_provider_name = "huggingface"
|
||||
else:
|
||||
llm_provider_enum = APIProvider.GEMINI
|
||||
llm_provider_name = "gemini"
|
||||
|
||||
# Estimate tokens for each operation in Exa research workflow
|
||||
# Exa Search call: 1 Exa API call (not token-based)
|
||||
# Keyword analyzer: ~1000 tokens (input: research results, output: structured JSON)
|
||||
# Competitor analyzer: ~1000 tokens (input: research results, output: structured JSON)
|
||||
# Content angle generator: ~1000 tokens (input: research results, output: list of angles)
|
||||
# Note: These are conservative estimates for pre-flight validation
|
||||
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.EXA, # Exa API call
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'exa',
|
||||
'operation_type': 'exa_neural_search'
|
||||
},
|
||||
{
|
||||
'provider': llm_provider_enum,
|
||||
'tokens_requested': 1000,
|
||||
'actual_provider_name': llm_provider_name,
|
||||
'operation_type': 'keyword_analysis'
|
||||
},
|
||||
{
|
||||
'provider': llm_provider_enum,
|
||||
'tokens_requested': 1000,
|
||||
'actual_provider_name': llm_provider_name,
|
||||
'operation_type': 'competitor_analysis'
|
||||
},
|
||||
{
|
||||
'provider': llm_provider_enum,
|
||||
'tokens_requested': 1000,
|
||||
'actual_provider_name': llm_provider_name,
|
||||
'operation_type': 'content_angle_generation'
|
||||
}
|
||||
]
|
||||
|
||||
logger.info(f"[Pre-flight Validator] 🚀 Starting Exa Research Workflow Validation")
|
||||
logger.info(f" ├─ User: {user_id}")
|
||||
logger.info(f" ├─ LLM Provider: {llm_provider_name} (GPT_PROVIDER={gpt_provider})")
|
||||
logger.info(f" └─ Operations to validate: {len(operations_to_validate)}")
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', llm_provider_name) if usage_info else llm_provider_name
|
||||
operation_type = usage_info.get('operation_type', 'unknown')
|
||||
|
||||
logger.error(f"[Pre-flight Validator] ❌ EXA RESEARCH WORKFLOW BLOCKED")
|
||||
logger.error(f" ├─ User: {user_id}")
|
||||
logger.error(f" ├─ Blocked at: {operation_type}")
|
||||
logger.error(f" ├─ Provider: {provider}")
|
||||
logger.error(f" └─ Reason: {message}")
|
||||
|
||||
# Raise HTTPException immediately - frontend gets immediate response, no API calls made
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ EXA RESEARCH WORKFLOW APPROVED")
|
||||
logger.info(f" ├─ User: {user_id}")
|
||||
logger.info(f" └─ All {len(operations_to_validate)} operations validated - proceeding with API calls")
|
||||
# Validation passed - no return needed (function raises HTTPException if validation fails)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating Exa research operations: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
'error': f"Failed to validate operations: {str(e)}",
|
||||
'message': f"Failed to validate operations: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def validate_image_generation_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str,
|
||||
num_images: int = 1
|
||||
) -> None:
|
||||
"""
|
||||
Validate image generation operation(s) before making API calls.
|
||||
|
||||
Args:
|
||||
pricing_service: PricingService instance
|
||||
user_id: User ID for subscription checking
|
||||
num_images: Number of images to generate (for multiple variations)
|
||||
|
||||
Returns:
|
||||
None
|
||||
If validation fails, raises HTTPException with 429 status
|
||||
"""
|
||||
try:
|
||||
# Create validation operations for each image
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.STABILITY,
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'stability',
|
||||
'operation_type': 'image_generation'
|
||||
}
|
||||
for _ in range(num_images)
|
||||
]
|
||||
|
||||
logger.info(f"[Pre-flight Validator] 🚀 Validating {num_images} image generation(s) for user {user_id}")
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.error(f"[Pre-flight Validator] Image generation blocked for user {user_id}: {message}")
|
||||
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', 'stability') if usage_info else 'stability'
|
||||
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ Image generation validated for user {user_id}")
|
||||
# Validation passed - no return needed (function raises HTTPException if validation fails)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
|
||||
def validate_image_upscale_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str,
|
||||
num_images: int = 1
|
||||
) -> None:
|
||||
"""
|
||||
Validate image upscaling before making API calls.
|
||||
"""
|
||||
try:
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.STABILITY,
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'stability',
|
||||
'operation_type': 'image_upscale'
|
||||
}
|
||||
for _ in range(num_images)
|
||||
]
|
||||
|
||||
logger.info(f"[Pre-flight Validator] 🚀 Validating {num_images} image upscale request(s) for user {user_id}")
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.error(f"[Pre-flight Validator] Image upscale blocked for user {user_id}: {message}")
|
||||
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', 'stability') if usage_info else 'stability'
|
||||
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ Image upscale validated for user {user_id}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating image generation: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
'error': f"Failed to validate image generation: {str(e)}",
|
||||
'message': f"Failed to validate image generation: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def validate_image_editing_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Validate image editing operation before making API calls.
|
||||
|
||||
Args:
|
||||
pricing_service: PricingService instance
|
||||
user_id: User ID for subscription checking
|
||||
|
||||
Returns:
|
||||
None - raises HTTPException with 429 status if validation fails
|
||||
"""
|
||||
try:
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.IMAGE_EDIT,
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'image_edit',
|
||||
'operation_type': 'image_editing'
|
||||
}
|
||||
]
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.error(f"[Pre-flight Validator] Image editing blocked for user {user_id}: {message}")
|
||||
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', 'image_edit') if usage_info else 'image_edit'
|
||||
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ Image editing validated for user {user_id}")
|
||||
# Validation passed - no return needed (function raises HTTPException if validation fails)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating image editing: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
'error': f"Failed to validate image editing: {str(e)}",
|
||||
'message': f"Failed to validate image editing: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def validate_image_control_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str,
|
||||
num_images: int = 1
|
||||
) -> None:
|
||||
"""
|
||||
Validate image control operations (sketch-to-image, structure control, style transfer) before making API calls.
|
||||
|
||||
Control operations use Stability AI for image generation with control inputs, so they use
|
||||
the same validation as image generation operations.
|
||||
|
||||
Args:
|
||||
pricing_service: PricingService instance
|
||||
user_id: User ID for subscription checking
|
||||
num_images: Number of images to generate (for multiple variations)
|
||||
|
||||
Returns:
|
||||
None - raises HTTPException with 429 status if validation fails
|
||||
"""
|
||||
try:
|
||||
# Control operations use Stability AI, same as image generation
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.STABILITY,
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'stability',
|
||||
'operation_type': 'image_generation' # Control ops use image generation limits
|
||||
}
|
||||
for _ in range(num_images)
|
||||
]
|
||||
|
||||
logger.info(f"[Pre-flight Validator] 🚀 Validating {num_images} image control operation(s) for user {user_id}")
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.error(f"[Pre-flight Validator] Image control blocked for user {user_id}: {message}")
|
||||
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', 'stability') if usage_info else 'stability'
|
||||
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ Image control validated for user {user_id}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating image control: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
'error': f"Failed to validate image control: {str(e)}",
|
||||
'message': f"Failed to validate image control: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def validate_video_generation_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Validate video generation operation before making API calls.
|
||||
|
||||
Args:
|
||||
pricing_service: PricingService instance
|
||||
user_id: User ID for subscription checking
|
||||
|
||||
Returns:
|
||||
None - raises HTTPException with 429 status if validation fails
|
||||
"""
|
||||
try:
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.VIDEO,
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'video',
|
||||
'operation_type': 'video_generation'
|
||||
}
|
||||
]
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.error(f"[Pre-flight Validator] Video generation blocked for user {user_id}: {message}")
|
||||
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', 'video') if usage_info else 'video'
|
||||
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ Video generation validated for user {user_id}")
|
||||
# Validation passed - no return needed (function raises HTTPException if validation fails)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating video generation: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
'error': f"Failed to validate video generation: {str(e)}",
|
||||
'message': f"Failed to validate video generation: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def validate_scene_animation_operation(
|
||||
pricing_service: PricingService,
|
||||
user_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Validate the per-scene animation workflow before API calls.
|
||||
"""
|
||||
try:
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.VIDEO,
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'wavespeed',
|
||||
'operation_type': 'scene_animation',
|
||||
}
|
||||
]
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate,
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.error(f"[Pre-flight Validator] Scene animation blocked for user {user_id}: {message}")
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', 'video') if usage_info else 'video'
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details,
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ Scene animation validated for user {user_id}")
|
||||
# Validation passed - no return needed (function raises HTTPException if validation fails)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating scene animation: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
'error': f"Failed to validate scene animation: {str(e)}",
|
||||
'message': f"Failed to validate scene animation: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def validate_image_control_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str,
|
||||
num_images: int = 1
|
||||
) -> None:
|
||||
"""
|
||||
Validate image control operations (sketch-to-image, structure control, style transfer) before making API calls.
|
||||
|
||||
Control operations use Stability AI for image generation with control inputs, so they use
|
||||
the same validation as image generation operations.
|
||||
|
||||
Args:
|
||||
pricing_service: PricingService instance
|
||||
user_id: User ID for subscription checking
|
||||
num_images: Number of images to generate (for multiple variations)
|
||||
|
||||
Returns:
|
||||
None - raises HTTPException with 429 status if validation fails
|
||||
"""
|
||||
try:
|
||||
# Control operations use Stability AI, same as image generation
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.STABILITY,
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'stability',
|
||||
'operation_type': 'image_generation' # Control ops use image generation limits
|
||||
}
|
||||
for _ in range(num_images)
|
||||
]
|
||||
|
||||
logger.info(f"[Pre-flight Validator] 🚀 Validating {num_images} image control operation(s) for user {user_id}")
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.error(f"[Pre-flight Validator] Image control blocked for user {user_id}: {message}")
|
||||
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', 'stability') if usage_info else 'stability'
|
||||
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ Image control validated for user {user_id}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating image control: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
'error': f"Failed to validate image control: {str(e)}",
|
||||
'message': f"Failed to validate image control: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def validate_video_generation_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Validate video generation operation before making API calls.
|
||||
|
||||
Args:
|
||||
pricing_service: PricingService instance
|
||||
user_id: User ID for subscription checking
|
||||
|
||||
Returns:
|
||||
None - raises HTTPException with 429 status if validation fails
|
||||
"""
|
||||
try:
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.VIDEO,
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'video',
|
||||
'operation_type': 'video_generation'
|
||||
}
|
||||
]
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.error(f"[Pre-flight Validator] Video generation blocked for user {user_id}: {message}")
|
||||
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', 'video') if usage_info else 'video'
|
||||
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ Video generation validated for user {user_id}")
|
||||
# Validation passed - no return needed (function raises HTTPException if validation fails)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating video generation: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
'error': f"Failed to validate video generation: {str(e)}",
|
||||
'message': f"Failed to validate video generation: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def validate_scene_animation_operation(
|
||||
pricing_service: PricingService,
|
||||
user_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Validate the per-scene animation workflow before API calls.
|
||||
"""
|
||||
try:
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.VIDEO,
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'wavespeed',
|
||||
'operation_type': 'scene_animation',
|
||||
}
|
||||
]
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate,
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.error(f"[Pre-flight Validator] Scene animation blocked for user {user_id}: {message}")
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', 'video') if usage_info else 'video'
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details,
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ Scene animation validated for user {user_id}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating scene animation: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
'error': f"Failed to validate scene animation: {str(e)}",
|
||||
'message': f"Failed to validate scene animation: {str(e)}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def validate_calendar_generation_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str,
|
||||
gpt_provider: str = "google"
|
||||
) -> None:
|
||||
"""
|
||||
Validate calendar generation operations before making API calls.
|
||||
|
||||
Args:
|
||||
pricing_service: PricingService instance
|
||||
user_id: User ID for subscription checking
|
||||
gpt_provider: GPT provider from env var (defaults to "google")
|
||||
|
||||
Returns:
|
||||
None - raises HTTPException with 429 status if validation fails
|
||||
"""
|
||||
try:
|
||||
# Determine actual provider for LLM calls based on GPT_PROVIDER env var
|
||||
gpt_provider_lower = gpt_provider.lower()
|
||||
if gpt_provider_lower == "huggingface":
|
||||
llm_provider_enum = APIProvider.MISTRAL
|
||||
llm_provider_name = "huggingface"
|
||||
else:
|
||||
llm_provider_enum = APIProvider.GEMINI
|
||||
llm_provider_name = "gemini"
|
||||
|
||||
# Estimate tokens for 12-step process
|
||||
# This is a heavy operation involving multiple steps and analysis
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': llm_provider_enum,
|
||||
'tokens_requested': 20000, # Conservative estimate for full calendar generation
|
||||
'actual_provider_name': llm_provider_name,
|
||||
'operation_type': 'calendar_generation'
|
||||
}
|
||||
]
|
||||
|
||||
logger.info(f"[Pre-flight Validator] 🚀 Validating Calendar Generation for user {user_id}")
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', llm_provider_name) if usage_info else llm_provider_name
|
||||
|
||||
logger.warning(f"[Pre-flight Validator] Calendar generation blocked for user {user_id}: {message}")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ Calendar Generation validated for user {user_id}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating calendar generation: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
'error': f"Failed to validate calendar generation: {str(e)}",
|
||||
'message': f"Failed to validate calendar generation: {str(e)}"
|
||||
}
|
||||
)
|
||||
815
backend/services/subscription/pricing_service.py
Normal file
815
backend/services/subscription/pricing_service.py
Normal file
@@ -0,0 +1,815 @@
|
||||
"""
|
||||
Pricing Service for API Usage Tracking
|
||||
Manages API pricing, cost calculation, and subscription limits.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List, Tuple, Union
|
||||
from decimal import Decimal, ROUND_HALF_UP
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
from loguru import logger
|
||||
import os
|
||||
|
||||
from models.subscription_models import (
|
||||
APIProviderPricing, SubscriptionPlan, UserSubscription,
|
||||
UsageSummary, APIUsageLog, APIProvider, SubscriptionTier
|
||||
)
|
||||
|
||||
class PricingService:
|
||||
"""Service for managing API pricing and cost calculations."""
|
||||
|
||||
# Class-level cache shared across all instances (critical for cache invalidation on subscription renewal)
|
||||
# key: f"{user_id}:{provider}", value: { 'result': (bool, str, dict), 'expires_at': datetime }
|
||||
_limits_cache: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self._pricing_cache = {}
|
||||
self._plans_cache = {}
|
||||
# Cache for schema feature detection (ai_text_generation_calls_limit column)
|
||||
self._ai_text_gen_col_checked: bool = False
|
||||
self._ai_text_gen_col_available: bool = False
|
||||
|
||||
# ------------------- Billing period helpers -------------------
|
||||
def _compute_next_period_end(self, start: datetime, cycle: str) -> datetime:
|
||||
"""Compute the next period end given a start and billing cycle."""
|
||||
try:
|
||||
cycle_value = cycle.value if hasattr(cycle, 'value') else str(cycle)
|
||||
except Exception:
|
||||
cycle_value = str(cycle)
|
||||
if cycle_value == 'yearly':
|
||||
return start + timedelta(days=365)
|
||||
return start + timedelta(days=30)
|
||||
|
||||
def _ensure_subscription_current(self, subscription) -> bool:
|
||||
"""Auto-advance subscription period if expired and auto_renew is enabled."""
|
||||
if not subscription:
|
||||
return False
|
||||
now = datetime.utcnow()
|
||||
try:
|
||||
if subscription.current_period_end and subscription.current_period_end < now:
|
||||
if getattr(subscription, 'auto_renew', False):
|
||||
subscription.current_period_start = now
|
||||
subscription.current_period_end = self._compute_next_period_end(now, subscription.billing_cycle)
|
||||
# Keep status active if model enum else string
|
||||
try:
|
||||
subscription.status = subscription.status.ACTIVE # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
setattr(subscription, 'status', 'active')
|
||||
self.db.commit()
|
||||
else:
|
||||
return False
|
||||
except Exception:
|
||||
self.db.rollback()
|
||||
return True
|
||||
|
||||
def get_current_billing_period(self, user_id: str) -> Optional[str]:
|
||||
"""Return current billing period key (YYYY-MM) after ensuring subscription is current."""
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
# Ensure subscription is current (advance if auto_renew)
|
||||
self._ensure_subscription_current(subscription)
|
||||
# Continue to use YYYY-MM for summaries
|
||||
return datetime.now().strftime("%Y-%m")
|
||||
|
||||
@classmethod
|
||||
def clear_user_cache(cls, user_id: str) -> int:
|
||||
"""Clear all cached limit checks for a specific user. Returns number of entries cleared."""
|
||||
keys_to_remove = [key for key in cls._limits_cache.keys() if key.startswith(f"{user_id}:")]
|
||||
for key in keys_to_remove:
|
||||
del cls._limits_cache[key]
|
||||
logger.info(f"Cleared {len(keys_to_remove)} cache entries for user {user_id}")
|
||||
return len(keys_to_remove)
|
||||
|
||||
def initialize_default_pricing(self):
|
||||
"""Initialize default pricing for all API providers."""
|
||||
|
||||
# Gemini API Pricing (Updated as of September 2025 - Official Google AI Pricing)
|
||||
# Source: https://ai.google.dev/gemini-api/docs/pricing
|
||||
gemini_pricing = [
|
||||
# Gemini 2.5 Pro - Standard Tier
|
||||
{
|
||||
"provider": APIProvider.GEMINI,
|
||||
"model_name": "gemini-2.5-pro",
|
||||
"cost_per_input_token": 0.00000125, # $1.25 per 1M input tokens (prompts <= 200k tokens)
|
||||
"cost_per_output_token": 0.00001, # $10.00 per 1M output tokens (prompts <= 200k tokens)
|
||||
"description": "Gemini 2.5 Pro - State-of-the-art multipurpose model for coding and complex reasoning"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.GEMINI,
|
||||
"model_name": "gemini-2.5-pro-large",
|
||||
"cost_per_input_token": 0.0000025, # $2.50 per 1M input tokens (prompts > 200k tokens)
|
||||
"cost_per_output_token": 0.000015, # $15.00 per 1M output tokens (prompts > 200k tokens)
|
||||
"description": "Gemini 2.5 Pro - Large context model for prompts > 200k tokens"
|
||||
},
|
||||
# Gemini 2.5 Flash - Standard Tier
|
||||
{
|
||||
"provider": APIProvider.GEMINI,
|
||||
"model_name": "gemini-2.5-flash",
|
||||
"cost_per_input_token": 0.0000003, # $0.30 per 1M input tokens (text/image/video)
|
||||
"cost_per_output_token": 0.0000025, # $2.50 per 1M output tokens
|
||||
"description": "Gemini 2.5 Flash - Hybrid reasoning model with 1M token context window"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.GEMINI,
|
||||
"model_name": "gemini-2.5-flash-audio",
|
||||
"cost_per_input_token": 0.000001, # $1.00 per 1M input tokens (audio)
|
||||
"cost_per_output_token": 0.0000025, # $2.50 per 1M output tokens
|
||||
"description": "Gemini 2.5 Flash - Audio input model"
|
||||
},
|
||||
# Gemini 2.5 Flash-Lite - Standard Tier
|
||||
{
|
||||
"provider": APIProvider.GEMINI,
|
||||
"model_name": "gemini-2.5-flash-lite",
|
||||
"cost_per_input_token": 0.0000001, # $0.10 per 1M input tokens (text/image/video)
|
||||
"cost_per_output_token": 0.0000004, # $0.40 per 1M output tokens
|
||||
"description": "Gemini 2.5 Flash-Lite - Smallest and most cost-effective model for at-scale usage"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.GEMINI,
|
||||
"model_name": "gemini-2.5-flash-lite-audio",
|
||||
"cost_per_input_token": 0.0000003, # $0.30 per 1M input tokens (audio)
|
||||
"cost_per_output_token": 0.0000004, # $0.40 per 1M output tokens
|
||||
"description": "Gemini 2.5 Flash-Lite - Audio input model"
|
||||
},
|
||||
# Gemini 1.5 Flash - Standard Tier
|
||||
{
|
||||
"provider": APIProvider.GEMINI,
|
||||
"model_name": "gemini-1.5-flash",
|
||||
"cost_per_input_token": 0.000000075, # $0.075 per 1M input tokens (prompts <= 128k tokens)
|
||||
"cost_per_output_token": 0.0000003, # $0.30 per 1M output tokens (prompts <= 128k tokens)
|
||||
"description": "Gemini 1.5 Flash - Fast multimodal model with 1M token context window"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.GEMINI,
|
||||
"model_name": "gemini-1.5-flash-large",
|
||||
"cost_per_input_token": 0.00000015, # $0.15 per 1M input tokens (prompts > 128k tokens)
|
||||
"cost_per_output_token": 0.0000006, # $0.60 per 1M output tokens (prompts > 128k tokens)
|
||||
"description": "Gemini 1.5 Flash - Large context model for prompts > 128k tokens"
|
||||
},
|
||||
# Gemini 1.5 Flash-8B - Standard Tier
|
||||
{
|
||||
"provider": APIProvider.GEMINI,
|
||||
"model_name": "gemini-1.5-flash-8b",
|
||||
"cost_per_input_token": 0.0000000375, # $0.0375 per 1M input tokens (prompts <= 128k tokens)
|
||||
"cost_per_output_token": 0.00000015, # $0.15 per 1M output tokens (prompts <= 128k tokens)
|
||||
"description": "Gemini 1.5 Flash-8B - Smallest model for lower intelligence use cases"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.GEMINI,
|
||||
"model_name": "gemini-1.5-flash-8b-large",
|
||||
"cost_per_input_token": 0.000000075, # $0.075 per 1M input tokens (prompts > 128k tokens)
|
||||
"cost_per_output_token": 0.0000003, # $0.30 per 1M output tokens (prompts > 128k tokens)
|
||||
"description": "Gemini 1.5 Flash-8B - Large context model for prompts > 128k tokens"
|
||||
},
|
||||
# Gemini 1.5 Pro - Standard Tier
|
||||
{
|
||||
"provider": APIProvider.GEMINI,
|
||||
"model_name": "gemini-1.5-pro",
|
||||
"cost_per_input_token": 0.00000125, # $1.25 per 1M input tokens (prompts <= 128k tokens)
|
||||
"cost_per_output_token": 0.000005, # $5.00 per 1M output tokens (prompts <= 128k tokens)
|
||||
"description": "Gemini 1.5 Pro - Highest intelligence model with 2M token context window"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.GEMINI,
|
||||
"model_name": "gemini-1.5-pro-large",
|
||||
"cost_per_input_token": 0.0000025, # $2.50 per 1M input tokens (prompts > 128k tokens)
|
||||
"cost_per_output_token": 0.00001, # $10.00 per 1M output tokens (prompts > 128k tokens)
|
||||
"description": "Gemini 1.5 Pro - Large context model for prompts > 128k tokens"
|
||||
},
|
||||
# Gemini Embedding - Standard Tier
|
||||
{
|
||||
"provider": APIProvider.GEMINI,
|
||||
"model_name": "gemini-embedding",
|
||||
"cost_per_input_token": 0.00000015, # $0.15 per 1M input tokens
|
||||
"cost_per_output_token": 0.0, # No output tokens for embeddings
|
||||
"description": "Gemini Embedding - Newest embeddings model with higher rate limits"
|
||||
},
|
||||
# Grounding with Google Search - Standard Tier
|
||||
{
|
||||
"provider": APIProvider.GEMINI,
|
||||
"model_name": "gemini-grounding-search",
|
||||
"cost_per_request": 0.035, # $35 per 1,000 requests (after free tier)
|
||||
"cost_per_input_token": 0.0, # No additional token cost for grounding
|
||||
"cost_per_output_token": 0.0, # No additional token cost for grounding
|
||||
"description": "Grounding with Google Search - 1,500 RPD free, then $35/1K requests"
|
||||
}
|
||||
]
|
||||
|
||||
# OpenAI Pricing (estimated, will be updated)
|
||||
openai_pricing = [
|
||||
{
|
||||
"provider": APIProvider.OPENAI,
|
||||
"model_name": "gpt-4o",
|
||||
"cost_per_input_token": 0.0000025, # $2.50 per 1M input tokens
|
||||
"cost_per_output_token": 0.00001, # $10.00 per 1M output tokens
|
||||
"description": "GPT-4o - Latest OpenAI model"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.OPENAI,
|
||||
"model_name": "gpt-4o-mini",
|
||||
"cost_per_input_token": 0.00000015, # $0.15 per 1M input tokens
|
||||
"cost_per_output_token": 0.0000006, # $0.60 per 1M output tokens
|
||||
"description": "GPT-4o Mini - Cost-effective model"
|
||||
}
|
||||
]
|
||||
|
||||
# Anthropic Pricing (estimated, will be updated)
|
||||
anthropic_pricing = [
|
||||
{
|
||||
"provider": APIProvider.ANTHROPIC,
|
||||
"model_name": "claude-3.5-sonnet",
|
||||
"cost_per_input_token": 0.000003, # $3.00 per 1M input tokens
|
||||
"cost_per_output_token": 0.000015, # $15.00 per 1M output tokens
|
||||
"description": "Claude 3.5 Sonnet - Anthropic's flagship model"
|
||||
}
|
||||
]
|
||||
|
||||
# HuggingFace/Mistral Pricing (for GPT-OSS-120B via Groq)
|
||||
# Default pricing from environment variables or fallback to estimated values
|
||||
# Based on Groq pricing: ~$1 per 1M input tokens, ~$3 per 1M output tokens
|
||||
hf_input_cost = float(os.getenv('HUGGINGFACE_INPUT_TOKEN_COST', '0.000001')) # $1 per 1M tokens default
|
||||
hf_output_cost = float(os.getenv('HUGGINGFACE_OUTPUT_TOKEN_COST', '0.000003')) # $3 per 1M tokens default
|
||||
|
||||
mistral_pricing = [
|
||||
{
|
||||
"provider": APIProvider.MISTRAL,
|
||||
"model_name": "openai/gpt-oss-120b:groq",
|
||||
"cost_per_input_token": hf_input_cost,
|
||||
"cost_per_output_token": hf_output_cost,
|
||||
"description": f"GPT-OSS-120B via HuggingFace/Groq (configurable via HUGGINGFACE_INPUT_TOKEN_COST and HUGGINGFACE_OUTPUT_TOKEN_COST env vars)"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.MISTRAL,
|
||||
"model_name": "gpt-oss-120b",
|
||||
"cost_per_input_token": hf_input_cost,
|
||||
"cost_per_output_token": hf_output_cost,
|
||||
"description": f"GPT-OSS-120B via HuggingFace/Groq (configurable via HUGGINGFACE_INPUT_TOKEN_COST and HUGGINGFACE_OUTPUT_TOKEN_COST env vars)"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.MISTRAL,
|
||||
"model_name": "default",
|
||||
"cost_per_input_token": hf_input_cost,
|
||||
"cost_per_output_token": hf_output_cost,
|
||||
"description": f"HuggingFace default model pricing (configurable via HUGGINGFACE_INPUT_TOKEN_COST and HUGGINGFACE_OUTPUT_TOKEN_COST env vars)"
|
||||
}
|
||||
]
|
||||
|
||||
# Search API Pricing (estimated)
|
||||
search_pricing = [
|
||||
{
|
||||
"provider": APIProvider.TAVILY,
|
||||
"model_name": "tavily-search",
|
||||
"cost_per_request": 0.001, # $0.001 per search
|
||||
"description": "Tavily AI Search API"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.SERPER,
|
||||
"model_name": "serper-search",
|
||||
"cost_per_request": 0.001, # $0.001 per search
|
||||
"description": "Serper Google Search API"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.METAPHOR,
|
||||
"model_name": "metaphor-search",
|
||||
"cost_per_request": 0.003, # $0.003 per search
|
||||
"description": "Metaphor/Exa AI Search API"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.FIRECRAWL,
|
||||
"model_name": "firecrawl-extract",
|
||||
"cost_per_page": 0.002, # $0.002 per page crawled
|
||||
"description": "Firecrawl Web Extraction API"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.STABILITY,
|
||||
"model_name": "stable-diffusion",
|
||||
"cost_per_image": 0.04, # $0.04 per image
|
||||
"description": "Stability AI Image Generation"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.EXA,
|
||||
"model_name": "exa-search",
|
||||
"cost_per_request": 0.005, # $0.005 per search (1-25 results)
|
||||
"description": "Exa Neural Search API"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.VIDEO,
|
||||
"model_name": "tencent/HunyuanVideo",
|
||||
"cost_per_request": 0.10, # $0.10 per video generation (estimated)
|
||||
"description": "HuggingFace AI Video Generation (HunyuanVideo)"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.VIDEO,
|
||||
"model_name": "default",
|
||||
"cost_per_request": 0.10, # $0.10 per video generation (estimated)
|
||||
"description": "AI Video Generation default pricing"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.VIDEO,
|
||||
"model_name": "kling-v2.5-turbo-std-5s",
|
||||
"cost_per_request": 0.21,
|
||||
"description": "WaveSpeed Kling v2.5 Turbo Std Image-to-Video (5 seconds)"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.VIDEO,
|
||||
"model_name": "kling-v2.5-turbo-std-10s",
|
||||
"cost_per_request": 0.42,
|
||||
"description": "WaveSpeed Kling v2.5 Turbo Std Image-to-Video (10 seconds)"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.VIDEO,
|
||||
"model_name": "wavespeed-ai/infinitetalk",
|
||||
"cost_per_request": 0.30,
|
||||
"description": "WaveSpeed InfiniteTalk (image + audio to talking avatar video)"
|
||||
},
|
||||
# Audio Generation Pricing (Minimax Speech 02 HD via WaveSpeed)
|
||||
{
|
||||
"provider": APIProvider.AUDIO,
|
||||
"model_name": "minimax/speech-02-hd",
|
||||
"cost_per_input_token": 0.00005, # $0.05 per 1,000 characters (every character is 1 token)
|
||||
"cost_per_output_token": 0.0, # No output tokens for audio
|
||||
"cost_per_request": 0.0, # Pricing is per character, not per request
|
||||
"description": "AI Audio Generation (Text-to-Speech) - Minimax Speech 02 HD via WaveSpeed"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.AUDIO,
|
||||
"model_name": "default",
|
||||
"cost_per_input_token": 0.00005, # $0.05 per 1,000 characters default
|
||||
"cost_per_output_token": 0.0,
|
||||
"cost_per_request": 0.0,
|
||||
"description": "AI Audio Generation default pricing"
|
||||
}
|
||||
]
|
||||
|
||||
# Combine all pricing data (include video pricing in search_pricing list)
|
||||
all_pricing = gemini_pricing + openai_pricing + anthropic_pricing + mistral_pricing + search_pricing
|
||||
|
||||
# Insert or update pricing data
|
||||
for pricing_data in all_pricing:
|
||||
existing = self.db.query(APIProviderPricing).filter(
|
||||
APIProviderPricing.provider == pricing_data["provider"],
|
||||
APIProviderPricing.model_name == pricing_data["model_name"]
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
# Update existing pricing (especially for HuggingFace if env vars changed)
|
||||
if pricing_data["provider"] == APIProvider.MISTRAL:
|
||||
# Update HuggingFace pricing from env vars
|
||||
existing.cost_per_input_token = pricing_data["cost_per_input_token"]
|
||||
existing.cost_per_output_token = pricing_data["cost_per_output_token"]
|
||||
existing.description = pricing_data["description"]
|
||||
existing.updated_at = datetime.utcnow()
|
||||
logger.debug(f"Updated pricing for {pricing_data['provider'].value}:{pricing_data['model_name']}")
|
||||
else:
|
||||
pricing = APIProviderPricing(**pricing_data)
|
||||
self.db.add(pricing)
|
||||
logger.debug(f"Added new pricing for {pricing_data['provider'].value}:{pricing_data['model_name']}")
|
||||
|
||||
self.db.commit()
|
||||
logger.info("Default API pricing initialized/updated. HuggingFace pricing loaded from env vars if available.")
|
||||
|
||||
def initialize_default_plans(self):
|
||||
"""Initialize default subscription plans."""
|
||||
|
||||
plans = [
|
||||
{
|
||||
"name": "Free",
|
||||
"tier": SubscriptionTier.FREE,
|
||||
"price_monthly": 0.0,
|
||||
"price_yearly": 0.0,
|
||||
"gemini_calls_limit": 100,
|
||||
"openai_calls_limit": 0,
|
||||
"anthropic_calls_limit": 0,
|
||||
"mistral_calls_limit": 50,
|
||||
"tavily_calls_limit": 20,
|
||||
"serper_calls_limit": 20,
|
||||
"metaphor_calls_limit": 10,
|
||||
"firecrawl_calls_limit": 10,
|
||||
"stability_calls_limit": 5,
|
||||
"exa_calls_limit": 100,
|
||||
"video_calls_limit": 0, # No video generation for free tier
|
||||
"image_edit_calls_limit": 10, # 10 AI image editing calls/month
|
||||
"audio_calls_limit": 20, # 20 AI audio generation calls/month
|
||||
"gemini_tokens_limit": 100000,
|
||||
"monthly_cost_limit": 0.0,
|
||||
"features": ["basic_content_generation", "limited_research"],
|
||||
"description": "Perfect for trying out ALwrity"
|
||||
},
|
||||
{
|
||||
"name": "Basic",
|
||||
"tier": SubscriptionTier.BASIC,
|
||||
"price_monthly": 29.0,
|
||||
"price_yearly": 290.0,
|
||||
"ai_text_generation_calls_limit": 10, # Unified limit for all LLM providers
|
||||
"gemini_calls_limit": 1000, # Legacy, kept for backwards compatibility (not used for enforcement)
|
||||
"openai_calls_limit": 500,
|
||||
"anthropic_calls_limit": 200,
|
||||
"mistral_calls_limit": 500,
|
||||
"tavily_calls_limit": 200,
|
||||
"serper_calls_limit": 200,
|
||||
"metaphor_calls_limit": 100,
|
||||
"firecrawl_calls_limit": 100,
|
||||
"stability_calls_limit": 5,
|
||||
"exa_calls_limit": 500,
|
||||
"video_calls_limit": 20, # 20 videos/month for basic plan
|
||||
"image_edit_calls_limit": 30, # 30 AI image editing calls/month
|
||||
"audio_calls_limit": 50, # 50 AI audio generation calls/month
|
||||
"gemini_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"openai_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"anthropic_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"mistral_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"monthly_cost_limit": 50.0,
|
||||
"features": ["full_content_generation", "advanced_research", "basic_analytics"],
|
||||
"description": "Great for individuals and small teams"
|
||||
},
|
||||
{
|
||||
"name": "Pro",
|
||||
"tier": SubscriptionTier.PRO,
|
||||
"price_monthly": 79.0,
|
||||
"price_yearly": 790.0,
|
||||
"gemini_calls_limit": 5000,
|
||||
"openai_calls_limit": 2500,
|
||||
"anthropic_calls_limit": 1000,
|
||||
"mistral_calls_limit": 2500,
|
||||
"tavily_calls_limit": 1000,
|
||||
"serper_calls_limit": 1000,
|
||||
"metaphor_calls_limit": 500,
|
||||
"firecrawl_calls_limit": 500,
|
||||
"stability_calls_limit": 200,
|
||||
"exa_calls_limit": 2000,
|
||||
"video_calls_limit": 50, # 50 videos/month for pro plan
|
||||
"image_edit_calls_limit": 100, # 100 AI image editing calls/month
|
||||
"audio_calls_limit": 200, # 200 AI audio generation calls/month
|
||||
"gemini_tokens_limit": 5000000,
|
||||
"openai_tokens_limit": 2500000,
|
||||
"anthropic_tokens_limit": 1000000,
|
||||
"mistral_tokens_limit": 2500000,
|
||||
"monthly_cost_limit": 150.0,
|
||||
"features": ["unlimited_content_generation", "premium_research", "advanced_analytics", "priority_support"],
|
||||
"description": "Perfect for growing businesses"
|
||||
},
|
||||
{
|
||||
"name": "Enterprise",
|
||||
"tier": SubscriptionTier.ENTERPRISE,
|
||||
"price_monthly": 199.0,
|
||||
"price_yearly": 1990.0,
|
||||
"gemini_calls_limit": 0, # Unlimited
|
||||
"openai_calls_limit": 0,
|
||||
"anthropic_calls_limit": 0,
|
||||
"mistral_calls_limit": 0,
|
||||
"tavily_calls_limit": 0,
|
||||
"serper_calls_limit": 0,
|
||||
"metaphor_calls_limit": 0,
|
||||
"firecrawl_calls_limit": 0,
|
||||
"stability_calls_limit": 0,
|
||||
"exa_calls_limit": 0, # Unlimited
|
||||
"video_calls_limit": 0, # Unlimited for enterprise
|
||||
"image_edit_calls_limit": 0, # Unlimited image editing for enterprise
|
||||
"audio_calls_limit": 0, # Unlimited audio generation for enterprise
|
||||
"gemini_tokens_limit": 0,
|
||||
"openai_tokens_limit": 0,
|
||||
"anthropic_tokens_limit": 0,
|
||||
"mistral_tokens_limit": 0,
|
||||
"monthly_cost_limit": 500.0,
|
||||
"features": ["unlimited_everything", "white_label", "dedicated_support", "custom_integrations"],
|
||||
"description": "For large organizations with high-volume needs"
|
||||
}
|
||||
]
|
||||
|
||||
for plan_data in plans:
|
||||
existing = self.db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.name == plan_data["name"]
|
||||
).first()
|
||||
|
||||
if not existing:
|
||||
plan = SubscriptionPlan(**plan_data)
|
||||
self.db.add(plan)
|
||||
else:
|
||||
# Update existing plan with new limits (e.g., image_edit_calls_limit)
|
||||
# This ensures existing plans get new columns like image_edit_calls_limit
|
||||
for key, value in plan_data.items():
|
||||
if key not in ["name", "tier"]: # Don't overwrite name/tier
|
||||
try:
|
||||
# Try to set the attribute (works even if column was just added)
|
||||
setattr(existing, key, value)
|
||||
except (AttributeError, Exception) as e:
|
||||
# If attribute doesn't exist yet (column not migrated), skip it
|
||||
# Schema migration will add it, then this will update it on next run
|
||||
logger.debug(f"Could not set {key} on plan {existing.name}: {e}")
|
||||
existing.updated_at = datetime.utcnow()
|
||||
logger.debug(f"Updated existing plan: {existing.name}")
|
||||
|
||||
self.db.commit()
|
||||
logger.debug("Default subscription plans initialized")
|
||||
|
||||
def calculate_api_cost(self, provider: APIProvider, model_name: str,
|
||||
tokens_input: int = 0, tokens_output: int = 0,
|
||||
request_count: int = 1, **kwargs) -> Dict[str, float]:
|
||||
"""Calculate cost for an API call.
|
||||
|
||||
Args:
|
||||
provider: APIProvider enum (e.g., APIProvider.MISTRAL for HuggingFace)
|
||||
model_name: Model name (e.g., "openai/gpt-oss-120b:groq")
|
||||
tokens_input: Number of input tokens
|
||||
tokens_output: Number of output tokens
|
||||
request_count: Number of requests (default: 1)
|
||||
**kwargs: Additional parameters (search_count, image_count, page_count, etc.)
|
||||
|
||||
Returns:
|
||||
Dict with cost_input, cost_output, and cost_total
|
||||
"""
|
||||
|
||||
# Get pricing for the provider and model
|
||||
# Try exact match first
|
||||
pricing = self.db.query(APIProviderPricing).filter(
|
||||
APIProviderPricing.provider == provider,
|
||||
APIProviderPricing.model_name == model_name,
|
||||
APIProviderPricing.is_active == True
|
||||
).first()
|
||||
|
||||
# If not found, try "default" model name for the provider
|
||||
if not pricing:
|
||||
pricing = self.db.query(APIProviderPricing).filter(
|
||||
APIProviderPricing.provider == provider,
|
||||
APIProviderPricing.model_name == "default",
|
||||
APIProviderPricing.is_active == True
|
||||
).first()
|
||||
|
||||
# If still not found, check for HuggingFace models (provider is MISTRAL)
|
||||
# Try alternative model name variations
|
||||
if not pricing and provider == APIProvider.MISTRAL:
|
||||
# Try with "gpt-oss-120b" (without full path) if model contains it
|
||||
if "gpt-oss-120b" in model_name.lower():
|
||||
pricing = self.db.query(APIProviderPricing).filter(
|
||||
APIProviderPricing.provider == provider,
|
||||
APIProviderPricing.model_name == "gpt-oss-120b",
|
||||
APIProviderPricing.is_active == True
|
||||
).first()
|
||||
|
||||
# Also try with full model path
|
||||
if not pricing:
|
||||
pricing = self.db.query(APIProviderPricing).filter(
|
||||
APIProviderPricing.provider == provider,
|
||||
APIProviderPricing.model_name == "openai/gpt-oss-120b:groq",
|
||||
APIProviderPricing.is_active == True
|
||||
).first()
|
||||
|
||||
if not pricing:
|
||||
# Check if we should use env vars for HuggingFace/Mistral
|
||||
if provider == APIProvider.MISTRAL:
|
||||
# Use environment variables for HuggingFace pricing if available
|
||||
hf_input_cost = float(os.getenv('HUGGINGFACE_INPUT_TOKEN_COST', '0.000001'))
|
||||
hf_output_cost = float(os.getenv('HUGGINGFACE_OUTPUT_TOKEN_COST', '0.000003'))
|
||||
logger.info(f"Using HuggingFace pricing from env vars: input={hf_input_cost}, output={hf_output_cost} for model {model_name}")
|
||||
cost_input = tokens_input * hf_input_cost
|
||||
cost_output = tokens_output * hf_output_cost
|
||||
cost_total = cost_input + cost_output
|
||||
else:
|
||||
logger.warning(f"No pricing found for {provider.value}:{model_name}, using default estimates")
|
||||
# Use default estimates
|
||||
cost_input = tokens_input * 0.000001 # $1 per 1M tokens default
|
||||
cost_output = tokens_output * 0.000001
|
||||
cost_total = cost_input + cost_output
|
||||
else:
|
||||
# Calculate based on actual pricing from database
|
||||
logger.debug(f"Using pricing from DB for {provider.value}:{model_name} - input: {pricing.cost_per_input_token}, output: {pricing.cost_per_output_token}")
|
||||
cost_input = tokens_input * (pricing.cost_per_input_token or 0.0)
|
||||
cost_output = tokens_output * (pricing.cost_per_output_token or 0.0)
|
||||
cost_request = request_count * (pricing.cost_per_request or 0.0)
|
||||
|
||||
# Handle special cases for non-LLM APIs
|
||||
cost_search = kwargs.get('search_count', 0) * (pricing.cost_per_search or 0.0)
|
||||
cost_image = kwargs.get('image_count', 0) * (pricing.cost_per_image or 0.0)
|
||||
cost_page = kwargs.get('page_count', 0) * (pricing.cost_per_page or 0.0)
|
||||
|
||||
cost_total = cost_input + cost_output + cost_request + cost_search + cost_image + cost_page
|
||||
|
||||
# Round to 6 decimal places for precision
|
||||
return {
|
||||
'cost_input': round(cost_input, 6),
|
||||
'cost_output': round(cost_output, 6),
|
||||
'cost_total': round(cost_total, 6)
|
||||
}
|
||||
|
||||
def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get usage limits for a user based on their subscription."""
|
||||
|
||||
# CRITICAL: Expire all objects first to ensure fresh data after renewal
|
||||
self.db.expire_all()
|
||||
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
|
||||
if not subscription:
|
||||
# Return free tier limits
|
||||
free_plan = self.db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.FREE
|
||||
).first()
|
||||
if free_plan:
|
||||
return self._plan_to_limits_dict(free_plan)
|
||||
return None
|
||||
|
||||
# Ensure current period before returning limits
|
||||
self._ensure_subscription_current(subscription)
|
||||
|
||||
# CRITICAL: Refresh subscription to get latest plan_id, then refresh plan relationship
|
||||
self.db.refresh(subscription)
|
||||
|
||||
# Re-query plan directly to ensure fresh data (bypass relationship cache)
|
||||
plan = self.db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.id == subscription.plan_id
|
||||
).first()
|
||||
|
||||
if not plan:
|
||||
logger.error(f"Plan not found for subscription plan_id={subscription.plan_id}")
|
||||
return None
|
||||
|
||||
# Refresh plan to ensure fresh limits
|
||||
self.db.refresh(plan)
|
||||
|
||||
return self._plan_to_limits_dict(plan)
|
||||
|
||||
def _ensure_ai_text_gen_column_detection(self) -> None:
|
||||
"""Detect at runtime whether ai_text_generation_calls_limit column exists and cache the result."""
|
||||
if self._ai_text_gen_col_checked:
|
||||
return
|
||||
try:
|
||||
# Try to query the column - if it exists, this will work
|
||||
self.db.execute(text('SELECT ai_text_generation_calls_limit FROM subscription_plans LIMIT 0'))
|
||||
self._ai_text_gen_col_available = True
|
||||
except Exception:
|
||||
self._ai_text_gen_col_available = False
|
||||
finally:
|
||||
self._ai_text_gen_col_checked = True
|
||||
|
||||
def _plan_to_limits_dict(self, plan: SubscriptionPlan) -> Dict[str, Any]:
|
||||
"""Convert subscription plan to limits dictionary."""
|
||||
# Detect if unified AI text generation limit column exists
|
||||
self._ensure_ai_text_gen_column_detection()
|
||||
|
||||
# Use unified AI text generation limit if column exists and is set
|
||||
ai_text_gen_limit = None
|
||||
if self._ai_text_gen_col_available:
|
||||
try:
|
||||
ai_text_gen_limit = getattr(plan, 'ai_text_generation_calls_limit', None)
|
||||
# If 0, treat as not set (unlimited for Enterprise or use fallback)
|
||||
if ai_text_gen_limit == 0:
|
||||
ai_text_gen_limit = None
|
||||
except (AttributeError, Exception):
|
||||
# Column exists but access failed - use fallback
|
||||
ai_text_gen_limit = None
|
||||
|
||||
return {
|
||||
'plan_name': plan.name,
|
||||
'tier': plan.tier.value,
|
||||
'limits': {
|
||||
# Unified AI text generation limit (applies to all LLM providers)
|
||||
# If not set, fall back to first non-zero legacy limit for backwards compatibility
|
||||
'ai_text_generation_calls': ai_text_gen_limit if ai_text_gen_limit is not None else (
|
||||
plan.gemini_calls_limit if plan.gemini_calls_limit > 0 else
|
||||
plan.openai_calls_limit if plan.openai_calls_limit > 0 else
|
||||
plan.anthropic_calls_limit if plan.anthropic_calls_limit > 0 else
|
||||
plan.mistral_calls_limit if plan.mistral_calls_limit > 0 else 0
|
||||
),
|
||||
# Legacy per-provider limits (for backwards compatibility and analytics)
|
||||
'gemini_calls': plan.gemini_calls_limit,
|
||||
'openai_calls': plan.openai_calls_limit,
|
||||
'anthropic_calls': plan.anthropic_calls_limit,
|
||||
'mistral_calls': plan.mistral_calls_limit,
|
||||
# Other API limits
|
||||
'tavily_calls': plan.tavily_calls_limit,
|
||||
'serper_calls': plan.serper_calls_limit,
|
||||
'metaphor_calls': plan.metaphor_calls_limit,
|
||||
'firecrawl_calls': plan.firecrawl_calls_limit,
|
||||
'stability_calls': plan.stability_calls_limit,
|
||||
'video_calls': getattr(plan, 'video_calls_limit', 0), # Support missing column
|
||||
'image_edit_calls': getattr(plan, 'image_edit_calls_limit', 0), # Support missing column
|
||||
'audio_calls': getattr(plan, 'audio_calls_limit', 0), # Support missing column
|
||||
# Token limits
|
||||
'gemini_tokens': plan.gemini_tokens_limit,
|
||||
'openai_tokens': plan.openai_tokens_limit,
|
||||
'anthropic_tokens': plan.anthropic_tokens_limit,
|
||||
'mistral_tokens': plan.mistral_tokens_limit,
|
||||
'monthly_cost': plan.monthly_cost_limit
|
||||
},
|
||||
'features': plan.features or []
|
||||
}
|
||||
|
||||
def check_usage_limits(self, user_id: str, provider: APIProvider,
|
||||
tokens_requested: int = 0, actual_provider_name: Optional[str] = None) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""Check if user can make an API call within their limits.
|
||||
|
||||
Delegates to LimitValidator for actual validation logic.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
provider: APIProvider enum (may be MISTRAL for HuggingFace)
|
||||
tokens_requested: Estimated tokens for the request
|
||||
actual_provider_name: Optional actual provider name (e.g., "huggingface" when provider is MISTRAL)
|
||||
|
||||
Returns:
|
||||
(can_proceed, error_message, usage_info)
|
||||
"""
|
||||
from .limit_validation import LimitValidator
|
||||
validator = LimitValidator(self)
|
||||
return validator.check_usage_limits(user_id, provider, tokens_requested, actual_provider_name)
|
||||
|
||||
def estimate_tokens(self, text: str, provider: APIProvider) -> int:
|
||||
"""Estimate token count for text based on provider."""
|
||||
|
||||
# Get pricing info for token estimation
|
||||
pricing = self.db.query(APIProviderPricing).filter(
|
||||
APIProviderPricing.provider == provider,
|
||||
APIProviderPricing.is_active == True
|
||||
).first()
|
||||
|
||||
if pricing and pricing.tokens_per_word:
|
||||
# Use provider-specific conversion
|
||||
word_count = len(text.split())
|
||||
return int(word_count * pricing.tokens_per_word)
|
||||
else:
|
||||
# Use default estimation (roughly 1.3 tokens per word for most models)
|
||||
word_count = len(text.split())
|
||||
return int(word_count * 1.3)
|
||||
|
||||
def get_pricing_info(self, provider: APIProvider, model_name: str = None) -> Optional[Dict[str, Any]]:
|
||||
"""Get pricing information for a provider/model."""
|
||||
|
||||
query = self.db.query(APIProviderPricing).filter(
|
||||
APIProviderPricing.provider == provider,
|
||||
APIProviderPricing.is_active == True
|
||||
)
|
||||
|
||||
if model_name:
|
||||
query = query.filter(APIProviderPricing.model_name == model_name)
|
||||
|
||||
pricing = query.first()
|
||||
|
||||
if not pricing:
|
||||
return None
|
||||
|
||||
# Return pricing info as dict
|
||||
return {
|
||||
'provider': pricing.provider.value,
|
||||
'model_name': pricing.model_name,
|
||||
'cost_per_input_token': pricing.cost_per_input_token,
|
||||
'cost_per_output_token': pricing.cost_per_output_token,
|
||||
'cost_per_request': pricing.cost_per_request,
|
||||
'description': pricing.description
|
||||
}
|
||||
|
||||
def check_comprehensive_limits(
|
||||
self,
|
||||
user_id: str,
|
||||
operations: List[Dict[str, Any]]
|
||||
) -> Tuple[bool, Optional[str], Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
Comprehensive pre-flight validation that checks ALL limits before making ANY API calls.
|
||||
|
||||
Delegates to LimitValidator for actual validation logic.
|
||||
This prevents wasteful API calls by validating that ALL subsequent operations will succeed
|
||||
before making the first external API call.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
operations: List of operations to validate, each with:
|
||||
- 'provider': APIProvider enum
|
||||
- 'tokens_requested': int (estimated tokens for LLM calls, 0 for non-LLM)
|
||||
- 'actual_provider_name': Optional[str] (e.g., "huggingface" when provider is MISTRAL)
|
||||
- 'operation_type': str (e.g., "google_grounding", "llm_call", "image_generation")
|
||||
|
||||
Returns:
|
||||
(can_proceed, error_message, error_details)
|
||||
If can_proceed is False, error_message explains which limit would be exceeded
|
||||
"""
|
||||
from .limit_validation import LimitValidator
|
||||
validator = LimitValidator(self)
|
||||
return validator.check_comprehensive_limits(user_id, operations)
|
||||
|
||||
def get_pricing_for_provider_model(self, provider: APIProvider, model_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get pricing configuration for a specific provider and model."""
|
||||
pricing = self.db.query(APIProviderPricing).filter(
|
||||
APIProviderPricing.provider == provider,
|
||||
APIProviderPricing.model_name == model_name
|
||||
).first()
|
||||
|
||||
if not pricing:
|
||||
return None
|
||||
|
||||
return {
|
||||
'provider': pricing.provider.value,
|
||||
'model_name': pricing.model_name,
|
||||
'cost_per_input_token': pricing.cost_per_input_token,
|
||||
'cost_per_output_token': pricing.cost_per_output_token,
|
||||
'cost_per_request': pricing.cost_per_request,
|
||||
'cost_per_search': pricing.cost_per_search,
|
||||
'cost_per_image': pricing.cost_per_image,
|
||||
'cost_per_page': pricing.cost_per_page,
|
||||
'description': pricing.description
|
||||
}
|
||||
122
backend/services/subscription/schema_utils.py
Normal file
122
backend/services/subscription/schema_utils.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from typing import Set
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
from loguru import logger
|
||||
|
||||
|
||||
_checked_subscription_plan_columns: bool = False
|
||||
_checked_usage_summaries_columns: bool = False
|
||||
|
||||
|
||||
def ensure_subscription_plan_columns(db: Session) -> None:
|
||||
"""Ensure required columns exist on subscription_plans for runtime safety.
|
||||
|
||||
This is a defensive guard for environments where migrations have not yet
|
||||
been applied. If columns are missing (e.g., exa_calls_limit), we add them
|
||||
with a safe default so ORM queries do not fail.
|
||||
"""
|
||||
global _checked_subscription_plan_columns
|
||||
if _checked_subscription_plan_columns:
|
||||
return
|
||||
|
||||
try:
|
||||
# Discover existing columns using PRAGMA
|
||||
result = db.execute(text("PRAGMA table_info(subscription_plans)"))
|
||||
cols: Set[str] = {row[1] for row in result}
|
||||
|
||||
logger.debug(f"Schema check: Found {len(cols)} columns in subscription_plans table")
|
||||
|
||||
# Columns we may reference in models but might be missing in older DBs
|
||||
required_columns = {
|
||||
"exa_calls_limit": "INTEGER DEFAULT 0",
|
||||
"video_calls_limit": "INTEGER DEFAULT 0",
|
||||
"image_edit_calls_limit": "INTEGER DEFAULT 0",
|
||||
"audio_calls_limit": "INTEGER DEFAULT 0",
|
||||
}
|
||||
|
||||
for col_name, ddl in required_columns.items():
|
||||
if col_name not in cols:
|
||||
logger.info(f"Adding missing column {col_name} to subscription_plans table")
|
||||
try:
|
||||
db.execute(text(f"ALTER TABLE subscription_plans ADD COLUMN {col_name} {ddl}"))
|
||||
db.commit()
|
||||
logger.info(f"Successfully added column {col_name}")
|
||||
except Exception as alter_err:
|
||||
logger.error(f"Failed to add column {col_name}: {alter_err}")
|
||||
db.rollback()
|
||||
# Don't set flag on error - allow retry
|
||||
raise
|
||||
else:
|
||||
logger.debug(f"Column {col_name} already exists")
|
||||
|
||||
# Only set flag if we successfully completed the check
|
||||
_checked_subscription_plan_columns = True
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring subscription_plan columns: {e}", exc_info=True)
|
||||
db.rollback()
|
||||
# Don't set the flag if there was an error, so we retry next time
|
||||
_checked_subscription_plan_columns = False
|
||||
raise
|
||||
|
||||
|
||||
def ensure_usage_summaries_columns(db: Session) -> None:
|
||||
"""Ensure required columns exist on usage_summaries for runtime safety.
|
||||
|
||||
This is a defensive guard for environments where migrations have not yet
|
||||
been applied. If columns are missing (e.g., exa_calls, exa_cost), we add them
|
||||
with a safe default so ORM queries do not fail.
|
||||
"""
|
||||
global _checked_usage_summaries_columns
|
||||
if _checked_usage_summaries_columns:
|
||||
return
|
||||
|
||||
try:
|
||||
# Discover existing columns using PRAGMA
|
||||
result = db.execute(text("PRAGMA table_info(usage_summaries)"))
|
||||
cols: Set[str] = {row[1] for row in result}
|
||||
|
||||
logger.debug(f"Schema check: Found {len(cols)} columns in usage_summaries table")
|
||||
|
||||
# Columns we may reference in models but might be missing in older DBs
|
||||
required_columns = {
|
||||
"exa_calls": "INTEGER DEFAULT 0",
|
||||
"exa_cost": "REAL DEFAULT 0.0",
|
||||
"video_calls": "INTEGER DEFAULT 0",
|
||||
"video_cost": "REAL DEFAULT 0.0",
|
||||
"image_edit_calls": "INTEGER DEFAULT 0",
|
||||
"image_edit_cost": "REAL DEFAULT 0.0",
|
||||
"audio_calls": "INTEGER DEFAULT 0",
|
||||
"audio_cost": "REAL DEFAULT 0.0",
|
||||
}
|
||||
|
||||
for col_name, ddl in required_columns.items():
|
||||
if col_name not in cols:
|
||||
logger.info(f"Adding missing column {col_name} to usage_summaries table")
|
||||
try:
|
||||
db.execute(text(f"ALTER TABLE usage_summaries ADD COLUMN {col_name} {ddl}"))
|
||||
db.commit()
|
||||
logger.info(f"Successfully added column {col_name}")
|
||||
except Exception as alter_err:
|
||||
logger.error(f"Failed to add column {col_name}: {alter_err}")
|
||||
db.rollback()
|
||||
# Don't set flag on error - allow retry
|
||||
raise
|
||||
else:
|
||||
logger.debug(f"Column {col_name} already exists")
|
||||
|
||||
# Only set flag if we successfully completed the check
|
||||
_checked_usage_summaries_columns = True
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring usage_summaries columns: {e}", exc_info=True)
|
||||
db.rollback()
|
||||
# Don't set the flag if there was an error, so we retry next time
|
||||
_checked_usage_summaries_columns = False
|
||||
raise
|
||||
|
||||
|
||||
def ensure_all_schema_columns(db: Session) -> None:
|
||||
"""Ensure all required columns exist in subscription-related tables."""
|
||||
ensure_subscription_plan_columns(db)
|
||||
ensure_usage_summaries_columns(db)
|
||||
|
||||
|
||||
644
backend/services/subscription/usage_tracking_service.py
Normal file
644
backend/services/subscription/usage_tracking_service.py
Normal file
@@ -0,0 +1,644 @@
|
||||
"""
|
||||
Usage Tracking Service
|
||||
Comprehensive tracking of API usage, costs, and subscription limits.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional, List, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
import json
|
||||
|
||||
from models.subscription_models import (
|
||||
APIUsageLog, UsageSummary, APIProvider, UsageAlert,
|
||||
UserSubscription, UsageStatus
|
||||
)
|
||||
from .pricing_service import PricingService
|
||||
|
||||
class UsageTrackingService:
|
||||
"""Service for tracking API usage and managing subscription limits."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.pricing_service = PricingService(db)
|
||||
# TTL cache (30s) for enforcement results to cut DB chatter
|
||||
# key: f"{user_id}:{provider}", value: { 'result': (bool,str,dict), 'expires_at': datetime }
|
||||
self._enforce_cache: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
async def track_api_usage(self, user_id: str, provider: APIProvider,
|
||||
endpoint: str, method: str, model_used: str = None,
|
||||
tokens_input: int = 0, tokens_output: int = 0,
|
||||
response_time: float = 0.0, status_code: int = 200,
|
||||
request_size: int = None, response_size: int = None,
|
||||
user_agent: str = None, ip_address: str = None,
|
||||
error_message: str = None, retry_count: int = 0,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
"""Track an API usage event and update billing information."""
|
||||
|
||||
try:
|
||||
# Calculate costs
|
||||
# Use specific model names instead of generic defaults
|
||||
default_models = {
|
||||
"gemini": "gemini-2.5-flash", # Use Flash as default (cost-effective)
|
||||
"openai": "gpt-4o-mini", # Use Mini as default (cost-effective)
|
||||
"anthropic": "claude-3.5-sonnet", # Use Sonnet as default
|
||||
"mistral": "openai/gpt-oss-120b:groq" # HuggingFace default model
|
||||
}
|
||||
|
||||
# For HuggingFace (stored as MISTRAL), use the actual model name or default
|
||||
if provider == APIProvider.MISTRAL:
|
||||
# HuggingFace models - try to match the actual model name from model_used
|
||||
if model_used:
|
||||
model_name = model_used
|
||||
else:
|
||||
model_name = default_models.get("mistral", "openai/gpt-oss-120b:groq")
|
||||
else:
|
||||
model_name = model_used or default_models.get(provider.value, f"{provider.value}-default")
|
||||
|
||||
cost_data = self.pricing_service.calculate_api_cost(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
tokens_input=tokens_input,
|
||||
tokens_output=tokens_output,
|
||||
request_count=1,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Create usage log entry
|
||||
billing_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
endpoint=endpoint,
|
||||
method=method,
|
||||
model_used=model_used,
|
||||
tokens_input=tokens_input,
|
||||
tokens_output=tokens_output,
|
||||
tokens_total=(tokens_input or 0) + (tokens_output or 0),
|
||||
cost_input=cost_data['cost_input'],
|
||||
cost_output=cost_data['cost_output'],
|
||||
cost_total=cost_data['cost_total'],
|
||||
response_time=response_time,
|
||||
status_code=status_code,
|
||||
request_size=request_size,
|
||||
response_size=response_size,
|
||||
user_agent=user_agent,
|
||||
ip_address=ip_address,
|
||||
error_message=error_message,
|
||||
retry_count=retry_count,
|
||||
billing_period=billing_period
|
||||
)
|
||||
|
||||
self.db.add(usage_log)
|
||||
|
||||
# Update usage summary
|
||||
await self._update_usage_summary(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
tokens_used=(tokens_input or 0) + (tokens_output or 0),
|
||||
cost=cost_data['cost_total'],
|
||||
billing_period=billing_period,
|
||||
response_time=response_time,
|
||||
is_error=status_code >= 400
|
||||
)
|
||||
|
||||
# Check for usage alerts
|
||||
await self._check_usage_alerts(user_id, provider, billing_period)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"Tracked API usage: {user_id} -> {provider.value} -> ${cost_data['cost_total']:.6f}")
|
||||
|
||||
return {
|
||||
'usage_logged': True,
|
||||
'cost': cost_data['cost_total'],
|
||||
'tokens_used': (tokens_input or 0) + (tokens_output or 0),
|
||||
'billing_period': billing_period
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error tracking API usage: {str(e)}")
|
||||
self.db.rollback()
|
||||
return {
|
||||
'usage_logged': False,
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
async def _update_usage_summary(self, user_id: str, provider: APIProvider,
|
||||
tokens_used: int, cost: float, billing_period: str,
|
||||
response_time: float, is_error: bool):
|
||||
"""Update the usage summary for a user."""
|
||||
|
||||
# Get or create usage summary
|
||||
summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == billing_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=billing_period
|
||||
)
|
||||
self.db.add(summary)
|
||||
|
||||
# Update provider-specific counters
|
||||
provider_name = provider.value
|
||||
current_calls = getattr(summary, f"{provider_name}_calls", 0)
|
||||
setattr(summary, f"{provider_name}_calls", current_calls + 1)
|
||||
|
||||
# Update token usage for LLM providers
|
||||
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
current_tokens = getattr(summary, f"{provider_name}_tokens", 0)
|
||||
setattr(summary, f"{provider_name}_tokens", current_tokens + tokens_used)
|
||||
|
||||
# Update cost
|
||||
current_cost = getattr(summary, f"{provider_name}_cost", 0.0)
|
||||
setattr(summary, f"{provider_name}_cost", current_cost + cost)
|
||||
|
||||
# Update totals
|
||||
summary.total_calls += 1
|
||||
summary.total_tokens += tokens_used
|
||||
summary.total_cost += cost
|
||||
|
||||
# Update performance metrics
|
||||
if summary.total_calls > 0:
|
||||
# Update average response time
|
||||
total_response_time = summary.avg_response_time * (summary.total_calls - 1) + response_time
|
||||
summary.avg_response_time = total_response_time / summary.total_calls
|
||||
|
||||
# Update error rate
|
||||
if is_error:
|
||||
error_count = int(summary.error_rate * (summary.total_calls - 1) / 100) + 1
|
||||
summary.error_rate = (error_count / summary.total_calls) * 100
|
||||
else:
|
||||
error_count = int(summary.error_rate * (summary.total_calls - 1) / 100)
|
||||
summary.error_rate = (error_count / summary.total_calls) * 100
|
||||
|
||||
# Update usage status based on limits
|
||||
await self._update_usage_status(summary)
|
||||
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
async def _update_usage_status(self, summary: UsageSummary):
|
||||
"""Update usage status based on subscription limits."""
|
||||
|
||||
limits = self.pricing_service.get_user_limits(summary.user_id)
|
||||
if not limits:
|
||||
return
|
||||
|
||||
# Check various limits and determine status
|
||||
max_usage_percentage = 0.0
|
||||
|
||||
# Check cost limit
|
||||
cost_limit = limits['limits'].get('monthly_cost', 0)
|
||||
if cost_limit > 0:
|
||||
cost_usage_pct = (summary.total_cost / cost_limit) * 100
|
||||
max_usage_percentage = max(max_usage_percentage, cost_usage_pct)
|
||||
|
||||
# Check call limits for each provider
|
||||
for provider in APIProvider:
|
||||
provider_name = provider.value
|
||||
current_calls = getattr(summary, f"{provider_name}_calls", 0)
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
|
||||
|
||||
if call_limit > 0:
|
||||
call_usage_pct = (current_calls / call_limit) * 100
|
||||
max_usage_percentage = max(max_usage_percentage, call_usage_pct)
|
||||
|
||||
# Update status based on highest usage percentage
|
||||
if max_usage_percentage >= 100:
|
||||
summary.usage_status = UsageStatus.LIMIT_REACHED
|
||||
elif max_usage_percentage >= 80:
|
||||
summary.usage_status = UsageStatus.WARNING
|
||||
else:
|
||||
summary.usage_status = UsageStatus.ACTIVE
|
||||
|
||||
async def _check_usage_alerts(self, user_id: str, provider: APIProvider, billing_period: str):
|
||||
"""Check if usage alerts should be sent."""
|
||||
|
||||
# Get current usage
|
||||
summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == billing_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
return
|
||||
|
||||
# Get user limits
|
||||
limits = self.pricing_service.get_user_limits(user_id)
|
||||
if not limits:
|
||||
return
|
||||
|
||||
# Check for alert thresholds (80%, 90%, 100%)
|
||||
thresholds = [80, 90, 100]
|
||||
|
||||
for threshold in thresholds:
|
||||
# Check if alert already sent for this threshold
|
||||
existing_alert = self.db.query(UsageAlert).filter(
|
||||
UsageAlert.user_id == user_id,
|
||||
UsageAlert.billing_period == billing_period,
|
||||
UsageAlert.threshold_percentage == threshold,
|
||||
UsageAlert.provider == provider,
|
||||
UsageAlert.is_sent == True
|
||||
).first()
|
||||
|
||||
if existing_alert:
|
||||
continue
|
||||
|
||||
# Check if threshold is reached
|
||||
provider_name = provider.value
|
||||
current_calls = getattr(summary, f"{provider_name}_calls", 0)
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
|
||||
|
||||
if call_limit > 0:
|
||||
usage_percentage = (current_calls / call_limit) * 100
|
||||
|
||||
if usage_percentage >= threshold:
|
||||
await self._create_usage_alert(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
threshold=threshold,
|
||||
current_usage=current_calls,
|
||||
limit=call_limit,
|
||||
billing_period=billing_period
|
||||
)
|
||||
|
||||
async def _create_usage_alert(self, user_id: str, provider: APIProvider,
|
||||
threshold: int, current_usage: int, limit: int,
|
||||
billing_period: str):
|
||||
"""Create a usage alert."""
|
||||
|
||||
# Determine alert type and severity
|
||||
if threshold >= 100:
|
||||
alert_type = "limit_reached"
|
||||
severity = "error"
|
||||
title = f"API Limit Reached - {provider.value.title()}"
|
||||
message = f"You have reached your {provider.value} API limit of {limit:,} calls for this billing period."
|
||||
elif threshold >= 90:
|
||||
alert_type = "usage_warning"
|
||||
severity = "warning"
|
||||
title = f"API Usage Warning - {provider.value.title()}"
|
||||
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
|
||||
else:
|
||||
alert_type = "usage_warning"
|
||||
severity = "info"
|
||||
title = f"API Usage Notice - {provider.value.title()}"
|
||||
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
|
||||
|
||||
alert = UsageAlert(
|
||||
user_id=user_id,
|
||||
alert_type=alert_type,
|
||||
threshold_percentage=threshold,
|
||||
provider=provider,
|
||||
title=title,
|
||||
message=message,
|
||||
severity=severity,
|
||||
billing_period=billing_period
|
||||
)
|
||||
|
||||
self.db.add(alert)
|
||||
logger.info(f"Created usage alert for {user_id}: {title}")
|
||||
|
||||
def get_user_usage_stats(self, user_id: str, billing_period: str = None) -> Dict[str, Any]:
|
||||
"""Get comprehensive usage statistics for a user."""
|
||||
|
||||
if not billing_period:
|
||||
billing_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get usage summary
|
||||
summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == billing_period
|
||||
).first()
|
||||
|
||||
# Get user limits
|
||||
limits = self.pricing_service.get_user_limits(user_id)
|
||||
|
||||
# Get recent alerts
|
||||
alerts = self.db.query(UsageAlert).filter(
|
||||
UsageAlert.user_id == user_id,
|
||||
UsageAlert.billing_period == billing_period,
|
||||
UsageAlert.is_read == False
|
||||
).order_by(UsageAlert.created_at.desc()).limit(10).all()
|
||||
|
||||
if not summary:
|
||||
# No usage this period - return complete structure with zeros
|
||||
provider_breakdown = {}
|
||||
usage_percentages = {}
|
||||
|
||||
# Initialize provider breakdown with zeros
|
||||
for provider in APIProvider:
|
||||
provider_name = provider.value
|
||||
provider_breakdown[provider_name] = {
|
||||
'calls': 0,
|
||||
'tokens': 0,
|
||||
'cost': 0.0
|
||||
}
|
||||
usage_percentages[f"{provider_name}_calls"] = 0
|
||||
|
||||
usage_percentages['cost'] = 0
|
||||
|
||||
return {
|
||||
'billing_period': billing_period,
|
||||
'usage_status': 'active',
|
||||
'total_calls': 0,
|
||||
'total_tokens': 0,
|
||||
'total_cost': 0.0,
|
||||
'avg_response_time': 0.0,
|
||||
'error_rate': 0.0,
|
||||
'last_updated': datetime.now().isoformat(),
|
||||
'limits': limits,
|
||||
'provider_breakdown': provider_breakdown,
|
||||
'alerts': [],
|
||||
'usage_percentages': {}
|
||||
}
|
||||
|
||||
# Provider breakdown - calculate costs first, then use for percentages
|
||||
# Only include Gemini and HuggingFace (HuggingFace is stored under MISTRAL enum)
|
||||
provider_breakdown = {}
|
||||
|
||||
# Gemini
|
||||
gemini_calls = getattr(summary, "gemini_calls", 0) or 0
|
||||
gemini_tokens = getattr(summary, "gemini_tokens", 0) or 0
|
||||
gemini_cost = getattr(summary, "gemini_cost", 0.0) or 0.0
|
||||
|
||||
# If gemini cost is 0 but there are calls, calculate from usage logs
|
||||
if gemini_calls > 0 and gemini_cost == 0.0:
|
||||
gemini_logs = self.db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id,
|
||||
APIUsageLog.provider == APIProvider.GEMINI,
|
||||
APIUsageLog.billing_period == billing_period
|
||||
).all()
|
||||
if gemini_logs:
|
||||
gemini_cost = sum(float(log.cost_total or 0.0) for log in gemini_logs)
|
||||
logger.info(f"[UsageStats] Calculated gemini cost from {len(gemini_logs)} logs: ${gemini_cost:.6f}")
|
||||
|
||||
provider_breakdown['gemini'] = {
|
||||
'calls': gemini_calls,
|
||||
'tokens': gemini_tokens,
|
||||
'cost': gemini_cost
|
||||
}
|
||||
|
||||
# HuggingFace (stored as MISTRAL in database)
|
||||
mistral_calls = getattr(summary, "mistral_calls", 0) or 0
|
||||
mistral_tokens = getattr(summary, "mistral_tokens", 0) or 0
|
||||
mistral_cost = getattr(summary, "mistral_cost", 0.0) or 0.0
|
||||
|
||||
# If mistral (HuggingFace) cost is 0 but there are calls, calculate from usage logs
|
||||
if mistral_calls > 0 and mistral_cost == 0.0:
|
||||
mistral_logs = self.db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id,
|
||||
APIUsageLog.provider == APIProvider.MISTRAL,
|
||||
APIUsageLog.billing_period == billing_period
|
||||
).all()
|
||||
if mistral_logs:
|
||||
mistral_cost = sum(float(log.cost_total or 0.0) for log in mistral_logs)
|
||||
logger.info(f"[UsageStats] Calculated mistral (HuggingFace) cost from {len(mistral_logs)} logs: ${mistral_cost:.6f}")
|
||||
|
||||
provider_breakdown['huggingface'] = {
|
||||
'calls': mistral_calls,
|
||||
'tokens': mistral_tokens,
|
||||
'cost': mistral_cost
|
||||
}
|
||||
|
||||
# Calculate total cost from provider breakdown if summary total_cost is 0
|
||||
calculated_total_cost = gemini_cost + mistral_cost
|
||||
summary_total_cost = summary.total_cost or 0.0
|
||||
# Use calculated cost if summary cost is 0, otherwise use summary cost (it's more accurate)
|
||||
final_total_cost = summary_total_cost if summary_total_cost > 0 else calculated_total_cost
|
||||
|
||||
# If we calculated costs from logs, update the summary for future requests
|
||||
if calculated_total_cost > 0 and summary_total_cost == 0.0:
|
||||
logger.info(f"[UsageStats] Updating summary costs: total_cost={final_total_cost:.6f}, gemini_cost={gemini_cost:.6f}, mistral_cost={mistral_cost:.6f}")
|
||||
summary.total_cost = final_total_cost
|
||||
summary.gemini_cost = gemini_cost
|
||||
summary.mistral_cost = mistral_cost
|
||||
try:
|
||||
self.db.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"[UsageStats] Error updating summary costs: {e}")
|
||||
self.db.rollback()
|
||||
|
||||
# Calculate usage percentages - only for Gemini and HuggingFace
|
||||
# Use the calculated costs for accurate percentages
|
||||
usage_percentages = {}
|
||||
if limits:
|
||||
# Gemini
|
||||
gemini_call_limit = limits['limits'].get("gemini_calls", 0) or 0
|
||||
if gemini_call_limit > 0:
|
||||
usage_percentages['gemini_calls'] = (gemini_calls / gemini_call_limit) * 100
|
||||
else:
|
||||
usage_percentages['gemini_calls'] = 0
|
||||
|
||||
# HuggingFace (stored as mistral in database)
|
||||
mistral_call_limit = limits['limits'].get("mistral_calls", 0) or 0
|
||||
if mistral_call_limit > 0:
|
||||
usage_percentages['mistral_calls'] = (mistral_calls / mistral_call_limit) * 100
|
||||
else:
|
||||
usage_percentages['mistral_calls'] = 0
|
||||
|
||||
# Cost usage percentage - use final_total_cost (calculated from logs if needed)
|
||||
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
|
||||
if cost_limit > 0:
|
||||
usage_percentages['cost'] = (final_total_cost / cost_limit) * 100
|
||||
else:
|
||||
usage_percentages['cost'] = 0
|
||||
|
||||
return {
|
||||
'billing_period': billing_period,
|
||||
'usage_status': summary.usage_status.value if hasattr(summary.usage_status, 'value') else str(summary.usage_status),
|
||||
'total_calls': summary.total_calls or 0,
|
||||
'total_tokens': summary.total_tokens or 0,
|
||||
'total_cost': final_total_cost,
|
||||
'avg_response_time': summary.avg_response_time or 0.0,
|
||||
'error_rate': summary.error_rate or 0.0,
|
||||
'limits': limits,
|
||||
'provider_breakdown': provider_breakdown,
|
||||
'alerts': [
|
||||
{
|
||||
'id': alert.id,
|
||||
'type': alert.alert_type,
|
||||
'title': alert.title,
|
||||
'message': alert.message,
|
||||
'severity': alert.severity,
|
||||
'created_at': alert.created_at.isoformat()
|
||||
}
|
||||
for alert in alerts
|
||||
],
|
||||
'usage_percentages': usage_percentages,
|
||||
'last_updated': summary.updated_at.isoformat()
|
||||
}
|
||||
|
||||
def get_usage_trends(self, user_id: str, months: int = 6) -> Dict[str, Any]:
|
||||
"""Get usage trends over time."""
|
||||
|
||||
# Calculate billing periods
|
||||
end_date = datetime.now()
|
||||
periods = []
|
||||
for i in range(months):
|
||||
period_date = end_date - timedelta(days=30 * i)
|
||||
periods.append(period_date.strftime("%Y-%m"))
|
||||
|
||||
periods.reverse() # Oldest first
|
||||
|
||||
# Get usage summaries for these periods
|
||||
summaries = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period.in_(periods)
|
||||
).order_by(UsageSummary.billing_period).all()
|
||||
|
||||
# Create trends data
|
||||
trends = {
|
||||
'periods': periods,
|
||||
'total_calls': [],
|
||||
'total_cost': [],
|
||||
'total_tokens': [],
|
||||
'provider_trends': {}
|
||||
}
|
||||
|
||||
summary_dict = {s.billing_period: s for s in summaries}
|
||||
|
||||
for period in periods:
|
||||
summary = summary_dict.get(period)
|
||||
|
||||
if summary:
|
||||
trends['total_calls'].append(summary.total_calls or 0)
|
||||
trends['total_cost'].append(summary.total_cost or 0.0)
|
||||
trends['total_tokens'].append(summary.total_tokens or 0)
|
||||
|
||||
# Provider-specific trends
|
||||
for provider in APIProvider:
|
||||
provider_name = provider.value
|
||||
if provider_name not in trends['provider_trends']:
|
||||
trends['provider_trends'][provider_name] = {
|
||||
'calls': [],
|
||||
'cost': [],
|
||||
'tokens': []
|
||||
}
|
||||
|
||||
trends['provider_trends'][provider_name]['calls'].append(
|
||||
getattr(summary, f"{provider_name}_calls", 0) or 0
|
||||
)
|
||||
trends['provider_trends'][provider_name]['cost'].append(
|
||||
getattr(summary, f"{provider_name}_cost", 0.0) or 0.0
|
||||
)
|
||||
trends['provider_trends'][provider_name]['tokens'].append(
|
||||
getattr(summary, f"{provider_name}_tokens", 0) or 0
|
||||
)
|
||||
else:
|
||||
# No data for this period
|
||||
trends['total_calls'].append(0)
|
||||
trends['total_cost'].append(0.0)
|
||||
trends['total_tokens'].append(0)
|
||||
|
||||
for provider in APIProvider:
|
||||
provider_name = provider.value
|
||||
if provider_name not in trends['provider_trends']:
|
||||
trends['provider_trends'][provider_name] = {
|
||||
'calls': [],
|
||||
'cost': [],
|
||||
'tokens': []
|
||||
}
|
||||
|
||||
trends['provider_trends'][provider_name]['calls'].append(0)
|
||||
trends['provider_trends'][provider_name]['cost'].append(0.0)
|
||||
trends['provider_trends'][provider_name]['tokens'].append(0)
|
||||
|
||||
return trends
|
||||
|
||||
async def enforce_usage_limits(self, user_id: str, provider: APIProvider,
|
||||
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""Enforce usage limits before making an API call."""
|
||||
# Check short-lived cache first (30s)
|
||||
cache_key = f"{user_id}:{provider.value}"
|
||||
now = datetime.utcnow()
|
||||
cached = self._enforce_cache.get(cache_key)
|
||||
if cached and cached.get('expires_at') and cached['expires_at'] > now:
|
||||
return tuple(cached['result']) # type: ignore
|
||||
|
||||
result = self.pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
tokens_requested=tokens_requested
|
||||
)
|
||||
self._enforce_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
|
||||
async def reset_current_billing_period(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Reset usage status and counters for the current billing period (after plan renewal/change)."""
|
||||
try:
|
||||
billing_period = datetime.now().strftime("%Y-%m")
|
||||
summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == billing_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
# Nothing to reset
|
||||
return {"reset": False, "reason": "no_summary"}
|
||||
|
||||
# CRITICAL: Reset ALL usage counters to 0 so user gets fresh limits with new/renewed plan
|
||||
# Clear LIMIT_REACHED status
|
||||
summary.usage_status = UsageStatus.ACTIVE
|
||||
|
||||
# Reset all LLM provider call counters
|
||||
summary.gemini_calls = 0
|
||||
summary.openai_calls = 0
|
||||
summary.anthropic_calls = 0
|
||||
summary.mistral_calls = 0
|
||||
|
||||
# Reset all LLM provider token counters
|
||||
summary.gemini_tokens = 0
|
||||
summary.openai_tokens = 0
|
||||
summary.anthropic_tokens = 0
|
||||
summary.mistral_tokens = 0
|
||||
|
||||
# Reset search/research provider counters
|
||||
summary.tavily_calls = 0
|
||||
summary.serper_calls = 0
|
||||
summary.metaphor_calls = 0
|
||||
summary.firecrawl_calls = 0
|
||||
|
||||
# Reset image generation counters
|
||||
summary.stability_calls = 0
|
||||
|
||||
# Reset video generation counters
|
||||
summary.video_calls = 0
|
||||
|
||||
# Reset image editing counters
|
||||
summary.image_edit_calls = 0
|
||||
|
||||
# Reset cost counters
|
||||
summary.gemini_cost = 0.0
|
||||
summary.openai_cost = 0.0
|
||||
summary.anthropic_cost = 0.0
|
||||
summary.mistral_cost = 0.0
|
||||
summary.tavily_cost = 0.0
|
||||
summary.serper_cost = 0.0
|
||||
summary.metaphor_cost = 0.0
|
||||
summary.firecrawl_cost = 0.0
|
||||
summary.stability_cost = 0.0
|
||||
summary.exa_cost = 0.0
|
||||
summary.video_cost = 0.0
|
||||
summary.image_edit_cost = 0.0
|
||||
|
||||
# Reset totals
|
||||
summary.total_calls = 0
|
||||
summary.total_tokens = 0
|
||||
summary.total_cost = 0.0
|
||||
|
||||
summary.updated_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"Reset usage counters for user {user_id} in billing period {billing_period} after renewal")
|
||||
return {"reset": True, "counters_reset": True}
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error resetting usage status: {e}")
|
||||
return {"reset": False, "error": str(e)}
|
||||
Reference in New Issue
Block a user