diff options
Diffstat (limited to 'Biz/PodcastItLater')
| -rw-r--r-- | Biz/PodcastItLater/Admin.py | 23 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Core.py | 44 | ||||
| -rwxr-xr-x[-rw-r--r--] | Biz/PodcastItLater/Test.py | 0 | ||||
| -rw-r--r-- | Biz/PodcastItLater/UI.py | 58 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Web.py | 168 |
5 files changed, 220 insertions, 73 deletions
diff --git a/Biz/PodcastItLater/Admin.py b/Biz/PodcastItLater/Admin.py index 10a8e58..6faf7fb 100644 --- a/Biz/PodcastItLater/Admin.py +++ b/Biz/PodcastItLater/Admin.py @@ -795,7 +795,7 @@ def admin_queue_status(request: Request) -> AdminView | Response | html.div: def retry_queue_item(request: Request, job_id: int) -> Response: """Retry a failed queue item.""" try: - # Check if user owns this job + # Check if user owns this job or is admin user_id = request.session.get("user_id") if not user_id: return Response("Unauthorized", status_code=401) @@ -803,15 +803,30 @@ def retry_queue_item(request: Request, job_id: int) -> Response: job = Core.Database.get_job_by_id( job_id, ) - if job is None or job.get("user_id") != user_id: + if job is None: + return Response("Job not found", status_code=404) + + # Check ownership or admin status + user = Core.Database.get_user_by_id(user_id) + if job.get("user_id") != user_id and not Core.is_admin(user): return Response("Forbidden", status_code=403) Core.Database.retry_job(job_id) - # Redirect back to admin view + + # Check if request is from admin page via referer header + is_from_admin = "/admin" in request.headers.get("referer", "") + + # Redirect to admin if from admin page, trigger update otherwise + if is_from_admin: + return Response( + "", + status_code=200, + headers={"HX-Redirect": "/admin"}, + ) return Response( "", status_code=200, - headers={"HX-Redirect": "/admin"}, + headers={"HX-Trigger": "queue-updated"}, ) except (ValueError, KeyError) as e: return Response( diff --git a/Biz/PodcastItLater/Core.py b/Biz/PodcastItLater/Core.py index 738531f..4a017cc 100644 --- a/Biz/PodcastItLater/Core.py +++ b/Biz/PodcastItLater/Core.py @@ -373,7 +373,10 @@ class Database: # noqa: PLR0904 SELECT id, url, email, status, created_at, error_message, title, author FROM queue - WHERE status IN ('pending', 'processing', 'extracting', 'synthesizing', 'uploading', 'error') + WHERE status IN ( + 'pending', 'processing', 'extracting', + 'synthesizing', 'uploading', 'error' + ) ORDER BY created_at DESC LIMIT 20 """) @@ -888,7 +891,10 @@ class Database: # noqa: PLR0904 title, author FROM queue WHERE user_id = ? AND - status IN ('pending', 'processing', 'extracting', 'synthesizing', 'uploading', 'error') + status IN ( + 'pending', 'processing', 'extracting', + 'synthesizing', 'uploading', 'error' + ) ORDER BY created_at DESC LIMIT 20 """, @@ -1785,6 +1791,40 @@ class TestQueueOperations(Test.TestCase): self.assertEqual(counts.get("processing", 0), 1) self.assertEqual(counts.get("error", 0), 1) + def test_queue_position(self) -> None: + """Verify queue position calculation.""" + # Add multiple pending jobs + job1 = Database.add_to_queue( + "https://example.com/1", + "test@example.com", + self.user_id, + ) + time.sleep(0.01) + job2 = Database.add_to_queue( + "https://example.com/2", + "test@example.com", + self.user_id, + ) + time.sleep(0.01) + job3 = Database.add_to_queue( + "https://example.com/3", + "test@example.com", + self.user_id, + ) + + # Check positions + self.assertEqual(Database.get_queue_position(job1), 1) + self.assertEqual(Database.get_queue_position(job2), 2) + self.assertEqual(Database.get_queue_position(job3), 3) + + # Move job 2 to processing + Database.update_job_status(job2, "processing") + + # Check positions (job 3 should now be 2nd pending job) + self.assertEqual(Database.get_queue_position(job1), 1) + self.assertIsNone(Database.get_queue_position(job2)) + self.assertEqual(Database.get_queue_position(job3), 2) + class TestEpisodeManagement(Test.TestCase): """Test episode management functionality.""" diff --git a/Biz/PodcastItLater/Test.py b/Biz/PodcastItLater/Test.py index b2a1d24..b2a1d24 100644..100755 --- a/Biz/PodcastItLater/Test.py +++ b/Biz/PodcastItLater/Test.py diff --git a/Biz/PodcastItLater/UI.py b/Biz/PodcastItLater/UI.py index 4650708..bdf7a5b 100644 --- a/Biz/PodcastItLater/UI.py +++ b/Biz/PodcastItLater/UI.py @@ -6,6 +6,7 @@ Common UI components and utilities shared across web pages. # : out podcastitlater-ui # : dep ludic +import Biz.PodcastItLater.Core as Core import ludic.html as html import typing from ludic.attrs import Attrs @@ -127,15 +128,6 @@ def create_bootstrap_js() -> html.script: ) -def is_admin(user: dict[str, typing.Any] | None) -> bool: - """Check if user is an admin based on email whitelist.""" - if not user: - return False - admin_emails = ["ben@bensima.com", "admin@example.com"] - return user.get("email", "").lower() in [ - email.lower() for email in admin_emails - ] - class PageLayoutAttrs(Attrs): """Attributes for PageLayout component.""" @@ -318,7 +310,7 @@ class PageLayout(Component[AnyChildren, PageLayoutAttrs]): ), classes=["nav-item", "dropdown"], ) - if user and is_admin(user) + if user and Core.is_admin(user) else html.span(), classes=["navbar-nav"], ), @@ -435,12 +427,7 @@ class AccountPage(Component[AnyChildren, AccountPageAttrs]): limit_text = "Unlimited" if article_limit is None else str(article_limit) return PageLayout( - user=user, - current_page="account", - page_title="Account - PodcastItLater", - error=None, - meta_tags=[], - children=[ + html.div( html.div( html.div( html.div( @@ -448,8 +435,8 @@ class AccountPage(Component[AnyChildren, AccountPageAttrs]): html.h2( html.i( classes=[ - "bi", - "bi-person-circle", + "bi", + "bi-person-circle", "me-2" ] ), @@ -481,8 +468,8 @@ class AccountPage(Component[AnyChildren, AccountPageAttrs]): plan_tier.title(), classes=[ "badge", - "bg-success" - if is_paid + "bg-success" + if is_paid else "bg-secondary", "ms-2", ], @@ -527,9 +514,9 @@ class AccountPage(Component[AnyChildren, AccountPageAttrs]): type="submit", classes=["btn", "btn-outline-primary"], ), - method="POST", + method="post", action=portal_url, - ) if is_paid and portal_url else + ) if is_paid and portal_url else html.a( html.i(classes=["bi", "bi-star-fill", "me-2"]), "Upgrade to Pro", @@ -548,8 +535,8 @@ class AccountPage(Component[AnyChildren, AccountPageAttrs]): html.button( html.i( classes=[ - "bi", - "bi-box-arrow-right", + "bi", + "bi-box-arrow-right", "me-2" ] ), @@ -558,7 +545,7 @@ class AccountPage(Component[AnyChildren, AccountPageAttrs]): classes=["btn", "btn-outline-danger"], ), action="/logout", - method="POST", + method="post", ), classes=["border-top", "pt-4"], ), @@ -570,7 +557,12 @@ class AccountPage(Component[AnyChildren, AccountPageAttrs]): ), classes=["row"], ), - ], + ), + user=user, + current_page="account", + page_title="Account - PodcastItLater", + error=None, + meta_tags=[], ) @@ -589,12 +581,7 @@ class PricingPage(Component[AnyChildren, PricingPageAttrs]): current_tier = user.get("plan_tier", "free") if user else "free" return PageLayout( - user=user, - current_page="pricing", - page_title="Pricing - PodcastItLater", - error=None, - meta_tags=[], - children=[ + html.div( html.div( html.h2("Simple Pricing", classes=["text-center", "mb-5"]), html.div( @@ -714,5 +701,10 @@ class PricingPage(Component[AnyChildren, PricingPageAttrs]): ), classes=["container", "py-3"], ), - ], + ), + user=user, + current_page="pricing", + page_title="Pricing - PodcastItLater", + error=None, + meta_tags=[], ) diff --git a/Biz/PodcastItLater/Web.py b/Biz/PodcastItLater/Web.py index 044cdca..0f095d3 100644 --- a/Biz/PodcastItLater/Web.py +++ b/Biz/PodcastItLater/Web.py @@ -54,6 +54,7 @@ from starlette.middleware.sessions import SessionMiddleware from starlette.responses import RedirectResponse from starlette.testclient import TestClient from typing import override +from unittest.mock import patch logger = logging.getLogger(__name__) Log.setup(logger) @@ -384,6 +385,11 @@ class QueueStatus(Component[AnyChildren, QueueStatusAttrs]): badge_class = status_classes.get(item["status"], "bg-secondary") icon_class = status_icons.get(item["status"], "bi-question-circle") + # Get queue position for pending items + queue_pos = None + if item["status"] == "pending": + queue_pos = Core.Database.get_queue_position(item["id"]) + queue_items.append( html.div( html.div( @@ -435,6 +441,16 @@ class QueueStatus(Component[AnyChildren, QueueStatusAttrs]): f"Created: {item['created_at']}", classes=["text-muted", "d-block", "mt-1"], ), + # Display queue position if available + html.small( + html.i( + classes=["bi", "bi-hourglass-split", "me-1"], + ), + f"Position in queue: #{queue_pos}", + classes=["text-info", "d-block", "mt-1"], + ) + if queue_pos + else html.span(), *( [ html.div( @@ -462,6 +478,33 @@ class QueueStatus(Component[AnyChildren, QueueStatusAttrs]): ), # Add cancel button for pending jobs, remove for others html.div( + # Retry button for error items + html.button( + html.i( + classes=[ + "bi", + "bi-arrow-clockwise", + "me-1", + ], + ), + "Retry", + hx_post=f"/queue/{item['id']}/retry", + hx_trigger="click", + hx_on=( + "htmx:afterRequest: " + "if(event.detail.successful) " + "htmx.trigger('body', 'queue-updated')" + ), + classes=[ + "btn", + "btn-sm", + "btn-outline-primary", + "mt-2", + "me-2", + ], + ) + if item["status"] == "error" + else html.span(), html.button( html.i(classes=["bi", "bi-x-lg", "me-1"]), "Cancel", @@ -1009,6 +1052,57 @@ def upgrade(request: Request) -> RedirectResponse: return RedirectResponse(url="/pricing?error=checkout_failed") +@app.get("/account") +def account(request: Request) -> UI.AccountPage | RedirectResponse: + """Display account management page.""" + user_id = request.session.get("user_id") + if not user_id: + return RedirectResponse(url="/?error=login_required") + + user = Core.Database.get_user_by_id(user_id) + if not user: + request.session.clear() + return RedirectResponse(url="/?error=user_not_found") + + # Get usage stats + period_start, period_end = Billing.get_period_boundaries(user) + usage = Billing.get_usage(user_id, period_start, period_end) + + # Get limits + tier = user.get("plan_tier", "free") + limits = Billing.TIER_LIMITS.get(tier, Billing.TIER_LIMITS["free"]) + + return UI.AccountPage( + user=user, + usage=usage, + limits=limits, + portal_url="/billing/portal" if tier == "paid" else None, + ) + + +@app.post("/logout") +def logout(request: Request) -> RedirectResponse: + """Log out user.""" + request.session.clear() + return RedirectResponse(url="/", status_code=303) + + +@app.post("/billing/portal") +def billing_portal(request: Request) -> RedirectResponse: + """Redirect to Stripe billing portal.""" + user_id = request.session.get("user_id") + if not user_id: + return RedirectResponse(url="/?error=login_required") + + try: + portal_url = Billing.create_portal_session(user_id, BASE_URL) + return RedirectResponse(url=portal_url, status_code=303) + except ValueError as e: + logger.warning("Failed to create portal session: %s", e) + # If user has no customer ID (e.g. free tier), redirect to pricing + return RedirectResponse(url="/pricing") + + def _handle_test_login(email: str, request: Request) -> Response: """Handle login in test mode.""" # Special handling for demo account @@ -1167,11 +1261,11 @@ def account_page(request: Request) -> UI.AccountPage | RedirectResponse: # Get usage stats period_start, period_end = Billing.get_period_boundaries(user) usage = Billing.get_usage(user["id"], period_start, period_end) - + # Get limits tier = user.get("plan_tier", "free") limits = Billing.TIER_LIMITS.get(tier, Billing.TIER_LIMITS["free"]) - + return UI.AccountPage( user=user, usage=usage, @@ -1180,17 +1274,6 @@ def account_page(request: Request) -> UI.AccountPage | RedirectResponse: ) -@app.get("/logout") -def logout(request: Request) -> Response: - """Handle logout.""" - request.session.clear() - return Response( - "", - status_code=302, - headers={"Location": "/"}, - ) - - @app.post("/submit") def submit_article( # noqa: PLR0911, PLR0914 request: Request, @@ -1565,21 +1648,6 @@ def billing_checkout(request: Request, data: FormData) -> Response: return Response(f"Error: {e!s}", status_code=400) -@app.post("/billing/portal") -def billing_portal(request: Request) -> Response | RedirectResponse: - """Create Stripe Billing Portal session.""" - user_id = request.session.get("user_id") - if not user_id: - return Response("Unauthorized", status_code=401) - - try: - portal_url = Billing.create_portal_session(user_id, BASE_URL) - return RedirectResponse(url=portal_url, status_code=303) - except Exception: - logger.exception("Portal error - ensure Stripe portal is configured") - return Response("Portal not configured", status_code=500) - - @app.post("/stripe/webhook") async def stripe_webhook(request: Request) -> Response: """Handle Stripe webhook events.""" @@ -3036,9 +3104,9 @@ class TestAccountPage(BaseWebTest): ) self.client.post("/login", data={"email": "test@example.com"}) - def test_account_page_shows_usage(self) -> None: - """Account page should show usage stats.""" - # Create some usage + def test_account_page_logged_in(self) -> None: + """Account page should render for logged-in users.""" + # Create some usage to verify stats are shown ep_id = Core.Database.create_episode( title="Test Episode", audio_url="https://example.com/audio.mp3", @@ -3052,19 +3120,51 @@ class TestAccountPage(BaseWebTest): Core.Database.add_episode_to_user(self.user_id, ep_id) response = self.client.get("/account") - + self.assertEqual(response.status_code, 200) - self.assertIn("1 / 10", response.text) # Usage / Limit for free tier self.assertIn("My Account", response.text) self.assertIn("test@example.com", response.text) + self.assertIn("1 / 10", response.text) # Usage / Limit for free tier def test_account_page_login_required(self) -> None: """Should redirect to login if not logged in.""" - self.client.get("/logout") + self.client.post("/logout") response = self.client.get("/account", follow_redirects=False) self.assertEqual(response.status_code, 307) self.assertEqual(response.headers["location"], "/?error=login_required") + def test_logout(self) -> None: + """Logout should clear session.""" + response = self.client.post("/logout", follow_redirects=False) + self.assertEqual(response.status_code, 303) + self.assertEqual(response.headers["location"], "/") + + # Verify session cleared + response = self.client.get("/account", follow_redirects=False) + self.assertEqual(response.status_code, 307) + + def test_billing_portal_redirect(self) -> None: + """Billing portal should redirect to Stripe.""" + # First set a customer ID + Core.Database.set_user_stripe_customer(self.user_id, "cus_test") + + # Mock the create_portal_session method + with patch( + "Biz.PodcastItLater.Billing.create_portal_session", + ) as mock_portal: + mock_portal.return_value = "https://billing.stripe.com/test" + + response = self.client.post( + "/billing/portal", + follow_redirects=False, + ) + + self.assertEqual(response.status_code, 303) + self.assertEqual( + response.headers["location"], + "https://billing.stripe.com/test", + ) + def test() -> None: """Run all tests for the web module.""" |
