diff options
| author | Ben Sima <ben@bsima.me> | 2025-11-09 15:45:48 -0500 |
|---|---|---|
| committer | Ben Sima <ben@bsima.me> | 2025-11-09 15:45:48 -0500 |
| commit | baf1ea549ad0218efcfaf489f9fb2ed7b67bf652 (patch) | |
| tree | d3b29471af0e2f6749e67ff1f64f1769350c9621 /Biz/PodcastItLater/Billing.py | |
| parent | fce44e9449c2305e32544c43a9a35a5c423daad3 (diff) | |
feat(PodcastItLater): Add Stripe billing infrastructure
Add complete Stripe integration backend ready for testing once stripe
package is available in Nix environment.
Components: - Billing.py: Stripe Checkout, Billing Portal, webhook
handling - Database migrations for subscription tracking - Usage
tracking with tier-based limits - Idempotent webhook processing with
stripe_events table
Tier limits: free (10/mo), personal (50/mo), pro (unlimited)
Webhook events handled: - checkout.session.completed (link customer to
user) - customer.subscription.{created,updated,deleted} (sync state)
- invoice.payment_failed (mark past_due)
Requires: Stripe Python package in Nix, Web.py routes (next commit)
Related to task t-144e7lF
Amp-Thread-ID:
https://ampcode.com/threads/T-8feaca83-dcc2-46cb-8f71-d0785960a2f7
Co-authored-by: Amp <amp@ampcode.com>
Diffstat (limited to 'Biz/PodcastItLater/Billing.py')
| -rw-r--r-- | Biz/PodcastItLater/Billing.py | 422 |
1 files changed, 422 insertions, 0 deletions
diff --git a/Biz/PodcastItLater/Billing.py b/Biz/PodcastItLater/Billing.py new file mode 100644 index 0000000..e472889 --- /dev/null +++ b/Biz/PodcastItLater/Billing.py @@ -0,0 +1,422 @@ +""" +PodcastItLater Billing Integration. + +Stripe subscription management and usage enforcement. +""" + +# : out podcastitlater-billing +# : dep stripe +# : dep pytest +# : dep pytest-mock +import Biz.PodcastItLater.Core as Core +import Omni.Log as Log +import os +import stripe +from datetime import datetime +from datetime import timezone + +logger = Log.setup() + +# Stripe configuration +stripe.api_key = os.getenv("STRIPE_SECRET_KEY", "") +STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET", "") + +# Price IDs from Stripe dashboard +STRIPE_PRICE_ID_PERSONAL = os.getenv("STRIPE_PRICE_ID_PERSONAL", "") +STRIPE_PRICE_ID_PRO = os.getenv("STRIPE_PRICE_ID_PRO", "") + +# Map Stripe price IDs to tier names +PRICE_TO_TIER = { + STRIPE_PRICE_ID_PERSONAL: "personal", + STRIPE_PRICE_ID_PRO: "pro", +} + +# Tier limits (None = unlimited) +TIER_LIMITS = { + "free": { + "articles_per_period": 10, + "minutes_per_period": None, + }, + "personal": { + "articles_per_period": 50, + "minutes_per_period": None, + }, + "pro": { + "articles_per_period": None, + "minutes_per_period": None, + }, +} + +# Price map for checkout +PRICE_MAP = { + "personal": STRIPE_PRICE_ID_PERSONAL, + "pro": STRIPE_PRICE_ID_PRO, +} + + +def get_period_boundaries(user: dict) -> tuple[datetime, datetime]: + """Get billing period boundaries for user. + + For paid users: use Stripe subscription period. + For free users: use current calendar month (UTC). + + Returns: + tuple: (period_start, period_end) as datetime objects + """ + if user.get("plan_tier") != "free" and user.get("current_period_start"): + # Paid user - use Stripe billing period + period_start = datetime.fromisoformat(user["current_period_start"]) + period_end = datetime.fromisoformat(user["current_period_end"]) + else: + # Free user - use calendar month + now = datetime.now(timezone.utc) + period_start = now.replace( + day=1, + hour=0, + minute=0, + second=0, + microsecond=0, + ) + # Get first day of next month + if now.month == 12: # noqa: PLR2004 + period_end = period_start.replace(year=now.year + 1, month=1) + else: + period_end = period_start.replace(month=now.month + 1) + + return period_start, period_end + + +def get_usage( + user_id: int, + period_start: datetime, + period_end: datetime, +) -> dict: + """Get usage stats for user in billing period. + + Returns: + dict with keys: articles (int), minutes (int) + """ + return Core.Database.get_usage(user_id, period_start, period_end) + + +def can_submit(user_id: int) -> tuple[bool, str, dict]: + """Check if user can submit article based on tier limits. + + Returns: + tuple: (allowed: bool, message: str, usage: dict) + """ + user = Core.Database.get_user_by_id(user_id) + if not user: + return False, "User not found", {} + + tier = user.get("plan_tier", "free") + limits = TIER_LIMITS.get(tier, TIER_LIMITS["free"]) + + # Get billing period boundaries + period_start, period_end = get_period_boundaries(user) + + # Get current usage + usage = get_usage(user_id, period_start, period_end) + + # Check article limit + article_limit = limits["articles_per_period"] + if article_limit is not None and usage["articles"] >= article_limit: + msg = ( + f"You've reached your limit of {article_limit} articles " + "per period. Upgrade to continue." + ) + return (False, msg, usage) + + # Check minutes limit (if implemented) + minute_limit = limits.get("minutes_per_period") + if minute_limit is not None and usage.get("minutes", 0) >= minute_limit: + return ( + False, + f"You've reached your limit of {minute_limit} minutes per period. " + "Please upgrade to continue.", + usage, + ) + + return True, "", usage + + +def create_checkout_session(user_id: int, tier: str, base_url: str) -> str: + """Create Stripe Checkout session for subscription. + + Args: + user_id: User ID + tier: Subscription tier (personal or pro) + base_url: Base URL for success/cancel redirects + + Returns: + Checkout session URL to redirect user to + + Raises: + ValueError: If tier is invalid or price ID not configured + """ + if tier not in PRICE_MAP: + msg = f"Invalid tier: {tier}" + raise ValueError(msg) + + price_id = PRICE_MAP[tier] + if not price_id: + msg = f"Stripe price ID not configured for tier: {tier}" + raise ValueError(msg) + + user = Core.Database.get_user_by_id(user_id) + if not user: + msg = f"User not found: {user_id}" + raise ValueError(msg) + + # Create checkout session + session_params = { + "mode": "subscription", + "line_items": [{"price": price_id, "quantity": 1}], + "success_url": f"{base_url}/billing?status=success", + "cancel_url": f"{base_url}/billing?status=cancel", + "client_reference_id": str(user_id), + "metadata": {"user_id": str(user_id), "tier": tier}, + "allow_promotion_codes": True, + } + + # Use existing customer if available + if user.get("stripe_customer_id"): + session_params["customer"] = user["stripe_customer_id"] + else: + session_params["customer_email"] = user["email"] + + session = stripe.checkout.Session.create(**session_params) + + logger.info( + "Created checkout session for user %s, tier %s: %s", + user_id, + tier, + session.id, + ) + + return session.url + + +def create_portal_session(user_id: int, base_url: str) -> str: + """Create Stripe Billing Portal session. + + Args: + user_id: User ID + base_url: Base URL for return redirect + + Returns: + Portal session URL to redirect user to + + Raises: + ValueError: If user has no Stripe customer ID + """ + user = Core.Database.get_user_by_id(user_id) + if not user: + msg = f"User not found: {user_id}" + raise ValueError(msg) + + if not user.get("stripe_customer_id"): + msg = "User has no Stripe customer ID" + raise ValueError(msg) + + session = stripe.billing_portal.Session.create( + customer=user["stripe_customer_id"], + return_url=f"{base_url}/billing", + ) + + logger.info("Created portal session for user %s: %s", user_id, session.id) + + return session.url + + +def handle_webhook_event(payload: bytes, sig_header: str) -> dict: + """Verify and process Stripe webhook event. + + Args: + payload: Raw webhook body + sig_header: Stripe-Signature header value + + Returns: + dict with processing status + + Note: + May raise stripe.error.SignatureVerificationError if invalid signature + """ + # Verify webhook signature + event = stripe.Webhook.construct_event( + payload, + sig_header, + STRIPE_WEBHOOK_SECRET, + ) + + event_id = event["id"] + event_type = event["type"] + + # Check if already processed (idempotency) + if Core.Database.has_processed_stripe_event(event_id): + logger.info("Skipping already processed event: %s", event_id) + return {"status": "skipped", "reason": "already_processed"} + + # Process event based on type + logger.info("Processing webhook event: %s (%s)", event_id, event_type) + + try: + if event_type == "checkout.session.completed": + _handle_checkout_completed(event["data"]["object"]) + elif event_type == "customer.subscription.created": + _handle_subscription_created(event["data"]["object"]) + elif event_type == "customer.subscription.updated": + _handle_subscription_updated(event["data"]["object"]) + elif event_type == "customer.subscription.deleted": + _handle_subscription_deleted(event["data"]["object"]) + elif event_type == "invoice.payment_failed": + _handle_payment_failed(event["data"]["object"]) + else: + logger.info("Unhandled event type: %s", event_type) + return {"status": "ignored", "type": event_type} + + # Mark event as processed + Core.Database.mark_stripe_event_processed(event_id, event_type, payload) + except Exception: + logger.exception("Error processing webhook event %s", event_id) + raise + else: + return {"status": "processed", "type": event_type} + + +def _handle_checkout_completed(session: dict) -> None: + """Handle checkout.session.completed event.""" + user_id = int(session.get("client_reference_id", 0)) + customer_id = session.get("customer") + + if not user_id or not customer_id: + logger.warning( + "Missing user_id or customer_id in checkout session: %s", + session["id"], + ) + return + + # Link Stripe customer to user + Core.Database.set_user_stripe_customer(user_id, customer_id) + logger.info("Linked user %s to Stripe customer %s", user_id, customer_id) + + +def _handle_subscription_created(subscription: dict) -> None: + """Handle customer.subscription.created event.""" + _update_subscription_state(subscription) + + +def _handle_subscription_updated(subscription: dict) -> None: + """Handle customer.subscription.updated event.""" + _update_subscription_state(subscription) + + +def _handle_subscription_deleted(subscription: dict) -> None: + """Handle customer.subscription.deleted event.""" + customer_id = subscription["customer"] + + # Find user by customer ID + user = Core.Database.get_user_by_stripe_customer_id(customer_id) + if not user: + logger.warning("User not found for customer: %s", customer_id) + return + + # Downgrade to free + Core.Database.downgrade_to_free(user["id"]) + logger.info("Downgraded user %s to free tier", user["id"]) + + +def _handle_payment_failed(invoice: dict) -> None: + """Handle invoice.payment_failed event.""" + customer_id = invoice["customer"] + subscription_id = invoice.get("subscription") + + # Find user by customer ID + user = Core.Database.get_user_by_stripe_customer_id(customer_id) + if not user: + logger.warning("User not found for customer: %s", customer_id) + return + + # Update subscription status to past_due + if subscription_id: + Core.Database.update_subscription_status(user["id"], "past_due") + logger.warning( + "Payment failed for user %s, subscription %s", + user["id"], + subscription_id, + ) + + +def _update_subscription_state(subscription: dict) -> None: + """Update user subscription state from Stripe subscription object.""" + customer_id = subscription["customer"] + subscription_id = subscription["id"] + status = subscription["status"] + cancel_at_period_end = subscription.get("cancel_at_period_end", False) + + # Get billing period + period_start = datetime.fromtimestamp( + subscription["current_period_start"], + tz=timezone.utc, + ) + period_end = datetime.fromtimestamp( + subscription["current_period_end"], + tz=timezone.utc, + ) + + # Determine tier from price ID + price_id = subscription["items"]["data"][0]["price"]["id"] + tier = PRICE_TO_TIER.get(price_id, "free") + + # Find user by customer ID + user = Core.Database.get_user_by_stripe_customer_id(customer_id) + if not user: + logger.warning("User not found for customer: %s", customer_id) + return + + # Update user subscription + Core.Database.update_user_subscription( + user["id"], + subscription_id, + status, + period_start, + period_end, + tier, + cancel_at_period_end, + ) + + logger.info( + "Updated user %s subscription: tier=%s, status=%s", + user["id"], + tier, + status, + ) + + +def get_tier_info(tier: str) -> dict: + """Get tier information for display. + + Returns: + dict with keys: name, articles_limit, price, description + """ + tier_info = { + "free": { + "name": "Free", + "articles_limit": 10, + "price": "$0", + "description": "10 articles per month", + }, + "personal": { + "name": "Personal", + "articles_limit": 50, + "price": "$9/mo", + "description": "50 articles per month", + }, + "pro": { + "name": "Pro", + "articles_limit": None, + "price": "$29/mo", + "description": "Unlimited articles", + }, + } + return tier_info.get(tier, tier_info["free"]) |
