summaryrefslogtreecommitdiff
path: root/Biz/PodcastItLater/Billing.py
diff options
context:
space:
mode:
Diffstat (limited to 'Biz/PodcastItLater/Billing.py')
-rw-r--r--Biz/PodcastItLater/Billing.py581
1 files changed, 581 insertions, 0 deletions
diff --git a/Biz/PodcastItLater/Billing.py b/Biz/PodcastItLater/Billing.py
new file mode 100644
index 0000000..9f3739d
--- /dev/null
+++ b/Biz/PodcastItLater/Billing.py
@@ -0,0 +1,581 @@
+"""
+PodcastItLater Billing Integration.
+
+Stripe subscription management and usage enforcement.
+"""
+
+# : out podcastitlater-billing
+# : dep stripe
+# : dep pytest
+# : dep pytest-mock
+import Biz.PodcastItLater.Core as Core
+import json
+import logging
+import Omni.App as App
+import Omni.Log as Log
+import Omni.Test as Test
+import os
+import stripe
+import sys
+import typing
+from datetime import datetime
+from datetime import timezone
+
+logger = logging.getLogger(__name__)
+Log.setup(logger)
+
+# Stripe configuration
+stripe.api_key = os.getenv("STRIPE_SECRET_KEY", "")
+STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET", "")
+
+# Price IDs from Stripe dashboard
+STRIPE_PRICE_ID_PAID = os.getenv("STRIPE_PRICE_ID_PAID", "")
+
+# Map Stripe price IDs to tier names
+PRICE_TO_TIER = {
+ STRIPE_PRICE_ID_PAID: "paid",
+}
+
+# Tier limits (None = unlimited)
+TIER_LIMITS: dict[str, dict[str, int | None]] = {
+ "free": {
+ "articles_per_period": 10,
+ "minutes_per_period": None,
+ },
+ "paid": {
+ "articles_per_period": None,
+ "minutes_per_period": None,
+ },
+}
+
+# Price map for checkout
+PRICE_MAP = {
+ "paid": STRIPE_PRICE_ID_PAID,
+}
+
+
+def get_period_boundaries(
+ user: dict[str, typing.Any],
+) -> tuple[datetime, datetime]:
+ """Get billing period boundaries for user.
+
+ For paid users: use Stripe subscription period.
+ For free users: lifetime (from account creation to far future).
+
+ Returns:
+ tuple: (period_start, period_end) as datetime objects
+ """
+ if user.get("plan_tier") != "free" and user.get("current_period_start"):
+ # Paid user - use Stripe billing period
+ period_start = datetime.fromisoformat(user["current_period_start"])
+ period_end = datetime.fromisoformat(user["current_period_end"])
+ else:
+ # Free user - lifetime limit from account creation
+ period_start = datetime.fromisoformat(user["created_at"])
+ # Set far future end date (100 years from now)
+ now = datetime.now(timezone.utc)
+ period_end = now.replace(year=now.year + 100)
+
+ return period_start, period_end
+
+
+def get_usage(
+ user_id: int,
+ period_start: datetime,
+ period_end: datetime,
+) -> dict[str, int]:
+ """Get usage stats for user in billing period.
+
+ Returns:
+ dict with keys: articles (int), minutes (int)
+ """
+ return Core.Database.get_usage(user_id, period_start, period_end)
+
+
+def can_submit(user_id: int) -> tuple[bool, str, dict[str, int]]:
+ """Check if user can submit article based on tier limits.
+
+ Returns:
+ tuple: (allowed: bool, message: str, usage: dict)
+ """
+ user = Core.Database.get_user_by_id(user_id)
+ if not user:
+ return False, "User not found", {}
+
+ tier = user.get("plan_tier", "free")
+ limits = TIER_LIMITS.get(tier, TIER_LIMITS["free"])
+
+ # Get billing period boundaries
+ period_start, period_end = get_period_boundaries(user)
+
+ # Get current usage
+ usage = get_usage(user_id, period_start, period_end)
+
+ # Check article limit
+ article_limit = limits.get("articles_per_period")
+ if article_limit is not None and usage["articles"] >= article_limit:
+ msg = (
+ f"You've reached your limit of {article_limit} articles "
+ "per period. Upgrade to continue."
+ )
+ return (False, msg, usage)
+
+ # Check minutes limit (if implemented)
+ minute_limit = limits.get("minutes_per_period")
+ if minute_limit is not None and usage.get("minutes", 0) >= minute_limit:
+ return (
+ False,
+ f"You've reached your limit of {minute_limit} minutes per period. "
+ "Please upgrade to continue.",
+ usage,
+ )
+
+ return True, "", usage
+
+
+def create_checkout_session(user_id: int, tier: str, base_url: str) -> str:
+ """Create Stripe Checkout session for subscription.
+
+ Args:
+ user_id: User ID
+ tier: Subscription tier (paid)
+ base_url: Base URL for success/cancel redirects
+
+ Returns:
+ Checkout session URL to redirect user to
+
+ Raises:
+ ValueError: If tier is invalid or price ID not configured
+ """
+ if tier not in PRICE_MAP:
+ msg = f"Invalid tier: {tier}"
+ raise ValueError(msg)
+
+ price_id = PRICE_MAP[tier]
+ if not price_id:
+ msg = f"Stripe price ID not configured for tier: {tier}"
+ raise ValueError(msg)
+
+ user = Core.Database.get_user_by_id(user_id)
+ if not user:
+ msg = f"User not found: {user_id}"
+ raise ValueError(msg)
+
+ # Create checkout session
+ session_params = {
+ "mode": "subscription",
+ "line_items": [{"price": price_id, "quantity": 1}],
+ "success_url": f"{base_url}/?status=success",
+ "cancel_url": f"{base_url}/?status=cancel",
+ "client_reference_id": str(user_id),
+ "metadata": {"user_id": str(user_id), "tier": tier},
+ "allow_promotion_codes": True,
+ }
+
+ # Use existing customer if available
+ if user.get("stripe_customer_id"):
+ session_params["customer"] = user["stripe_customer_id"]
+ else:
+ session_params["customer_email"] = user["email"]
+
+ session = stripe.checkout.Session.create(**session_params) # type: ignore[arg-type]
+
+ logger.info(
+ "Created checkout session for user %s, tier %s: %s",
+ user_id,
+ tier,
+ session.id,
+ )
+
+ return session.url # type: ignore[return-value]
+
+
+def create_portal_session(user_id: int, base_url: str) -> str:
+ """Create Stripe Billing Portal session.
+
+ Args:
+ user_id: User ID
+ base_url: Base URL for return redirect
+
+ Returns:
+ Portal session URL to redirect user to
+
+ Raises:
+ ValueError: If user has no Stripe customer ID or portal not configured
+ """
+ user = Core.Database.get_user_by_id(user_id)
+ if not user:
+ msg = f"User not found: {user_id}"
+ raise ValueError(msg)
+
+ if not user.get("stripe_customer_id"):
+ msg = "User has no Stripe customer ID"
+ raise ValueError(msg)
+
+ session = stripe.billing_portal.Session.create(
+ customer=user["stripe_customer_id"],
+ return_url=f"{base_url}/account",
+ )
+
+ logger.info(
+ "Created portal session for user %s: %s",
+ user_id,
+ session.id,
+ )
+
+ return session.url
+
+
+def handle_webhook_event(payload: bytes, sig_header: str) -> dict[str, str]:
+ """Verify and process Stripe webhook event.
+
+ Args:
+ payload: Raw webhook body
+ sig_header: Stripe-Signature header value
+
+ Returns:
+ dict with processing status
+
+ Note:
+ May raise stripe.error.SignatureVerificationError if invalid signature
+ """
+ # Verify webhook signature (skip in test mode if secret not configured)
+ if STRIPE_WEBHOOK_SECRET:
+ event = stripe.Webhook.construct_event( # type: ignore[no-untyped-call]
+ payload,
+ sig_header,
+ STRIPE_WEBHOOK_SECRET,
+ )
+ else:
+ # Test mode without signature verification
+ logger.warning(
+ "Webhook signature verification skipped (no STRIPE_WEBHOOK_SECRET)",
+ )
+ event = json.loads(payload.decode("utf-8"))
+
+ event_id = event["id"]
+ event_type = event["type"]
+
+ # Check if already processed (idempotency)
+ if Core.Database.has_processed_stripe_event(event_id):
+ logger.info("Skipping already processed event: %s", event_id)
+ return {"status": "skipped", "reason": "already_processed"}
+
+ # Process event based on type
+ logger.info("Processing webhook event: %s (%s)", event_id, event_type)
+
+ try:
+ if event_type == "checkout.session.completed":
+ _handle_checkout_completed(event["data"]["object"])
+ elif event_type == "customer.subscription.created":
+ _handle_subscription_created(event["data"]["object"])
+ elif event_type == "customer.subscription.updated":
+ _handle_subscription_updated(event["data"]["object"])
+ elif event_type == "customer.subscription.deleted":
+ _handle_subscription_deleted(event["data"]["object"])
+ elif event_type == "invoice.payment_failed":
+ _handle_payment_failed(event["data"]["object"])
+ else:
+ logger.info("Unhandled event type: %s", event_type)
+ return {"status": "ignored", "type": event_type}
+
+ # Mark event as processed
+ Core.Database.mark_stripe_event_processed(event_id, event_type, payload)
+ except Exception:
+ logger.exception("Error processing webhook event %s", event_id)
+ raise
+ else:
+ return {"status": "processed", "type": event_type}
+
+
+def _handle_checkout_completed(session: dict[str, typing.Any]) -> None:
+ """Handle checkout.session.completed event."""
+ client_ref = session.get("client_reference_id") or session.get(
+ "metadata",
+ {},
+ ).get("user_id")
+ customer_id = session.get("customer")
+
+ if not client_ref or not customer_id:
+ logger.warning(
+ "Missing user_id or customer_id in checkout session: %s",
+ session.get("id", "unknown"),
+ )
+ return
+
+ try:
+ user_id = int(client_ref)
+ except (ValueError, TypeError):
+ logger.warning(
+ "Invalid user_id in checkout session: %s",
+ client_ref,
+ )
+ return
+
+ # Link Stripe customer to user
+ Core.Database.set_user_stripe_customer(user_id, customer_id)
+ logger.info("Linked user %s to Stripe customer %s", user_id, customer_id)
+
+
+def _handle_subscription_created(subscription: dict[str, typing.Any]) -> None:
+ """Handle customer.subscription.created event."""
+ _update_subscription_state(subscription)
+
+
+def _handle_subscription_updated(subscription: dict[str, typing.Any]) -> None:
+ """Handle customer.subscription.updated event."""
+ _update_subscription_state(subscription)
+
+
+def _handle_subscription_deleted(subscription: dict[str, typing.Any]) -> None:
+ """Handle customer.subscription.deleted event."""
+ customer_id = subscription["customer"]
+
+ # Find user by customer ID
+ user = Core.Database.get_user_by_stripe_customer_id(customer_id)
+ if not user:
+ logger.warning("User not found for customer: %s", customer_id)
+ return
+
+ # Downgrade to free
+ Core.Database.downgrade_to_free(user["id"])
+ logger.info("Downgraded user %s to free tier", user["id"])
+
+
+def _handle_payment_failed(invoice: dict[str, typing.Any]) -> None:
+ """Handle invoice.payment_failed event."""
+ customer_id = invoice["customer"]
+ subscription_id = invoice.get("subscription")
+
+ # Find user by customer ID
+ user = Core.Database.get_user_by_stripe_customer_id(customer_id)
+ if not user:
+ logger.warning("User not found for customer: %s", customer_id)
+ return
+
+ # Update subscription status to past_due
+ if subscription_id:
+ Core.Database.update_subscription_status(user["id"], "past_due")
+ logger.warning(
+ "Payment failed for user %s, subscription %s",
+ user["id"],
+ subscription_id,
+ )
+
+
+def _update_subscription_state(subscription: dict[str, typing.Any]) -> None:
+ """Update user subscription state from Stripe subscription object."""
+ customer_id = subscription.get("customer")
+ subscription_id = subscription.get("id")
+ status = subscription.get("status")
+ cancel_at_period_end = subscription.get("cancel_at_period_end", False)
+
+ if not customer_id or not subscription_id or not status:
+ logger.warning(
+ "Missing required fields in subscription: %s",
+ subscription_id,
+ )
+ return
+
+ # Get billing period - try multiple field names for API compatibility
+ period_start_ts = (
+ subscription.get("current_period_start")
+ or subscription.get("billing_cycle_anchor")
+ or subscription.get("start_date")
+ )
+ period_end_ts = subscription.get("current_period_end")
+
+ if not period_start_ts:
+ logger.warning(
+ "Missing period start in subscription: %s",
+ subscription_id,
+ )
+ return
+
+ period_start = datetime.fromtimestamp(period_start_ts, tz=timezone.utc)
+
+ # Calculate period end if not provided (assume monthly)
+ december = 12
+ january = 1
+ if not period_end_ts:
+ if period_start.month == december:
+ period_end = period_start.replace(
+ year=period_start.year + 1,
+ month=january,
+ )
+ else:
+ period_end = period_start.replace(month=period_start.month + 1)
+ else:
+ period_end = datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
+
+ # Determine tier from price ID
+ items = subscription.get("items", {})
+ data = items.get("data", [])
+ if not data:
+ logger.warning("No items in subscription: %s", subscription_id)
+ return
+
+ price_id = data[0].get("price", {}).get("id")
+ if not price_id:
+ logger.warning("No price ID in subscription: %s", subscription_id)
+ return
+
+ tier = PRICE_TO_TIER.get(price_id, "free")
+
+ # Find user by customer ID
+ user = Core.Database.get_user_by_stripe_customer_id(customer_id)
+ if not user:
+ logger.warning("User not found for customer: %s", customer_id)
+ return
+
+ # Update user subscription
+ Core.Database.update_user_subscription(
+ user["id"],
+ subscription_id,
+ status,
+ period_start,
+ period_end,
+ tier,
+ cancel_at_period_end,
+ )
+
+ logger.info(
+ "Updated user %s subscription: tier=%s, status=%s",
+ user["id"],
+ tier,
+ status,
+ )
+
+
+def get_tier_info(tier: str) -> dict[str, typing.Any]:
+ """Get tier information for display.
+
+ Returns:
+ dict with keys: name, articles_limit, price, description
+ """
+ tier_info: dict[str, dict[str, typing.Any]] = {
+ "free": {
+ "name": "Free",
+ "articles_limit": 10,
+ "price": "$0",
+ "description": "10 articles total",
+ },
+ "paid": {
+ "name": "Paid",
+ "articles_limit": None,
+ "price": "$12/mo",
+ "description": "Unlimited articles",
+ },
+ }
+ return tier_info.get(tier, tier_info["free"])
+
+
+# Tests
+# ruff: noqa: PLR6301, PLW0603, S101
+
+
+class TestWebhookHandling(Test.TestCase):
+ """Test Stripe webhook handling."""
+
+ def setUp(self) -> None:
+ """Set up test database."""
+ Core.Database.init_db()
+
+ def tearDown(self) -> None:
+ """Clean up test database."""
+ Core.Database.teardown()
+
+ def test_full_checkout_flow(self) -> None:
+ """Test complete checkout flow from session to subscription."""
+ # Create test user
+ user_id, _token = Core.Database.create_user("test@example.com")
+
+ # Temporarily set price mapping for test
+ global PRICE_TO_TIER
+ old_mapping = PRICE_TO_TIER.copy()
+ PRICE_TO_TIER["price_test_paid"] = "paid"
+
+ try:
+ # Step 1: Handle checkout.session.completed
+ checkout_session = {
+ "id": "cs_test123",
+ "client_reference_id": str(user_id),
+ "customer": "cus_test123",
+ "metadata": {"user_id": str(user_id), "tier": "paid"},
+ }
+ _handle_checkout_completed(checkout_session)
+
+ # Verify customer was linked
+ user = Core.Database.get_user_by_id(user_id)
+ self.assertIsNotNone(user)
+ assert user is not None
+ self.assertEqual(user["stripe_customer_id"], "cus_test123")
+
+ # Step 2: Handle customer.subscription.created
+ # (newer API uses billing_cycle_anchor instead of current_period_*)
+ subscription = {
+ "id": "sub_test123",
+ "customer": "cus_test123",
+ "status": "active",
+ "billing_cycle_anchor": 1700000000,
+ "cancel_at_period_end": False,
+ "items": {
+ "data": [
+ {
+ "price": {
+ "id": "price_test_paid",
+ },
+ },
+ ],
+ },
+ }
+ _update_subscription_state(subscription)
+
+ # Verify subscription was created and user upgraded
+ user = Core.Database.get_user_by_id(user_id)
+ self.assertIsNotNone(user)
+ assert user is not None
+ self.assertEqual(user["plan_tier"], "paid")
+ self.assertEqual(user["subscription_status"], "active")
+ self.assertEqual(user["stripe_subscription_id"], "sub_test123")
+ self.assertEqual(user["stripe_customer_id"], "cus_test123")
+ finally:
+ PRICE_TO_TIER = old_mapping
+
+ def test_webhook_missing_fields(self) -> None:
+ """Test handling webhook with missing required fields."""
+ # Create test user
+ user_id, _token = Core.Database.create_user("test@example.com")
+ Core.Database.set_user_stripe_customer(user_id, "cus_test456")
+
+ # Mock subscription with missing current_period_start
+ subscription = {
+ "id": "sub_test456",
+ "customer": "cus_test456",
+ "status": "active",
+ # Missing current_period_start and current_period_end
+ "cancel_at_period_end": False,
+ "items": {"data": []},
+ }
+
+ # Should not crash, just log warning and return
+ _update_subscription_state(subscription)
+
+ # User should remain on free tier
+ user = Core.Database.get_user_by_id(user_id)
+ self.assertIsNotNone(user)
+ assert user is not None
+ self.assertEqual(user["plan_tier"], "free")
+
+
+def main() -> None:
+ """Run tests."""
+ if len(sys.argv) > 1 and sys.argv[1] == "test":
+ os.environ["AREA"] = "Test"
+ Test.run(App.Area.Test, [TestWebhookHandling])
+ else:
+ logger.error("Usage: billing.py test")
+
+
+if __name__ == "__main__":
+ main()