diff options
| author | Ben Sima <ben@bensima.com> | 2025-12-13 11:37:10 -0500 |
|---|---|---|
| committer | Ben Sima <ben@bensima.com> | 2025-12-13 11:37:10 -0500 |
| commit | 0936eb15144e2fc15b073e989d6c5d700dc47435 (patch) | |
| tree | 20e77aa1205dea43f398bf54deac12759fd54b7c /Omni/Agent | |
| parent | ed629a3335c6c5a172322a8d7387f0c6990b0ae5 (diff) | |
Add knowledge graph with typed relations to Memory module
- Add RelationType with 6 relation types
- Add MemoryLink type and memory_links table
- Add graph functions: linkMemories, getMemoryLinks, queryGraph
- Add link_memories and query_graph agent tools
- Wire up graph tools to Telegram bot
- Include memory ID in recall results for linking
- Fix streaming usage parsing for cost tracking
Closes t-255
Amp-Thread-ID: https://ampcode.com/threads/T-019b181f-d6cd-70de-8857-c445baef7508
Co-authored-by: Amp <amp@ampcode.com>
Diffstat (limited to 'Omni/Agent')
| -rw-r--r-- | Omni/Agent/Memory.hs | 410 | ||||
| -rw-r--r-- | Omni/Agent/Provider.hs | 18 | ||||
| -rw-r--r-- | Omni/Agent/Telegram.hs | 4 |
3 files changed, 419 insertions, 13 deletions
diff --git a/Omni/Agent/Memory.hs b/Omni/Agent/Memory.hs index 136ac1e..0a050b7 100644 --- a/Omni/Agent/Memory.hs +++ b/Omni/Agent/Memory.hs @@ -29,6 +29,8 @@ module Omni.Agent.Memory ConversationMessage (..), ConversationSummary (..), MessageRole (..), + RelationType (..), + MemoryLink (..), -- * User Management createUser, @@ -43,6 +45,12 @@ module Omni.Agent.Memory getAllMemoriesForUser, updateMemoryAccess, + -- * Knowledge Graph + linkMemories, + getMemoryLinks, + getLinkedMemories, + queryGraph, + -- * Conversation History saveMessage, getRecentMessages, @@ -56,6 +64,8 @@ module Omni.Agent.Memory -- * Agent Integration rememberTool, recallTool, + linkMemoriesTool, + queryGraphTool, formatMemoriesForPrompt, runAgentWithMemory, @@ -198,7 +208,38 @@ test = Engine.toolName tool Test.@=? "remember", Test.unit "recallTool has correct schema" <| do let tool = recallTool "test-user-id" - Engine.toolName tool Test.@=? "recall" + Engine.toolName tool Test.@=? "recall", + Test.unit "RelationType JSON roundtrip" <| do + let types = [Contradicts, Supports, Elaborates, Supersedes, Related, ContingentOn] + forM_ types <| \rt -> + case Aeson.decode (Aeson.encode rt) of + Nothing -> Test.assertFailure ("Failed to decode RelationType: " <> show rt) + Just decoded -> decoded Test.@=? rt, + Test.unit "MemoryLink JSON roundtrip" <| do + now <- getCurrentTime + let memLink = + MemoryLink + { linkFromMemoryId = "mem-1", + linkToMemoryId = "mem-2", + linkRelationType = Contradicts, + linkCreatedAt = now + } + case Aeson.decode (Aeson.encode memLink) of + Nothing -> Test.assertFailure "Failed to decode MemoryLink" + Just decoded -> do + linkFromMemoryId decoded Test.@=? "mem-1" + linkToMemoryId decoded Test.@=? "mem-2" + linkRelationType decoded Test.@=? Contradicts, + Test.unit "relationTypeToText and textToRelationType roundtrip" <| do + let types = [Contradicts, Supports, Elaborates, Supersedes, Related, ContingentOn] + forM_ types <| \rt -> + textToRelationType (relationTypeToText rt) Test.@=? Just rt, + Test.unit "linkMemoriesTool has correct schema" <| do + let tool = linkMemoriesTool "test-user-id" + Engine.toolName tool Test.@=? "link_memories", + Test.unit "queryGraphTool has correct schema" <| do + let tool = queryGraphTool "test-user-id" + Engine.toolName tool Test.@=? "query_graph" ] -- | User record for multi-user memory system. @@ -433,6 +474,93 @@ instance SQL.FromRow ConversationSummary where <*> SQL.field <*> SQL.field +-- | Relation types for the knowledge graph. +data RelationType + = Contradicts + | Supports + | Elaborates + | Supersedes + | Related + | ContingentOn + deriving (Show, Eq, Generic, Ord) + +instance Aeson.ToJSON RelationType where + toJSON Contradicts = Aeson.String "contradicts" + toJSON Supports = Aeson.String "supports" + toJSON Elaborates = Aeson.String "elaborates" + toJSON Supersedes = Aeson.String "supersedes" + toJSON Related = Aeson.String "related" + toJSON ContingentOn = Aeson.String "contingent_on" + +instance Aeson.FromJSON RelationType where + parseJSON = + Aeson.withText "RelationType" <| \case + "contradicts" -> pure Contradicts + "supports" -> pure Supports + "elaborates" -> pure Elaborates + "supersedes" -> pure Supersedes + "related" -> pure Related + "contingent_on" -> pure ContingentOn + _ -> empty + +relationTypeToText :: RelationType -> Text +relationTypeToText Contradicts = "contradicts" +relationTypeToText Supports = "supports" +relationTypeToText Elaborates = "elaborates" +relationTypeToText Supersedes = "supersedes" +relationTypeToText Related = "related" +relationTypeToText ContingentOn = "contingent_on" + +textToRelationType :: Text -> Maybe RelationType +textToRelationType "contradicts" = Just Contradicts +textToRelationType "supports" = Just Supports +textToRelationType "elaborates" = Just Elaborates +textToRelationType "supersedes" = Just Supersedes +textToRelationType "related" = Just Related +textToRelationType "contingent_on" = Just ContingentOn +textToRelationType _ = Nothing + +-- | A link between two memories in the knowledge graph. +data MemoryLink = MemoryLink + { linkFromMemoryId :: Text, + linkToMemoryId :: Text, + linkRelationType :: RelationType, + linkCreatedAt :: UTCTime + } + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON MemoryLink where + toJSON l = + Aeson.object + [ "from_memory_id" .= linkFromMemoryId l, + "to_memory_id" .= linkToMemoryId l, + "relation_type" .= linkRelationType l, + "created_at" .= linkCreatedAt l + ] + +instance Aeson.FromJSON MemoryLink where + parseJSON = + Aeson.withObject "MemoryLink" <| \v -> + (MemoryLink </ (v .: "from_memory_id")) + <*> (v .: "to_memory_id") + <*> (v .: "relation_type") + <*> (v .: "created_at") + +instance SQL.FromRow MemoryLink where + fromRow = do + fromId <- SQL.field + toId <- SQL.field + relTypeText <- SQL.field + createdAt <- SQL.field + let relType = fromMaybe Related (textToRelationType relTypeText) + pure + MemoryLink + { linkFromMemoryId = fromId, + linkToMemoryId = toId, + linkRelationType = relType, + linkCreatedAt = createdAt + } + -- | Get the path to memory.db getMemoryDbPath :: IO FilePath getMemoryDbPath = do @@ -549,6 +677,24 @@ initMemoryDb conn = do SQL.execute_ conn "CREATE INDEX IF NOT EXISTS idx_todos_due ON todos(user_id, due_date)" + SQL.execute_ + conn + "CREATE TABLE IF NOT EXISTS memory_links (\ + \ from_memory_id TEXT NOT NULL REFERENCES memories(id) ON DELETE CASCADE,\ + \ to_memory_id TEXT NOT NULL REFERENCES memories(id) ON DELETE CASCADE,\ + \ relation_type TEXT NOT NULL,\ + \ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\ + \ PRIMARY KEY (from_memory_id, to_memory_id, relation_type)\ + \)" + SQL.execute_ + conn + "CREATE INDEX IF NOT EXISTS idx_memory_links_from ON memory_links(from_memory_id)" + SQL.execute_ + conn + "CREATE INDEX IF NOT EXISTS idx_memory_links_to ON memory_links(to_memory_id)" + SQL.execute_ + conn + "CREATE INDEX IF NOT EXISTS idx_memory_links_type ON memory_links(relation_type)" -- | Migrate conversation_messages to add sender_name column. migrateConversationMessages :: SQL.Connection -> IO () @@ -694,6 +840,91 @@ updateMemoryAccess now mid = withMemoryDb <| \conn -> SQL.execute conn "UPDATE memories SET last_accessed_at = ? WHERE id = ?" (now, mid) +-- | Create a link between two memories. +linkMemories :: Text -> Text -> RelationType -> IO MemoryLink +linkMemories fromId toId relType = do + now <- getCurrentTime + withMemoryDb <| \conn -> + SQL.execute + conn + "INSERT OR REPLACE INTO memory_links (from_memory_id, to_memory_id, relation_type, created_at) VALUES (?, ?, ?, ?)" + (fromId, toId, relationTypeToText relType, now) + pure + MemoryLink + { linkFromMemoryId = fromId, + linkToMemoryId = toId, + linkRelationType = relType, + linkCreatedAt = now + } + +-- | Get all links from a memory. +getMemoryLinks :: Text -> IO [MemoryLink] +getMemoryLinks memId = + withMemoryDb <| \conn -> + SQL.query + conn + "SELECT from_memory_id, to_memory_id, relation_type, created_at \ + \FROM memory_links WHERE from_memory_id = ? OR to_memory_id = ?" + (memId, memId) + +-- | Get memories linked to a given memory with their content. +getLinkedMemories :: Text -> Maybe RelationType -> IO [(MemoryLink, Memory)] +getLinkedMemories memId maybeRelType = do + links <- getMemoryLinks memId + let filteredLinks = case maybeRelType of + Nothing -> links + Just rt -> filter (\l -> linkRelationType l == rt) links + mems <- traverse loadMemory filteredLinks + pure [(l, m) | (l, Just m) <- zip filteredLinks mems] + where + loadMemory memLink = do + let targetId = + if linkFromMemoryId memLink == memId + then linkToMemoryId memLink + else linkFromMemoryId memLink + withMemoryDb <| \conn -> do + results <- + SQL.query + conn + "SELECT id, user_id, content, embedding, source_agent, source_session, source_context, confidence, created_at, last_accessed_at, tags \ + \FROM memories WHERE id = ?" + (SQL.Only targetId) + pure (listToMaybe results) + +-- | Query the knowledge graph by traversing links from a starting memory. +-- Returns all memories reachable within the given depth. +queryGraph :: Text -> Int -> Maybe RelationType -> IO [(Memory, [MemoryLink])] +queryGraph startMemId maxDepth maybeRelType = do + startMem <- getMemoryById startMemId + case startMem of + Nothing -> pure [] + Just mem -> go [startMemId] [(mem, [])] 0 + where + go :: [Text] -> [(Memory, [MemoryLink])] -> Int -> IO [(Memory, [MemoryLink])] + go _ acc depth | depth >= maxDepth = pure acc + go visitedIds acc depth = do + let currentIds = map (memoryId <. fst) acc + newIds = filter (`notElem` visitedIds) currentIds + if null newIds + then pure acc + else do + newLinked <- concat </ traverse (`getLinkedMemories` maybeRelType) newIds + let newMems = [(m, [l]) | (l, m) <- newLinked, memoryId m `notElem` visitedIds] + newVisited = visitedIds <> map (memoryId <. fst) newMems + go newVisited (acc <> newMems) (depth + 1) + +-- | Get a memory by ID. +getMemoryById :: Text -> IO (Maybe Memory) +getMemoryById memId = + withMemoryDb <| \conn -> do + results <- + SQL.query + conn + "SELECT id, user_id, content, embedding, source_agent, source_session, source_context, confidence, created_at, last_accessed_at, tags \ + \FROM memories WHERE id = ?" + (SQL.Only memId) + pure (listToMaybe results) + -- | Embed text using Ollama's nomic-embed-text model. embedText :: Text -> IO (Either Text (VS.Vector Float)) embedText content = do @@ -781,7 +1012,11 @@ runAgentWithMemory user engineCfg agentCfg userPrompt = do { Engine.agentSystemPrompt = enhancedPrompt, Engine.agentTools = Engine.agentTools agentCfg - <> [rememberTool (userId user), recallTool (userId user)] + <> [ rememberTool (userId user), + recallTool (userId user), + linkMemoriesTool (userId user), + queryGraphTool (userId user) + ] } Engine.runAgent engineCfg enhancedConfig userPrompt @@ -884,7 +1119,8 @@ executeRecall uid v = .= map ( \m -> Aeson.object - [ "content" .= memoryContent m, + [ "id" .= memoryId m, + "content" .= memoryContent m, "confidence" .= memoryConfidence m, "source" .= sourceAgent (memorySource m), "tags" .= memoryTags m @@ -921,6 +1157,174 @@ instance Aeson.FromJSON RecallArgs where (RecallArgs </ (v .: "query")) <*> (v .:? "limit" .!= 5) +-- | Tool for agents to link memories in the knowledge graph. +linkMemoriesTool :: Text -> Engine.Tool +linkMemoriesTool _uid = + Engine.Tool + { Engine.toolName = "link_memories", + Engine.toolDescription = + "Create a typed relationship between two memories. " + <> "Use this to connect related information. Relation types:\n" + <> "- contradicts: conflicting information\n" + <> "- supports: evidence that reinforces another memory\n" + <> "- elaborates: adds detail to an existing memory\n" + <> "- supersedes: newer info replaces older\n" + <> "- related: general topical connection\n" + <> "- contingent_on: depends on another fact being true", + Engine.toolJsonSchema = + Aeson.object + [ "type" .= ("object" :: Text), + "properties" + .= Aeson.object + [ "from_memory_id" + .= Aeson.object + [ "type" .= ("string" :: Text), + "description" .= ("ID of the source memory" :: Text) + ], + "to_memory_id" + .= Aeson.object + [ "type" .= ("string" :: Text), + "description" .= ("ID of the target memory" :: Text) + ], + "relation_type" + .= Aeson.object + [ "type" .= ("string" :: Text), + "enum" .= (["contradicts", "supports", "elaborates", "supersedes", "related", "contingent_on"] :: [Text]), + "description" .= ("Type of relationship between memories" :: Text) + ] + ], + "required" .= (["from_memory_id", "to_memory_id", "relation_type"] :: [Text]) + ], + Engine.toolExecute = executeLinkMemories + } + +executeLinkMemories :: Aeson.Value -> IO Aeson.Value +executeLinkMemories v = + case Aeson.fromJSON v of + Aeson.Error e -> pure (Aeson.object ["error" .= Text.pack e]) + Aeson.Success (args :: LinkMemoriesArgs) -> do + case textToRelationType (linkArgsRelationType args) of + Nothing -> + pure + ( Aeson.object + [ "success" .= False, + "error" .= ("Invalid relation type: " <> linkArgsRelationType args) + ] + ) + Just relType -> do + memLink <- linkMemories (linkArgsFromId args) (linkArgsToId args) relType + pure + ( Aeson.object + [ "success" .= True, + "message" + .= ( "Linked memory " + <> linkFromMemoryId memLink + <> " -> " + <> linkToMemoryId memLink + <> " (" + <> relationTypeToText (linkRelationType memLink) + <> ")" + ) + ] + ) + +data LinkMemoriesArgs = LinkMemoriesArgs + { linkArgsFromId :: Text, + linkArgsToId :: Text, + linkArgsRelationType :: Text + } + deriving (Generic) + +instance Aeson.FromJSON LinkMemoriesArgs where + parseJSON = + Aeson.withObject "LinkMemoriesArgs" <| \v -> + (LinkMemoriesArgs </ (v .: "from_memory_id")) + <*> (v .: "to_memory_id") + <*> (v .: "relation_type") + +-- | Tool for agents to query the memory knowledge graph. +queryGraphTool :: Text -> Engine.Tool +queryGraphTool _uid = + Engine.Tool + { Engine.toolName = "query_graph", + Engine.toolDescription = + "Explore the knowledge graph to find related memories. " + <> "Given a starting memory, traverse links to find connected memories. " + <> "Useful for understanding context and finding contradictions or supporting evidence.", + Engine.toolJsonSchema = + Aeson.object + [ "type" .= ("object" :: Text), + "properties" + .= Aeson.object + [ "memory_id" + .= Aeson.object + [ "type" .= ("string" :: Text), + "description" .= ("ID of the memory to start from" :: Text) + ], + "depth" + .= Aeson.object + [ "type" .= ("integer" :: Text), + "description" .= ("How many link hops to traverse (default: 2)" :: Text) + ], + "relation_type" + .= Aeson.object + [ "type" .= ("string" :: Text), + "enum" .= (["contradicts", "supports", "elaborates", "supersedes", "related", "contingent_on"] :: [Text]), + "description" .= ("Optional: filter by relation type" :: Text) + ] + ], + "required" .= (["memory_id"] :: [Text]) + ], + Engine.toolExecute = executeQueryGraph + } + +executeQueryGraph :: Aeson.Value -> IO Aeson.Value +executeQueryGraph v = + case Aeson.fromJSON v of + Aeson.Error e -> pure (Aeson.object ["error" .= Text.pack e]) + Aeson.Success (args :: QueryGraphArgs) -> do + let maybeRelType = queryArgsRelationType args +> textToRelationType + results <- queryGraph (queryArgsMemoryId args) (queryArgsDepth args) maybeRelType + pure + ( Aeson.object + [ "success" .= True, + "count" .= length results, + "memories" + .= map + ( \(m, links) -> + Aeson.object + [ "id" .= memoryId m, + "content" .= memoryContent m, + "links" + .= map + ( \l -> + Aeson.object + [ "from" .= linkFromMemoryId l, + "to" .= linkToMemoryId l, + "relation" .= linkRelationType l + ] + ) + links + ] + ) + results + ] + ) + +data QueryGraphArgs = QueryGraphArgs + { queryArgsMemoryId :: Text, + queryArgsDepth :: Int, + queryArgsRelationType :: Maybe Text + } + deriving (Generic) + +instance Aeson.FromJSON QueryGraphArgs where + parseJSON = + Aeson.withObject "QueryGraphArgs" <| \v -> + (QueryGraphArgs </ (v .: "memory_id")) + <*> (v .:? "depth" .!= 2) + <*> (v .:? "relation_type") + -- | Estimate token count for text (rough: ~4 chars per token). estimateTokens :: Text -> Int estimateTokens t = max 1 (Text.length t `div` 4) diff --git a/Omni/Agent/Provider.hs b/Omni/Agent/Provider.hs index fd6920d..1bb4f04 100644 --- a/Omni/Agent/Provider.hs +++ b/Omni/Agent/Provider.hs @@ -589,6 +589,11 @@ parseStreamChunk obj = do _ -> "Unknown error" Just (StreamError errMsg) _ -> do + let usageChunk = case KeyMap.lookup "usage" obj of + Just usageVal -> case Aeson.fromJSON usageVal of + Aeson.Success usage -> Just (StreamDone (ChatResult (Message Assistant "" Nothing Nothing) (Just usage))) + _ -> Nothing + _ -> Nothing case KeyMap.lookup "choices" obj of Just (Aeson.Array choices) | not (null choices) -> do case toList choices of @@ -603,15 +608,10 @@ parseStreamChunk obj = do | not (null tcs) -> parseToolCallDelta (toList tcs) _ -> Nothing - contentChunk <|> toolCallChunk - _ -> Nothing - _ -> Nothing - _ -> do - case KeyMap.lookup "usage" obj of - Just usageVal -> case Aeson.fromJSON usageVal of - Aeson.Success usage -> Just (StreamDone (ChatResult (Message Assistant "" Nothing Nothing) (Just usage))) - _ -> Nothing - _ -> Nothing + contentChunk <|> toolCallChunk <|> usageChunk + _ -> usageChunk + _ -> usageChunk + _ -> usageChunk parseToolCallDelta :: [Aeson.Value] -> Maybe StreamChunk parseToolCallDelta [] = Nothing diff --git a/Omni/Agent/Telegram.hs b/Omni/Agent/Telegram.hs index 418e589..091ad11 100644 --- a/Omni/Agent/Telegram.hs +++ b/Omni/Agent/Telegram.hs @@ -602,7 +602,9 @@ processEngagedMessage tgConfig provider engineCfg msg uid userName chatId userMe let memoryTools = [ Memory.rememberTool uid, - Memory.recallTool uid + Memory.recallTool uid, + Memory.linkMemoriesTool uid, + Memory.queryGraphTool uid ] searchTools = case Types.tgKagiApiKey tgConfig of Just kagiKey -> [WebSearch.webSearchTool kagiKey] |
