From baf1ea549ad0218efcfaf489f9fb2ed7b67bf652 Mon Sep 17 00:00:00 2001 From: Ben Sima Date: Sun, 9 Nov 2025 15:45:48 -0500 Subject: feat(PodcastItLater): Add Stripe billing infrastructure Add complete Stripe integration backend ready for testing once stripe package is available in Nix environment. Components: - Billing.py: Stripe Checkout, Billing Portal, webhook handling - Database migrations for subscription tracking - Usage tracking with tier-based limits - Idempotent webhook processing with stripe_events table Tier limits: free (10/mo), personal (50/mo), pro (unlimited) Webhook events handled: - checkout.session.completed (link customer to user) - customer.subscription.{created,updated,deleted} (sync state) - invoice.payment_failed (mark past_due) Requires: Stripe Python package in Nix, Web.py routes (next commit) Related to task t-144e7lF Amp-Thread-ID: https://ampcode.com/threads/T-8feaca83-dcc2-46cb-8f71-d0785960a2f7 Co-authored-by: Amp --- Biz/PodcastItLater/Billing.hs | 2 + Biz/PodcastItLater/Billing.py | 422 ++++++++++++++++++++++++++++++++++++++++++ Biz/PodcastItLater/Core.py | 222 ++++++++++++++++++++++ 3 files changed, 646 insertions(+) create mode 100644 Biz/PodcastItLater/Billing.hs create mode 100644 Biz/PodcastItLater/Billing.py (limited to 'Biz') diff --git a/Biz/PodcastItLater/Billing.hs b/Biz/PodcastItLater/Billing.hs new file mode 100644 index 0000000..0a23e6c --- /dev/null +++ b/Biz/PodcastItLater/Billing.hs @@ -0,0 +1,2 @@ +-- | PodcastItLater Billing namespace marker +module Biz.PodcastItLater.Billing where diff --git a/Biz/PodcastItLater/Billing.py b/Biz/PodcastItLater/Billing.py new file mode 100644 index 0000000..e472889 --- /dev/null +++ b/Biz/PodcastItLater/Billing.py @@ -0,0 +1,422 @@ +""" +PodcastItLater Billing Integration. + +Stripe subscription management and usage enforcement. +""" + +# : out podcastitlater-billing +# : dep stripe +# : dep pytest +# : dep pytest-mock +import Biz.PodcastItLater.Core as Core +import Omni.Log as Log +import os +import stripe +from datetime import datetime +from datetime import timezone + +logger = Log.setup() + +# Stripe configuration +stripe.api_key = os.getenv("STRIPE_SECRET_KEY", "") +STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET", "") + +# Price IDs from Stripe dashboard +STRIPE_PRICE_ID_PERSONAL = os.getenv("STRIPE_PRICE_ID_PERSONAL", "") +STRIPE_PRICE_ID_PRO = os.getenv("STRIPE_PRICE_ID_PRO", "") + +# Map Stripe price IDs to tier names +PRICE_TO_TIER = { + STRIPE_PRICE_ID_PERSONAL: "personal", + STRIPE_PRICE_ID_PRO: "pro", +} + +# Tier limits (None = unlimited) +TIER_LIMITS = { + "free": { + "articles_per_period": 10, + "minutes_per_period": None, + }, + "personal": { + "articles_per_period": 50, + "minutes_per_period": None, + }, + "pro": { + "articles_per_period": None, + "minutes_per_period": None, + }, +} + +# Price map for checkout +PRICE_MAP = { + "personal": STRIPE_PRICE_ID_PERSONAL, + "pro": STRIPE_PRICE_ID_PRO, +} + + +def get_period_boundaries(user: dict) -> tuple[datetime, datetime]: + """Get billing period boundaries for user. + + For paid users: use Stripe subscription period. + For free users: use current calendar month (UTC). + + Returns: + tuple: (period_start, period_end) as datetime objects + """ + if user.get("plan_tier") != "free" and user.get("current_period_start"): + # Paid user - use Stripe billing period + period_start = datetime.fromisoformat(user["current_period_start"]) + period_end = datetime.fromisoformat(user["current_period_end"]) + else: + # Free user - use calendar month + now = datetime.now(timezone.utc) + period_start = now.replace( + day=1, + hour=0, + minute=0, + second=0, + microsecond=0, + ) + # Get first day of next month + if now.month == 12: # noqa: PLR2004 + period_end = period_start.replace(year=now.year + 1, month=1) + else: + period_end = period_start.replace(month=now.month + 1) + + return period_start, period_end + + +def get_usage( + user_id: int, + period_start: datetime, + period_end: datetime, +) -> dict: + """Get usage stats for user in billing period. + + Returns: + dict with keys: articles (int), minutes (int) + """ + return Core.Database.get_usage(user_id, period_start, period_end) + + +def can_submit(user_id: int) -> tuple[bool, str, dict]: + """Check if user can submit article based on tier limits. + + Returns: + tuple: (allowed: bool, message: str, usage: dict) + """ + user = Core.Database.get_user_by_id(user_id) + if not user: + return False, "User not found", {} + + tier = user.get("plan_tier", "free") + limits = TIER_LIMITS.get(tier, TIER_LIMITS["free"]) + + # Get billing period boundaries + period_start, period_end = get_period_boundaries(user) + + # Get current usage + usage = get_usage(user_id, period_start, period_end) + + # Check article limit + article_limit = limits["articles_per_period"] + if article_limit is not None and usage["articles"] >= article_limit: + msg = ( + f"You've reached your limit of {article_limit} articles " + "per period. Upgrade to continue." + ) + return (False, msg, usage) + + # Check minutes limit (if implemented) + minute_limit = limits.get("minutes_per_period") + if minute_limit is not None and usage.get("minutes", 0) >= minute_limit: + return ( + False, + f"You've reached your limit of {minute_limit} minutes per period. " + "Please upgrade to continue.", + usage, + ) + + return True, "", usage + + +def create_checkout_session(user_id: int, tier: str, base_url: str) -> str: + """Create Stripe Checkout session for subscription. + + Args: + user_id: User ID + tier: Subscription tier (personal or pro) + base_url: Base URL for success/cancel redirects + + Returns: + Checkout session URL to redirect user to + + Raises: + ValueError: If tier is invalid or price ID not configured + """ + if tier not in PRICE_MAP: + msg = f"Invalid tier: {tier}" + raise ValueError(msg) + + price_id = PRICE_MAP[tier] + if not price_id: + msg = f"Stripe price ID not configured for tier: {tier}" + raise ValueError(msg) + + user = Core.Database.get_user_by_id(user_id) + if not user: + msg = f"User not found: {user_id}" + raise ValueError(msg) + + # Create checkout session + session_params = { + "mode": "subscription", + "line_items": [{"price": price_id, "quantity": 1}], + "success_url": f"{base_url}/billing?status=success", + "cancel_url": f"{base_url}/billing?status=cancel", + "client_reference_id": str(user_id), + "metadata": {"user_id": str(user_id), "tier": tier}, + "allow_promotion_codes": True, + } + + # Use existing customer if available + if user.get("stripe_customer_id"): + session_params["customer"] = user["stripe_customer_id"] + else: + session_params["customer_email"] = user["email"] + + session = stripe.checkout.Session.create(**session_params) + + logger.info( + "Created checkout session for user %s, tier %s: %s", + user_id, + tier, + session.id, + ) + + return session.url + + +def create_portal_session(user_id: int, base_url: str) -> str: + """Create Stripe Billing Portal session. + + Args: + user_id: User ID + base_url: Base URL for return redirect + + Returns: + Portal session URL to redirect user to + + Raises: + ValueError: If user has no Stripe customer ID + """ + user = Core.Database.get_user_by_id(user_id) + if not user: + msg = f"User not found: {user_id}" + raise ValueError(msg) + + if not user.get("stripe_customer_id"): + msg = "User has no Stripe customer ID" + raise ValueError(msg) + + session = stripe.billing_portal.Session.create( + customer=user["stripe_customer_id"], + return_url=f"{base_url}/billing", + ) + + logger.info("Created portal session for user %s: %s", user_id, session.id) + + return session.url + + +def handle_webhook_event(payload: bytes, sig_header: str) -> dict: + """Verify and process Stripe webhook event. + + Args: + payload: Raw webhook body + sig_header: Stripe-Signature header value + + Returns: + dict with processing status + + Note: + May raise stripe.error.SignatureVerificationError if invalid signature + """ + # Verify webhook signature + event = stripe.Webhook.construct_event( + payload, + sig_header, + STRIPE_WEBHOOK_SECRET, + ) + + event_id = event["id"] + event_type = event["type"] + + # Check if already processed (idempotency) + if Core.Database.has_processed_stripe_event(event_id): + logger.info("Skipping already processed event: %s", event_id) + return {"status": "skipped", "reason": "already_processed"} + + # Process event based on type + logger.info("Processing webhook event: %s (%s)", event_id, event_type) + + try: + if event_type == "checkout.session.completed": + _handle_checkout_completed(event["data"]["object"]) + elif event_type == "customer.subscription.created": + _handle_subscription_created(event["data"]["object"]) + elif event_type == "customer.subscription.updated": + _handle_subscription_updated(event["data"]["object"]) + elif event_type == "customer.subscription.deleted": + _handle_subscription_deleted(event["data"]["object"]) + elif event_type == "invoice.payment_failed": + _handle_payment_failed(event["data"]["object"]) + else: + logger.info("Unhandled event type: %s", event_type) + return {"status": "ignored", "type": event_type} + + # Mark event as processed + Core.Database.mark_stripe_event_processed(event_id, event_type, payload) + except Exception: + logger.exception("Error processing webhook event %s", event_id) + raise + else: + return {"status": "processed", "type": event_type} + + +def _handle_checkout_completed(session: dict) -> None: + """Handle checkout.session.completed event.""" + user_id = int(session.get("client_reference_id", 0)) + customer_id = session.get("customer") + + if not user_id or not customer_id: + logger.warning( + "Missing user_id or customer_id in checkout session: %s", + session["id"], + ) + return + + # Link Stripe customer to user + Core.Database.set_user_stripe_customer(user_id, customer_id) + logger.info("Linked user %s to Stripe customer %s", user_id, customer_id) + + +def _handle_subscription_created(subscription: dict) -> None: + """Handle customer.subscription.created event.""" + _update_subscription_state(subscription) + + +def _handle_subscription_updated(subscription: dict) -> None: + """Handle customer.subscription.updated event.""" + _update_subscription_state(subscription) + + +def _handle_subscription_deleted(subscription: dict) -> None: + """Handle customer.subscription.deleted event.""" + customer_id = subscription["customer"] + + # Find user by customer ID + user = Core.Database.get_user_by_stripe_customer_id(customer_id) + if not user: + logger.warning("User not found for customer: %s", customer_id) + return + + # Downgrade to free + Core.Database.downgrade_to_free(user["id"]) + logger.info("Downgraded user %s to free tier", user["id"]) + + +def _handle_payment_failed(invoice: dict) -> None: + """Handle invoice.payment_failed event.""" + customer_id = invoice["customer"] + subscription_id = invoice.get("subscription") + + # Find user by customer ID + user = Core.Database.get_user_by_stripe_customer_id(customer_id) + if not user: + logger.warning("User not found for customer: %s", customer_id) + return + + # Update subscription status to past_due + if subscription_id: + Core.Database.update_subscription_status(user["id"], "past_due") + logger.warning( + "Payment failed for user %s, subscription %s", + user["id"], + subscription_id, + ) + + +def _update_subscription_state(subscription: dict) -> None: + """Update user subscription state from Stripe subscription object.""" + customer_id = subscription["customer"] + subscription_id = subscription["id"] + status = subscription["status"] + cancel_at_period_end = subscription.get("cancel_at_period_end", False) + + # Get billing period + period_start = datetime.fromtimestamp( + subscription["current_period_start"], + tz=timezone.utc, + ) + period_end = datetime.fromtimestamp( + subscription["current_period_end"], + tz=timezone.utc, + ) + + # Determine tier from price ID + price_id = subscription["items"]["data"][0]["price"]["id"] + tier = PRICE_TO_TIER.get(price_id, "free") + + # Find user by customer ID + user = Core.Database.get_user_by_stripe_customer_id(customer_id) + if not user: + logger.warning("User not found for customer: %s", customer_id) + return + + # Update user subscription + Core.Database.update_user_subscription( + user["id"], + subscription_id, + status, + period_start, + period_end, + tier, + cancel_at_period_end, + ) + + logger.info( + "Updated user %s subscription: tier=%s, status=%s", + user["id"], + tier, + status, + ) + + +def get_tier_info(tier: str) -> dict: + """Get tier information for display. + + Returns: + dict with keys: name, articles_limit, price, description + """ + tier_info = { + "free": { + "name": "Free", + "articles_limit": 10, + "price": "$0", + "description": "10 articles per month", + }, + "personal": { + "name": "Personal", + "articles_limit": 50, + "price": "$9/mo", + "description": "50 articles per month", + }, + "pro": { + "name": "Pro", + "articles_limit": None, + "price": "$29/mo", + "description": "Unlimited articles", + }, + } + return tier_info.get(tier, tier_info["free"]) diff --git a/Biz/PodcastItLater/Core.py b/Biz/PodcastItLater/Core.py index b7b6ed9..42b336e 100644 --- a/Biz/PodcastItLater/Core.py +++ b/Biz/PodcastItLater/Core.py @@ -141,6 +141,12 @@ class Database: # noqa: PLR0904 # Run migration to add default titles Database.migrate_add_default_titles() + # Run migration to add billing fields + Database.migrate_add_billing_fields() + + # Run migration to add stripe events table + Database.migrate_add_stripe_events_table() + @staticmethod def add_to_queue( url: str, @@ -553,6 +559,53 @@ class Database: # noqa: PLR0904 conn.commit() logger.info("Database migrated to support user status") + @staticmethod + def migrate_add_billing_fields() -> None: + """Add billing-related fields to users table.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Add columns one by one (SQLite limitation) + columns_to_add = [ + ("plan_tier", "TEXT NOT NULL DEFAULT 'free'"), + ("stripe_customer_id", "TEXT UNIQUE"), + ("stripe_subscription_id", "TEXT UNIQUE"), + ("subscription_status", "TEXT"), + ("current_period_start", "TIMESTAMP"), + ("current_period_end", "TIMESTAMP"), + ("cancel_at_period_end", "INTEGER NOT NULL DEFAULT 0"), + ] + + for column_name, column_def in columns_to_add: + try: + query = f"ALTER TABLE users ADD COLUMN {column_name} " + cursor.execute(query + column_def) + logger.info("Added column users.%s", column_name) + except sqlite3.OperationalError: # noqa: PERF203 + pass # Column already exists + + conn.commit() + + @staticmethod + def migrate_add_stripe_events_table() -> None: + """Create stripe_events table for webhook idempotency.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS stripe_events ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL, + payload TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_stripe_events_created " + "ON stripe_events(created_at)", + ) + conn.commit() + logger.info("Created stripe_events table") + @staticmethod def migrate_add_default_titles() -> None: """Add default titles to queue items that have None titles.""" @@ -730,6 +783,175 @@ class Database: # noqa: PLR0904 conn.commit() logger.info("Updated user %s status to %s", user_id, status) + @staticmethod + def set_user_stripe_customer(user_id: int, customer_id: str) -> None: + """Link Stripe customer ID to user.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE users SET stripe_customer_id = ? WHERE id = ?", + (customer_id, user_id), + ) + conn.commit() + logger.info( + "Linked user %s to Stripe customer %s", + user_id, + customer_id, + ) + + @staticmethod + def update_user_subscription( # noqa: PLR0913, PLR0917 + user_id: int, + subscription_id: str, + status: str, + period_start: Any, + period_end: Any, + tier: str, + cancel_at_period_end: bool, # noqa: FBT001 + ) -> None: + """Update user subscription details.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + UPDATE users SET + stripe_subscription_id = ?, + subscription_status = ?, + current_period_start = ?, + current_period_end = ?, + plan_tier = ?, + cancel_at_period_end = ? + WHERE id = ? + """, + ( + subscription_id, + status, + period_start.isoformat(), + period_end.isoformat(), + tier, + 1 if cancel_at_period_end else 0, + user_id, + ), + ) + conn.commit() + logger.info( + "Updated user %s subscription: tier=%s, status=%s", + user_id, + tier, + status, + ) + + @staticmethod + def update_subscription_status(user_id: int, status: str) -> None: + """Update only the subscription status (e.g., past_due).""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE users SET subscription_status = ? WHERE id = ?", + (status, user_id), + ) + conn.commit() + logger.info( + "Updated user %s subscription status to %s", + user_id, + status, + ) + + @staticmethod + def downgrade_to_free(user_id: int) -> None: + """Downgrade user to free tier and clear subscription data.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + UPDATE users SET + plan_tier = 'free', + subscription_status = 'canceled', + stripe_subscription_id = NULL, + current_period_start = NULL, + current_period_end = NULL, + cancel_at_period_end = 0 + WHERE id = ? + """, + (user_id,), + ) + conn.commit() + logger.info("Downgraded user %s to free tier", user_id) + + @staticmethod + def get_user_by_stripe_customer_id( + customer_id: str, + ) -> dict[str, Any] | None: + """Get user by Stripe customer ID.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM users WHERE stripe_customer_id = ?", + (customer_id,), + ) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def has_processed_stripe_event(event_id: str) -> bool: + """Check if Stripe event has already been processed.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT id FROM stripe_events WHERE id = ?", + (event_id,), + ) + return cursor.fetchone() is not None + + @staticmethod + def mark_stripe_event_processed( + event_id: str, + event_type: str, + payload: bytes, + ) -> None: + """Mark Stripe event as processed for idempotency.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT OR IGNORE INTO stripe_events (id, type, payload) " + "VALUES (?, ?, ?)", + (event_id, event_type, payload.decode("utf-8")), + ) + conn.commit() + + @staticmethod + def get_usage( + user_id: int, + period_start: Any, + period_end: Any, + ) -> dict[str, int]: + """Get usage stats for user in period. + + Returns: + dict with keys: articles (int), minutes (int) + """ + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Count articles created in period + cursor.execute( + """ + SELECT COUNT(*) as count, SUM(duration) as total_seconds + FROM episodes + WHERE user_id = ? AND created_at >= ? AND created_at < ? + """, + (user_id, period_start.isoformat(), period_end.isoformat()), + ) + row = cursor.fetchone() + + articles = row["count"] if row else 0 + total_seconds = ( + row["total_seconds"] if row and row["total_seconds"] else 0 + ) + minutes = total_seconds // 60 + + return {"articles": articles, "minutes": minutes} + class TestDatabase(Test.TestCase): """Test the Database class.""" -- cgit v1.2.3