summaryrefslogtreecommitdiff
path: root/Omni/Agent/Engine.hs
diff options
context:
space:
mode:
authorBen Sima <ben@bensima.com>2025-12-01 10:02:12 -0500
committerBen Sima <ben@bensima.com>2025-12-01 10:02:12 -0500
commitfb019f46c3adcf772df2dacf688cc75c30ed6e8e (patch)
tree1b365bac2cfd513852f73355893ffb9501ece18f /Omni/Agent/Engine.hs
parentffeb13fb9f2543dfc9cdecf8ed6778226267b403 (diff)
Add guardrails and progress tracking to Jr agent
Implement runtime guardrails in Engine.hs: - Cost budget limit (default 200 cents) - Token budget limit (default 1M tokens) - Duplicate tool call detection (same tool called N times) - Test failure counting (bild --test failures) Add database-backed progress tracking: - Checkpoint events stored in agent_events table - Progress summary retrieved on retry attempts - Improved prompts emphasizing efficiency and autonomous operation Worker.hs improvements: - Uses guardrails configuration - Reports guardrail violations via callbacks - Better prompt structure for autonomous operation Task-Id: t-203
Diffstat (limited to 'Omni/Agent/Engine.hs')
-rw-r--r--Omni/Agent/Engine.hs303
1 files changed, 237 insertions, 66 deletions
diff --git a/Omni/Agent/Engine.hs b/Omni/Agent/Engine.hs
index 7da7fa5..a2a24a1 100644
--- a/Omni/Agent/Engine.hs
+++ b/Omni/Agent/Engine.hs
@@ -19,6 +19,8 @@ module Omni.Agent.Engine
EngineConfig (..),
AgentConfig (..),
AgentResult (..),
+ Guardrails (..),
+ GuardrailResult (..),
Message (..),
Role (..),
ToolCall (..),
@@ -31,6 +33,7 @@ module Omni.Agent.Engine
defaultLLM,
defaultEngineConfig,
defaultAgentConfig,
+ defaultGuardrails,
chat,
runAgent,
main,
@@ -41,6 +44,7 @@ where
import Alpha
import Data.Aeson ((.!=), (.:), (.:?), (.=))
import qualified Data.Aeson as Aeson
+import qualified Data.Aeson.KeyMap as KeyMap
import qualified Data.ByteString.Lazy as BL
import qualified Data.CaseInsensitive as CI
import qualified Data.Map.Strict as Map
@@ -164,7 +168,57 @@ test =
forM_ roles <| \role ->
case Aeson.decode (Aeson.encode role) of
Nothing -> Test.assertFailure ("Failed to decode Role: " <> show role)
- Just decoded -> decoded Test.@=? role
+ Just decoded -> decoded Test.@=? role,
+ Test.unit "defaultGuardrails has sensible defaults" <| do
+ guardrailMaxCostCents defaultGuardrails Test.@=? 100.0
+ guardrailMaxTokens defaultGuardrails Test.@=? 500000
+ guardrailMaxDuplicateToolCalls defaultGuardrails Test.@=? 3
+ guardrailMaxTestFailures defaultGuardrails Test.@=? 3,
+ Test.unit "checkCostGuardrail detects exceeded budget" <| do
+ let g = defaultGuardrails {guardrailMaxCostCents = 50.0}
+ checkCostGuardrail g 60.0 Test.@=? GuardrailCostExceeded 60.0 50.0
+ checkCostGuardrail g 40.0 Test.@=? GuardrailOk,
+ Test.unit "checkTokenGuardrail detects exceeded budget" <| do
+ let g = defaultGuardrails {guardrailMaxTokens = 1000}
+ checkTokenGuardrail g 1500 Test.@=? GuardrailTokensExceeded 1500 1000
+ checkTokenGuardrail g 500 Test.@=? GuardrailOk,
+ Test.unit "checkDuplicateGuardrail detects repeated calls" <| do
+ let g = defaultGuardrails {guardrailMaxDuplicateToolCalls = 3}
+ counts = Map.fromList [("bash", 3), ("read_file", 1)]
+ case checkDuplicateGuardrail g counts of
+ GuardrailDuplicateToolCalls name count -> do
+ name Test.@=? "bash"
+ count Test.@=? 3
+ _ -> Test.assertFailure "Expected GuardrailDuplicateToolCalls"
+ checkDuplicateGuardrail g (Map.fromList [("bash", 2)]) Test.@=? GuardrailOk,
+ Test.unit "checkTestFailureGuardrail detects failures" <| do
+ let g = defaultGuardrails {guardrailMaxTestFailures = 3}
+ checkTestFailureGuardrail g 3 Test.@=? GuardrailTestFailures 3
+ checkTestFailureGuardrail g 2 Test.@=? GuardrailOk,
+ Test.unit "updateToolCallCounts accumulates correctly" <| do
+ let tc1 = ToolCall "1" "function" (FunctionCall "bash" "{}")
+ tc2 = ToolCall "2" "function" (FunctionCall "bash" "{}")
+ tc3 = ToolCall "3" "function" (FunctionCall "read_file" "{}")
+ counts = updateToolCallCounts Map.empty [tc1, tc2, tc3]
+ Map.lookup "bash" counts Test.@=? Just 2
+ Map.lookup "read_file" counts Test.@=? Just 1,
+ Test.unit "Guardrails JSON roundtrip" <| do
+ let g = Guardrails 75.0 100000 5 4
+ case Aeson.decode (Aeson.encode g) of
+ Nothing -> Test.assertFailure "Failed to decode Guardrails"
+ Just decoded -> decoded Test.@=? g,
+ Test.unit "GuardrailResult JSON roundtrip" <| do
+ let results =
+ [ GuardrailOk,
+ GuardrailCostExceeded 100.0 50.0,
+ GuardrailTokensExceeded 2000 1000,
+ GuardrailDuplicateToolCalls "bash" 5,
+ GuardrailTestFailures 3
+ ]
+ forM_ results <| \r ->
+ case Aeson.decode (Aeson.encode r) of
+ Nothing -> Test.assertFailure ("Failed to decode GuardrailResult: " <> show r)
+ Just decoded -> decoded Test.@=? r
]
data Tool = Tool
@@ -249,16 +303,51 @@ data AgentConfig = AgentConfig
{ agentModel :: Text,
agentTools :: [Tool],
agentSystemPrompt :: Text,
- agentMaxIterations :: Int
+ agentMaxIterations :: Int,
+ agentGuardrails :: Guardrails
}
+data Guardrails = Guardrails
+ { guardrailMaxCostCents :: Double,
+ guardrailMaxTokens :: Int,
+ guardrailMaxDuplicateToolCalls :: Int,
+ guardrailMaxTestFailures :: Int
+ }
+ deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON Guardrails
+
+instance Aeson.FromJSON Guardrails
+
+data GuardrailResult
+ = GuardrailOk
+ | GuardrailCostExceeded Double Double
+ | GuardrailTokensExceeded Int Int
+ | GuardrailDuplicateToolCalls Text Int
+ | GuardrailTestFailures Int
+ deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON GuardrailResult
+
+instance Aeson.FromJSON GuardrailResult
+
+defaultGuardrails :: Guardrails
+defaultGuardrails =
+ Guardrails
+ { guardrailMaxCostCents = 100.0,
+ guardrailMaxTokens = 500000,
+ guardrailMaxDuplicateToolCalls = 3,
+ guardrailMaxTestFailures = 3
+ }
+
defaultAgentConfig :: AgentConfig
defaultAgentConfig =
AgentConfig
{ agentModel = "gpt-4",
agentTools = [],
agentSystemPrompt = "You are a helpful assistant.",
- agentMaxIterations = 10
+ agentMaxIterations = 10,
+ agentGuardrails = defaultGuardrails
}
data EngineConfig = EngineConfig
@@ -269,7 +358,8 @@ data EngineConfig = EngineConfig
engineOnAssistant :: Text -> IO (),
engineOnToolResult :: Text -> Bool -> Text -> IO (),
engineOnComplete :: IO (),
- engineOnError :: Text -> IO ()
+ engineOnError :: Text -> IO (),
+ engineOnGuardrail :: GuardrailResult -> IO ()
}
defaultEngineConfig :: EngineConfig
@@ -282,7 +372,8 @@ defaultEngineConfig =
engineOnAssistant = \_ -> pure (),
engineOnToolResult = \_ _ _ -> pure (),
engineOnComplete = pure (),
- engineOnError = \_ -> pure ()
+ engineOnError = \_ -> pure (),
+ engineOnGuardrail = \_ -> pure ()
}
data AgentResult = AgentResult
@@ -511,72 +602,138 @@ runAgent engineCfg agentCfg userPrompt = do
initialMessages = [systemMsg, userMsg]
engineOnActivity engineCfg "Starting agent loop"
- loop llm tools toolMap initialMessages 0 0 0
+ loop llm tools toolMap initialMessages 0 0 0 0.0 Map.empty 0
where
maxIter = agentMaxIterations agentCfg
-
- loop :: LLM -> [Tool] -> Map.Map Text Tool -> [Message] -> Int -> Int -> Int -> IO (Either Text AgentResult)
- loop llm tools' toolMap msgs iteration totalCalls totalTokens
+ guardrails' = agentGuardrails agentCfg
+
+ loop ::
+ LLM ->
+ [Tool] ->
+ Map.Map Text Tool ->
+ [Message] ->
+ Int ->
+ Int ->
+ Int ->
+ Double ->
+ Map.Map Text Int ->
+ Int ->
+ IO (Either Text AgentResult)
+ loop llm tools' toolMap msgs iteration totalCalls totalTokens totalCost toolCallCounts testFailures
| iteration >= maxIter = do
let errMsg = "Max iterations (" <> tshow maxIter <> ") reached"
engineOnError engineCfg errMsg
pure <| Left errMsg
| otherwise = do
- engineOnActivity engineCfg <| "Iteration " <> tshow (iteration + 1)
- result <- chatWithUsage llm tools' msgs
- case result of
- Left err -> do
- engineOnError engineCfg err
- pure (Left err)
- Right chatRes -> do
- let msg = chatMessage chatRes
- tokens = maybe 0 usageTotalTokens (chatUsage chatRes)
- -- Use actual cost from API response when available
- -- OpenRouter returns cost in dollars, convert to cents
- cost = case chatUsage chatRes +> usageCost of
- Just actualCost -> actualCost * 100
- Nothing -> estimateCost (llmModel llm) tokens
- engineOnCost engineCfg tokens cost
- let newTokens = totalTokens + tokens
- let assistantText = msgContent msg
- unless (Text.null assistantText)
- <| engineOnAssistant engineCfg assistantText
- case msgToolCalls msg of
- Nothing -> do
- engineOnActivity engineCfg "Agent completed"
- engineOnComplete engineCfg
- pure
- <| Right
- <| AgentResult
- { resultFinalMessage = msgContent msg,
- resultToolCallCount = totalCalls,
- resultIterations = iteration + 1,
- resultTotalCost = estimateTotalCost (llmModel llm) newTokens,
- resultTotalTokens = newTokens
- }
- Just [] -> do
- engineOnActivity engineCfg "Agent completed (empty tool calls)"
- engineOnComplete engineCfg
- pure
- <| Right
- <| AgentResult
- { resultFinalMessage = msgContent msg,
- resultToolCallCount = totalCalls,
- resultIterations = iteration + 1,
- resultTotalCost = estimateTotalCost (llmModel llm) newTokens,
- resultTotalTokens = newTokens
- }
- Just tcs -> do
- toolResults <- executeToolCalls engineCfg toolMap tcs
- let newMsgs = msgs <> [msg] <> toolResults
- newCalls = totalCalls + length tcs
- loop llm tools' toolMap newMsgs (iteration + 1) newCalls newTokens
+ let guardrailViolation = findGuardrailViolation guardrails' totalCost totalTokens toolCallCounts testFailures
+ case guardrailViolation of
+ Just (g, errMsg) -> do
+ engineOnGuardrail engineCfg g
+ pure <| Left errMsg
+ Nothing -> do
+ engineOnActivity engineCfg <| "Iteration " <> tshow (iteration + 1)
+ result <- chatWithUsage llm tools' msgs
+ case result of
+ Left err -> do
+ engineOnError engineCfg err
+ pure (Left err)
+ Right chatRes -> do
+ let msg = chatMessage chatRes
+ tokens = maybe 0 usageTotalTokens (chatUsage chatRes)
+ cost = case chatUsage chatRes +> usageCost of
+ Just actualCost -> actualCost * 100
+ Nothing -> estimateCost (llmModel llm) tokens
+ engineOnCost engineCfg tokens cost
+ let newTokens = totalTokens + tokens
+ newCost = totalCost + cost
+ let assistantText = msgContent msg
+ unless (Text.null assistantText)
+ <| engineOnAssistant engineCfg assistantText
+ case msgToolCalls msg of
+ Nothing -> do
+ engineOnActivity engineCfg "Agent completed"
+ engineOnComplete engineCfg
+ pure
+ <| Right
+ <| AgentResult
+ { resultFinalMessage = msgContent msg,
+ resultToolCallCount = totalCalls,
+ resultIterations = iteration + 1,
+ resultTotalCost = newCost,
+ resultTotalTokens = newTokens
+ }
+ Just [] -> do
+ engineOnActivity engineCfg "Agent completed (empty tool calls)"
+ engineOnComplete engineCfg
+ pure
+ <| Right
+ <| AgentResult
+ { resultFinalMessage = msgContent msg,
+ resultToolCallCount = totalCalls,
+ resultIterations = iteration + 1,
+ resultTotalCost = newCost,
+ resultTotalTokens = newTokens
+ }
+ Just tcs -> do
+ (toolResults, newTestFailures) <- executeToolCallsWithTracking engineCfg toolMap tcs testFailures
+ let newMsgs = msgs <> [msg] <> toolResults
+ newCalls = totalCalls + length tcs
+ newToolCallCounts = updateToolCallCounts toolCallCounts tcs
+ loop llm tools' toolMap newMsgs (iteration + 1) newCalls newTokens newCost newToolCallCounts newTestFailures
+
+checkCostGuardrail :: Guardrails -> Double -> GuardrailResult
+checkCostGuardrail g cost
+ | cost > guardrailMaxCostCents g = GuardrailCostExceeded cost (guardrailMaxCostCents g)
+ | otherwise = GuardrailOk
+
+checkTokenGuardrail :: Guardrails -> Int -> GuardrailResult
+checkTokenGuardrail g tokens
+ | tokens > guardrailMaxTokens g = GuardrailTokensExceeded tokens (guardrailMaxTokens g)
+ | otherwise = GuardrailOk
+
+checkDuplicateGuardrail :: Guardrails -> Map.Map Text Int -> GuardrailResult
+checkDuplicateGuardrail g counts =
+ let maxAllowed = guardrailMaxDuplicateToolCalls g
+ violations = [(name, count) | (name, count) <- Map.toList counts, count >= maxAllowed]
+ in case violations of
+ ((name, count) : _) -> GuardrailDuplicateToolCalls name count
+ [] -> GuardrailOk
+
+checkTestFailureGuardrail :: Guardrails -> Int -> GuardrailResult
+checkTestFailureGuardrail g failures
+ | failures >= guardrailMaxTestFailures g = GuardrailTestFailures failures
+ | otherwise = GuardrailOk
+
+updateToolCallCounts :: Map.Map Text Int -> [ToolCall] -> Map.Map Text Int
+updateToolCallCounts =
+ foldr (\tc m -> Map.insertWith (+) (fcName (tcFunction tc)) 1 m)
+
+findGuardrailViolation :: Guardrails -> Double -> Int -> Map.Map Text Int -> Int -> Maybe (GuardrailResult, Text)
+findGuardrailViolation g cost tokens toolCallCounts testFailures =
+ case checkCostGuardrail g cost of
+ r@(GuardrailCostExceeded actual limit) ->
+ Just (r, "Guardrail: cost budget exceeded (" <> tshow actual <> "/" <> tshow limit <> " cents)")
+ _ -> case checkTokenGuardrail g tokens of
+ r@(GuardrailTokensExceeded actual limit) ->
+ Just (r, "Guardrail: token budget exceeded (" <> tshow actual <> "/" <> tshow limit <> " tokens)")
+ _ -> case checkDuplicateGuardrail g toolCallCounts of
+ r@(GuardrailDuplicateToolCalls tool count) ->
+ Just (r, "Guardrail: duplicate tool calls (" <> tool <> " called " <> tshow count <> " times)")
+ _ -> case checkTestFailureGuardrail g testFailures of
+ r@(GuardrailTestFailures count) ->
+ Just (r, "Guardrail: too many test failures (" <> tshow count <> ")")
+ _ -> Nothing
buildToolMap :: [Tool] -> Map.Map Text Tool
buildToolMap = Map.fromList <. map (\t -> (toolName t, t))
-executeToolCalls :: EngineConfig -> Map.Map Text Tool -> [ToolCall] -> IO [Message]
-executeToolCalls engineCfg toolMap = traverse executeSingle
+executeToolCallsWithTracking :: EngineConfig -> Map.Map Text Tool -> [ToolCall] -> Int -> IO ([Message], Int)
+executeToolCallsWithTracking engineCfg toolMap tcs initialFailures = do
+ results <- traverse executeSingle tcs
+ let msgs = map fst results
+ failureDeltas = map snd results
+ totalNewFailures = initialFailures + sum failureDeltas
+ pure (msgs, totalNewFailures)
where
executeSingle tc = do
let name = fcName (tcFunction tc)
@@ -588,18 +745,35 @@ executeToolCalls engineCfg toolMap = traverse executeSingle
Nothing -> do
let errMsg = "Tool not found: " <> name
engineOnToolResult engineCfg name False errMsg
- pure <| Message ToolRole errMsg Nothing (Just callId)
+ pure (Message ToolRole errMsg Nothing (Just callId), 0)
Just tool -> do
case Aeson.decode (BL.fromStrict (TE.encodeUtf8 argsText)) of
Nothing -> do
let errMsg = "Invalid JSON arguments: " <> argsText
engineOnToolResult engineCfg name False errMsg
- pure <| Message ToolRole errMsg Nothing (Just callId)
+ pure (Message ToolRole errMsg Nothing (Just callId), 0)
Just args -> do
resultValue <- toolExecute tool args
let resultText = TE.decodeUtf8 (BL.toStrict (Aeson.encode resultValue))
+ isTestCall = name == "bash" && ("bild --test" `Text.isInfixOf` argsText || "bild -t" `Text.isInfixOf` argsText)
+ isTestFailure = isTestCall && isFailureResult resultValue
+ failureDelta = if isTestFailure then 1 else 0
engineOnToolResult engineCfg name True resultText
- pure <| Message ToolRole resultText Nothing (Just callId)
+ pure (Message ToolRole resultText Nothing (Just callId), failureDelta)
+
+ isFailureResult :: Aeson.Value -> Bool
+ isFailureResult (Aeson.Object obj) =
+ case KeyMap.lookup "exit_code" obj of
+ Just (Aeson.Number n) -> n /= 0
+ _ -> False
+ isFailureResult (Aeson.String s) =
+ "error"
+ `Text.isInfixOf` Text.toLower s
+ || "failed"
+ `Text.isInfixOf` Text.toLower s
+ || "FAILED"
+ `Text.isInfixOf` s
+ isFailureResult _ = False
-- | Estimate cost in cents from token count
estimateCost :: Text -> Int -> Double
@@ -609,6 +783,3 @@ estimateCost model tokens
| "gpt-4" `Text.isInfixOf` model = fromIntegral tokens * 3 / 100000
| "claude" `Text.isInfixOf` model = fromIntegral tokens * 3 / 100000
| otherwise = fromIntegral tokens / 100000
-
-estimateTotalCost :: Text -> Int -> Double
-estimateTotalCost = estimateCost