summaryrefslogtreecommitdiff
path: root/Omni/Agent/Engine.hs
diff options
context:
space:
mode:
authorBen Sima <ben@bensima.com>2025-12-17 13:29:40 -0500
committerBen Sima <ben@bensima.com>2025-12-17 13:29:40 -0500
commitab01b34bf563990e0f491ada646472aaade97610 (patch)
tree5e46a1a157bb846b0c3a090a83153c788da2b977 /Omni/Agent/Engine.hs
parente112d3ce07fa24f31a281e521a554cc881a76c7b (diff)
parent337648981cc5a55935116141341521f4fce83214 (diff)
Merge Ava deployment changes
Diffstat (limited to 'Omni/Agent/Engine.hs')
-rw-r--r--Omni/Agent/Engine.hs414
1 files changed, 396 insertions, 18 deletions
diff --git a/Omni/Agent/Engine.hs b/Omni/Agent/Engine.hs
index 4ee5e5d..f137ddb 100644
--- a/Omni/Agent/Engine.hs
+++ b/Omni/Agent/Engine.hs
@@ -1,4 +1,5 @@
{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE NoImplicitPrelude #-}
@@ -30,12 +31,16 @@ module Omni.Agent.Engine
ChatCompletionResponse (..),
Choice (..),
Usage (..),
+ ToolApi (..),
+ encodeToolForApi,
defaultLLM,
defaultEngineConfig,
defaultAgentConfig,
defaultGuardrails,
chat,
runAgent,
+ runAgentWithProvider,
+ runAgentWithProviderStreaming,
main,
test,
)
@@ -47,10 +52,12 @@ import qualified Data.Aeson as Aeson
import qualified Data.Aeson.KeyMap as KeyMap
import qualified Data.ByteString.Lazy as BL
import qualified Data.CaseInsensitive as CI
+import Data.IORef (newIORef, writeIORef)
import qualified Data.Map.Strict as Map
import qualified Data.Text as Text
import qualified Data.Text.Encoding as TE
import qualified Network.HTTP.Simple as HTTP
+import qualified Omni.Agent.Provider as Provider
import qualified Omni.Test as Test
main :: IO ()
@@ -264,6 +271,14 @@ encodeToolForApi t =
toolApiParameters = toolJsonSchema t
}
+encodeToolForProvider :: Tool -> Provider.ToolApi
+encodeToolForProvider t =
+ Provider.ToolApi
+ { Provider.toolApiName = toolName t,
+ Provider.toolApiDescription = toolDescription t,
+ Provider.toolApiParameters = toolJsonSchema t
+ }
+
data LLM = LLM
{ llmBaseUrl :: Text,
llmApiKey :: Text,
@@ -655,18 +670,24 @@ runAgent engineCfg agentCfg userPrompt = do
unless (Text.null assistantText)
<| engineOnAssistant engineCfg assistantText
case msgToolCalls msg of
- Nothing -> do
- engineOnActivity engineCfg "Agent completed"
- engineOnComplete engineCfg
- pure
- <| Right
- <| AgentResult
- { resultFinalMessage = msgContent msg,
- resultToolCallCount = totalCalls,
- resultIterations = iteration + 1,
- resultTotalCost = newCost,
- resultTotalTokens = newTokens
- }
+ Nothing
+ | Text.null (msgContent msg) && totalCalls > 0 -> do
+ engineOnActivity engineCfg "Empty response after tools, prompting for text"
+ let promptMsg = Message ToolRole "Please provide a response to the user." Nothing Nothing
+ newMsgs = msgs <> [msg, promptMsg]
+ loop llm tools' toolMap newMsgs (iteration + 1) totalCalls newTokens newCost toolCallCounts testFailures editFailures
+ | otherwise -> do
+ engineOnActivity engineCfg "Agent completed"
+ engineOnComplete engineCfg
+ pure
+ <| Right
+ <| AgentResult
+ { resultFinalMessage = msgContent msg,
+ resultToolCallCount = totalCalls,
+ resultIterations = iteration + 1,
+ resultTotalCost = newCost,
+ resultTotalTokens = newTokens
+ }
Just [] -> do
engineOnActivity engineCfg "Agent completed (empty tool calls)"
engineOnComplete engineCfg
@@ -801,11 +822,368 @@ executeToolCallsWithTracking engineCfg toolMap tcs initialTestFailures initialEd
_ -> False
isOldStrNotFoundError _ = False
--- | Estimate cost in cents from token count
+-- | Estimate cost in cents from token count.
+-- Uses blended input/output rates (roughly 2:1 output:input ratio).
+-- Prices as of Dec 2024 from OpenRouter.
estimateCost :: Text -> Int -> Double
estimateCost model tokens
- | "gpt-4o-mini" `Text.isInfixOf` model = fromIntegral tokens * 15 / 1000000
- | "gpt-4o" `Text.isInfixOf` model = fromIntegral tokens * 250 / 100000
- | "gpt-4" `Text.isInfixOf` model = fromIntegral tokens * 3 / 100000
- | "claude" `Text.isInfixOf` model = fromIntegral tokens * 3 / 100000
- | otherwise = fromIntegral tokens / 100000
+ | "gpt-4o-mini" `Text.isInfixOf` model = fromIntegral tokens * 0.04 / 1000
+ | "gpt-4o" `Text.isInfixOf` model = fromIntegral tokens * 0.7 / 1000
+ | "gemini-2.0-flash" `Text.isInfixOf` model = fromIntegral tokens * 0.15 / 1000
+ | "gemini-2.5-flash" `Text.isInfixOf` model = fromIntegral tokens * 0.15 / 1000
+ | "claude-sonnet-4.5" `Text.isInfixOf` model = fromIntegral tokens * 0.9 / 1000
+ | "claude-sonnet-4" `Text.isInfixOf` model = fromIntegral tokens * 0.9 / 1000
+ | "claude-3-haiku" `Text.isInfixOf` model = fromIntegral tokens * 0.1 / 1000
+ | "claude" `Text.isInfixOf` model = fromIntegral tokens * 0.9 / 1000
+ | otherwise = fromIntegral tokens * 0.5 / 1000
+
+-- | Run agent with a Provider instead of LLM.
+-- This is the new preferred way to run agents with multiple backend support.
+runAgentWithProvider :: EngineConfig -> Provider.Provider -> AgentConfig -> Text -> IO (Either Text AgentResult)
+runAgentWithProvider engineCfg provider agentCfg userPrompt = do
+ let tools = agentTools agentCfg
+ toolApis = map encodeToolForProvider tools
+ toolMap = buildToolMap tools
+ systemMsg = providerMessage Provider.System (agentSystemPrompt agentCfg)
+ userMsg = providerMessage Provider.User userPrompt
+ initialMessages = [systemMsg, userMsg]
+
+ engineOnActivity engineCfg "Starting agent loop (Provider)"
+ loopProvider provider toolApis toolMap initialMessages 0 0 0 0.0 Map.empty 0 0
+ where
+ maxIter = agentMaxIterations agentCfg
+ guardrails' = agentGuardrails agentCfg
+
+ providerMessage :: Provider.Role -> Text -> Provider.Message
+ providerMessage role content = Provider.Message role content Nothing Nothing
+
+ loopProvider ::
+ Provider.Provider ->
+ [Provider.ToolApi] ->
+ Map.Map Text Tool ->
+ [Provider.Message] ->
+ Int ->
+ Int ->
+ Int ->
+ Double ->
+ Map.Map Text Int ->
+ Int ->
+ Int ->
+ IO (Either Text AgentResult)
+ loopProvider prov toolApis' toolMap msgs iteration totalCalls totalTokens totalCost toolCallCounts testFailures editFailures
+ | iteration >= maxIter = do
+ let errMsg = "Max iterations (" <> tshow maxIter <> ") reached"
+ engineOnError engineCfg errMsg
+ pure <| Left errMsg
+ | otherwise = do
+ let guardrailViolation = findGuardrailViolation guardrails' totalCost totalTokens toolCallCounts testFailures editFailures
+ case guardrailViolation of
+ Just (g, errMsg) -> do
+ engineOnGuardrail engineCfg g
+ pure <| Left errMsg
+ Nothing -> do
+ engineOnActivity engineCfg <| "Iteration " <> tshow (iteration + 1)
+ result <- Provider.chatWithUsage prov toolApis' msgs
+ case result of
+ Left err -> do
+ engineOnError engineCfg err
+ pure (Left err)
+ Right chatRes -> do
+ let msg = Provider.chatMessage chatRes
+ tokens = maybe 0 Provider.usageTotalTokens (Provider.chatUsage chatRes)
+ cost = case Provider.chatUsage chatRes +> Provider.usageCost of
+ Just actualCost -> actualCost * 100
+ Nothing -> estimateCost (getProviderModel prov) tokens
+ engineOnCost engineCfg tokens cost
+ let newTokens = totalTokens + tokens
+ newCost = totalCost + cost
+ let assistantText = Provider.msgContent msg
+ unless (Text.null assistantText)
+ <| engineOnAssistant engineCfg assistantText
+ case Provider.msgToolCalls msg of
+ Nothing
+ | Text.null (Provider.msgContent msg) && totalCalls > 0 -> do
+ engineOnActivity engineCfg "Empty response after tools, prompting for text"
+ let promptMsg = Provider.Message Provider.ToolRole "Please provide a response to the user." Nothing Nothing
+ newMsgs = msgs <> [msg, promptMsg]
+ loopProvider prov toolApis' toolMap newMsgs (iteration + 1) totalCalls newTokens newCost toolCallCounts testFailures editFailures
+ | otherwise -> do
+ engineOnActivity engineCfg "Agent completed"
+ engineOnComplete engineCfg
+ pure
+ <| Right
+ <| AgentResult
+ { resultFinalMessage = Provider.msgContent msg,
+ resultToolCallCount = totalCalls,
+ resultIterations = iteration + 1,
+ resultTotalCost = newCost,
+ resultTotalTokens = newTokens
+ }
+ Just [] -> do
+ engineOnActivity engineCfg "Agent completed (empty tool calls)"
+ engineOnComplete engineCfg
+ pure
+ <| Right
+ <| AgentResult
+ { resultFinalMessage = Provider.msgContent msg,
+ resultToolCallCount = totalCalls,
+ resultIterations = iteration + 1,
+ resultTotalCost = newCost,
+ resultTotalTokens = newTokens
+ }
+ Just tcs -> do
+ (toolResults, newTestFailures, newEditFailures) <- executeProviderToolCalls engineCfg toolMap tcs testFailures editFailures
+ let newMsgs = msgs <> [msg] <> toolResults
+ newCalls = totalCalls + length tcs
+ newToolCallCounts = updateProviderToolCallCounts toolCallCounts tcs
+ loopProvider prov toolApis' toolMap newMsgs (iteration + 1) newCalls newTokens newCost newToolCallCounts newTestFailures newEditFailures
+
+ getProviderModel :: Provider.Provider -> Text
+ getProviderModel (Provider.OpenRouter cfg) = Provider.providerModel cfg
+ getProviderModel (Provider.Ollama cfg) = Provider.providerModel cfg
+ getProviderModel (Provider.AmpCLI _) = "amp"
+
+ updateProviderToolCallCounts :: Map.Map Text Int -> [Provider.ToolCall] -> Map.Map Text Int
+ updateProviderToolCallCounts =
+ foldr (\tc m -> Map.insertWith (+) (Provider.fcName (Provider.tcFunction tc)) 1 m)
+
+ executeProviderToolCalls :: EngineConfig -> Map.Map Text Tool -> [Provider.ToolCall] -> Int -> Int -> IO ([Provider.Message], Int, Int)
+ executeProviderToolCalls eCfg tMap tcs initialTestFailures initialEditFailures = do
+ results <- traverse (executeSingleProvider eCfg tMap) tcs
+ let msgs = map (\(m, _, _) -> m) results
+ testDeltas = map (\(_, t, _) -> t) results
+ editDeltas = map (\(_, _, e) -> e) results
+ totalTestFail = initialTestFailures + sum testDeltas
+ totalEditFail = initialEditFailures + sum editDeltas
+ pure (msgs, totalTestFail, totalEditFail)
+
+ executeSingleProvider :: EngineConfig -> Map.Map Text Tool -> Provider.ToolCall -> IO (Provider.Message, Int, Int)
+ executeSingleProvider eCfg tMap tc = do
+ let name = Provider.fcName (Provider.tcFunction tc)
+ argsText = Provider.fcArguments (Provider.tcFunction tc)
+ callId = Provider.tcId tc
+ engineOnActivity eCfg <| "Executing tool: " <> name
+ engineOnToolCall eCfg name argsText
+ case Map.lookup name tMap of
+ Nothing -> do
+ let errMsg = "Tool not found: " <> name
+ engineOnToolResult eCfg name False errMsg
+ pure (Provider.Message Provider.ToolRole errMsg Nothing (Just callId), 0, 0)
+ Just tool -> do
+ case Aeson.decode (BL.fromStrict (TE.encodeUtf8 argsText)) of
+ Nothing -> do
+ let errMsg = "Invalid JSON arguments: " <> argsText
+ engineOnToolResult eCfg name False errMsg
+ pure (Provider.Message Provider.ToolRole errMsg Nothing (Just callId), 0, 0)
+ Just args -> do
+ resultValue <- toolExecute tool args
+ let resultText = TE.decodeUtf8 (BL.toStrict (Aeson.encode resultValue))
+ isTestCall = name == "bash" && ("bild --test" `Text.isInfixOf` argsText || "bild -t" `Text.isInfixOf` argsText)
+ isTestFailure = isTestCall && isFailureResultProvider resultValue
+ testDelta = if isTestFailure then 1 else 0
+ isEditFailure = name == "edit_file" && isOldStrNotFoundProvider resultValue
+ editDelta = if isEditFailure then 1 else 0
+ engineOnToolResult eCfg name True resultText
+ pure (Provider.Message Provider.ToolRole resultText Nothing (Just callId), testDelta, editDelta)
+
+ isFailureResultProvider :: Aeson.Value -> Bool
+ isFailureResultProvider (Aeson.Object obj) =
+ case KeyMap.lookup "exit_code" obj of
+ Just (Aeson.Number n) -> n /= 0
+ _ -> False
+ isFailureResultProvider (Aeson.String s) =
+ "error"
+ `Text.isInfixOf` Text.toLower s
+ || "failed"
+ `Text.isInfixOf` Text.toLower s
+ || "FAILED"
+ `Text.isInfixOf` s
+ isFailureResultProvider _ = False
+
+ isOldStrNotFoundProvider :: Aeson.Value -> Bool
+ isOldStrNotFoundProvider (Aeson.Object obj) =
+ case KeyMap.lookup "error" obj of
+ Just (Aeson.String s) -> "old_str not found" `Text.isInfixOf` s
+ _ -> False
+ isOldStrNotFoundProvider _ = False
+
+runAgentWithProviderStreaming ::
+ EngineConfig ->
+ Provider.Provider ->
+ AgentConfig ->
+ Text ->
+ (Text -> IO ()) ->
+ IO (Either Text AgentResult)
+runAgentWithProviderStreaming engineCfg provider agentCfg userPrompt onStreamChunk = do
+ let tools = agentTools agentCfg
+ toolApis = map encodeToolForProvider tools
+ toolMap = buildToolMap tools
+ systemMsg = providerMessage Provider.System (agentSystemPrompt agentCfg)
+ userMsg = providerMessage Provider.User userPrompt
+ initialMessages = [systemMsg, userMsg]
+
+ engineOnActivity engineCfg "Starting agent loop (Provider+Streaming)"
+ loopProviderStreaming provider toolApis toolMap initialMessages 0 0 0 0.0 Map.empty 0 0
+ where
+ maxIter = agentMaxIterations agentCfg
+ guardrails' = agentGuardrails agentCfg
+
+ providerMessage :: Provider.Role -> Text -> Provider.Message
+ providerMessage role content = Provider.Message role content Nothing Nothing
+
+ loopProviderStreaming ::
+ Provider.Provider ->
+ [Provider.ToolApi] ->
+ Map.Map Text Tool ->
+ [Provider.Message] ->
+ Int ->
+ Int ->
+ Int ->
+ Double ->
+ Map.Map Text Int ->
+ Int ->
+ Int ->
+ IO (Either Text AgentResult)
+ loopProviderStreaming prov toolApis' toolMap msgs iteration totalCalls totalTokens totalCost toolCallCounts testFailures editFailures
+ | iteration >= maxIter = do
+ let errMsg = "Max iterations (" <> tshow maxIter <> ") reached"
+ engineOnError engineCfg errMsg
+ pure <| Left errMsg
+ | otherwise = do
+ let guardrailViolation = findGuardrailViolation guardrails' totalCost totalTokens toolCallCounts testFailures editFailures
+ case guardrailViolation of
+ Just (g, errMsg) -> do
+ engineOnGuardrail engineCfg g
+ pure <| Left errMsg
+ Nothing -> do
+ engineOnActivity engineCfg <| "Iteration " <> tshow (iteration + 1)
+ hasToolCalls <- newIORef False
+ result <-
+ Provider.chatStream prov toolApis' msgs <| \case
+ Provider.StreamContent txt -> onStreamChunk txt
+ Provider.StreamToolCall _ -> writeIORef hasToolCalls True
+ Provider.StreamToolCallDelta _ -> writeIORef hasToolCalls True
+ Provider.StreamError err -> engineOnError engineCfg err
+ Provider.StreamDone _ -> pure ()
+ case result of
+ Left err -> do
+ engineOnError engineCfg err
+ pure (Left err)
+ Right chatRes -> do
+ let msg = Provider.chatMessage chatRes
+ tokens = maybe 0 Provider.usageTotalTokens (Provider.chatUsage chatRes)
+ cost = case Provider.chatUsage chatRes +> Provider.usageCost of
+ Just actualCost -> actualCost * 100
+ Nothing -> estimateCost (getProviderModelStreaming prov) tokens
+ engineOnCost engineCfg tokens cost
+ let newTokens = totalTokens + tokens
+ newCost = totalCost + cost
+ let assistantText = Provider.msgContent msg
+ unless (Text.null assistantText)
+ <| engineOnAssistant engineCfg assistantText
+ case Provider.msgToolCalls msg of
+ Nothing
+ | Text.null (Provider.msgContent msg) && totalCalls > 0 -> do
+ engineOnActivity engineCfg "Empty response after tools, prompting for text"
+ let promptMsg = Provider.Message Provider.ToolRole "Please provide a response to the user." Nothing Nothing
+ newMsgs = msgs <> [msg, promptMsg]
+ loopProviderStreaming prov toolApis' toolMap newMsgs (iteration + 1) totalCalls newTokens newCost toolCallCounts testFailures editFailures
+ | otherwise -> do
+ engineOnActivity engineCfg "Agent completed"
+ engineOnComplete engineCfg
+ pure
+ <| Right
+ <| AgentResult
+ { resultFinalMessage = Provider.msgContent msg,
+ resultToolCallCount = totalCalls,
+ resultIterations = iteration + 1,
+ resultTotalCost = newCost,
+ resultTotalTokens = newTokens
+ }
+ Just [] -> do
+ engineOnActivity engineCfg "Agent completed (empty tool calls)"
+ engineOnComplete engineCfg
+ pure
+ <| Right
+ <| AgentResult
+ { resultFinalMessage = Provider.msgContent msg,
+ resultToolCallCount = totalCalls,
+ resultIterations = iteration + 1,
+ resultTotalCost = newCost,
+ resultTotalTokens = newTokens
+ }
+ Just tcs -> do
+ (toolResults, newTestFailures, newEditFailures) <- executeToolCallsStreaming engineCfg toolMap tcs testFailures editFailures
+ let newMsgs = msgs <> [msg] <> toolResults
+ newCalls = totalCalls + length tcs
+ newToolCallCounts = updateToolCallCountsStreaming toolCallCounts tcs
+ loopProviderStreaming prov toolApis' toolMap newMsgs (iteration + 1) newCalls newTokens newCost newToolCallCounts newTestFailures newEditFailures
+
+ getProviderModelStreaming :: Provider.Provider -> Text
+ getProviderModelStreaming (Provider.OpenRouter cfg) = Provider.providerModel cfg
+ getProviderModelStreaming (Provider.Ollama cfg) = Provider.providerModel cfg
+ getProviderModelStreaming (Provider.AmpCLI _) = "amp"
+
+ updateToolCallCountsStreaming :: Map.Map Text Int -> [Provider.ToolCall] -> Map.Map Text Int
+ updateToolCallCountsStreaming =
+ foldr (\tc m -> Map.insertWith (+) (Provider.fcName (Provider.tcFunction tc)) 1 m)
+
+ executeToolCallsStreaming :: EngineConfig -> Map.Map Text Tool -> [Provider.ToolCall] -> Int -> Int -> IO ([Provider.Message], Int, Int)
+ executeToolCallsStreaming eCfg tMap tcs initialTestFailures initialEditFailures = do
+ results <- traverse (executeSingleStreaming eCfg tMap) tcs
+ let msgs = map (\(m, _, _) -> m) results
+ testDeltas = map (\(_, t, _) -> t) results
+ editDeltas = map (\(_, _, e) -> e) results
+ totalTestFail = initialTestFailures + sum testDeltas
+ totalEditFail = initialEditFailures + sum editDeltas
+ pure (msgs, totalTestFail, totalEditFail)
+
+ executeSingleStreaming :: EngineConfig -> Map.Map Text Tool -> Provider.ToolCall -> IO (Provider.Message, Int, Int)
+ executeSingleStreaming eCfg tMap tc = do
+ let name = Provider.fcName (Provider.tcFunction tc)
+ argsText = Provider.fcArguments (Provider.tcFunction tc)
+ callId = Provider.tcId tc
+ engineOnActivity eCfg <| "Executing tool: " <> name
+ engineOnToolCall eCfg name argsText
+ case Map.lookup name tMap of
+ Nothing -> do
+ let errMsg = "Tool not found: " <> name
+ engineOnToolResult eCfg name False errMsg
+ pure (Provider.Message Provider.ToolRole errMsg Nothing (Just callId), 0, 0)
+ Just tool -> do
+ case Aeson.decode (BL.fromStrict (TE.encodeUtf8 argsText)) of
+ Nothing -> do
+ let errMsg = "Invalid JSON arguments: " <> argsText
+ engineOnToolResult eCfg name False errMsg
+ pure (Provider.Message Provider.ToolRole errMsg Nothing (Just callId), 0, 0)
+ Just args -> do
+ resultValue <- toolExecute tool args
+ let resultText = TE.decodeUtf8 (BL.toStrict (Aeson.encode resultValue))
+ isTestCall = name == "bash" && ("bild --test" `Text.isInfixOf` argsText || "bild -t" `Text.isInfixOf` argsText)
+ isTestFailure = isTestCall && isFailureResultStreaming resultValue
+ testDelta = if isTestFailure then 1 else 0
+ isEditFailure = name == "edit_file" && isOldStrNotFoundStreaming resultValue
+ editDelta = if isEditFailure then 1 else 0
+ engineOnToolResult eCfg name True resultText
+ pure (Provider.Message Provider.ToolRole resultText Nothing (Just callId), testDelta, editDelta)
+
+ isFailureResultStreaming :: Aeson.Value -> Bool
+ isFailureResultStreaming (Aeson.Object obj) =
+ case KeyMap.lookup "exit_code" obj of
+ Just (Aeson.Number n) -> n /= 0
+ _ -> False
+ isFailureResultStreaming (Aeson.String s) =
+ "error"
+ `Text.isInfixOf` Text.toLower s
+ || "failed"
+ `Text.isInfixOf` Text.toLower s
+ || "FAILED"
+ `Text.isInfixOf` s
+ isFailureResultStreaming _ = False
+
+ isOldStrNotFoundStreaming :: Aeson.Value -> Bool
+ isOldStrNotFoundStreaming (Aeson.Object obj) =
+ case KeyMap.lookup "error" obj of
+ Just (Aeson.String s) -> "old_str not found" `Text.isInfixOf` s
+ _ -> False
+ isOldStrNotFoundStreaming _ = False