diff options
| author | Ben Sima <ben@bensima.com> | 2025-12-19 10:51:13 -0500 |
|---|---|---|
| committer | Ben Sima <ben@bensima.com> | 2025-12-19 10:51:13 -0500 |
| commit | a960a0de7abc50abec51262de8a5871048817f1f (patch) | |
| tree | 973f4fbc8d75031bbb01ed1aec62a23c85fa2919 /Omni/Agent | |
| parent | 533e4209192298de4808c58f6ea6244e4bed5768 (diff) | |
Add semantic search for chat history
- Add chat_history table with embeddings in memory.db
- Add saveChatHistoryEntry for live message ingestion
- Add searchChatHistorySemantic for vector similarity search
- Update search_chat_history tool to use semantic search
- Add backfill command: run.sh Omni/Agent/Memory.hs backfill
- Add stats command: run.sh Omni/Agent/Memory.hs stats
- Change default memory.db path to ~/memory.db
- Wire Telegram message handling to save to chat_history async
Diffstat (limited to 'Omni/Agent')
| -rw-r--r-- | Omni/Agent/Memory.hs | 197 | ||||
| -rw-r--r-- | Omni/Agent/Telegram.hs | 11 | ||||
| -rw-r--r-- | Omni/Agent/Tools/AvaLogs.hs | 65 |
3 files changed, 245 insertions, 28 deletions
diff --git a/Omni/Agent/Memory.hs b/Omni/Agent/Memory.hs index 4aaa438..d59104c 100644 --- a/Omni/Agent/Memory.hs +++ b/Omni/Agent/Memory.hs @@ -31,6 +31,7 @@ module Omni.Agent.Memory MessageRole (..), RelationType (..), MemoryLink (..), + ChatHistoryEntry (..), -- * User Management createUser, @@ -70,6 +71,12 @@ module Omni.Agent.Memory -- * Embeddings embedText, + -- * Chat History (Semantic Search) + saveChatHistoryEntry, + searchChatHistorySemantic, + backfillChatHistory, + getChatHistoryStats, + -- * Agent Integration rememberTool, recallTool, @@ -112,10 +119,26 @@ import qualified Omni.Agent.Engine as Engine import qualified Omni.Test as Test import System.Directory (createDirectoryIfMissing) import System.Environment (lookupEnv) +import qualified System.Environment import System.FilePath (takeDirectory, (</>)) main :: IO () -main = Test.run test +main = do + args <- getArgs' + case args of + ["backfill"] -> do + putText "Running chat history backfill..." + count <- backfillChatHistory + putText <| "Done! Processed " <> tshow count <> " messages." + ["stats"] -> do + (total, indexed) <- getChatHistoryStats + putText <| "Chat history stats:" + putText <| " Total entries: " <> tshow total + putText <| " With embeddings: " <> tshow indexed + _ -> Test.run test + +getArgs' :: IO [String] +getArgs' = System.Environment.getArgs test :: Test.Tree test = @@ -571,7 +594,54 @@ instance SQL.FromRow MemoryLink where linkCreatedAt = createdAt } +data ChatHistoryEntry = ChatHistoryEntry + { cheId :: Text, + cheChatId :: Int, + cheUserId :: Maybe Text, + cheRole :: Text, + cheSenderName :: Maybe Text, + cheContent :: Text, + cheEmbedding :: Maybe (VS.Vector Float), + cheCreatedAt :: UTCTime + } + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON ChatHistoryEntry where + toJSON e = + Aeson.object + [ "id" .= cheId e, + "chat_id" .= cheChatId e, + "user_id" .= cheUserId e, + "role" .= cheRole e, + "sender_name" .= cheSenderName e, + "content" .= cheContent e, + "created_at" .= cheCreatedAt e + ] + +instance SQL.FromRow ChatHistoryEntry where + fromRow = do + entryId <- SQL.field + chatId <- SQL.field + usrId <- SQL.field + role <- SQL.field + senderName <- SQL.field + content <- SQL.field + embeddingBlob <- SQL.field + createdAt <- SQL.field + pure + ChatHistoryEntry + { cheId = entryId, + cheChatId = chatId, + cheUserId = usrId, + cheRole = role, + cheSenderName = senderName, + cheContent = content, + cheEmbedding = blobToVector </ embeddingBlob, + cheCreatedAt = createdAt + } + -- | Get the path to memory.db +-- Priority: MEMORY_DB_PATH env var > $HOME/memory.db > _/memory.db getMemoryDbPath :: IO FilePath getMemoryDbPath = do maybeEnv <- lookupEnv "MEMORY_DB_PATH" @@ -580,7 +650,7 @@ getMemoryDbPath = do Nothing -> do home <- lookupEnv "HOME" case home of - Just h -> pure (h </> ".local/share/omni/memory.db") + Just h -> pure (h </> "memory.db") Nothing -> pure "_/memory.db" -- | Run an action with the memory database connection. @@ -707,6 +777,24 @@ initMemoryDb conn = do SQL.execute_ conn "CREATE INDEX IF NOT EXISTS idx_memory_links_type ON memory_links(relation_type)" + SQL.execute_ + conn + "CREATE TABLE IF NOT EXISTS chat_history (\ + \ id TEXT PRIMARY KEY,\ + \ chat_id INTEGER NOT NULL,\ + \ user_id TEXT,\ + \ role TEXT NOT NULL,\ + \ sender_name TEXT,\ + \ content TEXT NOT NULL,\ + \ embedding BLOB,\ + \ created_at TIMESTAMP NOT NULL\ + \)" + SQL.execute_ + conn + "CREATE INDEX IF NOT EXISTS idx_chat_history_chat ON chat_history(chat_id)" + SQL.execute_ + conn + "CREATE INDEX IF NOT EXISTS idx_chat_history_time ON chat_history(created_at)" -- | Migrate conversation_messages to add sender_name and thread_id columns. migrateConversationMessages :: SQL.Connection -> IO () @@ -1573,3 +1661,108 @@ storeGroupMemory chatId = storeMemory (groupUserId chatId) -- | Recall memories for a group. recallGroupMemories :: Int -> Text -> Int -> IO [Memory] recallGroupMemories chatId = recallMemories (groupUserId chatId) + +-- | Save a chat history entry with embedding for semantic search. +saveChatHistoryEntry :: Int -> Maybe Text -> Text -> Maybe Text -> Text -> UTCTime -> IO () +saveChatHistoryEntry chatId usrId role senderName content timestamp = do + uuid <- UUID.nextRandom + let entryId = UUID.toText uuid + embedding <- embedText content + withMemoryDb <| \conn -> + SQL.execute + conn + "INSERT OR IGNORE INTO chat_history (id, chat_id, user_id, role, sender_name, content, embedding, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)" + ( entryId, + chatId, + usrId, + role, + senderName, + content, + vectorToBlob </ either (const Nothing) Just embedding, + timestamp + ) + +-- | Semantic search over chat history. +searchChatHistorySemantic :: Text -> Int -> IO [(ChatHistoryEntry, Float)] +searchChatHistorySemantic query limit = do + queryEmbedding <- embedText query + case queryEmbedding of + Left _ -> pure [] + Right qEmb -> do + allEntries <- + withMemoryDb <| \conn -> + SQL.query_ + conn + "SELECT id, chat_id, user_id, role, sender_name, content, embedding, created_at \ + \FROM chat_history WHERE embedding IS NOT NULL ORDER BY created_at DESC LIMIT 10000" + let scored = + [ (e, cosineSimilarity qEmb emb) + | e <- allEntries, + Just emb <- [cheEmbedding e] + ] + sorted = List.sortBy (\(_, s1) (_, s2) -> compare s2 s1) scored + pure (take limit sorted) + +-- | Get chat history stats. +getChatHistoryStats :: IO (Int, Int) +getChatHistoryStats = + withMemoryDb <| \conn -> do + [[total]] <- SQL.query_ conn "SELECT COUNT(*) FROM chat_history" + [[withEmb]] <- SQL.query_ conn "SELECT COUNT(*) FROM chat_history WHERE embedding IS NOT NULL" + pure (total, withEmb) + +-- | Backfill chat history from conversation_messages table. +-- Creates embeddings for existing messages that don't have them yet. +backfillChatHistory :: IO Int +backfillChatHistory = do + putText "Starting chat history backfill from conversation_messages..." + messages <- getAllConversationMessages + putText <| "Found " <> tshow (length messages) <> " messages to process" + processed <- backfillMessages messages 0 + putText <| "Backfill complete: " <> tshow processed <> " messages processed" + pure processed + +-- | Get all conversation messages for backfill. +getAllConversationMessages :: IO [ConversationMessage] +getAllConversationMessages = + withMemoryDb <| \conn -> + SQL.query_ + conn + "SELECT id, user_id, chat_id, role, sender_name, content, tokens_estimate, created_at \ + \FROM conversation_messages ORDER BY created_at ASC" + +-- | Backfill messages one by one, showing progress. +backfillMessages :: [ConversationMessage] -> Int -> IO Int +backfillMessages [] count = pure count +backfillMessages (msg : rest) count = do + let chatId = cmChatId msg + uid = cmUserId msg + role = case cmRole msg of + UserRole -> "user" + AssistantRole -> "assistant" + senderName = cmSenderName msg + content = cmContent msg + timestamp = cmCreatedAt msg + alreadyExists <- checkChatHistoryExists chatId timestamp content + if alreadyExists + then backfillMessages rest count + else do + when (count `mod` 100 == 0) + <| putText + <| "Progress: " + <> tshow count + <> " messages processed..." + saveChatHistoryEntry chatId (Just uid) role senderName content timestamp + backfillMessages rest (count + 1) + +-- | Check if a chat history entry already exists. +checkChatHistoryExists :: Int -> UTCTime -> Text -> IO Bool +checkChatHistoryExists chatId timestamp content = + withMemoryDb <| \conn -> do + results <- + SQL.query + conn + "SELECT 1 FROM chat_history WHERE chat_id = ? AND created_at = ? AND content = ? LIMIT 1" + (chatId, timestamp, content) :: + IO [[Int]] + pure (not (null results)) diff --git a/Omni/Agent/Telegram.hs b/Omni/Agent/Telegram.hs index e59570a..de84e49 100644 --- a/Omni/Agent/Telegram.hs +++ b/Omni/Agent/Telegram.hs @@ -867,6 +867,9 @@ handleAuthorizedMessageContinued tgConfig provider engineCfg msg uid userName ch Memory.getConversationContext uid chatId maxConversationTokens putText <| "Conversation context: " <> tshow contextTokens <> " tokens" + now <- getCurrentTime + _ <- forkIO <| Memory.saveChatHistoryEntry chatId (Just uid) "user" (Just userName) userMessage now + processEngagedMessage tgConfig provider engineCfg msg uid userName chatId userMessage conversationContext handleAuthorizedMessageBatch :: @@ -999,6 +1002,9 @@ handleAuthorizedMessageBatch tgConfig provider engineCfg msg uid userName chatId Memory.getConversationContext uid chatId maxConversationTokens putText <| "Conversation context: " <> tshow contextTokens <> " tokens" + now <- getCurrentTime + _ <- forkIO <| Memory.saveChatHistoryEntry chatId (Just uid) "user" (Just userName) userMessage now + processEngagedMessage tgConfig provider engineCfg msg uid userName chatId userMessage conversationContext processEngagedMessage :: @@ -1212,6 +1218,11 @@ processEngagedMessage tgConfig provider engineCfg msg uid userName chatId userMe then void <| Memory.saveGroupMessage chatId threadId Memory.AssistantRole "Ava" response else void <| Memory.saveMessage uid chatId Memory.AssistantRole Nothing response + unless (Text.null response) <| do + nowResp <- getCurrentTime + _ <- forkIO <| Memory.saveChatHistoryEntry chatId (Just uid) "assistant" (Just "Ava") response nowResp + pure () + if Text.null response then do if isGroup diff --git a/Omni/Agent/Tools/AvaLogs.hs b/Omni/Agent/Tools/AvaLogs.hs index 84c9db8..4c8ce11 100644 --- a/Omni/Agent/Tools/AvaLogs.hs +++ b/Omni/Agent/Tools/AvaLogs.hs @@ -24,6 +24,7 @@ import qualified Data.Text as Text import qualified Data.Time as Time import qualified Omni.Agent.AuditLog as AuditLog import qualified Omni.Agent.Engine as Engine +import qualified Omni.Agent.Memory as Memory main :: IO () main = putText "Omni.Agent.Tools.AvaLogs - no standalone execution" @@ -114,9 +115,9 @@ searchChatHistoryTool = Engine.Tool { Engine.toolName = "search_chat_history", Engine.toolDescription = - "Search your conversation history for specific content. " + "Search your conversation history using semantic similarity. " <> "Use this to find what was said in past conversations, recall context, " - <> "or find when something was discussed. Searches message content.", + <> "or find when something was discussed. Finds semantically related messages.", Engine.toolJsonSchema = Aeson.object [ "type" .= ("object" :: Text), @@ -125,12 +126,7 @@ searchChatHistoryTool = [ "query" .= Aeson.object [ "type" .= ("string" :: Text), - "description" .= ("Text to search for in chat history" :: Text) - ], - "days_back" - .= Aeson.object - [ "type" .= ("integer" :: Text), - "description" .= ("How many days back to search (default: 7)" :: Text) + "description" .= ("What to search for (semantic search)" :: Text) ], "max_results" .= Aeson.object @@ -151,12 +147,6 @@ executeSearchHistory v = do _ -> "" _ -> "" - let daysBack = case v of - Aeson.Object obj -> case KeyMap.lookup "days_back" obj of - Just (Aeson.Number n) -> round n - _ -> 7 - _ -> 7 - let maxResults = case v of Aeson.Object obj -> case KeyMap.lookup "max_results" obj of Just (Aeson.Number n) -> round n @@ -166,18 +156,41 @@ executeSearchHistory v = do if Text.null query then pure <| Aeson.object ["error" .= ("query is required" :: Text)] else do - today <- Time.utctDay </ Time.getCurrentTime - let days = [Time.addDays (negate i) today | i <- [0 .. daysBack - 1]] - allEntries <- concat </ traverse AuditLog.readAvaLogs days - let matches = filter (matchesQuery query) allEntries - limited = take maxResults matches - pure - <| Aeson.object - [ "query" .= query, - "days_searched" .= daysBack, - "total_matches" .= length matches, - "results" .= map formatSearchResult limited - ] + (_total, indexed) <- Memory.getChatHistoryStats + if indexed > 0 + then do + results <- Memory.searchChatHistorySemantic query maxResults + pure + <| Aeson.object + [ "query" .= query, + "search_type" .= ("semantic" :: Text), + "indexed_messages" .= indexed, + "results" .= map formatSemanticResult results + ] + else do + today <- Time.utctDay </ Time.getCurrentTime + let days = [Time.addDays (negate i) today | i <- [0 .. 6]] + allEntries <- concat </ traverse AuditLog.readAvaLogs days + let matches = filter (matchesQuery query) allEntries + limited = take maxResults matches + pure + <| Aeson.object + [ "query" .= query, + "search_type" .= ("keyword_fallback" :: Text), + "note" .= ("no indexed history - run backfill" :: Text), + "total_matches" .= length matches, + "results" .= map formatSearchResult limited + ] + +formatSemanticResult :: (Memory.ChatHistoryEntry, Float) -> Aeson.Value +formatSemanticResult (entry, score) = + Aeson.object + [ "timestamp" .= Time.formatTime Time.defaultTimeLocale "%Y-%m-%d %H:%M:%S" (Memory.cheCreatedAt entry), + "role" .= Memory.cheRole entry, + "sender" .= Memory.cheSenderName entry, + "similarity" .= score, + "content" .= Text.take 500 (Memory.cheContent entry) + ] matchesQuery :: Text -> AuditLog.AuditLogEntry -> Bool matchesQuery query entry = |
