Files
ALwrity/backend/scripts/migrate_api_provider_enum.py
ajaysi 557f700f68 fix: Resolve APIProvider enum mismatch causing dashboard errors
- Fix import path in subscriptions.py (pricing_service location)
- Add values_callable to APIUsageLog.provider enum column
- Normalize provider values to lowercase in usage trends helpers
- Add migration script for existing databases
2026-03-29 12:50:50 +05:30

144 lines
4.4 KiB
Python

#!/usr/bin/env python3
"""
Migration script to normalize APIProvider enum values to lowercase.
This fixes the issue where the database has uppercase values like "VIDEO", "MISTRAL"
but the enum expects lowercase values like "video", "mistral".
Run this script once to migrate existing data.
"""
import os
import sys
from pathlib import Path
# Add backend to path
backend_dir = Path(__file__).parent
sys.path.insert(0, str(backend_dir))
# Load env first
from dotenv import load_dotenv
load_dotenv(backend_dir / '.env')
from loguru import logger
import sqlite3
from glob import glob
# Provider mapping: uppercase -> lowercase
PROVIDER_MAP = {
'GEMINI': 'gemini',
'OPENAI': 'openai',
'ANTHROPIC': 'anthropic',
'MISTRAL': 'mistral',
'WAVESPEED': 'wavespeed',
'TAVILY': 'tavily',
'SERPER': 'serper',
'METAPHOR': 'metaphor',
'FIRECRAWL': 'firecrawl',
'STABILITY': 'stability',
'EXA': 'exa',
'VIDEO': 'video',
'IMAGE_EDIT': 'image_edit',
'AUDIO': 'audio',
}
def normalize_provider_value(value: str) -> str:
"""Convert provider value to lowercase if it's uppercase."""
if not value:
return value
upper_value = value.upper()
if upper_value in PROVIDER_MAP:
return PROVIDER_MAP[upper_value]
# If already lowercase, return as-is
return value
def migrate_database(db_path: str) -> tuple[int, int]:
"""Migrate a single database file. Returns (total_rows, updated_rows)."""
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Check if api_usage_logs table exists
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='api_usage_logs'")
if not cursor.fetchone():
conn.close()
return 0, 0
# Get all unique provider values
cursor.execute("SELECT DISTINCT provider FROM api_usage_logs")
unique_providers = [row[0] for row in cursor.fetchall()]
total_rows = 0
updated_rows = 0
# Count total rows
cursor.execute("SELECT COUNT(*) FROM api_usage_logs")
total_rows = cursor.fetchone()[0]
# Update each provider value that needs normalization
for provider in unique_providers:
normalized = normalize_provider_value(provider)
if provider != normalized:
cursor.execute(
"UPDATE api_usage_logs SET provider = ? WHERE provider = ?",
(normalized, provider)
)
count = cursor.rowcount
updated_rows += count
logger.info(f" {db_path}: Updated {count} rows with provider '{provider}' -> '{normalized}'")
conn.commit()
conn.close()
return total_rows, updated_rows
except Exception as e:
logger.error(f"Error migrating {db_path}: {e}")
return 0, 0
def main():
"""Main migration function."""
logger.info("=" * 60)
logger.info("APIProvider Enum Normalization Migration")
logger.info("=" * 60)
# Workspace directory where user databases are stored
workspace_dir = backend_dir / "workspace"
# Find all database files in workspaces
all_dbs = set()
if workspace_dir.exists():
for workspace in workspace_dir.iterdir():
if workspace.is_dir():
# Look for database files in workspace subdirectories (database/, or root)
for db_pattern in ["**/*.db", "**/alwrity*.db"]:
all_dbs.update(workspace.glob(db_pattern))
logger.info(f"Found {len(all_dbs)} database files to check")
total_dbs = 0
total_rows = 0
total_updated = 0
for db_path in sorted(all_dbs):
total_dbs += 1
rows, updated = migrate_database(str(db_path))
total_rows += rows
total_updated += updated
logger.info("=" * 60)
logger.info(f"Migration complete!")
logger.info(f" Databases checked: {total_dbs}")
logger.info(f" Total rows: {total_rows}")
logger.info(f" Rows updated: {total_updated}")
logger.info("=" * 60)
if total_updated > 0:
logger.warning("Please restart the backend server to ensure changes take effect.")
else:
logger.info("No updates needed - all provider values are already normalized.")
if __name__ == "__main__":
main()