From 557f700f683c5a179f2ec903080299b3ea1b41a0 Mon Sep 17 00:00:00 2001 From: ajaysi Date: Sun, 29 Mar 2026 12:50:50 +0530 Subject: [PATCH] fix: Resolve APIProvider enum mismatch causing dashboard errors - Fix import path in subscriptions.py (pricing_service location) - Add values_callable to APIUsageLog.provider enum column - Normalize provider values to lowercase in usage trends helpers - Add migration script for existing databases --- .../api/subscription/routes/subscriptions.py | 4 +- backend/models/subscription_models.py | 2 +- backend/scripts/migrate_api_provider_enum.py | 143 ++++++++++++++++++ .../usage_trends_helpers.py | 36 ++++- 4 files changed, 176 insertions(+), 9 deletions(-) create mode 100644 backend/scripts/migrate_api_provider_enum.py diff --git a/backend/api/subscription/routes/subscriptions.py b/backend/api/subscription/routes/subscriptions.py index 10f44034..ef0df4a2 100644 --- a/backend/api/subscription/routes/subscriptions.py +++ b/backend/api/subscription/routes/subscriptions.py @@ -170,7 +170,7 @@ async def get_subscription_status( if getattr(subscription, 'auto_renew', False): # advance period try: - from services.pricing_service import PricingService + from services.subscription.pricing_service import PricingService pricing = PricingService(db) # reuse helper to ensure current pricing._ensure_subscription_current(subscription) @@ -245,7 +245,7 @@ async def get_subscription_status( if subscription.current_period_end < now: if getattr(subscription, 'auto_renew', False): try: - from services.pricing_service import PricingService + from services.subscription.pricing_service import PricingService pricing = PricingService(db) pricing._ensure_subscription_current(subscription) except Exception as e2: diff --git a/backend/models/subscription_models.py b/backend/models/subscription_models.py index 2541f678..482e2c04 100644 --- a/backend/models/subscription_models.py +++ b/backend/models/subscription_models.py @@ -155,7 +155,7 @@ class APIUsageLog(Base): user_id = Column(String(100), nullable=False) # API Details - provider = Column(Enum(APIProvider), nullable=False) + provider = Column(Enum(APIProvider, values_callable=lambda obj: [e.value for e in obj]), nullable=False) endpoint = Column(String(200), nullable=False) method = Column(String(10), nullable=False) model_used = Column(String(100), nullable=True) # e.g., "gemini-2.5-flash" diff --git a/backend/scripts/migrate_api_provider_enum.py b/backend/scripts/migrate_api_provider_enum.py new file mode 100644 index 00000000..6f2087d6 --- /dev/null +++ b/backend/scripts/migrate_api_provider_enum.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 +""" +Migration script to normalize APIProvider enum values to lowercase. + +This fixes the issue where the database has uppercase values like "VIDEO", "MISTRAL" +but the enum expects lowercase values like "video", "mistral". + +Run this script once to migrate existing data. +""" + +import os +import sys +from pathlib import Path + +# Add backend to path +backend_dir = Path(__file__).parent +sys.path.insert(0, str(backend_dir)) + +# Load env first +from dotenv import load_dotenv +load_dotenv(backend_dir / '.env') + +from loguru import logger +import sqlite3 +from glob import glob + +# Provider mapping: uppercase -> lowercase +PROVIDER_MAP = { + 'GEMINI': 'gemini', + 'OPENAI': 'openai', + 'ANTHROPIC': 'anthropic', + 'MISTRAL': 'mistral', + 'WAVESPEED': 'wavespeed', + 'TAVILY': 'tavily', + 'SERPER': 'serper', + 'METAPHOR': 'metaphor', + 'FIRECRAWL': 'firecrawl', + 'STABILITY': 'stability', + 'EXA': 'exa', + 'VIDEO': 'video', + 'IMAGE_EDIT': 'image_edit', + 'AUDIO': 'audio', +} + +def normalize_provider_value(value: str) -> str: + """Convert provider value to lowercase if it's uppercase.""" + if not value: + return value + upper_value = value.upper() + if upper_value in PROVIDER_MAP: + return PROVIDER_MAP[upper_value] + # If already lowercase, return as-is + return value + +def migrate_database(db_path: str) -> tuple[int, int]: + """Migrate a single database file. Returns (total_rows, updated_rows).""" + try: + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + # Check if api_usage_logs table exists + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='api_usage_logs'") + if not cursor.fetchone(): + conn.close() + return 0, 0 + + # Get all unique provider values + cursor.execute("SELECT DISTINCT provider FROM api_usage_logs") + unique_providers = [row[0] for row in cursor.fetchall()] + + total_rows = 0 + updated_rows = 0 + + # Count total rows + cursor.execute("SELECT COUNT(*) FROM api_usage_logs") + total_rows = cursor.fetchone()[0] + + # Update each provider value that needs normalization + for provider in unique_providers: + normalized = normalize_provider_value(provider) + if provider != normalized: + cursor.execute( + "UPDATE api_usage_logs SET provider = ? WHERE provider = ?", + (normalized, provider) + ) + count = cursor.rowcount + updated_rows += count + logger.info(f" {db_path}: Updated {count} rows with provider '{provider}' -> '{normalized}'") + + conn.commit() + conn.close() + + return total_rows, updated_rows + + except Exception as e: + logger.error(f"Error migrating {db_path}: {e}") + return 0, 0 + +def main(): + """Main migration function.""" + logger.info("=" * 60) + logger.info("APIProvider Enum Normalization Migration") + logger.info("=" * 60) + + # Workspace directory where user databases are stored + workspace_dir = backend_dir / "workspace" + + # Find all database files in workspaces + all_dbs = set() + + if workspace_dir.exists(): + for workspace in workspace_dir.iterdir(): + if workspace.is_dir(): + # Look for database files in workspace subdirectories (database/, or root) + for db_pattern in ["**/*.db", "**/alwrity*.db"]: + all_dbs.update(workspace.glob(db_pattern)) + + logger.info(f"Found {len(all_dbs)} database files to check") + + total_dbs = 0 + total_rows = 0 + total_updated = 0 + + for db_path in sorted(all_dbs): + total_dbs += 1 + rows, updated = migrate_database(str(db_path)) + total_rows += rows + total_updated += updated + + logger.info("=" * 60) + logger.info(f"Migration complete!") + logger.info(f" Databases checked: {total_dbs}") + logger.info(f" Total rows: {total_rows}") + logger.info(f" Rows updated: {total_updated}") + logger.info("=" * 60) + + if total_updated > 0: + logger.warning("Please restart the backend server to ensure changes take effect.") + else: + logger.info("No updates needed - all provider values are already normalized.") + +if __name__ == "__main__": + main() diff --git a/backend/services/subscription/usage_tracking_helpers/usage_trends_helpers.py b/backend/services/subscription/usage_tracking_helpers/usage_trends_helpers.py index 8f2b0c7a..b2c00b8b 100644 --- a/backend/services/subscription/usage_tracking_helpers/usage_trends_helpers.py +++ b/backend/services/subscription/usage_tracking_helpers/usage_trends_helpers.py @@ -11,6 +11,26 @@ from sqlalchemy import func from models.subscription_models import APIProvider, APIUsageLog, UsageStatus, UsageSummary +def _normalize_provider_name(provider_input: Any) -> str | None: + """Safely extract provider name from enum or string, handling both name and value formats.""" + valid_providers = {'gemini', 'openai', 'anthropic', 'mistral', 'wavespeed', + 'tavily', 'serper', 'metaphor', 'firecrawl', 'stability', + 'exa', 'video', 'image_edit', 'audio'} + + try: + if hasattr(provider_input, "value"): + return provider_input.value + elif isinstance(provider_input, str): + name = provider_input.lower() + if "." in name: + name = name.split(".")[-1].lower() + return name + else: + return str(provider_input).lower() + except Exception: + return None + + def build_billing_periods(months: int) -> List[str]: """Build billing period keys (YYYY-MM) from oldest to newest.""" end_date = datetime.now() @@ -35,27 +55,31 @@ def query_usage_summaries(db: Any, user_id: str, periods: List[str]) -> Dict[str def self_heal_summaries_from_logs(db: Any, user_id: str, periods: List[str], summary_dict: Dict[str, Any]) -> None: """Backfill/create usage summaries from aggregated API usage logs.""" try: + from sqlalchemy import cast, String + log_stats = ( db.query( APIUsageLog.billing_period, - APIUsageLog.provider, + cast(APIUsageLog.provider, String).label("provider"), func.count(APIUsageLog.id).label("calls"), func.sum(APIUsageLog.cost_total).label("cost"), func.sum(APIUsageLog.tokens_total).label("tokens"), ) .filter(APIUsageLog.user_id == user_id, APIUsageLog.billing_period.in_(periods)) - .group_by(APIUsageLog.billing_period, APIUsageLog.provider) + .group_by(APIUsageLog.billing_period, cast(APIUsageLog.provider, String)) .all() ) log_data_by_period: Dict[str, Dict[str, Dict[str, float | int]]] = {} - for period, provider_enum, calls, cost, tokens in log_stats: + + for period, provider_str, calls, cost, tokens in log_stats: if period not in log_data_by_period: log_data_by_period[period] = {} - provider_name = provider_enum.value if hasattr(provider_enum, "value") else str(provider_enum).lower() - if "." in provider_name: - provider_name = provider_name.split(".")[-1].lower() + provider_name = _normalize_provider_name(provider_str) + if not provider_name: + logger.warning(f"[UsageStats] Could not normalize provider: '{provider_str}', skipping") + continue if provider_name not in log_data_by_period[period]: log_data_by_period[period][provider_name] = {"calls": 0, "cost": 0.0, "tokens": 0}