""" PodcastItLater Billing Integration. Stripe subscription management and usage enforcement. """ # : out podcastitlater-billing # : dep stripe # : dep pytest # : dep pytest-mock import Biz.PodcastItLater.Core as Core import Omni.App as App import Omni.Log as Log import Omni.Test as Test import os import stripe import sys import typing from datetime import datetime from datetime import timezone logger = Log.setup() # Stripe configuration stripe.api_key = os.getenv("STRIPE_SECRET_KEY", "") STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET", "") # Price IDs from Stripe dashboard STRIPE_PRICE_ID_PRO = os.getenv("STRIPE_PRICE_ID_PRO", "") # Map Stripe price IDs to tier names PRICE_TO_TIER = { STRIPE_PRICE_ID_PRO: "pro", } # Tier limits (None = unlimited) TIER_LIMITS = { "free": { "articles_per_period": 10, "minutes_per_period": None, }, "pro": { "articles_per_period": None, "minutes_per_period": None, }, } # Price map for checkout PRICE_MAP = { "pro": STRIPE_PRICE_ID_PRO, } def get_period_boundaries( user: dict[str, typing.Any], ) -> tuple[datetime, datetime]: """Get billing period boundaries for user. For paid users: use Stripe subscription period. For free users: lifetime (from account creation to far future). Returns: tuple: (period_start, period_end) as datetime objects """ if user.get("plan_tier") != "free" and user.get("current_period_start"): # Paid user - use Stripe billing period period_start = datetime.fromisoformat(user["current_period_start"]) period_end = datetime.fromisoformat(user["current_period_end"]) else: # Free user - lifetime limit from account creation period_start = datetime.fromisoformat(user["created_at"]) # Set far future end date (100 years from now) now = datetime.now(timezone.utc) period_end = now.replace(year=now.year + 100) return period_start, period_end def get_usage( user_id: int, period_start: datetime, period_end: datetime, ) -> dict[str, int]: """Get usage stats for user in billing period. Returns: dict with keys: articles (int), minutes (int) """ return Core.Database.get_usage(user_id, period_start, period_end) def can_submit(user_id: int) -> tuple[bool, str, dict[str, int]]: """Check if user can submit article based on tier limits. Returns: tuple: (allowed: bool, message: str, usage: dict) """ user = Core.Database.get_user_by_id(user_id) if not user: return False, "User not found", {} tier = user.get("plan_tier", "free") limits = TIER_LIMITS.get(tier, TIER_LIMITS["free"]) # Get billing period boundaries period_start, period_end = get_period_boundaries(user) # Get current usage usage = get_usage(user_id, period_start, period_end) # Check article limit article_limit = limits["articles_per_period"] # type: ignore[index] if article_limit is not None and usage["articles"] >= article_limit: msg = ( f"You've reached your limit of {article_limit} articles " "per period. Upgrade to continue." ) return (False, msg, usage) # Check minutes limit (if implemented) minute_limit = limits.get("minutes_per_period") # type: ignore[attr-defined] if minute_limit is not None and usage.get("minutes", 0) >= minute_limit: return ( False, f"You've reached your limit of {minute_limit} minutes per period. " "Please upgrade to continue.", usage, ) return True, "", usage def create_checkout_session(user_id: int, tier: str, base_url: str) -> str: """Create Stripe Checkout session for subscription. Args: user_id: User ID tier: Subscription tier (pro) base_url: Base URL for success/cancel redirects Returns: Checkout session URL to redirect user to Raises: ValueError: If tier is invalid or price ID not configured """ if tier not in PRICE_MAP: msg = f"Invalid tier: {tier}" raise ValueError(msg) price_id = PRICE_MAP[tier] if not price_id: msg = f"Stripe price ID not configured for tier: {tier}" raise ValueError(msg) user = Core.Database.get_user_by_id(user_id) if not user: msg = f"User not found: {user_id}" raise ValueError(msg) # Create checkout session session_params = { "mode": "subscription", "line_items": [{"price": price_id, "quantity": 1}], "success_url": f"{base_url}/billing?status=success", "cancel_url": f"{base_url}/billing?status=cancel", "client_reference_id": str(user_id), "metadata": {"user_id": str(user_id), "tier": tier}, "allow_promotion_codes": True, } # Use existing customer if available if user.get("stripe_customer_id"): session_params["customer"] = user["stripe_customer_id"] else: session_params["customer_email"] = user["email"] session = stripe.checkout.Session.create(**session_params) # type: ignore[arg-type] logger.info( "Created checkout session for user %s, tier %s: %s", user_id, tier, session.id, ) return session.url # type: ignore[return-value] def create_portal_session(user_id: int, base_url: str) -> str: """Create Stripe Billing Portal session. Args: user_id: User ID base_url: Base URL for return redirect Returns: Portal session URL to redirect user to Raises: ValueError: If user has no Stripe customer ID """ user = Core.Database.get_user_by_id(user_id) if not user: msg = f"User not found: {user_id}" raise ValueError(msg) if not user.get("stripe_customer_id"): msg = "User has no Stripe customer ID" raise ValueError(msg) session = stripe.billing_portal.Session.create( customer=user["stripe_customer_id"], return_url=f"{base_url}/billing", ) logger.info("Created portal session for user %s: %s", user_id, session.id) return session.url def handle_webhook_event(payload: bytes, sig_header: str) -> dict[str, str]: """Verify and process Stripe webhook event. Args: payload: Raw webhook body sig_header: Stripe-Signature header value Returns: dict with processing status Note: May raise stripe.error.SignatureVerificationError if invalid signature """ # Verify webhook signature event = stripe.Webhook.construct_event( # type: ignore[no-untyped-call] payload, sig_header, STRIPE_WEBHOOK_SECRET, ) event_id = event["id"] event_type = event["type"] # Check if already processed (idempotency) if Core.Database.has_processed_stripe_event(event_id): logger.info("Skipping already processed event: %s", event_id) return {"status": "skipped", "reason": "already_processed"} # Process event based on type logger.info("Processing webhook event: %s (%s)", event_id, event_type) try: if event_type == "checkout.session.completed": _handle_checkout_completed(event["data"]["object"]) elif event_type == "customer.subscription.created": _handle_subscription_created(event["data"]["object"]) elif event_type == "customer.subscription.updated": _handle_subscription_updated(event["data"]["object"]) elif event_type == "customer.subscription.deleted": _handle_subscription_deleted(event["data"]["object"]) elif event_type == "invoice.payment_failed": _handle_payment_failed(event["data"]["object"]) else: logger.info("Unhandled event type: %s", event_type) return {"status": "ignored", "type": event_type} # Mark event as processed Core.Database.mark_stripe_event_processed(event_id, event_type, payload) except Exception: logger.exception("Error processing webhook event %s", event_id) raise else: return {"status": "processed", "type": event_type} def _handle_checkout_completed(session: dict[str, typing.Any]) -> None: """Handle checkout.session.completed event.""" user_id = int(session.get("client_reference_id", 0)) customer_id = session.get("customer") if not user_id or not customer_id: logger.warning( "Missing user_id or customer_id in checkout session: %s", session["id"], ) return # Link Stripe customer to user Core.Database.set_user_stripe_customer(user_id, customer_id) logger.info("Linked user %s to Stripe customer %s", user_id, customer_id) def _handle_subscription_created(subscription: dict[str, typing.Any]) -> None: """Handle customer.subscription.created event.""" _update_subscription_state(subscription) def _handle_subscription_updated(subscription: dict[str, typing.Any]) -> None: """Handle customer.subscription.updated event.""" _update_subscription_state(subscription) def _handle_subscription_deleted(subscription: dict[str, typing.Any]) -> None: """Handle customer.subscription.deleted event.""" customer_id = subscription["customer"] # Find user by customer ID user = Core.Database.get_user_by_stripe_customer_id(customer_id) if not user: logger.warning("User not found for customer: %s", customer_id) return # Downgrade to free Core.Database.downgrade_to_free(user["id"]) logger.info("Downgraded user %s to free tier", user["id"]) def _handle_payment_failed(invoice: dict[str, typing.Any]) -> None: """Handle invoice.payment_failed event.""" customer_id = invoice["customer"] subscription_id = invoice.get("subscription") # Find user by customer ID user = Core.Database.get_user_by_stripe_customer_id(customer_id) if not user: logger.warning("User not found for customer: %s", customer_id) return # Update subscription status to past_due if subscription_id: Core.Database.update_subscription_status(user["id"], "past_due") logger.warning( "Payment failed for user %s, subscription %s", user["id"], subscription_id, ) def _update_subscription_state(subscription: dict[str, typing.Any]) -> None: """Update user subscription state from Stripe subscription object.""" customer_id = subscription.get("customer") subscription_id = subscription.get("id") status = subscription.get("status") cancel_at_period_end = subscription.get("cancel_at_period_end", False) if not customer_id or not subscription_id or not status: logger.warning( "Missing required fields in subscription: %s", subscription_id, ) return # Get billing period - handle both dict and object access patterns period_start_ts = subscription.get("current_period_start") period_end_ts = subscription.get("current_period_end") if not period_start_ts or not period_end_ts: logger.warning( "Missing period dates in subscription: %s", subscription_id, ) return period_start = datetime.fromtimestamp(period_start_ts, tz=timezone.utc) period_end = datetime.fromtimestamp(period_end_ts, tz=timezone.utc) # Determine tier from price ID items = subscription.get("items", {}) data = items.get("data", []) if not data: logger.warning("No items in subscription: %s", subscription_id) return price_id = data[0].get("price", {}).get("id") if not price_id: logger.warning("No price ID in subscription: %s", subscription_id) return tier = PRICE_TO_TIER.get(price_id, "free") # Find user by customer ID user = Core.Database.get_user_by_stripe_customer_id(customer_id) if not user: logger.warning("User not found for customer: %s", customer_id) return # Update user subscription Core.Database.update_user_subscription( user["id"], subscription_id, status, period_start, period_end, tier, cancel_at_period_end, ) logger.info( "Updated user %s subscription: tier=%s, status=%s", user["id"], tier, status, ) def get_tier_info(tier: str) -> dict[str, typing.Any]: """Get tier information for display. Returns: dict with keys: name, articles_limit, price, description """ tier_info = { "free": { "name": "Free", "articles_limit": 10, "price": "$0", "description": "10 articles total", }, "pro": { "name": "Pro", "articles_limit": None, "price": "$29/mo", "description": "Unlimited articles", }, } return tier_info.get(tier, tier_info["free"]) # type: ignore[return-value] # Tests # ruff: noqa: PLR6301, PLW0603, S101 class TestWebhookHandling(Test.TestCase): """Test Stripe webhook handling.""" def setUp(self) -> None: """Set up test database.""" Core.Database.init_db() def tearDown(self) -> None: """Clean up test database.""" Core.Database.teardown() def test_subscription_created(self) -> None: """Test handling subscription.created webhook with mock data.""" # Create test user user_id, _token = Core.Database.create_user("test@example.com") Core.Database.set_user_stripe_customer(user_id, "cus_test123") # Mock subscription event subscription = { "id": "sub_test123", "customer": "cus_test123", "status": "active", "current_period_start": 1700000000, "current_period_end": 1702592000, "cancel_at_period_end": False, "items": { "data": [ { "price": { "id": "price_test_pro", }, }, ], }, } # Temporarily set price mapping for test global PRICE_TO_TIER old_mapping = PRICE_TO_TIER.copy() PRICE_TO_TIER["price_test_pro"] = "pro" try: _update_subscription_state(subscription) # Verify user was updated user = Core.Database.get_user_by_id(user_id) self.assertIsNotNone(user) assert user is not None self.assertEqual(user["plan_tier"], "pro") self.assertEqual(user["subscription_status"], "active") self.assertEqual(user["stripe_subscription_id"], "sub_test123") finally: PRICE_TO_TIER = old_mapping def test_webhook_missing_fields(self) -> None: """Test handling webhook with missing required fields.""" # Create test user user_id, _token = Core.Database.create_user("test@example.com") Core.Database.set_user_stripe_customer(user_id, "cus_test456") # Mock subscription with missing current_period_start subscription = { "id": "sub_test456", "customer": "cus_test456", "status": "active", # Missing current_period_start and current_period_end "cancel_at_period_end": False, "items": {"data": []}, } # Should not crash, just log warning and return _update_subscription_state(subscription) # User should remain on free tier user = Core.Database.get_user_by_id(user_id) self.assertIsNotNone(user) assert user is not None self.assertEqual(user["plan_tier"], "free") def main() -> None: """Run tests.""" if len(sys.argv) > 1 and sys.argv[1] == "test": os.environ["AREA"] = "Test" Test.run(App.Area.Test, [TestWebhookHandling]) else: logger.error("Usage: billing.py test") if __name__ == "__main__": main()