Move Stripe plan price mapping to env with startup validation

This commit is contained in:
ي
2026-03-04 20:41:47 +05:30
parent 2318fd8a48
commit fc96e1218a
4 changed files with 120 additions and 12 deletions

View File

@@ -199,6 +199,26 @@ You can customize the server behavior with these environment variables:
- `PORT`: Server port (default: 8000)
- `RELOAD`: Enable auto-reload (default: true)
Subscription billing (Stripe) variables used in deployment:
- `STRIPE_SECRET_KEY`: Stripe API secret key (`sk_test_...` for test, `sk_live_...` for live).
- `STRIPE_WEBHOOK_SECRET`: Stripe webhook signing secret for `/api/subscription/webhook`.
- `STRIPE_MODE`: Stripe mode selector (`test` or `live`). Recommended to set explicitly in each environment.
- `STRIPE_PLAN_PRICE_MAPPING_TEST`: JSON mapping for test mode price IDs.
- `STRIPE_PLAN_PRICE_MAPPING_LIVE`: JSON mapping for live mode price IDs.
- `STRIPE_PLAN_PRICE_MAPPING`: Optional fallback JSON mapping used when mode-specific variable is not provided.
Required mapping keys validated at startup:
- `basic.monthly`
- `pro.monthly`
Example mapping value:
```json
{"basic":{"monthly":"price_123"},"pro":{"monthly":"price_456"}}
```
Example:
```bash
HOST=127.0.0.1 PORT=8080 python start_alwrity_backend.py

View File

@@ -1,3 +1,4 @@
import json
import os
import stripe
from typing import Optional, Dict, Any
@@ -8,11 +9,84 @@ from models.subscription_models import UserSubscription, SubscriptionPlan, Subsc
from services.subscription.pricing_service import PricingService
from datetime import datetime
STRIPE_PLAN_PRICE_MAPPING = {
(SubscriptionTier.BASIC.value, BillingCycle.MONTHLY.value): "price_1T2lWHR2EuR7zQJepLIVQ1EJ",
(SubscriptionTier.PRO.value, BillingCycle.MONTHLY.value): "price_1T2ljDR2EuR7zQJeuS317KCj",
REQUIRED_STRIPE_PLAN_KEYS = {
(SubscriptionTier.BASIC.value, BillingCycle.MONTHLY.value),
(SubscriptionTier.PRO.value, BillingCycle.MONTHLY.value),
}
def _detect_stripe_mode() -> str:
configured_mode = os.getenv("STRIPE_MODE", "").strip().lower()
if configured_mode in {"test", "live"}:
return configured_mode
secret_key = os.getenv("STRIPE_SECRET_KEY", "").strip()
if secret_key.startswith("sk_live_"):
return "live"
if secret_key.startswith("sk_test_"):
return "test"
# Default to test when mode cannot be derived.
return "test"
def _normalize_stripe_plan_price_mapping(raw_mapping: Dict[str, Any]) -> Dict[tuple[str, str], str]:
normalized_mapping: Dict[tuple[str, str], str] = {}
for tier, billing_cycle_map in raw_mapping.items():
if not isinstance(billing_cycle_map, dict):
raise RuntimeError(
"Stripe plan mapping must be nested JSON in the form "
'{"basic": {"monthly": "price_..."}}.'
)
for billing_cycle, price_id in billing_cycle_map.items():
if not isinstance(price_id, str) or not price_id.strip():
raise RuntimeError(
f"Invalid Stripe price id for tier={tier}, billing_cycle={billing_cycle}."
)
normalized_mapping[(tier, billing_cycle)] = price_id.strip()
return normalized_mapping
def _load_stripe_plan_price_mapping() -> Dict[tuple[str, str], str]:
stripe_mode = _detect_stripe_mode()
mode_var_name = f"STRIPE_PLAN_PRICE_MAPPING_{stripe_mode.upper()}"
raw_mapping_json = os.getenv(mode_var_name) or os.getenv("STRIPE_PLAN_PRICE_MAPPING")
if not raw_mapping_json:
raise RuntimeError(
"Missing Stripe plan mapping configuration. Set "
f"{mode_var_name} (recommended) or STRIPE_PLAN_PRICE_MAPPING."
)
try:
parsed_mapping = json.loads(raw_mapping_json)
except json.JSONDecodeError as exc:
raise RuntimeError(
f"Invalid JSON in {mode_var_name}/STRIPE_PLAN_PRICE_MAPPING: {exc.msg}"
) from exc
if not isinstance(parsed_mapping, dict):
raise RuntimeError("Stripe plan mapping must decode to a JSON object.")
mapping = _normalize_stripe_plan_price_mapping(parsed_mapping)
missing_keys = REQUIRED_STRIPE_PLAN_KEYS - set(mapping.keys())
if missing_keys:
missing = ", ".join(
sorted([f"{tier}:{billing_cycle}" for tier, billing_cycle in missing_keys])
)
raise RuntimeError(
"Stripe plan mapping is missing required tier/cycle combinations: "
f"{missing}."
)
return mapping
STRIPE_PLAN_PRICE_MAPPING = _load_stripe_plan_price_mapping()
STRIPE_PRICE_TO_PLAN = {
price_id: {"tier": SubscriptionTier(tier), "billing_cycle": BillingCycle(billing_cycle)}
for (tier, billing_cycle), price_id in STRIPE_PLAN_PRICE_MAPPING.items()