summaryrefslogtreecommitdiff
path: root/Omni/Agent/Engine.hs
diff options
context:
space:
mode:
Diffstat (limited to 'Omni/Agent/Engine.hs')
-rw-r--r--Omni/Agent/Engine.hs70
1 files changed, 48 insertions, 22 deletions
diff --git a/Omni/Agent/Engine.hs b/Omni/Agent/Engine.hs
index a2a24a1..4ee5e5d 100644
--- a/Omni/Agent/Engine.hs
+++ b/Omni/Agent/Engine.hs
@@ -203,7 +203,7 @@ test =
Map.lookup "bash" counts Test.@=? Just 2
Map.lookup "read_file" counts Test.@=? Just 1,
Test.unit "Guardrails JSON roundtrip" <| do
- let g = Guardrails 75.0 100000 5 4
+ let g = Guardrails 75.0 100000 5 4 3
case Aeson.decode (Aeson.encode g) of
Nothing -> Test.assertFailure "Failed to decode Guardrails"
Just decoded -> decoded Test.@=? g,
@@ -213,7 +213,8 @@ test =
GuardrailCostExceeded 100.0 50.0,
GuardrailTokensExceeded 2000 1000,
GuardrailDuplicateToolCalls "bash" 5,
- GuardrailTestFailures 3
+ GuardrailTestFailures 3,
+ GuardrailEditFailures 5
]
forM_ results <| \r ->
case Aeson.decode (Aeson.encode r) of
@@ -311,7 +312,8 @@ data Guardrails = Guardrails
{ guardrailMaxCostCents :: Double,
guardrailMaxTokens :: Int,
guardrailMaxDuplicateToolCalls :: Int,
- guardrailMaxTestFailures :: Int
+ guardrailMaxTestFailures :: Int,
+ guardrailMaxEditFailures :: Int
}
deriving (Show, Eq, Generic)
@@ -325,6 +327,7 @@ data GuardrailResult
| GuardrailTokensExceeded Int Int
| GuardrailDuplicateToolCalls Text Int
| GuardrailTestFailures Int
+ | GuardrailEditFailures Int
deriving (Show, Eq, Generic)
instance Aeson.ToJSON GuardrailResult
@@ -337,7 +340,8 @@ defaultGuardrails =
{ guardrailMaxCostCents = 100.0,
guardrailMaxTokens = 500000,
guardrailMaxDuplicateToolCalls = 3,
- guardrailMaxTestFailures = 3
+ guardrailMaxTestFailures = 3,
+ guardrailMaxEditFailures = 5
}
defaultAgentConfig :: AgentConfig
@@ -602,7 +606,7 @@ runAgent engineCfg agentCfg userPrompt = do
initialMessages = [systemMsg, userMsg]
engineOnActivity engineCfg "Starting agent loop"
- loop llm tools toolMap initialMessages 0 0 0 0.0 Map.empty 0
+ loop llm tools toolMap initialMessages 0 0 0 0.0 Map.empty 0 0
where
maxIter = agentMaxIterations agentCfg
guardrails' = agentGuardrails agentCfg
@@ -618,14 +622,15 @@ runAgent engineCfg agentCfg userPrompt = do
Double ->
Map.Map Text Int ->
Int ->
+ Int ->
IO (Either Text AgentResult)
- loop llm tools' toolMap msgs iteration totalCalls totalTokens totalCost toolCallCounts testFailures
+ loop llm tools' toolMap msgs iteration totalCalls totalTokens totalCost toolCallCounts testFailures editFailures
| iteration >= maxIter = do
let errMsg = "Max iterations (" <> tshow maxIter <> ") reached"
engineOnError engineCfg errMsg
pure <| Left errMsg
| otherwise = do
- let guardrailViolation = findGuardrailViolation guardrails' totalCost totalTokens toolCallCounts testFailures
+ let guardrailViolation = findGuardrailViolation guardrails' totalCost totalTokens toolCallCounts testFailures editFailures
case guardrailViolation of
Just (g, errMsg) -> do
engineOnGuardrail engineCfg g
@@ -675,11 +680,11 @@ runAgent engineCfg agentCfg userPrompt = do
resultTotalTokens = newTokens
}
Just tcs -> do
- (toolResults, newTestFailures) <- executeToolCallsWithTracking engineCfg toolMap tcs testFailures
+ (toolResults, newTestFailures, newEditFailures) <- executeToolCallsWithTracking engineCfg toolMap tcs testFailures editFailures
let newMsgs = msgs <> [msg] <> toolResults
newCalls = totalCalls + length tcs
newToolCallCounts = updateToolCallCounts toolCallCounts tcs
- loop llm tools' toolMap newMsgs (iteration + 1) newCalls newTokens newCost newToolCallCounts newTestFailures
+ loop llm tools' toolMap newMsgs (iteration + 1) newCalls newTokens newCost newToolCallCounts newTestFailures newEditFailures
checkCostGuardrail :: Guardrails -> Double -> GuardrailResult
checkCostGuardrail g cost
@@ -704,12 +709,17 @@ checkTestFailureGuardrail g failures
| failures >= guardrailMaxTestFailures g = GuardrailTestFailures failures
| otherwise = GuardrailOk
+checkEditFailureGuardrail :: Guardrails -> Int -> GuardrailResult
+checkEditFailureGuardrail g failures
+ | failures >= guardrailMaxEditFailures g = GuardrailEditFailures failures
+ | otherwise = GuardrailOk
+
updateToolCallCounts :: Map.Map Text Int -> [ToolCall] -> Map.Map Text Int
updateToolCallCounts =
foldr (\tc m -> Map.insertWith (+) (fcName (tcFunction tc)) 1 m)
-findGuardrailViolation :: Guardrails -> Double -> Int -> Map.Map Text Int -> Int -> Maybe (GuardrailResult, Text)
-findGuardrailViolation g cost tokens toolCallCounts testFailures =
+findGuardrailViolation :: Guardrails -> Double -> Int -> Map.Map Text Int -> Int -> Int -> Maybe (GuardrailResult, Text)
+findGuardrailViolation g cost tokens toolCallCounts testFailures editFailures =
case checkCostGuardrail g cost of
r@(GuardrailCostExceeded actual limit) ->
Just (r, "Guardrail: cost budget exceeded (" <> tshow actual <> "/" <> tshow limit <> " cents)")
@@ -722,18 +732,25 @@ findGuardrailViolation g cost tokens toolCallCounts testFailures =
_ -> case checkTestFailureGuardrail g testFailures of
r@(GuardrailTestFailures count) ->
Just (r, "Guardrail: too many test failures (" <> tshow count <> ")")
- _ -> Nothing
+ _ -> case checkEditFailureGuardrail g editFailures of
+ r@(GuardrailEditFailures count) ->
+ Just (r, "Guardrail: too many edit_file failures (" <> tshow count <> " 'old_str not found' errors)")
+ _ -> Nothing
buildToolMap :: [Tool] -> Map.Map Text Tool
buildToolMap = Map.fromList <. map (\t -> (toolName t, t))
-executeToolCallsWithTracking :: EngineConfig -> Map.Map Text Tool -> [ToolCall] -> Int -> IO ([Message], Int)
-executeToolCallsWithTracking engineCfg toolMap tcs initialFailures = do
+-- | Track both test failures and edit failures
+-- Returns (messages, testFailures, editFailures)
+executeToolCallsWithTracking :: EngineConfig -> Map.Map Text Tool -> [ToolCall] -> Int -> Int -> IO ([Message], Int, Int)
+executeToolCallsWithTracking engineCfg toolMap tcs initialTestFailures initialEditFailures = do
results <- traverse executeSingle tcs
- let msgs = map fst results
- failureDeltas = map snd results
- totalNewFailures = initialFailures + sum failureDeltas
- pure (msgs, totalNewFailures)
+ let msgs = map (\(m, _, _) -> m) results
+ testDeltas = map (\(_, t, _) -> t) results
+ editDeltas = map (\(_, _, e) -> e) results
+ totalTestFailures = initialTestFailures + sum testDeltas
+ totalEditFailures = initialEditFailures + sum editDeltas
+ pure (msgs, totalTestFailures, totalEditFailures)
where
executeSingle tc = do
let name = fcName (tcFunction tc)
@@ -745,21 +762,23 @@ executeToolCallsWithTracking engineCfg toolMap tcs initialFailures = do
Nothing -> do
let errMsg = "Tool not found: " <> name
engineOnToolResult engineCfg name False errMsg
- pure (Message ToolRole errMsg Nothing (Just callId), 0)
+ pure (Message ToolRole errMsg Nothing (Just callId), 0, 0)
Just tool -> do
case Aeson.decode (BL.fromStrict (TE.encodeUtf8 argsText)) of
Nothing -> do
let errMsg = "Invalid JSON arguments: " <> argsText
engineOnToolResult engineCfg name False errMsg
- pure (Message ToolRole errMsg Nothing (Just callId), 0)
+ pure (Message ToolRole errMsg Nothing (Just callId), 0, 0)
Just args -> do
resultValue <- toolExecute tool args
let resultText = TE.decodeUtf8 (BL.toStrict (Aeson.encode resultValue))
isTestCall = name == "bash" && ("bild --test" `Text.isInfixOf` argsText || "bild -t" `Text.isInfixOf` argsText)
isTestFailure = isTestCall && isFailureResult resultValue
- failureDelta = if isTestFailure then 1 else 0
+ testDelta = if isTestFailure then 1 else 0
+ isEditFailure = name == "edit_file" && isOldStrNotFoundError resultValue
+ editDelta = if isEditFailure then 1 else 0
engineOnToolResult engineCfg name True resultText
- pure (Message ToolRole resultText Nothing (Just callId), failureDelta)
+ pure (Message ToolRole resultText Nothing (Just callId), testDelta, editDelta)
isFailureResult :: Aeson.Value -> Bool
isFailureResult (Aeson.Object obj) =
@@ -775,6 +794,13 @@ executeToolCallsWithTracking engineCfg toolMap tcs initialFailures = do
`Text.isInfixOf` s
isFailureResult _ = False
+ isOldStrNotFoundError :: Aeson.Value -> Bool
+ isOldStrNotFoundError (Aeson.Object obj) =
+ case KeyMap.lookup "error" obj of
+ Just (Aeson.String s) -> "old_str not found" `Text.isInfixOf` s
+ _ -> False
+ isOldStrNotFoundError _ = False
+
-- | Estimate cost in cents from token count
estimateCost :: Text -> Int -> Double
estimateCost model tokens