summaryrefslogtreecommitdiff
path: root/Omni/Agent/Memory.hs
diff options
context:
space:
mode:
Diffstat (limited to 'Omni/Agent/Memory.hs')
-rw-r--r--Omni/Agent/Memory.hs197
1 files changed, 195 insertions, 2 deletions
diff --git a/Omni/Agent/Memory.hs b/Omni/Agent/Memory.hs
index 4aaa438..d59104c 100644
--- a/Omni/Agent/Memory.hs
+++ b/Omni/Agent/Memory.hs
@@ -31,6 +31,7 @@ module Omni.Agent.Memory
MessageRole (..),
RelationType (..),
MemoryLink (..),
+ ChatHistoryEntry (..),
-- * User Management
createUser,
@@ -70,6 +71,12 @@ module Omni.Agent.Memory
-- * Embeddings
embedText,
+ -- * Chat History (Semantic Search)
+ saveChatHistoryEntry,
+ searchChatHistorySemantic,
+ backfillChatHistory,
+ getChatHistoryStats,
+
-- * Agent Integration
rememberTool,
recallTool,
@@ -112,10 +119,26 @@ import qualified Omni.Agent.Engine as Engine
import qualified Omni.Test as Test
import System.Directory (createDirectoryIfMissing)
import System.Environment (lookupEnv)
+import qualified System.Environment
import System.FilePath (takeDirectory, (</>))
main :: IO ()
-main = Test.run test
+main = do
+ args <- getArgs'
+ case args of
+ ["backfill"] -> do
+ putText "Running chat history backfill..."
+ count <- backfillChatHistory
+ putText <| "Done! Processed " <> tshow count <> " messages."
+ ["stats"] -> do
+ (total, indexed) <- getChatHistoryStats
+ putText <| "Chat history stats:"
+ putText <| " Total entries: " <> tshow total
+ putText <| " With embeddings: " <> tshow indexed
+ _ -> Test.run test
+
+getArgs' :: IO [String]
+getArgs' = System.Environment.getArgs
test :: Test.Tree
test =
@@ -571,7 +594,54 @@ instance SQL.FromRow MemoryLink where
linkCreatedAt = createdAt
}
+data ChatHistoryEntry = ChatHistoryEntry
+ { cheId :: Text,
+ cheChatId :: Int,
+ cheUserId :: Maybe Text,
+ cheRole :: Text,
+ cheSenderName :: Maybe Text,
+ cheContent :: Text,
+ cheEmbedding :: Maybe (VS.Vector Float),
+ cheCreatedAt :: UTCTime
+ }
+ deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON ChatHistoryEntry where
+ toJSON e =
+ Aeson.object
+ [ "id" .= cheId e,
+ "chat_id" .= cheChatId e,
+ "user_id" .= cheUserId e,
+ "role" .= cheRole e,
+ "sender_name" .= cheSenderName e,
+ "content" .= cheContent e,
+ "created_at" .= cheCreatedAt e
+ ]
+
+instance SQL.FromRow ChatHistoryEntry where
+ fromRow = do
+ entryId <- SQL.field
+ chatId <- SQL.field
+ usrId <- SQL.field
+ role <- SQL.field
+ senderName <- SQL.field
+ content <- SQL.field
+ embeddingBlob <- SQL.field
+ createdAt <- SQL.field
+ pure
+ ChatHistoryEntry
+ { cheId = entryId,
+ cheChatId = chatId,
+ cheUserId = usrId,
+ cheRole = role,
+ cheSenderName = senderName,
+ cheContent = content,
+ cheEmbedding = blobToVector </ embeddingBlob,
+ cheCreatedAt = createdAt
+ }
+
-- | Get the path to memory.db
+-- Priority: MEMORY_DB_PATH env var > $HOME/memory.db > _/memory.db
getMemoryDbPath :: IO FilePath
getMemoryDbPath = do
maybeEnv <- lookupEnv "MEMORY_DB_PATH"
@@ -580,7 +650,7 @@ getMemoryDbPath = do
Nothing -> do
home <- lookupEnv "HOME"
case home of
- Just h -> pure (h </> ".local/share/omni/memory.db")
+ Just h -> pure (h </> "memory.db")
Nothing -> pure "_/memory.db"
-- | Run an action with the memory database connection.
@@ -707,6 +777,24 @@ initMemoryDb conn = do
SQL.execute_
conn
"CREATE INDEX IF NOT EXISTS idx_memory_links_type ON memory_links(relation_type)"
+ SQL.execute_
+ conn
+ "CREATE TABLE IF NOT EXISTS chat_history (\
+ \ id TEXT PRIMARY KEY,\
+ \ chat_id INTEGER NOT NULL,\
+ \ user_id TEXT,\
+ \ role TEXT NOT NULL,\
+ \ sender_name TEXT,\
+ \ content TEXT NOT NULL,\
+ \ embedding BLOB,\
+ \ created_at TIMESTAMP NOT NULL\
+ \)"
+ SQL.execute_
+ conn
+ "CREATE INDEX IF NOT EXISTS idx_chat_history_chat ON chat_history(chat_id)"
+ SQL.execute_
+ conn
+ "CREATE INDEX IF NOT EXISTS idx_chat_history_time ON chat_history(created_at)"
-- | Migrate conversation_messages to add sender_name and thread_id columns.
migrateConversationMessages :: SQL.Connection -> IO ()
@@ -1573,3 +1661,108 @@ storeGroupMemory chatId = storeMemory (groupUserId chatId)
-- | Recall memories for a group.
recallGroupMemories :: Int -> Text -> Int -> IO [Memory]
recallGroupMemories chatId = recallMemories (groupUserId chatId)
+
+-- | Save a chat history entry with embedding for semantic search.
+saveChatHistoryEntry :: Int -> Maybe Text -> Text -> Maybe Text -> Text -> UTCTime -> IO ()
+saveChatHistoryEntry chatId usrId role senderName content timestamp = do
+ uuid <- UUID.nextRandom
+ let entryId = UUID.toText uuid
+ embedding <- embedText content
+ withMemoryDb <| \conn ->
+ SQL.execute
+ conn
+ "INSERT OR IGNORE INTO chat_history (id, chat_id, user_id, role, sender_name, content, embedding, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)"
+ ( entryId,
+ chatId,
+ usrId,
+ role,
+ senderName,
+ content,
+ vectorToBlob </ either (const Nothing) Just embedding,
+ timestamp
+ )
+
+-- | Semantic search over chat history.
+searchChatHistorySemantic :: Text -> Int -> IO [(ChatHistoryEntry, Float)]
+searchChatHistorySemantic query limit = do
+ queryEmbedding <- embedText query
+ case queryEmbedding of
+ Left _ -> pure []
+ Right qEmb -> do
+ allEntries <-
+ withMemoryDb <| \conn ->
+ SQL.query_
+ conn
+ "SELECT id, chat_id, user_id, role, sender_name, content, embedding, created_at \
+ \FROM chat_history WHERE embedding IS NOT NULL ORDER BY created_at DESC LIMIT 10000"
+ let scored =
+ [ (e, cosineSimilarity qEmb emb)
+ | e <- allEntries,
+ Just emb <- [cheEmbedding e]
+ ]
+ sorted = List.sortBy (\(_, s1) (_, s2) -> compare s2 s1) scored
+ pure (take limit sorted)
+
+-- | Get chat history stats.
+getChatHistoryStats :: IO (Int, Int)
+getChatHistoryStats =
+ withMemoryDb <| \conn -> do
+ [[total]] <- SQL.query_ conn "SELECT COUNT(*) FROM chat_history"
+ [[withEmb]] <- SQL.query_ conn "SELECT COUNT(*) FROM chat_history WHERE embedding IS NOT NULL"
+ pure (total, withEmb)
+
+-- | Backfill chat history from conversation_messages table.
+-- Creates embeddings for existing messages that don't have them yet.
+backfillChatHistory :: IO Int
+backfillChatHistory = do
+ putText "Starting chat history backfill from conversation_messages..."
+ messages <- getAllConversationMessages
+ putText <| "Found " <> tshow (length messages) <> " messages to process"
+ processed <- backfillMessages messages 0
+ putText <| "Backfill complete: " <> tshow processed <> " messages processed"
+ pure processed
+
+-- | Get all conversation messages for backfill.
+getAllConversationMessages :: IO [ConversationMessage]
+getAllConversationMessages =
+ withMemoryDb <| \conn ->
+ SQL.query_
+ conn
+ "SELECT id, user_id, chat_id, role, sender_name, content, tokens_estimate, created_at \
+ \FROM conversation_messages ORDER BY created_at ASC"
+
+-- | Backfill messages one by one, showing progress.
+backfillMessages :: [ConversationMessage] -> Int -> IO Int
+backfillMessages [] count = pure count
+backfillMessages (msg : rest) count = do
+ let chatId = cmChatId msg
+ uid = cmUserId msg
+ role = case cmRole msg of
+ UserRole -> "user"
+ AssistantRole -> "assistant"
+ senderName = cmSenderName msg
+ content = cmContent msg
+ timestamp = cmCreatedAt msg
+ alreadyExists <- checkChatHistoryExists chatId timestamp content
+ if alreadyExists
+ then backfillMessages rest count
+ else do
+ when (count `mod` 100 == 0)
+ <| putText
+ <| "Progress: "
+ <> tshow count
+ <> " messages processed..."
+ saveChatHistoryEntry chatId (Just uid) role senderName content timestamp
+ backfillMessages rest (count + 1)
+
+-- | Check if a chat history entry already exists.
+checkChatHistoryExists :: Int -> UTCTime -> Text -> IO Bool
+checkChatHistoryExists chatId timestamp content =
+ withMemoryDb <| \conn -> do
+ results <-
+ SQL.query
+ conn
+ "SELECT 1 FROM chat_history WHERE chat_id = ? AND created_at = ? AND content = ? LIMIT 1"
+ (chatId, timestamp, content) ::
+ IO [[Int]]
+ pure (not (null results))