diff options
| author | Ben Sima <ben@bensima.com> | 2025-12-12 21:52:57 -0500 |
|---|---|---|
| committer | Ben Sima <ben@bensima.com> | 2025-12-12 21:52:57 -0500 |
| commit | 1b4dc94eb261e3f3cd22dc12fbc1941e2a545cb9 (patch) | |
| tree | edae72e4c59e25dcd15b898792bf0932e29ad0f4 | |
| parent | 862b10aa05ef66af5a88f307e6209ce10185bbcf (diff) | |
feat: add reminder service for todos
Adds a background reminder loop that checks every 5 minutes for overdue
todos and sends Telegram notifications.
Changes:
- Add last_reminded_at column to todos table with auto-migration
- Add listTodosDueForReminder to find overdue, unreminded todos
- Add markReminderSent to update reminder timestamp
- Add user_chats table to map user_id -> chat_id for notifications
- Add recordUserChat called on each message to track chat IDs
- Add reminderLoop forked in runTelegramBot
- 24-hour anti-spam interval between reminders per todo
| -rw-r--r-- | Omni/Agent/Telegram.hs | 75 | ||||
| -rw-r--r-- | Omni/Agent/Tools/Todos.hs | 67 |
2 files changed, 133 insertions, 9 deletions
diff --git a/Omni/Agent/Telegram.hs b/Omni/Agent/Telegram.hs index f1c71e6..27b3ccf 100644 --- a/Omni/Agent/Telegram.hs +++ b/Omni/Agent/Telegram.hs @@ -43,6 +43,12 @@ module Omni.Agent.Telegram checkOllama, pullEmbeddingModel, + -- * Reminders + reminderLoop, + checkAndSendReminders, + recordUserChat, + lookupChatId, + -- * System Prompt telegramSystemPrompt, @@ -62,6 +68,7 @@ import qualified Data.Text as Text import Data.Time (getCurrentTime, utcToLocalTime) import Data.Time.Format (defaultTimeLocale, formatTime) import Data.Time.LocalTime (getCurrentTimeZone) +import qualified Database.SQLite.Simple as SQL import qualified Network.HTTP.Client as HTTPClient import qualified Network.HTTP.Simple as HTTP import qualified Omni.Agent.Engine as Engine @@ -578,12 +585,78 @@ telegramSystemPrompt = "ALWAYS include a text response to the user after using tools. never end your turn with only tool calls." ] +initUserChatsTable :: SQL.Connection -> IO () +initUserChatsTable conn = + SQL.execute_ + conn + "CREATE TABLE IF NOT EXISTS user_chats (\ + \ user_id TEXT PRIMARY KEY,\ + \ chat_id INTEGER NOT NULL,\ + \ last_seen_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP\ + \)" + +recordUserChat :: Text -> Int -> IO () +recordUserChat uid chatId = do + now <- getCurrentTime + Memory.withMemoryDb <| \conn -> do + initUserChatsTable conn + SQL.execute + conn + "INSERT INTO user_chats (user_id, chat_id, last_seen_at) \ + \VALUES (?, ?, ?) \ + \ON CONFLICT(user_id) DO UPDATE SET \ + \ chat_id = excluded.chat_id, \ + \ last_seen_at = excluded.last_seen_at" + (uid, chatId, now) + +lookupChatId :: Text -> IO (Maybe Int) +lookupChatId uid = + Memory.withMemoryDb <| \conn -> do + initUserChatsTable conn + rows <- + SQL.query + conn + "SELECT chat_id FROM user_chats WHERE user_id = ?" + (SQL.Only uid) + pure (listToMaybe (map SQL.fromOnly rows)) + +reminderLoop :: TelegramConfig -> IO () +reminderLoop tgConfig = + forever <| do + threadDelay (5 * 60 * 1000000) + checkAndSendReminders tgConfig + +checkAndSendReminders :: TelegramConfig -> IO () +checkAndSendReminders tgConfig = do + todos <- Todos.listTodosDueForReminder + forM_ todos <| \td -> do + mChatId <- lookupChatId (Todos.todoUserId td) + case mChatId of + Nothing -> pure () + Just chatId -> do + let title = Todos.todoTitle td + dueStr = case Todos.todoDueDate td of + Just d -> " (due: " <> tshow d <> ")" + Nothing -> "" + msg = + "⏰ reminder: \"" + <> title + <> "\"" + <> dueStr + <> "\nreply when you finish and i'll mark it complete." + sendMessage tgConfig chatId msg + Todos.markReminderSent (Todos.todoId td) + putText <| "Sent reminder for todo " <> tshow (Todos.todoId td) <> " to chat " <> tshow chatId + -- | Run the Telegram bot main loop. runTelegramBot :: TelegramConfig -> Provider.Provider -> IO () runTelegramBot tgConfig provider = do putText "Starting Telegram bot..." offsetVar <- newTVarIO 0 + _ <- forkIO (reminderLoop tgConfig) + putText "Reminder loop started (checking every 5 minutes)" + let engineCfg = Engine.defaultEngineConfig { Engine.engineOnToolCall = \toolName args -> @@ -639,6 +712,8 @@ handleAuthorizedMessage :: Int -> IO () handleAuthorizedMessage tgConfig provider engineCfg msg uid userName chatId = do + recordUserChat uid chatId + pdfContent <- case tmDocument msg of Just doc | isPdf doc -> do putText <| "Processing PDF: " <> fromMaybe "(unnamed)" (tdFileName doc) diff --git a/Omni/Agent/Tools/Todos.hs b/Omni/Agent/Tools/Todos.hs index 81253c1..4c7d2be 100644 --- a/Omni/Agent/Tools/Todos.hs +++ b/Omni/Agent/Tools/Todos.hs @@ -27,6 +27,11 @@ module Omni.Agent.Tools.Todos completeTodo, deleteTodo, + -- * Reminders + listTodosDueForReminder, + markReminderSent, + reminderInterval, + -- * Database initTodosTable, @@ -40,7 +45,7 @@ import Alpha import Data.Aeson ((.!=), (.:), (.:?), (.=)) import qualified Data.Aeson as Aeson import qualified Data.Text as Text -import Data.Time (UTCTime, getCurrentTime) +import Data.Time (NominalDiffTime, UTCTime, addUTCTime, getCurrentTime) import Data.Time.Format (defaultTimeLocale, parseTimeM) import qualified Database.SQLite.Simple as SQL import qualified Omni.Agent.Engine as Engine @@ -75,7 +80,8 @@ test = todoTitle = "Buy milk", todoDueDate = Just now, todoCompleted = False, - todoCreatedAt = now + todoCreatedAt = now, + todoLastRemindedAt = Nothing } case Aeson.decode (Aeson.encode td) of Nothing -> Test.assertFailure "Failed to decode Todo" @@ -93,7 +99,8 @@ data Todo = Todo todoTitle :: Text, todoDueDate :: Maybe UTCTime, todoCompleted :: Bool, - todoCreatedAt :: UTCTime + todoCreatedAt :: UTCTime, + todoLastRemindedAt :: Maybe UTCTime } deriving (Show, Eq, Generic) @@ -105,7 +112,8 @@ instance Aeson.ToJSON Todo where "title" .= todoTitle td, "due_date" .= todoDueDate td, "completed" .= todoCompleted td, - "created_at" .= todoCreatedAt td + "created_at" .= todoCreatedAt td, + "last_reminded_at" .= todoLastRemindedAt td ] instance Aeson.FromJSON Todo where @@ -117,6 +125,7 @@ instance Aeson.FromJSON Todo where <*> (v .:? "due_date") <*> (v .: "completed") <*> (v .: "created_at") + <*> (v .:? "last_reminded_at") instance SQL.FromRow Todo where fromRow = @@ -126,6 +135,7 @@ instance SQL.FromRow Todo where <*> SQL.field <*> SQL.field <*> SQL.field + <*> SQL.field initTodosTable :: SQL.Connection -> IO () initTodosTable conn = do @@ -137,7 +147,8 @@ initTodosTable conn = do \ title TEXT NOT NULL,\ \ due_date TIMESTAMP,\ \ completed INTEGER NOT NULL DEFAULT 0,\ - \ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP\ + \ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\ + \ last_reminded_at TIMESTAMP\ \)" SQL.execute_ conn @@ -145,6 +156,14 @@ initTodosTable conn = do SQL.execute_ conn "CREATE INDEX IF NOT EXISTS idx_todos_due ON todos(user_id, due_date)" + migrateTodosTable conn + +migrateTodosTable :: SQL.Connection -> IO () +migrateTodosTable conn = do + cols <- SQL.query_ conn "PRAGMA table_info(todos)" :: IO [(Int, Text, Text, Int, Maybe Text, Int)] + let colNames = map (\(_, name, _, _, _, _) -> name) cols + unless ("last_reminded_at" `elem` colNames) <| do + SQL.execute_ conn "ALTER TABLE todos ADD COLUMN last_reminded_at TIMESTAMP" parseDueDate :: Text -> Maybe UTCTime parseDueDate txt = @@ -172,7 +191,8 @@ createTodo uid title maybeDueDateStr = do todoTitle = title, todoDueDate = dueDate, todoCompleted = False, - todoCreatedAt = now + todoCreatedAt = now, + todoLastRemindedAt = Nothing } listTodos :: Text -> Int -> IO [Todo] @@ -181,7 +201,7 @@ listTodos uid limit = initTodosTable conn SQL.query conn - "SELECT id, user_id, title, due_date, completed, created_at \ + "SELECT id, user_id, title, due_date, completed, created_at, last_reminded_at \ \FROM todos WHERE user_id = ? \ \ORDER BY completed ASC, due_date ASC NULLS LAST, created_at DESC LIMIT ?" (uid, limit) @@ -192,7 +212,7 @@ listPendingTodos uid limit = initTodosTable conn SQL.query conn - "SELECT id, user_id, title, due_date, completed, created_at \ + "SELECT id, user_id, title, due_date, completed, created_at, last_reminded_at \ \FROM todos WHERE user_id = ? AND completed = 0 \ \ORDER BY due_date ASC NULLS LAST, created_at DESC LIMIT ?" (uid, limit) @@ -204,11 +224,40 @@ listOverdueTodos uid = do initTodosTable conn SQL.query conn - "SELECT id, user_id, title, due_date, completed, created_at \ + "SELECT id, user_id, title, due_date, completed, created_at, last_reminded_at \ \FROM todos WHERE user_id = ? AND completed = 0 AND due_date < ? \ \ORDER BY due_date ASC" (uid, now) +reminderInterval :: NominalDiffTime +reminderInterval = 24 * 60 * 60 + +listTodosDueForReminder :: IO [Todo] +listTodosDueForReminder = do + now <- getCurrentTime + let cutoff = addUTCTime (negate reminderInterval) now + Memory.withMemoryDb <| \conn -> do + initTodosTable conn + SQL.query + conn + "SELECT id, user_id, title, due_date, completed, created_at, last_reminded_at \ + \FROM todos \ + \WHERE completed = 0 \ + \ AND due_date IS NOT NULL \ + \ AND due_date < ? \ + \ AND (last_reminded_at IS NULL OR last_reminded_at < ?)" + (now, cutoff) + +markReminderSent :: Int -> IO () +markReminderSent tid = do + now <- getCurrentTime + Memory.withMemoryDb <| \conn -> do + initTodosTable conn + SQL.execute + conn + "UPDATE todos SET last_reminded_at = ? WHERE id = ?" + (now, tid) + completeTodo :: Text -> Int -> IO Bool completeTodo uid tid = Memory.withMemoryDb <| \conn -> do |
