summaryrefslogtreecommitdiff
path: root/Biz
diff options
context:
space:
mode:
authorBen Sima <ben@bsima.me>2025-11-20 23:17:35 -0500
committerBen Sima <ben@bsima.me>2025-11-20 23:17:35 -0500
commitd98cc29446eb14647d62d0372124fb0be08ec5ae (patch)
treef27acd5cef177a31173df4d23d8f637d9fac2fe3 /Biz
parent47f5162b48323b50deda761e9349cdf43406930a (diff)
feat: implement t-144gQry
Diffstat (limited to 'Biz')
-rw-r--r--Biz/PodcastItLater/UI.py12
-rw-r--r--Biz/PodcastItLater/Web.py64
2 files changed, 66 insertions, 10 deletions
diff --git a/Biz/PodcastItLater/UI.py b/Biz/PodcastItLater/UI.py
index 27f5fff..59de7e5 100644
--- a/Biz/PodcastItLater/UI.py
+++ b/Biz/PodcastItLater/UI.py
@@ -6,6 +6,7 @@ Common UI components and utilities shared across web pages.
# : out podcastitlater-ui
# : dep ludic
+import Biz.PodcastItLater.Core as Core
import ludic.html as html
import typing
from ludic.attrs import Attrs
@@ -127,15 +128,6 @@ def create_bootstrap_js() -> html.script:
)
-def is_admin(user: dict[str, typing.Any] | None) -> bool:
- """Check if user is an admin based on email whitelist."""
- if not user:
- return False
- admin_emails = ["ben@bensima.com", "admin@example.com"]
- return user.get("email", "").lower() in [
- email.lower() for email in admin_emails
- ]
-
class PageLayoutAttrs(Attrs):
"""Attributes for PageLayout component."""
@@ -318,7 +310,7 @@ class PageLayout(Component[AnyChildren, PageLayoutAttrs]):
),
classes=["nav-item", "dropdown"],
)
- if user and is_admin(user)
+ if user and Core.is_admin(user)
else html.span(),
classes=["navbar-nav"],
),
diff --git a/Biz/PodcastItLater/Web.py b/Biz/PodcastItLater/Web.py
index 7e8e969..e80b0b4 100644
--- a/Biz/PodcastItLater/Web.py
+++ b/Biz/PodcastItLater/Web.py
@@ -36,6 +36,7 @@ import re
import sys
import tempfile
import typing
+from unittest.mock import patch
import urllib.parse
import uvicorn
from datetime import datetime
@@ -3164,6 +3165,68 @@ class TestUsageLimits(BaseWebTest):
self.assertEqual(usage["articles"], 20)
+class TestAccountPage(BaseWebTest):
+ """Test account page functionality."""
+
+ def setUp(self) -> None:
+ """Set up test client with logged-in user."""
+ super().setUp()
+ self.user_id, _ = Core.Database.create_user(
+ "test@example.com",
+ status="active",
+ )
+ self.client.post("/login", data={"email": "test@example.com"})
+
+ def test_account_page_logged_in(self) -> None:
+ """Account page should render for logged-in users."""
+ response = self.client.get("/account")
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("Account Management", response.text)
+ self.assertIn("test@example.com", response.text)
+
+ def test_account_page_logged_out(self) -> None:
+ """Account page should redirect logged-out users."""
+ self.client.post("/logout")
+ response = self.client.get("/account", follow_redirects=False)
+ self.assertEqual(response.status_code, 307)
+ self.assertEqual(
+ response.headers["location"],
+ "/?error=login_required",
+ )
+
+ def test_logout(self) -> None:
+ """Logout should clear session."""
+ response = self.client.post("/logout", follow_redirects=False)
+ self.assertEqual(response.status_code, 303)
+ self.assertEqual(response.headers["location"], "/")
+
+ # Verify session cleared
+ response = self.client.get("/account", follow_redirects=False)
+ self.assertEqual(response.status_code, 307)
+
+ def test_billing_portal_redirect(self) -> None:
+ """Billing portal should redirect to Stripe."""
+ # First set a customer ID
+ Core.Database.set_user_stripe_customer(self.user_id, "cus_test")
+
+ # Mock the create_portal_session method
+ with patch(
+ "Biz.PodcastItLater.Billing.create_portal_session",
+ ) as mock_portal:
+ mock_portal.return_value = "https://billing.stripe.com/test"
+
+ response = self.client.post(
+ "/billing/portal",
+ follow_redirects=False,
+ )
+
+ self.assertEqual(response.status_code, 303)
+ self.assertEqual(
+ response.headers["location"],
+ "https://billing.stripe.com/test",
+ )
+
+
def test() -> None:
"""Run all tests for the web module."""
Test.run(
@@ -3180,6 +3243,7 @@ def test() -> None:
TestEpisodeDeduplication,
TestMetricsTracking,
TestUsageLimits,
+ TestAccountPage,
],
)