diff options
| author | Ben Sima <ben@bensima.com> | 2025-12-13 08:21:23 -0500 |
|---|---|---|
| committer | Ben Sima <ben@bensima.com> | 2025-12-13 08:21:23 -0500 |
| commit | 1c7b30005af27dcc3345f7dee0fe0404c3bc8c49 (patch) | |
| tree | 25eb99ddfce74749264aa30dcf2992c207cc71a3 /Omni/Agent/Engine.hs | |
| parent | f752330c9562b7a1bbdce15c05106a577daa2392 (diff) | |
fix: accumulate streaming tool call arguments across SSE chunks
OpenAI's SSE streaming sends tool calls incrementally - the first chunk
has the id and function name, subsequent chunks contain argument fragments.
Previously each chunk was treated as a complete tool call, causing invalid
JSON arguments.
- Add ToolCallDelta type with index for partial tool call data
- Add StreamToolCallDelta chunk type
- Track tool calls by index in IntMap accumulator
- Merge argument fragments across chunks via mergeToolCallDelta
- Build final ToolCall objects from accumulator when stream ends
- Handle new StreamToolCallDelta in Engine.hs pattern match
Diffstat (limited to 'Omni/Agent/Engine.hs')
| -rw-r--r-- | Omni/Agent/Engine.hs | 184 |
1 files changed, 184 insertions, 0 deletions
diff --git a/Omni/Agent/Engine.hs b/Omni/Agent/Engine.hs index f9b0355..f137ddb 100644 --- a/Omni/Agent/Engine.hs +++ b/Omni/Agent/Engine.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE NoImplicitPrelude #-} @@ -39,6 +40,7 @@ module Omni.Agent.Engine chat, runAgent, runAgentWithProvider, + runAgentWithProviderStreaming, main, test, ) @@ -50,6 +52,7 @@ import qualified Data.Aeson as Aeson import qualified Data.Aeson.KeyMap as KeyMap import qualified Data.ByteString.Lazy as BL import qualified Data.CaseInsensitive as CI +import Data.IORef (newIORef, writeIORef) import qualified Data.Map.Strict as Map import qualified Data.Text as Text import qualified Data.Text.Encoding as TE @@ -1003,3 +1006,184 @@ runAgentWithProvider engineCfg provider agentCfg userPrompt = do Just (Aeson.String s) -> "old_str not found" `Text.isInfixOf` s _ -> False isOldStrNotFoundProvider _ = False + +runAgentWithProviderStreaming :: + EngineConfig -> + Provider.Provider -> + AgentConfig -> + Text -> + (Text -> IO ()) -> + IO (Either Text AgentResult) +runAgentWithProviderStreaming engineCfg provider agentCfg userPrompt onStreamChunk = do + let tools = agentTools agentCfg + toolApis = map encodeToolForProvider tools + toolMap = buildToolMap tools + systemMsg = providerMessage Provider.System (agentSystemPrompt agentCfg) + userMsg = providerMessage Provider.User userPrompt + initialMessages = [systemMsg, userMsg] + + engineOnActivity engineCfg "Starting agent loop (Provider+Streaming)" + loopProviderStreaming provider toolApis toolMap initialMessages 0 0 0 0.0 Map.empty 0 0 + where + maxIter = agentMaxIterations agentCfg + guardrails' = agentGuardrails agentCfg + + providerMessage :: Provider.Role -> Text -> Provider.Message + providerMessage role content = Provider.Message role content Nothing Nothing + + loopProviderStreaming :: + Provider.Provider -> + [Provider.ToolApi] -> + Map.Map Text Tool -> + [Provider.Message] -> + Int -> + Int -> + Int -> + Double -> + Map.Map Text Int -> + Int -> + Int -> + IO (Either Text AgentResult) + loopProviderStreaming prov toolApis' toolMap msgs iteration totalCalls totalTokens totalCost toolCallCounts testFailures editFailures + | iteration >= maxIter = do + let errMsg = "Max iterations (" <> tshow maxIter <> ") reached" + engineOnError engineCfg errMsg + pure <| Left errMsg + | otherwise = do + let guardrailViolation = findGuardrailViolation guardrails' totalCost totalTokens toolCallCounts testFailures editFailures + case guardrailViolation of + Just (g, errMsg) -> do + engineOnGuardrail engineCfg g + pure <| Left errMsg + Nothing -> do + engineOnActivity engineCfg <| "Iteration " <> tshow (iteration + 1) + hasToolCalls <- newIORef False + result <- + Provider.chatStream prov toolApis' msgs <| \case + Provider.StreamContent txt -> onStreamChunk txt + Provider.StreamToolCall _ -> writeIORef hasToolCalls True + Provider.StreamToolCallDelta _ -> writeIORef hasToolCalls True + Provider.StreamError err -> engineOnError engineCfg err + Provider.StreamDone _ -> pure () + case result of + Left err -> do + engineOnError engineCfg err + pure (Left err) + Right chatRes -> do + let msg = Provider.chatMessage chatRes + tokens = maybe 0 Provider.usageTotalTokens (Provider.chatUsage chatRes) + cost = case Provider.chatUsage chatRes +> Provider.usageCost of + Just actualCost -> actualCost * 100 + Nothing -> estimateCost (getProviderModelStreaming prov) tokens + engineOnCost engineCfg tokens cost + let newTokens = totalTokens + tokens + newCost = totalCost + cost + let assistantText = Provider.msgContent msg + unless (Text.null assistantText) + <| engineOnAssistant engineCfg assistantText + case Provider.msgToolCalls msg of + Nothing + | Text.null (Provider.msgContent msg) && totalCalls > 0 -> do + engineOnActivity engineCfg "Empty response after tools, prompting for text" + let promptMsg = Provider.Message Provider.ToolRole "Please provide a response to the user." Nothing Nothing + newMsgs = msgs <> [msg, promptMsg] + loopProviderStreaming prov toolApis' toolMap newMsgs (iteration + 1) totalCalls newTokens newCost toolCallCounts testFailures editFailures + | otherwise -> do + engineOnActivity engineCfg "Agent completed" + engineOnComplete engineCfg + pure + <| Right + <| AgentResult + { resultFinalMessage = Provider.msgContent msg, + resultToolCallCount = totalCalls, + resultIterations = iteration + 1, + resultTotalCost = newCost, + resultTotalTokens = newTokens + } + Just [] -> do + engineOnActivity engineCfg "Agent completed (empty tool calls)" + engineOnComplete engineCfg + pure + <| Right + <| AgentResult + { resultFinalMessage = Provider.msgContent msg, + resultToolCallCount = totalCalls, + resultIterations = iteration + 1, + resultTotalCost = newCost, + resultTotalTokens = newTokens + } + Just tcs -> do + (toolResults, newTestFailures, newEditFailures) <- executeToolCallsStreaming engineCfg toolMap tcs testFailures editFailures + let newMsgs = msgs <> [msg] <> toolResults + newCalls = totalCalls + length tcs + newToolCallCounts = updateToolCallCountsStreaming toolCallCounts tcs + loopProviderStreaming prov toolApis' toolMap newMsgs (iteration + 1) newCalls newTokens newCost newToolCallCounts newTestFailures newEditFailures + + getProviderModelStreaming :: Provider.Provider -> Text + getProviderModelStreaming (Provider.OpenRouter cfg) = Provider.providerModel cfg + getProviderModelStreaming (Provider.Ollama cfg) = Provider.providerModel cfg + getProviderModelStreaming (Provider.AmpCLI _) = "amp" + + updateToolCallCountsStreaming :: Map.Map Text Int -> [Provider.ToolCall] -> Map.Map Text Int + updateToolCallCountsStreaming = + foldr (\tc m -> Map.insertWith (+) (Provider.fcName (Provider.tcFunction tc)) 1 m) + + executeToolCallsStreaming :: EngineConfig -> Map.Map Text Tool -> [Provider.ToolCall] -> Int -> Int -> IO ([Provider.Message], Int, Int) + executeToolCallsStreaming eCfg tMap tcs initialTestFailures initialEditFailures = do + results <- traverse (executeSingleStreaming eCfg tMap) tcs + let msgs = map (\(m, _, _) -> m) results + testDeltas = map (\(_, t, _) -> t) results + editDeltas = map (\(_, _, e) -> e) results + totalTestFail = initialTestFailures + sum testDeltas + totalEditFail = initialEditFailures + sum editDeltas + pure (msgs, totalTestFail, totalEditFail) + + executeSingleStreaming :: EngineConfig -> Map.Map Text Tool -> Provider.ToolCall -> IO (Provider.Message, Int, Int) + executeSingleStreaming eCfg tMap tc = do + let name = Provider.fcName (Provider.tcFunction tc) + argsText = Provider.fcArguments (Provider.tcFunction tc) + callId = Provider.tcId tc + engineOnActivity eCfg <| "Executing tool: " <> name + engineOnToolCall eCfg name argsText + case Map.lookup name tMap of + Nothing -> do + let errMsg = "Tool not found: " <> name + engineOnToolResult eCfg name False errMsg + pure (Provider.Message Provider.ToolRole errMsg Nothing (Just callId), 0, 0) + Just tool -> do + case Aeson.decode (BL.fromStrict (TE.encodeUtf8 argsText)) of + Nothing -> do + let errMsg = "Invalid JSON arguments: " <> argsText + engineOnToolResult eCfg name False errMsg + pure (Provider.Message Provider.ToolRole errMsg Nothing (Just callId), 0, 0) + Just args -> do + resultValue <- toolExecute tool args + let resultText = TE.decodeUtf8 (BL.toStrict (Aeson.encode resultValue)) + isTestCall = name == "bash" && ("bild --test" `Text.isInfixOf` argsText || "bild -t" `Text.isInfixOf` argsText) + isTestFailure = isTestCall && isFailureResultStreaming resultValue + testDelta = if isTestFailure then 1 else 0 + isEditFailure = name == "edit_file" && isOldStrNotFoundStreaming resultValue + editDelta = if isEditFailure then 1 else 0 + engineOnToolResult eCfg name True resultText + pure (Provider.Message Provider.ToolRole resultText Nothing (Just callId), testDelta, editDelta) + + isFailureResultStreaming :: Aeson.Value -> Bool + isFailureResultStreaming (Aeson.Object obj) = + case KeyMap.lookup "exit_code" obj of + Just (Aeson.Number n) -> n /= 0 + _ -> False + isFailureResultStreaming (Aeson.String s) = + "error" + `Text.isInfixOf` Text.toLower s + || "failed" + `Text.isInfixOf` Text.toLower s + || "FAILED" + `Text.isInfixOf` s + isFailureResultStreaming _ = False + + isOldStrNotFoundStreaming :: Aeson.Value -> Bool + isOldStrNotFoundStreaming (Aeson.Object obj) = + case KeyMap.lookup "error" obj of + Just (Aeson.String s) -> "old_str not found" `Text.isInfixOf` s + _ -> False + isOldStrNotFoundStreaming _ = False |
