""" Test Script for Subscription System Tests the core functionality of the usage-based subscription system. """ import sys import os from pathlib import Path import asyncio import json # Add the backend directory to Python path backend_dir = Path(__file__).parent sys.path.insert(0, str(backend_dir)) from sqlalchemy.orm import sessionmaker from loguru import logger from services.database import engine from services.pricing_service import PricingService from services.usage_tracking_service import UsageTrackingService from models.subscription_models import APIProvider, SubscriptionTier async def test_pricing_service(): """Test the pricing service functionality.""" logger.info("๐Ÿงช Testing Pricing Service...") SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) db = SessionLocal() try: pricing_service = PricingService(db) # Test cost calculation cost_data = pricing_service.calculate_api_cost( provider=APIProvider.GEMINI, model_name="gemini-2.5-flash", tokens_input=1000, tokens_output=500, request_count=1 ) logger.info(f"โœ… Cost calculation: {cost_data}") # Test user limits limits = pricing_service.get_user_limits("test_user") logger.info(f"โœ… User limits: {limits}") # Test usage limit checking can_proceed, message, usage_info = pricing_service.check_usage_limits( user_id="test_user", provider=APIProvider.GEMINI, tokens_requested=100 ) logger.info(f"โœ… Usage check: {can_proceed} - {message}") logger.info(f" Usage info: {usage_info}") return True except Exception as e: logger.error(f"โŒ Pricing service test failed: {e}") return False finally: db.close() async def test_usage_tracking(): """Test the usage tracking service.""" logger.info("๐Ÿงช Testing Usage Tracking Service...") SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) db = SessionLocal() try: usage_service = UsageTrackingService(db) # Test tracking an API usage result = await usage_service.track_api_usage( user_id="test_user", provider=APIProvider.GEMINI, endpoint="/api/generate", method="POST", model_used="gemini-2.5-flash", tokens_input=500, tokens_output=300, response_time=1.5, status_code=200 ) logger.info(f"โœ… Usage tracking result: {result}") # Test getting usage stats stats = usage_service.get_user_usage_stats("test_user") logger.info(f"โœ… Usage stats: {json.dumps(stats, indent=2)}") # Test usage trends trends = usage_service.get_usage_trends("test_user", 3) logger.info(f"โœ… Usage trends: {json.dumps(trends, indent=2)}") return True except Exception as e: logger.error(f"โŒ Usage tracking test failed: {e}") return False finally: db.close() async def test_limit_enforcement(): """Test usage limit enforcement.""" logger.info("๐Ÿงช Testing Limit Enforcement...") SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) db = SessionLocal() try: usage_service = UsageTrackingService(db) # Test multiple API calls to approach limits for i in range(5): result = await usage_service.track_api_usage( user_id="test_user_limits", provider=APIProvider.GEMINI, endpoint="/api/generate", method="POST", model_used="gemini-2.5-flash", tokens_input=1000, tokens_output=800, response_time=2.0, status_code=200 ) logger.info(f"Call {i+1}: {result}") # Check if limits are being enforced can_proceed, message, usage_info = await usage_service.enforce_usage_limits( user_id="test_user_limits", provider=APIProvider.GEMINI, tokens_requested=5000 ) logger.info(f"โœ… Limit enforcement: {can_proceed} - {message}") logger.info(f" Usage info: {usage_info}") return True except Exception as e: logger.error(f"โŒ Limit enforcement test failed: {e}") return False finally: db.close() def test_database_tables(): """Test that all subscription tables exist.""" logger.info("๐Ÿงช Testing Database Tables...") try: from sqlalchemy import text with engine.connect() as conn: # Check for subscription tables tables_query = text(""" SELECT name FROM sqlite_master WHERE type='table' AND ( name LIKE '%subscription%' OR name LIKE '%usage%' OR name LIKE '%pricing%' OR name LIKE '%billing%' ) ORDER BY name """) result = conn.execute(tables_query) tables = result.fetchall() expected_tables = [ 'api_provider_pricing', 'api_usage_logs', 'billing_history', 'subscription_plans', 'usage_alerts', 'usage_summaries', 'user_subscriptions' ] found_tables = [t[0] for t in tables] logger.info(f"Found tables: {found_tables}") missing_tables = [t for t in expected_tables if t not in found_tables] if missing_tables: logger.error(f"โŒ Missing tables: {missing_tables}") return False # Check table data for table in ['subscription_plans', 'api_provider_pricing']: count_query = text(f"SELECT COUNT(*) FROM {table}") result = conn.execute(count_query) count = result.fetchone()[0] logger.info(f"โœ… {table}: {count} records") return True except Exception as e: logger.error(f"โŒ Database tables test failed: {e}") return False async def run_comprehensive_test(): """Run comprehensive test suite.""" logger.info("๐Ÿš€ Starting Subscription System Comprehensive Test") logger.info("="*60) test_results = {} # Test 1: Database Tables logger.info("\n1. Testing Database Tables...") test_results['database_tables'] = test_database_tables() # Test 2: Pricing Service logger.info("\n2. Testing Pricing Service...") test_results['pricing_service'] = await test_pricing_service() # Test 3: Usage Tracking logger.info("\n3. Testing Usage Tracking...") test_results['usage_tracking'] = await test_usage_tracking() # Test 4: Limit Enforcement logger.info("\n4. Testing Limit Enforcement...") test_results['limit_enforcement'] = await test_limit_enforcement() # Summary logger.info("\n" + "="*60) logger.info("TEST RESULTS SUMMARY") logger.info("="*60) passed = sum(1 for result in test_results.values() if result) total = len(test_results) for test_name, result in test_results.items(): status = "โœ… PASS" if result else "โŒ FAIL" logger.info(f"{test_name.upper().replace('_', ' ')}: {status}") logger.info(f"\nOverall: {passed}/{total} tests passed") if passed == total: logger.info("๐ŸŽ‰ All tests passed! Subscription system is ready.") logger.info("\n" + "="*60) logger.info("NEXT STEPS:") logger.info("="*60) logger.info("1. Start the FastAPI server:") logger.info(" cd backend && python start_alwrity_backend.py") logger.info("\n2. Test the API endpoints:") logger.info(" GET http://localhost:8000/api/subscription/plans") logger.info(" GET http://localhost:8000/api/subscription/pricing") logger.info(" GET http://localhost:8000/api/subscription/usage/test_user") logger.info("\n3. Integrate with your frontend dashboard") logger.info("4. Set up user authentication/identification") logger.info("5. Configure payment processing (Stripe, etc.)") logger.info("="*60) return True else: logger.error("โŒ Some tests failed. Please check the errors above.") return False if __name__ == "__main__": # Run the comprehensive test success = asyncio.run(run_comprehensive_test()) if not success: sys.exit(1) logger.info("โœ… Test completed successfully!")