Base code
This commit is contained in:
276
backend/test/test_subscription_system.py
Normal file
276
backend/test/test_subscription_system.py
Normal file
@@ -0,0 +1,276 @@
|
||||
"""
|
||||
Test Script for Subscription System
|
||||
Tests the core functionality of the usage-based subscription system.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
# Add the backend directory to Python path
|
||||
backend_dir = Path(__file__).parent
|
||||
sys.path.insert(0, str(backend_dir))
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from loguru import logger
|
||||
|
||||
from services.database import engine
|
||||
from services.pricing_service import PricingService
|
||||
from services.usage_tracking_service import UsageTrackingService
|
||||
from models.subscription_models import APIProvider, SubscriptionTier
|
||||
|
||||
async def test_pricing_service():
|
||||
"""Test the pricing service functionality."""
|
||||
|
||||
logger.info("🧪 Testing Pricing Service...")
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
# Test cost calculation
|
||||
cost_data = pricing_service.calculate_api_cost(
|
||||
provider=APIProvider.GEMINI,
|
||||
model_name="gemini-2.5-flash",
|
||||
tokens_input=1000,
|
||||
tokens_output=500,
|
||||
request_count=1
|
||||
)
|
||||
|
||||
logger.info(f"✅ Cost calculation: {cost_data}")
|
||||
|
||||
# Test user limits
|
||||
limits = pricing_service.get_user_limits("test_user")
|
||||
logger.info(f"✅ User limits: {limits}")
|
||||
|
||||
# Test usage limit checking
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id="test_user",
|
||||
provider=APIProvider.GEMINI,
|
||||
tokens_requested=100
|
||||
)
|
||||
|
||||
logger.info(f"✅ Usage check: {can_proceed} - {message}")
|
||||
logger.info(f" Usage info: {usage_info}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Pricing service test failed: {e}")
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def test_usage_tracking():
|
||||
"""Test the usage tracking service."""
|
||||
|
||||
logger.info("🧪 Testing Usage Tracking Service...")
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
usage_service = UsageTrackingService(db)
|
||||
|
||||
# Test tracking an API usage
|
||||
result = await usage_service.track_api_usage(
|
||||
user_id="test_user",
|
||||
provider=APIProvider.GEMINI,
|
||||
endpoint="/api/generate",
|
||||
method="POST",
|
||||
model_used="gemini-2.5-flash",
|
||||
tokens_input=500,
|
||||
tokens_output=300,
|
||||
response_time=1.5,
|
||||
status_code=200
|
||||
)
|
||||
|
||||
logger.info(f"✅ Usage tracking result: {result}")
|
||||
|
||||
# Test getting usage stats
|
||||
stats = usage_service.get_user_usage_stats("test_user")
|
||||
logger.info(f"✅ Usage stats: {json.dumps(stats, indent=2)}")
|
||||
|
||||
# Test usage trends
|
||||
trends = usage_service.get_usage_trends("test_user", 3)
|
||||
logger.info(f"✅ Usage trends: {json.dumps(trends, indent=2)}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Usage tracking test failed: {e}")
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def test_limit_enforcement():
|
||||
"""Test usage limit enforcement."""
|
||||
|
||||
logger.info("🧪 Testing Limit Enforcement...")
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
usage_service = UsageTrackingService(db)
|
||||
|
||||
# Test multiple API calls to approach limits
|
||||
for i in range(5):
|
||||
result = await usage_service.track_api_usage(
|
||||
user_id="test_user_limits",
|
||||
provider=APIProvider.GEMINI,
|
||||
endpoint="/api/generate",
|
||||
method="POST",
|
||||
model_used="gemini-2.5-flash",
|
||||
tokens_input=1000,
|
||||
tokens_output=800,
|
||||
response_time=2.0,
|
||||
status_code=200
|
||||
)
|
||||
logger.info(f"Call {i+1}: {result}")
|
||||
|
||||
# Check if limits are being enforced
|
||||
can_proceed, message, usage_info = await usage_service.enforce_usage_limits(
|
||||
user_id="test_user_limits",
|
||||
provider=APIProvider.GEMINI,
|
||||
tokens_requested=5000
|
||||
)
|
||||
|
||||
logger.info(f"✅ Limit enforcement: {can_proceed} - {message}")
|
||||
logger.info(f" Usage info: {usage_info}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Limit enforcement test failed: {e}")
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def test_database_tables():
|
||||
"""Test that all subscription tables exist."""
|
||||
|
||||
logger.info("🧪 Testing Database Tables...")
|
||||
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
|
||||
with engine.connect() as conn:
|
||||
# Check for subscription tables
|
||||
tables_query = text("""
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND (
|
||||
name LIKE '%subscription%' OR
|
||||
name LIKE '%usage%' OR
|
||||
name LIKE '%pricing%' OR
|
||||
name LIKE '%billing%'
|
||||
)
|
||||
ORDER BY name
|
||||
""")
|
||||
|
||||
result = conn.execute(tables_query)
|
||||
tables = result.fetchall()
|
||||
|
||||
expected_tables = [
|
||||
'api_provider_pricing',
|
||||
'api_usage_logs',
|
||||
'billing_history',
|
||||
'subscription_plans',
|
||||
'usage_alerts',
|
||||
'usage_summaries',
|
||||
'user_subscriptions'
|
||||
]
|
||||
|
||||
found_tables = [t[0] for t in tables]
|
||||
logger.info(f"Found tables: {found_tables}")
|
||||
|
||||
missing_tables = [t for t in expected_tables if t not in found_tables]
|
||||
if missing_tables:
|
||||
logger.error(f"❌ Missing tables: {missing_tables}")
|
||||
return False
|
||||
|
||||
# Check table data
|
||||
for table in ['subscription_plans', 'api_provider_pricing']:
|
||||
count_query = text(f"SELECT COUNT(*) FROM {table}")
|
||||
result = conn.execute(count_query)
|
||||
count = result.fetchone()[0]
|
||||
logger.info(f"✅ {table}: {count} records")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Database tables test failed: {e}")
|
||||
return False
|
||||
|
||||
async def run_comprehensive_test():
|
||||
"""Run comprehensive test suite."""
|
||||
|
||||
logger.info("🚀 Starting Subscription System Comprehensive Test")
|
||||
logger.info("="*60)
|
||||
|
||||
test_results = {}
|
||||
|
||||
# Test 1: Database Tables
|
||||
logger.info("\n1. Testing Database Tables...")
|
||||
test_results['database_tables'] = test_database_tables()
|
||||
|
||||
# Test 2: Pricing Service
|
||||
logger.info("\n2. Testing Pricing Service...")
|
||||
test_results['pricing_service'] = await test_pricing_service()
|
||||
|
||||
# Test 3: Usage Tracking
|
||||
logger.info("\n3. Testing Usage Tracking...")
|
||||
test_results['usage_tracking'] = await test_usage_tracking()
|
||||
|
||||
# Test 4: Limit Enforcement
|
||||
logger.info("\n4. Testing Limit Enforcement...")
|
||||
test_results['limit_enforcement'] = await test_limit_enforcement()
|
||||
|
||||
# Summary
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("TEST RESULTS SUMMARY")
|
||||
logger.info("="*60)
|
||||
|
||||
passed = sum(1 for result in test_results.values() if result)
|
||||
total = len(test_results)
|
||||
|
||||
for test_name, result in test_results.items():
|
||||
status = "✅ PASS" if result else "❌ FAIL"
|
||||
logger.info(f"{test_name.upper().replace('_', ' ')}: {status}")
|
||||
|
||||
logger.info(f"\nOverall: {passed}/{total} tests passed")
|
||||
|
||||
if passed == total:
|
||||
logger.info("🎉 All tests passed! Subscription system is ready.")
|
||||
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("NEXT STEPS:")
|
||||
logger.info("="*60)
|
||||
logger.info("1. Start the FastAPI server:")
|
||||
logger.info(" cd backend && python start_alwrity_backend.py")
|
||||
logger.info("\n2. Test the API endpoints:")
|
||||
logger.info(" GET http://localhost:8000/api/subscription/plans")
|
||||
logger.info(" GET http://localhost:8000/api/subscription/pricing")
|
||||
logger.info(" GET http://localhost:8000/api/subscription/usage/test_user")
|
||||
logger.info("\n3. Integrate with your frontend dashboard")
|
||||
logger.info("4. Set up user authentication/identification")
|
||||
logger.info("5. Configure payment processing (Stripe, etc.)")
|
||||
logger.info("="*60)
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.error("❌ Some tests failed. Please check the errors above.")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the comprehensive test
|
||||
success = asyncio.run(run_comprehensive_test())
|
||||
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
|
||||
logger.info("✅ Test completed successfully!")
|
||||
Reference in New Issue
Block a user