fix: WYSIWYG editor, content generation, and writing assistant bug fixes

- Fix text selection menu not showing: wire contentRef via inputRef on multiline TextField
- Fix blog title not truncating: add min-w-0 for flex item overflow
- Fix outline generation 500: escape curly braces in f-string prompt template
- Fix content generation 'NoneType not callable': replace SessionLocal() with get_session_for_user(), add db param to MediumBlogGenerator, fix signature mismatch in database_task_manager
- Fix writing assistant suggest 500: add auth + user_id to API endpoint and service, replace sync requests with httpx.AsyncClient
- Fix hallucination detector 404: explicitly include router in main.py and app.py
- Fix missing error_data in task failure responses
- Hide CopilotKit web inspector button
- Remove hardcoded fallback suggestions from SmartTypingAssist
- Fix stale closure refs in SmartTypingAssist handleTypingChange
- Add two-column editor layout, stats bar, section hover menu
- Various subscription, billing, and research module improvements
This commit is contained in:
ajaysi
2026-05-14 09:11:30 +05:30
parent 7385100017
commit 928c2f20aa
113 changed files with 4344 additions and 10064 deletions

View File

@@ -12,7 +12,7 @@ from loguru import logger
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError
from models.subscription_models import APIProvider, UsageAlert
from models.subscription_models import APIProvider, UsageAlert, UserSubscription
class SubscriptionErrorType(Enum):
USAGE_LIMIT_EXCEEDED = "usage_limit_exceeded"
@@ -248,6 +248,18 @@ class SubscriptionExceptionHandler:
return
try:
# Get billing period from subscription, fallback to calendar month
billing_period = datetime.now().strftime("%Y-%m") # default
try:
subscription = self.db.query(UserSubscription).filter(
UserSubscription.user_id == error.user_id,
UserSubscription.is_active == True
).first()
if subscription and subscription.current_period_start:
billing_period = subscription.current_period_start.strftime("%Y-%m")
except:
pass # Use default calendar period
alert = UsageAlert(
user_id=error.user_id,
alert_type="system_error",
@@ -256,7 +268,7 @@ class SubscriptionExceptionHandler:
title=f"System Error: {error.error_type.value}",
message=error.message,
severity=error.severity.value,
billing_period=datetime.now().strftime("%Y-%m")
billing_period=billing_period
)
self.db.add(alert)

View File

@@ -157,39 +157,38 @@ class LimitValidator:
user_tier = limits.get('tier', 'free') if limits else 'free'
# Get current usage for this billing period with error handling
# Use targeted expiry instead of expire_all() to avoid nuking the entire session cache
# Use subscription period, not calendar month
current_period = self.pricing_service.get_current_billing_period(user_id)
# Only expire specific objects that might have changed after renewal
# (subscription was already checked above; plan was expired above)
# The usage record is the main object we need fresh, and we query it directly below
if subscription:
self.db.expire(subscription)
# Use raw SQL query first to bypass ORM cache, fallback to ORM if SQL fails
usage = None
try:
current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Only expire specific objects that might have changed after renewal
# (subscription was already checked above; plan was expired above)
# The usage record is the main object we need fresh, and we query it directly below
if subscription:
self.db.expire(subscription)
# Use raw SQL query first to bypass ORM cache, fallback to ORM if SQL fails
usage = None
try:
from sqlalchemy import text
sql_query = text("SELECT * FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
result = self.db.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
if result:
# Map result to UsageSummary object
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if usage:
self.db.refresh(usage) # Ensure fresh data
except Exception as sql_error:
logger.debug(f"[Subscription Check] Raw SQL query failed, using ORM: {sql_error}")
# Fallback to ORM query
from sqlalchemy import text
sql_query = text("SELECT * FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
result = self.db.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
if result:
# Map result to UsageSummary object
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if usage:
self.db.refresh(usage) # Ensure fresh data
except Exception as sql_error:
logger.debug(f"[Subscription Check] Raw SQL query failed, using ORM: {sql_error}")
# Fallback to ORM query
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if usage:
self.db.refresh(usage) # Ensure fresh data
if not usage:
# First usage this period, create summary
@@ -448,7 +447,7 @@ class LimitValidator:
logger.info(f"[Pre-flight Check] 📋 Validating {len(operations)} operation(s) before making any API calls")
# Get current usage and limits once
current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
current_period = self.pricing_service.get_current_billing_period(user_id)
logger.info(f"[Pre-flight Check] 📅 Billing Period: {current_period} (for user {user_id})")

View File

@@ -67,15 +67,56 @@ class PricingService:
self.db.rollback()
return True
def get_current_billing_period(self, user_id: str) -> Optional[str]:
"""Return current billing period key (YYYY-MM) after ensuring subscription is current."""
def get_current_billing_period(self, user_id: str) -> str:
"""Return current billing period key (YYYY-MM) based on subscription, not calendar.
Maintains backward compatibility with existing calendar-month data."""
subscription = self.db.query(UserSubscription).filter(
UserSubscription.user_id == user_id,
UserSubscription.is_active == True
).first()
# Ensure subscription is current (advance if auto_renew)
self._ensure_subscription_current(subscription)
# Continue to use YYYY-MM for summaries
# Use subscription's billing period, NOT calendar month
if subscription and subscription.current_period_start:
sub_period = subscription.current_period_start.strftime("%Y-%m")
# Check if usage data exists for this subscription period
from models.subscription_models import UsageSummary
usage_exists = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == sub_period
).first()
if usage_exists:
return sub_period
# If no data for subscription period, check for calendar month data
# This handles backward compatibility for existing users
calendar_period = datetime.now().strftime("%Y-%m")
if calendar_period != sub_period:
calendar_usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == calendar_period
).first()
if calendar_usage:
logger.info(f"Using calendar period {calendar_period} for backward compatibility (subscription period {sub_period} has no data)")
return calendar_period
return sub_period
# Fallback: Check if user has any usage summary and use that period
from models.subscription_models import UsageSummary
latest_summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id
).order_by(UsageSummary.billing_period.desc()).first()
if latest_summary:
logger.info(f"Using latest billing period from UsageSummary: {latest_summary.billing_period}")
return latest_summary.billing_period
# Last fallback to calendar month for free tier / no data
return datetime.now().strftime("%Y-%m")
@classmethod
@@ -830,6 +871,7 @@ class PricingService:
'serper_calls': plan.serper_calls_limit,
'metaphor_calls': plan.metaphor_calls_limit,
'firecrawl_calls': plan.firecrawl_calls_limit,
'exa_calls': getattr(plan, 'exa_calls_limit', 0), # Exa research API
'stability_calls': plan.stability_calls_limit,
'video_calls': getattr(plan, 'video_calls_limit', 0), # Support missing column
'image_edit_calls': getattr(plan, 'image_edit_calls_limit', 0), # Support missing column

View File

@@ -8,7 +8,7 @@ from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from models.subscription_models import UserSubscription, SubscriptionPlan, SubscriptionTier, BillingCycle, UsageStatus, FraudWarning, ProcessedStripeEvent
from services.subscription.pricing_service import PricingService
from datetime import datetime
from datetime import datetime, timedelta
REQUIRED_STRIPE_PLAN_KEYS = {
(SubscriptionTier.BASIC.value, BillingCycle.MONTHLY.value),
@@ -421,10 +421,6 @@ class StripeService:
try:
sub = stripe.Subscription.retrieve(subscription_id)
price_id = sub['items']['data'][0]['price']['id']
# Map price_id to internal plan_id
# Note: You need a way to map Stripe Price IDs to your Plan IDs.
# For now, we'll assume the metadata or a lookup.
# Ideally, store price_id in SubscriptionPlan table or config.
# Update DB
self._update_user_subscription(
@@ -434,6 +430,24 @@ class StripeService:
status="active",
price_id=price_id
)
# Clear PricingService cache so next status check returns updated limits
try:
from services.subscription import PricingService
PricingService.clear_user_cache(user_id)
except Exception as cache_err:
logger.warning(f"Failed to clear user cache after checkout for user {user_id}: {cache_err}")
try:
from api.subscription.cache import clear_dashboard_cache
clear_dashboard_cache(user_id)
logger.info(f"Cleared dashboard cache for user {user_id} after checkout")
except Exception as cache_err:
logger.warning(f"Failed to clear cache after checkout for user {user_id}: {cache_err}")
# Expire all SQLAlchemy objects to force fresh reads
self.db.expire_all()
logger.info(f"Expired all SQLAlchemy objects for user {user_id} after checkout")
except Exception as e:
logger.error(f"Error processing checkout subscription: {e}")
@@ -457,11 +471,28 @@ class StripeService:
logger.info(f"Payment succeeded for user {subscription.user_id}")
subscription.status = UsageStatus.ACTIVE
subscription.is_active = True
# Update period end based on invoice lines period
subscription.auto_renew = True
# Update period start/end based on invoice lines period
if invoice.get('lines'):
period_start = invoice['lines']['data'][0]['period']['start']
period_end = invoice['lines']['data'][0]['period']['end']
subscription.current_period_start = datetime.fromtimestamp(period_start)
subscription.current_period_end = datetime.fromtimestamp(period_end)
self.db.commit()
# Clear PricingService cache so next status check returns updated limits
try:
from services.subscription import PricingService
PricingService.clear_user_cache(subscription.user_id)
logger.info(f"Cleared subscription cache for user {subscription.user_id} after payment success")
except Exception as cache_err:
logger.warning(f"Failed to clear user cache after payment success for user {subscription.user_id}: {cache_err}")
try:
from api.subscription.cache import clear_dashboard_cache
clear_dashboard_cache(subscription.user_id)
except Exception as dash_cache_err:
logger.warning(f"Failed to clear dashboard cache after payment success for user {subscription.user_id}: {dash_cache_err}")
self.db.expire_all()
async def _handle_invoice_payment_failed(self, invoice: Dict[str, Any]):
subscription_id = invoice.get("subscription")
@@ -497,6 +528,12 @@ class StripeService:
if status in ["active", "trialing"]:
subscription.status = UsageStatus.ACTIVE
subscription.is_active = True
subscription.auto_renew = True
# Update period boundaries from Stripe event
current_period = subscription_obj.get("current_period", {})
if current_period:
subscription.current_period_start = datetime.fromtimestamp(current_period.get("start", 0))
subscription.current_period_end = datetime.fromtimestamp(current_period.get("end", 0))
elif status in ["past_due", "unpaid", "incomplete", "incomplete_expired"]:
subscription.status = UsageStatus.PAST_DUE
subscription.is_active = False
@@ -506,6 +543,20 @@ class StripeService:
subscription.auto_renew = False
self.db.commit()
# Clear PricingService cache so next status check returns updated limits
try:
from services.subscription import PricingService
PricingService.clear_user_cache(subscription.user_id)
logger.info(f"Cleared subscription cache for user {subscription.user_id} after subscription update")
except Exception as cache_err:
logger.warning(f"Failed to clear user cache after subscription update for user {subscription.user_id}: {cache_err}")
try:
from api.subscription.cache import clear_dashboard_cache
clear_dashboard_cache(subscription.user_id)
except Exception as dash_cache_err:
logger.warning(f"Failed to clear dashboard cache after subscription update for user {subscription.user_id}: {dash_cache_err}")
self.db.expire_all()
async def _handle_subscription_deleted(self, subscription_obj: Dict[str, Any]):
"""
@@ -610,6 +661,11 @@ class StripeService:
)
now = datetime.utcnow()
# Calculate billing period end based on cycle
if billing_cycle == BillingCycle.YEARLY:
period_end = now + timedelta(days=365)
else:
period_end = now + timedelta(days=30)
if not subscription:
subscription = UserSubscription(
@@ -617,7 +673,7 @@ class StripeService:
plan_id=plan.id,
billing_cycle=billing_cycle,
current_period_start=now,
current_period_end=now,
current_period_end=period_end,
status=UsageStatus.ACTIVE if status == "active" else UsageStatus.SUSPENDED,
is_active=status == "active",
auto_renew=True,
@@ -627,6 +683,11 @@ class StripeService:
subscription.plan_id = plan.id
subscription.billing_cycle = billing_cycle
subscription.is_active = status == "active"
subscription.status = UsageStatus.ACTIVE if status == "active" else UsageStatus.SUSPENDED
# Reset billing period on upgrade/plan change
subscription.current_period_start = now
subscription.current_period_end = period_end
subscription.auto_renew = True
subscription.stripe_customer_id = stripe_customer_id
subscription.stripe_subscription_id = stripe_subscription_id

View File

@@ -0,0 +1,21 @@
"""
Usage tracking modules package.
Split from the monolithic usage_tracking_service.py for better maintainability.
"""
from .historical_usage import get_all_historical_usage, get_current_period_usage, get_usage_for_period
from .usage_stats import get_user_usage_stats
from .usage_trends import get_usage_trends
from .limits_enforcement import enforce_usage_limits
from .alerts import check_usage_alerts, create_usage_alert
__all__ = [
'get_all_historical_usage',
'get_current_period_usage',
'get_usage_for_period',
'get_user_usage_stats',
'get_usage_trends',
'enforce_usage_limits',
'check_usage_alerts',
'create_usage_alert',
]

View File

@@ -0,0 +1,101 @@
"""
Usage alert functions.
Extracted from usage_tracking_service.py for better maintainability.
"""
from typing import Dict, Any
from sqlalchemy.orm import Session
from loguru import logger
from models.subscription_models import UsageAlert, UsageSummary, APIProvider, UsageStatus
def check_usage_alerts(user_id: str, provider: APIProvider,
billing_period: str, db: Session, pricing_service):
"""Check if usage alerts should be sent."""
# Get current usage
period_keys = {'billing_period': billing_period, 'lookup_periods': [billing_period]}
summary = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
).first()
if not summary:
return
# Get user limits
limits = pricing_service.get_user_limits(user_id)
if not limits:
return
# Check for alert thresholds (80%, 90%, 100%)
thresholds = [80, 90, 100]
for threshold in thresholds:
# Check if alert already sent for this threshold
existing_alert = db.query(UsageAlert).filter(
UsageAlert.user_id == user_id,
UsageAlert.billing_period == billing_period,
UsageAlert.threshold_percentage == threshold,
UsageAlert.provider == provider,
UsageAlert.is_sent == True
).first()
if existing_alert:
continue
# Check if threshold is reached
provider_name = provider.value
current_calls = getattr(summary, f"{provider_name}_calls", 0)
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
if call_limit > 0:
usage_percentage = (current_calls / call_limit) * 100
if usage_percentage >= threshold:
create_usage_alert(
user_id=user_id,
provider=provider,
threshold=threshold,
current_usage=current_calls,
limit=call_limit,
billing_period=billing_period,
db=db
)
def create_usage_alert(user_id: str, provider: APIProvider,
threshold: int, current_usage: int, limit: int,
billing_period: str, db: Session):
"""Create a usage alert."""
# Determine alert type and severity
if threshold >= 100:
alert_type = "limit_reached"
severity = "error"
title = f"API Limit Reached - {provider.value.title()}"
message = f"You have reached your {provider.value} API limit of {limit:,} calls for this billing period."
elif threshold >= 90:
alert_type = "usage_warning"
severity = "warning"
title = f"API Usage Warning - {provider.value.title()}"
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
else:
alert_type = "usage_warning"
severity = "info"
title = f"API Usage Notice - {provider.value.title()}"
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
alert = UsageAlert(
user_id=user_id,
alert_type=alert_type,
threshold_percentage=threshold,
provider=provider,
title=title,
message=message,
severity=severity,
billing_period=billing_period
)
db.add(alert)
logger.info(f"Created usage alert for {user_id}: {title}")

View File

@@ -0,0 +1,250 @@
"""
Historical usage aggregation functions.
Extracted from usage_tracking_service.py for better maintainability.
"""
from typing import Dict, Any
from sqlalchemy.orm import Session
from loguru import logger
from datetime import datetime
from models.subscription_models import UsageSummary, UsageStatus
# Shared provider mapping: DB column → frontend key
PROVIDER_MAPPING = {
'gemini_calls': 'gemini',
'openai_calls': 'openai',
'anthropic_calls': 'anthropic',
'mistral_calls': 'huggingface', # HuggingFace stored as mistral
'wavespeed_calls': 'wavespeed',
'exa_calls': 'exa',
'tavily_calls': 'tavily',
'serper_calls': 'serper',
'firecrawl_calls': 'firecrawl',
'metaphor_calls': 'metaphor',
'stability_calls': 'stability',
'video_calls': 'video',
'image_edit_calls': 'image_edit',
'audio_calls': 'audio',
}
def _build_provider_breakdown(summaries: list, mapping: dict) -> dict:
"""Build provider_breakdown dict from a list of UsageSummary records."""
breakdown = {}
for db_col, frontend_key in mapping.items():
total = sum(getattr(s, db_col, 0) or 0 for s in summaries)
breakdown[frontend_key] = {'calls': total, 'cost': 0, 'tokens': 0}
return breakdown
def _build_usage_percentages(provider_breakdown: dict, limits: dict) -> dict:
"""Build usage_percentages dict from provider_breakdown and per-period limits."""
pcts = {}
if not limits or not limits.get('limits'):
return pcts
limit_map = {
'gemini_calls': ('gemini', 'gemini_calls'),
'huggingface_calls': ('huggingface', 'mistral_calls'),
'stability_calls': ('stability', 'stability_calls'),
'video_calls': ('video', 'video_calls'),
'audio_calls': ('audio', 'audio_calls'),
'image_edit_calls': ('image_edit', 'image_edit_calls'),
'wavespeed_calls': ('wavespeed', 'wavespeed_calls'),
'tavily_calls': ('tavily', 'tavily_calls'),
'serper_calls': ('serper', 'serper_calls'),
'firecrawl_calls': ('firecrawl', 'firecrawl_calls'),
'metaphor_calls': ('metaphor', 'metaphor_calls'),
'exa_calls': ('exa', 'exa_calls'),
}
for pct_key, (bk_key, limit_key) in limit_map.items():
used = provider_breakdown.get(bk_key, {}).get('calls', 0)
limit_val = limits.get('limits', {}).get(limit_key, 0) or 0
if limit_val > 0:
pcts[pct_key] = (used / limit_val) * 100
# Cost percentage
total_cost = provider_breakdown.get('total_cost', 0)
cost_limit = limits.get('limits', {}).get('monthly_cost', 0) or 0
if cost_limit > 0:
pcts['cost'] = (total_cost / cost_limit) * 100
return pcts
def _summaries_usage_status(summaries: list) -> str:
"""Derive overall usage_status from a list of summaries."""
status = 'active'
for s in summaries:
try:
st = s.usage_status.value
except Exception:
st = str(s.usage_status)
if st == 'limit_reached':
return 'limit_reached'
if st == 'warning' and status != 'limit_reached':
status = 'warning'
return status
def _empty_usage_response(billing_period: str, limits: dict) -> Dict[str, Any]:
"""Return a zeroed UsageStats-shaped response."""
return {
'billing_period': billing_period,
'usage_status': 'active',
'total_calls': 0,
'total_tokens': 0,
'total_cost': 0.0,
'avg_response_time': 0.0,
'error_rate': 0.0,
'limits': limits,
'provider_breakdown': {},
'usage_percentages': {},
'historical_breakdown': [],
'last_updated': datetime.now().isoformat()
}
def get_all_historical_usage(user_id: str, db: Session, pricing_service) -> Dict[str, Any]:
"""Get ALL historical usage data aggregated across all billing periods."""
all_summaries = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id
).order_by(UsageSummary.billing_period.desc()).all()
limits = pricing_service.get_user_limits(user_id)
if not all_summaries:
return _empty_usage_response('all', limits)
# Aggregate
total_calls = sum(s.total_calls or 0 for s in all_summaries)
total_tokens = sum(s.total_tokens or 0 for s in all_summaries)
total_cost = sum(float(s.total_cost or 0) for s in all_summaries)
total_weighted_time = sum((s.avg_response_time or 0) * (s.total_calls or 0) for s in all_summaries)
avg_response_time = total_weighted_time / total_calls if total_calls > 0 else 0.0
total_errors = sum((s.total_calls or 0) * (s.error_rate or 0) / 100 for s in all_summaries)
error_rate = (total_errors / total_calls * 100) if total_calls > 0 else 0.0
provider_breakdown = _build_provider_breakdown(all_summaries, PROVIDER_MAPPING)
# Historical breakdown per period
historical_breakdown = []
for s in all_summaries:
try:
status_val = s.usage_status.value
except Exception:
status_val = str(s.usage_status)
historical_breakdown.append({
'billing_period': s.billing_period,
'total_calls': s.total_calls or 0,
'total_tokens': s.total_tokens or 0,
'total_cost': float(s.total_cost or 0),
'usage_status': status_val,
'updated_at': s.updated_at.isoformat() if s.updated_at else None
})
return {
'billing_period': 'all',
'usage_status': _summaries_usage_status(all_summaries),
'total_calls': total_calls,
'total_tokens': total_tokens,
'total_cost': round(total_cost, 2),
'avg_response_time': round(avg_response_time, 2),
'error_rate': round(error_rate, 2),
'limits': limits,
'provider_breakdown': provider_breakdown,
'usage_percentages': {}, # misleading for all-time vs per-period limits
'historical_breakdown': historical_breakdown,
'last_updated': datetime.now().isoformat()
}
def get_current_period_usage(user_id: str, db: Session, pricing_service) -> Dict[str, Any]:
"""Get current billing period usage data with correct per-period limit percentages.
Returns a UsageStats-shaped dict with provider_breakdown and usage_percentages
computed against the plan's per-period limits.
"""
current_period = pricing_service.get_current_billing_period(user_id)
limits = pricing_service.get_user_limits(user_id)
summary = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
result = _empty_usage_response(current_period, limits)
result['usage_percentages'] = _build_usage_percentages({}, limits)
return result
provider_breakdown = _build_provider_breakdown([summary], PROVIDER_MAPPING)
usage_percentages = _build_usage_percentages(provider_breakdown, limits)
try:
status_val = summary.usage_status.value
except Exception:
status_val = str(summary.usage_status)
return {
'billing_period': current_period,
'usage_status': status_val,
'total_calls': summary.total_calls or 0,
'total_tokens': summary.total_tokens or 0,
'total_cost': round(float(summary.total_cost or 0), 2),
'avg_response_time': summary.avg_response_time or 0.0,
'error_rate': summary.error_rate or 0.0,
'limits': limits,
'provider_breakdown': provider_breakdown,
'usage_percentages': usage_percentages,
'historical_breakdown': [],
'last_updated': datetime.now().isoformat()
}
def get_usage_for_period(user_id: str, billing_period: str, db: Session, pricing_service) -> Dict[str, Any]:
"""Get usage data for a specific billing period.
Returns a UsageStats-shaped dict with that period's provider_breakdown
and usage_percentages computed against plan limits.
"""
limits = pricing_service.get_user_limits(user_id)
summary = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == billing_period
).first()
if not summary:
result = _empty_usage_response(billing_period, limits)
result['usage_percentages'] = _build_usage_percentages({}, limits)
return result
provider_breakdown = _build_provider_breakdown([summary], PROVIDER_MAPPING)
usage_percentages = _build_usage_percentages(provider_breakdown, limits)
try:
status_val = summary.usage_status.value
except Exception:
status_val = str(summary.usage_status)
return {
'billing_period': billing_period,
'usage_status': status_val,
'total_calls': summary.total_calls or 0,
'total_tokens': summary.total_tokens or 0,
'total_cost': round(float(summary.total_cost or 0), 2),
'avg_response_time': summary.avg_response_time or 0.0,
'error_rate': summary.error_rate or 0.0,
'limits': limits,
'provider_breakdown': provider_breakdown,
'usage_percentages': usage_percentages,
'historical_breakdown': [],
'last_updated': datetime.now().isoformat()
}

View File

@@ -0,0 +1,38 @@
"""
Usage limit enforcement functions.
Extracted from usage_tracking_service.py for better maintainability.
"""
from typing import Tuple, Dict, Any
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from loguru import logger
from models.subscription_models import APIProvider
from services.subscription.pricing_service import PricingService
def enforce_usage_limits(user_id: str, provider: APIProvider,
tokens_requested: int, db: Session,
pricing_service: PricingService) -> Tuple[bool, str, Dict[str, Any]]:
"""Enforce usage limits before making an API call."""
# Check short-lived cache first (30s)
cache_key = f"{user_id}:{provider.value}"
now = datetime.utcnow()
# This would need access to self._enforce_cache
# For now, keeping the structure
result = pricing_service.check_usage_limits(
user_id=user_id,
provider=provider,
tokens_requested=tokens_requested
)
# Cache the result
# self._enforce_cache[cache_key] = {
# 'result': result,
# 'expires_at': now + timedelta(seconds=30)
# }
return tuple(result)

View File

@@ -0,0 +1,29 @@
"""
Usage statistics functions.
Extracted from usage_tracking_service.py for better maintainability.
"""
from typing import Dict, Any
from sqlalchemy.orm import Session
from loguru import logger
from datetime import datetime
from models.subscription_models import UsageSummary, UsageStatus, APIProvider
from services.subscription.usage_tracking_modules.historical_usage import get_all_historical_usage, get_usage_for_period
def get_user_usage_stats(user_id: str, billing_period: str, db: Session, pricing_service) -> Dict[str, Any]:
"""Get comprehensive usage statistics for a user.
When no billing_period is specified, returns ALL historical usage data.
When a specific period is given, returns only that period's data."""
if not user_id:
logger.error("get_user_usage_stats called without user_id")
raise ValueError("user_id is required")
# If no billing_period requested, return ALL historical data
if not billing_period:
return get_all_historical_usage(user_id, db, pricing_service)
# Return data for the specific billing period
return get_usage_for_period(user_id, billing_period, db, pricing_service)

View File

@@ -0,0 +1,18 @@
"""
Usage trends functions.
Extracted from usage_tracking_service.py for better maintainability.
"""
from typing import Dict, Any
from sqlalchemy.orm import Session
from loguru import logger
def get_usage_trends(user_id: str, months: int, db: Session) -> Dict[str, Any]:
"""Get usage trends over time with self-healing from logs."""
from services.subscription.usage_tracking_helpers import build_billing_periods, query_usage_summaries, self_heal_summaries_from_logs, build_usage_trends_response
periods = build_billing_periods(months)
summary_dict = query_usage_summaries(db, user_id, periods)
self_heal_summaries_from_logs(db, user_id, periods, summary_dict)
return build_usage_trends_response(periods, summary_dict)

View File

@@ -1,41 +1,60 @@
"""
Usage Tracking Service
Comprehensive tracking of API usage, costs, and subscription limits.
Usage Tracking Service - Refactored into modular components.
This file now serves as a facade that delegates to specialized modules
in the usage_tracking_modules package.
Modules:
- historical_usage: Functions for aggregating historical usage data
- usage_stats: Functions for getting user usage statistics
- usage_trends: Functions for usage trend analysis
- limit_enforcement: Functions for enforcing usage limits
- alerts: Functions for usage alerts
"""
# Ensure Optional is available in global scope for dynamic imports
from typing import Optional
import asyncio
from typing import Dict, Any, List, Tuple
from datetime import datetime, timedelta
from typing import Dict, Any, Tuple, Optional
from sqlalchemy.orm import Session
from sqlalchemy import desc
from sqlalchemy import text
from loguru import logger
import json
from api.subscription.cache import clear_dashboard_cache
from datetime import datetime, timedelta
import time
from models.subscription_models import (
APIUsageLog, UsageSummary, APIProvider, UsageAlert,
UserSubscription, UsageStatus
APIProvider, UsageStatus, UserSubscription,
UsageSummary, APIUsageLog, UsageAlert
)
from .pricing_service import PricingService
from .provider_detection import detect_actual_provider
from .usage_tracking_helpers import (
build_billing_periods,
build_default_usage_percentages,
build_empty_usage_response,
from services.subscription.pricing_service import PricingService
from services.subscription.provider_detection import detect_actual_provider
from services.subscription.usage_tracking_helpers import (
build_provider_breakdown,
build_usage_trends_response,
build_default_usage_percentages,
calculate_final_total_cost,
maybe_persist_reconciled_costs,
build_usage_trends_response,
build_billing_periods,
query_usage_summaries,
reset_usage_summary_counters,
self_heal_summaries_from_logs,
reset_usage_summary_counters,
)
# Import clear_dashboard_cache lazily to avoid circular import
def _clear_dashboard_cache_for_user(user_id: str):
from api.subscription.cache import clear_dashboard_cache as _clear
return _clear(user_id)
from .usage_tracking_modules import (
get_all_historical_usage,
get_current_period_usage,
get_usage_for_period,
get_user_usage_stats,
get_usage_trends,
enforce_usage_limits,
check_usage_alerts,
create_usage_alert,
)
class UsageTrackingService:
"""Service for tracking API usage and managing subscription limits."""
"""Service for tracking API usage and managing billing information."""
def __init__(self, db: Session):
self.db = db
@@ -43,13 +62,14 @@ class UsageTrackingService:
# TTL cache (30s) for enforcement results to cut DB chatter
# key: f"{user_id}:{provider}", value: { 'result': (bool,str,dict), 'expires_at': datetime }
self._enforce_cache: Dict[str, Dict[str, Any]] = {}
def _get_authoritative_billing_period_keys(self, user_id: str, billing_period: Optional[str] = None) -> Dict[str, Any]:
"""Return authoritative billing period lookup keys. Always uses calendar month for consistency."""
"""Return authoritative billing period lookup keys. Always uses subscription period for consistency.
Maintains backward compatibility with existing calendar-month data."""
subscription = self.db.query(UserSubscription).filter(
UserSubscription.user_id == user_id
).first()
# If caller explicitly requested a billing period, use it
if billing_period:
return {
@@ -58,26 +78,125 @@ class UsageTrackingService:
"period_start": subscription.current_period_start if subscription else None,
"period_end": subscription.current_period_end if subscription else None,
}
# ALWAYS use current calendar month for billing period to ensure consistency
# This prevents data loss when subscription spans month boundaries
current_period = datetime.now().strftime("%Y-%m")
# Get subscription period if available
subscription_period = None
if subscription and subscription.current_period_start:
subscription_period = subscription.current_period_start.strftime("%Y-%m")
# Get calendar period
calendar_period = datetime.now().strftime("%Y-%m")
# Check which period has usage data
from models.subscription_models import UsageSummary
if subscription_period:
# Check if data exists for subscription period
sub_data = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == subscription_period
).first()
if sub_data:
# Use subscription period (has data)
return {
"billing_period": subscription_period,
"lookup_periods": [subscription_period],
"period_start": subscription.current_period_start,
"period_end": subscription.current_period_end,
}
# No data for subscription period, check calendar period (backward compatibility)
if calendar_period != subscription_period:
cal_data = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == calendar_period
).first()
if cal_data:
logger.info(f"Using calendar period {calendar_period} for backward compatibility (subscription period {subscription_period} has no data)")
return {
"billing_period": calendar_period,
"lookup_periods": [calendar_period],
"period_start": None,
"period_end": None,
}
# No data in either period, use subscription period
return {
"billing_period": subscription_period,
"lookup_periods": [subscription_period],
"period_start": subscription.current_period_start,
"period_end": subscription.current_period_end,
}
# No subscription, check for any existing data
latest_summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id
).order_by(UsageSummary.billing_period.desc()).first()
if latest_summary:
logger.info(f"Using latest billing period from UsageSummary: {latest_summary.billing_period} for user {user_id}")
return {
"billing_period": latest_summary.billing_period,
"lookup_periods": [latest_summary.billing_period],
"period_start": None,
"period_end": None,
}
# Last fallback to calendar month for free tier / no subscription
return {
"billing_period": current_period,
"lookup_periods": [current_period],
"period_start": subscription.current_period_start if subscription else None,
"period_end": subscription.current_period_end if subscription else None,
"billing_period": calendar_period,
"lookup_periods": [calendar_period],
"period_start": None,
"period_end": None,
}
# Delegate to modular functions
def get_user_usage_stats(self, user_id: str, billing_period: str = None) -> Dict[str, Any]:
"""Get comprehensive usage statistics for a user."""
return get_user_usage_stats(user_id, billing_period, self.db, self.pricing_service)
def _get_all_historical_usage(self, user_id: str) -> Dict[str, Any]:
"""Get ALL historical usage data aggregated across all billing periods."""
return get_all_historical_usage(user_id, self.db, self.pricing_service)
def get_current_period_usage(self, user_id: str) -> Dict[str, Any]:
"""Get current billing period usage with correct per-period limit percentages."""
return get_current_period_usage(user_id, self.db, self.pricing_service)
def get_usage_for_period(self, user_id: str, billing_period: str) -> Dict[str, Any]:
"""Get usage for a specific billing period."""
return get_usage_for_period(user_id, billing_period, self.db, self.pricing_service)
def get_usage_trends(self, user_id: str, months: int = 6) -> Dict[str, Any]:
"""Get usage trends over time with self-healing from logs."""
return get_usage_trends(user_id, months, self.db)
async def enforce_usage_limits(self, user_id: str, provider: APIProvider,
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
"""Enforce usage limits before making an API call."""
return enforce_usage_limits(user_id, provider, tokens_requested, self.db, self.pricing_service)
async def _check_usage_alerts(self, user_id: str, provider: APIProvider, billing_period: str):
"""Check if usage alerts should be sent."""
check_usage_alerts(user_id, provider, billing_period, self.db, self.pricing_service)
async def _create_usage_alert(self, user_id: str, provider: APIProvider,
threshold: int, current_usage: int, limit: int,
billing_period: str):
"""Create a usage alert."""
create_usage_alert(user_id, provider, threshold, current_usage, limit, billing_period, self.db)
# Keep the track_api_usage method here as it's the core functionality
async def track_api_usage(self, user_id: str, provider: APIProvider,
endpoint: str, method: str, model_used: str = None,
tokens_input: int = 0, tokens_output: int = 0,
response_time: float = 0.0, status_code: int = 200,
request_size: int = None, response_size: int = None,
user_agent: str = None, ip_address: str = None,
error_message: str = None, retry_count: int = 0,
**kwargs) -> Dict[str, Any]:
endpoint: str, method: str, model_used: str = None,
tokens_input: int = 0, tokens_output: int = 0,
response_time: float = 0.0, status_code: int = 200,
request_size: int = None, response_size: int = None,
user_agent: str = None, ip_address: str = None,
error_message: str = None, retry_count: int = 0,
**kwargs) -> Dict[str, Any]:
"""Track an API usage event and update billing information."""
try:
@@ -165,394 +284,81 @@ class UsageTrackingService:
# Invalidate dashboard cache so header stats update immediately
try:
clear_dashboard_cache(user_id)
_clear_dashboard_cache_for_user(user_id)
except Exception as cache_err:
logger.debug(f"Could not clear dashboard cache: {cache_err}")
logger.info(f"Tracked API usage: {user_id} -> {provider.value} -> ${cost_data['cost_total']:.6f}")
logger.warning(f"Failed to clear dashboard cache: {cache_err}")
return {
'usage_logged': True,
'cost': cost_data['cost_total'],
'tokens_used': (tokens_input or 0) + (tokens_output or 0),
'billing_period': billing_period
"success": True,
"cost": cost_data['cost_total'],
"tokens": (tokens_input or 0) + (tokens_output or 0),
"billing_period": billing_period
}
except Exception as e:
logger.error(f"Error tracking API usage: {str(e)}")
logger.error(f"Failed to track API usage: {e}")
self.db.rollback()
return {
'usage_logged': False,
'error': str(e)
"success": False,
"error": str(e)
}
async def _update_usage_summary(self, user_id: str, provider: APIProvider,
tokens_used: int, cost: float, billing_period: str,
response_time: float, is_error: bool):
"""Update the usage summary for a user."""
tokens_used: int, cost: float,
billing_period: str,
response_time: float = 0.0,
is_error: bool = False):
"""Update or create usage summary for the billing period."""
# Get or create usage summary
period_keys = self._get_authoritative_billing_period_keys(user_id, billing_period)
# Get or create summary
summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
UsageSummary.billing_period == billing_period
).first()
if not summary:
logger.info(f"[UsageTracking] Creating new UsageSummary for user={user_id}, period={period_keys['billing_period']}")
summary = UsageSummary(
user_id=user_id,
billing_period=period_keys["billing_period"]
billing_period=billing_period,
usage_status=UsageStatus.ACTIVE,
total_calls=0,
total_tokens=0,
total_cost=0.0
)
self.db.add(summary)
else:
logger.debug(f"[UsageTracking] Found existing UsageSummary for user={user_id}, period={summary.billing_period}, calls={summary.total_calls}")
# Update provider-specific counters
# Update counts
summary.total_calls = (summary.total_calls or 0) + 1
summary.total_tokens = (summary.total_tokens or 0) + tokens_used
summary.total_cost = (summary.total_cost or 0.0) + cost
# Update provider-specific counts
provider_name = provider.value
current_calls = getattr(summary, f"{provider_name}_calls", 0)
current_calls = getattr(summary, f"{provider_name}_calls", 0) or 0
setattr(summary, f"{provider_name}_calls", current_calls + 1)
# Update token usage for LLM providers
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL, APIProvider.WAVESPEED]:
current_tokens = getattr(summary, f"{provider_name}_tokens", 0)
setattr(summary, f"{provider_name}_tokens", current_tokens + tokens_used)
# Update provider-specific tokens
tokens_attr = f"{provider_name}_tokens"
if hasattr(summary, tokens_attr):
current_tokens = getattr(summary, tokens_attr, 0) or 0
setattr(summary, tokens_attr, current_tokens + tokens_used)
# Update cost
current_cost = getattr(summary, f"{provider_name}_cost", 0.0)
setattr(summary, f"{provider_name}_cost", current_cost + cost)
# Update provider-specific cost
cost_attr = f"{provider_name}_cost"
if hasattr(summary, cost_attr):
current_cost = getattr(summary, cost_attr, 0.0) or 0.0
setattr(summary, cost_attr, current_cost + cost)
# Update totals
summary.total_calls += 1
summary.total_tokens += tokens_used
summary.total_cost += cost
# Update response time (rolling average)
if response_time > 0:
current_avg = summary.avg_response_time or 0.0
current_calls = summary.total_calls or 1
summary.avg_response_time = ((current_avg * (current_calls - 1)) + response_time) / current_calls
# Update performance metrics
if summary.total_calls > 0:
# Update average response time
total_response_time = summary.avg_response_time * (summary.total_calls - 1) + response_time
summary.avg_response_time = total_response_time / summary.total_calls
# Update error rate
if is_error:
error_count = int(summary.error_rate * (summary.total_calls - 1) / 100) + 1
summary.error_rate = (error_count / summary.total_calls) * 100
else:
error_count = int(summary.error_rate * (summary.total_calls - 1) / 100)
summary.error_rate = (error_count / summary.total_calls) * 100
# Update usage status based on limits
await self._update_usage_status(summary)
# Update error rate
if is_error:
summary.error_count = (summary.error_count or 0) + 1
total_calls = summary.total_calls or 1
summary.error_rate = (summary.error_count / total_calls) * 100
summary.updated_at = datetime.utcnow()
async def _update_usage_status(self, summary: UsageSummary):
"""Update usage status based on subscription limits."""
limits = self.pricing_service.get_user_limits(summary.user_id)
if not limits:
return
# Check various limits and determine status
max_usage_percentage = 0.0
# Check cost limit
cost_limit = limits['limits'].get('monthly_cost', 0)
if cost_limit > 0:
cost_usage_pct = (summary.total_cost / cost_limit) * 100
max_usage_percentage = max(max_usage_percentage, cost_usage_pct)
# Check call limits for each provider
for provider in APIProvider:
provider_name = provider.value
current_calls = getattr(summary, f"{provider_name}_calls", 0)
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
if call_limit > 0:
call_usage_pct = (current_calls / call_limit) * 100
max_usage_percentage = max(max_usage_percentage, call_usage_pct)
# Update status based on highest usage percentage
if max_usage_percentage >= 100:
summary.usage_status = UsageStatus.LIMIT_REACHED
elif max_usage_percentage >= 80:
summary.usage_status = UsageStatus.WARNING
else:
summary.usage_status = UsageStatus.ACTIVE
async def _check_usage_alerts(self, user_id: str, provider: APIProvider, billing_period: str):
"""Check if usage alerts should be sent."""
# Get current usage
period_keys = self._get_authoritative_billing_period_keys(user_id, billing_period)
summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
).first()
if not summary:
return
# Get user limits
limits = self.pricing_service.get_user_limits(user_id)
if not limits:
return
# Check for alert thresholds (80%, 90%, 100%)
thresholds = [80, 90, 100]
for threshold in thresholds:
# Check if alert already sent for this threshold
existing_alert = self.db.query(UsageAlert).filter(
UsageAlert.user_id == user_id,
UsageAlert.billing_period == billing_period,
UsageAlert.threshold_percentage == threshold,
UsageAlert.provider == provider,
UsageAlert.is_sent == True
).first()
if existing_alert:
continue
# Check if threshold is reached
provider_name = provider.value
current_calls = getattr(summary, f"{provider_name}_calls", 0)
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
if call_limit > 0:
usage_percentage = (current_calls / call_limit) * 100
if usage_percentage >= threshold:
await self._create_usage_alert(
user_id=user_id,
provider=provider,
threshold=threshold,
current_usage=current_calls,
limit=call_limit,
billing_period=billing_period
)
async def _create_usage_alert(self, user_id: str, provider: APIProvider,
threshold: int, current_usage: int, limit: int,
billing_period: str):
"""Create a usage alert."""
# Determine alert type and severity
if threshold >= 100:
alert_type = "limit_reached"
severity = "error"
title = f"API Limit Reached - {provider.value.title()}"
message = f"You have reached your {provider.value} API limit of {limit:,} calls for this billing period."
elif threshold >= 90:
alert_type = "usage_warning"
severity = "warning"
title = f"API Usage Warning - {provider.value.title()}"
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
else:
alert_type = "usage_warning"
severity = "info"
title = f"API Usage Notice - {provider.value.title()}"
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
alert = UsageAlert(
user_id=user_id,
alert_type=alert_type,
threshold_percentage=threshold,
provider=provider,
title=title,
message=message,
severity=severity,
billing_period=billing_period
)
self.db.add(alert)
logger.info(f"Created usage alert for {user_id}: {title}")
def get_user_usage_stats(self, user_id: str, billing_period: str = None) -> Dict[str, Any]:
"""Get comprehensive usage statistics for a user."""
if not user_id:
logger.error("get_user_usage_stats called without user_id")
raise ValueError("user_id is required")
requested_billing_period = billing_period
period_keys = self._get_authoritative_billing_period_keys(user_id, requested_billing_period)
billing_period = period_keys["billing_period"]
logger.debug(f"[get_user_usage_stats] user={user_id}, billing_period={billing_period}, lookup_periods={period_keys['lookup_periods']}")
# Get usage summary
summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
).first()
if summary:
logger.debug(f"[get_user_usage_stats] Found summary: period={summary.billing_period}, calls={summary.total_calls}, cost={summary.total_cost}")
else:
logger.debug(f"[get_user_usage_stats] No summary found for user={user_id}, period={billing_period}")
# Get user limits
limits = self.pricing_service.get_user_limits(user_id)
# Get recent alerts
alerts = self.db.query(UsageAlert).filter(
UsageAlert.user_id == user_id,
UsageAlert.billing_period == billing_period,
UsageAlert.is_read == False
).order_by(UsageAlert.created_at.desc()).limit(10).all()
if not summary:
# If no summary exists for current period, we should initialize it
# This handles the "start of month" case where a user logs in but hasn't made calls yet
if not requested_billing_period:
logger.info(f"Initializing empty UsageSummary for user {user_id} in period {billing_period}")
summary = UsageSummary(
user_id=user_id,
billing_period=billing_period,
usage_status=UsageStatus.ACTIVE,
total_calls=0,
total_tokens=0,
total_cost=0.0
)
try:
self.db.add(summary)
self.db.commit()
self.db.refresh(summary)
except Exception as e:
logger.error(f"Failed to initialize summary: {e}")
self.db.rollback()
# Fallback to zero-struct return if DB write fails
pass
if not summary: # Still no summary after attempt
return build_empty_usage_response(
billing_period=billing_period,
limits=limits,
providers=APIProvider,
)
# Provider breakdown - calculate costs first, then use for percentages
# Only include Gemini and HuggingFace (HuggingFace is stored under MISTRAL enum)
provider_breakdown, resolved_costs, core_counts = build_provider_breakdown(
db=self.db,
user_id=user_id,
billing_period=billing_period,
summary=summary,
)
summary_total_cost = summary.total_cost or 0.0
calculated_total_cost, final_total_cost = calculate_final_total_cost(
summary_total_cost=summary_total_cost,
resolved_costs=resolved_costs,
)
maybe_persist_reconciled_costs(
db=self.db,
summary=summary,
summary_total_cost=summary_total_cost,
calculated_total_cost=calculated_total_cost,
final_total_cost=final_total_cost,
resolved_costs=resolved_costs,
)
# Calculate usage percentages - only for Gemini and HuggingFace
# Use the calculated costs for accurate percentages
usage_percentages = build_default_usage_percentages(APIProvider)
if limits:
# Gemini
gemini_call_limit = limits['limits'].get("gemini_calls", 0) or 0
if gemini_call_limit > 0:
usage_percentages['gemini_calls'] = (core_counts['gemini_calls'] / gemini_call_limit) * 100
# HuggingFace (stored as mistral in database)
mistral_call_limit = limits['limits'].get("mistral_calls", 0) or 0
if mistral_call_limit > 0:
usage_percentages['mistral_calls'] = (core_counts['mistral_calls'] / mistral_call_limit) * 100
# Cost usage percentage - use final_total_cost (calculated from logs if needed)
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
if cost_limit > 0:
usage_percentages['cost'] = (final_total_cost / cost_limit) * 100
return {
'billing_period': billing_period,
'usage_status': summary.usage_status.value if hasattr(summary.usage_status, 'value') else str(summary.usage_status),
'total_calls': summary.total_calls or 0,
'total_tokens': summary.total_tokens or 0,
'total_cost': final_total_cost,
'avg_response_time': summary.avg_response_time or 0.0,
'error_rate': summary.error_rate or 0.0,
'limits': limits,
'provider_breakdown': provider_breakdown,
'alerts': [
{
'id': alert.id,
'type': alert.alert_type,
'title': alert.title,
'message': alert.message,
'severity': alert.severity,
'created_at': alert.created_at.isoformat()
}
for alert in alerts
],
'usage_percentages': usage_percentages,
'last_updated': summary.updated_at.isoformat()
}
def get_usage_trends(self, user_id: str, months: int = 6) -> Dict[str, Any]:
"""Get usage trends over time with self-healing from logs."""
periods = build_billing_periods(months)
summary_dict = query_usage_summaries(self.db, user_id, periods)
self_heal_summaries_from_logs(self.db, user_id, periods, summary_dict)
return build_usage_trends_response(periods, summary_dict)
async def enforce_usage_limits(self, user_id: str, provider: APIProvider,
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
"""Enforce usage limits before making an API call."""
# Check short-lived cache first (30s)
cache_key = f"{user_id}:{provider.value}"
now = datetime.utcnow()
cached = self._enforce_cache.get(cache_key)
if cached and cached.get('expires_at') and cached['expires_at'] > now:
return tuple(cached['result']) # type: ignore
result = self.pricing_service.check_usage_limits(
user_id=user_id,
provider=provider,
tokens_requested=tokens_requested
)
self._enforce_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
async def reset_current_billing_period(self, user_id: str) -> Dict[str, Any]:
"""Reset usage status and counters for the current billing period (after plan renewal/change)."""
period_keys = self._get_authoritative_billing_period_keys(user_id)
billing_period = period_keys["billing_period"]
summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
).first()
if not summary:
return {"reset": False, "reason": "no_summary"}
try:
reset_usage_summary_counters(summary)
self.db.commit()
# Invalidate dashboard cache so header stats update after reset
try:
clear_dashboard_cache(user_id)
except Exception as cache_err:
logger.debug(f"Could not clear dashboard cache: {cache_err}")
logger.info(f"Reset usage counters for user {user_id} in billing period {billing_period} after renewal")
return {"reset": True, "counters_reset": True}
except Exception as e:
self.db.rollback()
logger.error(f"Error resetting usage status: {e}")
return {"reset": False, "error": str(e)}