diff options
Diffstat (limited to 'Biz')
| -rw-r--r-- | Biz/PodcastItLater/UI.py | 12 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Web.py | 64 |
2 files changed, 66 insertions, 10 deletions
diff --git a/Biz/PodcastItLater/UI.py b/Biz/PodcastItLater/UI.py index 27f5fff..59de7e5 100644 --- a/Biz/PodcastItLater/UI.py +++ b/Biz/PodcastItLater/UI.py @@ -6,6 +6,7 @@ Common UI components and utilities shared across web pages. # : out podcastitlater-ui # : dep ludic +import Biz.PodcastItLater.Core as Core import ludic.html as html import typing from ludic.attrs import Attrs @@ -127,15 +128,6 @@ def create_bootstrap_js() -> html.script: ) -def is_admin(user: dict[str, typing.Any] | None) -> bool: - """Check if user is an admin based on email whitelist.""" - if not user: - return False - admin_emails = ["ben@bensima.com", "admin@example.com"] - return user.get("email", "").lower() in [ - email.lower() for email in admin_emails - ] - class PageLayoutAttrs(Attrs): """Attributes for PageLayout component.""" @@ -318,7 +310,7 @@ class PageLayout(Component[AnyChildren, PageLayoutAttrs]): ), classes=["nav-item", "dropdown"], ) - if user and is_admin(user) + if user and Core.is_admin(user) else html.span(), classes=["navbar-nav"], ), diff --git a/Biz/PodcastItLater/Web.py b/Biz/PodcastItLater/Web.py index 7e8e969..e80b0b4 100644 --- a/Biz/PodcastItLater/Web.py +++ b/Biz/PodcastItLater/Web.py @@ -36,6 +36,7 @@ import re import sys import tempfile import typing +from unittest.mock import patch import urllib.parse import uvicorn from datetime import datetime @@ -3164,6 +3165,68 @@ class TestUsageLimits(BaseWebTest): self.assertEqual(usage["articles"], 20) +class TestAccountPage(BaseWebTest): + """Test account page functionality.""" + + def setUp(self) -> None: + """Set up test client with logged-in user.""" + super().setUp() + self.user_id, _ = Core.Database.create_user( + "test@example.com", + status="active", + ) + self.client.post("/login", data={"email": "test@example.com"}) + + def test_account_page_logged_in(self) -> None: + """Account page should render for logged-in users.""" + response = self.client.get("/account") + self.assertEqual(response.status_code, 200) + self.assertIn("Account Management", response.text) + self.assertIn("test@example.com", response.text) + + def test_account_page_logged_out(self) -> None: + """Account page should redirect logged-out users.""" + self.client.post("/logout") + response = self.client.get("/account", follow_redirects=False) + self.assertEqual(response.status_code, 307) + self.assertEqual( + response.headers["location"], + "/?error=login_required", + ) + + def test_logout(self) -> None: + """Logout should clear session.""" + response = self.client.post("/logout", follow_redirects=False) + self.assertEqual(response.status_code, 303) + self.assertEqual(response.headers["location"], "/") + + # Verify session cleared + response = self.client.get("/account", follow_redirects=False) + self.assertEqual(response.status_code, 307) + + def test_billing_portal_redirect(self) -> None: + """Billing portal should redirect to Stripe.""" + # First set a customer ID + Core.Database.set_user_stripe_customer(self.user_id, "cus_test") + + # Mock the create_portal_session method + with patch( + "Biz.PodcastItLater.Billing.create_portal_session", + ) as mock_portal: + mock_portal.return_value = "https://billing.stripe.com/test" + + response = self.client.post( + "/billing/portal", + follow_redirects=False, + ) + + self.assertEqual(response.status_code, 303) + self.assertEqual( + response.headers["location"], + "https://billing.stripe.com/test", + ) + + def test() -> None: """Run all tests for the web module.""" Test.run( @@ -3180,6 +3243,7 @@ def test() -> None: TestEpisodeDeduplication, TestMetricsTracking, TestUsageLimits, + TestAccountPage, ], ) |
