Research component integration, Copilotkit implementation, SEO copilotkit implementation, Wix SEO metadata complete, Wix SEO metadata review
This commit is contained in:
725
backend/services/subscription/limit_validation.py
Normal file
725
backend/services/subscription/limit_validation.py
Normal file
@@ -0,0 +1,725 @@
|
||||
"""
|
||||
Limit Validation Module
|
||||
Handles subscription limit checking and validation logic.
|
||||
Extracted from pricing_service.py for better modularity.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List, Tuple, TYPE_CHECKING
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy import text
|
||||
from loguru import logger
|
||||
|
||||
from models.subscription_models import (
|
||||
UserSubscription, UsageSummary, SubscriptionPlan,
|
||||
APIProvider, SubscriptionTier
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .pricing_service import PricingService
|
||||
|
||||
|
||||
class LimitValidator:
|
||||
"""Validates subscription limits for API usage."""
|
||||
|
||||
def __init__(self, pricing_service: 'PricingService'):
|
||||
"""
|
||||
Initialize limit validator with reference to PricingService.
|
||||
|
||||
Args:
|
||||
pricing_service: Instance of PricingService to access helper methods and cache
|
||||
"""
|
||||
self.pricing_service = pricing_service
|
||||
self.db = pricing_service.db
|
||||
|
||||
def check_usage_limits(self, user_id: str, provider: APIProvider,
|
||||
tokens_requested: int = 0, actual_provider_name: Optional[str] = None) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""Check if user can make an API call within their limits.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
provider: APIProvider enum (may be MISTRAL for HuggingFace)
|
||||
tokens_requested: Estimated tokens for the request
|
||||
actual_provider_name: Optional actual provider name (e.g., "huggingface" when provider is MISTRAL)
|
||||
|
||||
Returns:
|
||||
(can_proceed, error_message, usage_info)
|
||||
"""
|
||||
try:
|
||||
# Use actual_provider_name if provided, otherwise use enum value
|
||||
# This fixes cases where HuggingFace maps to MISTRAL enum but should show as "huggingface" in errors
|
||||
display_provider_name = actual_provider_name or provider.value
|
||||
|
||||
logger.debug(f"[Subscription Check] Starting limit check for user {user_id}, provider {display_provider_name}, tokens {tokens_requested}")
|
||||
|
||||
# Short TTL cache to reduce DB reads under sustained traffic
|
||||
cache_key = f"{user_id}:{provider.value}"
|
||||
now = datetime.utcnow()
|
||||
cached = self.pricing_service._limits_cache.get(cache_key)
|
||||
if cached and cached.get('expires_at') and cached['expires_at'] > now:
|
||||
logger.debug(f"[Subscription Check] Using cached result for {user_id}:{provider.value}")
|
||||
return tuple(cached['result']) # type: ignore
|
||||
|
||||
# Get user subscription first to check expiration
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
|
||||
if subscription:
|
||||
logger.debug(f"[Subscription Check] Found subscription for user {user_id}: plan_id={subscription.plan_id}, period_end={subscription.current_period_end}")
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] No active subscription found for user {user_id}")
|
||||
|
||||
# Check subscription expiration (STRICT: deny if expired)
|
||||
if subscription:
|
||||
if subscription.current_period_end < now:
|
||||
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}: period_end={subscription.current_period_end}, now={now}")
|
||||
# Subscription expired - check if auto_renew is enabled
|
||||
if not getattr(subscription, 'auto_renew', False):
|
||||
# Expired and no auto-renew - deny access
|
||||
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}, auto_renew=False, denying access")
|
||||
result = (False, "Subscription expired. Please renew your subscription to continue using the service.", {
|
||||
'expired': True,
|
||||
'period_end': subscription.current_period_end.isoformat()
|
||||
})
|
||||
self.pricing_service._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
else:
|
||||
# Try to auto-renew
|
||||
if not self.pricing_service._ensure_subscription_current(subscription):
|
||||
# Auto-renew failed - deny access
|
||||
result = (False, "Subscription expired and auto-renewal failed. Please renew manually.", {
|
||||
'expired': True,
|
||||
'auto_renew_failed': True
|
||||
})
|
||||
self.pricing_service._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
|
||||
# Get user limits with error handling (STRICT: fail on errors)
|
||||
# CRITICAL: Expire SQLAlchemy objects to ensure we get fresh plan data after renewal
|
||||
try:
|
||||
# Force expire subscription and plan objects to avoid stale cache
|
||||
if subscription and subscription.plan_id:
|
||||
plan_obj = self.db.query(SubscriptionPlan).filter(SubscriptionPlan.id == subscription.plan_id).first()
|
||||
if plan_obj:
|
||||
self.db.expire(plan_obj)
|
||||
logger.debug(f"[Subscription Check] Expired plan object to ensure fresh limits after renewal")
|
||||
|
||||
limits = self.pricing_service.get_user_limits(user_id)
|
||||
if limits:
|
||||
logger.debug(f"[Subscription Check] Retrieved limits for user {user_id}: plan={limits.get('plan_name')}, tier={limits.get('tier')}")
|
||||
# Log token limits for debugging
|
||||
token_limits = limits.get('limits', {})
|
||||
logger.debug(f"[Subscription Check] Token limits: gemini={token_limits.get('gemini_tokens')}, mistral={token_limits.get('mistral_tokens')}, openai={token_limits.get('openai_tokens')}, anthropic={token_limits.get('anthropic_tokens')}")
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] No limits found for user {user_id}, checking free tier")
|
||||
except Exception as e:
|
||||
logger.error(f"[Subscription Check] Error getting user limits for {user_id}: {e}", exc_info=True)
|
||||
# STRICT: Fail closed - deny request if we can't check limits
|
||||
return False, f"Failed to retrieve subscription limits: {str(e)}", {}
|
||||
|
||||
if not limits:
|
||||
# No subscription found - check for free tier
|
||||
free_plan = self.db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.FREE,
|
||||
SubscriptionPlan.is_active == True
|
||||
).first()
|
||||
if free_plan:
|
||||
logger.info(f"[Subscription Check] Assigning free tier to user {user_id}")
|
||||
limits = self.pricing_service._plan_to_limits_dict(free_plan)
|
||||
else:
|
||||
# No subscription and no free tier - deny access
|
||||
logger.warning(f"[Subscription Check] No subscription or free tier found for user {user_id}, denying access")
|
||||
return False, "No subscription plan found. Please subscribe to a plan.", {}
|
||||
|
||||
# Get current usage for this billing period with error handling
|
||||
# CRITICAL: Use fresh queries to avoid SQLAlchemy cache after renewal
|
||||
try:
|
||||
current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Expire all objects to force fresh read from DB (critical after renewal)
|
||||
self.db.expire_all()
|
||||
|
||||
# Use raw SQL query first to bypass ORM cache, fallback to ORM if SQL fails
|
||||
usage = None
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
sql_query = text("SELECT * FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
|
||||
result = self.db.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if result:
|
||||
# Map result to UsageSummary object
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if usage:
|
||||
self.db.refresh(usage) # Ensure fresh data
|
||||
except Exception as sql_error:
|
||||
logger.debug(f"[Subscription Check] Raw SQL query failed, using ORM: {sql_error}")
|
||||
# Fallback to ORM query
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if usage:
|
||||
self.db.refresh(usage) # Ensure fresh data
|
||||
|
||||
if not usage:
|
||||
# First usage this period, create summary
|
||||
try:
|
||||
usage = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
self.db.add(usage)
|
||||
self.db.commit()
|
||||
except Exception as create_error:
|
||||
logger.error(f"Error creating usage summary: {create_error}")
|
||||
self.db.rollback()
|
||||
# STRICT: Fail closed on DB error
|
||||
return False, f"Failed to create usage summary: {str(create_error)}", {}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting usage summary for {user_id}: {e}")
|
||||
self.db.rollback()
|
||||
# STRICT: Fail closed on DB error
|
||||
return False, f"Failed to retrieve usage summary: {str(e)}", {}
|
||||
|
||||
# Check call limits with error handling
|
||||
# NOTE: call_limit = 0 means UNLIMITED (Enterprise plans)
|
||||
try:
|
||||
# Use display_provider_name for error messages, but provider.value for DB queries
|
||||
provider_name = provider.value # For DB field names (e.g., "mistral_calls", "mistral_tokens")
|
||||
|
||||
# For LLM text generation providers, check against unified total_calls limit
|
||||
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
|
||||
is_llm_provider = provider_name in llm_providers
|
||||
|
||||
if is_llm_provider:
|
||||
# Use unified AI text generation limit (total_calls across all LLM providers)
|
||||
ai_text_gen_limit = limits['limits'].get('ai_text_generation_calls', 0) or 0
|
||||
|
||||
# If unified limit not set, fall back to provider-specific limit for backwards compatibility
|
||||
if ai_text_gen_limit == 0:
|
||||
ai_text_gen_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Calculate total LLM provider calls (sum of gemini + openai + anthropic + mistral)
|
||||
current_total_llm_calls = (
|
||||
(usage.gemini_calls or 0) +
|
||||
(usage.openai_calls or 0) +
|
||||
(usage.anthropic_calls or 0) +
|
||||
(usage.mistral_calls or 0)
|
||||
)
|
||||
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if ai_text_gen_limit > 0 and current_total_llm_calls >= ai_text_gen_limit:
|
||||
logger.error(f"[Subscription Check] AI text generation call limit exceeded for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit} (provider: {display_provider_name})")
|
||||
result = (False, f"AI text generation call limit reached. Used {current_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls this billing period.", {
|
||||
'current_calls': current_total_llm_calls,
|
||||
'limit': ai_text_gen_limit,
|
||||
'usage_percentage': (current_total_llm_calls / ai_text_gen_limit) * 100 if ai_text_gen_limit > 0 else 0,
|
||||
'provider': display_provider_name, # Use display name for consistency
|
||||
'usage_info': {
|
||||
'provider': display_provider_name, # Use display name for user-facing info
|
||||
'current_calls': current_total_llm_calls,
|
||||
'limit': ai_text_gen_limit,
|
||||
'type': 'ai_text_generation',
|
||||
'breakdown': {
|
||||
'gemini': usage.gemini_calls or 0,
|
||||
'openai': usage.openai_calls or 0,
|
||||
'anthropic': usage.anthropic_calls or 0,
|
||||
'mistral': usage.mistral_calls or 0 # DB field name (not display name)
|
||||
}
|
||||
}
|
||||
})
|
||||
self.pricing_service._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] AI text generation limit check passed for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit if ai_text_gen_limit > 0 else 'unlimited'} (provider: {display_provider_name})")
|
||||
else:
|
||||
# For non-LLM providers, check provider-specific limit
|
||||
current_calls = getattr(usage, f"{provider_name}_calls", 0) or 0
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if call_limit > 0 and current_calls >= call_limit:
|
||||
logger.error(f"[Subscription Check] Call limit exceeded for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit}")
|
||||
result = (False, f"API call limit reached for {display_provider_name}. Used {current_calls} of {call_limit} calls this billing period.", {
|
||||
'current_calls': current_calls,
|
||||
'limit': call_limit,
|
||||
'usage_percentage': 100.0,
|
||||
'provider': display_provider_name # Use display name for consistency
|
||||
})
|
||||
self.pricing_service._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] Call limit check passed for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit if call_limit > 0 else 'unlimited'}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking call limits: {e}")
|
||||
# Continue to next check
|
||||
|
||||
# Check token limits for LLM providers with error handling
|
||||
# NOTE: token_limit = 0 means UNLIMITED (Enterprise plans)
|
||||
try:
|
||||
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
current_tokens = getattr(usage, f"{provider_name}_tokens", 0) or 0
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
|
||||
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if token_limit > 0 and (current_tokens + tokens_requested) > token_limit:
|
||||
result = (False, f"Token limit would be exceeded for {display_provider_name}. Current: {current_tokens}, Requested: {tokens_requested}, Limit: {token_limit}", {
|
||||
'current_tokens': current_tokens,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100,
|
||||
'provider': display_provider_name, # Use display name in error details
|
||||
'usage_info': {
|
||||
'provider': display_provider_name,
|
||||
'current_tokens': current_tokens,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'type': 'tokens'
|
||||
}
|
||||
})
|
||||
self.pricing_service._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking token limits: {e}")
|
||||
# Continue to next check
|
||||
|
||||
# Check cost limits with error handling
|
||||
# NOTE: cost_limit = 0 means UNLIMITED (Enterprise plans)
|
||||
try:
|
||||
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if cost_limit > 0 and usage.total_cost >= cost_limit:
|
||||
result = (False, f"Monthly cost limit reached. Current cost: ${usage.total_cost:.2f}, Limit: ${cost_limit:.2f}", {
|
||||
'current_cost': usage.total_cost,
|
||||
'limit': cost_limit,
|
||||
'usage_percentage': 100.0
|
||||
})
|
||||
self.pricing_service._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking cost limits: {e}")
|
||||
# Continue to success case
|
||||
|
||||
# Calculate usage percentages for warnings
|
||||
try:
|
||||
# Determine which call variables to use based on provider type
|
||||
if is_llm_provider:
|
||||
# Use unified LLM call tracking
|
||||
current_call_count = current_total_llm_calls
|
||||
call_limit_value = ai_text_gen_limit
|
||||
else:
|
||||
# Use provider-specific call tracking
|
||||
current_call_count = current_calls
|
||||
call_limit_value = call_limit
|
||||
|
||||
call_usage_pct = (current_call_count / max(call_limit_value, 1)) * 100 if call_limit_value > 0 else 0
|
||||
cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0
|
||||
result = (True, "Within limits", {
|
||||
'current_calls': current_call_count,
|
||||
'call_limit': call_limit_value,
|
||||
'call_usage_percentage': call_usage_pct,
|
||||
'current_cost': usage.total_cost,
|
||||
'cost_limit': cost_limit,
|
||||
'cost_usage_percentage': cost_usage_pct
|
||||
})
|
||||
self.pricing_service._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating usage percentages: {e}")
|
||||
# Return basic success
|
||||
return True, "Within limits", {}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in check_usage_limits for {user_id}: {e}")
|
||||
# STRICT: Fail closed - deny requests if subscription system fails
|
||||
return False, f"Subscription check error: {str(e)}", {}
|
||||
|
||||
def check_comprehensive_limits(
|
||||
self,
|
||||
user_id: str,
|
||||
operations: List[Dict[str, Any]]
|
||||
) -> Tuple[bool, Optional[str], Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
Comprehensive pre-flight validation that checks ALL limits before making ANY API calls.
|
||||
|
||||
This prevents wasteful API calls by validating that ALL subsequent operations will succeed
|
||||
before making the first external API call.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
operations: List of operations to validate, each with:
|
||||
- 'provider': APIProvider enum
|
||||
- 'tokens_requested': int (estimated tokens for LLM calls, 0 for non-LLM)
|
||||
- 'actual_provider_name': Optional[str] (e.g., "huggingface" when provider is MISTRAL)
|
||||
- 'operation_type': str (e.g., "google_grounding", "llm_call", "image_generation")
|
||||
|
||||
Returns:
|
||||
(can_proceed, error_message, error_details)
|
||||
If can_proceed is False, error_message explains which limit would be exceeded
|
||||
"""
|
||||
try:
|
||||
logger.info(f"[Pre-flight Check] 🔍 Starting comprehensive validation for user {user_id}")
|
||||
logger.info(f"[Pre-flight Check] 📋 Validating {len(operations)} operation(s) before making any API calls")
|
||||
|
||||
# Get current usage and limits once
|
||||
current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
logger.info(f"[Pre-flight Check] 📅 Billing Period: {current_period} (for user {user_id})")
|
||||
|
||||
# Explicitly expire any cached objects and refresh from DB to ensure fresh data
|
||||
self.db.expire_all()
|
||||
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
# CRITICAL: Explicitly refresh from database to get latest values (clears SQLAlchemy cache)
|
||||
if usage:
|
||||
self.db.refresh(usage)
|
||||
|
||||
# Log what we actually read from database
|
||||
if usage:
|
||||
logger.info(f"[Pre-flight Check] 📊 Usage Summary from DB (Period: {current_period}):")
|
||||
logger.info(f" ├─ Gemini: {usage.gemini_tokens or 0} tokens / {usage.gemini_calls or 0} calls")
|
||||
logger.info(f" ├─ Mistral/HF: {usage.mistral_tokens or 0} tokens / {usage.mistral_calls or 0} calls")
|
||||
logger.info(f" ├─ Total Tokens: {usage.total_tokens or 0}")
|
||||
logger.info(f" └─ Usage Status: {usage.usage_status.value if usage.usage_status else 'N/A'}")
|
||||
else:
|
||||
logger.info(f"[Pre-flight Check] 📊 No usage summary found for period {current_period} (will create new)")
|
||||
|
||||
if not usage:
|
||||
# First usage this period, create summary
|
||||
try:
|
||||
usage = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
self.db.add(usage)
|
||||
self.db.commit()
|
||||
except Exception as create_error:
|
||||
logger.error(f"Error creating usage summary: {create_error}")
|
||||
self.db.rollback()
|
||||
return False, f"Failed to create usage summary: {str(create_error)}", {}
|
||||
|
||||
# Get user limits
|
||||
limits_dict = self.pricing_service.get_user_limits(user_id)
|
||||
if not limits_dict:
|
||||
# No subscription found - check for free tier
|
||||
free_plan = self.db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.FREE,
|
||||
SubscriptionPlan.is_active == True
|
||||
).first()
|
||||
if free_plan:
|
||||
limits_dict = self.pricing_service._plan_to_limits_dict(free_plan)
|
||||
else:
|
||||
return False, "No subscription plan found. Please subscribe to a plan.", {}
|
||||
|
||||
limits = limits_dict.get('limits', {})
|
||||
|
||||
# Track cumulative usage across all operations
|
||||
total_llm_calls = (
|
||||
(usage.gemini_calls or 0) +
|
||||
(usage.openai_calls or 0) +
|
||||
(usage.anthropic_calls or 0) +
|
||||
(usage.mistral_calls or 0)
|
||||
)
|
||||
total_llm_tokens = {}
|
||||
total_images = usage.stability_calls or 0
|
||||
|
||||
# Log current usage summary
|
||||
logger.info(f"[Pre-flight Check] 📊 Current Usage Summary:")
|
||||
logger.info(f" └─ Total LLM Calls: {total_llm_calls}")
|
||||
logger.info(f" └─ Gemini Tokens: {usage.gemini_tokens or 0}, Mistral/HF Tokens: {usage.mistral_tokens or 0}")
|
||||
logger.info(f" └─ Image Calls: {total_images}")
|
||||
|
||||
# Validate each operation
|
||||
for op_idx, operation in enumerate(operations):
|
||||
provider = operation.get('provider')
|
||||
provider_name = provider.value if hasattr(provider, 'value') else str(provider)
|
||||
tokens_requested = operation.get('tokens_requested', 0)
|
||||
actual_provider_name = operation.get('actual_provider_name')
|
||||
operation_type = operation.get('operation_type', 'unknown')
|
||||
|
||||
display_provider_name = actual_provider_name or provider_name
|
||||
|
||||
logger.error(f"[Pre-flight Check] ✅ Operation {op_idx + 1}/{len(operations)}: {operation_type}")
|
||||
logger.error(f" ├─ Provider: {display_provider_name} (enum: {provider_name})")
|
||||
logger.error(f" ├─ Operation Index: {op_idx}")
|
||||
logger.error(f" └─ Estimated Tokens Requested: {tokens_requested}")
|
||||
|
||||
# Check if this is an LLM provider
|
||||
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
|
||||
is_llm_provider = provider_name in llm_providers
|
||||
|
||||
# Check unified AI text generation limit for LLM providers
|
||||
if is_llm_provider:
|
||||
ai_text_gen_limit = limits.get('ai_text_generation_calls', 0) or 0
|
||||
if ai_text_gen_limit == 0:
|
||||
# Fallback to provider-specific limit
|
||||
ai_text_gen_limit = limits.get(f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Count this operation as an LLM call
|
||||
projected_total_llm_calls = total_llm_calls + 1
|
||||
|
||||
if ai_text_gen_limit > 0 and projected_total_llm_calls > ai_text_gen_limit:
|
||||
error_info = {
|
||||
'current_calls': total_llm_calls,
|
||||
'limit': ai_text_gen_limit,
|
||||
'provider': display_provider_name,
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"AI text generation call limit would be exceeded. Would use {projected_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls.", {
|
||||
'error_type': 'call_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
# Check token limits for this provider
|
||||
# CRITICAL: Always query fresh from DB for each operation to avoid SQLAlchemy cache issues
|
||||
# This ensures we get the latest values after subscription renewal, even for cumulative tracking
|
||||
provider_tokens_key = f"{provider_name}_tokens"
|
||||
|
||||
# Try to get fresh value from DB with comprehensive error handling
|
||||
base_current_tokens = 0
|
||||
query_succeeded = False
|
||||
|
||||
try:
|
||||
# Validate column name is safe (only allow known provider token columns)
|
||||
valid_token_columns = ['gemini_tokens', 'openai_tokens', 'anthropic_tokens', 'mistral_tokens']
|
||||
|
||||
if provider_tokens_key not in valid_token_columns:
|
||||
logger.error(f" └─ Invalid provider tokens key: {provider_tokens_key}")
|
||||
query_succeeded = True # Treat as success with 0 value
|
||||
else:
|
||||
# Method 1: Try raw SQL query to completely bypass ORM cache
|
||||
try:
|
||||
logger.debug(f" └─ Attempting raw SQL query for {provider_tokens_key}")
|
||||
sql_query = text(f"""
|
||||
SELECT {provider_tokens_key}
|
||||
FROM usage_summaries
|
||||
WHERE user_id = :user_id
|
||||
AND billing_period = :period
|
||||
LIMIT 1
|
||||
""")
|
||||
|
||||
logger.debug(f" └─ SQL: SELECT {provider_tokens_key} FROM usage_summaries WHERE user_id={user_id} AND billing_period={current_period}")
|
||||
|
||||
result = self.db.execute(sql_query, {
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
}).first()
|
||||
|
||||
if result:
|
||||
base_current_tokens = result[0] if result[0] is not None else 0
|
||||
logger.error(f"[Pre-flight Check] ✅ Raw SQL query returned result: {result[0]} -> {base_current_tokens}")
|
||||
else:
|
||||
base_current_tokens = 0
|
||||
logger.error(f"[Pre-flight Check] ⚠️ Raw SQL query returned None (no rows found)")
|
||||
|
||||
query_succeeded = True
|
||||
logger.error(f"[Pre-flight Check] ✅ Raw SQL query succeeded for {provider_tokens_key}: {base_current_tokens}")
|
||||
|
||||
except Exception as sql_error:
|
||||
logger.error(f" └─ Raw SQL query failed for {provider_tokens_key}: {type(sql_error).__name__}: {sql_error}", exc_info=True)
|
||||
query_succeeded = False # Will try ORM fallback
|
||||
|
||||
# Method 2: Fallback to fresh ORM query if raw SQL fails
|
||||
if not query_succeeded:
|
||||
try:
|
||||
# Expire all cached objects and do fresh query
|
||||
self.db.expire_all()
|
||||
fresh_usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if fresh_usage:
|
||||
# Explicitly refresh to get latest from DB
|
||||
self.db.refresh(fresh_usage)
|
||||
base_current_tokens = getattr(fresh_usage, provider_tokens_key, 0) or 0
|
||||
else:
|
||||
base_current_tokens = 0
|
||||
|
||||
query_succeeded = True
|
||||
logger.info(f"[Pre-flight Check] ✅ ORM fallback query succeeded for {provider_tokens_key}: {base_current_tokens}")
|
||||
|
||||
except Exception as orm_error:
|
||||
logger.error(f" └─ ORM query also failed: {orm_error}", exc_info=True)
|
||||
query_succeeded = False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" └─ Unexpected error getting tokens from DB for {provider_tokens_key}: {e}", exc_info=True)
|
||||
base_current_tokens = 0 # Fail safe - assume 0 if we can't query
|
||||
|
||||
if not query_succeeded:
|
||||
logger.warning(f" └─ Both query methods failed, using 0 as fallback")
|
||||
|
||||
# CRITICAL LOG: Always log what we got from DB - this helps debug renewal issues
|
||||
# Use ERROR level to ensure it shows even if INFO is filtered
|
||||
logger.error(f"[Pre-flight Check] 🔍 Fresh DB Query for {display_provider_name}:")
|
||||
logger.error(f" ├─ Column: {provider_tokens_key}")
|
||||
logger.error(f" ├─ Billing Period: {current_period}")
|
||||
logger.error(f" ├─ User ID: {user_id}")
|
||||
logger.error(f" ├─ Method: {'Raw SQL' if query_succeeded and base_current_tokens >= 0 else 'ORM' if query_succeeded else 'Failed - using 0'}")
|
||||
logger.error(f" └─ Value from DB: {base_current_tokens}")
|
||||
|
||||
# Add any projected tokens from previous operations in this validation run
|
||||
# Note: total_llm_tokens tracks ONLY projected tokens from this run, not base DB value
|
||||
projected_from_previous = total_llm_tokens.get(provider_tokens_key, 0)
|
||||
|
||||
# Current tokens = base from DB + projected from previous operations in this run
|
||||
current_provider_tokens = base_current_tokens + projected_from_previous
|
||||
|
||||
# Use ERROR level to ensure visibility
|
||||
logger.error(f"[Pre-flight Check] 📊 Token Calculation for {display_provider_name}:")
|
||||
logger.error(f" ├─ Base from DB (fresh query): {base_current_tokens}")
|
||||
logger.error(f" ├─ Projected from previous ops in this run: {projected_from_previous}")
|
||||
logger.error(f" └─ Total current tokens (base + projected): {current_provider_tokens}")
|
||||
|
||||
# Also check the initial usage object to see if it's being used incorrectly
|
||||
if usage and hasattr(usage, provider_tokens_key):
|
||||
initial_usage_value = getattr(usage, provider_tokens_key, 0) or 0
|
||||
logger.error(f" ⚠️ Initial usage object value: {initial_usage_value} (this should NOT be used for fresh query)")
|
||||
|
||||
token_limit = limits.get(provider_tokens_key, 0) or 0
|
||||
|
||||
if token_limit > 0 and tokens_requested > 0:
|
||||
projected_tokens = current_provider_tokens + tokens_requested
|
||||
logger.info(f" └─ Token Check: {current_provider_tokens} (current) + {tokens_requested} (requested) = {projected_tokens} (total) / {token_limit} (limit)")
|
||||
|
||||
if projected_tokens > token_limit:
|
||||
usage_percentage = (projected_tokens / token_limit) * 100 if token_limit > 0 else 0
|
||||
error_info = {
|
||||
'current_tokens': current_provider_tokens,
|
||||
'base_tokens_from_db': base_current_tokens,
|
||||
'projected_from_previous_ops': projected_from_previous,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'provider': display_provider_name,
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
# Make error message clearer: show actual DB usage vs projected
|
||||
if projected_from_previous > 0:
|
||||
error_msg = (
|
||||
f"Token limit exceeded for {display_provider_name} "
|
||||
f"({operation_type}). "
|
||||
f"Base usage: {base_current_tokens}/{token_limit}, "
|
||||
f"After previous operations in this workflow: {current_provider_tokens}/{token_limit}, "
|
||||
f"This operation would add: {tokens_requested}, "
|
||||
f"Total would be: {projected_tokens} (exceeds by {projected_tokens - token_limit} tokens)"
|
||||
)
|
||||
else:
|
||||
error_msg = (
|
||||
f"Token limit exceeded for {display_provider_name} "
|
||||
f"({operation_type}). "
|
||||
f"Current: {current_provider_tokens}/{token_limit}, "
|
||||
f"Requested: {tokens_requested}, "
|
||||
f"Would exceed by: {projected_tokens - token_limit} tokens "
|
||||
f"({usage_percentage:.1f}% of limit)"
|
||||
)
|
||||
logger.error(f"[Pre-flight Check] ❌ BLOCKED: {error_msg}")
|
||||
return False, error_msg, {
|
||||
'error_type': 'token_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
else:
|
||||
logger.info(f" └─ ✅ Token limit check passed: {projected_tokens} <= {token_limit}")
|
||||
|
||||
# Update cumulative counts for next operation
|
||||
total_llm_calls = projected_total_llm_calls
|
||||
# Update cumulative projected tokens from this validation run
|
||||
# This represents only projected tokens from previous operations in this run
|
||||
# Base DB value is always queried fresh, so we only track the projection delta
|
||||
old_projected = total_llm_tokens.get(provider_tokens_key, 0)
|
||||
if tokens_requested > 0:
|
||||
# Add this operation's tokens to cumulative projected tokens
|
||||
total_llm_tokens[provider_tokens_key] = projected_from_previous + tokens_requested
|
||||
logger.error(f"[Pre-flight Check] 📝 Updated cumulative projected tokens for {display_provider_name}:")
|
||||
logger.error(f" ├─ Previous projected: {projected_from_previous}")
|
||||
logger.error(f" ├─ This operation requested: {tokens_requested}")
|
||||
logger.error(f" ├─ New cumulative projected: {total_llm_tokens[provider_tokens_key]}")
|
||||
logger.error(f" └─ Old value in dict was: {old_projected}")
|
||||
else:
|
||||
# No tokens requested, keep existing projected tokens (or 0 if first operation)
|
||||
total_llm_tokens[provider_tokens_key] = projected_from_previous
|
||||
logger.error(f"[Pre-flight Check] 📝 No tokens requested, keeping projected at: {projected_from_previous}")
|
||||
|
||||
# Check image generation limits
|
||||
elif provider == APIProvider.STABILITY:
|
||||
image_limit = limits.get('stability_calls', 0) or 0
|
||||
projected_images = total_images + 1
|
||||
|
||||
if image_limit > 0 and projected_images > image_limit:
|
||||
error_info = {
|
||||
'current_images': total_images,
|
||||
'limit': image_limit,
|
||||
'provider': 'stability',
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"Image generation limit would be exceeded. Would use {projected_images} of {image_limit} images this billing period.", {
|
||||
'error_type': 'image_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
total_images = projected_images
|
||||
|
||||
# Check other provider-specific limits
|
||||
else:
|
||||
provider_calls_key = f"{provider_name}_calls"
|
||||
current_provider_calls = getattr(usage, provider_calls_key, 0) or 0
|
||||
call_limit = limits.get(provider_calls_key, 0) or 0
|
||||
|
||||
if call_limit > 0:
|
||||
projected_calls = current_provider_calls + 1
|
||||
if projected_calls > call_limit:
|
||||
error_info = {
|
||||
'current_calls': current_provider_calls,
|
||||
'limit': call_limit,
|
||||
'provider': display_provider_name,
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"API call limit would be exceeded for {display_provider_name}. Would use {projected_calls} of {call_limit} calls this billing period.", {
|
||||
'error_type': 'call_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
# All checks passed
|
||||
logger.info(f"[Pre-flight Check] ✅ All {len(operations)} operation(s) validated successfully")
|
||||
logger.info(f"[Pre-flight Check] ✅ User {user_id} is cleared to proceed with API calls")
|
||||
return True, None, None
|
||||
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
error_message = str(e)
|
||||
logger.error(f"[Pre-flight Check] ❌ Error during comprehensive limit check: {error_type}: {error_message}", exc_info=True)
|
||||
logger.error(f"[Pre-flight Check] ❌ User: {user_id}, Operations count: {len(operations) if operations else 0}")
|
||||
return False, f"Failed to validate limits: {error_type}: {error_message}", {}
|
||||
|
||||
@@ -44,15 +44,17 @@ def validate_research_operations(
|
||||
llm_provider_name = "gemini"
|
||||
|
||||
# Estimate tokens for each operation in research workflow
|
||||
# Google Grounding call: ~2000 tokens (input + output)
|
||||
# Google Grounding call: ~1200 tokens (input: ~500 tokens, output: ~700 tokens for research results)
|
||||
# Keyword analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON)
|
||||
# Competitor analyzer: ~1000 tokens (input: 3000 chars research, output: structured JSON)
|
||||
# Content angle generator: ~1000 tokens (input: 3000 chars research, output: list of angles)
|
||||
# Note: These are conservative estimates. Actual usage may be lower, but we use these for pre-flight validation
|
||||
# to prevent wasteful API calls if the workflow would exceed limits.
|
||||
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.GEMINI, # Google Grounding uses Gemini
|
||||
'tokens_requested': 2000,
|
||||
'tokens_requested': 1200, # Reduced from 2000 to more realistic estimate
|
||||
'actual_provider_name': 'gemini',
|
||||
'operation_type': 'google_grounding'
|
||||
},
|
||||
@@ -126,6 +128,120 @@ def validate_research_operations(
|
||||
)
|
||||
|
||||
|
||||
def validate_exa_research_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str,
|
||||
gpt_provider: str = "google"
|
||||
) -> None:
|
||||
"""
|
||||
Validate all operations for an Exa research workflow before making ANY API calls.
|
||||
|
||||
This prevents wasteful external API calls (like Exa search) if subsequent
|
||||
LLM calls would fail due to token or call limits.
|
||||
|
||||
Args:
|
||||
pricing_service: PricingService instance
|
||||
user_id: User ID for subscription checking
|
||||
gpt_provider: GPT provider from env var (defaults to "google")
|
||||
|
||||
Returns:
|
||||
None
|
||||
If validation fails, raises HTTPException with 429 status
|
||||
"""
|
||||
try:
|
||||
# Determine actual provider for LLM calls based on GPT_PROVIDER env var
|
||||
gpt_provider_lower = gpt_provider.lower()
|
||||
if gpt_provider_lower == "huggingface":
|
||||
llm_provider_enum = APIProvider.MISTRAL # Maps to HuggingFace
|
||||
llm_provider_name = "huggingface"
|
||||
else:
|
||||
llm_provider_enum = APIProvider.GEMINI
|
||||
llm_provider_name = "gemini"
|
||||
|
||||
# Estimate tokens for each operation in Exa research workflow
|
||||
# Exa Search call: 1 Exa API call (not token-based)
|
||||
# Keyword analyzer: ~1000 tokens (input: research results, output: structured JSON)
|
||||
# Competitor analyzer: ~1000 tokens (input: research results, output: structured JSON)
|
||||
# Content angle generator: ~1000 tokens (input: research results, output: list of angles)
|
||||
# Note: These are conservative estimates for pre-flight validation
|
||||
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.EXA, # Exa API call
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'exa',
|
||||
'operation_type': 'exa_neural_search'
|
||||
},
|
||||
{
|
||||
'provider': llm_provider_enum,
|
||||
'tokens_requested': 1000,
|
||||
'actual_provider_name': llm_provider_name,
|
||||
'operation_type': 'keyword_analysis'
|
||||
},
|
||||
{
|
||||
'provider': llm_provider_enum,
|
||||
'tokens_requested': 1000,
|
||||
'actual_provider_name': llm_provider_name,
|
||||
'operation_type': 'competitor_analysis'
|
||||
},
|
||||
{
|
||||
'provider': llm_provider_enum,
|
||||
'tokens_requested': 1000,
|
||||
'actual_provider_name': llm_provider_name,
|
||||
'operation_type': 'content_angle_generation'
|
||||
}
|
||||
]
|
||||
|
||||
logger.info(f"[Pre-flight Validator] 🚀 Starting Exa Research Workflow Validation")
|
||||
logger.info(f" ├─ User: {user_id}")
|
||||
logger.info(f" ├─ LLM Provider: {llm_provider_name} (GPT_PROVIDER={gpt_provider})")
|
||||
logger.info(f" └─ Operations to validate: {len(operations_to_validate)}")
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', llm_provider_name) if usage_info else llm_provider_name
|
||||
operation_type = usage_info.get('operation_type', 'unknown')
|
||||
|
||||
logger.error(f"[Pre-flight Validator] ❌ EXA RESEARCH WORKFLOW BLOCKED")
|
||||
logger.error(f" ├─ User: {user_id}")
|
||||
logger.error(f" ├─ Blocked at: {operation_type}")
|
||||
logger.error(f" ├─ Provider: {provider}")
|
||||
logger.error(f" └─ Reason: {message}")
|
||||
|
||||
# Raise HTTPException immediately - frontend gets immediate response, no API calls made
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ EXA RESEARCH WORKFLOW APPROVED")
|
||||
logger.info(f" ├─ User: {user_id}")
|
||||
logger.info(f" └─ All {len(operations_to_validate)} operations validated - proceeding with API calls")
|
||||
# Validation passed - no return needed (function raises HTTPException if validation fails)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating Exa research operations: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
'error': f"Failed to validate operations: {str(e)}",
|
||||
'message': f"Failed to validate operations: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def validate_image_generation_operations(
|
||||
pricing_service: PricingService,
|
||||
user_id: str
|
||||
|
||||
@@ -258,6 +258,12 @@ class PricingService:
|
||||
"model_name": "stable-diffusion",
|
||||
"cost_per_image": 0.04, # $0.04 per image
|
||||
"description": "Stability AI Image Generation"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.EXA,
|
||||
"model_name": "exa-search",
|
||||
"cost_per_request": 0.005, # $0.005 per search (1-25 results)
|
||||
"description": "Exa Neural Search API"
|
||||
}
|
||||
]
|
||||
|
||||
@@ -296,6 +302,7 @@ class PricingService:
|
||||
"metaphor_calls_limit": 10,
|
||||
"firecrawl_calls_limit": 10,
|
||||
"stability_calls_limit": 5,
|
||||
"exa_calls_limit": 100,
|
||||
"gemini_tokens_limit": 100000,
|
||||
"monthly_cost_limit": 0.0,
|
||||
"features": ["basic_content_generation", "limited_research"],
|
||||
@@ -316,10 +323,11 @@ class PricingService:
|
||||
"metaphor_calls_limit": 100,
|
||||
"firecrawl_calls_limit": 100,
|
||||
"stability_calls_limit": 5,
|
||||
"gemini_tokens_limit": 2000,
|
||||
"openai_tokens_limit": 2000,
|
||||
"anthropic_tokens_limit": 2000,
|
||||
"mistral_tokens_limit": 2000,
|
||||
"exa_calls_limit": 500,
|
||||
"gemini_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"openai_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"anthropic_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"mistral_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"monthly_cost_limit": 50.0,
|
||||
"features": ["full_content_generation", "advanced_research", "basic_analytics"],
|
||||
"description": "Great for individuals and small teams"
|
||||
@@ -338,6 +346,7 @@ class PricingService:
|
||||
"metaphor_calls_limit": 500,
|
||||
"firecrawl_calls_limit": 500,
|
||||
"stability_calls_limit": 200,
|
||||
"exa_calls_limit": 2000,
|
||||
"gemini_tokens_limit": 5000000,
|
||||
"openai_tokens_limit": 2500000,
|
||||
"anthropic_tokens_limit": 1000000,
|
||||
@@ -360,6 +369,7 @@ class PricingService:
|
||||
"metaphor_calls_limit": 0,
|
||||
"firecrawl_calls_limit": 0,
|
||||
"stability_calls_limit": 0,
|
||||
"exa_calls_limit": 0, # Unlimited
|
||||
"gemini_tokens_limit": 0,
|
||||
"openai_tokens_limit": 0,
|
||||
"anthropic_tokens_limit": 0,
|
||||
@@ -423,11 +433,14 @@ class PricingService:
|
||||
def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get usage limits for a user based on their subscription."""
|
||||
|
||||
# CRITICAL: Expire all objects first to ensure fresh data after renewal
|
||||
self.db.expire_all()
|
||||
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
|
||||
|
||||
if not subscription:
|
||||
# Return free tier limits
|
||||
free_plan = self.db.query(SubscriptionPlan).filter(
|
||||
@@ -439,7 +452,23 @@ class PricingService:
|
||||
|
||||
# Ensure current period before returning limits
|
||||
self._ensure_subscription_current(subscription)
|
||||
return self._plan_to_limits_dict(subscription.plan)
|
||||
|
||||
# CRITICAL: Refresh subscription to get latest plan_id, then refresh plan relationship
|
||||
self.db.refresh(subscription)
|
||||
|
||||
# Re-query plan directly to ensure fresh data (bypass relationship cache)
|
||||
plan = self.db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.id == subscription.plan_id
|
||||
).first()
|
||||
|
||||
if not plan:
|
||||
logger.error(f"Plan not found for subscription plan_id={subscription.plan_id}")
|
||||
return None
|
||||
|
||||
# Refresh plan to ensure fresh limits
|
||||
self.db.refresh(plan)
|
||||
|
||||
return self._plan_to_limits_dict(plan)
|
||||
|
||||
def _ensure_ai_text_gen_column_detection(self) -> None:
|
||||
"""Detect at runtime whether ai_text_generation_calls_limit column exists and cache the result."""
|
||||
@@ -508,290 +537,20 @@ class PricingService:
|
||||
tokens_requested: int = 0, actual_provider_name: Optional[str] = None) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""Check if user can make an API call within their limits.
|
||||
|
||||
Delegates to LimitValidator for actual validation logic.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
provider: APIProvider enum (may be MISTRAL for HuggingFace)
|
||||
tokens_requested: Estimated tokens for the request
|
||||
actual_provider_name: Optional actual provider name (e.g., "huggingface" when provider is MISTRAL)
|
||||
"""
|
||||
try:
|
||||
# Use actual_provider_name if provided, otherwise use enum value
|
||||
# This fixes cases where HuggingFace maps to MISTRAL enum but should show as "huggingface" in errors
|
||||
display_provider_name = actual_provider_name or provider.value
|
||||
|
||||
logger.debug(f"[Subscription Check] Starting limit check for user {user_id}, provider {display_provider_name}, tokens {tokens_requested}")
|
||||
|
||||
# Short TTL cache to reduce DB reads under sustained traffic
|
||||
cache_key = f"{user_id}:{provider.value}"
|
||||
now = datetime.utcnow()
|
||||
cached = self._limits_cache.get(cache_key)
|
||||
if cached and cached.get('expires_at') and cached['expires_at'] > now:
|
||||
logger.debug(f"[Subscription Check] Using cached result for {user_id}:{provider.value}")
|
||||
return tuple(cached['result']) # type: ignore
|
||||
|
||||
# Get user subscription first to check expiration
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
|
||||
if subscription:
|
||||
logger.debug(f"[Subscription Check] Found subscription for user {user_id}: plan_id={subscription.plan_id}, period_end={subscription.current_period_end}")
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] No active subscription found for user {user_id}")
|
||||
|
||||
# Check subscription expiration (STRICT: deny if expired)
|
||||
if subscription:
|
||||
if subscription.current_period_end < now:
|
||||
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}: period_end={subscription.current_period_end}, now={now}")
|
||||
# Subscription expired - check if auto_renew is enabled
|
||||
if not getattr(subscription, 'auto_renew', False):
|
||||
# Expired and no auto-renew - deny access
|
||||
logger.warning(f"[Subscription Check] Subscription expired for user {user_id}, auto_renew=False, denying access")
|
||||
result = (False, "Subscription expired. Please renew your subscription to continue using the service.", {
|
||||
'expired': True,
|
||||
'period_end': subscription.current_period_end.isoformat()
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
else:
|
||||
# Try to auto-renew
|
||||
if not self._ensure_subscription_current(subscription):
|
||||
# Auto-renew failed - deny access
|
||||
result = (False, "Subscription expired and auto-renewal failed. Please renew manually.", {
|
||||
'expired': True,
|
||||
'auto_renew_failed': True
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
|
||||
# Get user limits with error handling (STRICT: fail on errors)
|
||||
try:
|
||||
limits = self.get_user_limits(user_id)
|
||||
if limits:
|
||||
logger.debug(f"[Subscription Check] Retrieved limits for user {user_id}: plan={limits.get('plan_name')}, tier={limits.get('tier')}")
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] No limits found for user {user_id}, checking free tier")
|
||||
except Exception as e:
|
||||
logger.error(f"[Subscription Check] Error getting user limits for {user_id}: {e}", exc_info=True)
|
||||
# STRICT: Fail closed - deny request if we can't check limits
|
||||
return False, f"Failed to retrieve subscription limits: {str(e)}", {}
|
||||
|
||||
if not limits:
|
||||
# No subscription found - check for free tier
|
||||
free_plan = self.db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.FREE,
|
||||
SubscriptionPlan.is_active == True
|
||||
).first()
|
||||
if free_plan:
|
||||
logger.info(f"[Subscription Check] Assigning free tier to user {user_id}")
|
||||
limits = self._plan_to_limits_dict(free_plan)
|
||||
else:
|
||||
# No subscription and no free tier - deny access
|
||||
logger.warning(f"[Subscription Check] No subscription or free tier found for user {user_id}, denying access")
|
||||
return False, "No subscription plan found. Please subscribe to a plan.", {}
|
||||
|
||||
# Get current usage for this billing period with error handling
|
||||
try:
|
||||
current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not usage:
|
||||
# First usage this period, create summary
|
||||
try:
|
||||
usage = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
self.db.add(usage)
|
||||
self.db.commit()
|
||||
except Exception as create_error:
|
||||
logger.error(f"Error creating usage summary: {create_error}")
|
||||
self.db.rollback()
|
||||
# STRICT: Fail closed on DB error
|
||||
return False, f"Failed to create usage summary: {str(create_error)}", {}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting usage summary for {user_id}: {e}")
|
||||
self.db.rollback()
|
||||
# STRICT: Fail closed on DB error
|
||||
return False, f"Failed to retrieve usage summary: {str(e)}", {}
|
||||
|
||||
# Check call limits with error handling
|
||||
# NOTE: call_limit = 0 means UNLIMITED (Enterprise plans)
|
||||
try:
|
||||
# Use display_provider_name for error messages, but provider.value for DB queries
|
||||
provider_name = provider.value # For DB field names (e.g., "mistral_calls", "mistral_tokens")
|
||||
|
||||
# For LLM text generation providers, check against unified total_calls limit
|
||||
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
|
||||
is_llm_provider = provider_name in llm_providers
|
||||
|
||||
if is_llm_provider:
|
||||
# Use unified AI text generation limit (total_calls across all LLM providers)
|
||||
ai_text_gen_limit = limits['limits'].get('ai_text_generation_calls', 0) or 0
|
||||
|
||||
# If unified limit not set, fall back to provider-specific limit for backwards compatibility
|
||||
if ai_text_gen_limit == 0:
|
||||
ai_text_gen_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Calculate total LLM provider calls (sum of gemini + openai + anthropic + mistral)
|
||||
current_total_llm_calls = (
|
||||
(usage.gemini_calls or 0) +
|
||||
(usage.openai_calls or 0) +
|
||||
(usage.anthropic_calls or 0) +
|
||||
(usage.mistral_calls or 0)
|
||||
)
|
||||
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if ai_text_gen_limit > 0 and current_total_llm_calls >= ai_text_gen_limit:
|
||||
logger.error(f"[Subscription Check] AI text generation call limit exceeded for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit} (provider: {display_provider_name})")
|
||||
result = (False, f"AI text generation call limit reached. Used {current_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls this billing period.", {
|
||||
'current_calls': current_total_llm_calls,
|
||||
'limit': ai_text_gen_limit,
|
||||
'usage_percentage': (current_total_llm_calls / ai_text_gen_limit) * 100 if ai_text_gen_limit > 0 else 0,
|
||||
'provider': display_provider_name, # Use display name for consistency
|
||||
'usage_info': {
|
||||
'provider': display_provider_name, # Use display name for user-facing info
|
||||
'current_calls': current_total_llm_calls,
|
||||
'limit': ai_text_gen_limit,
|
||||
'type': 'ai_text_generation',
|
||||
'breakdown': {
|
||||
'gemini': usage.gemini_calls or 0,
|
||||
'openai': usage.openai_calls or 0,
|
||||
'anthropic': usage.anthropic_calls or 0,
|
||||
'mistral': usage.mistral_calls or 0 # DB field name (not display name)
|
||||
}
|
||||
}
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] AI text generation limit check passed for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit if ai_text_gen_limit > 0 else 'unlimited'} (provider: {display_provider_name})")
|
||||
else:
|
||||
# For non-LLM providers, check provider-specific limit
|
||||
current_calls = getattr(usage, f"{provider_name}_calls", 0) or 0
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if call_limit > 0 and current_calls >= call_limit:
|
||||
logger.error(f"[Subscription Check] Call limit exceeded for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit}")
|
||||
result = (False, f"API call limit reached for {display_provider_name}. Used {current_calls} of {call_limit} calls this billing period.", {
|
||||
'current_calls': current_calls,
|
||||
'limit': call_limit,
|
||||
'usage_percentage': 100.0,
|
||||
'provider': display_provider_name # Use display name for consistency
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
else:
|
||||
logger.debug(f"[Subscription Check] Call limit check passed for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit if call_limit > 0 else 'unlimited'}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking call limits: {e}")
|
||||
# Continue to next check
|
||||
|
||||
# Check token limits for LLM providers with error handling
|
||||
# NOTE: token_limit = 0 means UNLIMITED (Enterprise plans)
|
||||
try:
|
||||
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
current_tokens = getattr(usage, f"{provider_name}_tokens", 0) or 0
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
|
||||
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if token_limit > 0 and (current_tokens + tokens_requested) > token_limit:
|
||||
result = (False, f"Token limit would be exceeded for {display_provider_name}. Current: {current_tokens}, Requested: {tokens_requested}, Limit: {token_limit}", {
|
||||
'current_tokens': current_tokens,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100,
|
||||
'provider': display_provider_name, # Use display name in error details
|
||||
'usage_info': {
|
||||
'provider': display_provider_name,
|
||||
'current_tokens': current_tokens,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'type': 'tokens'
|
||||
}
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking token limits: {e}")
|
||||
# Continue to next check
|
||||
|
||||
# Check cost limits with error handling
|
||||
# NOTE: cost_limit = 0 means UNLIMITED (Enterprise plans)
|
||||
try:
|
||||
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
|
||||
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
|
||||
if cost_limit > 0 and usage.total_cost >= cost_limit:
|
||||
result = (False, f"Monthly cost limit reached. Current cost: ${usage.total_cost:.2f}, Limit: ${cost_limit:.2f}", {
|
||||
'current_cost': usage.total_cost,
|
||||
'limit': cost_limit,
|
||||
'usage_percentage': 100.0
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking cost limits: {e}")
|
||||
# Continue to success case
|
||||
|
||||
# Calculate usage percentages for warnings
|
||||
try:
|
||||
# Determine which call variables to use based on provider type
|
||||
if is_llm_provider:
|
||||
# Use unified LLM call tracking
|
||||
current_call_count = current_total_llm_calls
|
||||
call_limit_value = ai_text_gen_limit
|
||||
else:
|
||||
# Use provider-specific call tracking
|
||||
current_call_count = current_calls
|
||||
call_limit_value = call_limit
|
||||
|
||||
call_usage_pct = (current_call_count / max(call_limit_value, 1)) * 100 if call_limit_value > 0 else 0
|
||||
cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0
|
||||
result = (True, "Within limits", {
|
||||
'current_calls': current_call_count,
|
||||
'call_limit': call_limit_value,
|
||||
'call_usage_percentage': call_usage_pct,
|
||||
'current_cost': usage.total_cost,
|
||||
'cost_limit': cost_limit,
|
||||
'cost_usage_percentage': cost_usage_pct
|
||||
})
|
||||
self._limits_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating usage percentages: {e}")
|
||||
# Return basic success
|
||||
return True, "Within limits", {}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in check_usage_limits for {user_id}: {e}")
|
||||
# STRICT: Fail closed - deny requests if subscription system fails
|
||||
return False, f"Subscription check error: {str(e)}", {}
|
||||
Returns:
|
||||
(can_proceed, error_message, usage_info)
|
||||
"""
|
||||
from .limit_validation import LimitValidator
|
||||
validator = LimitValidator(self)
|
||||
return validator.check_usage_limits(user_id, provider, tokens_requested, actual_provider_name)
|
||||
|
||||
def estimate_tokens(self, text: str, provider: APIProvider) -> int:
|
||||
"""Estimate token count for text based on provider."""
|
||||
@@ -827,6 +586,16 @@ class PricingService:
|
||||
if not pricing:
|
||||
return None
|
||||
|
||||
# Return pricing info as dict
|
||||
return {
|
||||
'provider': pricing.provider.value,
|
||||
'model_name': pricing.model_name,
|
||||
'cost_per_input_token': pricing.cost_per_input_token,
|
||||
'cost_per_output_token': pricing.cost_per_output_token,
|
||||
'cost_per_request': pricing.cost_per_request,
|
||||
'description': pricing.description
|
||||
}
|
||||
|
||||
def check_comprehensive_limits(
|
||||
self,
|
||||
user_id: str,
|
||||
@@ -835,6 +604,7 @@ class PricingService:
|
||||
"""
|
||||
Comprehensive pre-flight validation that checks ALL limits before making ANY API calls.
|
||||
|
||||
Delegates to LimitValidator for actual validation logic.
|
||||
This prevents wasteful API calls by validating that ALL subsequent operations will succeed
|
||||
before making the first external API call.
|
||||
|
||||
@@ -850,202 +620,9 @@ class PricingService:
|
||||
(can_proceed, error_message, error_details)
|
||||
If can_proceed is False, error_message explains which limit would be exceeded
|
||||
"""
|
||||
try:
|
||||
logger.info(f"[Pre-flight Check] 🔍 Starting comprehensive validation for user {user_id}")
|
||||
logger.info(f"[Pre-flight Check] 📋 Validating {len(operations)} operation(s) before making any API calls")
|
||||
|
||||
# Get current usage and limits once
|
||||
current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not usage:
|
||||
# First usage this period, create summary
|
||||
try:
|
||||
usage = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
self.db.add(usage)
|
||||
self.db.commit()
|
||||
except Exception as create_error:
|
||||
logger.error(f"Error creating usage summary: {create_error}")
|
||||
self.db.rollback()
|
||||
return False, f"Failed to create usage summary: {str(create_error)}", {}
|
||||
|
||||
# Get user limits
|
||||
limits_dict = self.get_user_limits(user_id)
|
||||
if not limits_dict:
|
||||
# No subscription found - check for free tier
|
||||
free_plan = self.db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.FREE,
|
||||
SubscriptionPlan.is_active == True
|
||||
).first()
|
||||
if free_plan:
|
||||
limits_dict = self._plan_to_limits_dict(free_plan)
|
||||
else:
|
||||
return False, "No subscription plan found. Please subscribe to a plan.", {}
|
||||
|
||||
limits = limits_dict.get('limits', {})
|
||||
|
||||
# Track cumulative usage across all operations
|
||||
total_llm_calls = (
|
||||
(usage.gemini_calls or 0) +
|
||||
(usage.openai_calls or 0) +
|
||||
(usage.anthropic_calls or 0) +
|
||||
(usage.mistral_calls or 0)
|
||||
)
|
||||
total_llm_tokens = {}
|
||||
total_images = usage.stability_calls or 0
|
||||
|
||||
# Log current usage summary
|
||||
logger.info(f"[Pre-flight Check] 📊 Current Usage Summary:")
|
||||
logger.info(f" └─ Total LLM Calls: {total_llm_calls}")
|
||||
logger.info(f" └─ Gemini Tokens: {usage.gemini_tokens or 0}, Mistral/HF Tokens: {usage.mistral_tokens or 0}")
|
||||
logger.info(f" └─ Image Calls: {total_images}")
|
||||
|
||||
# Validate each operation
|
||||
for op_idx, operation in enumerate(operations):
|
||||
provider = operation.get('provider')
|
||||
provider_name = provider.value if hasattr(provider, 'value') else str(provider)
|
||||
tokens_requested = operation.get('tokens_requested', 0)
|
||||
actual_provider_name = operation.get('actual_provider_name')
|
||||
operation_type = operation.get('operation_type', 'unknown')
|
||||
|
||||
display_provider_name = actual_provider_name or provider_name
|
||||
|
||||
logger.info(f"[Pre-flight Check] ✅ Operation {op_idx + 1}/{len(operations)}: {operation_type}")
|
||||
logger.info(f" ├─ Provider: {display_provider_name} (enum: {provider_name})")
|
||||
logger.info(f" └─ Estimated Tokens: {tokens_requested}")
|
||||
|
||||
# Check if this is an LLM provider
|
||||
llm_providers = ['gemini', 'openai', 'anthropic', 'mistral']
|
||||
is_llm_provider = provider_name in llm_providers
|
||||
|
||||
# Check unified AI text generation limit for LLM providers
|
||||
if is_llm_provider:
|
||||
ai_text_gen_limit = limits.get('ai_text_generation_calls', 0) or 0
|
||||
if ai_text_gen_limit == 0:
|
||||
# Fallback to provider-specific limit
|
||||
ai_text_gen_limit = limits.get(f"{provider_name}_calls", 0) or 0
|
||||
|
||||
# Count this operation as an LLM call
|
||||
projected_total_llm_calls = total_llm_calls + 1
|
||||
|
||||
if ai_text_gen_limit > 0 and projected_total_llm_calls > ai_text_gen_limit:
|
||||
error_info = {
|
||||
'current_calls': total_llm_calls,
|
||||
'limit': ai_text_gen_limit,
|
||||
'provider': display_provider_name,
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"AI text generation call limit would be exceeded. Would use {projected_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls.", {
|
||||
'error_type': 'call_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
# Check token limits for this provider
|
||||
# Use cumulative projected tokens from previous operations, or current from DB if first operation
|
||||
provider_tokens_key = f"{provider_name}_tokens"
|
||||
if provider_tokens_key in total_llm_tokens:
|
||||
# Use cumulative projected tokens from previous operations
|
||||
current_provider_tokens = total_llm_tokens[provider_tokens_key]
|
||||
logger.info(f" └─ Using cumulative projected tokens: {current_provider_tokens}")
|
||||
else:
|
||||
# First operation for this provider - get current from database
|
||||
current_provider_tokens = getattr(usage, provider_tokens_key, 0) or 0
|
||||
total_llm_tokens[provider_tokens_key] = current_provider_tokens
|
||||
logger.info(f" └─ Current tokens from DB: {current_provider_tokens}")
|
||||
|
||||
token_limit = limits.get(provider_tokens_key, 0) or 0
|
||||
|
||||
if token_limit > 0 and tokens_requested > 0:
|
||||
projected_tokens = current_provider_tokens + tokens_requested
|
||||
logger.info(f" └─ Token Check: {current_provider_tokens} (current) + {tokens_requested} (requested) = {projected_tokens} (total) / {token_limit} (limit)")
|
||||
|
||||
if projected_tokens > token_limit:
|
||||
usage_percentage = (projected_tokens / token_limit) * 100 if token_limit > 0 else 0
|
||||
error_info = {
|
||||
'current_tokens': current_provider_tokens,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'provider': display_provider_name,
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
error_msg = (
|
||||
f"Token limit exceeded for {display_provider_name} "
|
||||
f"({operation_type}). "
|
||||
f"Current: {current_provider_tokens}/{token_limit}, "
|
||||
f"Requested: {tokens_requested}, "
|
||||
f"Would exceed by: {projected_tokens - token_limit} tokens "
|
||||
f"({usage_percentage:.1f}% of limit)"
|
||||
)
|
||||
logger.error(f"[Pre-flight Check] ❌ BLOCKED: {error_msg}")
|
||||
return False, error_msg, {
|
||||
'error_type': 'token_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
else:
|
||||
logger.info(f" └─ ✅ Token limit check passed: {projected_tokens} <= {token_limit}")
|
||||
|
||||
# Update cumulative counts for next operation
|
||||
total_llm_calls = projected_total_llm_calls
|
||||
total_llm_tokens[provider_tokens_key] += tokens_requested
|
||||
logger.info(f" └─ Updated cumulative tokens for {display_provider_name}: {total_llm_tokens[provider_tokens_key]}")
|
||||
|
||||
# Check image generation limits
|
||||
elif provider == APIProvider.STABILITY:
|
||||
image_limit = limits.get('stability_calls', 0) or 0
|
||||
projected_images = total_images + 1
|
||||
|
||||
if image_limit > 0 and projected_images > image_limit:
|
||||
error_info = {
|
||||
'current_images': total_images,
|
||||
'limit': image_limit,
|
||||
'provider': 'stability',
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"Image generation limit would be exceeded. Would use {projected_images} of {image_limit} images this billing period.", {
|
||||
'error_type': 'image_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
total_images = projected_images
|
||||
|
||||
# Check other provider-specific limits
|
||||
else:
|
||||
provider_calls_key = f"{provider_name}_calls"
|
||||
current_provider_calls = getattr(usage, provider_calls_key, 0) or 0
|
||||
call_limit = limits.get(provider_calls_key, 0) or 0
|
||||
|
||||
if call_limit > 0:
|
||||
projected_calls = current_provider_calls + 1
|
||||
if projected_calls > call_limit:
|
||||
error_info = {
|
||||
'current_calls': current_provider_calls,
|
||||
'limit': call_limit,
|
||||
'provider': display_provider_name,
|
||||
'operation_type': operation_type,
|
||||
'operation_index': op_idx
|
||||
}
|
||||
return False, f"API call limit would be exceeded for {display_provider_name}. Would use {projected_calls} of {call_limit} calls this billing period.", {
|
||||
'error_type': 'call_limit',
|
||||
'usage_info': error_info
|
||||
}
|
||||
|
||||
# All checks passed
|
||||
logger.info(f"[Pre-flight Check] ✅ All {len(operations)} operation(s) validated successfully")
|
||||
logger.info(f"[Pre-flight Check] ✅ User {user_id} is cleared to proceed with API calls")
|
||||
return True, None, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Check] Error during comprehensive limit check: {e}", exc_info=True)
|
||||
return False, f"Failed to validate limits: {str(e)}", {}
|
||||
from .limit_validation import LimitValidator
|
||||
validator = LimitValidator(self)
|
||||
return validator.check_comprehensive_limits(user_id, operations)
|
||||
|
||||
def get_pricing_for_provider_model(self, provider: APIProvider, model_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get pricing configuration for a specific provider and model."""
|
||||
|
||||
39
backend/services/subscription/schema_utils.py
Normal file
39
backend/services/subscription/schema_utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Set
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
_checked_subscription_plan_columns: bool = False
|
||||
|
||||
|
||||
def ensure_subscription_plan_columns(db: Session) -> None:
|
||||
"""Ensure required columns exist on subscription_plans for runtime safety.
|
||||
|
||||
This is a defensive guard for environments where migrations have not yet
|
||||
been applied. If columns are missing (e.g., exa_calls_limit), we add them
|
||||
with a safe default so ORM queries do not fail.
|
||||
"""
|
||||
global _checked_subscription_plan_columns
|
||||
if _checked_subscription_plan_columns:
|
||||
return
|
||||
|
||||
try:
|
||||
# Discover existing columns
|
||||
result = db.execute("PRAGMA table_info(subscription_plans)")
|
||||
cols: Set[str] = {row[1] for row in result}
|
||||
|
||||
# Columns we may reference in models but might be missing in older DBs
|
||||
required_columns = {
|
||||
"exa_calls_limit": "INTEGER DEFAULT 0",
|
||||
}
|
||||
|
||||
for col_name, ddl in required_columns.items():
|
||||
if col_name not in cols:
|
||||
db.execute(f"ALTER TABLE subscription_plans ADD COLUMN {col_name} {ddl}")
|
||||
db.commit()
|
||||
except Exception:
|
||||
# Do not block app if pragma/alter fails; let normal errors surface
|
||||
db.rollback()
|
||||
finally:
|
||||
_checked_subscription_plan_columns = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user