summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBen Sima <ben@bensima.com>2025-12-19 10:51:13 -0500
committerBen Sima <ben@bensima.com>2025-12-19 10:51:13 -0500
commita960a0de7abc50abec51262de8a5871048817f1f (patch)
tree973f4fbc8d75031bbb01ed1aec62a23c85fa2919
parent533e4209192298de4808c58f6ea6244e4bed5768 (diff)
Add semantic search for chat history
- Add chat_history table with embeddings in memory.db - Add saveChatHistoryEntry for live message ingestion - Add searchChatHistorySemantic for vector similarity search - Update search_chat_history tool to use semantic search - Add backfill command: run.sh Omni/Agent/Memory.hs backfill - Add stats command: run.sh Omni/Agent/Memory.hs stats - Change default memory.db path to ~/memory.db - Wire Telegram message handling to save to chat_history async
-rw-r--r--Omni/Agent/Memory.hs197
-rw-r--r--Omni/Agent/Telegram.hs11
-rw-r--r--Omni/Agent/Tools/AvaLogs.hs65
3 files changed, 245 insertions, 28 deletions
diff --git a/Omni/Agent/Memory.hs b/Omni/Agent/Memory.hs
index 4aaa438..d59104c 100644
--- a/Omni/Agent/Memory.hs
+++ b/Omni/Agent/Memory.hs
@@ -31,6 +31,7 @@ module Omni.Agent.Memory
MessageRole (..),
RelationType (..),
MemoryLink (..),
+ ChatHistoryEntry (..),
-- * User Management
createUser,
@@ -70,6 +71,12 @@ module Omni.Agent.Memory
-- * Embeddings
embedText,
+ -- * Chat History (Semantic Search)
+ saveChatHistoryEntry,
+ searchChatHistorySemantic,
+ backfillChatHistory,
+ getChatHistoryStats,
+
-- * Agent Integration
rememberTool,
recallTool,
@@ -112,10 +119,26 @@ import qualified Omni.Agent.Engine as Engine
import qualified Omni.Test as Test
import System.Directory (createDirectoryIfMissing)
import System.Environment (lookupEnv)
+import qualified System.Environment
import System.FilePath (takeDirectory, (</>))
main :: IO ()
-main = Test.run test
+main = do
+ args <- getArgs'
+ case args of
+ ["backfill"] -> do
+ putText "Running chat history backfill..."
+ count <- backfillChatHistory
+ putText <| "Done! Processed " <> tshow count <> " messages."
+ ["stats"] -> do
+ (total, indexed) <- getChatHistoryStats
+ putText <| "Chat history stats:"
+ putText <| " Total entries: " <> tshow total
+ putText <| " With embeddings: " <> tshow indexed
+ _ -> Test.run test
+
+getArgs' :: IO [String]
+getArgs' = System.Environment.getArgs
test :: Test.Tree
test =
@@ -571,7 +594,54 @@ instance SQL.FromRow MemoryLink where
linkCreatedAt = createdAt
}
+data ChatHistoryEntry = ChatHistoryEntry
+ { cheId :: Text,
+ cheChatId :: Int,
+ cheUserId :: Maybe Text,
+ cheRole :: Text,
+ cheSenderName :: Maybe Text,
+ cheContent :: Text,
+ cheEmbedding :: Maybe (VS.Vector Float),
+ cheCreatedAt :: UTCTime
+ }
+ deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON ChatHistoryEntry where
+ toJSON e =
+ Aeson.object
+ [ "id" .= cheId e,
+ "chat_id" .= cheChatId e,
+ "user_id" .= cheUserId e,
+ "role" .= cheRole e,
+ "sender_name" .= cheSenderName e,
+ "content" .= cheContent e,
+ "created_at" .= cheCreatedAt e
+ ]
+
+instance SQL.FromRow ChatHistoryEntry where
+ fromRow = do
+ entryId <- SQL.field
+ chatId <- SQL.field
+ usrId <- SQL.field
+ role <- SQL.field
+ senderName <- SQL.field
+ content <- SQL.field
+ embeddingBlob <- SQL.field
+ createdAt <- SQL.field
+ pure
+ ChatHistoryEntry
+ { cheId = entryId,
+ cheChatId = chatId,
+ cheUserId = usrId,
+ cheRole = role,
+ cheSenderName = senderName,
+ cheContent = content,
+ cheEmbedding = blobToVector </ embeddingBlob,
+ cheCreatedAt = createdAt
+ }
+
-- | Get the path to memory.db
+-- Priority: MEMORY_DB_PATH env var > $HOME/memory.db > _/memory.db
getMemoryDbPath :: IO FilePath
getMemoryDbPath = do
maybeEnv <- lookupEnv "MEMORY_DB_PATH"
@@ -580,7 +650,7 @@ getMemoryDbPath = do
Nothing -> do
home <- lookupEnv "HOME"
case home of
- Just h -> pure (h </> ".local/share/omni/memory.db")
+ Just h -> pure (h </> "memory.db")
Nothing -> pure "_/memory.db"
-- | Run an action with the memory database connection.
@@ -707,6 +777,24 @@ initMemoryDb conn = do
SQL.execute_
conn
"CREATE INDEX IF NOT EXISTS idx_memory_links_type ON memory_links(relation_type)"
+ SQL.execute_
+ conn
+ "CREATE TABLE IF NOT EXISTS chat_history (\
+ \ id TEXT PRIMARY KEY,\
+ \ chat_id INTEGER NOT NULL,\
+ \ user_id TEXT,\
+ \ role TEXT NOT NULL,\
+ \ sender_name TEXT,\
+ \ content TEXT NOT NULL,\
+ \ embedding BLOB,\
+ \ created_at TIMESTAMP NOT NULL\
+ \)"
+ SQL.execute_
+ conn
+ "CREATE INDEX IF NOT EXISTS idx_chat_history_chat ON chat_history(chat_id)"
+ SQL.execute_
+ conn
+ "CREATE INDEX IF NOT EXISTS idx_chat_history_time ON chat_history(created_at)"
-- | Migrate conversation_messages to add sender_name and thread_id columns.
migrateConversationMessages :: SQL.Connection -> IO ()
@@ -1573,3 +1661,108 @@ storeGroupMemory chatId = storeMemory (groupUserId chatId)
-- | Recall memories for a group.
recallGroupMemories :: Int -> Text -> Int -> IO [Memory]
recallGroupMemories chatId = recallMemories (groupUserId chatId)
+
+-- | Save a chat history entry with embedding for semantic search.
+saveChatHistoryEntry :: Int -> Maybe Text -> Text -> Maybe Text -> Text -> UTCTime -> IO ()
+saveChatHistoryEntry chatId usrId role senderName content timestamp = do
+ uuid <- UUID.nextRandom
+ let entryId = UUID.toText uuid
+ embedding <- embedText content
+ withMemoryDb <| \conn ->
+ SQL.execute
+ conn
+ "INSERT OR IGNORE INTO chat_history (id, chat_id, user_id, role, sender_name, content, embedding, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)"
+ ( entryId,
+ chatId,
+ usrId,
+ role,
+ senderName,
+ content,
+ vectorToBlob </ either (const Nothing) Just embedding,
+ timestamp
+ )
+
+-- | Semantic search over chat history.
+searchChatHistorySemantic :: Text -> Int -> IO [(ChatHistoryEntry, Float)]
+searchChatHistorySemantic query limit = do
+ queryEmbedding <- embedText query
+ case queryEmbedding of
+ Left _ -> pure []
+ Right qEmb -> do
+ allEntries <-
+ withMemoryDb <| \conn ->
+ SQL.query_
+ conn
+ "SELECT id, chat_id, user_id, role, sender_name, content, embedding, created_at \
+ \FROM chat_history WHERE embedding IS NOT NULL ORDER BY created_at DESC LIMIT 10000"
+ let scored =
+ [ (e, cosineSimilarity qEmb emb)
+ | e <- allEntries,
+ Just emb <- [cheEmbedding e]
+ ]
+ sorted = List.sortBy (\(_, s1) (_, s2) -> compare s2 s1) scored
+ pure (take limit sorted)
+
+-- | Get chat history stats.
+getChatHistoryStats :: IO (Int, Int)
+getChatHistoryStats =
+ withMemoryDb <| \conn -> do
+ [[total]] <- SQL.query_ conn "SELECT COUNT(*) FROM chat_history"
+ [[withEmb]] <- SQL.query_ conn "SELECT COUNT(*) FROM chat_history WHERE embedding IS NOT NULL"
+ pure (total, withEmb)
+
+-- | Backfill chat history from conversation_messages table.
+-- Creates embeddings for existing messages that don't have them yet.
+backfillChatHistory :: IO Int
+backfillChatHistory = do
+ putText "Starting chat history backfill from conversation_messages..."
+ messages <- getAllConversationMessages
+ putText <| "Found " <> tshow (length messages) <> " messages to process"
+ processed <- backfillMessages messages 0
+ putText <| "Backfill complete: " <> tshow processed <> " messages processed"
+ pure processed
+
+-- | Get all conversation messages for backfill.
+getAllConversationMessages :: IO [ConversationMessage]
+getAllConversationMessages =
+ withMemoryDb <| \conn ->
+ SQL.query_
+ conn
+ "SELECT id, user_id, chat_id, role, sender_name, content, tokens_estimate, created_at \
+ \FROM conversation_messages ORDER BY created_at ASC"
+
+-- | Backfill messages one by one, showing progress.
+backfillMessages :: [ConversationMessage] -> Int -> IO Int
+backfillMessages [] count = pure count
+backfillMessages (msg : rest) count = do
+ let chatId = cmChatId msg
+ uid = cmUserId msg
+ role = case cmRole msg of
+ UserRole -> "user"
+ AssistantRole -> "assistant"
+ senderName = cmSenderName msg
+ content = cmContent msg
+ timestamp = cmCreatedAt msg
+ alreadyExists <- checkChatHistoryExists chatId timestamp content
+ if alreadyExists
+ then backfillMessages rest count
+ else do
+ when (count `mod` 100 == 0)
+ <| putText
+ <| "Progress: "
+ <> tshow count
+ <> " messages processed..."
+ saveChatHistoryEntry chatId (Just uid) role senderName content timestamp
+ backfillMessages rest (count + 1)
+
+-- | Check if a chat history entry already exists.
+checkChatHistoryExists :: Int -> UTCTime -> Text -> IO Bool
+checkChatHistoryExists chatId timestamp content =
+ withMemoryDb <| \conn -> do
+ results <-
+ SQL.query
+ conn
+ "SELECT 1 FROM chat_history WHERE chat_id = ? AND created_at = ? AND content = ? LIMIT 1"
+ (chatId, timestamp, content) ::
+ IO [[Int]]
+ pure (not (null results))
diff --git a/Omni/Agent/Telegram.hs b/Omni/Agent/Telegram.hs
index e59570a..de84e49 100644
--- a/Omni/Agent/Telegram.hs
+++ b/Omni/Agent/Telegram.hs
@@ -867,6 +867,9 @@ handleAuthorizedMessageContinued tgConfig provider engineCfg msg uid userName ch
Memory.getConversationContext uid chatId maxConversationTokens
putText <| "Conversation context: " <> tshow contextTokens <> " tokens"
+ now <- getCurrentTime
+ _ <- forkIO <| Memory.saveChatHistoryEntry chatId (Just uid) "user" (Just userName) userMessage now
+
processEngagedMessage tgConfig provider engineCfg msg uid userName chatId userMessage conversationContext
handleAuthorizedMessageBatch ::
@@ -999,6 +1002,9 @@ handleAuthorizedMessageBatch tgConfig provider engineCfg msg uid userName chatId
Memory.getConversationContext uid chatId maxConversationTokens
putText <| "Conversation context: " <> tshow contextTokens <> " tokens"
+ now <- getCurrentTime
+ _ <- forkIO <| Memory.saveChatHistoryEntry chatId (Just uid) "user" (Just userName) userMessage now
+
processEngagedMessage tgConfig provider engineCfg msg uid userName chatId userMessage conversationContext
processEngagedMessage ::
@@ -1212,6 +1218,11 @@ processEngagedMessage tgConfig provider engineCfg msg uid userName chatId userMe
then void <| Memory.saveGroupMessage chatId threadId Memory.AssistantRole "Ava" response
else void <| Memory.saveMessage uid chatId Memory.AssistantRole Nothing response
+ unless (Text.null response) <| do
+ nowResp <- getCurrentTime
+ _ <- forkIO <| Memory.saveChatHistoryEntry chatId (Just uid) "assistant" (Just "Ava") response nowResp
+ pure ()
+
if Text.null response
then do
if isGroup
diff --git a/Omni/Agent/Tools/AvaLogs.hs b/Omni/Agent/Tools/AvaLogs.hs
index 84c9db8..4c8ce11 100644
--- a/Omni/Agent/Tools/AvaLogs.hs
+++ b/Omni/Agent/Tools/AvaLogs.hs
@@ -24,6 +24,7 @@ import qualified Data.Text as Text
import qualified Data.Time as Time
import qualified Omni.Agent.AuditLog as AuditLog
import qualified Omni.Agent.Engine as Engine
+import qualified Omni.Agent.Memory as Memory
main :: IO ()
main = putText "Omni.Agent.Tools.AvaLogs - no standalone execution"
@@ -114,9 +115,9 @@ searchChatHistoryTool =
Engine.Tool
{ Engine.toolName = "search_chat_history",
Engine.toolDescription =
- "Search your conversation history for specific content. "
+ "Search your conversation history using semantic similarity. "
<> "Use this to find what was said in past conversations, recall context, "
- <> "or find when something was discussed. Searches message content.",
+ <> "or find when something was discussed. Finds semantically related messages.",
Engine.toolJsonSchema =
Aeson.object
[ "type" .= ("object" :: Text),
@@ -125,12 +126,7 @@ searchChatHistoryTool =
[ "query"
.= Aeson.object
[ "type" .= ("string" :: Text),
- "description" .= ("Text to search for in chat history" :: Text)
- ],
- "days_back"
- .= Aeson.object
- [ "type" .= ("integer" :: Text),
- "description" .= ("How many days back to search (default: 7)" :: Text)
+ "description" .= ("What to search for (semantic search)" :: Text)
],
"max_results"
.= Aeson.object
@@ -151,12 +147,6 @@ executeSearchHistory v = do
_ -> ""
_ -> ""
- let daysBack = case v of
- Aeson.Object obj -> case KeyMap.lookup "days_back" obj of
- Just (Aeson.Number n) -> round n
- _ -> 7
- _ -> 7
-
let maxResults = case v of
Aeson.Object obj -> case KeyMap.lookup "max_results" obj of
Just (Aeson.Number n) -> round n
@@ -166,18 +156,41 @@ executeSearchHistory v = do
if Text.null query
then pure <| Aeson.object ["error" .= ("query is required" :: Text)]
else do
- today <- Time.utctDay </ Time.getCurrentTime
- let days = [Time.addDays (negate i) today | i <- [0 .. daysBack - 1]]
- allEntries <- concat </ traverse AuditLog.readAvaLogs days
- let matches = filter (matchesQuery query) allEntries
- limited = take maxResults matches
- pure
- <| Aeson.object
- [ "query" .= query,
- "days_searched" .= daysBack,
- "total_matches" .= length matches,
- "results" .= map formatSearchResult limited
- ]
+ (_total, indexed) <- Memory.getChatHistoryStats
+ if indexed > 0
+ then do
+ results <- Memory.searchChatHistorySemantic query maxResults
+ pure
+ <| Aeson.object
+ [ "query" .= query,
+ "search_type" .= ("semantic" :: Text),
+ "indexed_messages" .= indexed,
+ "results" .= map formatSemanticResult results
+ ]
+ else do
+ today <- Time.utctDay </ Time.getCurrentTime
+ let days = [Time.addDays (negate i) today | i <- [0 .. 6]]
+ allEntries <- concat </ traverse AuditLog.readAvaLogs days
+ let matches = filter (matchesQuery query) allEntries
+ limited = take maxResults matches
+ pure
+ <| Aeson.object
+ [ "query" .= query,
+ "search_type" .= ("keyword_fallback" :: Text),
+ "note" .= ("no indexed history - run backfill" :: Text),
+ "total_matches" .= length matches,
+ "results" .= map formatSearchResult limited
+ ]
+
+formatSemanticResult :: (Memory.ChatHistoryEntry, Float) -> Aeson.Value
+formatSemanticResult (entry, score) =
+ Aeson.object
+ [ "timestamp" .= Time.formatTime Time.defaultTimeLocale "%Y-%m-%d %H:%M:%S" (Memory.cheCreatedAt entry),
+ "role" .= Memory.cheRole entry,
+ "sender" .= Memory.cheSenderName entry,
+ "similarity" .= score,
+ "content" .= Text.take 500 (Memory.cheContent entry)
+ ]
matchesQuery :: Text -> AuditLog.AuditLogEntry -> Bool
matchesQuery query entry =