- Fix text selection menu not showing: wire contentRef via inputRef on multiline TextField - Fix blog title not truncating: add min-w-0 for flex item overflow - Fix outline generation 500: escape curly braces in f-string prompt template - Fix content generation 'NoneType not callable': replace SessionLocal() with get_session_for_user(), add db param to MediumBlogGenerator, fix signature mismatch in database_task_manager - Fix writing assistant suggest 500: add auth + user_id to API endpoint and service, replace sync requests with httpx.AsyncClient - Fix hallucination detector 404: explicitly include router in main.py and app.py - Fix missing error_data in task failure responses - Hide CopilotKit web inspector button - Remove hardcoded fallback suggestions from SmartTypingAssist - Fix stale closure refs in SmartTypingAssist handleTypingChange - Add two-column editor layout, stats bar, section hover menu - Various subscription, billing, and research module improvements
696 lines
28 KiB
Python
696 lines
28 KiB
Python
import json
|
|
import os
|
|
import stripe
|
|
from typing import Optional, Dict, Any
|
|
from loguru import logger
|
|
from fastapi import HTTPException
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.exc import IntegrityError
|
|
from models.subscription_models import UserSubscription, SubscriptionPlan, SubscriptionTier, BillingCycle, UsageStatus, FraudWarning, ProcessedStripeEvent
|
|
from services.subscription.pricing_service import PricingService
|
|
from datetime import datetime, timedelta
|
|
|
|
REQUIRED_STRIPE_PLAN_KEYS = {
|
|
(SubscriptionTier.BASIC.value, BillingCycle.MONTHLY.value),
|
|
(SubscriptionTier.PRO.value, BillingCycle.MONTHLY.value),
|
|
}
|
|
|
|
|
|
def _is_truthy_env(var_name: str) -> bool:
|
|
return os.getenv(var_name, "").strip().lower() in {"1", "true", "yes", "on"}
|
|
|
|
|
|
def _detect_stripe_mode() -> str:
|
|
configured_mode = os.getenv("STRIPE_MODE", "").strip().lower()
|
|
if configured_mode in {"test", "live"}:
|
|
return configured_mode
|
|
|
|
secret_key = os.getenv("STRIPE_SECRET_KEY", "").strip()
|
|
if secret_key.startswith("sk_live_"):
|
|
return "live"
|
|
if secret_key.startswith("sk_test_"):
|
|
return "test"
|
|
|
|
# Default to test when mode cannot be derived.
|
|
return "test"
|
|
|
|
|
|
def _normalize_stripe_plan_price_mapping(raw_mapping: Dict[str, Any]) -> Dict[tuple[str, str], str]:
|
|
normalized_mapping: Dict[tuple[str, str], str] = {}
|
|
|
|
for tier, billing_cycle_map in raw_mapping.items():
|
|
if not isinstance(billing_cycle_map, dict):
|
|
raise RuntimeError(
|
|
"Stripe plan mapping must be nested JSON in the form "
|
|
'{"basic": {"monthly": "price_..."}}.'
|
|
)
|
|
|
|
for billing_cycle, price_id in billing_cycle_map.items():
|
|
if not isinstance(price_id, str) or not price_id.strip():
|
|
raise RuntimeError(
|
|
f"Invalid Stripe price id for tier={tier}, billing_cycle={billing_cycle}."
|
|
)
|
|
normalized_mapping[(tier, billing_cycle)] = price_id.strip()
|
|
|
|
return normalized_mapping
|
|
|
|
|
|
def _load_stripe_plan_price_mapping() -> Dict[tuple[str, str], str]:
|
|
stripe_mode = _detect_stripe_mode()
|
|
mode_var_name = f"STRIPE_PLAN_PRICE_MAPPING_{stripe_mode.upper()}"
|
|
raw_mapping_json = os.getenv(mode_var_name) or os.getenv("STRIPE_PLAN_PRICE_MAPPING")
|
|
|
|
if not raw_mapping_json:
|
|
raise RuntimeError(
|
|
"Missing Stripe plan mapping configuration. Set "
|
|
f"{mode_var_name} (recommended) or STRIPE_PLAN_PRICE_MAPPING."
|
|
)
|
|
|
|
try:
|
|
parsed_mapping = json.loads(raw_mapping_json)
|
|
except json.JSONDecodeError as exc:
|
|
raise RuntimeError(
|
|
f"Invalid JSON in {mode_var_name}/STRIPE_PLAN_PRICE_MAPPING: {exc.msg}"
|
|
) from exc
|
|
|
|
if not isinstance(parsed_mapping, dict):
|
|
raise RuntimeError("Stripe plan mapping must decode to a JSON object.")
|
|
|
|
mapping = _normalize_stripe_plan_price_mapping(parsed_mapping)
|
|
missing_keys = REQUIRED_STRIPE_PLAN_KEYS - set(mapping.keys())
|
|
if missing_keys:
|
|
missing = ", ".join(
|
|
sorted([f"{tier}:{billing_cycle}" for tier, billing_cycle in missing_keys])
|
|
)
|
|
raise RuntimeError(
|
|
"Stripe plan mapping is missing required tier/cycle combinations: "
|
|
f"{missing}."
|
|
)
|
|
|
|
return mapping
|
|
|
|
|
|
STRIPE_PLAN_PRICE_MAPPING = _load_stripe_plan_price_mapping()
|
|
|
|
STRIPE_PRICE_TO_PLAN = {
|
|
price_id: {"tier": SubscriptionTier(tier), "billing_cycle": BillingCycle(billing_cycle)}
|
|
for (tier, billing_cycle), price_id in STRIPE_PLAN_PRICE_MAPPING.items()
|
|
}
|
|
|
|
class StripeService:
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
self.api_key = os.getenv("STRIPE_SECRET_KEY")
|
|
self.webhook_secret = os.getenv("STRIPE_WEBHOOK_SECRET")
|
|
self.require_stripe_checkout = _is_truthy_env("REQUIRE_STRIPE_CHECKOUT")
|
|
if not self.api_key:
|
|
if self.require_stripe_checkout:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=(
|
|
"REQUIRE_STRIPE_CHECKOUT=true but STRIPE_SECRET_KEY is missing. "
|
|
"Configure STRIPE_SECRET_KEY to enable Stripe checkout."
|
|
),
|
|
)
|
|
logger.warning("STRIPE_SECRET_KEY is not set. Stripe integration will not work.")
|
|
else:
|
|
stripe.api_key = self.api_key
|
|
|
|
def _get_price_id_for_plan(self, tier: SubscriptionTier, billing_cycle: BillingCycle) -> str:
|
|
key = (tier.value, billing_cycle.value)
|
|
price_id = STRIPE_PLAN_PRICE_MAPPING.get(key)
|
|
if not price_id:
|
|
logger.error(f"No Stripe price configured for tier={tier.value}, billing_cycle={billing_cycle.value}")
|
|
raise HTTPException(status_code=400, detail="Payment plan is not configured")
|
|
return price_id
|
|
|
|
def _get_plan_for_price_id(self, price_id: str) -> tuple[SubscriptionPlan, BillingCycle]:
|
|
mapping = STRIPE_PRICE_TO_PLAN.get(price_id)
|
|
if not mapping:
|
|
logger.error(f"Unknown Stripe price_id: {price_id}")
|
|
raise HTTPException(status_code=400, detail="Unknown payment price configuration")
|
|
tier = mapping["tier"]
|
|
billing_cycle = mapping["billing_cycle"]
|
|
plan = (
|
|
self.db.query(SubscriptionPlan)
|
|
.filter(SubscriptionPlan.tier == tier, SubscriptionPlan.is_active == True)
|
|
.order_by(SubscriptionPlan.price_monthly)
|
|
.first()
|
|
)
|
|
if not plan:
|
|
logger.error(f"No subscription plan found for tier={tier.value}")
|
|
raise HTTPException(status_code=400, detail="Subscription plan not found for payment price")
|
|
return plan, billing_cycle
|
|
|
|
def _get_or_create_customer(self, user_id: str, email: Optional[str] = None) -> str:
|
|
"""
|
|
Get existing Stripe customer ID for user, or create a new one.
|
|
"""
|
|
subscription = self.db.query(UserSubscription).filter(
|
|
UserSubscription.user_id == user_id
|
|
).first()
|
|
|
|
if subscription and subscription.stripe_customer_id:
|
|
return subscription.stripe_customer_id
|
|
|
|
# Search Stripe for existing customer by email (if provided) or metadata
|
|
try:
|
|
# If we have an email, search by email first
|
|
if email:
|
|
existing_customers = stripe.Customer.list(email=email, limit=1)
|
|
if existing_customers and len(existing_customers.data) > 0:
|
|
customer = existing_customers.data[0]
|
|
# Update DB
|
|
if subscription:
|
|
subscription.stripe_customer_id = customer.id
|
|
self.db.commit()
|
|
return customer.id
|
|
|
|
# Search by metadata user_id
|
|
existing_customers = stripe.Customer.search(
|
|
query=f"metadata['user_id']:'{user_id}'",
|
|
limit=1
|
|
)
|
|
if existing_customers and len(existing_customers.data) > 0:
|
|
customer = existing_customers.data[0]
|
|
if subscription:
|
|
subscription.stripe_customer_id = customer.id
|
|
self.db.commit()
|
|
return customer.id
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error searching Stripe customer: {e}")
|
|
|
|
# Create new customer
|
|
try:
|
|
customer_data = {
|
|
"metadata": {"user_id": user_id},
|
|
}
|
|
if email:
|
|
customer_data["email"] = email
|
|
|
|
customer = stripe.Customer.create(**customer_data)
|
|
|
|
# Update DB
|
|
if subscription:
|
|
subscription.stripe_customer_id = customer.id
|
|
else:
|
|
# Create a placeholder subscription record if none exists (usually created on signup/free tier)
|
|
# But typically we expect a free tier record to exist.
|
|
pass
|
|
|
|
self.db.commit()
|
|
return customer.id
|
|
except Exception as e:
|
|
logger.error(f"Error creating Stripe customer: {e}")
|
|
raise HTTPException(status_code=500, detail="Failed to create payment profile")
|
|
|
|
def create_checkout_session(
|
|
self,
|
|
user_id: str,
|
|
tier: SubscriptionTier,
|
|
billing_cycle: BillingCycle,
|
|
success_url: str,
|
|
cancel_url: str,
|
|
user_email: Optional[str] = None,
|
|
) -> str:
|
|
"""
|
|
Create a Stripe Checkout Session for a subscription.
|
|
"""
|
|
if not self.api_key:
|
|
raise HTTPException(status_code=500, detail="Payment service not configured")
|
|
|
|
price_id = self._get_price_id_for_plan(tier, billing_cycle)
|
|
customer_id = self._get_or_create_customer(user_id, user_email)
|
|
|
|
line_item: Dict[str, Any] = {"price": price_id}
|
|
try:
|
|
price = stripe.Price.retrieve(price_id)
|
|
recurring = getattr(price, "recurring", None)
|
|
usage_type = None
|
|
if recurring:
|
|
if isinstance(recurring, dict):
|
|
usage_type = recurring.get("usage_type")
|
|
else:
|
|
usage_type = getattr(recurring, "usage_type", None)
|
|
if usage_type != "metered":
|
|
line_item["quantity"] = 1
|
|
else:
|
|
logger.info(f"Detected metered price {price_id}; omitting quantity in Checkout line item")
|
|
except Exception as e:
|
|
logger.error(f"Error inspecting Stripe price {price_id}: {e}")
|
|
line_item["quantity"] = 1
|
|
|
|
try:
|
|
checkout_session = stripe.checkout.Session.create(
|
|
customer=customer_id,
|
|
payment_method_types=["card"],
|
|
line_items=[line_item],
|
|
mode="subscription",
|
|
success_url=success_url,
|
|
cancel_url=cancel_url,
|
|
metadata={
|
|
"user_id": user_id,
|
|
"price_id": price_id,
|
|
},
|
|
subscription_data={
|
|
"metadata": {
|
|
"user_id": user_id,
|
|
}
|
|
},
|
|
allow_promotion_codes=True,
|
|
)
|
|
return checkout_session.url
|
|
except Exception as e:
|
|
logger.error(f"Error creating checkout session: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
def create_portal_session(self, user_id: str, return_url: str) -> str:
|
|
"""
|
|
Create a Stripe Customer Portal session for managing billing.
|
|
"""
|
|
if not self.api_key:
|
|
raise HTTPException(status_code=500, detail="Payment service not configured")
|
|
|
|
subscription = self.db.query(UserSubscription).filter(
|
|
UserSubscription.user_id == user_id
|
|
).first()
|
|
|
|
if not subscription or not subscription.stripe_customer_id:
|
|
# Try to find customer by user_id if not in DB
|
|
try:
|
|
customers = stripe.Customer.search(query=f"metadata['user_id']:'{user_id}'", limit=1)
|
|
if customers and len(customers.data) > 0:
|
|
customer_id = customers.data[0].id
|
|
# Update DB while we're at it
|
|
if subscription:
|
|
subscription.stripe_customer_id = customer_id
|
|
self.db.commit()
|
|
else:
|
|
raise HTTPException(status_code=400, detail="No billing profile found for this user")
|
|
except Exception as e:
|
|
logger.error(f"Error finding customer for portal: {e}")
|
|
raise HTTPException(status_code=500, detail="Failed to access billing portal")
|
|
else:
|
|
customer_id = subscription.stripe_customer_id
|
|
|
|
try:
|
|
portal_session = stripe.billing_portal.Session.create(
|
|
customer=customer_id,
|
|
return_url=return_url,
|
|
)
|
|
return portal_session.url
|
|
except Exception as e:
|
|
logger.error(f"Error creating portal session: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
async def handle_webhook(self, payload: bytes, sig_header: str):
|
|
"""
|
|
Handle Stripe webhooks.
|
|
"""
|
|
if not self.webhook_secret:
|
|
logger.warning("STRIPE_WEBHOOK_SECRET not set. Ignoring webhook.")
|
|
return
|
|
|
|
try:
|
|
event = stripe.Webhook.construct_event(
|
|
payload, sig_header, self.webhook_secret
|
|
)
|
|
except ValueError as e:
|
|
logger.error(f"Invalid payload: {e}")
|
|
raise HTTPException(status_code=400, detail="Invalid payload")
|
|
except stripe.error.SignatureVerificationError as e:
|
|
logger.error(f"Invalid signature: {e}")
|
|
raise HTTPException(status_code=400, detail="Invalid signature")
|
|
|
|
event_id = event.get("id")
|
|
event_type = event["type"]
|
|
data = event["data"]["object"]
|
|
|
|
if not event_id:
|
|
logger.error("Stripe webhook event missing id")
|
|
raise HTTPException(status_code=400, detail="Missing event id")
|
|
|
|
now = datetime.utcnow()
|
|
processed_event = self.db.query(ProcessedStripeEvent).filter(
|
|
ProcessedStripeEvent.event_id == event_id
|
|
).first()
|
|
|
|
if processed_event and processed_event.status == "processed":
|
|
logger.info(f"Skipping already processed Stripe event {event_id}")
|
|
return {"status": "success"}
|
|
|
|
if processed_event:
|
|
processed_event.status = "processing"
|
|
processed_event.processing_started_at = now
|
|
processed_event.last_error = None
|
|
processed_event.attempt_count = (processed_event.attempt_count or 0) + 1
|
|
else:
|
|
processed_event = ProcessedStripeEvent(
|
|
event_id=event_id,
|
|
event_type=event_type,
|
|
status="processing",
|
|
received_at=now,
|
|
processing_started_at=now,
|
|
attempt_count=1,
|
|
)
|
|
self.db.add(processed_event)
|
|
|
|
try:
|
|
self.db.commit()
|
|
except IntegrityError:
|
|
self.db.rollback()
|
|
existing_event = self.db.query(ProcessedStripeEvent).filter(
|
|
ProcessedStripeEvent.event_id == event_id
|
|
).first()
|
|
if existing_event and existing_event.status == "processed":
|
|
logger.info(f"Skipping already processed Stripe event {event_id} after race")
|
|
return {"status": "success"}
|
|
raise
|
|
|
|
logger.info(f"Received Stripe webhook: {event_type}")
|
|
|
|
try:
|
|
if event_type == "checkout.session.completed":
|
|
await self._handle_checkout_completed(data)
|
|
elif event_type == "invoice.payment_succeeded":
|
|
await self._handle_invoice_payment_succeeded(data)
|
|
elif event_type == "invoice.payment_failed":
|
|
await self._handle_invoice_payment_failed(data)
|
|
elif event_type == "customer.subscription.updated":
|
|
await self._handle_subscription_updated(data)
|
|
elif event_type == "customer.subscription.deleted":
|
|
await self._handle_subscription_deleted(data)
|
|
elif event_type.startswith("radar.early_fraud_warning."):
|
|
await self._handle_early_fraud_warning(data)
|
|
|
|
processed_event.status = "processed"
|
|
processed_event.processed_at = datetime.utcnow()
|
|
processed_event.last_error = None
|
|
self.db.commit()
|
|
except Exception as e:
|
|
self.db.rollback()
|
|
failed_event = self.db.query(ProcessedStripeEvent).filter(
|
|
ProcessedStripeEvent.event_id == event_id
|
|
).first()
|
|
if failed_event:
|
|
failed_event.status = "failed"
|
|
failed_event.last_error = str(e)[:2000]
|
|
failed_event.processed_at = datetime.utcnow()
|
|
self.db.commit()
|
|
raise
|
|
|
|
return {"status": "success"}
|
|
|
|
async def _handle_checkout_completed(self, session: Dict[str, Any]):
|
|
"""
|
|
Handle successful checkout.
|
|
"""
|
|
user_id = session.get("metadata", {}).get("user_id")
|
|
customer_id = session.get("customer")
|
|
subscription_id = session.get("subscription")
|
|
|
|
if not user_id:
|
|
logger.error("No user_id in checkout session metadata")
|
|
return
|
|
|
|
logger.info(f"Checkout completed for user {user_id}")
|
|
|
|
# Retrieve subscription details to get the plan/price
|
|
if subscription_id:
|
|
try:
|
|
sub = stripe.Subscription.retrieve(subscription_id)
|
|
price_id = sub['items']['data'][0]['price']['id']
|
|
|
|
# Update DB
|
|
self._update_user_subscription(
|
|
user_id,
|
|
stripe_customer_id=customer_id,
|
|
stripe_subscription_id=subscription_id,
|
|
status="active",
|
|
price_id=price_id
|
|
)
|
|
|
|
# Clear PricingService cache so next status check returns updated limits
|
|
try:
|
|
from services.subscription import PricingService
|
|
PricingService.clear_user_cache(user_id)
|
|
except Exception as cache_err:
|
|
logger.warning(f"Failed to clear user cache after checkout for user {user_id}: {cache_err}")
|
|
try:
|
|
from api.subscription.cache import clear_dashboard_cache
|
|
clear_dashboard_cache(user_id)
|
|
logger.info(f"Cleared dashboard cache for user {user_id} after checkout")
|
|
except Exception as cache_err:
|
|
logger.warning(f"Failed to clear cache after checkout for user {user_id}: {cache_err}")
|
|
|
|
# Expire all SQLAlchemy objects to force fresh reads
|
|
self.db.expire_all()
|
|
logger.info(f"Expired all SQLAlchemy objects for user {user_id} after checkout")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing checkout subscription: {e}")
|
|
|
|
async def _handle_invoice_payment_succeeded(self, invoice: Dict[str, Any]):
|
|
"""
|
|
Handle recurring payment success.
|
|
"""
|
|
subscription_id = invoice.get("subscription")
|
|
customer_id = invoice.get("customer")
|
|
|
|
if not subscription_id:
|
|
return
|
|
|
|
# Find user by stripe_subscription_id or customer_id
|
|
subscription = self.db.query(UserSubscription).filter(
|
|
(UserSubscription.stripe_subscription_id == subscription_id) |
|
|
(UserSubscription.stripe_customer_id == customer_id)
|
|
).first()
|
|
|
|
if subscription:
|
|
logger.info(f"Payment succeeded for user {subscription.user_id}")
|
|
subscription.status = UsageStatus.ACTIVE
|
|
subscription.is_active = True
|
|
subscription.auto_renew = True
|
|
# Update period start/end based on invoice lines period
|
|
if invoice.get('lines'):
|
|
period_start = invoice['lines']['data'][0]['period']['start']
|
|
period_end = invoice['lines']['data'][0]['period']['end']
|
|
subscription.current_period_start = datetime.fromtimestamp(period_start)
|
|
subscription.current_period_end = datetime.fromtimestamp(period_end)
|
|
self.db.commit()
|
|
|
|
# Clear PricingService cache so next status check returns updated limits
|
|
try:
|
|
from services.subscription import PricingService
|
|
PricingService.clear_user_cache(subscription.user_id)
|
|
logger.info(f"Cleared subscription cache for user {subscription.user_id} after payment success")
|
|
except Exception as cache_err:
|
|
logger.warning(f"Failed to clear user cache after payment success for user {subscription.user_id}: {cache_err}")
|
|
try:
|
|
from api.subscription.cache import clear_dashboard_cache
|
|
clear_dashboard_cache(subscription.user_id)
|
|
except Exception as dash_cache_err:
|
|
logger.warning(f"Failed to clear dashboard cache after payment success for user {subscription.user_id}: {dash_cache_err}")
|
|
self.db.expire_all()
|
|
|
|
async def _handle_invoice_payment_failed(self, invoice: Dict[str, Any]):
|
|
subscription_id = invoice.get("subscription")
|
|
customer_id = invoice.get("customer")
|
|
|
|
if not subscription_id:
|
|
return
|
|
|
|
subscription = self.db.query(UserSubscription).filter(
|
|
(UserSubscription.stripe_subscription_id == subscription_id) |
|
|
(UserSubscription.stripe_customer_id == customer_id)
|
|
).first()
|
|
|
|
if subscription:
|
|
logger.warning(f"Payment failed for user {subscription.user_id}")
|
|
subscription.status = UsageStatus.PAST_DUE
|
|
subscription.is_active = False
|
|
self.db.commit()
|
|
|
|
async def _handle_subscription_updated(self, subscription_obj: Dict[str, Any]):
|
|
"""
|
|
Handle subscription updates (cancellations, changes).
|
|
"""
|
|
stripe_sub_id = subscription_obj.get("id")
|
|
status = subscription_obj.get("status")
|
|
|
|
subscription = self.db.query(UserSubscription).filter(
|
|
UserSubscription.stripe_subscription_id == stripe_sub_id
|
|
).first()
|
|
|
|
if subscription:
|
|
logger.info(f"Subscription {stripe_sub_id} updated to {status}")
|
|
if status in ["active", "trialing"]:
|
|
subscription.status = UsageStatus.ACTIVE
|
|
subscription.is_active = True
|
|
subscription.auto_renew = True
|
|
# Update period boundaries from Stripe event
|
|
current_period = subscription_obj.get("current_period", {})
|
|
if current_period:
|
|
subscription.current_period_start = datetime.fromtimestamp(current_period.get("start", 0))
|
|
subscription.current_period_end = datetime.fromtimestamp(current_period.get("end", 0))
|
|
elif status in ["past_due", "unpaid", "incomplete", "incomplete_expired"]:
|
|
subscription.status = UsageStatus.PAST_DUE
|
|
subscription.is_active = False
|
|
elif status in ["canceled"]:
|
|
subscription.status = UsageStatus.CANCELLED
|
|
subscription.is_active = False
|
|
subscription.auto_renew = False
|
|
|
|
self.db.commit()
|
|
|
|
# Clear PricingService cache so next status check returns updated limits
|
|
try:
|
|
from services.subscription import PricingService
|
|
PricingService.clear_user_cache(subscription.user_id)
|
|
logger.info(f"Cleared subscription cache for user {subscription.user_id} after subscription update")
|
|
except Exception as cache_err:
|
|
logger.warning(f"Failed to clear user cache after subscription update for user {subscription.user_id}: {cache_err}")
|
|
try:
|
|
from api.subscription.cache import clear_dashboard_cache
|
|
clear_dashboard_cache(subscription.user_id)
|
|
except Exception as dash_cache_err:
|
|
logger.warning(f"Failed to clear dashboard cache after subscription update for user {subscription.user_id}: {dash_cache_err}")
|
|
self.db.expire_all()
|
|
|
|
async def _handle_subscription_deleted(self, subscription_obj: Dict[str, Any]):
|
|
"""
|
|
Handle subscription cancellation (immediate).
|
|
"""
|
|
stripe_sub_id = subscription_obj.get("id")
|
|
|
|
subscription = self.db.query(UserSubscription).filter(
|
|
UserSubscription.stripe_subscription_id == stripe_sub_id
|
|
).first()
|
|
|
|
if subscription:
|
|
logger.info(f"Subscription {stripe_sub_id} deleted")
|
|
subscription.status = UsageStatus.CANCELLED # Need to check if this enum value exists
|
|
subscription.is_active = False
|
|
subscription.auto_renew = False
|
|
self.db.commit()
|
|
|
|
async def _handle_early_fraud_warning(self, warning_obj: Dict[str, Any]):
|
|
efw_id = warning_obj.get("id")
|
|
if not efw_id:
|
|
return
|
|
|
|
charge_id = warning_obj.get("charge")
|
|
payment_intent_id = warning_obj.get("payment_intent")
|
|
created_ts = warning_obj.get("created")
|
|
created_at = datetime.utcfromtimestamp(created_ts) if created_ts else datetime.utcnow()
|
|
|
|
amount = 0
|
|
currency = ""
|
|
user_id = None
|
|
charge_data: Dict[str, Any] = {}
|
|
|
|
if charge_id and self.api_key:
|
|
try:
|
|
charge = stripe.Charge.retrieve(charge_id)
|
|
charge_data = charge.to_dict() if hasattr(charge, "to_dict") else dict(charge)
|
|
amount = charge_data.get("amount") or 0
|
|
currency = charge_data.get("currency") or ""
|
|
metadata = charge_data.get("metadata") or {}
|
|
user_id = metadata.get("user_id")
|
|
except Exception as e:
|
|
logger.error(f"Error retrieving charge for early fraud warning {efw_id}: {e}")
|
|
|
|
if not amount:
|
|
amount = warning_obj.get("amount") or 0
|
|
if not currency:
|
|
currency = warning_obj.get("currency") or ""
|
|
|
|
existing = self.db.query(FraudWarning).filter(FraudWarning.id == efw_id).first()
|
|
|
|
metadata_payload: Dict[str, Any] = {
|
|
"early_fraud_warning": warning_obj,
|
|
}
|
|
if charge_data:
|
|
metadata_payload["charge"] = charge_data
|
|
|
|
if existing:
|
|
existing.charge_id = charge_id or existing.charge_id
|
|
existing.payment_intent_id = payment_intent_id or existing.payment_intent_id
|
|
if user_id:
|
|
existing.user_id = user_id
|
|
if amount:
|
|
existing.amount = amount
|
|
if currency:
|
|
existing.currency = currency
|
|
existing.status = "open"
|
|
existing.meta_info = metadata_payload
|
|
else:
|
|
if not charge_id:
|
|
return
|
|
warning = FraudWarning(
|
|
id=efw_id,
|
|
charge_id=charge_id,
|
|
payment_intent_id=payment_intent_id,
|
|
user_id=user_id,
|
|
amount=amount or 0,
|
|
currency=currency or "",
|
|
status="open",
|
|
action="none",
|
|
meta_info=metadata_payload,
|
|
created_at=created_at,
|
|
)
|
|
self.db.add(warning)
|
|
|
|
self.db.commit()
|
|
|
|
def _update_user_subscription(
|
|
self,
|
|
user_id: str,
|
|
stripe_customer_id: str,
|
|
stripe_subscription_id: str,
|
|
status: str,
|
|
price_id: str,
|
|
):
|
|
plan, billing_cycle = self._get_plan_for_price_id(price_id)
|
|
|
|
subscription = (
|
|
self.db.query(UserSubscription)
|
|
.filter(UserSubscription.user_id == user_id)
|
|
.first()
|
|
)
|
|
|
|
now = datetime.utcnow()
|
|
# Calculate billing period end based on cycle
|
|
if billing_cycle == BillingCycle.YEARLY:
|
|
period_end = now + timedelta(days=365)
|
|
else:
|
|
period_end = now + timedelta(days=30)
|
|
|
|
if not subscription:
|
|
subscription = UserSubscription(
|
|
user_id=user_id,
|
|
plan_id=plan.id,
|
|
billing_cycle=billing_cycle,
|
|
current_period_start=now,
|
|
current_period_end=period_end,
|
|
status=UsageStatus.ACTIVE if status == "active" else UsageStatus.SUSPENDED,
|
|
is_active=status == "active",
|
|
auto_renew=True,
|
|
)
|
|
self.db.add(subscription)
|
|
else:
|
|
subscription.plan_id = plan.id
|
|
subscription.billing_cycle = billing_cycle
|
|
subscription.is_active = status == "active"
|
|
subscription.status = UsageStatus.ACTIVE if status == "active" else UsageStatus.SUSPENDED
|
|
# Reset billing period on upgrade/plan change
|
|
subscription.current_period_start = now
|
|
subscription.current_period_end = period_end
|
|
subscription.auto_renew = True
|
|
|
|
subscription.stripe_customer_id = stripe_customer_id
|
|
subscription.stripe_subscription_id = stripe_subscription_id
|
|
|
|
self.db.commit()
|