From fb019f46c3adcf772df2dacf688cc75c30ed6e8e Mon Sep 17 00:00:00 2001 From: Ben Sima Date: Mon, 1 Dec 2025 10:02:12 -0500 Subject: Add guardrails and progress tracking to Jr agent Implement runtime guardrails in Engine.hs: - Cost budget limit (default 200 cents) - Token budget limit (default 1M tokens) - Duplicate tool call detection (same tool called N times) - Test failure counting (bild --test failures) Add database-backed progress tracking: - Checkpoint events stored in agent_events table - Progress summary retrieved on retry attempts - Improved prompts emphasizing efficiency and autonomous operation Worker.hs improvements: - Uses guardrails configuration - Reports guardrail violations via callbacks - Better prompt structure for autonomous operation Task-Id: t-203 --- Omni/Agent/Engine.hs | 303 ++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 237 insertions(+), 66 deletions(-) (limited to 'Omni/Agent/Engine.hs') diff --git a/Omni/Agent/Engine.hs b/Omni/Agent/Engine.hs index 7da7fa5..a2a24a1 100644 --- a/Omni/Agent/Engine.hs +++ b/Omni/Agent/Engine.hs @@ -19,6 +19,8 @@ module Omni.Agent.Engine EngineConfig (..), AgentConfig (..), AgentResult (..), + Guardrails (..), + GuardrailResult (..), Message (..), Role (..), ToolCall (..), @@ -31,6 +33,7 @@ module Omni.Agent.Engine defaultLLM, defaultEngineConfig, defaultAgentConfig, + defaultGuardrails, chat, runAgent, main, @@ -41,6 +44,7 @@ where import Alpha import Data.Aeson ((.!=), (.:), (.:?), (.=)) import qualified Data.Aeson as Aeson +import qualified Data.Aeson.KeyMap as KeyMap import qualified Data.ByteString.Lazy as BL import qualified Data.CaseInsensitive as CI import qualified Data.Map.Strict as Map @@ -164,7 +168,57 @@ test = forM_ roles <| \role -> case Aeson.decode (Aeson.encode role) of Nothing -> Test.assertFailure ("Failed to decode Role: " <> show role) - Just decoded -> decoded Test.@=? role + Just decoded -> decoded Test.@=? role, + Test.unit "defaultGuardrails has sensible defaults" <| do + guardrailMaxCostCents defaultGuardrails Test.@=? 100.0 + guardrailMaxTokens defaultGuardrails Test.@=? 500000 + guardrailMaxDuplicateToolCalls defaultGuardrails Test.@=? 3 + guardrailMaxTestFailures defaultGuardrails Test.@=? 3, + Test.unit "checkCostGuardrail detects exceeded budget" <| do + let g = defaultGuardrails {guardrailMaxCostCents = 50.0} + checkCostGuardrail g 60.0 Test.@=? GuardrailCostExceeded 60.0 50.0 + checkCostGuardrail g 40.0 Test.@=? GuardrailOk, + Test.unit "checkTokenGuardrail detects exceeded budget" <| do + let g = defaultGuardrails {guardrailMaxTokens = 1000} + checkTokenGuardrail g 1500 Test.@=? GuardrailTokensExceeded 1500 1000 + checkTokenGuardrail g 500 Test.@=? GuardrailOk, + Test.unit "checkDuplicateGuardrail detects repeated calls" <| do + let g = defaultGuardrails {guardrailMaxDuplicateToolCalls = 3} + counts = Map.fromList [("bash", 3), ("read_file", 1)] + case checkDuplicateGuardrail g counts of + GuardrailDuplicateToolCalls name count -> do + name Test.@=? "bash" + count Test.@=? 3 + _ -> Test.assertFailure "Expected GuardrailDuplicateToolCalls" + checkDuplicateGuardrail g (Map.fromList [("bash", 2)]) Test.@=? GuardrailOk, + Test.unit "checkTestFailureGuardrail detects failures" <| do + let g = defaultGuardrails {guardrailMaxTestFailures = 3} + checkTestFailureGuardrail g 3 Test.@=? GuardrailTestFailures 3 + checkTestFailureGuardrail g 2 Test.@=? GuardrailOk, + Test.unit "updateToolCallCounts accumulates correctly" <| do + let tc1 = ToolCall "1" "function" (FunctionCall "bash" "{}") + tc2 = ToolCall "2" "function" (FunctionCall "bash" "{}") + tc3 = ToolCall "3" "function" (FunctionCall "read_file" "{}") + counts = updateToolCallCounts Map.empty [tc1, tc2, tc3] + Map.lookup "bash" counts Test.@=? Just 2 + Map.lookup "read_file" counts Test.@=? Just 1, + Test.unit "Guardrails JSON roundtrip" <| do + let g = Guardrails 75.0 100000 5 4 + case Aeson.decode (Aeson.encode g) of + Nothing -> Test.assertFailure "Failed to decode Guardrails" + Just decoded -> decoded Test.@=? g, + Test.unit "GuardrailResult JSON roundtrip" <| do + let results = + [ GuardrailOk, + GuardrailCostExceeded 100.0 50.0, + GuardrailTokensExceeded 2000 1000, + GuardrailDuplicateToolCalls "bash" 5, + GuardrailTestFailures 3 + ] + forM_ results <| \r -> + case Aeson.decode (Aeson.encode r) of + Nothing -> Test.assertFailure ("Failed to decode GuardrailResult: " <> show r) + Just decoded -> decoded Test.@=? r ] data Tool = Tool @@ -249,16 +303,51 @@ data AgentConfig = AgentConfig { agentModel :: Text, agentTools :: [Tool], agentSystemPrompt :: Text, - agentMaxIterations :: Int + agentMaxIterations :: Int, + agentGuardrails :: Guardrails } +data Guardrails = Guardrails + { guardrailMaxCostCents :: Double, + guardrailMaxTokens :: Int, + guardrailMaxDuplicateToolCalls :: Int, + guardrailMaxTestFailures :: Int + } + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON Guardrails + +instance Aeson.FromJSON Guardrails + +data GuardrailResult + = GuardrailOk + | GuardrailCostExceeded Double Double + | GuardrailTokensExceeded Int Int + | GuardrailDuplicateToolCalls Text Int + | GuardrailTestFailures Int + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON GuardrailResult + +instance Aeson.FromJSON GuardrailResult + +defaultGuardrails :: Guardrails +defaultGuardrails = + Guardrails + { guardrailMaxCostCents = 100.0, + guardrailMaxTokens = 500000, + guardrailMaxDuplicateToolCalls = 3, + guardrailMaxTestFailures = 3 + } + defaultAgentConfig :: AgentConfig defaultAgentConfig = AgentConfig { agentModel = "gpt-4", agentTools = [], agentSystemPrompt = "You are a helpful assistant.", - agentMaxIterations = 10 + agentMaxIterations = 10, + agentGuardrails = defaultGuardrails } data EngineConfig = EngineConfig @@ -269,7 +358,8 @@ data EngineConfig = EngineConfig engineOnAssistant :: Text -> IO (), engineOnToolResult :: Text -> Bool -> Text -> IO (), engineOnComplete :: IO (), - engineOnError :: Text -> IO () + engineOnError :: Text -> IO (), + engineOnGuardrail :: GuardrailResult -> IO () } defaultEngineConfig :: EngineConfig @@ -282,7 +372,8 @@ defaultEngineConfig = engineOnAssistant = \_ -> pure (), engineOnToolResult = \_ _ _ -> pure (), engineOnComplete = pure (), - engineOnError = \_ -> pure () + engineOnError = \_ -> pure (), + engineOnGuardrail = \_ -> pure () } data AgentResult = AgentResult @@ -511,72 +602,138 @@ runAgent engineCfg agentCfg userPrompt = do initialMessages = [systemMsg, userMsg] engineOnActivity engineCfg "Starting agent loop" - loop llm tools toolMap initialMessages 0 0 0 + loop llm tools toolMap initialMessages 0 0 0 0.0 Map.empty 0 where maxIter = agentMaxIterations agentCfg - - loop :: LLM -> [Tool] -> Map.Map Text Tool -> [Message] -> Int -> Int -> Int -> IO (Either Text AgentResult) - loop llm tools' toolMap msgs iteration totalCalls totalTokens + guardrails' = agentGuardrails agentCfg + + loop :: + LLM -> + [Tool] -> + Map.Map Text Tool -> + [Message] -> + Int -> + Int -> + Int -> + Double -> + Map.Map Text Int -> + Int -> + IO (Either Text AgentResult) + loop llm tools' toolMap msgs iteration totalCalls totalTokens totalCost toolCallCounts testFailures | iteration >= maxIter = do let errMsg = "Max iterations (" <> tshow maxIter <> ") reached" engineOnError engineCfg errMsg pure <| Left errMsg | otherwise = do - engineOnActivity engineCfg <| "Iteration " <> tshow (iteration + 1) - result <- chatWithUsage llm tools' msgs - case result of - Left err -> do - engineOnError engineCfg err - pure (Left err) - Right chatRes -> do - let msg = chatMessage chatRes - tokens = maybe 0 usageTotalTokens (chatUsage chatRes) - -- Use actual cost from API response when available - -- OpenRouter returns cost in dollars, convert to cents - cost = case chatUsage chatRes +> usageCost of - Just actualCost -> actualCost * 100 - Nothing -> estimateCost (llmModel llm) tokens - engineOnCost engineCfg tokens cost - let newTokens = totalTokens + tokens - let assistantText = msgContent msg - unless (Text.null assistantText) - <| engineOnAssistant engineCfg assistantText - case msgToolCalls msg of - Nothing -> do - engineOnActivity engineCfg "Agent completed" - engineOnComplete engineCfg - pure - <| Right - <| AgentResult - { resultFinalMessage = msgContent msg, - resultToolCallCount = totalCalls, - resultIterations = iteration + 1, - resultTotalCost = estimateTotalCost (llmModel llm) newTokens, - resultTotalTokens = newTokens - } - Just [] -> do - engineOnActivity engineCfg "Agent completed (empty tool calls)" - engineOnComplete engineCfg - pure - <| Right - <| AgentResult - { resultFinalMessage = msgContent msg, - resultToolCallCount = totalCalls, - resultIterations = iteration + 1, - resultTotalCost = estimateTotalCost (llmModel llm) newTokens, - resultTotalTokens = newTokens - } - Just tcs -> do - toolResults <- executeToolCalls engineCfg toolMap tcs - let newMsgs = msgs <> [msg] <> toolResults - newCalls = totalCalls + length tcs - loop llm tools' toolMap newMsgs (iteration + 1) newCalls newTokens + let guardrailViolation = findGuardrailViolation guardrails' totalCost totalTokens toolCallCounts testFailures + case guardrailViolation of + Just (g, errMsg) -> do + engineOnGuardrail engineCfg g + pure <| Left errMsg + Nothing -> do + engineOnActivity engineCfg <| "Iteration " <> tshow (iteration + 1) + result <- chatWithUsage llm tools' msgs + case result of + Left err -> do + engineOnError engineCfg err + pure (Left err) + Right chatRes -> do + let msg = chatMessage chatRes + tokens = maybe 0 usageTotalTokens (chatUsage chatRes) + cost = case chatUsage chatRes +> usageCost of + Just actualCost -> actualCost * 100 + Nothing -> estimateCost (llmModel llm) tokens + engineOnCost engineCfg tokens cost + let newTokens = totalTokens + tokens + newCost = totalCost + cost + let assistantText = msgContent msg + unless (Text.null assistantText) + <| engineOnAssistant engineCfg assistantText + case msgToolCalls msg of + Nothing -> do + engineOnActivity engineCfg "Agent completed" + engineOnComplete engineCfg + pure + <| Right + <| AgentResult + { resultFinalMessage = msgContent msg, + resultToolCallCount = totalCalls, + resultIterations = iteration + 1, + resultTotalCost = newCost, + resultTotalTokens = newTokens + } + Just [] -> do + engineOnActivity engineCfg "Agent completed (empty tool calls)" + engineOnComplete engineCfg + pure + <| Right + <| AgentResult + { resultFinalMessage = msgContent msg, + resultToolCallCount = totalCalls, + resultIterations = iteration + 1, + resultTotalCost = newCost, + resultTotalTokens = newTokens + } + Just tcs -> do + (toolResults, newTestFailures) <- executeToolCallsWithTracking engineCfg toolMap tcs testFailures + let newMsgs = msgs <> [msg] <> toolResults + newCalls = totalCalls + length tcs + newToolCallCounts = updateToolCallCounts toolCallCounts tcs + loop llm tools' toolMap newMsgs (iteration + 1) newCalls newTokens newCost newToolCallCounts newTestFailures + +checkCostGuardrail :: Guardrails -> Double -> GuardrailResult +checkCostGuardrail g cost + | cost > guardrailMaxCostCents g = GuardrailCostExceeded cost (guardrailMaxCostCents g) + | otherwise = GuardrailOk + +checkTokenGuardrail :: Guardrails -> Int -> GuardrailResult +checkTokenGuardrail g tokens + | tokens > guardrailMaxTokens g = GuardrailTokensExceeded tokens (guardrailMaxTokens g) + | otherwise = GuardrailOk + +checkDuplicateGuardrail :: Guardrails -> Map.Map Text Int -> GuardrailResult +checkDuplicateGuardrail g counts = + let maxAllowed = guardrailMaxDuplicateToolCalls g + violations = [(name, count) | (name, count) <- Map.toList counts, count >= maxAllowed] + in case violations of + ((name, count) : _) -> GuardrailDuplicateToolCalls name count + [] -> GuardrailOk + +checkTestFailureGuardrail :: Guardrails -> Int -> GuardrailResult +checkTestFailureGuardrail g failures + | failures >= guardrailMaxTestFailures g = GuardrailTestFailures failures + | otherwise = GuardrailOk + +updateToolCallCounts :: Map.Map Text Int -> [ToolCall] -> Map.Map Text Int +updateToolCallCounts = + foldr (\tc m -> Map.insertWith (+) (fcName (tcFunction tc)) 1 m) + +findGuardrailViolation :: Guardrails -> Double -> Int -> Map.Map Text Int -> Int -> Maybe (GuardrailResult, Text) +findGuardrailViolation g cost tokens toolCallCounts testFailures = + case checkCostGuardrail g cost of + r@(GuardrailCostExceeded actual limit) -> + Just (r, "Guardrail: cost budget exceeded (" <> tshow actual <> "/" <> tshow limit <> " cents)") + _ -> case checkTokenGuardrail g tokens of + r@(GuardrailTokensExceeded actual limit) -> + Just (r, "Guardrail: token budget exceeded (" <> tshow actual <> "/" <> tshow limit <> " tokens)") + _ -> case checkDuplicateGuardrail g toolCallCounts of + r@(GuardrailDuplicateToolCalls tool count) -> + Just (r, "Guardrail: duplicate tool calls (" <> tool <> " called " <> tshow count <> " times)") + _ -> case checkTestFailureGuardrail g testFailures of + r@(GuardrailTestFailures count) -> + Just (r, "Guardrail: too many test failures (" <> tshow count <> ")") + _ -> Nothing buildToolMap :: [Tool] -> Map.Map Text Tool buildToolMap = Map.fromList <. map (\t -> (toolName t, t)) -executeToolCalls :: EngineConfig -> Map.Map Text Tool -> [ToolCall] -> IO [Message] -executeToolCalls engineCfg toolMap = traverse executeSingle +executeToolCallsWithTracking :: EngineConfig -> Map.Map Text Tool -> [ToolCall] -> Int -> IO ([Message], Int) +executeToolCallsWithTracking engineCfg toolMap tcs initialFailures = do + results <- traverse executeSingle tcs + let msgs = map fst results + failureDeltas = map snd results + totalNewFailures = initialFailures + sum failureDeltas + pure (msgs, totalNewFailures) where executeSingle tc = do let name = fcName (tcFunction tc) @@ -588,18 +745,35 @@ executeToolCalls engineCfg toolMap = traverse executeSingle Nothing -> do let errMsg = "Tool not found: " <> name engineOnToolResult engineCfg name False errMsg - pure <| Message ToolRole errMsg Nothing (Just callId) + pure (Message ToolRole errMsg Nothing (Just callId), 0) Just tool -> do case Aeson.decode (BL.fromStrict (TE.encodeUtf8 argsText)) of Nothing -> do let errMsg = "Invalid JSON arguments: " <> argsText engineOnToolResult engineCfg name False errMsg - pure <| Message ToolRole errMsg Nothing (Just callId) + pure (Message ToolRole errMsg Nothing (Just callId), 0) Just args -> do resultValue <- toolExecute tool args let resultText = TE.decodeUtf8 (BL.toStrict (Aeson.encode resultValue)) + isTestCall = name == "bash" && ("bild --test" `Text.isInfixOf` argsText || "bild -t" `Text.isInfixOf` argsText) + isTestFailure = isTestCall && isFailureResult resultValue + failureDelta = if isTestFailure then 1 else 0 engineOnToolResult engineCfg name True resultText - pure <| Message ToolRole resultText Nothing (Just callId) + pure (Message ToolRole resultText Nothing (Just callId), failureDelta) + + isFailureResult :: Aeson.Value -> Bool + isFailureResult (Aeson.Object obj) = + case KeyMap.lookup "exit_code" obj of + Just (Aeson.Number n) -> n /= 0 + _ -> False + isFailureResult (Aeson.String s) = + "error" + `Text.isInfixOf` Text.toLower s + || "failed" + `Text.isInfixOf` Text.toLower s + || "FAILED" + `Text.isInfixOf` s + isFailureResult _ = False -- | Estimate cost in cents from token count estimateCost :: Text -> Int -> Double @@ -609,6 +783,3 @@ estimateCost model tokens | "gpt-4" `Text.isInfixOf` model = fromIntegral tokens * 3 / 100000 | "claude" `Text.isInfixOf` model = fromIntegral tokens * 3 / 100000 | otherwise = fromIntegral tokens / 100000 - -estimateTotalCost :: Text -> Int -> Double -estimateTotalCost = estimateCost -- cgit v1.2.3