summaryrefslogtreecommitdiff
path: root/Omni
diff options
context:
space:
mode:
authorBen Sima <ben@bensima.com>2025-12-13 11:37:10 -0500
committerBen Sima <ben@bensima.com>2025-12-13 11:37:10 -0500
commit0936eb15144e2fc15b073e989d6c5d700dc47435 (patch)
tree20e77aa1205dea43f398bf54deac12759fd54b7c /Omni
parented629a3335c6c5a172322a8d7387f0c6990b0ae5 (diff)
Add knowledge graph with typed relations to Memory module
- Add RelationType with 6 relation types - Add MemoryLink type and memory_links table - Add graph functions: linkMemories, getMemoryLinks, queryGraph - Add link_memories and query_graph agent tools - Wire up graph tools to Telegram bot - Include memory ID in recall results for linking - Fix streaming usage parsing for cost tracking Closes t-255 Amp-Thread-ID: https://ampcode.com/threads/T-019b181f-d6cd-70de-8857-c445baef7508 Co-authored-by: Amp <amp@ampcode.com>
Diffstat (limited to 'Omni')
-rw-r--r--Omni/Agent/Memory.hs410
-rw-r--r--Omni/Agent/Provider.hs18
-rw-r--r--Omni/Agent/Telegram.hs4
3 files changed, 419 insertions, 13 deletions
diff --git a/Omni/Agent/Memory.hs b/Omni/Agent/Memory.hs
index 136ac1e..0a050b7 100644
--- a/Omni/Agent/Memory.hs
+++ b/Omni/Agent/Memory.hs
@@ -29,6 +29,8 @@ module Omni.Agent.Memory
ConversationMessage (..),
ConversationSummary (..),
MessageRole (..),
+ RelationType (..),
+ MemoryLink (..),
-- * User Management
createUser,
@@ -43,6 +45,12 @@ module Omni.Agent.Memory
getAllMemoriesForUser,
updateMemoryAccess,
+ -- * Knowledge Graph
+ linkMemories,
+ getMemoryLinks,
+ getLinkedMemories,
+ queryGraph,
+
-- * Conversation History
saveMessage,
getRecentMessages,
@@ -56,6 +64,8 @@ module Omni.Agent.Memory
-- * Agent Integration
rememberTool,
recallTool,
+ linkMemoriesTool,
+ queryGraphTool,
formatMemoriesForPrompt,
runAgentWithMemory,
@@ -198,7 +208,38 @@ test =
Engine.toolName tool Test.@=? "remember",
Test.unit "recallTool has correct schema" <| do
let tool = recallTool "test-user-id"
- Engine.toolName tool Test.@=? "recall"
+ Engine.toolName tool Test.@=? "recall",
+ Test.unit "RelationType JSON roundtrip" <| do
+ let types = [Contradicts, Supports, Elaborates, Supersedes, Related, ContingentOn]
+ forM_ types <| \rt ->
+ case Aeson.decode (Aeson.encode rt) of
+ Nothing -> Test.assertFailure ("Failed to decode RelationType: " <> show rt)
+ Just decoded -> decoded Test.@=? rt,
+ Test.unit "MemoryLink JSON roundtrip" <| do
+ now <- getCurrentTime
+ let memLink =
+ MemoryLink
+ { linkFromMemoryId = "mem-1",
+ linkToMemoryId = "mem-2",
+ linkRelationType = Contradicts,
+ linkCreatedAt = now
+ }
+ case Aeson.decode (Aeson.encode memLink) of
+ Nothing -> Test.assertFailure "Failed to decode MemoryLink"
+ Just decoded -> do
+ linkFromMemoryId decoded Test.@=? "mem-1"
+ linkToMemoryId decoded Test.@=? "mem-2"
+ linkRelationType decoded Test.@=? Contradicts,
+ Test.unit "relationTypeToText and textToRelationType roundtrip" <| do
+ let types = [Contradicts, Supports, Elaborates, Supersedes, Related, ContingentOn]
+ forM_ types <| \rt ->
+ textToRelationType (relationTypeToText rt) Test.@=? Just rt,
+ Test.unit "linkMemoriesTool has correct schema" <| do
+ let tool = linkMemoriesTool "test-user-id"
+ Engine.toolName tool Test.@=? "link_memories",
+ Test.unit "queryGraphTool has correct schema" <| do
+ let tool = queryGraphTool "test-user-id"
+ Engine.toolName tool Test.@=? "query_graph"
]
-- | User record for multi-user memory system.
@@ -433,6 +474,93 @@ instance SQL.FromRow ConversationSummary where
<*> SQL.field
<*> SQL.field
+-- | Relation types for the knowledge graph.
+data RelationType
+ = Contradicts
+ | Supports
+ | Elaborates
+ | Supersedes
+ | Related
+ | ContingentOn
+ deriving (Show, Eq, Generic, Ord)
+
+instance Aeson.ToJSON RelationType where
+ toJSON Contradicts = Aeson.String "contradicts"
+ toJSON Supports = Aeson.String "supports"
+ toJSON Elaborates = Aeson.String "elaborates"
+ toJSON Supersedes = Aeson.String "supersedes"
+ toJSON Related = Aeson.String "related"
+ toJSON ContingentOn = Aeson.String "contingent_on"
+
+instance Aeson.FromJSON RelationType where
+ parseJSON =
+ Aeson.withText "RelationType" <| \case
+ "contradicts" -> pure Contradicts
+ "supports" -> pure Supports
+ "elaborates" -> pure Elaborates
+ "supersedes" -> pure Supersedes
+ "related" -> pure Related
+ "contingent_on" -> pure ContingentOn
+ _ -> empty
+
+relationTypeToText :: RelationType -> Text
+relationTypeToText Contradicts = "contradicts"
+relationTypeToText Supports = "supports"
+relationTypeToText Elaborates = "elaborates"
+relationTypeToText Supersedes = "supersedes"
+relationTypeToText Related = "related"
+relationTypeToText ContingentOn = "contingent_on"
+
+textToRelationType :: Text -> Maybe RelationType
+textToRelationType "contradicts" = Just Contradicts
+textToRelationType "supports" = Just Supports
+textToRelationType "elaborates" = Just Elaborates
+textToRelationType "supersedes" = Just Supersedes
+textToRelationType "related" = Just Related
+textToRelationType "contingent_on" = Just ContingentOn
+textToRelationType _ = Nothing
+
+-- | A link between two memories in the knowledge graph.
+data MemoryLink = MemoryLink
+ { linkFromMemoryId :: Text,
+ linkToMemoryId :: Text,
+ linkRelationType :: RelationType,
+ linkCreatedAt :: UTCTime
+ }
+ deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON MemoryLink where
+ toJSON l =
+ Aeson.object
+ [ "from_memory_id" .= linkFromMemoryId l,
+ "to_memory_id" .= linkToMemoryId l,
+ "relation_type" .= linkRelationType l,
+ "created_at" .= linkCreatedAt l
+ ]
+
+instance Aeson.FromJSON MemoryLink where
+ parseJSON =
+ Aeson.withObject "MemoryLink" <| \v ->
+ (MemoryLink </ (v .: "from_memory_id"))
+ <*> (v .: "to_memory_id")
+ <*> (v .: "relation_type")
+ <*> (v .: "created_at")
+
+instance SQL.FromRow MemoryLink where
+ fromRow = do
+ fromId <- SQL.field
+ toId <- SQL.field
+ relTypeText <- SQL.field
+ createdAt <- SQL.field
+ let relType = fromMaybe Related (textToRelationType relTypeText)
+ pure
+ MemoryLink
+ { linkFromMemoryId = fromId,
+ linkToMemoryId = toId,
+ linkRelationType = relType,
+ linkCreatedAt = createdAt
+ }
+
-- | Get the path to memory.db
getMemoryDbPath :: IO FilePath
getMemoryDbPath = do
@@ -549,6 +677,24 @@ initMemoryDb conn = do
SQL.execute_
conn
"CREATE INDEX IF NOT EXISTS idx_todos_due ON todos(user_id, due_date)"
+ SQL.execute_
+ conn
+ "CREATE TABLE IF NOT EXISTS memory_links (\
+ \ from_memory_id TEXT NOT NULL REFERENCES memories(id) ON DELETE CASCADE,\
+ \ to_memory_id TEXT NOT NULL REFERENCES memories(id) ON DELETE CASCADE,\
+ \ relation_type TEXT NOT NULL,\
+ \ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\
+ \ PRIMARY KEY (from_memory_id, to_memory_id, relation_type)\
+ \)"
+ SQL.execute_
+ conn
+ "CREATE INDEX IF NOT EXISTS idx_memory_links_from ON memory_links(from_memory_id)"
+ SQL.execute_
+ conn
+ "CREATE INDEX IF NOT EXISTS idx_memory_links_to ON memory_links(to_memory_id)"
+ SQL.execute_
+ conn
+ "CREATE INDEX IF NOT EXISTS idx_memory_links_type ON memory_links(relation_type)"
-- | Migrate conversation_messages to add sender_name column.
migrateConversationMessages :: SQL.Connection -> IO ()
@@ -694,6 +840,91 @@ updateMemoryAccess now mid =
withMemoryDb <| \conn ->
SQL.execute conn "UPDATE memories SET last_accessed_at = ? WHERE id = ?" (now, mid)
+-- | Create a link between two memories.
+linkMemories :: Text -> Text -> RelationType -> IO MemoryLink
+linkMemories fromId toId relType = do
+ now <- getCurrentTime
+ withMemoryDb <| \conn ->
+ SQL.execute
+ conn
+ "INSERT OR REPLACE INTO memory_links (from_memory_id, to_memory_id, relation_type, created_at) VALUES (?, ?, ?, ?)"
+ (fromId, toId, relationTypeToText relType, now)
+ pure
+ MemoryLink
+ { linkFromMemoryId = fromId,
+ linkToMemoryId = toId,
+ linkRelationType = relType,
+ linkCreatedAt = now
+ }
+
+-- | Get all links from a memory.
+getMemoryLinks :: Text -> IO [MemoryLink]
+getMemoryLinks memId =
+ withMemoryDb <| \conn ->
+ SQL.query
+ conn
+ "SELECT from_memory_id, to_memory_id, relation_type, created_at \
+ \FROM memory_links WHERE from_memory_id = ? OR to_memory_id = ?"
+ (memId, memId)
+
+-- | Get memories linked to a given memory with their content.
+getLinkedMemories :: Text -> Maybe RelationType -> IO [(MemoryLink, Memory)]
+getLinkedMemories memId maybeRelType = do
+ links <- getMemoryLinks memId
+ let filteredLinks = case maybeRelType of
+ Nothing -> links
+ Just rt -> filter (\l -> linkRelationType l == rt) links
+ mems <- traverse loadMemory filteredLinks
+ pure [(l, m) | (l, Just m) <- zip filteredLinks mems]
+ where
+ loadMemory memLink = do
+ let targetId =
+ if linkFromMemoryId memLink == memId
+ then linkToMemoryId memLink
+ else linkFromMemoryId memLink
+ withMemoryDb <| \conn -> do
+ results <-
+ SQL.query
+ conn
+ "SELECT id, user_id, content, embedding, source_agent, source_session, source_context, confidence, created_at, last_accessed_at, tags \
+ \FROM memories WHERE id = ?"
+ (SQL.Only targetId)
+ pure (listToMaybe results)
+
+-- | Query the knowledge graph by traversing links from a starting memory.
+-- Returns all memories reachable within the given depth.
+queryGraph :: Text -> Int -> Maybe RelationType -> IO [(Memory, [MemoryLink])]
+queryGraph startMemId maxDepth maybeRelType = do
+ startMem <- getMemoryById startMemId
+ case startMem of
+ Nothing -> pure []
+ Just mem -> go [startMemId] [(mem, [])] 0
+ where
+ go :: [Text] -> [(Memory, [MemoryLink])] -> Int -> IO [(Memory, [MemoryLink])]
+ go _ acc depth | depth >= maxDepth = pure acc
+ go visitedIds acc depth = do
+ let currentIds = map (memoryId <. fst) acc
+ newIds = filter (`notElem` visitedIds) currentIds
+ if null newIds
+ then pure acc
+ else do
+ newLinked <- concat </ traverse (`getLinkedMemories` maybeRelType) newIds
+ let newMems = [(m, [l]) | (l, m) <- newLinked, memoryId m `notElem` visitedIds]
+ newVisited = visitedIds <> map (memoryId <. fst) newMems
+ go newVisited (acc <> newMems) (depth + 1)
+
+-- | Get a memory by ID.
+getMemoryById :: Text -> IO (Maybe Memory)
+getMemoryById memId =
+ withMemoryDb <| \conn -> do
+ results <-
+ SQL.query
+ conn
+ "SELECT id, user_id, content, embedding, source_agent, source_session, source_context, confidence, created_at, last_accessed_at, tags \
+ \FROM memories WHERE id = ?"
+ (SQL.Only memId)
+ pure (listToMaybe results)
+
-- | Embed text using Ollama's nomic-embed-text model.
embedText :: Text -> IO (Either Text (VS.Vector Float))
embedText content = do
@@ -781,7 +1012,11 @@ runAgentWithMemory user engineCfg agentCfg userPrompt = do
{ Engine.agentSystemPrompt = enhancedPrompt,
Engine.agentTools =
Engine.agentTools agentCfg
- <> [rememberTool (userId user), recallTool (userId user)]
+ <> [ rememberTool (userId user),
+ recallTool (userId user),
+ linkMemoriesTool (userId user),
+ queryGraphTool (userId user)
+ ]
}
Engine.runAgent engineCfg enhancedConfig userPrompt
@@ -884,7 +1119,8 @@ executeRecall uid v =
.= map
( \m ->
Aeson.object
- [ "content" .= memoryContent m,
+ [ "id" .= memoryId m,
+ "content" .= memoryContent m,
"confidence" .= memoryConfidence m,
"source" .= sourceAgent (memorySource m),
"tags" .= memoryTags m
@@ -921,6 +1157,174 @@ instance Aeson.FromJSON RecallArgs where
(RecallArgs </ (v .: "query"))
<*> (v .:? "limit" .!= 5)
+-- | Tool for agents to link memories in the knowledge graph.
+linkMemoriesTool :: Text -> Engine.Tool
+linkMemoriesTool _uid =
+ Engine.Tool
+ { Engine.toolName = "link_memories",
+ Engine.toolDescription =
+ "Create a typed relationship between two memories. "
+ <> "Use this to connect related information. Relation types:\n"
+ <> "- contradicts: conflicting information\n"
+ <> "- supports: evidence that reinforces another memory\n"
+ <> "- elaborates: adds detail to an existing memory\n"
+ <> "- supersedes: newer info replaces older\n"
+ <> "- related: general topical connection\n"
+ <> "- contingent_on: depends on another fact being true",
+ Engine.toolJsonSchema =
+ Aeson.object
+ [ "type" .= ("object" :: Text),
+ "properties"
+ .= Aeson.object
+ [ "from_memory_id"
+ .= Aeson.object
+ [ "type" .= ("string" :: Text),
+ "description" .= ("ID of the source memory" :: Text)
+ ],
+ "to_memory_id"
+ .= Aeson.object
+ [ "type" .= ("string" :: Text),
+ "description" .= ("ID of the target memory" :: Text)
+ ],
+ "relation_type"
+ .= Aeson.object
+ [ "type" .= ("string" :: Text),
+ "enum" .= (["contradicts", "supports", "elaborates", "supersedes", "related", "contingent_on"] :: [Text]),
+ "description" .= ("Type of relationship between memories" :: Text)
+ ]
+ ],
+ "required" .= (["from_memory_id", "to_memory_id", "relation_type"] :: [Text])
+ ],
+ Engine.toolExecute = executeLinkMemories
+ }
+
+executeLinkMemories :: Aeson.Value -> IO Aeson.Value
+executeLinkMemories v =
+ case Aeson.fromJSON v of
+ Aeson.Error e -> pure (Aeson.object ["error" .= Text.pack e])
+ Aeson.Success (args :: LinkMemoriesArgs) -> do
+ case textToRelationType (linkArgsRelationType args) of
+ Nothing ->
+ pure
+ ( Aeson.object
+ [ "success" .= False,
+ "error" .= ("Invalid relation type: " <> linkArgsRelationType args)
+ ]
+ )
+ Just relType -> do
+ memLink <- linkMemories (linkArgsFromId args) (linkArgsToId args) relType
+ pure
+ ( Aeson.object
+ [ "success" .= True,
+ "message"
+ .= ( "Linked memory "
+ <> linkFromMemoryId memLink
+ <> " -> "
+ <> linkToMemoryId memLink
+ <> " ("
+ <> relationTypeToText (linkRelationType memLink)
+ <> ")"
+ )
+ ]
+ )
+
+data LinkMemoriesArgs = LinkMemoriesArgs
+ { linkArgsFromId :: Text,
+ linkArgsToId :: Text,
+ linkArgsRelationType :: Text
+ }
+ deriving (Generic)
+
+instance Aeson.FromJSON LinkMemoriesArgs where
+ parseJSON =
+ Aeson.withObject "LinkMemoriesArgs" <| \v ->
+ (LinkMemoriesArgs </ (v .: "from_memory_id"))
+ <*> (v .: "to_memory_id")
+ <*> (v .: "relation_type")
+
+-- | Tool for agents to query the memory knowledge graph.
+queryGraphTool :: Text -> Engine.Tool
+queryGraphTool _uid =
+ Engine.Tool
+ { Engine.toolName = "query_graph",
+ Engine.toolDescription =
+ "Explore the knowledge graph to find related memories. "
+ <> "Given a starting memory, traverse links to find connected memories. "
+ <> "Useful for understanding context and finding contradictions or supporting evidence.",
+ Engine.toolJsonSchema =
+ Aeson.object
+ [ "type" .= ("object" :: Text),
+ "properties"
+ .= Aeson.object
+ [ "memory_id"
+ .= Aeson.object
+ [ "type" .= ("string" :: Text),
+ "description" .= ("ID of the memory to start from" :: Text)
+ ],
+ "depth"
+ .= Aeson.object
+ [ "type" .= ("integer" :: Text),
+ "description" .= ("How many link hops to traverse (default: 2)" :: Text)
+ ],
+ "relation_type"
+ .= Aeson.object
+ [ "type" .= ("string" :: Text),
+ "enum" .= (["contradicts", "supports", "elaborates", "supersedes", "related", "contingent_on"] :: [Text]),
+ "description" .= ("Optional: filter by relation type" :: Text)
+ ]
+ ],
+ "required" .= (["memory_id"] :: [Text])
+ ],
+ Engine.toolExecute = executeQueryGraph
+ }
+
+executeQueryGraph :: Aeson.Value -> IO Aeson.Value
+executeQueryGraph v =
+ case Aeson.fromJSON v of
+ Aeson.Error e -> pure (Aeson.object ["error" .= Text.pack e])
+ Aeson.Success (args :: QueryGraphArgs) -> do
+ let maybeRelType = queryArgsRelationType args +> textToRelationType
+ results <- queryGraph (queryArgsMemoryId args) (queryArgsDepth args) maybeRelType
+ pure
+ ( Aeson.object
+ [ "success" .= True,
+ "count" .= length results,
+ "memories"
+ .= map
+ ( \(m, links) ->
+ Aeson.object
+ [ "id" .= memoryId m,
+ "content" .= memoryContent m,
+ "links"
+ .= map
+ ( \l ->
+ Aeson.object
+ [ "from" .= linkFromMemoryId l,
+ "to" .= linkToMemoryId l,
+ "relation" .= linkRelationType l
+ ]
+ )
+ links
+ ]
+ )
+ results
+ ]
+ )
+
+data QueryGraphArgs = QueryGraphArgs
+ { queryArgsMemoryId :: Text,
+ queryArgsDepth :: Int,
+ queryArgsRelationType :: Maybe Text
+ }
+ deriving (Generic)
+
+instance Aeson.FromJSON QueryGraphArgs where
+ parseJSON =
+ Aeson.withObject "QueryGraphArgs" <| \v ->
+ (QueryGraphArgs </ (v .: "memory_id"))
+ <*> (v .:? "depth" .!= 2)
+ <*> (v .:? "relation_type")
+
-- | Estimate token count for text (rough: ~4 chars per token).
estimateTokens :: Text -> Int
estimateTokens t = max 1 (Text.length t `div` 4)
diff --git a/Omni/Agent/Provider.hs b/Omni/Agent/Provider.hs
index fd6920d..1bb4f04 100644
--- a/Omni/Agent/Provider.hs
+++ b/Omni/Agent/Provider.hs
@@ -589,6 +589,11 @@ parseStreamChunk obj = do
_ -> "Unknown error"
Just (StreamError errMsg)
_ -> do
+ let usageChunk = case KeyMap.lookup "usage" obj of
+ Just usageVal -> case Aeson.fromJSON usageVal of
+ Aeson.Success usage -> Just (StreamDone (ChatResult (Message Assistant "" Nothing Nothing) (Just usage)))
+ _ -> Nothing
+ _ -> Nothing
case KeyMap.lookup "choices" obj of
Just (Aeson.Array choices) | not (null choices) -> do
case toList choices of
@@ -603,15 +608,10 @@ parseStreamChunk obj = do
| not (null tcs) ->
parseToolCallDelta (toList tcs)
_ -> Nothing
- contentChunk <|> toolCallChunk
- _ -> Nothing
- _ -> Nothing
- _ -> do
- case KeyMap.lookup "usage" obj of
- Just usageVal -> case Aeson.fromJSON usageVal of
- Aeson.Success usage -> Just (StreamDone (ChatResult (Message Assistant "" Nothing Nothing) (Just usage)))
- _ -> Nothing
- _ -> Nothing
+ contentChunk <|> toolCallChunk <|> usageChunk
+ _ -> usageChunk
+ _ -> usageChunk
+ _ -> usageChunk
parseToolCallDelta :: [Aeson.Value] -> Maybe StreamChunk
parseToolCallDelta [] = Nothing
diff --git a/Omni/Agent/Telegram.hs b/Omni/Agent/Telegram.hs
index 418e589..091ad11 100644
--- a/Omni/Agent/Telegram.hs
+++ b/Omni/Agent/Telegram.hs
@@ -602,7 +602,9 @@ processEngagedMessage tgConfig provider engineCfg msg uid userName chatId userMe
let memoryTools =
[ Memory.rememberTool uid,
- Memory.recallTool uid
+ Memory.recallTool uid,
+ Memory.linkMemoriesTool uid,
+ Memory.queryGraphTool uid
]
searchTools = case Types.tgKagiApiKey tgConfig of
Just kagiKey -> [WebSearch.webSearchTool kagiKey]