diff options
Diffstat (limited to 'Biz')
| -rw-r--r-- | Biz/PodcastItLater/Core.py | 14 | ||||
| -rw-r--r-- | Biz/PodcastItLater/Web.py | 179 |
2 files changed, 189 insertions, 4 deletions
diff --git a/Biz/PodcastItLater/Core.py b/Biz/PodcastItLater/Core.py index f1f89e9..8d31956 100644 --- a/Biz/PodcastItLater/Core.py +++ b/Biz/PodcastItLater/Core.py @@ -1382,18 +1382,24 @@ class Database: # noqa: PLR0904 ) -> dict[str, int]: """Get usage stats for user in period. + Counts episodes added to user's feed (via user_episodes table) + during the billing period, regardless of who created them. + Returns: dict with keys: articles (int), minutes (int) """ with Database.get_connection() as conn: cursor = conn.cursor() - # Count articles created in period + # Count articles added to user's feed in period + # Uses user_episodes junction table to track when episodes + # were added, which correctly handles shared/existing episodes cursor.execute( """ - SELECT COUNT(*) as count, SUM(duration) as total_seconds - FROM episodes - WHERE user_id = ? AND created_at >= ? AND created_at < ? + SELECT COUNT(*) as count, SUM(e.duration) as total_seconds + FROM user_episodes ue + JOIN episodes e ON e.id = ue.episode_id + WHERE ue.user_id = ? AND ue.added_at >= ? AND ue.added_at < ? """, (user_id, period_start.isoformat(), period_end.isoformat()), ) diff --git a/Biz/PodcastItLater/Web.py b/Biz/PodcastItLater/Web.py index 7a9f63a..4a8e57e 100644 --- a/Biz/PodcastItLater/Web.py +++ b/Biz/PodcastItLater/Web.py @@ -2956,6 +2956,184 @@ class TestMetricsTracking(BaseWebTest): self.assertGreater(len(played_metrics), 0) +class TestUsageLimits(BaseWebTest): + """Test usage tracking and limit enforcement.""" + + def setUp(self) -> None: + """Set up test with free tier user.""" + super().setUp() + + # Create free tier user + self.user_id, self.token = Core.Database.create_user( + "free@example.com", + status="active", + ) + # Login + self.client.post("/login", data={"email": "free@example.com"}) + + def test_usage_counts_episodes_added_to_feed(self) -> None: + """Usage should count episodes added via user_episodes table.""" + user = Core.Database.get_user_by_id(self.user_id) + self.assertIsNotNone(user) + assert user is not None # type narrowing # noqa: S101 + period_start, period_end = Billing.get_period_boundaries(user) + + # Initially no usage + usage = Billing.get_usage(self.user_id, period_start, period_end) + self.assertEqual(usage["articles"], 0) + + # Add an episode to user's feed + ep_id = Core.Database.create_episode( + title="Test Episode", + audio_url="https://example.com/test.mp3", + duration=300, + content_length=1000, + user_id=self.user_id, + author="Test", + original_url="https://example.com/article", + original_url_hash=Core.hash_url("https://example.com/article"), + ) + Core.Database.add_episode_to_user(self.user_id, ep_id) + + # Usage should now be 1 + usage = Billing.get_usage(self.user_id, period_start, period_end) + self.assertEqual(usage["articles"], 1) + + def test_usage_counts_existing_episodes_correctly(self) -> None: + """Adding existing episodes should count toward usage.""" + # Create another user who creates an episode + other_user_id, _ = Core.Database.create_user("other@example.com") + ep_id = Core.Database.create_episode( + title="Other User Episode", + audio_url="https://example.com/other.mp3", + duration=400, + content_length=1200, + user_id=other_user_id, + author="Other", + original_url="https://example.com/other-article", + original_url_hash=Core.hash_url( + "https://example.com/other-article", + ), + ) + Core.Database.add_episode_to_user(other_user_id, ep_id) + + # Free user adds it to their feed + Core.Database.add_episode_to_user(self.user_id, ep_id) + + # Check usage for free user + user = Core.Database.get_user_by_id(self.user_id) + self.assertIsNotNone(user) + assert user is not None # type narrowing # noqa: S101 + period_start, period_end = Billing.get_period_boundaries(user) + usage = Billing.get_usage(self.user_id, period_start, period_end) + + # Should count as 1 article for free user + self.assertEqual(usage["articles"], 1) + + def test_free_tier_limit_enforcement(self) -> None: + """Free tier users should be blocked at 10 articles.""" + # Add 10 episodes (the free tier limit) + for i in range(10): + ep_id = Core.Database.create_episode( + title=f"Episode {i}", + audio_url=f"https://example.com/ep{i}.mp3", + duration=300, + content_length=1000, + user_id=self.user_id, + author="Test", + original_url=f"https://example.com/article{i}", + original_url_hash=Core.hash_url( + f"https://example.com/article{i}", + ), + ) + Core.Database.add_episode_to_user(self.user_id, ep_id) + + # Try to submit 11th article + response = self.client.post( + "/submit", + data={"url": "https://example.com/article11"}, + ) + + # Should be blocked + self.assertEqual(response.status_code, 200) + self.assertIn("Limit reached", response.text) + self.assertIn("10", response.text) + self.assertIn("Upgrade", response.text) + + def test_can_submit_blocks_at_limit(self) -> None: + """can_submit should return False at limit.""" + # Add 10 episodes + for i in range(10): + ep_id = Core.Database.create_episode( + title=f"Episode {i}", + audio_url=f"https://example.com/ep{i}.mp3", + duration=300, + content_length=1000, + user_id=self.user_id, + author="Test", + original_url=f"https://example.com/article{i}", + original_url_hash=Core.hash_url( + f"https://example.com/article{i}", + ), + ) + Core.Database.add_episode_to_user(self.user_id, ep_id) + + # Check can_submit + allowed, msg, usage = Billing.can_submit(self.user_id) + + self.assertFalse(allowed) + self.assertIn("10", msg) + self.assertIn("limit", msg.lower()) + self.assertEqual(usage["articles"], 10) + + def test_paid_tier_unlimited(self) -> None: + """Paid tier should have no article limits.""" + # Create a paid tier user directly + paid_user_id, _ = Core.Database.create_user("paid@example.com") + + # Simulate paid subscription via update_user_subscription + now = datetime.now(timezone.utc) + period_start = now + december = 12 + january = 1 + period_end = now.replace( + month=now.month + 1 if now.month < december else january, + ) + + Core.Database.update_user_subscription( + paid_user_id, + subscription_id="sub_test123", + status="active", + period_start=period_start, + period_end=period_end, + tier="paid", + cancel_at_period_end=False, + ) + + # Add 20 episodes (more than free limit) + for i in range(20): + ep_id = Core.Database.create_episode( + title=f"Episode {i}", + audio_url=f"https://example.com/ep{i}.mp3", + duration=300, + content_length=1000, + user_id=paid_user_id, + author="Test", + original_url=f"https://example.com/article{i}", + original_url_hash=Core.hash_url( + f"https://example.com/article{i}", + ), + ) + Core.Database.add_episode_to_user(paid_user_id, ep_id) + + # Should still be allowed to submit + allowed, msg, usage = Billing.can_submit(paid_user_id) + + self.assertTrue(allowed) + self.assertEqual(msg, "") + self.assertEqual(usage["articles"], 20) + + def test() -> None: """Run all tests for the web module.""" Test.run( @@ -2971,6 +3149,7 @@ def test() -> None: TestPublicFeed, TestEpisodeDeduplication, TestMetricsTracking, + TestUsageLimits, ], ) |
