diff options
Diffstat (limited to 'Biz/PodcastItLater/Core.py')
| -rw-r--r-- | Biz/PodcastItLater/Core.py | 222 |
1 files changed, 222 insertions, 0 deletions
diff --git a/Biz/PodcastItLater/Core.py b/Biz/PodcastItLater/Core.py index b7b6ed9..42b336e 100644 --- a/Biz/PodcastItLater/Core.py +++ b/Biz/PodcastItLater/Core.py @@ -141,6 +141,12 @@ class Database: # noqa: PLR0904 # Run migration to add default titles Database.migrate_add_default_titles() + # Run migration to add billing fields + Database.migrate_add_billing_fields() + + # Run migration to add stripe events table + Database.migrate_add_stripe_events_table() + @staticmethod def add_to_queue( url: str, @@ -554,6 +560,53 @@ class Database: # noqa: PLR0904 logger.info("Database migrated to support user status") @staticmethod + def migrate_add_billing_fields() -> None: + """Add billing-related fields to users table.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Add columns one by one (SQLite limitation) + columns_to_add = [ + ("plan_tier", "TEXT NOT NULL DEFAULT 'free'"), + ("stripe_customer_id", "TEXT UNIQUE"), + ("stripe_subscription_id", "TEXT UNIQUE"), + ("subscription_status", "TEXT"), + ("current_period_start", "TIMESTAMP"), + ("current_period_end", "TIMESTAMP"), + ("cancel_at_period_end", "INTEGER NOT NULL DEFAULT 0"), + ] + + for column_name, column_def in columns_to_add: + try: + query = f"ALTER TABLE users ADD COLUMN {column_name} " + cursor.execute(query + column_def) + logger.info("Added column users.%s", column_name) + except sqlite3.OperationalError: # noqa: PERF203 + pass # Column already exists + + conn.commit() + + @staticmethod + def migrate_add_stripe_events_table() -> None: + """Create stripe_events table for webhook idempotency.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS stripe_events ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL, + payload TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_stripe_events_created " + "ON stripe_events(created_at)", + ) + conn.commit() + logger.info("Created stripe_events table") + + @staticmethod def migrate_add_default_titles() -> None: """Add default titles to queue items that have None titles.""" with Database.get_connection() as conn: @@ -730,6 +783,175 @@ class Database: # noqa: PLR0904 conn.commit() logger.info("Updated user %s status to %s", user_id, status) + @staticmethod + def set_user_stripe_customer(user_id: int, customer_id: str) -> None: + """Link Stripe customer ID to user.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE users SET stripe_customer_id = ? WHERE id = ?", + (customer_id, user_id), + ) + conn.commit() + logger.info( + "Linked user %s to Stripe customer %s", + user_id, + customer_id, + ) + + @staticmethod + def update_user_subscription( # noqa: PLR0913, PLR0917 + user_id: int, + subscription_id: str, + status: str, + period_start: Any, + period_end: Any, + tier: str, + cancel_at_period_end: bool, # noqa: FBT001 + ) -> None: + """Update user subscription details.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + UPDATE users SET + stripe_subscription_id = ?, + subscription_status = ?, + current_period_start = ?, + current_period_end = ?, + plan_tier = ?, + cancel_at_period_end = ? + WHERE id = ? + """, + ( + subscription_id, + status, + period_start.isoformat(), + period_end.isoformat(), + tier, + 1 if cancel_at_period_end else 0, + user_id, + ), + ) + conn.commit() + logger.info( + "Updated user %s subscription: tier=%s, status=%s", + user_id, + tier, + status, + ) + + @staticmethod + def update_subscription_status(user_id: int, status: str) -> None: + """Update only the subscription status (e.g., past_due).""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE users SET subscription_status = ? WHERE id = ?", + (status, user_id), + ) + conn.commit() + logger.info( + "Updated user %s subscription status to %s", + user_id, + status, + ) + + @staticmethod + def downgrade_to_free(user_id: int) -> None: + """Downgrade user to free tier and clear subscription data.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + UPDATE users SET + plan_tier = 'free', + subscription_status = 'canceled', + stripe_subscription_id = NULL, + current_period_start = NULL, + current_period_end = NULL, + cancel_at_period_end = 0 + WHERE id = ? + """, + (user_id,), + ) + conn.commit() + logger.info("Downgraded user %s to free tier", user_id) + + @staticmethod + def get_user_by_stripe_customer_id( + customer_id: str, + ) -> dict[str, Any] | None: + """Get user by Stripe customer ID.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM users WHERE stripe_customer_id = ?", + (customer_id,), + ) + row = cursor.fetchone() + return dict(row) if row is not None else None + + @staticmethod + def has_processed_stripe_event(event_id: str) -> bool: + """Check if Stripe event has already been processed.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT id FROM stripe_events WHERE id = ?", + (event_id,), + ) + return cursor.fetchone() is not None + + @staticmethod + def mark_stripe_event_processed( + event_id: str, + event_type: str, + payload: bytes, + ) -> None: + """Mark Stripe event as processed for idempotency.""" + with Database.get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT OR IGNORE INTO stripe_events (id, type, payload) " + "VALUES (?, ?, ?)", + (event_id, event_type, payload.decode("utf-8")), + ) + conn.commit() + + @staticmethod + def get_usage( + user_id: int, + period_start: Any, + period_end: Any, + ) -> dict[str, int]: + """Get usage stats for user in period. + + Returns: + dict with keys: articles (int), minutes (int) + """ + with Database.get_connection() as conn: + cursor = conn.cursor() + + # Count articles created in period + cursor.execute( + """ + SELECT COUNT(*) as count, SUM(duration) as total_seconds + FROM episodes + WHERE user_id = ? AND created_at >= ? AND created_at < ? + """, + (user_id, period_start.isoformat(), period_end.isoformat()), + ) + row = cursor.fetchone() + + articles = row["count"] if row else 0 + total_seconds = ( + row["total_seconds"] if row and row["total_seconds"] else 0 + ) + minutes = total_seconds // 60 + + return {"articles": articles, "minutes": minutes} + class TestDatabase(Test.TestCase): """Test the Database class.""" |
