diff --git a/backend/models/subscription_models.py b/backend/models/subscription_models.py index 25f9b788..996bc870 100644 --- a/backend/models/subscription_models.py +++ b/backend/models/subscription_models.py @@ -408,3 +408,17 @@ class FraudWarning(Base): reason_notes = Column(Text, nullable=True) meta_info = Column(JSON, nullable=True) created_at = Column(DateTime, default=datetime.utcnow) + + +class StripeWebhookEvent(Base): + """Processed Stripe webhook events for idempotency and replay protection.""" + + __tablename__ = "stripe_webhook_events" + + event_id = Column(String(100), primary_key=True) + event_type = Column(String(100), nullable=False) + status = Column(String(20), nullable=False, default="processing") # processing, processed, failed + error_message = Column(Text, nullable=True) + processed_at = Column(DateTime, nullable=True) + created_at = Column(DateTime, default=datetime.utcnow, nullable=False) + diff --git a/backend/services/subscription/stripe_service.py b/backend/services/subscription/stripe_service.py index afd0e133..75d370ed 100644 --- a/backend/services/subscription/stripe_service.py +++ b/backend/services/subscription/stripe_service.py @@ -4,7 +4,8 @@ from typing import Optional, Dict, Any from loguru import logger from fastapi import HTTPException from sqlalchemy.orm import Session -from models.subscription_models import UserSubscription, SubscriptionPlan, SubscriptionTier, BillingCycle, UsageStatus, FraudWarning +from sqlalchemy.exc import IntegrityError +from models.subscription_models import UserSubscription, SubscriptionPlan, SubscriptionTier, BillingCycle, UsageStatus, FraudWarning, StripeWebhookEvent from services.subscription.pricing_service import PricingService from datetime import datetime @@ -216,6 +217,35 @@ class StripeService: logger.error(f"Error creating portal session: {e}") raise HTTPException(status_code=500, detail=str(e)) + def _ensure_webhook_event_table(self) -> None: + """Ensure webhook idempotency table exists before processing events.""" + try: + bind = self.db.get_bind() + if bind is not None: + StripeWebhookEvent.__table__.create(bind=bind, checkfirst=True) + except Exception as e: + logger.warning(f"Failed to ensure stripe_webhook_events table exists: {e}") + + def _mark_webhook_event_status( + self, + event_id: str, + status: str, + error_message: Optional[str] = None, + ) -> None: + """Update persisted webhook event processing status.""" + event_row = self.db.query(StripeWebhookEvent).filter( + StripeWebhookEvent.event_id == event_id + ).first() + if not event_row: + return + + event_row.status = status + event_row.error_message = (error_message or "")[:1000] if error_message else None + if status == "processed": + event_row.processed_at = datetime.utcnow() + + self.db.commit() + async def handle_webhook(self, payload: bytes, sig_header: str): """ Handle Stripe webhooks. @@ -235,25 +265,64 @@ class StripeService: logger.error(f"Invalid signature: {e}") raise HTTPException(status_code=400, detail="Invalid signature") - event_type = event["type"] - data = event["data"]["object"] + event_id = event.get("id") + event_type = event.get("type") + data = event.get("data", {}).get("object", {}) - logger.info(f"Received Stripe webhook: {event_type}") - - if event_type == "checkout.session.completed": - await self._handle_checkout_completed(data) - elif event_type == "invoice.payment_succeeded": - await self._handle_invoice_payment_succeeded(data) - elif event_type == "invoice.payment_failed": - await self._handle_invoice_payment_failed(data) - elif event_type == "customer.subscription.updated": - await self._handle_subscription_updated(data) - elif event_type == "customer.subscription.deleted": - await self._handle_subscription_deleted(data) - elif event_type.startswith("radar.early_fraud_warning."): - await self._handle_early_fraud_warning(data) + if not event_id or not event_type: + logger.error("Stripe webhook missing event id/type") + raise HTTPException(status_code=400, detail="Invalid Stripe event payload") - return {"status": "success"} + # Idempotency guard: persist event id before mutating subscription state. + self._ensure_webhook_event_table() + existing_event = self.db.query(StripeWebhookEvent).filter( + StripeWebhookEvent.event_id == event_id + ).first() + if existing_event: + logger.info(f"Skipping already processed Stripe event: {event_id} ({existing_event.status})") + return {"status": "success", "idempotent": True} + + try: + event_row = StripeWebhookEvent( + event_id=event_id, + event_type=event_type, + status="processing", + ) + self.db.add(event_row) + self.db.commit() + except IntegrityError: + self.db.rollback() + logger.info(f"Skipping duplicate Stripe event insert: {event_id}") + return {"status": "success", "idempotent": True} + + logger.info(f"Received Stripe webhook: {event_type} ({event_id})") + + try: + if event_type == "checkout.session.completed": + await self._handle_checkout_completed(data) + elif event_type == "invoice.payment_succeeded": + await self._handle_invoice_payment_succeeded(data) + elif event_type == "invoice.payment_failed": + await self._handle_invoice_payment_failed(data) + elif event_type == "customer.subscription.updated": + await self._handle_subscription_updated(data) + elif event_type == "customer.subscription.deleted": + await self._handle_subscription_deleted(data) + elif event_type.startswith("radar.early_fraud_warning."): + await self._handle_early_fraud_warning(data) + + self._mark_webhook_event_status(event_id=event_id, status="processed") + return {"status": "success"} + + except Exception as e: + self.db.rollback() + self._mark_webhook_event_status( + event_id=event_id, + status="failed", + error_message=str(e), + ) + logger.error(f"Failed Stripe webhook handling for {event_id}: {e}") + raise async def _handle_checkout_completed(self, session: Dict[str, Any]): """