summaryrefslogtreecommitdiff
path: root/Biz
diff options
context:
space:
mode:
Diffstat (limited to 'Biz')
-rw-r--r--Biz/PodcastItLater/Admin.py23
-rw-r--r--Biz/PodcastItLater/Core.py34
-rwxr-xr-x[-rw-r--r--]Biz/PodcastItLater/Test.py0
-rw-r--r--Biz/PodcastItLater/UI.py193
-rw-r--r--Biz/PodcastItLater/Web.py331
-rwxr-xr-xBiz/Que/Host.hs73
6 files changed, 462 insertions, 192 deletions
diff --git a/Biz/PodcastItLater/Admin.py b/Biz/PodcastItLater/Admin.py
index 10a8e58..6faf7fb 100644
--- a/Biz/PodcastItLater/Admin.py
+++ b/Biz/PodcastItLater/Admin.py
@@ -795,7 +795,7 @@ def admin_queue_status(request: Request) -> AdminView | Response | html.div:
def retry_queue_item(request: Request, job_id: int) -> Response:
"""Retry a failed queue item."""
try:
- # Check if user owns this job
+ # Check if user owns this job or is admin
user_id = request.session.get("user_id")
if not user_id:
return Response("Unauthorized", status_code=401)
@@ -803,15 +803,30 @@ def retry_queue_item(request: Request, job_id: int) -> Response:
job = Core.Database.get_job_by_id(
job_id,
)
- if job is None or job.get("user_id") != user_id:
+ if job is None:
+ return Response("Job not found", status_code=404)
+
+ # Check ownership or admin status
+ user = Core.Database.get_user_by_id(user_id)
+ if job.get("user_id") != user_id and not Core.is_admin(user):
return Response("Forbidden", status_code=403)
Core.Database.retry_job(job_id)
- # Redirect back to admin view
+
+ # Check if request is from admin page via referer header
+ is_from_admin = "/admin" in request.headers.get("referer", "")
+
+ # Redirect to admin if from admin page, trigger update otherwise
+ if is_from_admin:
+ return Response(
+ "",
+ status_code=200,
+ headers={"HX-Redirect": "/admin"},
+ )
return Response(
"",
status_code=200,
- headers={"HX-Redirect": "/admin"},
+ headers={"HX-Trigger": "queue-updated"},
)
except (ValueError, KeyError) as e:
return Response(
diff --git a/Biz/PodcastItLater/Core.py b/Biz/PodcastItLater/Core.py
index 8d31956..a2c1b3b 100644
--- a/Biz/PodcastItLater/Core.py
+++ b/Biz/PodcastItLater/Core.py
@@ -1785,6 +1785,40 @@ class TestQueueOperations(Test.TestCase):
self.assertEqual(counts.get("processing", 0), 1)
self.assertEqual(counts.get("error", 0), 1)
+ def test_queue_position(self) -> None:
+ """Verify queue position calculation."""
+ # Add multiple pending jobs
+ job1 = Database.add_to_queue(
+ "https://example.com/1",
+ "test@example.com",
+ self.user_id,
+ )
+ time.sleep(0.01)
+ job2 = Database.add_to_queue(
+ "https://example.com/2",
+ "test@example.com",
+ self.user_id,
+ )
+ time.sleep(0.01)
+ job3 = Database.add_to_queue(
+ "https://example.com/3",
+ "test@example.com",
+ self.user_id,
+ )
+
+ # Check positions
+ self.assertEqual(Database.get_queue_position(job1), 1)
+ self.assertEqual(Database.get_queue_position(job2), 2)
+ self.assertEqual(Database.get_queue_position(job3), 3)
+
+ # Move job 2 to processing
+ Database.update_job_status(job2, "processing")
+
+ # Check positions (job 3 should now be 2nd pending job)
+ self.assertEqual(Database.get_queue_position(job1), 1)
+ self.assertIsNone(Database.get_queue_position(job2))
+ self.assertEqual(Database.get_queue_position(job3), 2)
+
class TestEpisodeManagement(Test.TestCase):
"""Test episode management functionality."""
diff --git a/Biz/PodcastItLater/Test.py b/Biz/PodcastItLater/Test.py
index b2a1d24..b2a1d24 100644..100755
--- a/Biz/PodcastItLater/Test.py
+++ b/Biz/PodcastItLater/Test.py
diff --git a/Biz/PodcastItLater/UI.py b/Biz/PodcastItLater/UI.py
index 27f5fff..bdf7a5b 100644
--- a/Biz/PodcastItLater/UI.py
+++ b/Biz/PodcastItLater/UI.py
@@ -6,6 +6,7 @@ Common UI components and utilities shared across web pages.
# : out podcastitlater-ui
# : dep ludic
+import Biz.PodcastItLater.Core as Core
import ludic.html as html
import typing
from ludic.attrs import Attrs
@@ -127,15 +128,6 @@ def create_bootstrap_js() -> html.script:
)
-def is_admin(user: dict[str, typing.Any] | None) -> bool:
- """Check if user is an admin based on email whitelist."""
- if not user:
- return False
- admin_emails = ["ben@bensima.com", "admin@example.com"]
- return user.get("email", "").lower() in [
- email.lower() for email in admin_emails
- ]
-
class PageLayoutAttrs(Attrs):
"""Attributes for PageLayout component."""
@@ -318,7 +310,7 @@ class PageLayout(Component[AnyChildren, PageLayoutAttrs]):
),
classes=["nav-item", "dropdown"],
)
- if user and is_admin(user)
+ if user and Core.is_admin(user)
else html.span(),
classes=["navbar-nav"],
),
@@ -407,6 +399,173 @@ class PageLayout(Component[AnyChildren, PageLayoutAttrs]):
)
+class AccountPageAttrs(Attrs):
+ """Attributes for AccountPage component."""
+
+ user: dict[str, typing.Any]
+ usage: dict[str, int]
+ limits: dict[str, int | None]
+ portal_url: str | None
+
+
+class AccountPage(Component[AnyChildren, AccountPageAttrs]):
+ """Account management page component."""
+
+ @override
+ def render(self) -> PageLayout:
+ user = self.attrs["user"]
+ usage = self.attrs["usage"]
+ limits = self.attrs["limits"]
+ portal_url = self.attrs["portal_url"]
+
+ plan_tier = user.get("plan_tier", "free")
+ is_paid = plan_tier == "paid"
+
+ article_limit = limits.get("articles_per_period")
+ article_usage = usage.get("articles", 0)
+
+ limit_text = "Unlimited" if article_limit is None else str(article_limit)
+
+ return PageLayout(
+ html.div(
+ html.div(
+ html.div(
+ html.div(
+ html.div(
+ html.h2(
+ html.i(
+ classes=[
+ "bi",
+ "bi-person-circle",
+ "me-2"
+ ]
+ ),
+ "My Account",
+ classes=["card-title", "mb-4"],
+ ),
+ # User Info Section
+ html.div(
+ html.h5("Profile", classes=["mb-3"]),
+ html.p(
+ html.strong("Email: "),
+ user.get("email", ""),
+ classes=["mb-2"],
+ ),
+ html.p(
+ html.strong("Member since: "),
+ user.get("created_at", "").split("T")[0],
+ classes=["mb-4"],
+ ),
+ classes=["mb-5"],
+ ),
+ # Subscription Section
+ html.div(
+ html.h5("Subscription", classes=["mb-3"]),
+ html.div(
+ html.div(
+ html.strong("Current Plan"),
+ html.span(
+ plan_tier.title(),
+ classes=[
+ "badge",
+ "bg-success"
+ if is_paid
+ else "bg-secondary",
+ "ms-2",
+ ],
+ ),
+ classes=[
+ "d-flex",
+ "align-items-center",
+ "mb-3",
+ ],
+ ),
+ # Usage Stats
+ html.div(
+ html.p(
+ "Usage this period:",
+ classes=["mb-2", "text-muted"],
+ ),
+ html.div(
+ html.div(
+ f"{article_usage} / {limit_text}",
+ classes=["mb-1"],
+ ),
+ html.div(
+ html.div(
+ classes=["progress-bar"],
+ role="progressbar",
+ style={
+ "width": f"{min(100, (article_usage / article_limit * 100))}%"
+ } if article_limit else {"width": "0%"},
+ ),
+ classes=["progress", "mb-3"],
+ style={"height": "10px"},
+ ) if article_limit else html.div(),
+ classes=["mb-3"],
+ ),
+ ),
+ # Actions
+ html.div(
+ html.form(
+ html.button(
+ html.i(classes=["bi", "bi-credit-card", "me-2"]),
+ "Manage Subscription",
+ type="submit",
+ classes=["btn", "btn-outline-primary"],
+ ),
+ method="post",
+ action=portal_url,
+ ) if is_paid and portal_url else
+ html.a(
+ html.i(classes=["bi", "bi-star-fill", "me-2"]),
+ "Upgrade to Pro",
+ href="/pricing",
+ classes=["btn", "btn-primary"],
+ ),
+ classes=["d-flex", "gap-2"],
+ ),
+ classes=["card", "card-body", "bg-light"],
+ ),
+ classes=["mb-5"],
+ ),
+ # Logout Section
+ html.div(
+ html.form(
+ html.button(
+ html.i(
+ classes=[
+ "bi",
+ "bi-box-arrow-right",
+ "me-2"
+ ]
+ ),
+ "Log Out",
+ type="submit",
+ classes=["btn", "btn-outline-danger"],
+ ),
+ action="/logout",
+ method="post",
+ ),
+ classes=["border-top", "pt-4"],
+ ),
+ classes=["card-body", "p-4"],
+ ),
+ classes=["card", "shadow-sm"],
+ ),
+ classes=["col-lg-8", "mx-auto"],
+ ),
+ classes=["row"],
+ ),
+ ),
+ user=user,
+ current_page="account",
+ page_title="Account - PodcastItLater",
+ error=None,
+ meta_tags=[],
+ )
+
+
class PricingPageAttrs(Attrs):
"""Attributes for PricingPage component."""
@@ -422,12 +581,7 @@ class PricingPage(Component[AnyChildren, PricingPageAttrs]):
current_tier = user.get("plan_tier", "free") if user else "free"
return PageLayout(
- user=user,
- current_page="pricing",
- page_title="Pricing - PodcastItLater",
- error=None,
- meta_tags=[],
- children=[
+ html.div(
html.div(
html.h2("Simple Pricing", classes=["text-center", "mb-5"]),
html.div(
@@ -547,5 +701,10 @@ class PricingPage(Component[AnyChildren, PricingPageAttrs]):
),
classes=["container", "py-3"],
),
- ],
+ ),
+ user=user,
+ current_page="pricing",
+ page_title="Pricing - PodcastItLater",
+ error=None,
+ meta_tags=[],
)
diff --git a/Biz/PodcastItLater/Web.py b/Biz/PodcastItLater/Web.py
index 7e8e969..30ebf53 100644
--- a/Biz/PodcastItLater/Web.py
+++ b/Biz/PodcastItLater/Web.py
@@ -36,6 +36,7 @@ import re
import sys
import tempfile
import typing
+from unittest.mock import patch
import urllib.parse
import uvicorn
from datetime import datetime
@@ -378,6 +379,11 @@ class QueueStatus(Component[AnyChildren, QueueStatusAttrs]):
badge_class = status_classes.get(item["status"], "bg-secondary")
icon_class = status_icons.get(item["status"], "bi-question-circle")
+ # Get queue position for pending items
+ queue_pos = None
+ if item["status"] == "pending":
+ queue_pos = Core.Database.get_queue_position(item["id"])
+
queue_items.append(
html.div(
html.div(
@@ -429,6 +435,14 @@ class QueueStatus(Component[AnyChildren, QueueStatusAttrs]):
f"Created: {item['created_at']}",
classes=["text-muted", "d-block", "mt-1"],
),
+ # Display queue position if available
+ html.small(
+ html.i(classes=["bi", "bi-hourglass-split", "me-1"]),
+ f"Position in queue: #{queue_pos}",
+ classes=["text-info", "d-block", "mt-1"],
+ )
+ if queue_pos
+ else html.span(),
*(
[
html.div(
@@ -456,6 +470,27 @@ class QueueStatus(Component[AnyChildren, QueueStatusAttrs]):
),
# Add cancel button for pending jobs, remove for others
html.div(
+ # Retry button for error items
+ html.button(
+ html.i(classes=["bi", "bi-arrow-clockwise", "me-1"]),
+ "Retry",
+ hx_post=f"/queue/{item['id']}/retry",
+ hx_trigger="click",
+ hx_on=(
+ "htmx:afterRequest: "
+ "if(event.detail.successful) "
+ "htmx.trigger('body', 'queue-updated')"
+ ),
+ classes=[
+ "btn",
+ "btn-sm",
+ "btn-outline-primary",
+ "mt-2",
+ "me-2",
+ ],
+ )
+ if item["status"] == "error"
+ else html.span(),
html.button(
html.i(classes=["bi", "bi-x-lg", "me-1"]),
"Cancel",
@@ -1003,6 +1038,57 @@ def upgrade(request: Request) -> RedirectResponse:
return RedirectResponse(url="/pricing?error=checkout_failed")
+@app.get("/account")
+def account(request: Request) -> UI.AccountPage | RedirectResponse:
+ """Display account management page."""
+ user_id = request.session.get("user_id")
+ if not user_id:
+ return RedirectResponse(url="/?error=login_required")
+
+ user = Core.Database.get_user_by_id(user_id)
+ if not user:
+ request.session.clear()
+ return RedirectResponse(url="/?error=user_not_found")
+
+ # Get usage stats
+ period_start, period_end = Billing.get_period_boundaries(user)
+ usage = Billing.get_usage(user_id, period_start, period_end)
+
+ # Get limits
+ tier = user.get("plan_tier", "free")
+ limits = Billing.TIER_LIMITS.get(tier, Billing.TIER_LIMITS["free"])
+
+ return UI.AccountPage(
+ user=user,
+ usage=usage,
+ limits=limits,
+ portal_url="/billing/portal" if tier == "paid" else None,
+ )
+
+
+@app.post("/logout")
+def logout(request: Request) -> RedirectResponse:
+ """Log out user."""
+ request.session.clear()
+ return RedirectResponse(url="/", status_code=303)
+
+
+@app.post("/billing/portal")
+def billing_portal(request: Request) -> RedirectResponse:
+ """Redirect to Stripe billing portal."""
+ user_id = request.session.get("user_id")
+ if not user_id:
+ return RedirectResponse(url="/?error=login_required")
+
+ try:
+ portal_url = Billing.create_portal_session(user_id, BASE_URL)
+ return RedirectResponse(url=portal_url, status_code=303)
+ except ValueError as e:
+ logger.warning("Failed to create portal session: %s", e)
+ # If user has no customer ID (e.g. free tier), redirect to pricing
+ return RedirectResponse(url="/pricing")
+
+
def _handle_test_login(email: str, request: Request) -> Response:
"""Handle login in test mode."""
# Special handling for demo account
@@ -1148,7 +1234,7 @@ def verify_magic_link(request: Request) -> Response:
@app.get("/account")
-def account_page(request: Request) -> UI.PageLayout | RedirectResponse:
+def account_page(request: Request) -> UI.AccountPage | RedirectResponse:
"""Account management page."""
user_id = request.session.get("user_id")
if not user_id:
@@ -1158,165 +1244,19 @@ def account_page(request: Request) -> UI.PageLayout | RedirectResponse:
if not user:
return RedirectResponse(url="/?error=user_not_found")
- # Get subscription details
+ # Get usage stats
+ period_start, period_end = Billing.get_period_boundaries(user)
+ usage = Billing.get_usage(user["id"], period_start, period_end)
+
+ # Get limits
tier = user.get("plan_tier", "free")
- tier_info = Billing.get_tier_info(tier)
- subscription_status = user.get("subscription_status", "")
- cancel_at_period_end = user.get("cancel_at_period_end", 0) == 1
-
- return UI.PageLayout(
- html.h2(
- html.i(
- classes=["bi", "bi-person-circle", "me-2"],
- ),
- "Account Management",
- classes=["mb-4"],
- ),
- html.div(
- html.h4(
- html.i(classes=["bi", "bi-envelope-fill", "me-2"]),
- "Account Information",
- classes=["card-header", "bg-transparent"],
- ),
- html.div(
- html.div(
- html.strong("Email: "),
- user["email"],
- classes=["mb-2"],
- ),
- html.div(
- html.strong("Account Created: "),
- user["created_at"],
- classes=["mb-2"],
- ),
- classes=["card-body"],
- ),
- classes=["card", "mb-4"],
- ),
- html.div(
- html.h4(
- html.i(
- classes=["bi", "bi-credit-card-fill", "me-2"],
- ),
- "Subscription",
- classes=["card-header", "bg-transparent"],
- ),
- html.div(
- html.div(
- html.strong("Plan: "),
- tier_info["name"],
- f" ({tier_info['price']})",
- classes=["mb-2"],
- ),
- html.div(
- html.strong("Status: "),
- subscription_status.title()
- if subscription_status
- else "Active",
- classes=["mb-2"],
- )
- if tier == "paid"
- else html.div(),
- html.div(
- html.i(
- classes=[
- "bi",
- "bi-info-circle",
- "me-1",
- ],
- ),
- "Your subscription will cancel at the end "
- "of the billing period.",
- classes=[
- "alert",
- "alert-warning",
- "mt-2",
- "mb-2",
- ],
- )
- if cancel_at_period_end
- else html.div(),
- html.div(
- html.strong("Features: "),
- tier_info["description"],
- classes=["mb-3"],
- ),
- html.div(
- html.a(
- html.i(
- classes=[
- "bi",
- "bi-arrow-up-circle",
- "me-1",
- ],
- ),
- "Upgrade to Paid Plan",
- href="#",
- hx_post="/billing/checkout",
- hx_vals='{"tier": "paid"}',
- classes=[
- "btn",
- "btn-success",
- "me-2",
- ],
- )
- if tier == "free"
- else html.form(
- html.button(
- html.i(
- classes=[
- "bi",
- "bi-gear-fill",
- "me-1",
- ],
- ),
- "Manage Subscription",
- type="submit",
- classes=[
- "btn",
- "btn-primary",
- "me-2",
- ],
- ),
- method="post",
- action="/billing/portal",
- ),
- ),
- classes=["card-body"],
- ),
- classes=["card", "mb-4"],
- ),
- html.div(
- html.h4(
- html.i(classes=["bi", "bi-sliders", "me-2"]),
- "Actions",
- classes=["card-header", "bg-transparent"],
- ),
- html.div(
- html.a(
- html.i(
- classes=[
- "bi",
- "bi-box-arrow-right",
- "me-1",
- ],
- ),
- "Logout",
- href="/logout",
- classes=[
- "btn",
- "btn-outline-secondary",
- "mb-2",
- "me-2",
- ],
- ),
- classes=["card-body"],
- ),
- classes=["card", "mb-4"],
- ),
+ limits = Billing.TIER_LIMITS.get(tier, Billing.TIER_LIMITS["free"])
+
+ return UI.AccountPage(
user=user,
- current_page="account",
- error=None,
+ usage=usage,
+ limits=limits,
+ portal_url="/billing/portal" if tier == "paid" else None,
)
@@ -3164,6 +3104,80 @@ class TestUsageLimits(BaseWebTest):
self.assertEqual(usage["articles"], 20)
+class TestAccountPage(BaseWebTest):
+ """Test account page functionality."""
+
+ def setUp(self) -> None:
+ """Set up test with user."""
+ super().setUp()
+ self.user_id, _ = Core.Database.create_user(
+ "test@example.com",
+ status="active",
+ )
+ self.client.post("/login", data={"email": "test@example.com"})
+
+ def test_account_page_logged_in(self) -> None:
+ """Account page should render for logged-in users."""
+ # Create some usage to verify stats are shown
+ ep_id = Core.Database.create_episode(
+ title="Test Episode",
+ audio_url="https://example.com/audio.mp3",
+ duration=300,
+ content_length=1000,
+ user_id=self.user_id,
+ author="Test Author",
+ original_url="https://example.com/article",
+ original_url_hash=Core.hash_url("https://example.com/article"),
+ )
+ Core.Database.add_episode_to_user(self.user_id, ep_id)
+
+ response = self.client.get("/account")
+
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("My Account", response.text)
+ self.assertIn("test@example.com", response.text)
+ self.assertIn("1 / 10", response.text) # Usage / Limit for free tier
+
+ def test_account_page_login_required(self) -> None:
+ """Should redirect to login if not logged in."""
+ self.client.post("/logout")
+ response = self.client.get("/account", follow_redirects=False)
+ self.assertEqual(response.status_code, 307)
+ self.assertEqual(response.headers["location"], "/?error=login_required")
+
+ def test_logout(self) -> None:
+ """Logout should clear session."""
+ response = self.client.post("/logout", follow_redirects=False)
+ self.assertEqual(response.status_code, 303)
+ self.assertEqual(response.headers["location"], "/")
+
+ # Verify session cleared
+ response = self.client.get("/account", follow_redirects=False)
+ self.assertEqual(response.status_code, 307)
+
+ def test_billing_portal_redirect(self) -> None:
+ """Billing portal should redirect to Stripe."""
+ # First set a customer ID
+ Core.Database.set_user_stripe_customer(self.user_id, "cus_test")
+
+ # Mock the create_portal_session method
+ with patch(
+ "Biz.PodcastItLater.Billing.create_portal_session",
+ ) as mock_portal:
+ mock_portal.return_value = "https://billing.stripe.com/test"
+
+ response = self.client.post(
+ "/billing/portal",
+ follow_redirects=False,
+ )
+
+ self.assertEqual(response.status_code, 303)
+ self.assertEqual(
+ response.headers["location"],
+ "https://billing.stripe.com/test",
+ )
+
+
def test() -> None:
"""Run all tests for the web module."""
Test.run(
@@ -3180,6 +3194,7 @@ def test() -> None:
TestEpisodeDeduplication,
TestMetricsTracking,
TestUsageLimits,
+ TestAccountPage,
],
)
diff --git a/Biz/Que/Host.hs b/Biz/Que/Host.hs
index 834ce0e..8d826b4 100755
--- a/Biz/Que/Host.hs
+++ b/Biz/Que/Host.hs
@@ -33,6 +33,7 @@ import qualified Control.Exception as Exception
import Data.HashMap.Lazy (HashMap)
import qualified Data.HashMap.Lazy as HashMap
import Network.HTTP.Media ((//), (/:))
+import Network.Socket (SockAddr (..))
import qualified Network.Wai.Handler.Warp as Warp
import qualified Omni.Cli as Cli
import qualified Omni.Log as Log
@@ -75,7 +76,30 @@ Usage:
|]
test :: Test.Tree
-test = Test.group "Biz.Que.Host" [Test.unit "id" <| 1 @=? (1 :: Integer)]
+test =
+ Test.group
+ "Biz.Que.Host"
+ [ Test.unit "id" <| 1 @=? (1 :: Integer),
+ Test.unit "putQue requires auth for '_'" <| do
+ st <- atomically <| STM.newTVar mempty
+ let cfg = Envy.defConfig
+ let handlers = paths cfg
+
+ -- Case 1: No auth, should fail
+ let nonLocalHost = SockAddrInet 0 0
+ let handler1 = putQue handlers nonLocalHost Nothing "_" "testq" "body"
+ res1 <- Servant.runHandler (runReaderT handler1 st)
+ case res1 of
+ Left err -> if errHTTPCode err == 401 then pure () else Test.assertFailure ("Expected 401, got " <> show err)
+ Right _ -> Test.assertFailure "Expected failure, got success"
+
+ -- Case 2: Correct auth, should succeed
+ let handler2 = putQue handlers nonLocalHost (Just "admin-key") "_" "testq" "body"
+ res2 <- Servant.runHandler (runReaderT handler2 st)
+ case res2 of
+ Left err -> Test.assertFailure (show err)
+ Right _ -> pure ()
+ ]
type App = ReaderT AppState Servant.Handler
@@ -125,23 +149,31 @@ data Paths path = Paths
:- Get '[JSON] NoContent,
dash ::
path
- :- "_"
+ :- RemoteHost
+ :> Header "Authorization" Text
+ :> "_"
:> "dash"
:> Get '[JSON] Ques,
getQue ::
path
- :- Capture "ns" Text
+ :- RemoteHost
+ :> Header "Authorization" Text
+ :> Capture "ns" Text
:> Capture "quename" Text
:> Get '[PlainText, HTML, OctetStream] Message,
getStream ::
path
- :- Capture "ns" Text
+ :- RemoteHost
+ :> Header "Authorization" Text
+ :> Capture "ns" Text
:> Capture "quename" Text
:> "stream"
:> StreamGet NoFraming OctetStream (SourceIO Message),
putQue ::
path
- :- Capture "ns" Text
+ :- RemoteHost
+ :> Header "Authorization" Text
+ :> Capture "ns" Text
:> Capture "quepath" Text
:> ReqBody '[PlainText, HTML, OctetStream] Text
:> Post '[PlainText, HTML, OctetStream] NoContent
@@ -149,15 +181,15 @@ data Paths path = Paths
deriving (Generic)
paths :: Config -> Paths (AsServerT App)
-paths _ =
- -- TODO revive authkey stuff
- -- - read Authorization header, compare with queSkey
- -- - Only allow my IP or localhost to publish to '_' namespace
+paths Config {..} =
Paths
{ home =
throwError <| err301 {errHeaders = [("Location", "/_/index")]},
- dash = gets,
- getQue = \ns qn -> do
+ dash = \rh mAuth -> do
+ checkAuth queSkey rh mAuth "_"
+ gets,
+ getQue = \rh mAuth ns qn -> do
+ checkAuth queSkey rh mAuth ns
guardNs ns ["pub", "_"]
modify <| upsertNamespace ns
q <- que ns qn
@@ -165,7 +197,8 @@ paths _ =
|> liftIO
+> Go.tap
|> liftIO,
- getStream = \ns qn -> do
+ getStream = \rh mAuth ns qn -> do
+ checkAuth queSkey rh mAuth ns
guardNs ns ["pub", "_"]
modify <| upsertNamespace ns
q <- que ns qn
@@ -174,7 +207,8 @@ paths _ =
+> Go.tap
|> Source.fromAction (const False) -- peek chan instead of False?
|> pure,
- putQue = \ns qp body -> do
+ putQue = \rh mAuth ns qp body -> do
+ checkAuth queSkey rh mAuth ns
guardNs ns ["pub", "_"]
modify <| upsertNamespace ns
q <- que ns qp
@@ -188,6 +222,19 @@ paths _ =
>> pure NoContent
}
+checkAuth :: Text -> SockAddr -> Maybe Text -> Text -> App ()
+checkAuth skey rh mAuth ns = do
+ let authorized = mAuth == Just skey
+ let isLocal = isLocalhost rh
+ when (ns == "_" && not (authorized || isLocal)) <| do
+ throwError err401 {errBody = "Authorized access only for '_' namespace"}
+
+isLocalhost :: SockAddr -> Bool
+isLocalhost (SockAddrInet _ h) = h == 0x0100007f -- 127.0.0.1
+isLocalhost (SockAddrInet6 _ _ (0, 0, 0, 1) _) = True -- ::1
+isLocalhost (SockAddrUnix _) = True
+isLocalhost _ = False
+
-- | Given `guardNs ns whitelist`, if `ns` is not in the `whitelist`
-- list, return a 405 error.
guardNs :: (Applicative a, MonadError ServerError a) => Text -> [Text] -> a ()