Fix voice clone NotSupportedError and improve subscription services

This commit is contained in:
ajaysi
2026-04-22 12:27:51 +05:30
parent 641143a7d6
commit cbd68fa43f
13 changed files with 221 additions and 72 deletions

View File

@@ -107,6 +107,20 @@ class LimitValidator:
}
return result
# Helper: Check if a limit should be enforced based on tier
def should_enforce_limit(limit_value: int, tier: str) -> bool:
"""
Determine if a limit should be enforced.
- Free tier: 0 means DISABLED (not unlimited)
- Basic/Pro/Enterprise: 0 means UNLIMITED
"""
if tier == 'free':
# Free tier: 0 means disabled
return limit_value > 0
else:
# Basic/Pro/Enterprise: 0 means unlimited
return limit_value > 0
# Get user limits with error handling (STRICT: fail on errors)
# CRITICAL: Expire SQLAlchemy objects to ensure we get fresh plan data after renewal
try:
@@ -144,6 +158,9 @@ class LimitValidator:
logger.warning(f"[Subscription Check] No subscription or free tier found for user {user_id}, denying access")
return False, "No subscription plan found. Please subscribe to a plan.", {}
# Extract tier for limit enforcement logic
user_tier = limits.get('tier', 'free') if limits else 'free'
# Get current usage for this billing period with error handling
# Use targeted expiry instead of expire_all() to avoid nuking the entire session cache
try:
@@ -245,8 +262,8 @@ class LimitValidator:
(usage.mistral_calls or 0)
)
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if ai_text_gen_limit > 0 and current_total_llm_calls >= ai_text_gen_limit:
# Enforce limit based on tier (Free: 0=disabled, others: 0=unlimited)
if should_enforce_limit(ai_text_gen_limit, user_tier) and current_total_llm_calls >= ai_text_gen_limit:
logger.error(f"[Subscription Check] AI text generation call limit exceeded for user {user_id}: {current_total_llm_calls}/{ai_text_gen_limit} (provider: {display_provider_name})")
result = (False, f"AI text generation call limit reached. Used {current_total_llm_calls} of {ai_text_gen_limit} total AI text generation calls this billing period.", {
'current_calls': current_total_llm_calls,
@@ -278,8 +295,8 @@ class LimitValidator:
current_calls = getattr(usage, f"{provider_name}_calls", 0) or 0
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if call_limit > 0 and current_calls >= call_limit:
# Enforce limit based on tier (Free: 0=disabled, others: 0=unlimited)
if should_enforce_limit(call_limit, user_tier) and current_calls >= call_limit:
logger.error(f"[Subscription Check] Call limit exceeded for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit}")
result = (False, f"API call limit reached for {display_provider_name}. Used {current_calls} of {call_limit} calls this billing period.", {
'current_calls': current_calls,
@@ -296,7 +313,13 @@ class LimitValidator:
logger.debug(f"[Subscription Check] Call limit check passed for user {user_id}, provider {display_provider_name}: {current_calls}/{call_limit if call_limit > 0 else 'unlimited'}")
except Exception as e:
logger.error(f"Error checking call limits: {e}")
# Continue to next check
# Fail closed - deny if we can't verify the limit
result = (False, f"Unable to verify call limit: {str(e)}", {})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
# Check token limits for LLM providers with error handling
# NOTE: token_limit = 0 means UNLIMITED (Enterprise plans)
@@ -305,8 +328,8 @@ class LimitValidator:
current_tokens = getattr(usage, f"{provider_name}_tokens", 0) or 0
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if token_limit > 0 and (current_tokens + tokens_requested) > token_limit:
# Enforce limit based on tier (Free: 0=disabled, others: 0=unlimited)
if should_enforce_limit(token_limit, user_tier) and (current_tokens + tokens_requested) > token_limit:
result = (False, f"Token limit would be exceeded for {display_provider_name}. Current: {current_tokens}, Requested: {tokens_requested}, Limit: {token_limit}", {
'current_tokens': current_tokens,
'requested_tokens': tokens_requested,
@@ -328,14 +351,19 @@ class LimitValidator:
return result
except Exception as e:
logger.error(f"Error checking token limits: {e}")
# Continue to next check
# Fail closed - deny if we can't verify the limit
result = (False, f"Unable to verify token limit: {str(e)}", {})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
# Check cost limits with error handling
# NOTE: cost_limit = 0 means UNLIMITED (Enterprise plans)
try:
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
# Only enforce limit if limit > 0 (0 means unlimited for Enterprise)
if cost_limit > 0 and usage.total_cost >= cost_limit:
# Enforce limit based on tier (Free: 0=disabled, others: 0=unlimited)
if should_enforce_limit(cost_limit, user_tier) and usage.total_cost >= cost_limit:
result = (False, f"Monthly cost limit reached. Current cost: ${usage.total_cost:.2f}, Limit: ${cost_limit:.2f}", {
'current_cost': usage.total_cost,
'limit': cost_limit,
@@ -348,7 +376,13 @@ class LimitValidator:
return result
except Exception as e:
logger.error(f"Error checking cost limits: {e}")
# Continue to success case
# Fail closed - deny if we can't verify the limit
result = (False, f"Unable to verify cost limit: {str(e)}", {})
self.pricing_service._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
# Calculate usage percentages for warnings
try:
@@ -503,6 +537,7 @@ class LimitValidator:
return False, "No subscription plan found. Please subscribe to a plan.", {}
limits = limits_dict.get('limits', {})
tier = limits_dict.get('tier', 'free')
# Track cumulative usage across all operations
total_llm_calls = (
@@ -547,7 +582,8 @@ class LimitValidator:
# Count this operation as an LLM call
projected_total_llm_calls = total_llm_calls + 1
if ai_text_gen_limit > 0 and projected_total_llm_calls > ai_text_gen_limit:
# Enforce limit based on tier (Free: 0=disabled, others: 0=unlimited)
if should_enforce_limit(ai_text_gen_limit, tier) and projected_total_llm_calls > ai_text_gen_limit:
error_info = {
'current_calls': total_llm_calls,
'limit': ai_text_gen_limit,
@@ -654,7 +690,8 @@ class LimitValidator:
token_limit = limits.get(provider_tokens_key, 0) or 0
if token_limit > 0 and tokens_requested > 0:
# Enforce limit based on tier (Free: 0=disabled, others: 0=unlimited)
if should_enforce_limit(token_limit, tier) and tokens_requested > 0:
projected_tokens = current_provider_tokens + tokens_requested
logger.info(f" └─ Token Check: {current_provider_tokens} (current) + {tokens_requested} (requested) = {projected_tokens} (total) / {token_limit} (limit)")
@@ -716,7 +753,8 @@ class LimitValidator:
image_limit = limits.get('stability_calls', 0) or 0
projected_images = total_images + 1
if image_limit > 0 and projected_images > image_limit:
# Enforce limit based on tier (Free: 0=disabled, others: 0=unlimited)
if should_enforce_limit(image_limit, tier) and projected_images > image_limit:
error_info = {
'current_images': total_images,
'limit': image_limit,
@@ -737,7 +775,8 @@ class LimitValidator:
total_video_calls = usage.video_calls or 0
projected_video_calls = total_video_calls + 1
if video_limit > 0 and projected_video_calls > video_limit:
# Enforce limit based on tier (Free: 0=disabled, others: 0=unlimited)
if should_enforce_limit(video_limit, tier) and projected_video_calls > video_limit:
error_info = {
'current_calls': total_video_calls,
'limit': video_limit,
@@ -756,7 +795,8 @@ class LimitValidator:
total_image_edit_calls = getattr(usage, 'image_edit_calls', 0) or 0
projected_image_edit_calls = total_image_edit_calls + 1
if image_edit_limit > 0 and projected_image_edit_calls > image_edit_limit:
# Enforce limit based on tier (Free: 0=disabled, others: 0=unlimited)
if should_enforce_limit(image_edit_limit, tier) and projected_image_edit_calls > image_edit_limit:
error_info = {
'current_calls': total_image_edit_calls,
'limit': image_edit_limit,
@@ -789,6 +829,25 @@ class LimitValidator:
'error_type': 'call_limit',
'usage_info': error_info
}
# Check WaveSpeed combined limit if actual_provider is WaveSpeed
if actual_provider_name == 'wavespeed':
wavespeed_limit = limits.get('wavespeed_calls', 0) or 0
if should_enforce_limit(wavespeed_limit, tier):
wavespeed_usage = usage.wavespeed_calls or 0
projected_wavespeed = wavespeed_usage + 1
if projected_wavespeed > wavespeed_limit:
error_info = {
'current_calls': wavespeed_usage,
'limit': wavespeed_limit,
'provider': 'wavespeed',
'operation_type': operation_type,
'operation_index': op_idx
}
return False, f"WaveSpeed API limit would be exceeded. Would use {projected_wavespeed} of {wavespeed_limit} WaveSpeed calls this billing period.", {
'error_type': 'wavespeed_limit',
'usage_info': error_info
}
# All checks passed
logger.info(f"[Pre-flight Check] ✅ All {len(operations)} operation(s) validated successfully")