diff options
| author | Ben Sima <ben@bensima.com> | 2025-12-19 10:51:13 -0500 |
|---|---|---|
| committer | Ben Sima <ben@bensima.com> | 2025-12-19 10:51:13 -0500 |
| commit | a960a0de7abc50abec51262de8a5871048817f1f (patch) | |
| tree | 973f4fbc8d75031bbb01ed1aec62a23c85fa2919 /Omni/Agent/Memory.hs | |
| parent | 533e4209192298de4808c58f6ea6244e4bed5768 (diff) | |
Add semantic search for chat history
- Add chat_history table with embeddings in memory.db
- Add saveChatHistoryEntry for live message ingestion
- Add searchChatHistorySemantic for vector similarity search
- Update search_chat_history tool to use semantic search
- Add backfill command: run.sh Omni/Agent/Memory.hs backfill
- Add stats command: run.sh Omni/Agent/Memory.hs stats
- Change default memory.db path to ~/memory.db
- Wire Telegram message handling to save to chat_history async
Diffstat (limited to 'Omni/Agent/Memory.hs')
| -rw-r--r-- | Omni/Agent/Memory.hs | 197 |
1 files changed, 195 insertions, 2 deletions
diff --git a/Omni/Agent/Memory.hs b/Omni/Agent/Memory.hs index 4aaa438..d59104c 100644 --- a/Omni/Agent/Memory.hs +++ b/Omni/Agent/Memory.hs @@ -31,6 +31,7 @@ module Omni.Agent.Memory MessageRole (..), RelationType (..), MemoryLink (..), + ChatHistoryEntry (..), -- * User Management createUser, @@ -70,6 +71,12 @@ module Omni.Agent.Memory -- * Embeddings embedText, + -- * Chat History (Semantic Search) + saveChatHistoryEntry, + searchChatHistorySemantic, + backfillChatHistory, + getChatHistoryStats, + -- * Agent Integration rememberTool, recallTool, @@ -112,10 +119,26 @@ import qualified Omni.Agent.Engine as Engine import qualified Omni.Test as Test import System.Directory (createDirectoryIfMissing) import System.Environment (lookupEnv) +import qualified System.Environment import System.FilePath (takeDirectory, (</>)) main :: IO () -main = Test.run test +main = do + args <- getArgs' + case args of + ["backfill"] -> do + putText "Running chat history backfill..." + count <- backfillChatHistory + putText <| "Done! Processed " <> tshow count <> " messages." + ["stats"] -> do + (total, indexed) <- getChatHistoryStats + putText <| "Chat history stats:" + putText <| " Total entries: " <> tshow total + putText <| " With embeddings: " <> tshow indexed + _ -> Test.run test + +getArgs' :: IO [String] +getArgs' = System.Environment.getArgs test :: Test.Tree test = @@ -571,7 +594,54 @@ instance SQL.FromRow MemoryLink where linkCreatedAt = createdAt } +data ChatHistoryEntry = ChatHistoryEntry + { cheId :: Text, + cheChatId :: Int, + cheUserId :: Maybe Text, + cheRole :: Text, + cheSenderName :: Maybe Text, + cheContent :: Text, + cheEmbedding :: Maybe (VS.Vector Float), + cheCreatedAt :: UTCTime + } + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON ChatHistoryEntry where + toJSON e = + Aeson.object + [ "id" .= cheId e, + "chat_id" .= cheChatId e, + "user_id" .= cheUserId e, + "role" .= cheRole e, + "sender_name" .= cheSenderName e, + "content" .= cheContent e, + "created_at" .= cheCreatedAt e + ] + +instance SQL.FromRow ChatHistoryEntry where + fromRow = do + entryId <- SQL.field + chatId <- SQL.field + usrId <- SQL.field + role <- SQL.field + senderName <- SQL.field + content <- SQL.field + embeddingBlob <- SQL.field + createdAt <- SQL.field + pure + ChatHistoryEntry + { cheId = entryId, + cheChatId = chatId, + cheUserId = usrId, + cheRole = role, + cheSenderName = senderName, + cheContent = content, + cheEmbedding = blobToVector </ embeddingBlob, + cheCreatedAt = createdAt + } + -- | Get the path to memory.db +-- Priority: MEMORY_DB_PATH env var > $HOME/memory.db > _/memory.db getMemoryDbPath :: IO FilePath getMemoryDbPath = do maybeEnv <- lookupEnv "MEMORY_DB_PATH" @@ -580,7 +650,7 @@ getMemoryDbPath = do Nothing -> do home <- lookupEnv "HOME" case home of - Just h -> pure (h </> ".local/share/omni/memory.db") + Just h -> pure (h </> "memory.db") Nothing -> pure "_/memory.db" -- | Run an action with the memory database connection. @@ -707,6 +777,24 @@ initMemoryDb conn = do SQL.execute_ conn "CREATE INDEX IF NOT EXISTS idx_memory_links_type ON memory_links(relation_type)" + SQL.execute_ + conn + "CREATE TABLE IF NOT EXISTS chat_history (\ + \ id TEXT PRIMARY KEY,\ + \ chat_id INTEGER NOT NULL,\ + \ user_id TEXT,\ + \ role TEXT NOT NULL,\ + \ sender_name TEXT,\ + \ content TEXT NOT NULL,\ + \ embedding BLOB,\ + \ created_at TIMESTAMP NOT NULL\ + \)" + SQL.execute_ + conn + "CREATE INDEX IF NOT EXISTS idx_chat_history_chat ON chat_history(chat_id)" + SQL.execute_ + conn + "CREATE INDEX IF NOT EXISTS idx_chat_history_time ON chat_history(created_at)" -- | Migrate conversation_messages to add sender_name and thread_id columns. migrateConversationMessages :: SQL.Connection -> IO () @@ -1573,3 +1661,108 @@ storeGroupMemory chatId = storeMemory (groupUserId chatId) -- | Recall memories for a group. recallGroupMemories :: Int -> Text -> Int -> IO [Memory] recallGroupMemories chatId = recallMemories (groupUserId chatId) + +-- | Save a chat history entry with embedding for semantic search. +saveChatHistoryEntry :: Int -> Maybe Text -> Text -> Maybe Text -> Text -> UTCTime -> IO () +saveChatHistoryEntry chatId usrId role senderName content timestamp = do + uuid <- UUID.nextRandom + let entryId = UUID.toText uuid + embedding <- embedText content + withMemoryDb <| \conn -> + SQL.execute + conn + "INSERT OR IGNORE INTO chat_history (id, chat_id, user_id, role, sender_name, content, embedding, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)" + ( entryId, + chatId, + usrId, + role, + senderName, + content, + vectorToBlob </ either (const Nothing) Just embedding, + timestamp + ) + +-- | Semantic search over chat history. +searchChatHistorySemantic :: Text -> Int -> IO [(ChatHistoryEntry, Float)] +searchChatHistorySemantic query limit = do + queryEmbedding <- embedText query + case queryEmbedding of + Left _ -> pure [] + Right qEmb -> do + allEntries <- + withMemoryDb <| \conn -> + SQL.query_ + conn + "SELECT id, chat_id, user_id, role, sender_name, content, embedding, created_at \ + \FROM chat_history WHERE embedding IS NOT NULL ORDER BY created_at DESC LIMIT 10000" + let scored = + [ (e, cosineSimilarity qEmb emb) + | e <- allEntries, + Just emb <- [cheEmbedding e] + ] + sorted = List.sortBy (\(_, s1) (_, s2) -> compare s2 s1) scored + pure (take limit sorted) + +-- | Get chat history stats. +getChatHistoryStats :: IO (Int, Int) +getChatHistoryStats = + withMemoryDb <| \conn -> do + [[total]] <- SQL.query_ conn "SELECT COUNT(*) FROM chat_history" + [[withEmb]] <- SQL.query_ conn "SELECT COUNT(*) FROM chat_history WHERE embedding IS NOT NULL" + pure (total, withEmb) + +-- | Backfill chat history from conversation_messages table. +-- Creates embeddings for existing messages that don't have them yet. +backfillChatHistory :: IO Int +backfillChatHistory = do + putText "Starting chat history backfill from conversation_messages..." + messages <- getAllConversationMessages + putText <| "Found " <> tshow (length messages) <> " messages to process" + processed <- backfillMessages messages 0 + putText <| "Backfill complete: " <> tshow processed <> " messages processed" + pure processed + +-- | Get all conversation messages for backfill. +getAllConversationMessages :: IO [ConversationMessage] +getAllConversationMessages = + withMemoryDb <| \conn -> + SQL.query_ + conn + "SELECT id, user_id, chat_id, role, sender_name, content, tokens_estimate, created_at \ + \FROM conversation_messages ORDER BY created_at ASC" + +-- | Backfill messages one by one, showing progress. +backfillMessages :: [ConversationMessage] -> Int -> IO Int +backfillMessages [] count = pure count +backfillMessages (msg : rest) count = do + let chatId = cmChatId msg + uid = cmUserId msg + role = case cmRole msg of + UserRole -> "user" + AssistantRole -> "assistant" + senderName = cmSenderName msg + content = cmContent msg + timestamp = cmCreatedAt msg + alreadyExists <- checkChatHistoryExists chatId timestamp content + if alreadyExists + then backfillMessages rest count + else do + when (count `mod` 100 == 0) + <| putText + <| "Progress: " + <> tshow count + <> " messages processed..." + saveChatHistoryEntry chatId (Just uid) role senderName content timestamp + backfillMessages rest (count + 1) + +-- | Check if a chat history entry already exists. +checkChatHistoryExists :: Int -> UTCTime -> Text -> IO Bool +checkChatHistoryExists chatId timestamp content = + withMemoryDb <| \conn -> do + results <- + SQL.query + conn + "SELECT 1 FROM chat_history WHERE chat_id = ? AND created_at = ? AND content = ? LIMIT 1" + (chatId, timestamp, content) :: + IO [[Int]] + pure (not (null results)) |
