summaryrefslogtreecommitdiff
path: root/Biz/PodcastItLater/Billing.py
diff options
context:
space:
mode:
authorBen Sima <ben@bsima.me>2025-11-09 15:45:48 -0500
committerBen Sima <ben@bsima.me>2025-11-09 15:45:48 -0500
commitbaf1ea549ad0218efcfaf489f9fb2ed7b67bf652 (patch)
treed3b29471af0e2f6749e67ff1f64f1769350c9621 /Biz/PodcastItLater/Billing.py
parentfce44e9449c2305e32544c43a9a35a5c423daad3 (diff)
feat(PodcastItLater): Add Stripe billing infrastructure
Add complete Stripe integration backend ready for testing once stripe package is available in Nix environment. Components: - Billing.py: Stripe Checkout, Billing Portal, webhook handling - Database migrations for subscription tracking - Usage tracking with tier-based limits - Idempotent webhook processing with stripe_events table Tier limits: free (10/mo), personal (50/mo), pro (unlimited) Webhook events handled: - checkout.session.completed (link customer to user) - customer.subscription.{created,updated,deleted} (sync state) - invoice.payment_failed (mark past_due) Requires: Stripe Python package in Nix, Web.py routes (next commit) Related to task t-144e7lF Amp-Thread-ID: https://ampcode.com/threads/T-8feaca83-dcc2-46cb-8f71-d0785960a2f7 Co-authored-by: Amp <amp@ampcode.com>
Diffstat (limited to 'Biz/PodcastItLater/Billing.py')
-rw-r--r--Biz/PodcastItLater/Billing.py422
1 files changed, 422 insertions, 0 deletions
diff --git a/Biz/PodcastItLater/Billing.py b/Biz/PodcastItLater/Billing.py
new file mode 100644
index 0000000..e472889
--- /dev/null
+++ b/Biz/PodcastItLater/Billing.py
@@ -0,0 +1,422 @@
+"""
+PodcastItLater Billing Integration.
+
+Stripe subscription management and usage enforcement.
+"""
+
+# : out podcastitlater-billing
+# : dep stripe
+# : dep pytest
+# : dep pytest-mock
+import Biz.PodcastItLater.Core as Core
+import Omni.Log as Log
+import os
+import stripe
+from datetime import datetime
+from datetime import timezone
+
+logger = Log.setup()
+
+# Stripe configuration
+stripe.api_key = os.getenv("STRIPE_SECRET_KEY", "")
+STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET", "")
+
+# Price IDs from Stripe dashboard
+STRIPE_PRICE_ID_PERSONAL = os.getenv("STRIPE_PRICE_ID_PERSONAL", "")
+STRIPE_PRICE_ID_PRO = os.getenv("STRIPE_PRICE_ID_PRO", "")
+
+# Map Stripe price IDs to tier names
+PRICE_TO_TIER = {
+ STRIPE_PRICE_ID_PERSONAL: "personal",
+ STRIPE_PRICE_ID_PRO: "pro",
+}
+
+# Tier limits (None = unlimited)
+TIER_LIMITS = {
+ "free": {
+ "articles_per_period": 10,
+ "minutes_per_period": None,
+ },
+ "personal": {
+ "articles_per_period": 50,
+ "minutes_per_period": None,
+ },
+ "pro": {
+ "articles_per_period": None,
+ "minutes_per_period": None,
+ },
+}
+
+# Price map for checkout
+PRICE_MAP = {
+ "personal": STRIPE_PRICE_ID_PERSONAL,
+ "pro": STRIPE_PRICE_ID_PRO,
+}
+
+
+def get_period_boundaries(user: dict) -> tuple[datetime, datetime]:
+ """Get billing period boundaries for user.
+
+ For paid users: use Stripe subscription period.
+ For free users: use current calendar month (UTC).
+
+ Returns:
+ tuple: (period_start, period_end) as datetime objects
+ """
+ if user.get("plan_tier") != "free" and user.get("current_period_start"):
+ # Paid user - use Stripe billing period
+ period_start = datetime.fromisoformat(user["current_period_start"])
+ period_end = datetime.fromisoformat(user["current_period_end"])
+ else:
+ # Free user - use calendar month
+ now = datetime.now(timezone.utc)
+ period_start = now.replace(
+ day=1,
+ hour=0,
+ minute=0,
+ second=0,
+ microsecond=0,
+ )
+ # Get first day of next month
+ if now.month == 12: # noqa: PLR2004
+ period_end = period_start.replace(year=now.year + 1, month=1)
+ else:
+ period_end = period_start.replace(month=now.month + 1)
+
+ return period_start, period_end
+
+
+def get_usage(
+ user_id: int,
+ period_start: datetime,
+ period_end: datetime,
+) -> dict:
+ """Get usage stats for user in billing period.
+
+ Returns:
+ dict with keys: articles (int), minutes (int)
+ """
+ return Core.Database.get_usage(user_id, period_start, period_end)
+
+
+def can_submit(user_id: int) -> tuple[bool, str, dict]:
+ """Check if user can submit article based on tier limits.
+
+ Returns:
+ tuple: (allowed: bool, message: str, usage: dict)
+ """
+ user = Core.Database.get_user_by_id(user_id)
+ if not user:
+ return False, "User not found", {}
+
+ tier = user.get("plan_tier", "free")
+ limits = TIER_LIMITS.get(tier, TIER_LIMITS["free"])
+
+ # Get billing period boundaries
+ period_start, period_end = get_period_boundaries(user)
+
+ # Get current usage
+ usage = get_usage(user_id, period_start, period_end)
+
+ # Check article limit
+ article_limit = limits["articles_per_period"]
+ if article_limit is not None and usage["articles"] >= article_limit:
+ msg = (
+ f"You've reached your limit of {article_limit} articles "
+ "per period. Upgrade to continue."
+ )
+ return (False, msg, usage)
+
+ # Check minutes limit (if implemented)
+ minute_limit = limits.get("minutes_per_period")
+ if minute_limit is not None and usage.get("minutes", 0) >= minute_limit:
+ return (
+ False,
+ f"You've reached your limit of {minute_limit} minutes per period. "
+ "Please upgrade to continue.",
+ usage,
+ )
+
+ return True, "", usage
+
+
+def create_checkout_session(user_id: int, tier: str, base_url: str) -> str:
+ """Create Stripe Checkout session for subscription.
+
+ Args:
+ user_id: User ID
+ tier: Subscription tier (personal or pro)
+ base_url: Base URL for success/cancel redirects
+
+ Returns:
+ Checkout session URL to redirect user to
+
+ Raises:
+ ValueError: If tier is invalid or price ID not configured
+ """
+ if tier not in PRICE_MAP:
+ msg = f"Invalid tier: {tier}"
+ raise ValueError(msg)
+
+ price_id = PRICE_MAP[tier]
+ if not price_id:
+ msg = f"Stripe price ID not configured for tier: {tier}"
+ raise ValueError(msg)
+
+ user = Core.Database.get_user_by_id(user_id)
+ if not user:
+ msg = f"User not found: {user_id}"
+ raise ValueError(msg)
+
+ # Create checkout session
+ session_params = {
+ "mode": "subscription",
+ "line_items": [{"price": price_id, "quantity": 1}],
+ "success_url": f"{base_url}/billing?status=success",
+ "cancel_url": f"{base_url}/billing?status=cancel",
+ "client_reference_id": str(user_id),
+ "metadata": {"user_id": str(user_id), "tier": tier},
+ "allow_promotion_codes": True,
+ }
+
+ # Use existing customer if available
+ if user.get("stripe_customer_id"):
+ session_params["customer"] = user["stripe_customer_id"]
+ else:
+ session_params["customer_email"] = user["email"]
+
+ session = stripe.checkout.Session.create(**session_params)
+
+ logger.info(
+ "Created checkout session for user %s, tier %s: %s",
+ user_id,
+ tier,
+ session.id,
+ )
+
+ return session.url
+
+
+def create_portal_session(user_id: int, base_url: str) -> str:
+ """Create Stripe Billing Portal session.
+
+ Args:
+ user_id: User ID
+ base_url: Base URL for return redirect
+
+ Returns:
+ Portal session URL to redirect user to
+
+ Raises:
+ ValueError: If user has no Stripe customer ID
+ """
+ user = Core.Database.get_user_by_id(user_id)
+ if not user:
+ msg = f"User not found: {user_id}"
+ raise ValueError(msg)
+
+ if not user.get("stripe_customer_id"):
+ msg = "User has no Stripe customer ID"
+ raise ValueError(msg)
+
+ session = stripe.billing_portal.Session.create(
+ customer=user["stripe_customer_id"],
+ return_url=f"{base_url}/billing",
+ )
+
+ logger.info("Created portal session for user %s: %s", user_id, session.id)
+
+ return session.url
+
+
+def handle_webhook_event(payload: bytes, sig_header: str) -> dict:
+ """Verify and process Stripe webhook event.
+
+ Args:
+ payload: Raw webhook body
+ sig_header: Stripe-Signature header value
+
+ Returns:
+ dict with processing status
+
+ Note:
+ May raise stripe.error.SignatureVerificationError if invalid signature
+ """
+ # Verify webhook signature
+ event = stripe.Webhook.construct_event(
+ payload,
+ sig_header,
+ STRIPE_WEBHOOK_SECRET,
+ )
+
+ event_id = event["id"]
+ event_type = event["type"]
+
+ # Check if already processed (idempotency)
+ if Core.Database.has_processed_stripe_event(event_id):
+ logger.info("Skipping already processed event: %s", event_id)
+ return {"status": "skipped", "reason": "already_processed"}
+
+ # Process event based on type
+ logger.info("Processing webhook event: %s (%s)", event_id, event_type)
+
+ try:
+ if event_type == "checkout.session.completed":
+ _handle_checkout_completed(event["data"]["object"])
+ elif event_type == "customer.subscription.created":
+ _handle_subscription_created(event["data"]["object"])
+ elif event_type == "customer.subscription.updated":
+ _handle_subscription_updated(event["data"]["object"])
+ elif event_type == "customer.subscription.deleted":
+ _handle_subscription_deleted(event["data"]["object"])
+ elif event_type == "invoice.payment_failed":
+ _handle_payment_failed(event["data"]["object"])
+ else:
+ logger.info("Unhandled event type: %s", event_type)
+ return {"status": "ignored", "type": event_type}
+
+ # Mark event as processed
+ Core.Database.mark_stripe_event_processed(event_id, event_type, payload)
+ except Exception:
+ logger.exception("Error processing webhook event %s", event_id)
+ raise
+ else:
+ return {"status": "processed", "type": event_type}
+
+
+def _handle_checkout_completed(session: dict) -> None:
+ """Handle checkout.session.completed event."""
+ user_id = int(session.get("client_reference_id", 0))
+ customer_id = session.get("customer")
+
+ if not user_id or not customer_id:
+ logger.warning(
+ "Missing user_id or customer_id in checkout session: %s",
+ session["id"],
+ )
+ return
+
+ # Link Stripe customer to user
+ Core.Database.set_user_stripe_customer(user_id, customer_id)
+ logger.info("Linked user %s to Stripe customer %s", user_id, customer_id)
+
+
+def _handle_subscription_created(subscription: dict) -> None:
+ """Handle customer.subscription.created event."""
+ _update_subscription_state(subscription)
+
+
+def _handle_subscription_updated(subscription: dict) -> None:
+ """Handle customer.subscription.updated event."""
+ _update_subscription_state(subscription)
+
+
+def _handle_subscription_deleted(subscription: dict) -> None:
+ """Handle customer.subscription.deleted event."""
+ customer_id = subscription["customer"]
+
+ # Find user by customer ID
+ user = Core.Database.get_user_by_stripe_customer_id(customer_id)
+ if not user:
+ logger.warning("User not found for customer: %s", customer_id)
+ return
+
+ # Downgrade to free
+ Core.Database.downgrade_to_free(user["id"])
+ logger.info("Downgraded user %s to free tier", user["id"])
+
+
+def _handle_payment_failed(invoice: dict) -> None:
+ """Handle invoice.payment_failed event."""
+ customer_id = invoice["customer"]
+ subscription_id = invoice.get("subscription")
+
+ # Find user by customer ID
+ user = Core.Database.get_user_by_stripe_customer_id(customer_id)
+ if not user:
+ logger.warning("User not found for customer: %s", customer_id)
+ return
+
+ # Update subscription status to past_due
+ if subscription_id:
+ Core.Database.update_subscription_status(user["id"], "past_due")
+ logger.warning(
+ "Payment failed for user %s, subscription %s",
+ user["id"],
+ subscription_id,
+ )
+
+
+def _update_subscription_state(subscription: dict) -> None:
+ """Update user subscription state from Stripe subscription object."""
+ customer_id = subscription["customer"]
+ subscription_id = subscription["id"]
+ status = subscription["status"]
+ cancel_at_period_end = subscription.get("cancel_at_period_end", False)
+
+ # Get billing period
+ period_start = datetime.fromtimestamp(
+ subscription["current_period_start"],
+ tz=timezone.utc,
+ )
+ period_end = datetime.fromtimestamp(
+ subscription["current_period_end"],
+ tz=timezone.utc,
+ )
+
+ # Determine tier from price ID
+ price_id = subscription["items"]["data"][0]["price"]["id"]
+ tier = PRICE_TO_TIER.get(price_id, "free")
+
+ # Find user by customer ID
+ user = Core.Database.get_user_by_stripe_customer_id(customer_id)
+ if not user:
+ logger.warning("User not found for customer: %s", customer_id)
+ return
+
+ # Update user subscription
+ Core.Database.update_user_subscription(
+ user["id"],
+ subscription_id,
+ status,
+ period_start,
+ period_end,
+ tier,
+ cancel_at_period_end,
+ )
+
+ logger.info(
+ "Updated user %s subscription: tier=%s, status=%s",
+ user["id"],
+ tier,
+ status,
+ )
+
+
+def get_tier_info(tier: str) -> dict:
+ """Get tier information for display.
+
+ Returns:
+ dict with keys: name, articles_limit, price, description
+ """
+ tier_info = {
+ "free": {
+ "name": "Free",
+ "articles_limit": 10,
+ "price": "$0",
+ "description": "10 articles per month",
+ },
+ "personal": {
+ "name": "Personal",
+ "articles_limit": 50,
+ "price": "$9/mo",
+ "description": "50 articles per month",
+ },
+ "pro": {
+ "name": "Pro",
+ "articles_limit": None,
+ "price": "$29/mo",
+ "description": "Unlimited articles",
+ },
+ }
+ return tier_info.get(tier, tier_info["free"])