diff options
Diffstat (limited to 'Biz/PodcastItLater/Billing.py')
| -rw-r--r-- | Biz/PodcastItLater/Billing.py | 581 |
1 files changed, 581 insertions, 0 deletions
diff --git a/Biz/PodcastItLater/Billing.py b/Biz/PodcastItLater/Billing.py new file mode 100644 index 0000000..9f3739d --- /dev/null +++ b/Biz/PodcastItLater/Billing.py @@ -0,0 +1,581 @@ +""" +PodcastItLater Billing Integration. + +Stripe subscription management and usage enforcement. +""" + +# : out podcastitlater-billing +# : dep stripe +# : dep pytest +# : dep pytest-mock +import Biz.PodcastItLater.Core as Core +import json +import logging +import Omni.App as App +import Omni.Log as Log +import Omni.Test as Test +import os +import stripe +import sys +import typing +from datetime import datetime +from datetime import timezone + +logger = logging.getLogger(__name__) +Log.setup(logger) + +# Stripe configuration +stripe.api_key = os.getenv("STRIPE_SECRET_KEY", "") +STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET", "") + +# Price IDs from Stripe dashboard +STRIPE_PRICE_ID_PAID = os.getenv("STRIPE_PRICE_ID_PAID", "") + +# Map Stripe price IDs to tier names +PRICE_TO_TIER = { + STRIPE_PRICE_ID_PAID: "paid", +} + +# Tier limits (None = unlimited) +TIER_LIMITS: dict[str, dict[str, int | None]] = { + "free": { + "articles_per_period": 10, + "minutes_per_period": None, + }, + "paid": { + "articles_per_period": None, + "minutes_per_period": None, + }, +} + +# Price map for checkout +PRICE_MAP = { + "paid": STRIPE_PRICE_ID_PAID, +} + + +def get_period_boundaries( + user: dict[str, typing.Any], +) -> tuple[datetime, datetime]: + """Get billing period boundaries for user. + + For paid users: use Stripe subscription period. + For free users: lifetime (from account creation to far future). + + Returns: + tuple: (period_start, period_end) as datetime objects + """ + if user.get("plan_tier") != "free" and user.get("current_period_start"): + # Paid user - use Stripe billing period + period_start = datetime.fromisoformat(user["current_period_start"]) + period_end = datetime.fromisoformat(user["current_period_end"]) + else: + # Free user - lifetime limit from account creation + period_start = datetime.fromisoformat(user["created_at"]) + # Set far future end date (100 years from now) + now = datetime.now(timezone.utc) + period_end = now.replace(year=now.year + 100) + + return period_start, period_end + + +def get_usage( + user_id: int, + period_start: datetime, + period_end: datetime, +) -> dict[str, int]: + """Get usage stats for user in billing period. + + Returns: + dict with keys: articles (int), minutes (int) + """ + return Core.Database.get_usage(user_id, period_start, period_end) + + +def can_submit(user_id: int) -> tuple[bool, str, dict[str, int]]: + """Check if user can submit article based on tier limits. + + Returns: + tuple: (allowed: bool, message: str, usage: dict) + """ + user = Core.Database.get_user_by_id(user_id) + if not user: + return False, "User not found", {} + + tier = user.get("plan_tier", "free") + limits = TIER_LIMITS.get(tier, TIER_LIMITS["free"]) + + # Get billing period boundaries + period_start, period_end = get_period_boundaries(user) + + # Get current usage + usage = get_usage(user_id, period_start, period_end) + + # Check article limit + article_limit = limits.get("articles_per_period") + if article_limit is not None and usage["articles"] >= article_limit: + msg = ( + f"You've reached your limit of {article_limit} articles " + "per period. Upgrade to continue." + ) + return (False, msg, usage) + + # Check minutes limit (if implemented) + minute_limit = limits.get("minutes_per_period") + if minute_limit is not None and usage.get("minutes", 0) >= minute_limit: + return ( + False, + f"You've reached your limit of {minute_limit} minutes per period. " + "Please upgrade to continue.", + usage, + ) + + return True, "", usage + + +def create_checkout_session(user_id: int, tier: str, base_url: str) -> str: + """Create Stripe Checkout session for subscription. + + Args: + user_id: User ID + tier: Subscription tier (paid) + base_url: Base URL for success/cancel redirects + + Returns: + Checkout session URL to redirect user to + + Raises: + ValueError: If tier is invalid or price ID not configured + """ + if tier not in PRICE_MAP: + msg = f"Invalid tier: {tier}" + raise ValueError(msg) + + price_id = PRICE_MAP[tier] + if not price_id: + msg = f"Stripe price ID not configured for tier: {tier}" + raise ValueError(msg) + + user = Core.Database.get_user_by_id(user_id) + if not user: + msg = f"User not found: {user_id}" + raise ValueError(msg) + + # Create checkout session + session_params = { + "mode": "subscription", + "line_items": [{"price": price_id, "quantity": 1}], + "success_url": f"{base_url}/?status=success", + "cancel_url": f"{base_url}/?status=cancel", + "client_reference_id": str(user_id), + "metadata": {"user_id": str(user_id), "tier": tier}, + "allow_promotion_codes": True, + } + + # Use existing customer if available + if user.get("stripe_customer_id"): + session_params["customer"] = user["stripe_customer_id"] + else: + session_params["customer_email"] = user["email"] + + session = stripe.checkout.Session.create(**session_params) # type: ignore[arg-type] + + logger.info( + "Created checkout session for user %s, tier %s: %s", + user_id, + tier, + session.id, + ) + + return session.url # type: ignore[return-value] + + +def create_portal_session(user_id: int, base_url: str) -> str: + """Create Stripe Billing Portal session. + + Args: + user_id: User ID + base_url: Base URL for return redirect + + Returns: + Portal session URL to redirect user to + + Raises: + ValueError: If user has no Stripe customer ID or portal not configured + """ + user = Core.Database.get_user_by_id(user_id) + if not user: + msg = f"User not found: {user_id}" + raise ValueError(msg) + + if not user.get("stripe_customer_id"): + msg = "User has no Stripe customer ID" + raise ValueError(msg) + + session = stripe.billing_portal.Session.create( + customer=user["stripe_customer_id"], + return_url=f"{base_url}/account", + ) + + logger.info( + "Created portal session for user %s: %s", + user_id, + session.id, + ) + + return session.url + + +def handle_webhook_event(payload: bytes, sig_header: str) -> dict[str, str]: + """Verify and process Stripe webhook event. + + Args: + payload: Raw webhook body + sig_header: Stripe-Signature header value + + Returns: + dict with processing status + + Note: + May raise stripe.error.SignatureVerificationError if invalid signature + """ + # Verify webhook signature (skip in test mode if secret not configured) + if STRIPE_WEBHOOK_SECRET: + event = stripe.Webhook.construct_event( # type: ignore[no-untyped-call] + payload, + sig_header, + STRIPE_WEBHOOK_SECRET, + ) + else: + # Test mode without signature verification + logger.warning( + "Webhook signature verification skipped (no STRIPE_WEBHOOK_SECRET)", + ) + event = json.loads(payload.decode("utf-8")) + + event_id = event["id"] + event_type = event["type"] + + # Check if already processed (idempotency) + if Core.Database.has_processed_stripe_event(event_id): + logger.info("Skipping already processed event: %s", event_id) + return {"status": "skipped", "reason": "already_processed"} + + # Process event based on type + logger.info("Processing webhook event: %s (%s)", event_id, event_type) + + try: + if event_type == "checkout.session.completed": + _handle_checkout_completed(event["data"]["object"]) + elif event_type == "customer.subscription.created": + _handle_subscription_created(event["data"]["object"]) + elif event_type == "customer.subscription.updated": + _handle_subscription_updated(event["data"]["object"]) + elif event_type == "customer.subscription.deleted": + _handle_subscription_deleted(event["data"]["object"]) + elif event_type == "invoice.payment_failed": + _handle_payment_failed(event["data"]["object"]) + else: + logger.info("Unhandled event type: %s", event_type) + return {"status": "ignored", "type": event_type} + + # Mark event as processed + Core.Database.mark_stripe_event_processed(event_id, event_type, payload) + except Exception: + logger.exception("Error processing webhook event %s", event_id) + raise + else: + return {"status": "processed", "type": event_type} + + +def _handle_checkout_completed(session: dict[str, typing.Any]) -> None: + """Handle checkout.session.completed event.""" + client_ref = session.get("client_reference_id") or session.get( + "metadata", + {}, + ).get("user_id") + customer_id = session.get("customer") + + if not client_ref or not customer_id: + logger.warning( + "Missing user_id or customer_id in checkout session: %s", + session.get("id", "unknown"), + ) + return + + try: + user_id = int(client_ref) + except (ValueError, TypeError): + logger.warning( + "Invalid user_id in checkout session: %s", + client_ref, + ) + return + + # Link Stripe customer to user + Core.Database.set_user_stripe_customer(user_id, customer_id) + logger.info("Linked user %s to Stripe customer %s", user_id, customer_id) + + +def _handle_subscription_created(subscription: dict[str, typing.Any]) -> None: + """Handle customer.subscription.created event.""" + _update_subscription_state(subscription) + + +def _handle_subscription_updated(subscription: dict[str, typing.Any]) -> None: + """Handle customer.subscription.updated event.""" + _update_subscription_state(subscription) + + +def _handle_subscription_deleted(subscription: dict[str, typing.Any]) -> None: + """Handle customer.subscription.deleted event.""" + customer_id = subscription["customer"] + + # Find user by customer ID + user = Core.Database.get_user_by_stripe_customer_id(customer_id) + if not user: + logger.warning("User not found for customer: %s", customer_id) + return + + # Downgrade to free + Core.Database.downgrade_to_free(user["id"]) + logger.info("Downgraded user %s to free tier", user["id"]) + + +def _handle_payment_failed(invoice: dict[str, typing.Any]) -> None: + """Handle invoice.payment_failed event.""" + customer_id = invoice["customer"] + subscription_id = invoice.get("subscription") + + # Find user by customer ID + user = Core.Database.get_user_by_stripe_customer_id(customer_id) + if not user: + logger.warning("User not found for customer: %s", customer_id) + return + + # Update subscription status to past_due + if subscription_id: + Core.Database.update_subscription_status(user["id"], "past_due") + logger.warning( + "Payment failed for user %s, subscription %s", + user["id"], + subscription_id, + ) + + +def _update_subscription_state(subscription: dict[str, typing.Any]) -> None: + """Update user subscription state from Stripe subscription object.""" + customer_id = subscription.get("customer") + subscription_id = subscription.get("id") + status = subscription.get("status") + cancel_at_period_end = subscription.get("cancel_at_period_end", False) + + if not customer_id or not subscription_id or not status: + logger.warning( + "Missing required fields in subscription: %s", + subscription_id, + ) + return + + # Get billing period - try multiple field names for API compatibility + period_start_ts = ( + subscription.get("current_period_start") + or subscription.get("billing_cycle_anchor") + or subscription.get("start_date") + ) + period_end_ts = subscription.get("current_period_end") + + if not period_start_ts: + logger.warning( + "Missing period start in subscription: %s", + subscription_id, + ) + return + + period_start = datetime.fromtimestamp(period_start_ts, tz=timezone.utc) + + # Calculate period end if not provided (assume monthly) + december = 12 + january = 1 + if not period_end_ts: + if period_start.month == december: + period_end = period_start.replace( + year=period_start.year + 1, + month=january, + ) + else: + period_end = period_start.replace(month=period_start.month + 1) + else: + period_end = datetime.fromtimestamp(period_end_ts, tz=timezone.utc) + + # Determine tier from price ID + items = subscription.get("items", {}) + data = items.get("data", []) + if not data: + logger.warning("No items in subscription: %s", subscription_id) + return + + price_id = data[0].get("price", {}).get("id") + if not price_id: + logger.warning("No price ID in subscription: %s", subscription_id) + return + + tier = PRICE_TO_TIER.get(price_id, "free") + + # Find user by customer ID + user = Core.Database.get_user_by_stripe_customer_id(customer_id) + if not user: + logger.warning("User not found for customer: %s", customer_id) + return + + # Update user subscription + Core.Database.update_user_subscription( + user["id"], + subscription_id, + status, + period_start, + period_end, + tier, + cancel_at_period_end, + ) + + logger.info( + "Updated user %s subscription: tier=%s, status=%s", + user["id"], + tier, + status, + ) + + +def get_tier_info(tier: str) -> dict[str, typing.Any]: + """Get tier information for display. + + Returns: + dict with keys: name, articles_limit, price, description + """ + tier_info: dict[str, dict[str, typing.Any]] = { + "free": { + "name": "Free", + "articles_limit": 10, + "price": "$0", + "description": "10 articles total", + }, + "paid": { + "name": "Paid", + "articles_limit": None, + "price": "$12/mo", + "description": "Unlimited articles", + }, + } + return tier_info.get(tier, tier_info["free"]) + + +# Tests +# ruff: noqa: PLR6301, PLW0603, S101 + + +class TestWebhookHandling(Test.TestCase): + """Test Stripe webhook handling.""" + + def setUp(self) -> None: + """Set up test database.""" + Core.Database.init_db() + + def tearDown(self) -> None: + """Clean up test database.""" + Core.Database.teardown() + + def test_full_checkout_flow(self) -> None: + """Test complete checkout flow from session to subscription.""" + # Create test user + user_id, _token = Core.Database.create_user("test@example.com") + + # Temporarily set price mapping for test + global PRICE_TO_TIER + old_mapping = PRICE_TO_TIER.copy() + PRICE_TO_TIER["price_test_paid"] = "paid" + + try: + # Step 1: Handle checkout.session.completed + checkout_session = { + "id": "cs_test123", + "client_reference_id": str(user_id), + "customer": "cus_test123", + "metadata": {"user_id": str(user_id), "tier": "paid"}, + } + _handle_checkout_completed(checkout_session) + + # Verify customer was linked + user = Core.Database.get_user_by_id(user_id) + self.assertIsNotNone(user) + assert user is not None + self.assertEqual(user["stripe_customer_id"], "cus_test123") + + # Step 2: Handle customer.subscription.created + # (newer API uses billing_cycle_anchor instead of current_period_*) + subscription = { + "id": "sub_test123", + "customer": "cus_test123", + "status": "active", + "billing_cycle_anchor": 1700000000, + "cancel_at_period_end": False, + "items": { + "data": [ + { + "price": { + "id": "price_test_paid", + }, + }, + ], + }, + } + _update_subscription_state(subscription) + + # Verify subscription was created and user upgraded + user = Core.Database.get_user_by_id(user_id) + self.assertIsNotNone(user) + assert user is not None + self.assertEqual(user["plan_tier"], "paid") + self.assertEqual(user["subscription_status"], "active") + self.assertEqual(user["stripe_subscription_id"], "sub_test123") + self.assertEqual(user["stripe_customer_id"], "cus_test123") + finally: + PRICE_TO_TIER = old_mapping + + def test_webhook_missing_fields(self) -> None: + """Test handling webhook with missing required fields.""" + # Create test user + user_id, _token = Core.Database.create_user("test@example.com") + Core.Database.set_user_stripe_customer(user_id, "cus_test456") + + # Mock subscription with missing current_period_start + subscription = { + "id": "sub_test456", + "customer": "cus_test456", + "status": "active", + # Missing current_period_start and current_period_end + "cancel_at_period_end": False, + "items": {"data": []}, + } + + # Should not crash, just log warning and return + _update_subscription_state(subscription) + + # User should remain on free tier + user = Core.Database.get_user_by_id(user_id) + self.assertIsNotNone(user) + assert user is not None + self.assertEqual(user["plan_tier"], "free") + + +def main() -> None: + """Run tests.""" + if len(sys.argv) > 1 and sys.argv[1] == "test": + os.environ["AREA"] = "Test" + Test.run(App.Area.Test, [TestWebhookHandling]) + else: + logger.error("Usage: billing.py test") + + +if __name__ == "__main__": + main() |
