summaryrefslogtreecommitdiff
path: root/Biz/PodcastItLater
diff options
context:
space:
mode:
Diffstat (limited to 'Biz/PodcastItLater')
-rw-r--r--Biz/PodcastItLater/Billing.py37
-rw-r--r--Biz/PodcastItLater/Web.py72
2 files changed, 92 insertions, 17 deletions
diff --git a/Biz/PodcastItLater/Billing.py b/Biz/PodcastItLater/Billing.py
index e472889..025f1aa 100644
--- a/Biz/PodcastItLater/Billing.py
+++ b/Biz/PodcastItLater/Billing.py
@@ -12,6 +12,7 @@ import Biz.PodcastItLater.Core as Core
import Omni.Log as Log
import os
import stripe
+import typing
from datetime import datetime
from datetime import timezone
@@ -54,7 +55,9 @@ PRICE_MAP = {
}
-def get_period_boundaries(user: dict) -> tuple[datetime, datetime]:
+def get_period_boundaries(
+ user: dict[str, typing.Any],
+) -> tuple[datetime, datetime]:
"""Get billing period boundaries for user.
For paid users: use Stripe subscription period.
@@ -90,7 +93,7 @@ def get_usage(
user_id: int,
period_start: datetime,
period_end: datetime,
-) -> dict:
+) -> dict[str, int]:
"""Get usage stats for user in billing period.
Returns:
@@ -99,7 +102,7 @@ def get_usage(
return Core.Database.get_usage(user_id, period_start, period_end)
-def can_submit(user_id: int) -> tuple[bool, str, dict]:
+def can_submit(user_id: int) -> tuple[bool, str, dict[str, int]]:
"""Check if user can submit article based on tier limits.
Returns:
@@ -119,7 +122,7 @@ def can_submit(user_id: int) -> tuple[bool, str, dict]:
usage = get_usage(user_id, period_start, period_end)
# Check article limit
- article_limit = limits["articles_per_period"]
+ article_limit = limits["articles_per_period"] # type: ignore[index]
if article_limit is not None and usage["articles"] >= article_limit:
msg = (
f"You've reached your limit of {article_limit} articles "
@@ -128,7 +131,7 @@ def can_submit(user_id: int) -> tuple[bool, str, dict]:
return (False, msg, usage)
# Check minutes limit (if implemented)
- minute_limit = limits.get("minutes_per_period")
+ minute_limit = limits.get("minutes_per_period") # type: ignore[attr-defined]
if minute_limit is not None and usage.get("minutes", 0) >= minute_limit:
return (
False,
@@ -185,7 +188,7 @@ def create_checkout_session(user_id: int, tier: str, base_url: str) -> str:
else:
session_params["customer_email"] = user["email"]
- session = stripe.checkout.Session.create(**session_params)
+ session = stripe.checkout.Session.create(**session_params) # type: ignore[arg-type]
logger.info(
"Created checkout session for user %s, tier %s: %s",
@@ -194,7 +197,7 @@ def create_checkout_session(user_id: int, tier: str, base_url: str) -> str:
session.id,
)
- return session.url
+ return session.url # type: ignore[return-value]
def create_portal_session(user_id: int, base_url: str) -> str:
@@ -229,7 +232,7 @@ def create_portal_session(user_id: int, base_url: str) -> str:
return session.url
-def handle_webhook_event(payload: bytes, sig_header: str) -> dict:
+def handle_webhook_event(payload: bytes, sig_header: str) -> dict[str, str]:
"""Verify and process Stripe webhook event.
Args:
@@ -243,7 +246,7 @@ def handle_webhook_event(payload: bytes, sig_header: str) -> dict:
May raise stripe.error.SignatureVerificationError if invalid signature
"""
# Verify webhook signature
- event = stripe.Webhook.construct_event(
+ event = stripe.Webhook.construct_event( # type: ignore[no-untyped-call]
payload,
sig_header,
STRIPE_WEBHOOK_SECRET,
@@ -284,7 +287,7 @@ def handle_webhook_event(payload: bytes, sig_header: str) -> dict:
return {"status": "processed", "type": event_type}
-def _handle_checkout_completed(session: dict) -> None:
+def _handle_checkout_completed(session: dict[str, typing.Any]) -> None:
"""Handle checkout.session.completed event."""
user_id = int(session.get("client_reference_id", 0))
customer_id = session.get("customer")
@@ -301,17 +304,17 @@ def _handle_checkout_completed(session: dict) -> None:
logger.info("Linked user %s to Stripe customer %s", user_id, customer_id)
-def _handle_subscription_created(subscription: dict) -> None:
+def _handle_subscription_created(subscription: dict[str, typing.Any]) -> None:
"""Handle customer.subscription.created event."""
_update_subscription_state(subscription)
-def _handle_subscription_updated(subscription: dict) -> None:
+def _handle_subscription_updated(subscription: dict[str, typing.Any]) -> None:
"""Handle customer.subscription.updated event."""
_update_subscription_state(subscription)
-def _handle_subscription_deleted(subscription: dict) -> None:
+def _handle_subscription_deleted(subscription: dict[str, typing.Any]) -> None:
"""Handle customer.subscription.deleted event."""
customer_id = subscription["customer"]
@@ -326,7 +329,7 @@ def _handle_subscription_deleted(subscription: dict) -> None:
logger.info("Downgraded user %s to free tier", user["id"])
-def _handle_payment_failed(invoice: dict) -> None:
+def _handle_payment_failed(invoice: dict[str, typing.Any]) -> None:
"""Handle invoice.payment_failed event."""
customer_id = invoice["customer"]
subscription_id = invoice.get("subscription")
@@ -347,7 +350,7 @@ def _handle_payment_failed(invoice: dict) -> None:
)
-def _update_subscription_state(subscription: dict) -> None:
+def _update_subscription_state(subscription: dict[str, typing.Any]) -> None:
"""Update user subscription state from Stripe subscription object."""
customer_id = subscription["customer"]
subscription_id = subscription["id"]
@@ -393,7 +396,7 @@ def _update_subscription_state(subscription: dict) -> None:
)
-def get_tier_info(tier: str) -> dict:
+def get_tier_info(tier: str) -> dict[str, typing.Any]:
"""Get tier information for display.
Returns:
@@ -419,4 +422,4 @@ def get_tier_info(tier: str) -> dict:
"description": "Unlimited articles",
},
}
- return tier_info.get(tier, tier_info["free"])
+ return tier_info.get(tier, tier_info["free"]) # type: ignore[return-value]
diff --git a/Biz/PodcastItLater/Web.py b/Biz/PodcastItLater/Web.py
index 2f86b16..c469874 100644
--- a/Biz/PodcastItLater/Web.py
+++ b/Biz/PodcastItLater/Web.py
@@ -15,8 +15,10 @@ Provides ludic + htmx interface and RSS feed generation.
# : dep pytest-asyncio
# : dep pytest-mock
# : dep starlette
+# : dep stripe
import Biz.EmailAgent
import Biz.PodcastItLater.Admin as Admin
+import Biz.PodcastItLater.Billing as Billing
import Biz.PodcastItLater.Core as Core
import html as html_module
import httpx
@@ -1123,6 +1125,76 @@ app.get("/admin")(Admin.admin_queue_status)
app.post("/queue/{job_id}/retry")(Admin.retry_queue_item)
+@app.get("/billing")
+def billing_page(request: Request) -> Response:
+ """Display billing page with current plan and upgrade options."""
+ user_id = request.session.get("user_id")
+ if not user_id:
+ return RedirectResponse(url="/?error=login_required")
+
+ user = Core.Database.get_user_by_id(user_id)
+ if not user:
+ return RedirectResponse(url="/?error=user_not_found")
+
+ tier = user.get("plan_tier", "free")
+ tier_info = Billing.get_tier_info(tier)
+
+ # Get current usage
+ period_start, period_end = Billing.get_period_boundaries(user)
+ Billing.get_usage(user_id, period_start, period_end)
+
+ # Billing page component to be implemented
+ return Response(f"<h1>Billing - Current plan: {tier_info['name']}</h1>")
+
+
+@app.post("/billing/checkout")
+def billing_checkout(request: Request, data: FormData) -> Response:
+ """Create Stripe Checkout session."""
+ user_id = request.session.get("user_id")
+ if not user_id:
+ return Response("Unauthorized", status_code=401)
+
+ tier = data.get("tier", "personal")
+ if tier not in {"personal", "pro"}:
+ return Response("Invalid tier", status_code=400)
+
+ try:
+ checkout_url = Billing.create_checkout_session(user_id, tier, BASE_URL)
+ return RedirectResponse(url=checkout_url)
+ except ValueError as e:
+ logger.exception("Checkout error")
+ return Response(f"Error: {e!s}", status_code=400)
+
+
+@app.post("/billing/portal")
+def billing_portal(request: Request) -> Response:
+ """Create Stripe Billing Portal session."""
+ user_id = request.session.get("user_id")
+ if not user_id:
+ return Response("Unauthorized", status_code=401)
+
+ try:
+ portal_url = Billing.create_portal_session(user_id, BASE_URL)
+ return RedirectResponse(url=portal_url)
+ except ValueError as e:
+ logger.exception("Portal error")
+ return Response(f"Error: {e!s}", status_code=400)
+
+
+@app.post("/stripe/webhook")
+async def stripe_webhook(request: Request) -> Response:
+ """Handle Stripe webhook events."""
+ payload = await request.body()
+ sig_header = request.headers.get("stripe-signature", "")
+
+ try:
+ result = Billing.handle_webhook_event(payload, sig_header)
+ return Response(f"OK: {result['status']}", status_code=200)
+ except Exception as e:
+ logger.exception("Webhook error")
+ return Response(f"Error: {e!s}", status_code=400)
+
+
@app.post("/queue/{job_id}/cancel")
def cancel_queue_item(request: Request, job_id: int) -> Response:
"""Cancel a pending queue item."""