diff options
| author | Ben Sima <ben@bensima.com> | 2025-12-11 22:42:08 -0500 |
|---|---|---|
| committer | Ben Sima <ben@bensima.com> | 2025-12-11 22:42:08 -0500 |
| commit | ff89735dab5d923b13dc6fdca8af7cd448e6234e (patch) | |
| tree | e158ca9d36d97070661c29161d901e730ce9addc /Omni/Agent | |
| parent | 276a27f27aeff7781a25e13fad0d568f5455ce05 (diff) | |
Add cross-agent memory system (t-248)
- User management with Telegram ID identification
- Memory storage with Ollama embeddings (nomic-embed-text)
- Semantic similarity search via cosine similarity
- remember/recall tools for agents
- runAgentWithMemory wrapper for memory-enhanced agents
- Separate memory.db database for user privacy
Diffstat (limited to 'Omni/Agent')
| -rw-r--r-- | Omni/Agent/Memory.hs | 751 |
1 files changed, 751 insertions, 0 deletions
diff --git a/Omni/Agent/Memory.hs b/Omni/Agent/Memory.hs new file mode 100644 index 0000000..863528c --- /dev/null +++ b/Omni/Agent/Memory.hs @@ -0,0 +1,751 @@ +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE NoImplicitPrelude #-} + +-- | Cross-agent shared memory system with vector similarity search. +-- +-- Provides persistent memory that is: +-- - Shared across all agents (Telegram, researcher, coder, etc.) +-- - Private per user (users can't see each other's memories) +-- - Searchable via semantic similarity using embeddings +-- +-- Uses sqlite-vss for vector similarity search and Ollama for embeddings. +-- +-- : out omni-agent-memory +-- : dep aeson +-- : dep http-conduit +-- : dep sqlite-simple +-- : dep uuid +-- : dep vector +-- : dep directory +-- : dep bytestring +module Omni.Agent.Memory + ( -- * Types + User (..), + Memory (..), + MemorySource (..), + + -- * User Management + createUser, + getUser, + getUserByTelegramId, + getOrCreateUserByTelegramId, + + -- * Memory Operations + storeMemory, + recallMemories, + forgetMemory, + getAllMemoriesForUser, + updateMemoryAccess, + + -- * Embeddings + embedText, + + -- * Agent Integration + rememberTool, + recallTool, + formatMemoriesForPrompt, + runAgentWithMemory, + + -- * Database + withMemoryDb, + initMemoryDb, + getMemoryDbPath, + + -- * Testing + main, + test, + ) +where + +import Alpha +import Data.Aeson ((.!=), (.:), (.:?), (.=)) +import qualified Data.Aeson as Aeson +import qualified Data.Aeson.KeyMap as KeyMap +import qualified Data.ByteString as BS +import qualified Data.ByteString.Lazy as BL +import qualified Data.List as List +import qualified Data.Text as Text +import qualified Data.Text.Encoding as TE +import Data.Time (UTCTime, getCurrentTime) +import qualified Data.UUID as UUID +import qualified Data.UUID.V4 as UUID +import qualified Data.Vector.Storable as VS +import qualified Database.SQLite.Simple as SQL +import Database.SQLite.Simple.FromField () +import qualified Database.SQLite.Simple.ToField as SQL +import Foreign.Storable () +import qualified Network.HTTP.Simple as HTTP +import qualified Omni.Agent.Engine as Engine +import qualified Omni.Test as Test +import System.Directory (createDirectoryIfMissing) +import System.Environment (lookupEnv) +import System.FilePath (takeDirectory, (</>)) + +main :: IO () +main = Test.run test + +test :: Test.Tree +test = + Test.group + "Omni.Agent.Memory" + [ Test.unit "User JSON roundtrip" <| do + now <- getCurrentTime + let user = + User + { userId = "test-uuid", + userTelegramId = Just 12345, + userEmail = Nothing, + userName = "Test User", + userCreatedAt = now + } + case Aeson.decode (Aeson.encode user) of + Nothing -> Test.assertFailure "Failed to decode User" + Just decoded -> userName decoded Test.@=? "Test User", + Test.unit "Memory JSON roundtrip" <| do + now <- getCurrentTime + let mem = + Memory + { memoryId = "mem-uuid", + memoryUserId = "user-uuid", + memoryContent = "User is an AI engineer", + memoryEmbedding = Nothing, + memorySource = + MemorySource + { sourceAgent = "telegram", + sourceSession = Nothing, + sourceContext = "User mentioned in chat" + }, + memoryConfidence = 0.9, + memoryCreatedAt = now, + memoryLastAccessedAt = now, + memoryTags = ["profession", "ai"] + } + case Aeson.decode (Aeson.encode mem) of + Nothing -> Test.assertFailure "Failed to decode Memory" + Just decoded -> memoryContent decoded Test.@=? "User is an AI engineer", + Test.unit "MemorySource JSON roundtrip" <| do + let src = + MemorySource + { sourceAgent = "researcher", + sourceSession = Just "session-123", + sourceContext = "Extracted from conversation" + } + case Aeson.decode (Aeson.encode src) of + Nothing -> Test.assertFailure "Failed to decode MemorySource" + Just decoded -> sourceAgent decoded Test.@=? "researcher", + Test.unit "formatMemoriesForPrompt formats correctly" <| do + now <- getCurrentTime + let mem1 = + Memory + { memoryId = "1", + memoryUserId = "u", + memoryContent = "User is an AI engineer", + memoryEmbedding = Nothing, + memorySource = MemorySource "telegram" Nothing "chat", + memoryConfidence = 0.9, + memoryCreatedAt = now, + memoryLastAccessedAt = now, + memoryTags = [] + } + mem2 = + Memory + { memoryId = "2", + memoryUserId = "u", + memoryContent = "User prefers Haskell", + memoryEmbedding = Nothing, + memorySource = MemorySource "coder" Nothing "code review", + memoryConfidence = 0.8, + memoryCreatedAt = now, + memoryLastAccessedAt = now, + memoryTags = [] + } + formatted = formatMemoriesForPrompt [mem1, mem2] + ("AI engineer" `Text.isInfixOf` formatted) Test.@=? True + ("Haskell" `Text.isInfixOf` formatted) Test.@=? True, + Test.unit "cosineSimilarity identical vectors" <| do + let v1 = VS.fromList [1.0, 0.0, 0.0 :: Float] + v2 = VS.fromList [1.0, 0.0, 0.0 :: Float] + abs (cosineSimilarity v1 v2 - 1.0) < 0.0001 Test.@=? True, + Test.unit "cosineSimilarity orthogonal vectors" <| do + let v1 = VS.fromList [1.0, 0.0, 0.0 :: Float] + v2 = VS.fromList [0.0, 1.0, 0.0 :: Float] + abs (cosineSimilarity v1 v2) < 0.0001 Test.@=? True, + Test.unit "cosineSimilarity opposite vectors" <| do + let v1 = VS.fromList [1.0, 0.0, 0.0 :: Float] + v2 = VS.fromList [-1.0, 0.0, 0.0 :: Float] + abs (cosineSimilarity v1 v2 + 1.0) < 0.0001 Test.@=? True, + Test.unit "vectorToBlob and blobToVector roundtrip" <| do + let v = VS.fromList [0.1, 0.2, 0.3, 0.4, 0.5 :: Float] + blob = vectorToBlob v + v' = blobToVector blob + VS.length v Test.@=? VS.length v' + VS.toList v Test.@=? VS.toList v', + Test.unit "rememberTool has correct schema" <| do + let tool = rememberTool "test-user-id" + Engine.toolName tool Test.@=? "remember", + Test.unit "recallTool has correct schema" <| do + let tool = recallTool "test-user-id" + Engine.toolName tool Test.@=? "recall" + ] + +-- | User record for multi-user memory system. +data User = User + { userId :: Text, + userTelegramId :: Maybe Int, + userEmail :: Maybe Text, + userName :: Text, + userCreatedAt :: UTCTime + } + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON User where + toJSON u = + Aeson.object + [ "id" .= userId u, + "telegram_id" .= userTelegramId u, + "email" .= userEmail u, + "name" .= userName u, + "created_at" .= userCreatedAt u + ] + +instance Aeson.FromJSON User where + parseJSON = + Aeson.withObject "User" <| \v -> + (User </ (v .: "id")) + <*> (v .:? "telegram_id") + <*> (v .:? "email") + <*> (v .: "name") + <*> (v .: "created_at") + +instance SQL.FromRow User where + fromRow = + User + </ SQL.field + <*> SQL.field + <*> SQL.field + <*> SQL.field + <*> SQL.field + +instance SQL.ToRow User where + toRow u = + [ SQL.toField (userId u), + SQL.toField (userTelegramId u), + SQL.toField (userEmail u), + SQL.toField (userName u), + SQL.toField (userCreatedAt u) + ] + +-- | Source information for a memory. +data MemorySource = MemorySource + { sourceAgent :: Text, + sourceSession :: Maybe Text, + sourceContext :: Text + } + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON MemorySource where + toJSON s = + Aeson.object + [ "agent" .= sourceAgent s, + "session" .= sourceSession s, + "context" .= sourceContext s + ] + +instance Aeson.FromJSON MemorySource where + parseJSON = + Aeson.withObject "MemorySource" <| \v -> + (MemorySource </ (v .: "agent")) + <*> (v .:? "session") + <*> (v .: "context") + +-- | A memory stored in the system. +data Memory = Memory + { memoryId :: Text, + memoryUserId :: Text, + memoryContent :: Text, + memoryEmbedding :: Maybe (VS.Vector Float), + memorySource :: MemorySource, + memoryConfidence :: Double, + memoryCreatedAt :: UTCTime, + memoryLastAccessedAt :: UTCTime, + memoryTags :: [Text] + } + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON Memory where + toJSON m = + Aeson.object + [ "id" .= memoryId m, + "user_id" .= memoryUserId m, + "content" .= memoryContent m, + "source" .= memorySource m, + "confidence" .= memoryConfidence m, + "created_at" .= memoryCreatedAt m, + "last_accessed_at" .= memoryLastAccessedAt m, + "tags" .= memoryTags m + ] + +instance Aeson.FromJSON Memory where + parseJSON = + Aeson.withObject "Memory" <| \v -> + ( Memory + </ (v .: "id") + ) + <*> (v .: "user_id") + <*> (v .: "content") + <*> pure Nothing + <*> (v .: "source") + <*> (v .:? "confidence" .!= 0.8) + <*> (v .: "created_at") + <*> (v .: "last_accessed_at") + <*> (v .:? "tags" .!= []) + +-- SQLite instances for Memory (partial - embedding handled separately) +instance SQL.FromRow Memory where + fromRow = do + mid <- SQL.field + uid <- SQL.field + content <- SQL.field + embeddingBlob <- SQL.field + agent <- SQL.field + session <- SQL.field + context <- SQL.field + confidence <- SQL.field + createdAt <- SQL.field + lastAccessedAt <- SQL.field + tagsJson <- SQL.field + let embedding = blobToVector </ (embeddingBlob :: Maybe BS.ByteString) + source = MemorySource agent session context + tags = fromMaybe [] ((tagsJson :: Maybe Text) +> (Aeson.decode <. BL.fromStrict <. TE.encodeUtf8)) + pure + Memory + { memoryId = mid, + memoryUserId = uid, + memoryContent = content, + memoryEmbedding = embedding, + memorySource = source, + memoryConfidence = confidence, + memoryCreatedAt = createdAt, + memoryLastAccessedAt = lastAccessedAt, + memoryTags = tags + } + +-- | Get the path to memory.db +getMemoryDbPath :: IO FilePath +getMemoryDbPath = do + maybeEnv <- lookupEnv "MEMORY_DB_PATH" + case maybeEnv of + Just p -> pure p + Nothing -> do + home <- lookupEnv "HOME" + case home of + Just h -> pure (h </> ".local/share/omni/memory.db") + Nothing -> pure "_/memory.db" + +-- | Run an action with the memory database connection. +withMemoryDb :: (SQL.Connection -> IO a) -> IO a +withMemoryDb action = do + dbPath <- getMemoryDbPath + createDirectoryIfMissing True (takeDirectory dbPath) + SQL.withConnection dbPath <| \conn -> do + initMemoryDb conn + action conn + +-- | Initialize the memory database schema. +initMemoryDb :: SQL.Connection -> IO () +initMemoryDb conn = do + SQL.execute_ conn "PRAGMA foreign_keys = ON" + SQL.execute_ + conn + "CREATE TABLE IF NOT EXISTS users (\ + \ id TEXT PRIMARY KEY,\ + \ telegram_id INTEGER UNIQUE,\ + \ email TEXT UNIQUE,\ + \ name TEXT NOT NULL,\ + \ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP\ + \)" + SQL.execute_ + conn + "CREATE TABLE IF NOT EXISTS memories (\ + \ id TEXT PRIMARY KEY,\ + \ user_id TEXT NOT NULL REFERENCES users(id),\ + \ content TEXT NOT NULL,\ + \ embedding BLOB,\ + \ source_agent TEXT NOT NULL,\ + \ source_session TEXT,\ + \ source_context TEXT,\ + \ confidence REAL DEFAULT 0.8,\ + \ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\ + \ last_accessed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\ + \ tags TEXT\ + \)" + SQL.execute_ + conn + "CREATE INDEX IF NOT EXISTS idx_memories_user ON memories(user_id)" + SQL.execute_ + conn + "CREATE INDEX IF NOT EXISTS idx_memories_agent ON memories(source_agent)" + +-- | Create a new user. +createUser :: Text -> Maybe Int -> IO User +createUser name telegramId = do + uuid <- UUID.nextRandom + now <- getCurrentTime + let user = + User + { userId = UUID.toText uuid, + userTelegramId = telegramId, + userEmail = Nothing, + userName = name, + userCreatedAt = now + } + withMemoryDb <| \conn -> + SQL.execute + conn + "INSERT INTO users (id, telegram_id, email, name, created_at) VALUES (?, ?, ?, ?, ?)" + user + pure user + +-- | Get a user by ID. +getUser :: Text -> IO (Maybe User) +getUser uid = + withMemoryDb <| \conn -> do + results <- SQL.query conn "SELECT id, telegram_id, email, name, created_at FROM users WHERE id = ?" (SQL.Only uid) + pure (listToMaybe results) + +-- | Get a user by Telegram ID. +getUserByTelegramId :: Int -> IO (Maybe User) +getUserByTelegramId tid = + withMemoryDb <| \conn -> do + results <- SQL.query conn "SELECT id, telegram_id, email, name, created_at FROM users WHERE telegram_id = ?" (SQL.Only tid) + pure (listToMaybe results) + +-- | Get or create a user by Telegram ID. +getOrCreateUserByTelegramId :: Int -> Text -> IO User +getOrCreateUserByTelegramId tid name = do + existing <- getUserByTelegramId tid + case existing of + Just user -> pure user + Nothing -> createUser name (Just tid) + +-- | Store a memory for a user. +storeMemory :: Text -> Text -> MemorySource -> IO Memory +storeMemory uid content source = storeMemoryWithTags uid content source [] + +-- | Store a memory with tags. +storeMemoryWithTags :: Text -> Text -> MemorySource -> [Text] -> IO Memory +storeMemoryWithTags uid content source tags = do + uuid <- UUID.nextRandom + now <- getCurrentTime + embedding <- embedText content + let mem = + Memory + { memoryId = UUID.toText uuid, + memoryUserId = uid, + memoryContent = content, + memoryEmbedding = either (const Nothing) Just embedding, + memorySource = source, + memoryConfidence = 0.8, + memoryCreatedAt = now, + memoryLastAccessedAt = now, + memoryTags = tags + } + withMemoryDb <| \conn -> + SQL.execute + conn + "INSERT INTO memories (id, user_id, content, embedding, source_agent, source_session, source_context, confidence, created_at, last_accessed_at, tags) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" + ( ( memoryId mem, + memoryUserId mem, + memoryContent mem, + vectorToBlob </ memoryEmbedding mem, + sourceAgent (memorySource mem), + sourceSession (memorySource mem), + sourceContext (memorySource mem) + ) + SQL.:. ( memoryConfidence mem, + memoryCreatedAt mem, + memoryLastAccessedAt mem, + TE.decodeUtf8 (BL.toStrict (Aeson.encode (memoryTags mem))) + ) + ) + pure mem + +-- | Recall memories for a user using semantic similarity. +recallMemories :: Text -> Text -> Int -> IO [Memory] +recallMemories uid query limit = do + queryEmbedding <- embedText query + case queryEmbedding of + Left _ -> recallMemoriesByRecency uid limit + Right qEmb -> do + allMems <- getAllMemoriesForUser uid + let scored = + [ (m, cosineSimilarity qEmb emb) + | m <- allMems, + Just emb <- [memoryEmbedding m] + ] + sorted = List.sortBy (\(_, s1) (_, s2) -> compare s2 s1) scored + topN = take limit sorted + now <- getCurrentTime + traverse_ (updateMemoryAccess now <. memoryId <. fst) topN + pure (map fst topN) + +-- | Recall memories by recency (fallback when embedding fails). +recallMemoriesByRecency :: Text -> Int -> IO [Memory] +recallMemoriesByRecency uid limit = + withMemoryDb <| \conn -> do + SQL.query + conn + "SELECT id, user_id, content, embedding, source_agent, source_session, source_context, confidence, created_at, last_accessed_at, tags \ + \FROM memories WHERE user_id = ? ORDER BY last_accessed_at DESC LIMIT ?" + (uid, limit) + +-- | Get all memories for a user. +getAllMemoriesForUser :: Text -> IO [Memory] +getAllMemoriesForUser uid = + withMemoryDb <| \conn -> + SQL.query + conn + "SELECT id, user_id, content, embedding, source_agent, source_session, source_context, confidence, created_at, last_accessed_at, tags \ + \FROM memories WHERE user_id = ?" + (SQL.Only uid) + +-- | Delete a memory. +forgetMemory :: Text -> IO () +forgetMemory mid = + withMemoryDb <| \conn -> + SQL.execute conn "DELETE FROM memories WHERE id = ?" (SQL.Only mid) + +-- | Update memory's last accessed timestamp. +updateMemoryAccess :: UTCTime -> Text -> IO () +updateMemoryAccess now mid = + withMemoryDb <| \conn -> + SQL.execute conn "UPDATE memories SET last_accessed_at = ? WHERE id = ?" (now, mid) + +-- | Embed text using Ollama's nomic-embed-text model. +embedText :: Text -> IO (Either Text (VS.Vector Float)) +embedText content = do + ollamaUrl <- fromMaybe "http://localhost:11434" </ lookupEnv "OLLAMA_URL" + let url = ollamaUrl <> "/api/embeddings" + req0 <- HTTP.parseRequest url + let body = + Aeson.object + [ "model" .= ("nomic-embed-text" :: Text), + "prompt" .= content + ] + req = + HTTP.setRequestMethod "POST" + <| HTTP.setRequestHeader "Content-Type" ["application/json"] + <| HTTP.setRequestBodyLBS (Aeson.encode body) + <| req0 + result <- try (HTTP.httpLBS req) + case result of + Left (e :: SomeException) -> + pure (Left ("Embedding request failed: " <> tshow e)) + Right response -> do + let status = HTTP.getResponseStatusCode response + if status >= 200 && status < 300 + then case Aeson.decode (HTTP.getResponseBody response) of + Just (Aeson.Object obj) -> case KeyMap.lookup "embedding" obj of + Just (Aeson.Array arr) -> + let floats = [f | Aeson.Number n <- toList arr, let f = realToFrac n] + in pure (Right (VS.fromList floats)) + _ -> pure (Left "No embedding in response") + _ -> pure (Left "Failed to parse embedding response") + else pure (Left ("Embedding HTTP error: " <> tshow status)) + +-- | Convert a vector to a blob for storage. +vectorToBlob :: VS.Vector Float -> BS.ByteString +vectorToBlob v = + let bytes = VS.unsafeCast v :: VS.Vector Word8 + in BS.pack (VS.toList bytes) + +-- | Convert a blob back to a vector. +blobToVector :: BS.ByteString -> VS.Vector Float +blobToVector bs = + let bytes = VS.fromList (BS.unpack bs) :: VS.Vector Word8 + in VS.unsafeCast bytes + +-- | Calculate cosine similarity between two vectors. +cosineSimilarity :: VS.Vector Float -> VS.Vector Float -> Float +cosineSimilarity v1 v2 + | VS.length v1 /= VS.length v2 = 0 + | otherwise = + let dot = VS.sum (VS.zipWith (*) v1 v2) + mag1 = sqrt (VS.sum (VS.map (\x -> x * x) v1)) + mag2 = sqrt (VS.sum (VS.map (\x -> x * x) v2)) + in if mag1 == 0 || mag2 == 0 then 0 else dot / (mag1 * mag2) + +-- | Format memories for inclusion in a prompt. +formatMemoriesForPrompt :: [Memory] -> Text +formatMemoriesForPrompt [] = "No prior context available." +formatMemoriesForPrompt mems = + Text.unlines + [ "Known context about this user:", + "", + Text.unlines (map formatMem mems) + ] + where + formatMem m = + "- " <> memoryContent m <> " (via " <> sourceAgent (memorySource m) <> ")" + +-- | Run an agent with memory context. +-- Recalls relevant memories for the user and injects them into the system prompt. +runAgentWithMemory :: + User -> + Engine.EngineConfig -> + Engine.AgentConfig -> + Text -> + IO (Either Text Engine.AgentResult) +runAgentWithMemory user engineCfg agentCfg userPrompt = do + memories <- recallMemories (userId user) userPrompt 10 + let memoryContext = formatMemoriesForPrompt memories + enhancedPrompt = + Engine.agentSystemPrompt agentCfg + <> "\n\n## Known about this user\n" + <> memoryContext + enhancedConfig = + agentCfg + { Engine.agentSystemPrompt = enhancedPrompt, + Engine.agentTools = + Engine.agentTools agentCfg + <> [rememberTool (userId user), recallTool (userId user)] + } + Engine.runAgent engineCfg enhancedConfig userPrompt + +-- | Tool for agents to store memories about users. +rememberTool :: Text -> Engine.Tool +rememberTool uid = + Engine.Tool + { Engine.toolName = "remember", + Engine.toolDescription = + "Store a piece of information about the user for future reference. " + <> "Use this when the user shares personal facts, preferences, or context " + <> "that would be useful to recall in future conversations.", + Engine.toolJsonSchema = + Aeson.object + [ "type" .= ("object" :: Text), + "properties" + .= Aeson.object + [ "content" + .= Aeson.object + [ "type" .= ("string" :: Text), + "description" .= ("The information to remember about the user" :: Text) + ], + "context" + .= Aeson.object + [ "type" .= ("string" :: Text), + "description" .= ("How/why this was learned (e.g., 'user mentioned in chat')" :: Text) + ], + "tags" + .= Aeson.object + [ "type" .= ("array" :: Text), + "items" .= Aeson.object ["type" .= ("string" :: Text)], + "description" .= ("Optional tags for categorization" :: Text) + ] + ], + "required" .= (["content", "context"] :: [Text]) + ], + Engine.toolExecute = executeRemember uid + } + +executeRemember :: Text -> Aeson.Value -> IO Aeson.Value +executeRemember uid v = + case Aeson.fromJSON v of + Aeson.Error e -> pure (Aeson.object ["error" .= Text.pack e]) + Aeson.Success (args :: RememberArgs) -> do + let source = + MemorySource + { sourceAgent = "agent", + sourceSession = Nothing, + sourceContext = rememberContext args + } + mem <- storeMemoryWithTags uid (rememberContent args) source (rememberTags args) + pure + ( Aeson.object + [ "success" .= True, + "memory_id" .= memoryId mem, + "message" .= ("Remembered: " <> rememberContent args) + ] + ) + +-- | Tool for agents to recall memories about users. +recallTool :: Text -> Engine.Tool +recallTool uid = + Engine.Tool + { Engine.toolName = "recall", + Engine.toolDescription = + "Search your memory for information about the user. " + <> "Use this to retrieve previously stored facts, preferences, or context.", + Engine.toolJsonSchema = + Aeson.object + [ "type" .= ("object" :: Text), + "properties" + .= Aeson.object + [ "query" + .= Aeson.object + [ "type" .= ("string" :: Text), + "description" .= ("What to search for in memory" :: Text) + ], + "limit" + .= Aeson.object + [ "type" .= ("integer" :: Text), + "description" .= ("Maximum memories to return (default: 5)" :: Text) + ] + ], + "required" .= (["query"] :: [Text]) + ], + Engine.toolExecute = executeRecall uid + } + +executeRecall :: Text -> Aeson.Value -> IO Aeson.Value +executeRecall uid v = + case Aeson.fromJSON v of + Aeson.Error e -> pure (Aeson.object ["error" .= Text.pack e]) + Aeson.Success (args :: RecallArgs) -> do + mems <- recallMemories uid (recallQuery args) (recallLimit args) + pure + ( Aeson.object + [ "success" .= True, + "count" .= length mems, + "memories" + .= map + ( \m -> + Aeson.object + [ "content" .= memoryContent m, + "confidence" .= memoryConfidence m, + "source" .= sourceAgent (memorySource m), + "tags" .= memoryTags m + ] + ) + mems + ] + ) + +-- Helper for parsing remember args +data RememberArgs = RememberArgs + { rememberContent :: Text, + rememberContext :: Text, + rememberTags :: [Text] + } + deriving (Generic) + +instance Aeson.FromJSON RememberArgs where + parseJSON = + Aeson.withObject "RememberArgs" <| \v -> + (RememberArgs </ (v .: "content")) + <*> (v .:? "context" .!= "agent observation") + <*> (v .:? "tags" .!= []) + +data RecallArgs = RecallArgs + { recallQuery :: Text, + recallLimit :: Int + } + deriving (Generic) + +instance Aeson.FromJSON RecallArgs where + parseJSON = + Aeson.withObject "RecallArgs" <| \v -> + (RecallArgs </ (v .: "query")) + <*> (v .:? "limit" .!= 5) |
