summaryrefslogtreecommitdiff
path: root/Biz/PodcastItLater
diff options
context:
space:
mode:
Diffstat (limited to 'Biz/PodcastItLater')
-rw-r--r--Biz/PodcastItLater/Billing.hs2
-rw-r--r--Biz/PodcastItLater/Billing.py422
-rw-r--r--Biz/PodcastItLater/Core.py222
3 files changed, 646 insertions, 0 deletions
diff --git a/Biz/PodcastItLater/Billing.hs b/Biz/PodcastItLater/Billing.hs
new file mode 100644
index 0000000..0a23e6c
--- /dev/null
+++ b/Biz/PodcastItLater/Billing.hs
@@ -0,0 +1,2 @@
+-- | PodcastItLater Billing namespace marker
+module Biz.PodcastItLater.Billing where
diff --git a/Biz/PodcastItLater/Billing.py b/Biz/PodcastItLater/Billing.py
new file mode 100644
index 0000000..e472889
--- /dev/null
+++ b/Biz/PodcastItLater/Billing.py
@@ -0,0 +1,422 @@
+"""
+PodcastItLater Billing Integration.
+
+Stripe subscription management and usage enforcement.
+"""
+
+# : out podcastitlater-billing
+# : dep stripe
+# : dep pytest
+# : dep pytest-mock
+import Biz.PodcastItLater.Core as Core
+import Omni.Log as Log
+import os
+import stripe
+from datetime import datetime
+from datetime import timezone
+
+logger = Log.setup()
+
+# Stripe configuration
+stripe.api_key = os.getenv("STRIPE_SECRET_KEY", "")
+STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET", "")
+
+# Price IDs from Stripe dashboard
+STRIPE_PRICE_ID_PERSONAL = os.getenv("STRIPE_PRICE_ID_PERSONAL", "")
+STRIPE_PRICE_ID_PRO = os.getenv("STRIPE_PRICE_ID_PRO", "")
+
+# Map Stripe price IDs to tier names
+PRICE_TO_TIER = {
+ STRIPE_PRICE_ID_PERSONAL: "personal",
+ STRIPE_PRICE_ID_PRO: "pro",
+}
+
+# Tier limits (None = unlimited)
+TIER_LIMITS = {
+ "free": {
+ "articles_per_period": 10,
+ "minutes_per_period": None,
+ },
+ "personal": {
+ "articles_per_period": 50,
+ "minutes_per_period": None,
+ },
+ "pro": {
+ "articles_per_period": None,
+ "minutes_per_period": None,
+ },
+}
+
+# Price map for checkout
+PRICE_MAP = {
+ "personal": STRIPE_PRICE_ID_PERSONAL,
+ "pro": STRIPE_PRICE_ID_PRO,
+}
+
+
+def get_period_boundaries(user: dict) -> tuple[datetime, datetime]:
+ """Get billing period boundaries for user.
+
+ For paid users: use Stripe subscription period.
+ For free users: use current calendar month (UTC).
+
+ Returns:
+ tuple: (period_start, period_end) as datetime objects
+ """
+ if user.get("plan_tier") != "free" and user.get("current_period_start"):
+ # Paid user - use Stripe billing period
+ period_start = datetime.fromisoformat(user["current_period_start"])
+ period_end = datetime.fromisoformat(user["current_period_end"])
+ else:
+ # Free user - use calendar month
+ now = datetime.now(timezone.utc)
+ period_start = now.replace(
+ day=1,
+ hour=0,
+ minute=0,
+ second=0,
+ microsecond=0,
+ )
+ # Get first day of next month
+ if now.month == 12: # noqa: PLR2004
+ period_end = period_start.replace(year=now.year + 1, month=1)
+ else:
+ period_end = period_start.replace(month=now.month + 1)
+
+ return period_start, period_end
+
+
+def get_usage(
+ user_id: int,
+ period_start: datetime,
+ period_end: datetime,
+) -> dict:
+ """Get usage stats for user in billing period.
+
+ Returns:
+ dict with keys: articles (int), minutes (int)
+ """
+ return Core.Database.get_usage(user_id, period_start, period_end)
+
+
+def can_submit(user_id: int) -> tuple[bool, str, dict]:
+ """Check if user can submit article based on tier limits.
+
+ Returns:
+ tuple: (allowed: bool, message: str, usage: dict)
+ """
+ user = Core.Database.get_user_by_id(user_id)
+ if not user:
+ return False, "User not found", {}
+
+ tier = user.get("plan_tier", "free")
+ limits = TIER_LIMITS.get(tier, TIER_LIMITS["free"])
+
+ # Get billing period boundaries
+ period_start, period_end = get_period_boundaries(user)
+
+ # Get current usage
+ usage = get_usage(user_id, period_start, period_end)
+
+ # Check article limit
+ article_limit = limits["articles_per_period"]
+ if article_limit is not None and usage["articles"] >= article_limit:
+ msg = (
+ f"You've reached your limit of {article_limit} articles "
+ "per period. Upgrade to continue."
+ )
+ return (False, msg, usage)
+
+ # Check minutes limit (if implemented)
+ minute_limit = limits.get("minutes_per_period")
+ if minute_limit is not None and usage.get("minutes", 0) >= minute_limit:
+ return (
+ False,
+ f"You've reached your limit of {minute_limit} minutes per period. "
+ "Please upgrade to continue.",
+ usage,
+ )
+
+ return True, "", usage
+
+
+def create_checkout_session(user_id: int, tier: str, base_url: str) -> str:
+ """Create Stripe Checkout session for subscription.
+
+ Args:
+ user_id: User ID
+ tier: Subscription tier (personal or pro)
+ base_url: Base URL for success/cancel redirects
+
+ Returns:
+ Checkout session URL to redirect user to
+
+ Raises:
+ ValueError: If tier is invalid or price ID not configured
+ """
+ if tier not in PRICE_MAP:
+ msg = f"Invalid tier: {tier}"
+ raise ValueError(msg)
+
+ price_id = PRICE_MAP[tier]
+ if not price_id:
+ msg = f"Stripe price ID not configured for tier: {tier}"
+ raise ValueError(msg)
+
+ user = Core.Database.get_user_by_id(user_id)
+ if not user:
+ msg = f"User not found: {user_id}"
+ raise ValueError(msg)
+
+ # Create checkout session
+ session_params = {
+ "mode": "subscription",
+ "line_items": [{"price": price_id, "quantity": 1}],
+ "success_url": f"{base_url}/billing?status=success",
+ "cancel_url": f"{base_url}/billing?status=cancel",
+ "client_reference_id": str(user_id),
+ "metadata": {"user_id": str(user_id), "tier": tier},
+ "allow_promotion_codes": True,
+ }
+
+ # Use existing customer if available
+ if user.get("stripe_customer_id"):
+ session_params["customer"] = user["stripe_customer_id"]
+ else:
+ session_params["customer_email"] = user["email"]
+
+ session = stripe.checkout.Session.create(**session_params)
+
+ logger.info(
+ "Created checkout session for user %s, tier %s: %s",
+ user_id,
+ tier,
+ session.id,
+ )
+
+ return session.url
+
+
+def create_portal_session(user_id: int, base_url: str) -> str:
+ """Create Stripe Billing Portal session.
+
+ Args:
+ user_id: User ID
+ base_url: Base URL for return redirect
+
+ Returns:
+ Portal session URL to redirect user to
+
+ Raises:
+ ValueError: If user has no Stripe customer ID
+ """
+ user = Core.Database.get_user_by_id(user_id)
+ if not user:
+ msg = f"User not found: {user_id}"
+ raise ValueError(msg)
+
+ if not user.get("stripe_customer_id"):
+ msg = "User has no Stripe customer ID"
+ raise ValueError(msg)
+
+ session = stripe.billing_portal.Session.create(
+ customer=user["stripe_customer_id"],
+ return_url=f"{base_url}/billing",
+ )
+
+ logger.info("Created portal session for user %s: %s", user_id, session.id)
+
+ return session.url
+
+
+def handle_webhook_event(payload: bytes, sig_header: str) -> dict:
+ """Verify and process Stripe webhook event.
+
+ Args:
+ payload: Raw webhook body
+ sig_header: Stripe-Signature header value
+
+ Returns:
+ dict with processing status
+
+ Note:
+ May raise stripe.error.SignatureVerificationError if invalid signature
+ """
+ # Verify webhook signature
+ event = stripe.Webhook.construct_event(
+ payload,
+ sig_header,
+ STRIPE_WEBHOOK_SECRET,
+ )
+
+ event_id = event["id"]
+ event_type = event["type"]
+
+ # Check if already processed (idempotency)
+ if Core.Database.has_processed_stripe_event(event_id):
+ logger.info("Skipping already processed event: %s", event_id)
+ return {"status": "skipped", "reason": "already_processed"}
+
+ # Process event based on type
+ logger.info("Processing webhook event: %s (%s)", event_id, event_type)
+
+ try:
+ if event_type == "checkout.session.completed":
+ _handle_checkout_completed(event["data"]["object"])
+ elif event_type == "customer.subscription.created":
+ _handle_subscription_created(event["data"]["object"])
+ elif event_type == "customer.subscription.updated":
+ _handle_subscription_updated(event["data"]["object"])
+ elif event_type == "customer.subscription.deleted":
+ _handle_subscription_deleted(event["data"]["object"])
+ elif event_type == "invoice.payment_failed":
+ _handle_payment_failed(event["data"]["object"])
+ else:
+ logger.info("Unhandled event type: %s", event_type)
+ return {"status": "ignored", "type": event_type}
+
+ # Mark event as processed
+ Core.Database.mark_stripe_event_processed(event_id, event_type, payload)
+ except Exception:
+ logger.exception("Error processing webhook event %s", event_id)
+ raise
+ else:
+ return {"status": "processed", "type": event_type}
+
+
+def _handle_checkout_completed(session: dict) -> None:
+ """Handle checkout.session.completed event."""
+ user_id = int(session.get("client_reference_id", 0))
+ customer_id = session.get("customer")
+
+ if not user_id or not customer_id:
+ logger.warning(
+ "Missing user_id or customer_id in checkout session: %s",
+ session["id"],
+ )
+ return
+
+ # Link Stripe customer to user
+ Core.Database.set_user_stripe_customer(user_id, customer_id)
+ logger.info("Linked user %s to Stripe customer %s", user_id, customer_id)
+
+
+def _handle_subscription_created(subscription: dict) -> None:
+ """Handle customer.subscription.created event."""
+ _update_subscription_state(subscription)
+
+
+def _handle_subscription_updated(subscription: dict) -> None:
+ """Handle customer.subscription.updated event."""
+ _update_subscription_state(subscription)
+
+
+def _handle_subscription_deleted(subscription: dict) -> None:
+ """Handle customer.subscription.deleted event."""
+ customer_id = subscription["customer"]
+
+ # Find user by customer ID
+ user = Core.Database.get_user_by_stripe_customer_id(customer_id)
+ if not user:
+ logger.warning("User not found for customer: %s", customer_id)
+ return
+
+ # Downgrade to free
+ Core.Database.downgrade_to_free(user["id"])
+ logger.info("Downgraded user %s to free tier", user["id"])
+
+
+def _handle_payment_failed(invoice: dict) -> None:
+ """Handle invoice.payment_failed event."""
+ customer_id = invoice["customer"]
+ subscription_id = invoice.get("subscription")
+
+ # Find user by customer ID
+ user = Core.Database.get_user_by_stripe_customer_id(customer_id)
+ if not user:
+ logger.warning("User not found for customer: %s", customer_id)
+ return
+
+ # Update subscription status to past_due
+ if subscription_id:
+ Core.Database.update_subscription_status(user["id"], "past_due")
+ logger.warning(
+ "Payment failed for user %s, subscription %s",
+ user["id"],
+ subscription_id,
+ )
+
+
+def _update_subscription_state(subscription: dict) -> None:
+ """Update user subscription state from Stripe subscription object."""
+ customer_id = subscription["customer"]
+ subscription_id = subscription["id"]
+ status = subscription["status"]
+ cancel_at_period_end = subscription.get("cancel_at_period_end", False)
+
+ # Get billing period
+ period_start = datetime.fromtimestamp(
+ subscription["current_period_start"],
+ tz=timezone.utc,
+ )
+ period_end = datetime.fromtimestamp(
+ subscription["current_period_end"],
+ tz=timezone.utc,
+ )
+
+ # Determine tier from price ID
+ price_id = subscription["items"]["data"][0]["price"]["id"]
+ tier = PRICE_TO_TIER.get(price_id, "free")
+
+ # Find user by customer ID
+ user = Core.Database.get_user_by_stripe_customer_id(customer_id)
+ if not user:
+ logger.warning("User not found for customer: %s", customer_id)
+ return
+
+ # Update user subscription
+ Core.Database.update_user_subscription(
+ user["id"],
+ subscription_id,
+ status,
+ period_start,
+ period_end,
+ tier,
+ cancel_at_period_end,
+ )
+
+ logger.info(
+ "Updated user %s subscription: tier=%s, status=%s",
+ user["id"],
+ tier,
+ status,
+ )
+
+
+def get_tier_info(tier: str) -> dict:
+ """Get tier information for display.
+
+ Returns:
+ dict with keys: name, articles_limit, price, description
+ """
+ tier_info = {
+ "free": {
+ "name": "Free",
+ "articles_limit": 10,
+ "price": "$0",
+ "description": "10 articles per month",
+ },
+ "personal": {
+ "name": "Personal",
+ "articles_limit": 50,
+ "price": "$9/mo",
+ "description": "50 articles per month",
+ },
+ "pro": {
+ "name": "Pro",
+ "articles_limit": None,
+ "price": "$29/mo",
+ "description": "Unlimited articles",
+ },
+ }
+ return tier_info.get(tier, tier_info["free"])
diff --git a/Biz/PodcastItLater/Core.py b/Biz/PodcastItLater/Core.py
index b7b6ed9..42b336e 100644
--- a/Biz/PodcastItLater/Core.py
+++ b/Biz/PodcastItLater/Core.py
@@ -141,6 +141,12 @@ class Database: # noqa: PLR0904
# Run migration to add default titles
Database.migrate_add_default_titles()
+ # Run migration to add billing fields
+ Database.migrate_add_billing_fields()
+
+ # Run migration to add stripe events table
+ Database.migrate_add_stripe_events_table()
+
@staticmethod
def add_to_queue(
url: str,
@@ -554,6 +560,53 @@ class Database: # noqa: PLR0904
logger.info("Database migrated to support user status")
@staticmethod
+ def migrate_add_billing_fields() -> None:
+ """Add billing-related fields to users table."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Add columns one by one (SQLite limitation)
+ columns_to_add = [
+ ("plan_tier", "TEXT NOT NULL DEFAULT 'free'"),
+ ("stripe_customer_id", "TEXT UNIQUE"),
+ ("stripe_subscription_id", "TEXT UNIQUE"),
+ ("subscription_status", "TEXT"),
+ ("current_period_start", "TIMESTAMP"),
+ ("current_period_end", "TIMESTAMP"),
+ ("cancel_at_period_end", "INTEGER NOT NULL DEFAULT 0"),
+ ]
+
+ for column_name, column_def in columns_to_add:
+ try:
+ query = f"ALTER TABLE users ADD COLUMN {column_name} "
+ cursor.execute(query + column_def)
+ logger.info("Added column users.%s", column_name)
+ except sqlite3.OperationalError: # noqa: PERF203
+ pass # Column already exists
+
+ conn.commit()
+
+ @staticmethod
+ def migrate_add_stripe_events_table() -> None:
+ """Create stripe_events table for webhook idempotency."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS stripe_events (
+ id TEXT PRIMARY KEY,
+ type TEXT NOT NULL,
+ payload TEXT NOT NULL,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+ )
+ """)
+ cursor.execute(
+ "CREATE INDEX IF NOT EXISTS idx_stripe_events_created "
+ "ON stripe_events(created_at)",
+ )
+ conn.commit()
+ logger.info("Created stripe_events table")
+
+ @staticmethod
def migrate_add_default_titles() -> None:
"""Add default titles to queue items that have None titles."""
with Database.get_connection() as conn:
@@ -730,6 +783,175 @@ class Database: # noqa: PLR0904
conn.commit()
logger.info("Updated user %s status to %s", user_id, status)
+ @staticmethod
+ def set_user_stripe_customer(user_id: int, customer_id: str) -> None:
+ """Link Stripe customer ID to user."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "UPDATE users SET stripe_customer_id = ? WHERE id = ?",
+ (customer_id, user_id),
+ )
+ conn.commit()
+ logger.info(
+ "Linked user %s to Stripe customer %s",
+ user_id,
+ customer_id,
+ )
+
+ @staticmethod
+ def update_user_subscription( # noqa: PLR0913, PLR0917
+ user_id: int,
+ subscription_id: str,
+ status: str,
+ period_start: Any,
+ period_end: Any,
+ tier: str,
+ cancel_at_period_end: bool, # noqa: FBT001
+ ) -> None:
+ """Update user subscription details."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ UPDATE users SET
+ stripe_subscription_id = ?,
+ subscription_status = ?,
+ current_period_start = ?,
+ current_period_end = ?,
+ plan_tier = ?,
+ cancel_at_period_end = ?
+ WHERE id = ?
+ """,
+ (
+ subscription_id,
+ status,
+ period_start.isoformat(),
+ period_end.isoformat(),
+ tier,
+ 1 if cancel_at_period_end else 0,
+ user_id,
+ ),
+ )
+ conn.commit()
+ logger.info(
+ "Updated user %s subscription: tier=%s, status=%s",
+ user_id,
+ tier,
+ status,
+ )
+
+ @staticmethod
+ def update_subscription_status(user_id: int, status: str) -> None:
+ """Update only the subscription status (e.g., past_due)."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "UPDATE users SET subscription_status = ? WHERE id = ?",
+ (status, user_id),
+ )
+ conn.commit()
+ logger.info(
+ "Updated user %s subscription status to %s",
+ user_id,
+ status,
+ )
+
+ @staticmethod
+ def downgrade_to_free(user_id: int) -> None:
+ """Downgrade user to free tier and clear subscription data."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ UPDATE users SET
+ plan_tier = 'free',
+ subscription_status = 'canceled',
+ stripe_subscription_id = NULL,
+ current_period_start = NULL,
+ current_period_end = NULL,
+ cancel_at_period_end = 0
+ WHERE id = ?
+ """,
+ (user_id,),
+ )
+ conn.commit()
+ logger.info("Downgraded user %s to free tier", user_id)
+
+ @staticmethod
+ def get_user_by_stripe_customer_id(
+ customer_id: str,
+ ) -> dict[str, Any] | None:
+ """Get user by Stripe customer ID."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT * FROM users WHERE stripe_customer_id = ?",
+ (customer_id,),
+ )
+ row = cursor.fetchone()
+ return dict(row) if row is not None else None
+
+ @staticmethod
+ def has_processed_stripe_event(event_id: str) -> bool:
+ """Check if Stripe event has already been processed."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT id FROM stripe_events WHERE id = ?",
+ (event_id,),
+ )
+ return cursor.fetchone() is not None
+
+ @staticmethod
+ def mark_stripe_event_processed(
+ event_id: str,
+ event_type: str,
+ payload: bytes,
+ ) -> None:
+ """Mark Stripe event as processed for idempotency."""
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "INSERT OR IGNORE INTO stripe_events (id, type, payload) "
+ "VALUES (?, ?, ?)",
+ (event_id, event_type, payload.decode("utf-8")),
+ )
+ conn.commit()
+
+ @staticmethod
+ def get_usage(
+ user_id: int,
+ period_start: Any,
+ period_end: Any,
+ ) -> dict[str, int]:
+ """Get usage stats for user in period.
+
+ Returns:
+ dict with keys: articles (int), minutes (int)
+ """
+ with Database.get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Count articles created in period
+ cursor.execute(
+ """
+ SELECT COUNT(*) as count, SUM(duration) as total_seconds
+ FROM episodes
+ WHERE user_id = ? AND created_at >= ? AND created_at < ?
+ """,
+ (user_id, period_start.isoformat(), period_end.isoformat()),
+ )
+ row = cursor.fetchone()
+
+ articles = row["count"] if row else 0
+ total_seconds = (
+ row["total_seconds"] if row and row["total_seconds"] else 0
+ )
+ minutes = total_seconds // 60
+
+ return {"articles": articles, "minutes": minutes}
+
class TestDatabase(Test.TestCase):
"""Test the Database class."""