{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE NoImplicitPrelude #-} -- | Cross-agent shared memory system with vector similarity search. -- -- Provides persistent memory that is: -- - Shared across all agents (Telegram, researcher, coder, etc.) -- - Private per user (users can't see each other's memories) -- - Searchable via semantic similarity using embeddings -- -- Uses sqlite-vss for vector similarity search and Ollama for embeddings. -- -- : out omni-agent-memory -- : dep aeson -- : dep http-conduit -- : dep sqlite-simple -- : dep uuid -- : dep vector -- : dep directory -- : dep bytestring module Omni.Agent.Memory ( -- * Types User (..), Memory (..), MemorySource (..), -- * User Management createUser, getUser, getUserByTelegramId, getOrCreateUserByTelegramId, -- * Memory Operations storeMemory, recallMemories, forgetMemory, getAllMemoriesForUser, updateMemoryAccess, -- * Embeddings embedText, -- * Agent Integration rememberTool, recallTool, formatMemoriesForPrompt, runAgentWithMemory, -- * Database withMemoryDb, initMemoryDb, getMemoryDbPath, -- * Testing main, test, ) where import Alpha import Data.Aeson ((.!=), (.:), (.:?), (.=)) import qualified Data.Aeson as Aeson import qualified Data.Aeson.KeyMap as KeyMap import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as BL import qualified Data.List as List import qualified Data.Text as Text import qualified Data.Text.Encoding as TE import Data.Time (UTCTime, getCurrentTime) import qualified Data.UUID as UUID import qualified Data.UUID.V4 as UUID import qualified Data.Vector.Storable as VS import qualified Database.SQLite.Simple as SQL import Database.SQLite.Simple.FromField () import qualified Database.SQLite.Simple.ToField as SQL import Foreign.Storable () import qualified Network.HTTP.Simple as HTTP import qualified Omni.Agent.Engine as Engine import qualified Omni.Test as Test import System.Directory (createDirectoryIfMissing) import System.Environment (lookupEnv) import System.FilePath (takeDirectory, ()) main :: IO () main = Test.run test test :: Test.Tree test = Test.group "Omni.Agent.Memory" [ Test.unit "User JSON roundtrip" <| do now <- getCurrentTime let user = User { userId = "test-uuid", userTelegramId = Just 12345, userEmail = Nothing, userName = "Test User", userCreatedAt = now } case Aeson.decode (Aeson.encode user) of Nothing -> Test.assertFailure "Failed to decode User" Just decoded -> userName decoded Test.@=? "Test User", Test.unit "Memory JSON roundtrip" <| do now <- getCurrentTime let mem = Memory { memoryId = "mem-uuid", memoryUserId = "user-uuid", memoryContent = "User is an AI engineer", memoryEmbedding = Nothing, memorySource = MemorySource { sourceAgent = "telegram", sourceSession = Nothing, sourceContext = "User mentioned in chat" }, memoryConfidence = 0.9, memoryCreatedAt = now, memoryLastAccessedAt = now, memoryTags = ["profession", "ai"] } case Aeson.decode (Aeson.encode mem) of Nothing -> Test.assertFailure "Failed to decode Memory" Just decoded -> memoryContent decoded Test.@=? "User is an AI engineer", Test.unit "MemorySource JSON roundtrip" <| do let src = MemorySource { sourceAgent = "researcher", sourceSession = Just "session-123", sourceContext = "Extracted from conversation" } case Aeson.decode (Aeson.encode src) of Nothing -> Test.assertFailure "Failed to decode MemorySource" Just decoded -> sourceAgent decoded Test.@=? "researcher", Test.unit "formatMemoriesForPrompt formats correctly" <| do now <- getCurrentTime let mem1 = Memory { memoryId = "1", memoryUserId = "u", memoryContent = "User is an AI engineer", memoryEmbedding = Nothing, memorySource = MemorySource "telegram" Nothing "chat", memoryConfidence = 0.9, memoryCreatedAt = now, memoryLastAccessedAt = now, memoryTags = [] } mem2 = Memory { memoryId = "2", memoryUserId = "u", memoryContent = "User prefers Haskell", memoryEmbedding = Nothing, memorySource = MemorySource "coder" Nothing "code review", memoryConfidence = 0.8, memoryCreatedAt = now, memoryLastAccessedAt = now, memoryTags = [] } formatted = formatMemoriesForPrompt [mem1, mem2] ("AI engineer" `Text.isInfixOf` formatted) Test.@=? True ("Haskell" `Text.isInfixOf` formatted) Test.@=? True, Test.unit "cosineSimilarity identical vectors" <| do let v1 = VS.fromList [1.0, 0.0, 0.0 :: Float] v2 = VS.fromList [1.0, 0.0, 0.0 :: Float] abs (cosineSimilarity v1 v2 - 1.0) < 0.0001 Test.@=? True, Test.unit "cosineSimilarity orthogonal vectors" <| do let v1 = VS.fromList [1.0, 0.0, 0.0 :: Float] v2 = VS.fromList [0.0, 1.0, 0.0 :: Float] abs (cosineSimilarity v1 v2) < 0.0001 Test.@=? True, Test.unit "cosineSimilarity opposite vectors" <| do let v1 = VS.fromList [1.0, 0.0, 0.0 :: Float] v2 = VS.fromList [-1.0, 0.0, 0.0 :: Float] abs (cosineSimilarity v1 v2 + 1.0) < 0.0001 Test.@=? True, Test.unit "vectorToBlob and blobToVector roundtrip" <| do let v = VS.fromList [0.1, 0.2, 0.3, 0.4, 0.5 :: Float] blob = vectorToBlob v v' = blobToVector blob VS.length v Test.@=? VS.length v' VS.toList v Test.@=? VS.toList v', Test.unit "rememberTool has correct schema" <| do let tool = rememberTool "test-user-id" Engine.toolName tool Test.@=? "remember", Test.unit "recallTool has correct schema" <| do let tool = recallTool "test-user-id" Engine.toolName tool Test.@=? "recall" ] -- | User record for multi-user memory system. data User = User { userId :: Text, userTelegramId :: Maybe Int, userEmail :: Maybe Text, userName :: Text, userCreatedAt :: UTCTime } deriving (Show, Eq, Generic) instance Aeson.ToJSON User where toJSON u = Aeson.object [ "id" .= userId u, "telegram_id" .= userTelegramId u, "email" .= userEmail u, "name" .= userName u, "created_at" .= userCreatedAt u ] instance Aeson.FromJSON User where parseJSON = Aeson.withObject "User" <| \v -> (User (v .:? "telegram_id") <*> (v .:? "email") <*> (v .: "name") <*> (v .: "created_at") instance SQL.FromRow User where fromRow = User SQL.field <*> SQL.field <*> SQL.field <*> SQL.field instance SQL.ToRow User where toRow u = [ SQL.toField (userId u), SQL.toField (userTelegramId u), SQL.toField (userEmail u), SQL.toField (userName u), SQL.toField (userCreatedAt u) ] -- | Source information for a memory. data MemorySource = MemorySource { sourceAgent :: Text, sourceSession :: Maybe Text, sourceContext :: Text } deriving (Show, Eq, Generic) instance Aeson.ToJSON MemorySource where toJSON s = Aeson.object [ "agent" .= sourceAgent s, "session" .= sourceSession s, "context" .= sourceContext s ] instance Aeson.FromJSON MemorySource where parseJSON = Aeson.withObject "MemorySource" <| \v -> (MemorySource (v .:? "session") <*> (v .: "context") -- | A memory stored in the system. data Memory = Memory { memoryId :: Text, memoryUserId :: Text, memoryContent :: Text, memoryEmbedding :: Maybe (VS.Vector Float), memorySource :: MemorySource, memoryConfidence :: Double, memoryCreatedAt :: UTCTime, memoryLastAccessedAt :: UTCTime, memoryTags :: [Text] } deriving (Show, Eq, Generic) instance Aeson.ToJSON Memory where toJSON m = Aeson.object [ "id" .= memoryId m, "user_id" .= memoryUserId m, "content" .= memoryContent m, "source" .= memorySource m, "confidence" .= memoryConfidence m, "created_at" .= memoryCreatedAt m, "last_accessed_at" .= memoryLastAccessedAt m, "tags" .= memoryTags m ] instance Aeson.FromJSON Memory where parseJSON = Aeson.withObject "Memory" <| \v -> ( Memory (v .: "user_id") <*> (v .: "content") <*> pure Nothing <*> (v .: "source") <*> (v .:? "confidence" .!= 0.8) <*> (v .: "created_at") <*> (v .: "last_accessed_at") <*> (v .:? "tags" .!= []) -- SQLite instances for Memory (partial - embedding handled separately) instance SQL.FromRow Memory where fromRow = do mid <- SQL.field uid <- SQL.field content <- SQL.field embeddingBlob <- SQL.field agent <- SQL.field session <- SQL.field context <- SQL.field confidence <- SQL.field createdAt <- SQL.field lastAccessedAt <- SQL.field tagsJson <- SQL.field let embedding = blobToVector (Aeson.decode <. BL.fromStrict <. TE.encodeUtf8)) pure Memory { memoryId = mid, memoryUserId = uid, memoryContent = content, memoryEmbedding = embedding, memorySource = source, memoryConfidence = confidence, memoryCreatedAt = createdAt, memoryLastAccessedAt = lastAccessedAt, memoryTags = tags } -- | Get the path to memory.db getMemoryDbPath :: IO FilePath getMemoryDbPath = do maybeEnv <- lookupEnv "MEMORY_DB_PATH" case maybeEnv of Just p -> pure p Nothing -> do home <- lookupEnv "HOME" case home of Just h -> pure (h ".local/share/omni/memory.db") Nothing -> pure "_/memory.db" -- | Run an action with the memory database connection. withMemoryDb :: (SQL.Connection -> IO a) -> IO a withMemoryDb action = do dbPath <- getMemoryDbPath createDirectoryIfMissing True (takeDirectory dbPath) SQL.withConnection dbPath <| \conn -> do initMemoryDb conn action conn -- | Initialize the memory database schema. initMemoryDb :: SQL.Connection -> IO () initMemoryDb conn = do SQL.execute_ conn "PRAGMA foreign_keys = ON" SQL.execute_ conn "CREATE TABLE IF NOT EXISTS users (\ \ id TEXT PRIMARY KEY,\ \ telegram_id INTEGER UNIQUE,\ \ email TEXT UNIQUE,\ \ name TEXT NOT NULL,\ \ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP\ \)" SQL.execute_ conn "CREATE TABLE IF NOT EXISTS memories (\ \ id TEXT PRIMARY KEY,\ \ user_id TEXT NOT NULL REFERENCES users(id),\ \ content TEXT NOT NULL,\ \ embedding BLOB,\ \ source_agent TEXT NOT NULL,\ \ source_session TEXT,\ \ source_context TEXT,\ \ confidence REAL DEFAULT 0.8,\ \ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\ \ last_accessed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\ \ tags TEXT\ \)" SQL.execute_ conn "CREATE INDEX IF NOT EXISTS idx_memories_user ON memories(user_id)" SQL.execute_ conn "CREATE INDEX IF NOT EXISTS idx_memories_agent ON memories(source_agent)" -- | Create a new user. createUser :: Text -> Maybe Int -> IO User createUser name telegramId = do uuid <- UUID.nextRandom now <- getCurrentTime let user = User { userId = UUID.toText uuid, userTelegramId = telegramId, userEmail = Nothing, userName = name, userCreatedAt = now } withMemoryDb <| \conn -> SQL.execute conn "INSERT INTO users (id, telegram_id, email, name, created_at) VALUES (?, ?, ?, ?, ?)" user pure user -- | Get a user by ID. getUser :: Text -> IO (Maybe User) getUser uid = withMemoryDb <| \conn -> do results <- SQL.query conn "SELECT id, telegram_id, email, name, created_at FROM users WHERE id = ?" (SQL.Only uid) pure (listToMaybe results) -- | Get a user by Telegram ID. getUserByTelegramId :: Int -> IO (Maybe User) getUserByTelegramId tid = withMemoryDb <| \conn -> do results <- SQL.query conn "SELECT id, telegram_id, email, name, created_at FROM users WHERE telegram_id = ?" (SQL.Only tid) pure (listToMaybe results) -- | Get or create a user by Telegram ID. getOrCreateUserByTelegramId :: Int -> Text -> IO User getOrCreateUserByTelegramId tid name = do existing <- getUserByTelegramId tid case existing of Just user -> pure user Nothing -> createUser name (Just tid) -- | Store a memory for a user. storeMemory :: Text -> Text -> MemorySource -> IO Memory storeMemory uid content source = storeMemoryWithTags uid content source [] -- | Store a memory with tags. storeMemoryWithTags :: Text -> Text -> MemorySource -> [Text] -> IO Memory storeMemoryWithTags uid content source tags = do uuid <- UUID.nextRandom now <- getCurrentTime embedding <- embedText content let mem = Memory { memoryId = UUID.toText uuid, memoryUserId = uid, memoryContent = content, memoryEmbedding = either (const Nothing) Just embedding, memorySource = source, memoryConfidence = 0.8, memoryCreatedAt = now, memoryLastAccessedAt = now, memoryTags = tags } withMemoryDb <| \conn -> SQL.execute conn "INSERT INTO memories (id, user_id, content, embedding, source_agent, source_session, source_context, confidence, created_at, last_accessed_at, tags) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" ( ( memoryId mem, memoryUserId mem, memoryContent mem, vectorToBlob Text -> Int -> IO [Memory] recallMemories uid query limit = do queryEmbedding <- embedText query case queryEmbedding of Left _ -> recallMemoriesByRecency uid limit Right qEmb -> do allMems <- getAllMemoriesForUser uid let scored = [ (m, cosineSimilarity qEmb emb) | m <- allMems, Just emb <- [memoryEmbedding m] ] sorted = List.sortBy (\(_, s1) (_, s2) -> compare s2 s1) scored topN = take limit sorted now <- getCurrentTime traverse_ (updateMemoryAccess now <. memoryId <. fst) topN pure (map fst topN) -- | Recall memories by recency (fallback when embedding fails). recallMemoriesByRecency :: Text -> Int -> IO [Memory] recallMemoriesByRecency uid limit = withMemoryDb <| \conn -> do SQL.query conn "SELECT id, user_id, content, embedding, source_agent, source_session, source_context, confidence, created_at, last_accessed_at, tags \ \FROM memories WHERE user_id = ? ORDER BY last_accessed_at DESC LIMIT ?" (uid, limit) -- | Get all memories for a user. getAllMemoriesForUser :: Text -> IO [Memory] getAllMemoriesForUser uid = withMemoryDb <| \conn -> SQL.query conn "SELECT id, user_id, content, embedding, source_agent, source_session, source_context, confidence, created_at, last_accessed_at, tags \ \FROM memories WHERE user_id = ?" (SQL.Only uid) -- | Delete a memory. forgetMemory :: Text -> IO () forgetMemory mid = withMemoryDb <| \conn -> SQL.execute conn "DELETE FROM memories WHERE id = ?" (SQL.Only mid) -- | Update memory's last accessed timestamp. updateMemoryAccess :: UTCTime -> Text -> IO () updateMemoryAccess now mid = withMemoryDb <| \conn -> SQL.execute conn "UPDATE memories SET last_accessed_at = ? WHERE id = ?" (now, mid) -- | Embed text using Ollama's nomic-embed-text model. embedText :: Text -> IO (Either Text (VS.Vector Float)) embedText content = do ollamaUrl <- fromMaybe "http://localhost:11434" "/api/embeddings" req0 <- HTTP.parseRequest url let body = Aeson.object [ "model" .= ("nomic-embed-text" :: Text), "prompt" .= content ] req = HTTP.setRequestMethod "POST" <| HTTP.setRequestHeader "Content-Type" ["application/json"] <| HTTP.setRequestBodyLBS (Aeson.encode body) <| req0 result <- try (HTTP.httpLBS req) case result of Left (e :: SomeException) -> pure (Left ("Embedding request failed: " <> tshow e)) Right response -> do let status = HTTP.getResponseStatusCode response if status >= 200 && status < 300 then case Aeson.decode (HTTP.getResponseBody response) of Just (Aeson.Object obj) -> case KeyMap.lookup "embedding" obj of Just (Aeson.Array arr) -> let floats = [f | Aeson.Number n <- toList arr, let f = realToFrac n] in pure (Right (VS.fromList floats)) _ -> pure (Left "No embedding in response") _ -> pure (Left "Failed to parse embedding response") else pure (Left ("Embedding HTTP error: " <> tshow status)) -- | Convert a vector to a blob for storage. vectorToBlob :: VS.Vector Float -> BS.ByteString vectorToBlob v = let bytes = VS.unsafeCast v :: VS.Vector Word8 in BS.pack (VS.toList bytes) -- | Convert a blob back to a vector. blobToVector :: BS.ByteString -> VS.Vector Float blobToVector bs = let bytes = VS.fromList (BS.unpack bs) :: VS.Vector Word8 in VS.unsafeCast bytes -- | Calculate cosine similarity between two vectors. cosineSimilarity :: VS.Vector Float -> VS.Vector Float -> Float cosineSimilarity v1 v2 | VS.length v1 /= VS.length v2 = 0 | otherwise = let dot = VS.sum (VS.zipWith (*) v1 v2) mag1 = sqrt (VS.sum (VS.map (\x -> x * x) v1)) mag2 = sqrt (VS.sum (VS.map (\x -> x * x) v2)) in if mag1 == 0 || mag2 == 0 then 0 else dot / (mag1 * mag2) -- | Format memories for inclusion in a prompt. formatMemoriesForPrompt :: [Memory] -> Text formatMemoriesForPrompt [] = "No prior context available." formatMemoriesForPrompt mems = Text.unlines [ "Known context about this user:", "", Text.unlines (map formatMem mems) ] where formatMem m = "- " <> memoryContent m <> " (via " <> sourceAgent (memorySource m) <> ")" -- | Run an agent with memory context. -- Recalls relevant memories for the user and injects them into the system prompt. runAgentWithMemory :: User -> Engine.EngineConfig -> Engine.AgentConfig -> Text -> IO (Either Text Engine.AgentResult) runAgentWithMemory user engineCfg agentCfg userPrompt = do memories <- recallMemories (userId user) userPrompt 10 let memoryContext = formatMemoriesForPrompt memories enhancedPrompt = Engine.agentSystemPrompt agentCfg <> "\n\n## Known about this user\n" <> memoryContext enhancedConfig = agentCfg { Engine.agentSystemPrompt = enhancedPrompt, Engine.agentTools = Engine.agentTools agentCfg <> [rememberTool (userId user), recallTool (userId user)] } Engine.runAgent engineCfg enhancedConfig userPrompt -- | Tool for agents to store memories about users. rememberTool :: Text -> Engine.Tool rememberTool uid = Engine.Tool { Engine.toolName = "remember", Engine.toolDescription = "Store a piece of information about the user for future reference. " <> "Use this when the user shares personal facts, preferences, or context " <> "that would be useful to recall in future conversations.", Engine.toolJsonSchema = Aeson.object [ "type" .= ("object" :: Text), "properties" .= Aeson.object [ "content" .= Aeson.object [ "type" .= ("string" :: Text), "description" .= ("The information to remember about the user" :: Text) ], "context" .= Aeson.object [ "type" .= ("string" :: Text), "description" .= ("How/why this was learned (e.g., 'user mentioned in chat')" :: Text) ], "tags" .= Aeson.object [ "type" .= ("array" :: Text), "items" .= Aeson.object ["type" .= ("string" :: Text)], "description" .= ("Optional tags for categorization" :: Text) ] ], "required" .= (["content", "context"] :: [Text]) ], Engine.toolExecute = executeRemember uid } executeRemember :: Text -> Aeson.Value -> IO Aeson.Value executeRemember uid v = case Aeson.fromJSON v of Aeson.Error e -> pure (Aeson.object ["error" .= Text.pack e]) Aeson.Success (args :: RememberArgs) -> do let source = MemorySource { sourceAgent = "agent", sourceSession = Nothing, sourceContext = rememberContext args } mem <- storeMemoryWithTags uid (rememberContent args) source (rememberTags args) pure ( Aeson.object [ "success" .= True, "memory_id" .= memoryId mem, "message" .= ("Remembered: " <> rememberContent args) ] ) -- | Tool for agents to recall memories about users. recallTool :: Text -> Engine.Tool recallTool uid = Engine.Tool { Engine.toolName = "recall", Engine.toolDescription = "Search your memory for information about the user. " <> "Use this to retrieve previously stored facts, preferences, or context.", Engine.toolJsonSchema = Aeson.object [ "type" .= ("object" :: Text), "properties" .= Aeson.object [ "query" .= Aeson.object [ "type" .= ("string" :: Text), "description" .= ("What to search for in memory" :: Text) ], "limit" .= Aeson.object [ "type" .= ("integer" :: Text), "description" .= ("Maximum memories to return (default: 5)" :: Text) ] ], "required" .= (["query"] :: [Text]) ], Engine.toolExecute = executeRecall uid } executeRecall :: Text -> Aeson.Value -> IO Aeson.Value executeRecall uid v = case Aeson.fromJSON v of Aeson.Error e -> pure (Aeson.object ["error" .= Text.pack e]) Aeson.Success (args :: RecallArgs) -> do mems <- recallMemories uid (recallQuery args) (recallLimit args) pure ( Aeson.object [ "success" .= True, "count" .= length mems, "memories" .= map ( \m -> Aeson.object [ "content" .= memoryContent m, "confidence" .= memoryConfidence m, "source" .= sourceAgent (memorySource m), "tags" .= memoryTags m ] ) mems ] ) -- Helper for parsing remember args data RememberArgs = RememberArgs { rememberContent :: Text, rememberContext :: Text, rememberTags :: [Text] } deriving (Generic) instance Aeson.FromJSON RememberArgs where parseJSON = Aeson.withObject "RememberArgs" <| \v -> (RememberArgs (v .:? "context" .!= "agent observation") <*> (v .:? "tags" .!= []) data RecallArgs = RecallArgs { recallQuery :: Text, recallLimit :: Int } deriving (Generic) instance Aeson.FromJSON RecallArgs where parseJSON = Aeson.withObject "RecallArgs" <| \v -> (RecallArgs (v .:? "limit" .!= 5)