diff options
| author | Ben Sima <ben@bensima.com> | 2025-12-17 13:29:40 -0500 |
|---|---|---|
| committer | Ben Sima <ben@bensima.com> | 2025-12-17 13:29:40 -0500 |
| commit | ab01b34bf563990e0f491ada646472aaade97610 (patch) | |
| tree | 5e46a1a157bb846b0c3a090a83153c788da2b977 /Omni/Agent/Engine.hs | |
| parent | e112d3ce07fa24f31a281e521a554cc881a76c7b (diff) | |
| parent | 337648981cc5a55935116141341521f4fce83214 (diff) | |
Merge Ava deployment changes
Diffstat (limited to 'Omni/Agent/Engine.hs')
| -rw-r--r-- | Omni/Agent/Engine.hs | 414 |
1 files changed, 396 insertions, 18 deletions
diff --git a/Omni/Agent/Engine.hs b/Omni/Agent/Engine.hs index 4ee5e5d..f137ddb 100644 --- a/Omni/Agent/Engine.hs +++ b/Omni/Agent/Engine.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE NoImplicitPrelude #-} @@ -30,12 +31,16 @@ module Omni.Agent.Engine ChatCompletionResponse (..), Choice (..), Usage (..), + ToolApi (..), + encodeToolForApi, defaultLLM, defaultEngineConfig, defaultAgentConfig, defaultGuardrails, chat, runAgent, + runAgentWithProvider, + runAgentWithProviderStreaming, main, test, ) @@ -47,10 +52,12 @@ import qualified Data.Aeson as Aeson import qualified Data.Aeson.KeyMap as KeyMap import qualified Data.ByteString.Lazy as BL import qualified Data.CaseInsensitive as CI +import Data.IORef (newIORef, writeIORef) import qualified Data.Map.Strict as Map import qualified Data.Text as Text import qualified Data.Text.Encoding as TE import qualified Network.HTTP.Simple as HTTP +import qualified Omni.Agent.Provider as Provider import qualified Omni.Test as Test main :: IO () @@ -264,6 +271,14 @@ encodeToolForApi t = toolApiParameters = toolJsonSchema t } +encodeToolForProvider :: Tool -> Provider.ToolApi +encodeToolForProvider t = + Provider.ToolApi + { Provider.toolApiName = toolName t, + Provider.toolApiDescription = toolDescription t, + Provider.toolApiParameters = toolJsonSchema t + } + data LLM = LLM { llmBaseUrl :: Text, llmApiKey :: Text, @@ -655,18 +670,24 @@ runAgent engineCfg agentCfg userPrompt = do unless (Text.null assistantText) <| engineOnAssistant engineCfg assistantText case msgToolCalls msg of - Nothing -> do - engineOnActivity engineCfg "Agent completed" - engineOnComplete engineCfg - pure - <| Right - <| AgentResult - { resultFinalMessage = msgContent msg, - resultToolCallCount = totalCalls, - resultIterations = iteration + 1, - resultTotalCost = newCost, - resultTotalTokens = newTokens - } + Nothing + | Text.null (msgContent msg) && totalCalls > 0 -> do + engineOnActivity engineCfg "Empty response after tools, prompting for text" + let promptMsg = Message ToolRole "Please provide a response to the user." Nothing Nothing + newMsgs = msgs <> [msg, promptMsg] + loop llm tools' toolMap newMsgs (iteration + 1) totalCalls newTokens newCost toolCallCounts testFailures editFailures + | otherwise -> do + engineOnActivity engineCfg "Agent completed" + engineOnComplete engineCfg + pure + <| Right + <| AgentResult + { resultFinalMessage = msgContent msg, + resultToolCallCount = totalCalls, + resultIterations = iteration + 1, + resultTotalCost = newCost, + resultTotalTokens = newTokens + } Just [] -> do engineOnActivity engineCfg "Agent completed (empty tool calls)" engineOnComplete engineCfg @@ -801,11 +822,368 @@ executeToolCallsWithTracking engineCfg toolMap tcs initialTestFailures initialEd _ -> False isOldStrNotFoundError _ = False --- | Estimate cost in cents from token count +-- | Estimate cost in cents from token count. +-- Uses blended input/output rates (roughly 2:1 output:input ratio). +-- Prices as of Dec 2024 from OpenRouter. estimateCost :: Text -> Int -> Double estimateCost model tokens - | "gpt-4o-mini" `Text.isInfixOf` model = fromIntegral tokens * 15 / 1000000 - | "gpt-4o" `Text.isInfixOf` model = fromIntegral tokens * 250 / 100000 - | "gpt-4" `Text.isInfixOf` model = fromIntegral tokens * 3 / 100000 - | "claude" `Text.isInfixOf` model = fromIntegral tokens * 3 / 100000 - | otherwise = fromIntegral tokens / 100000 + | "gpt-4o-mini" `Text.isInfixOf` model = fromIntegral tokens * 0.04 / 1000 + | "gpt-4o" `Text.isInfixOf` model = fromIntegral tokens * 0.7 / 1000 + | "gemini-2.0-flash" `Text.isInfixOf` model = fromIntegral tokens * 0.15 / 1000 + | "gemini-2.5-flash" `Text.isInfixOf` model = fromIntegral tokens * 0.15 / 1000 + | "claude-sonnet-4.5" `Text.isInfixOf` model = fromIntegral tokens * 0.9 / 1000 + | "claude-sonnet-4" `Text.isInfixOf` model = fromIntegral tokens * 0.9 / 1000 + | "claude-3-haiku" `Text.isInfixOf` model = fromIntegral tokens * 0.1 / 1000 + | "claude" `Text.isInfixOf` model = fromIntegral tokens * 0.9 / 1000 + | otherwise = fromIntegral tokens * 0.5 / 1000 + +-- | Run agent with a Provider instead of LLM. +-- This is the new preferred way to run agents with multiple backend support. +runAgentWithProvider :: EngineConfig -> Provider.Provider -> AgentConfig -> Text -> IO (Either Text AgentResult) +runAgentWithProvider engineCfg provider agentCfg userPrompt = do + let tools = agentTools agentCfg + toolApis = map encodeToolForProvider tools + toolMap = buildToolMap tools + systemMsg = providerMessage Provider.System (agentSystemPrompt agentCfg) + userMsg = providerMessage Provider.User userPrompt + initialMessages = [systemMsg, userMsg] + + engineOnActivity engineCfg "Starting agent loop (Provider)" + loopProvider provider toolApis toolMap initialMessages 0 0 0 0.0 Map.empty 0 0 + where + maxIter = agentMaxIterations agentCfg + guardrails' = agentGuardrails agentCfg + + providerMessage :: Provider.Role -> Text -> Provider.Message + providerMessage role content = Provider.Message role content Nothing Nothing + + loopProvider :: + Provider.Provider -> + [Provider.ToolApi] -> + Map.Map Text Tool -> + [Provider.Message] -> + Int -> + Int -> + Int -> + Double -> + Map.Map Text Int -> + Int -> + Int -> + IO (Either Text AgentResult) + loopProvider prov toolApis' toolMap msgs iteration totalCalls totalTokens totalCost toolCallCounts testFailures editFailures + | iteration >= maxIter = do + let errMsg = "Max iterations (" <> tshow maxIter <> ") reached" + engineOnError engineCfg errMsg + pure <| Left errMsg + | otherwise = do + let guardrailViolation = findGuardrailViolation guardrails' totalCost totalTokens toolCallCounts testFailures editFailures + case guardrailViolation of + Just (g, errMsg) -> do + engineOnGuardrail engineCfg g + pure <| Left errMsg + Nothing -> do + engineOnActivity engineCfg <| "Iteration " <> tshow (iteration + 1) + result <- Provider.chatWithUsage prov toolApis' msgs + case result of + Left err -> do + engineOnError engineCfg err + pure (Left err) + Right chatRes -> do + let msg = Provider.chatMessage chatRes + tokens = maybe 0 Provider.usageTotalTokens (Provider.chatUsage chatRes) + cost = case Provider.chatUsage chatRes +> Provider.usageCost of + Just actualCost -> actualCost * 100 + Nothing -> estimateCost (getProviderModel prov) tokens + engineOnCost engineCfg tokens cost + let newTokens = totalTokens + tokens + newCost = totalCost + cost + let assistantText = Provider.msgContent msg + unless (Text.null assistantText) + <| engineOnAssistant engineCfg assistantText + case Provider.msgToolCalls msg of + Nothing + | Text.null (Provider.msgContent msg) && totalCalls > 0 -> do + engineOnActivity engineCfg "Empty response after tools, prompting for text" + let promptMsg = Provider.Message Provider.ToolRole "Please provide a response to the user." Nothing Nothing + newMsgs = msgs <> [msg, promptMsg] + loopProvider prov toolApis' toolMap newMsgs (iteration + 1) totalCalls newTokens newCost toolCallCounts testFailures editFailures + | otherwise -> do + engineOnActivity engineCfg "Agent completed" + engineOnComplete engineCfg + pure + <| Right + <| AgentResult + { resultFinalMessage = Provider.msgContent msg, + resultToolCallCount = totalCalls, + resultIterations = iteration + 1, + resultTotalCost = newCost, + resultTotalTokens = newTokens + } + Just [] -> do + engineOnActivity engineCfg "Agent completed (empty tool calls)" + engineOnComplete engineCfg + pure + <| Right + <| AgentResult + { resultFinalMessage = Provider.msgContent msg, + resultToolCallCount = totalCalls, + resultIterations = iteration + 1, + resultTotalCost = newCost, + resultTotalTokens = newTokens + } + Just tcs -> do + (toolResults, newTestFailures, newEditFailures) <- executeProviderToolCalls engineCfg toolMap tcs testFailures editFailures + let newMsgs = msgs <> [msg] <> toolResults + newCalls = totalCalls + length tcs + newToolCallCounts = updateProviderToolCallCounts toolCallCounts tcs + loopProvider prov toolApis' toolMap newMsgs (iteration + 1) newCalls newTokens newCost newToolCallCounts newTestFailures newEditFailures + + getProviderModel :: Provider.Provider -> Text + getProviderModel (Provider.OpenRouter cfg) = Provider.providerModel cfg + getProviderModel (Provider.Ollama cfg) = Provider.providerModel cfg + getProviderModel (Provider.AmpCLI _) = "amp" + + updateProviderToolCallCounts :: Map.Map Text Int -> [Provider.ToolCall] -> Map.Map Text Int + updateProviderToolCallCounts = + foldr (\tc m -> Map.insertWith (+) (Provider.fcName (Provider.tcFunction tc)) 1 m) + + executeProviderToolCalls :: EngineConfig -> Map.Map Text Tool -> [Provider.ToolCall] -> Int -> Int -> IO ([Provider.Message], Int, Int) + executeProviderToolCalls eCfg tMap tcs initialTestFailures initialEditFailures = do + results <- traverse (executeSingleProvider eCfg tMap) tcs + let msgs = map (\(m, _, _) -> m) results + testDeltas = map (\(_, t, _) -> t) results + editDeltas = map (\(_, _, e) -> e) results + totalTestFail = initialTestFailures + sum testDeltas + totalEditFail = initialEditFailures + sum editDeltas + pure (msgs, totalTestFail, totalEditFail) + + executeSingleProvider :: EngineConfig -> Map.Map Text Tool -> Provider.ToolCall -> IO (Provider.Message, Int, Int) + executeSingleProvider eCfg tMap tc = do + let name = Provider.fcName (Provider.tcFunction tc) + argsText = Provider.fcArguments (Provider.tcFunction tc) + callId = Provider.tcId tc + engineOnActivity eCfg <| "Executing tool: " <> name + engineOnToolCall eCfg name argsText + case Map.lookup name tMap of + Nothing -> do + let errMsg = "Tool not found: " <> name + engineOnToolResult eCfg name False errMsg + pure (Provider.Message Provider.ToolRole errMsg Nothing (Just callId), 0, 0) + Just tool -> do + case Aeson.decode (BL.fromStrict (TE.encodeUtf8 argsText)) of + Nothing -> do + let errMsg = "Invalid JSON arguments: " <> argsText + engineOnToolResult eCfg name False errMsg + pure (Provider.Message Provider.ToolRole errMsg Nothing (Just callId), 0, 0) + Just args -> do + resultValue <- toolExecute tool args + let resultText = TE.decodeUtf8 (BL.toStrict (Aeson.encode resultValue)) + isTestCall = name == "bash" && ("bild --test" `Text.isInfixOf` argsText || "bild -t" `Text.isInfixOf` argsText) + isTestFailure = isTestCall && isFailureResultProvider resultValue + testDelta = if isTestFailure then 1 else 0 + isEditFailure = name == "edit_file" && isOldStrNotFoundProvider resultValue + editDelta = if isEditFailure then 1 else 0 + engineOnToolResult eCfg name True resultText + pure (Provider.Message Provider.ToolRole resultText Nothing (Just callId), testDelta, editDelta) + + isFailureResultProvider :: Aeson.Value -> Bool + isFailureResultProvider (Aeson.Object obj) = + case KeyMap.lookup "exit_code" obj of + Just (Aeson.Number n) -> n /= 0 + _ -> False + isFailureResultProvider (Aeson.String s) = + "error" + `Text.isInfixOf` Text.toLower s + || "failed" + `Text.isInfixOf` Text.toLower s + || "FAILED" + `Text.isInfixOf` s + isFailureResultProvider _ = False + + isOldStrNotFoundProvider :: Aeson.Value -> Bool + isOldStrNotFoundProvider (Aeson.Object obj) = + case KeyMap.lookup "error" obj of + Just (Aeson.String s) -> "old_str not found" `Text.isInfixOf` s + _ -> False + isOldStrNotFoundProvider _ = False + +runAgentWithProviderStreaming :: + EngineConfig -> + Provider.Provider -> + AgentConfig -> + Text -> + (Text -> IO ()) -> + IO (Either Text AgentResult) +runAgentWithProviderStreaming engineCfg provider agentCfg userPrompt onStreamChunk = do + let tools = agentTools agentCfg + toolApis = map encodeToolForProvider tools + toolMap = buildToolMap tools + systemMsg = providerMessage Provider.System (agentSystemPrompt agentCfg) + userMsg = providerMessage Provider.User userPrompt + initialMessages = [systemMsg, userMsg] + + engineOnActivity engineCfg "Starting agent loop (Provider+Streaming)" + loopProviderStreaming provider toolApis toolMap initialMessages 0 0 0 0.0 Map.empty 0 0 + where + maxIter = agentMaxIterations agentCfg + guardrails' = agentGuardrails agentCfg + + providerMessage :: Provider.Role -> Text -> Provider.Message + providerMessage role content = Provider.Message role content Nothing Nothing + + loopProviderStreaming :: + Provider.Provider -> + [Provider.ToolApi] -> + Map.Map Text Tool -> + [Provider.Message] -> + Int -> + Int -> + Int -> + Double -> + Map.Map Text Int -> + Int -> + Int -> + IO (Either Text AgentResult) + loopProviderStreaming prov toolApis' toolMap msgs iteration totalCalls totalTokens totalCost toolCallCounts testFailures editFailures + | iteration >= maxIter = do + let errMsg = "Max iterations (" <> tshow maxIter <> ") reached" + engineOnError engineCfg errMsg + pure <| Left errMsg + | otherwise = do + let guardrailViolation = findGuardrailViolation guardrails' totalCost totalTokens toolCallCounts testFailures editFailures + case guardrailViolation of + Just (g, errMsg) -> do + engineOnGuardrail engineCfg g + pure <| Left errMsg + Nothing -> do + engineOnActivity engineCfg <| "Iteration " <> tshow (iteration + 1) + hasToolCalls <- newIORef False + result <- + Provider.chatStream prov toolApis' msgs <| \case + Provider.StreamContent txt -> onStreamChunk txt + Provider.StreamToolCall _ -> writeIORef hasToolCalls True + Provider.StreamToolCallDelta _ -> writeIORef hasToolCalls True + Provider.StreamError err -> engineOnError engineCfg err + Provider.StreamDone _ -> pure () + case result of + Left err -> do + engineOnError engineCfg err + pure (Left err) + Right chatRes -> do + let msg = Provider.chatMessage chatRes + tokens = maybe 0 Provider.usageTotalTokens (Provider.chatUsage chatRes) + cost = case Provider.chatUsage chatRes +> Provider.usageCost of + Just actualCost -> actualCost * 100 + Nothing -> estimateCost (getProviderModelStreaming prov) tokens + engineOnCost engineCfg tokens cost + let newTokens = totalTokens + tokens + newCost = totalCost + cost + let assistantText = Provider.msgContent msg + unless (Text.null assistantText) + <| engineOnAssistant engineCfg assistantText + case Provider.msgToolCalls msg of + Nothing + | Text.null (Provider.msgContent msg) && totalCalls > 0 -> do + engineOnActivity engineCfg "Empty response after tools, prompting for text" + let promptMsg = Provider.Message Provider.ToolRole "Please provide a response to the user." Nothing Nothing + newMsgs = msgs <> [msg, promptMsg] + loopProviderStreaming prov toolApis' toolMap newMsgs (iteration + 1) totalCalls newTokens newCost toolCallCounts testFailures editFailures + | otherwise -> do + engineOnActivity engineCfg "Agent completed" + engineOnComplete engineCfg + pure + <| Right + <| AgentResult + { resultFinalMessage = Provider.msgContent msg, + resultToolCallCount = totalCalls, + resultIterations = iteration + 1, + resultTotalCost = newCost, + resultTotalTokens = newTokens + } + Just [] -> do + engineOnActivity engineCfg "Agent completed (empty tool calls)" + engineOnComplete engineCfg + pure + <| Right + <| AgentResult + { resultFinalMessage = Provider.msgContent msg, + resultToolCallCount = totalCalls, + resultIterations = iteration + 1, + resultTotalCost = newCost, + resultTotalTokens = newTokens + } + Just tcs -> do + (toolResults, newTestFailures, newEditFailures) <- executeToolCallsStreaming engineCfg toolMap tcs testFailures editFailures + let newMsgs = msgs <> [msg] <> toolResults + newCalls = totalCalls + length tcs + newToolCallCounts = updateToolCallCountsStreaming toolCallCounts tcs + loopProviderStreaming prov toolApis' toolMap newMsgs (iteration + 1) newCalls newTokens newCost newToolCallCounts newTestFailures newEditFailures + + getProviderModelStreaming :: Provider.Provider -> Text + getProviderModelStreaming (Provider.OpenRouter cfg) = Provider.providerModel cfg + getProviderModelStreaming (Provider.Ollama cfg) = Provider.providerModel cfg + getProviderModelStreaming (Provider.AmpCLI _) = "amp" + + updateToolCallCountsStreaming :: Map.Map Text Int -> [Provider.ToolCall] -> Map.Map Text Int + updateToolCallCountsStreaming = + foldr (\tc m -> Map.insertWith (+) (Provider.fcName (Provider.tcFunction tc)) 1 m) + + executeToolCallsStreaming :: EngineConfig -> Map.Map Text Tool -> [Provider.ToolCall] -> Int -> Int -> IO ([Provider.Message], Int, Int) + executeToolCallsStreaming eCfg tMap tcs initialTestFailures initialEditFailures = do + results <- traverse (executeSingleStreaming eCfg tMap) tcs + let msgs = map (\(m, _, _) -> m) results + testDeltas = map (\(_, t, _) -> t) results + editDeltas = map (\(_, _, e) -> e) results + totalTestFail = initialTestFailures + sum testDeltas + totalEditFail = initialEditFailures + sum editDeltas + pure (msgs, totalTestFail, totalEditFail) + + executeSingleStreaming :: EngineConfig -> Map.Map Text Tool -> Provider.ToolCall -> IO (Provider.Message, Int, Int) + executeSingleStreaming eCfg tMap tc = do + let name = Provider.fcName (Provider.tcFunction tc) + argsText = Provider.fcArguments (Provider.tcFunction tc) + callId = Provider.tcId tc + engineOnActivity eCfg <| "Executing tool: " <> name + engineOnToolCall eCfg name argsText + case Map.lookup name tMap of + Nothing -> do + let errMsg = "Tool not found: " <> name + engineOnToolResult eCfg name False errMsg + pure (Provider.Message Provider.ToolRole errMsg Nothing (Just callId), 0, 0) + Just tool -> do + case Aeson.decode (BL.fromStrict (TE.encodeUtf8 argsText)) of + Nothing -> do + let errMsg = "Invalid JSON arguments: " <> argsText + engineOnToolResult eCfg name False errMsg + pure (Provider.Message Provider.ToolRole errMsg Nothing (Just callId), 0, 0) + Just args -> do + resultValue <- toolExecute tool args + let resultText = TE.decodeUtf8 (BL.toStrict (Aeson.encode resultValue)) + isTestCall = name == "bash" && ("bild --test" `Text.isInfixOf` argsText || "bild -t" `Text.isInfixOf` argsText) + isTestFailure = isTestCall && isFailureResultStreaming resultValue + testDelta = if isTestFailure then 1 else 0 + isEditFailure = name == "edit_file" && isOldStrNotFoundStreaming resultValue + editDelta = if isEditFailure then 1 else 0 + engineOnToolResult eCfg name True resultText + pure (Provider.Message Provider.ToolRole resultText Nothing (Just callId), testDelta, editDelta) + + isFailureResultStreaming :: Aeson.Value -> Bool + isFailureResultStreaming (Aeson.Object obj) = + case KeyMap.lookup "exit_code" obj of + Just (Aeson.Number n) -> n /= 0 + _ -> False + isFailureResultStreaming (Aeson.String s) = + "error" + `Text.isInfixOf` Text.toLower s + || "failed" + `Text.isInfixOf` Text.toLower s + || "FAILED" + `Text.isInfixOf` s + isFailureResultStreaming _ = False + + isOldStrNotFoundStreaming :: Aeson.Value -> Bool + isOldStrNotFoundStreaming (Aeson.Object obj) = + case KeyMap.lookup "error" obj of + Just (Aeson.String s) -> "old_str not found" `Text.isInfixOf` s + _ -> False + isOldStrNotFoundStreaming _ = False |
