diff options
Diffstat (limited to 'Biz/PodcastItLater/Web.py')
| -rw-r--r-- | Biz/PodcastItLater/Web.py | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/Biz/PodcastItLater/Web.py b/Biz/PodcastItLater/Web.py index 7e8e969..e80b0b4 100644 --- a/Biz/PodcastItLater/Web.py +++ b/Biz/PodcastItLater/Web.py @@ -36,6 +36,7 @@ import re import sys import tempfile import typing +from unittest.mock import patch import urllib.parse import uvicorn from datetime import datetime @@ -3164,6 +3165,68 @@ class TestUsageLimits(BaseWebTest): self.assertEqual(usage["articles"], 20) +class TestAccountPage(BaseWebTest): + """Test account page functionality.""" + + def setUp(self) -> None: + """Set up test client with logged-in user.""" + super().setUp() + self.user_id, _ = Core.Database.create_user( + "test@example.com", + status="active", + ) + self.client.post("/login", data={"email": "test@example.com"}) + + def test_account_page_logged_in(self) -> None: + """Account page should render for logged-in users.""" + response = self.client.get("/account") + self.assertEqual(response.status_code, 200) + self.assertIn("Account Management", response.text) + self.assertIn("test@example.com", response.text) + + def test_account_page_logged_out(self) -> None: + """Account page should redirect logged-out users.""" + self.client.post("/logout") + response = self.client.get("/account", follow_redirects=False) + self.assertEqual(response.status_code, 307) + self.assertEqual( + response.headers["location"], + "/?error=login_required", + ) + + def test_logout(self) -> None: + """Logout should clear session.""" + response = self.client.post("/logout", follow_redirects=False) + self.assertEqual(response.status_code, 303) + self.assertEqual(response.headers["location"], "/") + + # Verify session cleared + response = self.client.get("/account", follow_redirects=False) + self.assertEqual(response.status_code, 307) + + def test_billing_portal_redirect(self) -> None: + """Billing portal should redirect to Stripe.""" + # First set a customer ID + Core.Database.set_user_stripe_customer(self.user_id, "cus_test") + + # Mock the create_portal_session method + with patch( + "Biz.PodcastItLater.Billing.create_portal_session", + ) as mock_portal: + mock_portal.return_value = "https://billing.stripe.com/test" + + response = self.client.post( + "/billing/portal", + follow_redirects=False, + ) + + self.assertEqual(response.status_code, 303) + self.assertEqual( + response.headers["location"], + "https://billing.stripe.com/test", + ) + + def test() -> None: """Run all tests for the web module.""" Test.run( @@ -3180,6 +3243,7 @@ def test() -> None: TestEpisodeDeduplication, TestMetricsTracking, TestUsageLimits, + TestAccountPage, ], ) |
