diff options
Diffstat (limited to 'Omni/Agent')
| -rw-r--r-- | Omni/Agent/Engine.hs | 228 |
1 files changed, 223 insertions, 5 deletions
diff --git a/Omni/Agent/Engine.hs b/Omni/Agent/Engine.hs index ac6c517..10b36b2 100644 --- a/Omni/Agent/Engine.hs +++ b/Omni/Agent/Engine.hs @@ -15,7 +15,9 @@ module Omni.Agent.Engine ( Tool (..), LLM (..), + EngineConfig (..), AgentConfig (..), + AgentResult (..), Message (..), Role (..), ToolCall (..), @@ -24,9 +26,12 @@ module Omni.Agent.Engine ChatCompletionRequest (..), ChatCompletionResponse (..), Choice (..), + Usage (..), defaultLLM, + defaultEngineConfig, defaultAgentConfig, chat, + runAgent, main, test, ) @@ -36,6 +41,7 @@ import Alpha import Data.Aeson ((.!=), (.:), (.:?), (.=)) import qualified Data.Aeson as Aeson import qualified Data.ByteString.Lazy as BL +import qualified Data.Map.Strict as Map import qualified Data.Text as Text import qualified Data.Text.Encoding as TE import qualified Network.HTTP.Simple as HTTP @@ -68,7 +74,58 @@ test = Test.unit "defaultLLM has correct endpoint" <| do llmBaseUrl defaultLLM Test.@=? "https://api.openai.com", Test.unit "defaultAgentConfig has sensible defaults" <| do - agentMaxIterations defaultAgentConfig Test.@=? 10 + agentMaxIterations defaultAgentConfig Test.@=? 10, + Test.unit "defaultEngineConfig has no-op callbacks" <| do + engineOnCost defaultEngineConfig 100 5 + engineOnActivity defaultEngineConfig "test" + engineOnToolCall defaultEngineConfig "tool" "result" + True Test.@=? True, + Test.unit "buildToolMap creates correct map" <| do + let tool1 = + Tool + { toolName = "tool1", + toolDescription = "First tool", + toolJsonSchema = Aeson.object [], + toolExecute = \_ -> pure Aeson.Null + } + tool2 = + Tool + { toolName = "tool2", + toolDescription = "Second tool", + toolJsonSchema = Aeson.object [], + toolExecute = \_ -> pure Aeson.Null + } + toolMap = buildToolMap [tool1, tool2] + Map.size toolMap Test.@=? 2 + Map.member "tool1" toolMap Test.@=? True + Map.member "tool2" toolMap Test.@=? True, + Test.unit "Usage JSON parsing" <| do + let json = "{\"prompt_tokens\":100,\"completion_tokens\":50,\"total_tokens\":150}" + case Aeson.decode json of + Nothing -> Test.assertFailure "Failed to decode usage" + Just usage -> do + usagePromptTokens usage Test.@=? 100 + usageCompletionTokens usage Test.@=? 50 + usageTotalTokens usage Test.@=? 150, + Test.unit "AgentResult JSON roundtrip" <| do + let result = + AgentResult + { resultFinalMessage = "Done", + resultToolCallCount = 3, + resultIterations = 2, + resultTotalCost = 50, + resultTotalTokens = 1500 + } + case Aeson.decode (Aeson.encode result) of + Nothing -> Test.assertFailure "Failed to decode AgentResult" + Just decoded -> do + resultFinalMessage decoded Test.@=? "Done" + resultToolCallCount decoded Test.@=? 3 + resultIterations decoded Test.@=? 2, + Test.unit "estimateCost calculates correctly" <| do + let gpt4oCost = estimateCost "gpt-4o" 1000 + gpt4oMiniCost = estimateCost "gpt-4o-mini" 1000 + (gpt4oCost >= gpt4oMiniCost) Test.@=? True ] data Tool = Tool @@ -148,6 +205,35 @@ defaultAgentConfig = agentMaxIterations = 10 } +data EngineConfig = EngineConfig + { engineLLM :: LLM, + engineOnCost :: Int -> Int -> IO (), + engineOnActivity :: Text -> IO (), + engineOnToolCall :: Text -> Text -> IO () + } + +defaultEngineConfig :: EngineConfig +defaultEngineConfig = + EngineConfig + { engineLLM = defaultLLM, + engineOnCost = \_ _ -> pure (), + engineOnActivity = \_ -> pure (), + engineOnToolCall = \_ _ -> pure () + } + +data AgentResult = AgentResult + { resultFinalMessage :: Text, + resultToolCallCount :: Int, + resultIterations :: Int, + resultTotalCost :: Int, + resultTotalTokens :: Int + } + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON AgentResult + +instance Aeson.FromJSON AgentResult + data Role = System | User | Assistant | ToolRole deriving (Show, Eq, Generic) @@ -273,10 +359,25 @@ instance Aeson.FromJSON Choice where <*> (v .: "message") <*> (v .:? "finish_reason") +data Usage = Usage + { usagePromptTokens :: Int, + usageCompletionTokens :: Int, + usageTotalTokens :: Int + } + deriving (Show, Eq, Generic) + +instance Aeson.FromJSON Usage where + parseJSON = + Aeson.withObject "Usage" <| \v -> + (Usage </ (v .: "prompt_tokens")) + <*> (v .: "completion_tokens") + <*> (v .: "total_tokens") + data ChatCompletionResponse = ChatCompletionResponse { respId :: Text, respChoices :: [Choice], - respModel :: Text + respModel :: Text, + respUsage :: Maybe Usage } deriving (Show, Eq, Generic) @@ -286,9 +387,16 @@ instance Aeson.FromJSON ChatCompletionResponse where (ChatCompletionResponse </ (v .: "id")) <*> (v .: "choices") <*> (v .: "model") + <*> (v .:? "usage") -chat :: LLM -> [Tool] -> [Message] -> IO (Either Text Message) -chat llm tools messages = do +data ChatResult = ChatResult + { chatMessage :: Message, + chatUsage :: Maybe Usage + } + deriving (Show, Eq) + +chatWithUsage :: LLM -> [Tool] -> [Message] -> IO (Either Text ChatResult) +chatWithUsage llm tools messages = do let url = Text.unpack (llmBaseUrl llm) <> "/v1/chat/completions" req0 <- HTTP.parseRequest url let toolApis = [encodeToolForApi t | not (null tools), t <- tools] @@ -311,7 +419,117 @@ chat llm tools messages = do then case Aeson.decode (HTTP.getResponseBody response) of Just resp -> case respChoices resp of - (c : _) -> pure (Right (choiceMessage c)) + (c : _) -> pure (Right (ChatResult (choiceMessage c) (respUsage resp))) [] -> pure (Left "No choices in response") Nothing -> pure (Left "Failed to parse response") else pure (Left ("HTTP error: " <> tshow status <> " - " <> TE.decodeUtf8 (BL.toStrict (HTTP.getResponseBody response)))) + +chat :: LLM -> [Tool] -> [Message] -> IO (Either Text Message) +chat llm tools messages = do + result <- chatWithUsage llm tools messages + pure (chatMessage </ result) + +runAgent :: EngineConfig -> AgentConfig -> Text -> IO (Either Text AgentResult) +runAgent engineCfg agentCfg userPrompt = do + let llm = + (engineLLM engineCfg) + { llmModel = agentModel agentCfg + } + tools = agentTools agentCfg + toolMap = buildToolMap tools + systemMsg = Message System (agentSystemPrompt agentCfg) Nothing Nothing + userMsg = Message User userPrompt Nothing Nothing + initialMessages = [systemMsg, userMsg] + + engineOnActivity engineCfg "Starting agent loop" + loop llm tools toolMap initialMessages 0 0 0 + where + maxIter = agentMaxIterations agentCfg + + loop :: LLM -> [Tool] -> Map.Map Text Tool -> [Message] -> Int -> Int -> Int -> IO (Either Text AgentResult) + loop llm tools' toolMap msgs iteration totalCalls totalTokens + | iteration >= maxIter = + pure + <| Left + <| "Max iterations (" + <> tshow maxIter + <> ") reached" + | otherwise = do + engineOnActivity engineCfg <| "Iteration " <> tshow (iteration + 1) + result <- chatWithUsage llm tools' msgs + case result of + Left err -> pure (Left err) + Right chatRes -> do + let msg = chatMessage chatRes + tokens = maybe 0 usageTotalTokens (chatUsage chatRes) + cost = estimateCost (llmModel llm) tokens + engineOnCost engineCfg tokens cost + let newTokens = totalTokens + tokens + case msgToolCalls msg of + Nothing -> do + engineOnActivity engineCfg "Agent completed" + pure + <| Right + <| AgentResult + { resultFinalMessage = msgContent msg, + resultToolCallCount = totalCalls, + resultIterations = iteration + 1, + resultTotalCost = estimateTotalCost (llmModel llm) newTokens, + resultTotalTokens = newTokens + } + Just [] -> do + engineOnActivity engineCfg "Agent completed (empty tool calls)" + pure + <| Right + <| AgentResult + { resultFinalMessage = msgContent msg, + resultToolCallCount = totalCalls, + resultIterations = iteration + 1, + resultTotalCost = estimateTotalCost (llmModel llm) newTokens, + resultTotalTokens = newTokens + } + Just tcs -> do + toolResults <- executeToolCalls engineCfg toolMap tcs + let newMsgs = msgs <> [msg] <> toolResults + newCalls = totalCalls + length tcs + loop llm tools' toolMap newMsgs (iteration + 1) newCalls newTokens + +buildToolMap :: [Tool] -> Map.Map Text Tool +buildToolMap = Map.fromList <. map (\t -> (toolName t, t)) + +executeToolCalls :: EngineConfig -> Map.Map Text Tool -> [ToolCall] -> IO [Message] +executeToolCalls engineCfg toolMap = traverse executeSingle + where + executeSingle tc = do + let name = fcName (tcFunction tc) + argsText = fcArguments (tcFunction tc) + callId = tcId tc + engineOnActivity engineCfg <| "Executing tool: " <> name + case Map.lookup name toolMap of + Nothing -> do + let errMsg = "Tool not found: " <> name + engineOnToolCall engineCfg name errMsg + pure <| Message ToolRole errMsg Nothing (Just callId) + Just tool -> do + case Aeson.decode (BL.fromStrict (TE.encodeUtf8 argsText)) of + Nothing -> do + let errMsg = "Invalid JSON arguments: " <> argsText + engineOnToolCall engineCfg name errMsg + pure <| Message ToolRole errMsg Nothing (Just callId) + Just args -> do + resultValue <- toolExecute tool args + let resultText = TE.decodeUtf8 (BL.toStrict (Aeson.encode resultValue)) + summary = Text.take 100 resultText + engineOnToolCall engineCfg name summary + pure <| Message ToolRole resultText Nothing (Just callId) + +estimateCost :: Text -> Int -> Int +estimateCost model tokens + | "gpt-4o-mini" `Text.isInfixOf` model = tokens * 15 `div` 1000000 + | "gpt-4o" `Text.isInfixOf` model = tokens * 250 `div` 100000 + | "gpt-4" `Text.isInfixOf` model = tokens * 3 `div` 100000 + | "claude" `Text.isInfixOf` model = tokens * 3 `div` 100000 + | otherwise = tokens `div` 100000 + +estimateTotalCost :: Text -> Int -> Int +estimateTotalCost = estimateCost |
