diff options
Diffstat (limited to 'Omni/Agent/Engine.hs')
| -rw-r--r-- | Omni/Agent/Engine.hs | 1189 |
1 files changed, 1189 insertions, 0 deletions
diff --git a/Omni/Agent/Engine.hs b/Omni/Agent/Engine.hs new file mode 100644 index 0000000..f137ddb --- /dev/null +++ b/Omni/Agent/Engine.hs @@ -0,0 +1,1189 @@ +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE NoImplicitPrelude #-} + +-- | LLM Agent Engine - Tool protocol and LLM provider abstraction. +-- +-- This module provides the core abstractions for building LLM-powered agents: +-- - Tool: Defines tools that agents can use +-- - LLM: OpenAI-compatible chat completions API provider +-- - AgentConfig: Configuration for running agents +-- +-- : out omni-agent-engine +-- : dep http-conduit +-- : dep aeson +-- : dep case-insensitive +module Omni.Agent.Engine + ( Tool (..), + LLM (..), + EngineConfig (..), + AgentConfig (..), + AgentResult (..), + Guardrails (..), + GuardrailResult (..), + Message (..), + Role (..), + ToolCall (..), + FunctionCall (..), + ToolResult (..), + ChatCompletionRequest (..), + ChatCompletionResponse (..), + Choice (..), + Usage (..), + ToolApi (..), + encodeToolForApi, + defaultLLM, + defaultEngineConfig, + defaultAgentConfig, + defaultGuardrails, + chat, + runAgent, + runAgentWithProvider, + runAgentWithProviderStreaming, + main, + test, + ) +where + +import Alpha +import Data.Aeson ((.!=), (.:), (.:?), (.=)) +import qualified Data.Aeson as Aeson +import qualified Data.Aeson.KeyMap as KeyMap +import qualified Data.ByteString.Lazy as BL +import qualified Data.CaseInsensitive as CI +import Data.IORef (newIORef, writeIORef) +import qualified Data.Map.Strict as Map +import qualified Data.Text as Text +import qualified Data.Text.Encoding as TE +import qualified Network.HTTP.Simple as HTTP +import qualified Omni.Agent.Provider as Provider +import qualified Omni.Test as Test + +main :: IO () +main = Test.run test + +test :: Test.Tree +test = + Test.group + "Omni.Agent.Engine" + [ Test.unit "Tool JSON roundtrip" <| do + let tool = + Tool + { toolName = "get_weather", + toolDescription = "Get weather for a location", + toolJsonSchema = Aeson.object ["type" .= ("object" :: Text), "properties" .= Aeson.object []], + toolExecute = \_ -> pure (Aeson.String "sunny") + } + let encoded = encodeToolForApi tool + case Aeson.decode (Aeson.encode encoded) of + Nothing -> Test.assertFailure "Failed to decode tool" + Just decoded -> toolName tool Test.@=? toolApiName decoded, + Test.unit "Message JSON roundtrip" <| do + let msg = Message User "Hello" Nothing Nothing + case Aeson.decode (Aeson.encode msg) of + Nothing -> Test.assertFailure "Failed to decode message" + Just decoded -> msgContent msg Test.@=? msgContent decoded, + Test.unit "defaultLLM has correct endpoint" <| do + llmBaseUrl defaultLLM Test.@=? "https://openrouter.ai/api/v1", + Test.unit "defaultLLM has OpenRouter headers" <| do + length (llmExtraHeaders defaultLLM) Test.@=? 2 + llmModel defaultLLM Test.@=? "anthropic/claude-sonnet-4.5", + Test.unit "defaultAgentConfig has sensible defaults" <| do + agentMaxIterations defaultAgentConfig Test.@=? 10, + Test.unit "defaultEngineConfig has no-op callbacks" <| do + engineOnCost defaultEngineConfig 100 5 + engineOnActivity defaultEngineConfig "test" + engineOnToolCall defaultEngineConfig "tool" "result" + True Test.@=? True, + Test.unit "buildToolMap creates correct map" <| do + let tool1 = + Tool + { toolName = "tool1", + toolDescription = "First tool", + toolJsonSchema = Aeson.object [], + toolExecute = \_ -> pure Aeson.Null + } + tool2 = + Tool + { toolName = "tool2", + toolDescription = "Second tool", + toolJsonSchema = Aeson.object [], + toolExecute = \_ -> pure Aeson.Null + } + toolMap = buildToolMap [tool1, tool2] + Map.size toolMap Test.@=? 2 + Map.member "tool1" toolMap Test.@=? True + Map.member "tool2" toolMap Test.@=? True, + Test.unit "Usage JSON parsing" <| do + let json = "{\"prompt_tokens\":100,\"completion_tokens\":50,\"total_tokens\":150}" + case Aeson.decode json of + Nothing -> Test.assertFailure "Failed to decode usage" + Just usage -> do + usagePromptTokens usage Test.@=? 100 + usageCompletionTokens usage Test.@=? 50 + usageTotalTokens usage Test.@=? 150 + usageCost usage Test.@=? Nothing, + Test.unit "Usage JSON parsing with cost" <| do + let json = "{\"prompt_tokens\":194,\"completion_tokens\":2,\"total_tokens\":196,\"cost\":0.95}" + case Aeson.decode json of + Nothing -> Test.assertFailure "Failed to decode usage with cost" + Just usage -> do + usagePromptTokens usage Test.@=? 194 + usageCompletionTokens usage Test.@=? 2 + usageTotalTokens usage Test.@=? 196 + usageCost usage Test.@=? Just 0.95, + Test.unit "AgentResult JSON roundtrip" <| do + let result = + AgentResult + { resultFinalMessage = "Done", + resultToolCallCount = 3, + resultIterations = 2, + resultTotalCost = 50, + resultTotalTokens = 1500 + } + case Aeson.decode (Aeson.encode result) of + Nothing -> Test.assertFailure "Failed to decode AgentResult" + Just decoded -> do + resultFinalMessage decoded Test.@=? "Done" + resultToolCallCount decoded Test.@=? 3 + resultIterations decoded Test.@=? 2, + Test.unit "estimateCost calculates correctly" <| do + let gpt4oCost = estimateCost "gpt-4o" 1000 + gpt4oMiniCost = estimateCost "gpt-4o-mini" 1000 + (gpt4oCost >= gpt4oMiniCost) Test.@=? True + (gpt4oCost > 0) Test.@=? True, + Test.unit "ToolCall JSON roundtrip" <| do + let tc = + ToolCall + { tcId = "call_123", + tcType = "function", + tcFunction = FunctionCall "read_file" "{\"path\":\"/tmp/test\"}" + } + case Aeson.decode (Aeson.encode tc) of + Nothing -> Test.assertFailure "Failed to decode ToolCall" + Just decoded -> tcId decoded Test.@=? "call_123", + Test.unit "FunctionCall JSON roundtrip" <| do + let fc = FunctionCall "test_func" "{\"arg\":\"value\"}" + case Aeson.decode (Aeson.encode fc) of + Nothing -> Test.assertFailure "Failed to decode FunctionCall" + Just decoded -> do + fcName decoded Test.@=? "test_func" + fcArguments decoded Test.@=? "{\"arg\":\"value\"}", + Test.unit "Role JSON roundtrip for all roles" <| do + let roles = [System, User, Assistant, ToolRole] + forM_ roles <| \role -> + case Aeson.decode (Aeson.encode role) of + Nothing -> Test.assertFailure ("Failed to decode Role: " <> show role) + Just decoded -> decoded Test.@=? role, + Test.unit "defaultGuardrails has sensible defaults" <| do + guardrailMaxCostCents defaultGuardrails Test.@=? 100.0 + guardrailMaxTokens defaultGuardrails Test.@=? 500000 + guardrailMaxDuplicateToolCalls defaultGuardrails Test.@=? 3 + guardrailMaxTestFailures defaultGuardrails Test.@=? 3, + Test.unit "checkCostGuardrail detects exceeded budget" <| do + let g = defaultGuardrails {guardrailMaxCostCents = 50.0} + checkCostGuardrail g 60.0 Test.@=? GuardrailCostExceeded 60.0 50.0 + checkCostGuardrail g 40.0 Test.@=? GuardrailOk, + Test.unit "checkTokenGuardrail detects exceeded budget" <| do + let g = defaultGuardrails {guardrailMaxTokens = 1000} + checkTokenGuardrail g 1500 Test.@=? GuardrailTokensExceeded 1500 1000 + checkTokenGuardrail g 500 Test.@=? GuardrailOk, + Test.unit "checkDuplicateGuardrail detects repeated calls" <| do + let g = defaultGuardrails {guardrailMaxDuplicateToolCalls = 3} + counts = Map.fromList [("bash", 3), ("read_file", 1)] + case checkDuplicateGuardrail g counts of + GuardrailDuplicateToolCalls name count -> do + name Test.@=? "bash" + count Test.@=? 3 + _ -> Test.assertFailure "Expected GuardrailDuplicateToolCalls" + checkDuplicateGuardrail g (Map.fromList [("bash", 2)]) Test.@=? GuardrailOk, + Test.unit "checkTestFailureGuardrail detects failures" <| do + let g = defaultGuardrails {guardrailMaxTestFailures = 3} + checkTestFailureGuardrail g 3 Test.@=? GuardrailTestFailures 3 + checkTestFailureGuardrail g 2 Test.@=? GuardrailOk, + Test.unit "updateToolCallCounts accumulates correctly" <| do + let tc1 = ToolCall "1" "function" (FunctionCall "bash" "{}") + tc2 = ToolCall "2" "function" (FunctionCall "bash" "{}") + tc3 = ToolCall "3" "function" (FunctionCall "read_file" "{}") + counts = updateToolCallCounts Map.empty [tc1, tc2, tc3] + Map.lookup "bash" counts Test.@=? Just 2 + Map.lookup "read_file" counts Test.@=? Just 1, + Test.unit "Guardrails JSON roundtrip" <| do + let g = Guardrails 75.0 100000 5 4 3 + case Aeson.decode (Aeson.encode g) of + Nothing -> Test.assertFailure "Failed to decode Guardrails" + Just decoded -> decoded Test.@=? g, + Test.unit "GuardrailResult JSON roundtrip" <| do + let results = + [ GuardrailOk, + GuardrailCostExceeded 100.0 50.0, + GuardrailTokensExceeded 2000 1000, + GuardrailDuplicateToolCalls "bash" 5, + GuardrailTestFailures 3, + GuardrailEditFailures 5 + ] + forM_ results <| \r -> + case Aeson.decode (Aeson.encode r) of + Nothing -> Test.assertFailure ("Failed to decode GuardrailResult: " <> show r) + Just decoded -> decoded Test.@=? r + ] + +data Tool = Tool + { toolName :: Text, + toolDescription :: Text, + toolJsonSchema :: Aeson.Value, + toolExecute :: Aeson.Value -> IO Aeson.Value + } + +data ToolApi = ToolApi + { toolApiName :: Text, + toolApiDescription :: Text, + toolApiParameters :: Aeson.Value + } + deriving (Generic) + +instance Aeson.ToJSON ToolApi where + toJSON t = + Aeson.object + [ "type" .= ("function" :: Text), + "function" + .= Aeson.object + [ "name" .= toolApiName t, + "description" .= toolApiDescription t, + "parameters" .= toolApiParameters t + ] + ] + +instance Aeson.FromJSON ToolApi where + parseJSON = + Aeson.withObject "ToolApi" <| \v -> do + fn <- v .: "function" + (ToolApi </ (fn .: "name")) + <*> (fn .: "description") + <*> (fn .: "parameters") + +encodeToolForApi :: Tool -> ToolApi +encodeToolForApi t = + ToolApi + { toolApiName = toolName t, + toolApiDescription = toolDescription t, + toolApiParameters = toolJsonSchema t + } + +encodeToolForProvider :: Tool -> Provider.ToolApi +encodeToolForProvider t = + Provider.ToolApi + { Provider.toolApiName = toolName t, + Provider.toolApiDescription = toolDescription t, + Provider.toolApiParameters = toolJsonSchema t + } + +data LLM = LLM + { llmBaseUrl :: Text, + llmApiKey :: Text, + llmModel :: Text, + llmExtraHeaders :: [(ByteString, ByteString)] + } + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON LLM where + toJSON l = + Aeson.object + [ "llmBaseUrl" .= llmBaseUrl l, + "llmApiKey" .= llmApiKey l, + "llmModel" .= llmModel l + ] + +instance Aeson.FromJSON LLM where + parseJSON = + Aeson.withObject "LLM" <| \v -> + (LLM </ (v .: "llmBaseUrl")) + <*> (v .: "llmApiKey") + <*> (v .: "llmModel") + <*> pure [] + +defaultLLM :: LLM +defaultLLM = + LLM + { llmBaseUrl = "https://openrouter.ai/api/v1", + llmApiKey = "", + llmModel = "anthropic/claude-sonnet-4.5", + llmExtraHeaders = + [ ("HTTP-Referer", "https://omni.dev"), + ("X-Title", "Omni Agent") + ] + } + +data AgentConfig = AgentConfig + { agentModel :: Text, + agentTools :: [Tool], + agentSystemPrompt :: Text, + agentMaxIterations :: Int, + agentGuardrails :: Guardrails + } + +data Guardrails = Guardrails + { guardrailMaxCostCents :: Double, + guardrailMaxTokens :: Int, + guardrailMaxDuplicateToolCalls :: Int, + guardrailMaxTestFailures :: Int, + guardrailMaxEditFailures :: Int + } + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON Guardrails + +instance Aeson.FromJSON Guardrails + +data GuardrailResult + = GuardrailOk + | GuardrailCostExceeded Double Double + | GuardrailTokensExceeded Int Int + | GuardrailDuplicateToolCalls Text Int + | GuardrailTestFailures Int + | GuardrailEditFailures Int + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON GuardrailResult + +instance Aeson.FromJSON GuardrailResult + +defaultGuardrails :: Guardrails +defaultGuardrails = + Guardrails + { guardrailMaxCostCents = 100.0, + guardrailMaxTokens = 500000, + guardrailMaxDuplicateToolCalls = 3, + guardrailMaxTestFailures = 3, + guardrailMaxEditFailures = 5 + } + +defaultAgentConfig :: AgentConfig +defaultAgentConfig = + AgentConfig + { agentModel = "gpt-4", + agentTools = [], + agentSystemPrompt = "You are a helpful assistant.", + agentMaxIterations = 10, + agentGuardrails = defaultGuardrails + } + +data EngineConfig = EngineConfig + { engineLLM :: LLM, + engineOnCost :: Int -> Double -> IO (), + engineOnActivity :: Text -> IO (), + engineOnToolCall :: Text -> Text -> IO (), + engineOnAssistant :: Text -> IO (), + engineOnToolResult :: Text -> Bool -> Text -> IO (), + engineOnComplete :: IO (), + engineOnError :: Text -> IO (), + engineOnGuardrail :: GuardrailResult -> IO () + } + +defaultEngineConfig :: EngineConfig +defaultEngineConfig = + EngineConfig + { engineLLM = defaultLLM, + engineOnCost = \_ _ -> pure (), + engineOnActivity = \_ -> pure (), + engineOnToolCall = \_ _ -> pure (), + engineOnAssistant = \_ -> pure (), + engineOnToolResult = \_ _ _ -> pure (), + engineOnComplete = pure (), + engineOnError = \_ -> pure (), + engineOnGuardrail = \_ -> pure () + } + +data AgentResult = AgentResult + { resultFinalMessage :: Text, + resultToolCallCount :: Int, + resultIterations :: Int, + resultTotalCost :: Double, + resultTotalTokens :: Int + } + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON AgentResult + +instance Aeson.FromJSON AgentResult + +data Role = System | User | Assistant | ToolRole + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON Role where + toJSON System = Aeson.String "system" + toJSON User = Aeson.String "user" + toJSON Assistant = Aeson.String "assistant" + toJSON ToolRole = Aeson.String "tool" + +instance Aeson.FromJSON Role where + parseJSON = Aeson.withText "Role" parseRole + where + parseRole "system" = pure System + parseRole "user" = pure User + parseRole "assistant" = pure Assistant + parseRole "tool" = pure ToolRole + parseRole _ = empty + +data Message = Message + { msgRole :: Role, + msgContent :: Text, + msgToolCalls :: Maybe [ToolCall], + msgToolCallId :: Maybe Text + } + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON Message where + toJSON m = + Aeson.object + <| catMaybes + [ Just ("role" .= msgRole m), + Just ("content" .= msgContent m), + ("tool_calls" .=) </ msgToolCalls m, + ("tool_call_id" .=) </ msgToolCallId m + ] + +instance Aeson.FromJSON Message where + parseJSON = + Aeson.withObject "Message" <| \v -> + (Message </ (v .: "role")) + <*> (v .:? "content" .!= "") + <*> (v .:? "tool_calls") + <*> (v .:? "tool_call_id") + +data ToolCall = ToolCall + { tcId :: Text, + tcType :: Text, + tcFunction :: FunctionCall + } + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON ToolCall where + toJSON tc = + Aeson.object + [ "id" .= tcId tc, + "type" .= tcType tc, + "function" .= tcFunction tc + ] + +instance Aeson.FromJSON ToolCall where + parseJSON = + Aeson.withObject "ToolCall" <| \v -> + (ToolCall </ (v .: "id")) + <*> (v .:? "type" .!= "function") + <*> (v .: "function") + +data FunctionCall = FunctionCall + { fcName :: Text, + fcArguments :: Text + } + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON FunctionCall where + toJSON fc = + Aeson.object + [ "name" .= fcName fc, + "arguments" .= fcArguments fc + ] + +instance Aeson.FromJSON FunctionCall where + parseJSON = + Aeson.withObject "FunctionCall" <| \v -> + (FunctionCall </ (v .: "name")) + <*> (v .: "arguments") + +data ToolResult = ToolResult + { trToolCallId :: Text, + trContent :: Text + } + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON ToolResult + +instance Aeson.FromJSON ToolResult + +data ChatCompletionRequest = ChatCompletionRequest + { reqModel :: Text, + reqMessages :: [Message], + reqTools :: Maybe [ToolApi] + } + deriving (Generic) + +instance Aeson.ToJSON ChatCompletionRequest where + toJSON r = + Aeson.object + <| catMaybes + [ Just ("model" .= reqModel r), + Just ("messages" .= reqMessages r), + ("tools" .=) </ reqTools r, + Just ("usage" .= Aeson.object ["include" .= True]) + ] + +data Choice = Choice + { choiceIndex :: Int, + choiceMessage :: Message, + choiceFinishReason :: Maybe Text + } + deriving (Show, Eq, Generic) + +instance Aeson.FromJSON Choice where + parseJSON = + Aeson.withObject "Choice" <| \v -> + (Choice </ (v .: "index")) + <*> (v .: "message") + <*> (v .:? "finish_reason") + +data Usage = Usage + { usagePromptTokens :: Int, + usageCompletionTokens :: Int, + usageTotalTokens :: Int, + usageCost :: Maybe Double + } + deriving (Show, Eq, Generic) + +instance Aeson.FromJSON Usage where + parseJSON = + Aeson.withObject "Usage" <| \v -> + (Usage </ (v .: "prompt_tokens")) + <*> (v .: "completion_tokens") + <*> (v .: "total_tokens") + <*> (v .:? "cost") + +data ChatCompletionResponse = ChatCompletionResponse + { respId :: Text, + respChoices :: [Choice], + respModel :: Text, + respUsage :: Maybe Usage + } + deriving (Show, Eq, Generic) + +instance Aeson.FromJSON ChatCompletionResponse where + parseJSON = + Aeson.withObject "ChatCompletionResponse" <| \v -> + (ChatCompletionResponse </ (v .: "id")) + <*> (v .: "choices") + <*> (v .: "model") + <*> (v .:? "usage") + +data ChatResult = ChatResult + { chatMessage :: Message, + chatUsage :: Maybe Usage + } + deriving (Show, Eq) + +chatWithUsage :: LLM -> [Tool] -> [Message] -> IO (Either Text ChatResult) +chatWithUsage llm tools messages = do + let url = Text.unpack (llmBaseUrl llm) <> "/chat/completions" + req0 <- HTTP.parseRequest url + let toolApis = [encodeToolForApi t | not (null tools), t <- tools] + body = + ChatCompletionRequest + { reqModel = llmModel llm, + reqMessages = messages, + reqTools = if null toolApis then Nothing else Just toolApis + } + baseReq = + HTTP.setRequestMethod "POST" + <| HTTP.setRequestHeader "Content-Type" ["application/json"] + <| HTTP.setRequestHeader "Authorization" ["Bearer " <> TE.encodeUtf8 (llmApiKey llm)] + <| HTTP.setRequestBodyLBS (Aeson.encode body) + <| req0 + req = foldr addHeader baseReq (llmExtraHeaders llm) + addHeader (name, value) = HTTP.addRequestHeader (CI.mk name) value + + response <- HTTP.httpLBS req + let status = HTTP.getResponseStatusCode response + if status >= 200 && status < 300 + then case Aeson.decode (HTTP.getResponseBody response) of + Just resp -> + case respChoices resp of + (c : _) -> pure (Right (ChatResult (choiceMessage c) (respUsage resp))) + [] -> pure (Left "No choices in response") + Nothing -> pure (Left "Failed to parse response") + else pure (Left ("HTTP error: " <> tshow status <> " - " <> TE.decodeUtf8 (BL.toStrict (HTTP.getResponseBody response)))) + +chat :: LLM -> [Tool] -> [Message] -> IO (Either Text Message) +chat llm tools messages = do + result <- chatWithUsage llm tools messages + pure (chatMessage </ result) + +runAgent :: EngineConfig -> AgentConfig -> Text -> IO (Either Text AgentResult) +runAgent engineCfg agentCfg userPrompt = do + let llm = + (engineLLM engineCfg) + { llmModel = agentModel agentCfg + } + tools = agentTools agentCfg + toolMap = buildToolMap tools + systemMsg = Message System (agentSystemPrompt agentCfg) Nothing Nothing + userMsg = Message User userPrompt Nothing Nothing + initialMessages = [systemMsg, userMsg] + + engineOnActivity engineCfg "Starting agent loop" + loop llm tools toolMap initialMessages 0 0 0 0.0 Map.empty 0 0 + where + maxIter = agentMaxIterations agentCfg + guardrails' = agentGuardrails agentCfg + + loop :: + LLM -> + [Tool] -> + Map.Map Text Tool -> + [Message] -> + Int -> + Int -> + Int -> + Double -> + Map.Map Text Int -> + Int -> + Int -> + IO (Either Text AgentResult) + loop llm tools' toolMap msgs iteration totalCalls totalTokens totalCost toolCallCounts testFailures editFailures + | iteration >= maxIter = do + let errMsg = "Max iterations (" <> tshow maxIter <> ") reached" + engineOnError engineCfg errMsg + pure <| Left errMsg + | otherwise = do + let guardrailViolation = findGuardrailViolation guardrails' totalCost totalTokens toolCallCounts testFailures editFailures + case guardrailViolation of + Just (g, errMsg) -> do + engineOnGuardrail engineCfg g + pure <| Left errMsg + Nothing -> do + engineOnActivity engineCfg <| "Iteration " <> tshow (iteration + 1) + result <- chatWithUsage llm tools' msgs + case result of + Left err -> do + engineOnError engineCfg err + pure (Left err) + Right chatRes -> do + let msg = chatMessage chatRes + tokens = maybe 0 usageTotalTokens (chatUsage chatRes) + cost = case chatUsage chatRes +> usageCost of + Just actualCost -> actualCost * 100 + Nothing -> estimateCost (llmModel llm) tokens + engineOnCost engineCfg tokens cost + let newTokens = totalTokens + tokens + newCost = totalCost + cost + let assistantText = msgContent msg + unless (Text.null assistantText) + <| engineOnAssistant engineCfg assistantText + case msgToolCalls msg of + Nothing + | Text.null (msgContent msg) && totalCalls > 0 -> do + engineOnActivity engineCfg "Empty response after tools, prompting for text" + let promptMsg = Message ToolRole "Please provide a response to the user." Nothing Nothing + newMsgs = msgs <> [msg, promptMsg] + loop llm tools' toolMap newMsgs (iteration + 1) totalCalls newTokens newCost toolCallCounts testFailures editFailures + | otherwise -> do + engineOnActivity engineCfg "Agent completed" + engineOnComplete engineCfg + pure + <| Right + <| AgentResult + { resultFinalMessage = msgContent msg, + resultToolCallCount = totalCalls, + resultIterations = iteration + 1, + resultTotalCost = newCost, + resultTotalTokens = newTokens + } + Just [] -> do + engineOnActivity engineCfg "Agent completed (empty tool calls)" + engineOnComplete engineCfg + pure + <| Right + <| AgentResult + { resultFinalMessage = msgContent msg, + resultToolCallCount = totalCalls, + resultIterations = iteration + 1, + resultTotalCost = newCost, + resultTotalTokens = newTokens + } + Just tcs -> do + (toolResults, newTestFailures, newEditFailures) <- executeToolCallsWithTracking engineCfg toolMap tcs testFailures editFailures + let newMsgs = msgs <> [msg] <> toolResults + newCalls = totalCalls + length tcs + newToolCallCounts = updateToolCallCounts toolCallCounts tcs + loop llm tools' toolMap newMsgs (iteration + 1) newCalls newTokens newCost newToolCallCounts newTestFailures newEditFailures + +checkCostGuardrail :: Guardrails -> Double -> GuardrailResult +checkCostGuardrail g cost + | cost > guardrailMaxCostCents g = GuardrailCostExceeded cost (guardrailMaxCostCents g) + | otherwise = GuardrailOk + +checkTokenGuardrail :: Guardrails -> Int -> GuardrailResult +checkTokenGuardrail g tokens + | tokens > guardrailMaxTokens g = GuardrailTokensExceeded tokens (guardrailMaxTokens g) + | otherwise = GuardrailOk + +checkDuplicateGuardrail :: Guardrails -> Map.Map Text Int -> GuardrailResult +checkDuplicateGuardrail g counts = + let maxAllowed = guardrailMaxDuplicateToolCalls g + violations = [(name, count) | (name, count) <- Map.toList counts, count >= maxAllowed] + in case violations of + ((name, count) : _) -> GuardrailDuplicateToolCalls name count + [] -> GuardrailOk + +checkTestFailureGuardrail :: Guardrails -> Int -> GuardrailResult +checkTestFailureGuardrail g failures + | failures >= guardrailMaxTestFailures g = GuardrailTestFailures failures + | otherwise = GuardrailOk + +checkEditFailureGuardrail :: Guardrails -> Int -> GuardrailResult +checkEditFailureGuardrail g failures + | failures >= guardrailMaxEditFailures g = GuardrailEditFailures failures + | otherwise = GuardrailOk + +updateToolCallCounts :: Map.Map Text Int -> [ToolCall] -> Map.Map Text Int +updateToolCallCounts = + foldr (\tc m -> Map.insertWith (+) (fcName (tcFunction tc)) 1 m) + +findGuardrailViolation :: Guardrails -> Double -> Int -> Map.Map Text Int -> Int -> Int -> Maybe (GuardrailResult, Text) +findGuardrailViolation g cost tokens toolCallCounts testFailures editFailures = + case checkCostGuardrail g cost of + r@(GuardrailCostExceeded actual limit) -> + Just (r, "Guardrail: cost budget exceeded (" <> tshow actual <> "/" <> tshow limit <> " cents)") + _ -> case checkTokenGuardrail g tokens of + r@(GuardrailTokensExceeded actual limit) -> + Just (r, "Guardrail: token budget exceeded (" <> tshow actual <> "/" <> tshow limit <> " tokens)") + _ -> case checkDuplicateGuardrail g toolCallCounts of + r@(GuardrailDuplicateToolCalls tool count) -> + Just (r, "Guardrail: duplicate tool calls (" <> tool <> " called " <> tshow count <> " times)") + _ -> case checkTestFailureGuardrail g testFailures of + r@(GuardrailTestFailures count) -> + Just (r, "Guardrail: too many test failures (" <> tshow count <> ")") + _ -> case checkEditFailureGuardrail g editFailures of + r@(GuardrailEditFailures count) -> + Just (r, "Guardrail: too many edit_file failures (" <> tshow count <> " 'old_str not found' errors)") + _ -> Nothing + +buildToolMap :: [Tool] -> Map.Map Text Tool +buildToolMap = Map.fromList <. map (\t -> (toolName t, t)) + +-- | Track both test failures and edit failures +-- Returns (messages, testFailures, editFailures) +executeToolCallsWithTracking :: EngineConfig -> Map.Map Text Tool -> [ToolCall] -> Int -> Int -> IO ([Message], Int, Int) +executeToolCallsWithTracking engineCfg toolMap tcs initialTestFailures initialEditFailures = do + results <- traverse executeSingle tcs + let msgs = map (\(m, _, _) -> m) results + testDeltas = map (\(_, t, _) -> t) results + editDeltas = map (\(_, _, e) -> e) results + totalTestFailures = initialTestFailures + sum testDeltas + totalEditFailures = initialEditFailures + sum editDeltas + pure (msgs, totalTestFailures, totalEditFailures) + where + executeSingle tc = do + let name = fcName (tcFunction tc) + argsText = fcArguments (tcFunction tc) + callId = tcId tc + engineOnActivity engineCfg <| "Executing tool: " <> name + engineOnToolCall engineCfg name argsText + case Map.lookup name toolMap of + Nothing -> do + let errMsg = "Tool not found: " <> name + engineOnToolResult engineCfg name False errMsg + pure (Message ToolRole errMsg Nothing (Just callId), 0, 0) + Just tool -> do + case Aeson.decode (BL.fromStrict (TE.encodeUtf8 argsText)) of + Nothing -> do + let errMsg = "Invalid JSON arguments: " <> argsText + engineOnToolResult engineCfg name False errMsg + pure (Message ToolRole errMsg Nothing (Just callId), 0, 0) + Just args -> do + resultValue <- toolExecute tool args + let resultText = TE.decodeUtf8 (BL.toStrict (Aeson.encode resultValue)) + isTestCall = name == "bash" && ("bild --test" `Text.isInfixOf` argsText || "bild -t" `Text.isInfixOf` argsText) + isTestFailure = isTestCall && isFailureResult resultValue + testDelta = if isTestFailure then 1 else 0 + isEditFailure = name == "edit_file" && isOldStrNotFoundError resultValue + editDelta = if isEditFailure then 1 else 0 + engineOnToolResult engineCfg name True resultText + pure (Message ToolRole resultText Nothing (Just callId), testDelta, editDelta) + + isFailureResult :: Aeson.Value -> Bool + isFailureResult (Aeson.Object obj) = + case KeyMap.lookup "exit_code" obj of + Just (Aeson.Number n) -> n /= 0 + _ -> False + isFailureResult (Aeson.String s) = + "error" + `Text.isInfixOf` Text.toLower s + || "failed" + `Text.isInfixOf` Text.toLower s + || "FAILED" + `Text.isInfixOf` s + isFailureResult _ = False + + isOldStrNotFoundError :: Aeson.Value -> Bool + isOldStrNotFoundError (Aeson.Object obj) = + case KeyMap.lookup "error" obj of + Just (Aeson.String s) -> "old_str not found" `Text.isInfixOf` s + _ -> False + isOldStrNotFoundError _ = False + +-- | Estimate cost in cents from token count. +-- Uses blended input/output rates (roughly 2:1 output:input ratio). +-- Prices as of Dec 2024 from OpenRouter. +estimateCost :: Text -> Int -> Double +estimateCost model tokens + | "gpt-4o-mini" `Text.isInfixOf` model = fromIntegral tokens * 0.04 / 1000 + | "gpt-4o" `Text.isInfixOf` model = fromIntegral tokens * 0.7 / 1000 + | "gemini-2.0-flash" `Text.isInfixOf` model = fromIntegral tokens * 0.15 / 1000 + | "gemini-2.5-flash" `Text.isInfixOf` model = fromIntegral tokens * 0.15 / 1000 + | "claude-sonnet-4.5" `Text.isInfixOf` model = fromIntegral tokens * 0.9 / 1000 + | "claude-sonnet-4" `Text.isInfixOf` model = fromIntegral tokens * 0.9 / 1000 + | "claude-3-haiku" `Text.isInfixOf` model = fromIntegral tokens * 0.1 / 1000 + | "claude" `Text.isInfixOf` model = fromIntegral tokens * 0.9 / 1000 + | otherwise = fromIntegral tokens * 0.5 / 1000 + +-- | Run agent with a Provider instead of LLM. +-- This is the new preferred way to run agents with multiple backend support. +runAgentWithProvider :: EngineConfig -> Provider.Provider -> AgentConfig -> Text -> IO (Either Text AgentResult) +runAgentWithProvider engineCfg provider agentCfg userPrompt = do + let tools = agentTools agentCfg + toolApis = map encodeToolForProvider tools + toolMap = buildToolMap tools + systemMsg = providerMessage Provider.System (agentSystemPrompt agentCfg) + userMsg = providerMessage Provider.User userPrompt + initialMessages = [systemMsg, userMsg] + + engineOnActivity engineCfg "Starting agent loop (Provider)" + loopProvider provider toolApis toolMap initialMessages 0 0 0 0.0 Map.empty 0 0 + where + maxIter = agentMaxIterations agentCfg + guardrails' = agentGuardrails agentCfg + + providerMessage :: Provider.Role -> Text -> Provider.Message + providerMessage role content = Provider.Message role content Nothing Nothing + + loopProvider :: + Provider.Provider -> + [Provider.ToolApi] -> + Map.Map Text Tool -> + [Provider.Message] -> + Int -> + Int -> + Int -> + Double -> + Map.Map Text Int -> + Int -> + Int -> + IO (Either Text AgentResult) + loopProvider prov toolApis' toolMap msgs iteration totalCalls totalTokens totalCost toolCallCounts testFailures editFailures + | iteration >= maxIter = do + let errMsg = "Max iterations (" <> tshow maxIter <> ") reached" + engineOnError engineCfg errMsg + pure <| Left errMsg + | otherwise = do + let guardrailViolation = findGuardrailViolation guardrails' totalCost totalTokens toolCallCounts testFailures editFailures + case guardrailViolation of + Just (g, errMsg) -> do + engineOnGuardrail engineCfg g + pure <| Left errMsg + Nothing -> do + engineOnActivity engineCfg <| "Iteration " <> tshow (iteration + 1) + result <- Provider.chatWithUsage prov toolApis' msgs + case result of + Left err -> do + engineOnError engineCfg err + pure (Left err) + Right chatRes -> do + let msg = Provider.chatMessage chatRes + tokens = maybe 0 Provider.usageTotalTokens (Provider.chatUsage chatRes) + cost = case Provider.chatUsage chatRes +> Provider.usageCost of + Just actualCost -> actualCost * 100 + Nothing -> estimateCost (getProviderModel prov) tokens + engineOnCost engineCfg tokens cost + let newTokens = totalTokens + tokens + newCost = totalCost + cost + let assistantText = Provider.msgContent msg + unless (Text.null assistantText) + <| engineOnAssistant engineCfg assistantText + case Provider.msgToolCalls msg of + Nothing + | Text.null (Provider.msgContent msg) && totalCalls > 0 -> do + engineOnActivity engineCfg "Empty response after tools, prompting for text" + let promptMsg = Provider.Message Provider.ToolRole "Please provide a response to the user." Nothing Nothing + newMsgs = msgs <> [msg, promptMsg] + loopProvider prov toolApis' toolMap newMsgs (iteration + 1) totalCalls newTokens newCost toolCallCounts testFailures editFailures + | otherwise -> do + engineOnActivity engineCfg "Agent completed" + engineOnComplete engineCfg + pure + <| Right + <| AgentResult + { resultFinalMessage = Provider.msgContent msg, + resultToolCallCount = totalCalls, + resultIterations = iteration + 1, + resultTotalCost = newCost, + resultTotalTokens = newTokens + } + Just [] -> do + engineOnActivity engineCfg "Agent completed (empty tool calls)" + engineOnComplete engineCfg + pure + <| Right + <| AgentResult + { resultFinalMessage = Provider.msgContent msg, + resultToolCallCount = totalCalls, + resultIterations = iteration + 1, + resultTotalCost = newCost, + resultTotalTokens = newTokens + } + Just tcs -> do + (toolResults, newTestFailures, newEditFailures) <- executeProviderToolCalls engineCfg toolMap tcs testFailures editFailures + let newMsgs = msgs <> [msg] <> toolResults + newCalls = totalCalls + length tcs + newToolCallCounts = updateProviderToolCallCounts toolCallCounts tcs + loopProvider prov toolApis' toolMap newMsgs (iteration + 1) newCalls newTokens newCost newToolCallCounts newTestFailures newEditFailures + + getProviderModel :: Provider.Provider -> Text + getProviderModel (Provider.OpenRouter cfg) = Provider.providerModel cfg + getProviderModel (Provider.Ollama cfg) = Provider.providerModel cfg + getProviderModel (Provider.AmpCLI _) = "amp" + + updateProviderToolCallCounts :: Map.Map Text Int -> [Provider.ToolCall] -> Map.Map Text Int + updateProviderToolCallCounts = + foldr (\tc m -> Map.insertWith (+) (Provider.fcName (Provider.tcFunction tc)) 1 m) + + executeProviderToolCalls :: EngineConfig -> Map.Map Text Tool -> [Provider.ToolCall] -> Int -> Int -> IO ([Provider.Message], Int, Int) + executeProviderToolCalls eCfg tMap tcs initialTestFailures initialEditFailures = do + results <- traverse (executeSingleProvider eCfg tMap) tcs + let msgs = map (\(m, _, _) -> m) results + testDeltas = map (\(_, t, _) -> t) results + editDeltas = map (\(_, _, e) -> e) results + totalTestFail = initialTestFailures + sum testDeltas + totalEditFail = initialEditFailures + sum editDeltas + pure (msgs, totalTestFail, totalEditFail) + + executeSingleProvider :: EngineConfig -> Map.Map Text Tool -> Provider.ToolCall -> IO (Provider.Message, Int, Int) + executeSingleProvider eCfg tMap tc = do + let name = Provider.fcName (Provider.tcFunction tc) + argsText = Provider.fcArguments (Provider.tcFunction tc) + callId = Provider.tcId tc + engineOnActivity eCfg <| "Executing tool: " <> name + engineOnToolCall eCfg name argsText + case Map.lookup name tMap of + Nothing -> do + let errMsg = "Tool not found: " <> name + engineOnToolResult eCfg name False errMsg + pure (Provider.Message Provider.ToolRole errMsg Nothing (Just callId), 0, 0) + Just tool -> do + case Aeson.decode (BL.fromStrict (TE.encodeUtf8 argsText)) of + Nothing -> do + let errMsg = "Invalid JSON arguments: " <> argsText + engineOnToolResult eCfg name False errMsg + pure (Provider.Message Provider.ToolRole errMsg Nothing (Just callId), 0, 0) + Just args -> do + resultValue <- toolExecute tool args + let resultText = TE.decodeUtf8 (BL.toStrict (Aeson.encode resultValue)) + isTestCall = name == "bash" && ("bild --test" `Text.isInfixOf` argsText || "bild -t" `Text.isInfixOf` argsText) + isTestFailure = isTestCall && isFailureResultProvider resultValue + testDelta = if isTestFailure then 1 else 0 + isEditFailure = name == "edit_file" && isOldStrNotFoundProvider resultValue + editDelta = if isEditFailure then 1 else 0 + engineOnToolResult eCfg name True resultText + pure (Provider.Message Provider.ToolRole resultText Nothing (Just callId), testDelta, editDelta) + + isFailureResultProvider :: Aeson.Value -> Bool + isFailureResultProvider (Aeson.Object obj) = + case KeyMap.lookup "exit_code" obj of + Just (Aeson.Number n) -> n /= 0 + _ -> False + isFailureResultProvider (Aeson.String s) = + "error" + `Text.isInfixOf` Text.toLower s + || "failed" + `Text.isInfixOf` Text.toLower s + || "FAILED" + `Text.isInfixOf` s + isFailureResultProvider _ = False + + isOldStrNotFoundProvider :: Aeson.Value -> Bool + isOldStrNotFoundProvider (Aeson.Object obj) = + case KeyMap.lookup "error" obj of + Just (Aeson.String s) -> "old_str not found" `Text.isInfixOf` s + _ -> False + isOldStrNotFoundProvider _ = False + +runAgentWithProviderStreaming :: + EngineConfig -> + Provider.Provider -> + AgentConfig -> + Text -> + (Text -> IO ()) -> + IO (Either Text AgentResult) +runAgentWithProviderStreaming engineCfg provider agentCfg userPrompt onStreamChunk = do + let tools = agentTools agentCfg + toolApis = map encodeToolForProvider tools + toolMap = buildToolMap tools + systemMsg = providerMessage Provider.System (agentSystemPrompt agentCfg) + userMsg = providerMessage Provider.User userPrompt + initialMessages = [systemMsg, userMsg] + + engineOnActivity engineCfg "Starting agent loop (Provider+Streaming)" + loopProviderStreaming provider toolApis toolMap initialMessages 0 0 0 0.0 Map.empty 0 0 + where + maxIter = agentMaxIterations agentCfg + guardrails' = agentGuardrails agentCfg + + providerMessage :: Provider.Role -> Text -> Provider.Message + providerMessage role content = Provider.Message role content Nothing Nothing + + loopProviderStreaming :: + Provider.Provider -> + [Provider.ToolApi] -> + Map.Map Text Tool -> + [Provider.Message] -> + Int -> + Int -> + Int -> + Double -> + Map.Map Text Int -> + Int -> + Int -> + IO (Either Text AgentResult) + loopProviderStreaming prov toolApis' toolMap msgs iteration totalCalls totalTokens totalCost toolCallCounts testFailures editFailures + | iteration >= maxIter = do + let errMsg = "Max iterations (" <> tshow maxIter <> ") reached" + engineOnError engineCfg errMsg + pure <| Left errMsg + | otherwise = do + let guardrailViolation = findGuardrailViolation guardrails' totalCost totalTokens toolCallCounts testFailures editFailures + case guardrailViolation of + Just (g, errMsg) -> do + engineOnGuardrail engineCfg g + pure <| Left errMsg + Nothing -> do + engineOnActivity engineCfg <| "Iteration " <> tshow (iteration + 1) + hasToolCalls <- newIORef False + result <- + Provider.chatStream prov toolApis' msgs <| \case + Provider.StreamContent txt -> onStreamChunk txt + Provider.StreamToolCall _ -> writeIORef hasToolCalls True + Provider.StreamToolCallDelta _ -> writeIORef hasToolCalls True + Provider.StreamError err -> engineOnError engineCfg err + Provider.StreamDone _ -> pure () + case result of + Left err -> do + engineOnError engineCfg err + pure (Left err) + Right chatRes -> do + let msg = Provider.chatMessage chatRes + tokens = maybe 0 Provider.usageTotalTokens (Provider.chatUsage chatRes) + cost = case Provider.chatUsage chatRes +> Provider.usageCost of + Just actualCost -> actualCost * 100 + Nothing -> estimateCost (getProviderModelStreaming prov) tokens + engineOnCost engineCfg tokens cost + let newTokens = totalTokens + tokens + newCost = totalCost + cost + let assistantText = Provider.msgContent msg + unless (Text.null assistantText) + <| engineOnAssistant engineCfg assistantText + case Provider.msgToolCalls msg of + Nothing + | Text.null (Provider.msgContent msg) && totalCalls > 0 -> do + engineOnActivity engineCfg "Empty response after tools, prompting for text" + let promptMsg = Provider.Message Provider.ToolRole "Please provide a response to the user." Nothing Nothing + newMsgs = msgs <> [msg, promptMsg] + loopProviderStreaming prov toolApis' toolMap newMsgs (iteration + 1) totalCalls newTokens newCost toolCallCounts testFailures editFailures + | otherwise -> do + engineOnActivity engineCfg "Agent completed" + engineOnComplete engineCfg + pure + <| Right + <| AgentResult + { resultFinalMessage = Provider.msgContent msg, + resultToolCallCount = totalCalls, + resultIterations = iteration + 1, + resultTotalCost = newCost, + resultTotalTokens = newTokens + } + Just [] -> do + engineOnActivity engineCfg "Agent completed (empty tool calls)" + engineOnComplete engineCfg + pure + <| Right + <| AgentResult + { resultFinalMessage = Provider.msgContent msg, + resultToolCallCount = totalCalls, + resultIterations = iteration + 1, + resultTotalCost = newCost, + resultTotalTokens = newTokens + } + Just tcs -> do + (toolResults, newTestFailures, newEditFailures) <- executeToolCallsStreaming engineCfg toolMap tcs testFailures editFailures + let newMsgs = msgs <> [msg] <> toolResults + newCalls = totalCalls + length tcs + newToolCallCounts = updateToolCallCountsStreaming toolCallCounts tcs + loopProviderStreaming prov toolApis' toolMap newMsgs (iteration + 1) newCalls newTokens newCost newToolCallCounts newTestFailures newEditFailures + + getProviderModelStreaming :: Provider.Provider -> Text + getProviderModelStreaming (Provider.OpenRouter cfg) = Provider.providerModel cfg + getProviderModelStreaming (Provider.Ollama cfg) = Provider.providerModel cfg + getProviderModelStreaming (Provider.AmpCLI _) = "amp" + + updateToolCallCountsStreaming :: Map.Map Text Int -> [Provider.ToolCall] -> Map.Map Text Int + updateToolCallCountsStreaming = + foldr (\tc m -> Map.insertWith (+) (Provider.fcName (Provider.tcFunction tc)) 1 m) + + executeToolCallsStreaming :: EngineConfig -> Map.Map Text Tool -> [Provider.ToolCall] -> Int -> Int -> IO ([Provider.Message], Int, Int) + executeToolCallsStreaming eCfg tMap tcs initialTestFailures initialEditFailures = do + results <- traverse (executeSingleStreaming eCfg tMap) tcs + let msgs = map (\(m, _, _) -> m) results + testDeltas = map (\(_, t, _) -> t) results + editDeltas = map (\(_, _, e) -> e) results + totalTestFail = initialTestFailures + sum testDeltas + totalEditFail = initialEditFailures + sum editDeltas + pure (msgs, totalTestFail, totalEditFail) + + executeSingleStreaming :: EngineConfig -> Map.Map Text Tool -> Provider.ToolCall -> IO (Provider.Message, Int, Int) + executeSingleStreaming eCfg tMap tc = do + let name = Provider.fcName (Provider.tcFunction tc) + argsText = Provider.fcArguments (Provider.tcFunction tc) + callId = Provider.tcId tc + engineOnActivity eCfg <| "Executing tool: " <> name + engineOnToolCall eCfg name argsText + case Map.lookup name tMap of + Nothing -> do + let errMsg = "Tool not found: " <> name + engineOnToolResult eCfg name False errMsg + pure (Provider.Message Provider.ToolRole errMsg Nothing (Just callId), 0, 0) + Just tool -> do + case Aeson.decode (BL.fromStrict (TE.encodeUtf8 argsText)) of + Nothing -> do + let errMsg = "Invalid JSON arguments: " <> argsText + engineOnToolResult eCfg name False errMsg + pure (Provider.Message Provider.ToolRole errMsg Nothing (Just callId), 0, 0) + Just args -> do + resultValue <- toolExecute tool args + let resultText = TE.decodeUtf8 (BL.toStrict (Aeson.encode resultValue)) + isTestCall = name == "bash" && ("bild --test" `Text.isInfixOf` argsText || "bild -t" `Text.isInfixOf` argsText) + isTestFailure = isTestCall && isFailureResultStreaming resultValue + testDelta = if isTestFailure then 1 else 0 + isEditFailure = name == "edit_file" && isOldStrNotFoundStreaming resultValue + editDelta = if isEditFailure then 1 else 0 + engineOnToolResult eCfg name True resultText + pure (Provider.Message Provider.ToolRole resultText Nothing (Just callId), testDelta, editDelta) + + isFailureResultStreaming :: Aeson.Value -> Bool + isFailureResultStreaming (Aeson.Object obj) = + case KeyMap.lookup "exit_code" obj of + Just (Aeson.Number n) -> n /= 0 + _ -> False + isFailureResultStreaming (Aeson.String s) = + "error" + `Text.isInfixOf` Text.toLower s + || "failed" + `Text.isInfixOf` Text.toLower s + || "FAILED" + `Text.isInfixOf` s + isFailureResultStreaming _ = False + + isOldStrNotFoundStreaming :: Aeson.Value -> Bool + isOldStrNotFoundStreaming (Aeson.Object obj) = + case KeyMap.lookup "error" obj of + Just (Aeson.String s) -> "old_str not found" `Text.isInfixOf` s + _ -> False + isOldStrNotFoundStreaming _ = False |
