diff options
Diffstat (limited to 'Omni/Agent/Engine.hs')
| -rw-r--r-- | Omni/Agent/Engine.hs | 176 |
1 files changed, 176 insertions, 0 deletions
diff --git a/Omni/Agent/Engine.hs b/Omni/Agent/Engine.hs index 4ee5e5d..fe3b3d5 100644 --- a/Omni/Agent/Engine.hs +++ b/Omni/Agent/Engine.hs @@ -30,12 +30,15 @@ module Omni.Agent.Engine ChatCompletionResponse (..), Choice (..), Usage (..), + ToolApi (..), + encodeToolForApi, defaultLLM, defaultEngineConfig, defaultAgentConfig, defaultGuardrails, chat, runAgent, + runAgentWithProvider, main, test, ) @@ -51,6 +54,7 @@ import qualified Data.Map.Strict as Map import qualified Data.Text as Text import qualified Data.Text.Encoding as TE import qualified Network.HTTP.Simple as HTTP +import qualified Omni.Agent.Provider as Provider import qualified Omni.Test as Test main :: IO () @@ -264,6 +268,14 @@ encodeToolForApi t = toolApiParameters = toolJsonSchema t } +encodeToolForProvider :: Tool -> Provider.ToolApi +encodeToolForProvider t = + Provider.ToolApi + { Provider.toolApiName = toolName t, + Provider.toolApiDescription = toolDescription t, + Provider.toolApiParameters = toolJsonSchema t + } + data LLM = LLM { llmBaseUrl :: Text, llmApiKey :: Text, @@ -809,3 +821,167 @@ estimateCost model tokens | "gpt-4" `Text.isInfixOf` model = fromIntegral tokens * 3 / 100000 | "claude" `Text.isInfixOf` model = fromIntegral tokens * 3 / 100000 | otherwise = fromIntegral tokens / 100000 + +-- | Run agent with a Provider instead of LLM. +-- This is the new preferred way to run agents with multiple backend support. +runAgentWithProvider :: EngineConfig -> Provider.Provider -> AgentConfig -> Text -> IO (Either Text AgentResult) +runAgentWithProvider engineCfg provider agentCfg userPrompt = do + let tools = agentTools agentCfg + toolApis = map encodeToolForProvider tools + toolMap = buildToolMap tools + systemMsg = providerMessage Provider.System (agentSystemPrompt agentCfg) + userMsg = providerMessage Provider.User userPrompt + initialMessages = [systemMsg, userMsg] + + engineOnActivity engineCfg "Starting agent loop (Provider)" + loopProvider provider toolApis toolMap initialMessages 0 0 0 0.0 Map.empty 0 0 + where + maxIter = agentMaxIterations agentCfg + guardrails' = agentGuardrails agentCfg + + providerMessage :: Provider.Role -> Text -> Provider.Message + providerMessage role content = Provider.Message role content Nothing Nothing + + loopProvider :: + Provider.Provider -> + [Provider.ToolApi] -> + Map.Map Text Tool -> + [Provider.Message] -> + Int -> + Int -> + Int -> + Double -> + Map.Map Text Int -> + Int -> + Int -> + IO (Either Text AgentResult) + loopProvider prov toolApis' toolMap msgs iteration totalCalls totalTokens totalCost toolCallCounts testFailures editFailures + | iteration >= maxIter = do + let errMsg = "Max iterations (" <> tshow maxIter <> ") reached" + engineOnError engineCfg errMsg + pure <| Left errMsg + | otherwise = do + let guardrailViolation = findGuardrailViolation guardrails' totalCost totalTokens toolCallCounts testFailures editFailures + case guardrailViolation of + Just (g, errMsg) -> do + engineOnGuardrail engineCfg g + pure <| Left errMsg + Nothing -> do + engineOnActivity engineCfg <| "Iteration " <> tshow (iteration + 1) + result <- Provider.chatWithUsage prov toolApis' msgs + case result of + Left err -> do + engineOnError engineCfg err + pure (Left err) + Right chatRes -> do + let msg = Provider.chatMessage chatRes + tokens = maybe 0 Provider.usageTotalTokens (Provider.chatUsage chatRes) + cost = case Provider.chatUsage chatRes +> Provider.usageCost of + Just actualCost -> actualCost * 100 + Nothing -> estimateCost (getProviderModel prov) tokens + engineOnCost engineCfg tokens cost + let newTokens = totalTokens + tokens + newCost = totalCost + cost + let assistantText = Provider.msgContent msg + unless (Text.null assistantText) + <| engineOnAssistant engineCfg assistantText + case Provider.msgToolCalls msg of + Nothing -> do + engineOnActivity engineCfg "Agent completed" + engineOnComplete engineCfg + pure + <| Right + <| AgentResult + { resultFinalMessage = Provider.msgContent msg, + resultToolCallCount = totalCalls, + resultIterations = iteration + 1, + resultTotalCost = newCost, + resultTotalTokens = newTokens + } + Just [] -> do + engineOnActivity engineCfg "Agent completed (empty tool calls)" + engineOnComplete engineCfg + pure + <| Right + <| AgentResult + { resultFinalMessage = Provider.msgContent msg, + resultToolCallCount = totalCalls, + resultIterations = iteration + 1, + resultTotalCost = newCost, + resultTotalTokens = newTokens + } + Just tcs -> do + (toolResults, newTestFailures, newEditFailures) <- executeProviderToolCalls engineCfg toolMap tcs testFailures editFailures + let newMsgs = msgs <> [msg] <> toolResults + newCalls = totalCalls + length tcs + newToolCallCounts = updateProviderToolCallCounts toolCallCounts tcs + loopProvider prov toolApis' toolMap newMsgs (iteration + 1) newCalls newTokens newCost newToolCallCounts newTestFailures newEditFailures + + getProviderModel :: Provider.Provider -> Text + getProviderModel (Provider.OpenRouter cfg) = Provider.providerModel cfg + getProviderModel (Provider.Ollama cfg) = Provider.providerModel cfg + getProviderModel (Provider.AmpCLI _) = "amp" + + updateProviderToolCallCounts :: Map.Map Text Int -> [Provider.ToolCall] -> Map.Map Text Int + updateProviderToolCallCounts = + foldr (\tc m -> Map.insertWith (+) (Provider.fcName (Provider.tcFunction tc)) 1 m) + + executeProviderToolCalls :: EngineConfig -> Map.Map Text Tool -> [Provider.ToolCall] -> Int -> Int -> IO ([Provider.Message], Int, Int) + executeProviderToolCalls eCfg tMap tcs initialTestFailures initialEditFailures = do + results <- traverse (executeSingleProvider eCfg tMap) tcs + let msgs = map (\(m, _, _) -> m) results + testDeltas = map (\(_, t, _) -> t) results + editDeltas = map (\(_, _, e) -> e) results + totalTestFail = initialTestFailures + sum testDeltas + totalEditFail = initialEditFailures + sum editDeltas + pure (msgs, totalTestFail, totalEditFail) + + executeSingleProvider :: EngineConfig -> Map.Map Text Tool -> Provider.ToolCall -> IO (Provider.Message, Int, Int) + executeSingleProvider eCfg tMap tc = do + let name = Provider.fcName (Provider.tcFunction tc) + argsText = Provider.fcArguments (Provider.tcFunction tc) + callId = Provider.tcId tc + engineOnActivity eCfg <| "Executing tool: " <> name + engineOnToolCall eCfg name argsText + case Map.lookup name tMap of + Nothing -> do + let errMsg = "Tool not found: " <> name + engineOnToolResult eCfg name False errMsg + pure (Provider.Message Provider.ToolRole errMsg Nothing (Just callId), 0, 0) + Just tool -> do + case Aeson.decode (BL.fromStrict (TE.encodeUtf8 argsText)) of + Nothing -> do + let errMsg = "Invalid JSON arguments: " <> argsText + engineOnToolResult eCfg name False errMsg + pure (Provider.Message Provider.ToolRole errMsg Nothing (Just callId), 0, 0) + Just args -> do + resultValue <- toolExecute tool args + let resultText = TE.decodeUtf8 (BL.toStrict (Aeson.encode resultValue)) + isTestCall = name == "bash" && ("bild --test" `Text.isInfixOf` argsText || "bild -t" `Text.isInfixOf` argsText) + isTestFailure = isTestCall && isFailureResultProvider resultValue + testDelta = if isTestFailure then 1 else 0 + isEditFailure = name == "edit_file" && isOldStrNotFoundProvider resultValue + editDelta = if isEditFailure then 1 else 0 + engineOnToolResult eCfg name True resultText + pure (Provider.Message Provider.ToolRole resultText Nothing (Just callId), testDelta, editDelta) + + isFailureResultProvider :: Aeson.Value -> Bool + isFailureResultProvider (Aeson.Object obj) = + case KeyMap.lookup "exit_code" obj of + Just (Aeson.Number n) -> n /= 0 + _ -> False + isFailureResultProvider (Aeson.String s) = + "error" + `Text.isInfixOf` Text.toLower s + || "failed" + `Text.isInfixOf` Text.toLower s + || "FAILED" + `Text.isInfixOf` s + isFailureResultProvider _ = False + + isOldStrNotFoundProvider :: Aeson.Value -> Bool + isOldStrNotFoundProvider (Aeson.Object obj) = + case KeyMap.lookup "error" obj of + Just (Aeson.String s) -> "old_str not found" `Text.isInfixOf` s + _ -> False + isOldStrNotFoundProvider _ = False |
