diff options
Diffstat (limited to 'Omni/Agent/Engine.hs')
| -rw-r--r-- | Omni/Agent/Engine.hs | 70 |
1 files changed, 48 insertions, 22 deletions
diff --git a/Omni/Agent/Engine.hs b/Omni/Agent/Engine.hs index a2a24a1..4ee5e5d 100644 --- a/Omni/Agent/Engine.hs +++ b/Omni/Agent/Engine.hs @@ -203,7 +203,7 @@ test = Map.lookup "bash" counts Test.@=? Just 2 Map.lookup "read_file" counts Test.@=? Just 1, Test.unit "Guardrails JSON roundtrip" <| do - let g = Guardrails 75.0 100000 5 4 + let g = Guardrails 75.0 100000 5 4 3 case Aeson.decode (Aeson.encode g) of Nothing -> Test.assertFailure "Failed to decode Guardrails" Just decoded -> decoded Test.@=? g, @@ -213,7 +213,8 @@ test = GuardrailCostExceeded 100.0 50.0, GuardrailTokensExceeded 2000 1000, GuardrailDuplicateToolCalls "bash" 5, - GuardrailTestFailures 3 + GuardrailTestFailures 3, + GuardrailEditFailures 5 ] forM_ results <| \r -> case Aeson.decode (Aeson.encode r) of @@ -311,7 +312,8 @@ data Guardrails = Guardrails { guardrailMaxCostCents :: Double, guardrailMaxTokens :: Int, guardrailMaxDuplicateToolCalls :: Int, - guardrailMaxTestFailures :: Int + guardrailMaxTestFailures :: Int, + guardrailMaxEditFailures :: Int } deriving (Show, Eq, Generic) @@ -325,6 +327,7 @@ data GuardrailResult | GuardrailTokensExceeded Int Int | GuardrailDuplicateToolCalls Text Int | GuardrailTestFailures Int + | GuardrailEditFailures Int deriving (Show, Eq, Generic) instance Aeson.ToJSON GuardrailResult @@ -337,7 +340,8 @@ defaultGuardrails = { guardrailMaxCostCents = 100.0, guardrailMaxTokens = 500000, guardrailMaxDuplicateToolCalls = 3, - guardrailMaxTestFailures = 3 + guardrailMaxTestFailures = 3, + guardrailMaxEditFailures = 5 } defaultAgentConfig :: AgentConfig @@ -602,7 +606,7 @@ runAgent engineCfg agentCfg userPrompt = do initialMessages = [systemMsg, userMsg] engineOnActivity engineCfg "Starting agent loop" - loop llm tools toolMap initialMessages 0 0 0 0.0 Map.empty 0 + loop llm tools toolMap initialMessages 0 0 0 0.0 Map.empty 0 0 where maxIter = agentMaxIterations agentCfg guardrails' = agentGuardrails agentCfg @@ -618,14 +622,15 @@ runAgent engineCfg agentCfg userPrompt = do Double -> Map.Map Text Int -> Int -> + Int -> IO (Either Text AgentResult) - loop llm tools' toolMap msgs iteration totalCalls totalTokens totalCost toolCallCounts testFailures + loop llm tools' toolMap msgs iteration totalCalls totalTokens totalCost toolCallCounts testFailures editFailures | iteration >= maxIter = do let errMsg = "Max iterations (" <> tshow maxIter <> ") reached" engineOnError engineCfg errMsg pure <| Left errMsg | otherwise = do - let guardrailViolation = findGuardrailViolation guardrails' totalCost totalTokens toolCallCounts testFailures + let guardrailViolation = findGuardrailViolation guardrails' totalCost totalTokens toolCallCounts testFailures editFailures case guardrailViolation of Just (g, errMsg) -> do engineOnGuardrail engineCfg g @@ -675,11 +680,11 @@ runAgent engineCfg agentCfg userPrompt = do resultTotalTokens = newTokens } Just tcs -> do - (toolResults, newTestFailures) <- executeToolCallsWithTracking engineCfg toolMap tcs testFailures + (toolResults, newTestFailures, newEditFailures) <- executeToolCallsWithTracking engineCfg toolMap tcs testFailures editFailures let newMsgs = msgs <> [msg] <> toolResults newCalls = totalCalls + length tcs newToolCallCounts = updateToolCallCounts toolCallCounts tcs - loop llm tools' toolMap newMsgs (iteration + 1) newCalls newTokens newCost newToolCallCounts newTestFailures + loop llm tools' toolMap newMsgs (iteration + 1) newCalls newTokens newCost newToolCallCounts newTestFailures newEditFailures checkCostGuardrail :: Guardrails -> Double -> GuardrailResult checkCostGuardrail g cost @@ -704,12 +709,17 @@ checkTestFailureGuardrail g failures | failures >= guardrailMaxTestFailures g = GuardrailTestFailures failures | otherwise = GuardrailOk +checkEditFailureGuardrail :: Guardrails -> Int -> GuardrailResult +checkEditFailureGuardrail g failures + | failures >= guardrailMaxEditFailures g = GuardrailEditFailures failures + | otherwise = GuardrailOk + updateToolCallCounts :: Map.Map Text Int -> [ToolCall] -> Map.Map Text Int updateToolCallCounts = foldr (\tc m -> Map.insertWith (+) (fcName (tcFunction tc)) 1 m) -findGuardrailViolation :: Guardrails -> Double -> Int -> Map.Map Text Int -> Int -> Maybe (GuardrailResult, Text) -findGuardrailViolation g cost tokens toolCallCounts testFailures = +findGuardrailViolation :: Guardrails -> Double -> Int -> Map.Map Text Int -> Int -> Int -> Maybe (GuardrailResult, Text) +findGuardrailViolation g cost tokens toolCallCounts testFailures editFailures = case checkCostGuardrail g cost of r@(GuardrailCostExceeded actual limit) -> Just (r, "Guardrail: cost budget exceeded (" <> tshow actual <> "/" <> tshow limit <> " cents)") @@ -722,18 +732,25 @@ findGuardrailViolation g cost tokens toolCallCounts testFailures = _ -> case checkTestFailureGuardrail g testFailures of r@(GuardrailTestFailures count) -> Just (r, "Guardrail: too many test failures (" <> tshow count <> ")") - _ -> Nothing + _ -> case checkEditFailureGuardrail g editFailures of + r@(GuardrailEditFailures count) -> + Just (r, "Guardrail: too many edit_file failures (" <> tshow count <> " 'old_str not found' errors)") + _ -> Nothing buildToolMap :: [Tool] -> Map.Map Text Tool buildToolMap = Map.fromList <. map (\t -> (toolName t, t)) -executeToolCallsWithTracking :: EngineConfig -> Map.Map Text Tool -> [ToolCall] -> Int -> IO ([Message], Int) -executeToolCallsWithTracking engineCfg toolMap tcs initialFailures = do +-- | Track both test failures and edit failures +-- Returns (messages, testFailures, editFailures) +executeToolCallsWithTracking :: EngineConfig -> Map.Map Text Tool -> [ToolCall] -> Int -> Int -> IO ([Message], Int, Int) +executeToolCallsWithTracking engineCfg toolMap tcs initialTestFailures initialEditFailures = do results <- traverse executeSingle tcs - let msgs = map fst results - failureDeltas = map snd results - totalNewFailures = initialFailures + sum failureDeltas - pure (msgs, totalNewFailures) + let msgs = map (\(m, _, _) -> m) results + testDeltas = map (\(_, t, _) -> t) results + editDeltas = map (\(_, _, e) -> e) results + totalTestFailures = initialTestFailures + sum testDeltas + totalEditFailures = initialEditFailures + sum editDeltas + pure (msgs, totalTestFailures, totalEditFailures) where executeSingle tc = do let name = fcName (tcFunction tc) @@ -745,21 +762,23 @@ executeToolCallsWithTracking engineCfg toolMap tcs initialFailures = do Nothing -> do let errMsg = "Tool not found: " <> name engineOnToolResult engineCfg name False errMsg - pure (Message ToolRole errMsg Nothing (Just callId), 0) + pure (Message ToolRole errMsg Nothing (Just callId), 0, 0) Just tool -> do case Aeson.decode (BL.fromStrict (TE.encodeUtf8 argsText)) of Nothing -> do let errMsg = "Invalid JSON arguments: " <> argsText engineOnToolResult engineCfg name False errMsg - pure (Message ToolRole errMsg Nothing (Just callId), 0) + pure (Message ToolRole errMsg Nothing (Just callId), 0, 0) Just args -> do resultValue <- toolExecute tool args let resultText = TE.decodeUtf8 (BL.toStrict (Aeson.encode resultValue)) isTestCall = name == "bash" && ("bild --test" `Text.isInfixOf` argsText || "bild -t" `Text.isInfixOf` argsText) isTestFailure = isTestCall && isFailureResult resultValue - failureDelta = if isTestFailure then 1 else 0 + testDelta = if isTestFailure then 1 else 0 + isEditFailure = name == "edit_file" && isOldStrNotFoundError resultValue + editDelta = if isEditFailure then 1 else 0 engineOnToolResult engineCfg name True resultText - pure (Message ToolRole resultText Nothing (Just callId), failureDelta) + pure (Message ToolRole resultText Nothing (Just callId), testDelta, editDelta) isFailureResult :: Aeson.Value -> Bool isFailureResult (Aeson.Object obj) = @@ -775,6 +794,13 @@ executeToolCallsWithTracking engineCfg toolMap tcs initialFailures = do `Text.isInfixOf` s isFailureResult _ = False + isOldStrNotFoundError :: Aeson.Value -> Bool + isOldStrNotFoundError (Aeson.Object obj) = + case KeyMap.lookup "error" obj of + Just (Aeson.String s) -> "old_str not found" `Text.isInfixOf` s + _ -> False + isOldStrNotFoundError _ = False + -- | Estimate cost in cents from token count estimateCost :: Text -> Int -> Double estimateCost model tokens |
