summaryrefslogtreecommitdiff
path: root/Biz/PodcastItLater/Core.py
diff options
context:
space:
mode:
authorBen Sima <ben@bsima.me>2025-11-09 15:45:48 -0500
committerBen Sima <ben@bsima.me>2025-11-09 15:45:48 -0500
commitbaf1ea549ad0218efcfaf489f9fb2ed7b67bf652 (patch)
treed3b29471af0e2f6749e67ff1f64f1769350c9621 /Biz/PodcastItLater/Core.py
parentfce44e9449c2305e32544c43a9a35a5c423daad3 (diff)
feat(PodcastItLater): Add Stripe billing infrastructure
Add complete Stripe integration backend ready for testing once stripe package is available in Nix environment. Components: - Billing.py: Stripe Checkout, Billing Portal, webhook handling - Database migrations for subscription tracking - Usage tracking with tier-based limits - Idempotent webhook processing with stripe_events table Tier limits: free (10/mo), personal (50/mo), pro (unlimited) Webhook events handled: - checkout.session.completed (link customer to user) - customer.subscription.{created,updated,deleted} (sync state) - invoice.payment_failed (mark past_due) Requires: Stripe Python package in Nix, Web.py routes (next commit) Related to task t-144e7lF Amp-Thread-ID: https://ampcode.com/threads/T-8feaca83-dcc2-46cb-8f71-d0785960a2f7 Co-authored-by: Amp <amp@ampcode.com>
Diffstat (limited to 'Biz/PodcastItLater/Core.py')
-rw-r--r--Biz/PodcastItLater/Core.py222
1 files changed, 222 insertions, 0 deletions
diff --git a/Biz/PodcastItLater/Core.py b/Biz/PodcastItLater/Core.py
index b7b6ed9..42b336e 100644
--- a/Biz/PodcastItLater/Core.py
+++ b/Biz/PodcastItLater/Core.py
@@ -141,6 +141,12 @@ class Database: # noqa: PLR0904
# Run migration to add default titles
Database.migrate_add_default_titles()
+ # Run migration to add billing fields
+ Database.migrate_add_billing_fields()
+
+ # Run migration to add stripe events table
+ Database.migrate_add_stripe_events_table()
+
@staticmethod
def add_to_queue(
url: str,
@@ -554,6 +560,53 @@ class Database: # noqa: PLR0904
logger.info("Database migrated to support user status")
@staticmethod
+ def migrate_add_billing_fields() -> None:
+ """Add billing-related fields to users table."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Add columns one by one (SQLite limitation)
+ columns_to_add = [
+ ("plan_tier", "TEXT NOT NULL DEFAULT 'free'"),
+ ("stripe_customer_id", "TEXT UNIQUE"),
+ ("stripe_subscription_id", "TEXT UNIQUE"),
+ ("subscription_status", "TEXT"),
+ ("current_period_start", "TIMESTAMP"),
+ ("current_period_end", "TIMESTAMP"),
+ ("cancel_at_period_end", "INTEGER NOT NULL DEFAULT 0"),
+ ]
+
+ for column_name, column_def in columns_to_add:
+ try:
+ query = f"ALTER TABLE users ADD COLUMN {column_name} "
+ cursor.execute(query + column_def)
+ logger.info("Added column users.%s", column_name)
+ except sqlite3.OperationalError: # noqa: PERF203
+ pass # Column already exists
+
+ conn.commit()
+
+ @staticmethod
+ def migrate_add_stripe_events_table() -> None:
+ """Create stripe_events table for webhook idempotency."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS stripe_events (
+ id TEXT PRIMARY KEY,
+ type TEXT NOT NULL,
+ payload TEXT NOT NULL,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+ )
+ """)
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_stripe_events_created "
+ "ON stripe_events(created_at)",
+ )
+ conn.commit()
+ logger.info("Created stripe_events table")
+
+ @staticmethod
def migrate_add_default_titles() -> None:
"""Add default titles to queue items that have None titles."""
with Database.get_connection() as conn:
@@ -730,6 +783,175 @@ class Database: # noqa: PLR0904
conn.commit()
logger.info("Updated user %s status to %s", user_id, status)
+ @staticmethod
+ def set_user_stripe_customer(user_id: int, customer_id: str) -> None:
+ """Link Stripe customer ID to user."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "UPDATE users SET stripe_customer_id = ? WHERE id = ?",
+ (customer_id, user_id),
+ )
+ conn.commit()
+ logger.info(
+ "Linked user %s to Stripe customer %s",
+ user_id,
+ customer_id,
+ )
+
+ @staticmethod
+ def update_user_subscription( # noqa: PLR0913, PLR0917
+ user_id: int,
+ subscription_id: str,
+ status: str,
+ period_start: Any,
+ period_end: Any,
+ tier: str,
+ cancel_at_period_end: bool, # noqa: FBT001
+ ) -> None:
+ """Update user subscription details."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ UPDATE users SET
+ stripe_subscription_id = ?,
+ subscription_status = ?,
+ current_period_start = ?,
+ current_period_end = ?,
+ plan_tier = ?,
+ cancel_at_period_end = ?
+ WHERE id = ?
+ """,
+ (
+ subscription_id,
+ status,
+ period_start.isoformat(),
+ period_end.isoformat(),
+ tier,
+ 1 if cancel_at_period_end else 0,
+ user_id,
+ ),
+ )
+ conn.commit()
+ logger.info(
+ "Updated user %s subscription: tier=%s, status=%s",
+ user_id,
+ tier,
+ status,
+ )
+
+ @staticmethod
+ def update_subscription_status(user_id: int, status: str) -> None:
+ """Update only the subscription status (e.g., past_due)."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "UPDATE users SET subscription_status = ? WHERE id = ?",
+ (status, user_id),
+ )
+ conn.commit()
+ logger.info(
+ "Updated user %s subscription status to %s",
+ user_id,
+ status,
+ )
+
+ @staticmethod
+ def downgrade_to_free(user_id: int) -> None:
+ """Downgrade user to free tier and clear subscription data."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ UPDATE users SET
+ plan_tier = 'free',
+ subscription_status = 'canceled',
+ stripe_subscription_id = NULL,
+ current_period_start = NULL,
+ current_period_end = NULL,
+ cancel_at_period_end = 0
+ WHERE id = ?
+ """,
+ (user_id,),
+ )
+ conn.commit()
+ logger.info("Downgraded user %s to free tier", user_id)
+
+ @staticmethod
+ def get_user_by_stripe_customer_id(
+ customer_id: str,
+ ) -> dict[str, Any] | None:
+ """Get user by Stripe customer ID."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT * FROM users WHERE stripe_customer_id = ?",
+ (customer_id,),
+ )
+ row = cursor.fetchone()
+ return dict(row) if row is not None else None
+
+ @staticmethod
+ def has_processed_stripe_event(event_id: str) -> bool:
+ """Check if Stripe event has already been processed."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT id FROM stripe_events WHERE id = ?",
+ (event_id,),
+ )
+ return cursor.fetchone() is not None
+
+ @staticmethod
+ def mark_stripe_event_processed(
+ event_id: str,
+ event_type: str,
+ payload: bytes,
+ ) -> None:
+ """Mark Stripe event as processed for idempotency."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "INSERT OR IGNORE INTO stripe_events (id, type, payload) "
+ "VALUES (?, ?, ?)",
+ (event_id, event_type, payload.decode("utf-8")),
+ )
+ conn.commit()
+
+ @staticmethod
+ def get_usage(
+ user_id: int,
+ period_start: Any,
+ period_end: Any,
+ ) -> dict[str, int]:
+ """Get usage stats for user in period.
+
+ Returns:
+ dict with keys: articles (int), minutes (int)
+ """
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Count articles created in period
+ cursor.execute(
+ """
+ SELECT COUNT(*) as count, SUM(duration) as total_seconds
+ FROM episodes
+ WHERE user_id = ? AND created_at >= ? AND created_at < ?
+ """,
+ (user_id, period_start.isoformat(), period_end.isoformat()),
+ )
+ row = cursor.fetchone()
+
+ articles = row["count"] if row else 0
+ total_seconds = (
+ row["total_seconds"] if row and row["total_seconds"] else 0
+ )
+ minutes = total_seconds // 60
+
+ return {"articles": articles, "minutes": minutes}
+
class TestDatabase(Test.TestCase):
"""Test the Database class."""