""" PodcastItLater Billing Integration. Stripe subscription management and usage enforcement. """ # : out podcastitlater-billing # : dep stripe # : dep pytest # : dep pytest-mock import Biz.PodcastItLater.Core as Core import Omni.Log as Log import os import stripe from datetime import datetime from datetime import timezone logger = Log.setup() # Stripe configuration stripe.api_key = os.getenv("STRIPE_SECRET_KEY", "") STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET", "") # Price IDs from Stripe dashboard STRIPE_PRICE_ID_PERSONAL = os.getenv("STRIPE_PRICE_ID_PERSONAL", "") STRIPE_PRICE_ID_PRO = os.getenv("STRIPE_PRICE_ID_PRO", "") # Map Stripe price IDs to tier names PRICE_TO_TIER = { STRIPE_PRICE_ID_PERSONAL: "personal", STRIPE_PRICE_ID_PRO: "pro", } # Tier limits (None = unlimited) TIER_LIMITS = { "free": { "articles_per_period": 10, "minutes_per_period": None, }, "personal": { "articles_per_period": 50, "minutes_per_period": None, }, "pro": { "articles_per_period": None, "minutes_per_period": None, }, } # Price map for checkout PRICE_MAP = { "personal": STRIPE_PRICE_ID_PERSONAL, "pro": STRIPE_PRICE_ID_PRO, } def get_period_boundaries(user: dict) -> tuple[datetime, datetime]: """Get billing period boundaries for user. For paid users: use Stripe subscription period. For free users: use current calendar month (UTC). Returns: tuple: (period_start, period_end) as datetime objects """ if user.get("plan_tier") != "free" and user.get("current_period_start"): # Paid user - use Stripe billing period period_start = datetime.fromisoformat(user["current_period_start"]) period_end = datetime.fromisoformat(user["current_period_end"]) else: # Free user - use calendar month now = datetime.now(timezone.utc) period_start = now.replace( day=1, hour=0, minute=0, second=0, microsecond=0, ) # Get first day of next month if now.month == 12: # noqa: PLR2004 period_end = period_start.replace(year=now.year + 1, month=1) else: period_end = period_start.replace(month=now.month + 1) return period_start, period_end def get_usage( user_id: int, period_start: datetime, period_end: datetime, ) -> dict: """Get usage stats for user in billing period. Returns: dict with keys: articles (int), minutes (int) """ return Core.Database.get_usage(user_id, period_start, period_end) def can_submit(user_id: int) -> tuple[bool, str, dict]: """Check if user can submit article based on tier limits. Returns: tuple: (allowed: bool, message: str, usage: dict) """ user = Core.Database.get_user_by_id(user_id) if not user: return False, "User not found", {} tier = user.get("plan_tier", "free") limits = TIER_LIMITS.get(tier, TIER_LIMITS["free"]) # Get billing period boundaries period_start, period_end = get_period_boundaries(user) # Get current usage usage = get_usage(user_id, period_start, period_end) # Check article limit article_limit = limits["articles_per_period"] if article_limit is not None and usage["articles"] >= article_limit: msg = ( f"You've reached your limit of {article_limit} articles " "per period. Upgrade to continue." ) return (False, msg, usage) # Check minutes limit (if implemented) minute_limit = limits.get("minutes_per_period") if minute_limit is not None and usage.get("minutes", 0) >= minute_limit: return ( False, f"You've reached your limit of {minute_limit} minutes per period. " "Please upgrade to continue.", usage, ) return True, "", usage def create_checkout_session(user_id: int, tier: str, base_url: str) -> str: """Create Stripe Checkout session for subscription. Args: user_id: User ID tier: Subscription tier (personal or pro) base_url: Base URL for success/cancel redirects Returns: Checkout session URL to redirect user to Raises: ValueError: If tier is invalid or price ID not configured """ if tier not in PRICE_MAP: msg = f"Invalid tier: {tier}" raise ValueError(msg) price_id = PRICE_MAP[tier] if not price_id: msg = f"Stripe price ID not configured for tier: {tier}" raise ValueError(msg) user = Core.Database.get_user_by_id(user_id) if not user: msg = f"User not found: {user_id}" raise ValueError(msg) # Create checkout session session_params = { "mode": "subscription", "line_items": [{"price": price_id, "quantity": 1}], "success_url": f"{base_url}/billing?status=success", "cancel_url": f"{base_url}/billing?status=cancel", "client_reference_id": str(user_id), "metadata": {"user_id": str(user_id), "tier": tier}, "allow_promotion_codes": True, } # Use existing customer if available if user.get("stripe_customer_id"): session_params["customer"] = user["stripe_customer_id"] else: session_params["customer_email"] = user["email"] session = stripe.checkout.Session.create(**session_params) logger.info( "Created checkout session for user %s, tier %s: %s", user_id, tier, session.id, ) return session.url def create_portal_session(user_id: int, base_url: str) -> str: """Create Stripe Billing Portal session. Args: user_id: User ID base_url: Base URL for return redirect Returns: Portal session URL to redirect user to Raises: ValueError: If user has no Stripe customer ID """ user = Core.Database.get_user_by_id(user_id) if not user: msg = f"User not found: {user_id}" raise ValueError(msg) if not user.get("stripe_customer_id"): msg = "User has no Stripe customer ID" raise ValueError(msg) session = stripe.billing_portal.Session.create( customer=user["stripe_customer_id"], return_url=f"{base_url}/billing", ) logger.info("Created portal session for user %s: %s", user_id, session.id) return session.url def handle_webhook_event(payload: bytes, sig_header: str) -> dict: """Verify and process Stripe webhook event. Args: payload: Raw webhook body sig_header: Stripe-Signature header value Returns: dict with processing status Note: May raise stripe.error.SignatureVerificationError if invalid signature """ # Verify webhook signature event = stripe.Webhook.construct_event( payload, sig_header, STRIPE_WEBHOOK_SECRET, ) event_id = event["id"] event_type = event["type"] # Check if already processed (idempotency) if Core.Database.has_processed_stripe_event(event_id): logger.info("Skipping already processed event: %s", event_id) return {"status": "skipped", "reason": "already_processed"} # Process event based on type logger.info("Processing webhook event: %s (%s)", event_id, event_type) try: if event_type == "checkout.session.completed": _handle_checkout_completed(event["data"]["object"]) elif event_type == "customer.subscription.created": _handle_subscription_created(event["data"]["object"]) elif event_type == "customer.subscription.updated": _handle_subscription_updated(event["data"]["object"]) elif event_type == "customer.subscription.deleted": _handle_subscription_deleted(event["data"]["object"]) elif event_type == "invoice.payment_failed": _handle_payment_failed(event["data"]["object"]) else: logger.info("Unhandled event type: %s", event_type) return {"status": "ignored", "type": event_type} # Mark event as processed Core.Database.mark_stripe_event_processed(event_id, event_type, payload) except Exception: logger.exception("Error processing webhook event %s", event_id) raise else: return {"status": "processed", "type": event_type} def _handle_checkout_completed(session: dict) -> None: """Handle checkout.session.completed event.""" user_id = int(session.get("client_reference_id", 0)) customer_id = session.get("customer") if not user_id or not customer_id: logger.warning( "Missing user_id or customer_id in checkout session: %s", session["id"], ) return # Link Stripe customer to user Core.Database.set_user_stripe_customer(user_id, customer_id) logger.info("Linked user %s to Stripe customer %s", user_id, customer_id) def _handle_subscription_created(subscription: dict) -> None: """Handle customer.subscription.created event.""" _update_subscription_state(subscription) def _handle_subscription_updated(subscription: dict) -> None: """Handle customer.subscription.updated event.""" _update_subscription_state(subscription) def _handle_subscription_deleted(subscription: dict) -> None: """Handle customer.subscription.deleted event.""" customer_id = subscription["customer"] # Find user by customer ID user = Core.Database.get_user_by_stripe_customer_id(customer_id) if not user: logger.warning("User not found for customer: %s", customer_id) return # Downgrade to free Core.Database.downgrade_to_free(user["id"]) logger.info("Downgraded user %s to free tier", user["id"]) def _handle_payment_failed(invoice: dict) -> None: """Handle invoice.payment_failed event.""" customer_id = invoice["customer"] subscription_id = invoice.get("subscription") # Find user by customer ID user = Core.Database.get_user_by_stripe_customer_id(customer_id) if not user: logger.warning("User not found for customer: %s", customer_id) return # Update subscription status to past_due if subscription_id: Core.Database.update_subscription_status(user["id"], "past_due") logger.warning( "Payment failed for user %s, subscription %s", user["id"], subscription_id, ) def _update_subscription_state(subscription: dict) -> None: """Update user subscription state from Stripe subscription object.""" customer_id = subscription["customer"] subscription_id = subscription["id"] status = subscription["status"] cancel_at_period_end = subscription.get("cancel_at_period_end", False) # Get billing period period_start = datetime.fromtimestamp( subscription["current_period_start"], tz=timezone.utc, ) period_end = datetime.fromtimestamp( subscription["current_period_end"], tz=timezone.utc, ) # Determine tier from price ID price_id = subscription["items"]["data"][0]["price"]["id"] tier = PRICE_TO_TIER.get(price_id, "free") # Find user by customer ID user = Core.Database.get_user_by_stripe_customer_id(customer_id) if not user: logger.warning("User not found for customer: %s", customer_id) return # Update user subscription Core.Database.update_user_subscription( user["id"], subscription_id, status, period_start, period_end, tier, cancel_at_period_end, ) logger.info( "Updated user %s subscription: tier=%s, status=%s", user["id"], tier, status, ) def get_tier_info(tier: str) -> dict: """Get tier information for display. Returns: dict with keys: name, articles_limit, price, description """ tier_info = { "free": { "name": "Free", "articles_limit": 10, "price": "$0", "description": "10 articles per month", }, "personal": { "name": "Personal", "articles_limit": 50, "price": "$9/mo", "description": "50 articles per month", }, "pro": { "name": "Pro", "articles_limit": None, "price": "$29/mo", "description": "Unlimited articles", }, } return tier_info.get(tier, tier_info["free"])