summaryrefslogtreecommitdiff
path: root/Omni
diff options
context:
space:
mode:
Diffstat (limited to 'Omni')
-rw-r--r--Omni/Agent/Engine.hs228
1 files changed, 223 insertions, 5 deletions
diff --git a/Omni/Agent/Engine.hs b/Omni/Agent/Engine.hs
index ac6c517..10b36b2 100644
--- a/Omni/Agent/Engine.hs
+++ b/Omni/Agent/Engine.hs
@@ -15,7 +15,9 @@
module Omni.Agent.Engine
( Tool (..),
LLM (..),
+ EngineConfig (..),
AgentConfig (..),
+ AgentResult (..),
Message (..),
Role (..),
ToolCall (..),
@@ -24,9 +26,12 @@ module Omni.Agent.Engine
ChatCompletionRequest (..),
ChatCompletionResponse (..),
Choice (..),
+ Usage (..),
defaultLLM,
+ defaultEngineConfig,
defaultAgentConfig,
chat,
+ runAgent,
main,
test,
)
@@ -36,6 +41,7 @@ import Alpha
import Data.Aeson ((.!=), (.:), (.:?), (.=))
import qualified Data.Aeson as Aeson
import qualified Data.ByteString.Lazy as BL
+import qualified Data.Map.Strict as Map
import qualified Data.Text as Text
import qualified Data.Text.Encoding as TE
import qualified Network.HTTP.Simple as HTTP
@@ -68,7 +74,58 @@ test =
Test.unit "defaultLLM has correct endpoint" <| do
llmBaseUrl defaultLLM Test.@=? "https://api.openai.com",
Test.unit "defaultAgentConfig has sensible defaults" <| do
- agentMaxIterations defaultAgentConfig Test.@=? 10
+ agentMaxIterations defaultAgentConfig Test.@=? 10,
+ Test.unit "defaultEngineConfig has no-op callbacks" <| do
+ engineOnCost defaultEngineConfig 100 5
+ engineOnActivity defaultEngineConfig "test"
+ engineOnToolCall defaultEngineConfig "tool" "result"
+ True Test.@=? True,
+ Test.unit "buildToolMap creates correct map" <| do
+ let tool1 =
+ Tool
+ { toolName = "tool1",
+ toolDescription = "First tool",
+ toolJsonSchema = Aeson.object [],
+ toolExecute = \_ -> pure Aeson.Null
+ }
+ tool2 =
+ Tool
+ { toolName = "tool2",
+ toolDescription = "Second tool",
+ toolJsonSchema = Aeson.object [],
+ toolExecute = \_ -> pure Aeson.Null
+ }
+ toolMap = buildToolMap [tool1, tool2]
+ Map.size toolMap Test.@=? 2
+ Map.member "tool1" toolMap Test.@=? True
+ Map.member "tool2" toolMap Test.@=? True,
+ Test.unit "Usage JSON parsing" <| do
+ let json = "{\"prompt_tokens\":100,\"completion_tokens\":50,\"total_tokens\":150}"
+ case Aeson.decode json of
+ Nothing -> Test.assertFailure "Failed to decode usage"
+ Just usage -> do
+ usagePromptTokens usage Test.@=? 100
+ usageCompletionTokens usage Test.@=? 50
+ usageTotalTokens usage Test.@=? 150,
+ Test.unit "AgentResult JSON roundtrip" <| do
+ let result =
+ AgentResult
+ { resultFinalMessage = "Done",
+ resultToolCallCount = 3,
+ resultIterations = 2,
+ resultTotalCost = 50,
+ resultTotalTokens = 1500
+ }
+ case Aeson.decode (Aeson.encode result) of
+ Nothing -> Test.assertFailure "Failed to decode AgentResult"
+ Just decoded -> do
+ resultFinalMessage decoded Test.@=? "Done"
+ resultToolCallCount decoded Test.@=? 3
+ resultIterations decoded Test.@=? 2,
+ Test.unit "estimateCost calculates correctly" <| do
+ let gpt4oCost = estimateCost "gpt-4o" 1000
+ gpt4oMiniCost = estimateCost "gpt-4o-mini" 1000
+ (gpt4oCost >= gpt4oMiniCost) Test.@=? True
]
data Tool = Tool
@@ -148,6 +205,35 @@ defaultAgentConfig =
agentMaxIterations = 10
}
+data EngineConfig = EngineConfig
+ { engineLLM :: LLM,
+ engineOnCost :: Int -> Int -> IO (),
+ engineOnActivity :: Text -> IO (),
+ engineOnToolCall :: Text -> Text -> IO ()
+ }
+
+defaultEngineConfig :: EngineConfig
+defaultEngineConfig =
+ EngineConfig
+ { engineLLM = defaultLLM,
+ engineOnCost = \_ _ -> pure (),
+ engineOnActivity = \_ -> pure (),
+ engineOnToolCall = \_ _ -> pure ()
+ }
+
+data AgentResult = AgentResult
+ { resultFinalMessage :: Text,
+ resultToolCallCount :: Int,
+ resultIterations :: Int,
+ resultTotalCost :: Int,
+ resultTotalTokens :: Int
+ }
+ deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON AgentResult
+
+instance Aeson.FromJSON AgentResult
+
data Role = System | User | Assistant | ToolRole
deriving (Show, Eq, Generic)
@@ -273,10 +359,25 @@ instance Aeson.FromJSON Choice where
<*> (v .: "message")
<*> (v .:? "finish_reason")
+data Usage = Usage
+ { usagePromptTokens :: Int,
+ usageCompletionTokens :: Int,
+ usageTotalTokens :: Int
+ }
+ deriving (Show, Eq, Generic)
+
+instance Aeson.FromJSON Usage where
+ parseJSON =
+ Aeson.withObject "Usage" <| \v ->
+ (Usage </ (v .: "prompt_tokens"))
+ <*> (v .: "completion_tokens")
+ <*> (v .: "total_tokens")
+
data ChatCompletionResponse = ChatCompletionResponse
{ respId :: Text,
respChoices :: [Choice],
- respModel :: Text
+ respModel :: Text,
+ respUsage :: Maybe Usage
}
deriving (Show, Eq, Generic)
@@ -286,9 +387,16 @@ instance Aeson.FromJSON ChatCompletionResponse where
(ChatCompletionResponse </ (v .: "id"))
<*> (v .: "choices")
<*> (v .: "model")
+ <*> (v .:? "usage")
-chat :: LLM -> [Tool] -> [Message] -> IO (Either Text Message)
-chat llm tools messages = do
+data ChatResult = ChatResult
+ { chatMessage :: Message,
+ chatUsage :: Maybe Usage
+ }
+ deriving (Show, Eq)
+
+chatWithUsage :: LLM -> [Tool] -> [Message] -> IO (Either Text ChatResult)
+chatWithUsage llm tools messages = do
let url = Text.unpack (llmBaseUrl llm) <> "/v1/chat/completions"
req0 <- HTTP.parseRequest url
let toolApis = [encodeToolForApi t | not (null tools), t <- tools]
@@ -311,7 +419,117 @@ chat llm tools messages = do
then case Aeson.decode (HTTP.getResponseBody response) of
Just resp ->
case respChoices resp of
- (c : _) -> pure (Right (choiceMessage c))
+ (c : _) -> pure (Right (ChatResult (choiceMessage c) (respUsage resp)))
[] -> pure (Left "No choices in response")
Nothing -> pure (Left "Failed to parse response")
else pure (Left ("HTTP error: " <> tshow status <> " - " <> TE.decodeUtf8 (BL.toStrict (HTTP.getResponseBody response))))
+
+chat :: LLM -> [Tool] -> [Message] -> IO (Either Text Message)
+chat llm tools messages = do
+ result <- chatWithUsage llm tools messages
+ pure (chatMessage </ result)
+
+runAgent :: EngineConfig -> AgentConfig -> Text -> IO (Either Text AgentResult)
+runAgent engineCfg agentCfg userPrompt = do
+ let llm =
+ (engineLLM engineCfg)
+ { llmModel = agentModel agentCfg
+ }
+ tools = agentTools agentCfg
+ toolMap = buildToolMap tools
+ systemMsg = Message System (agentSystemPrompt agentCfg) Nothing Nothing
+ userMsg = Message User userPrompt Nothing Nothing
+ initialMessages = [systemMsg, userMsg]
+
+ engineOnActivity engineCfg "Starting agent loop"
+ loop llm tools toolMap initialMessages 0 0 0
+ where
+ maxIter = agentMaxIterations agentCfg
+
+ loop :: LLM -> [Tool] -> Map.Map Text Tool -> [Message] -> Int -> Int -> Int -> IO (Either Text AgentResult)
+ loop llm tools' toolMap msgs iteration totalCalls totalTokens
+ | iteration >= maxIter =
+ pure
+ <| Left
+ <| "Max iterations ("
+ <> tshow maxIter
+ <> ") reached"
+ | otherwise = do
+ engineOnActivity engineCfg <| "Iteration " <> tshow (iteration + 1)
+ result <- chatWithUsage llm tools' msgs
+ case result of
+ Left err -> pure (Left err)
+ Right chatRes -> do
+ let msg = chatMessage chatRes
+ tokens = maybe 0 usageTotalTokens (chatUsage chatRes)
+ cost = estimateCost (llmModel llm) tokens
+ engineOnCost engineCfg tokens cost
+ let newTokens = totalTokens + tokens
+ case msgToolCalls msg of
+ Nothing -> do
+ engineOnActivity engineCfg "Agent completed"
+ pure
+ <| Right
+ <| AgentResult
+ { resultFinalMessage = msgContent msg,
+ resultToolCallCount = totalCalls,
+ resultIterations = iteration + 1,
+ resultTotalCost = estimateTotalCost (llmModel llm) newTokens,
+ resultTotalTokens = newTokens
+ }
+ Just [] -> do
+ engineOnActivity engineCfg "Agent completed (empty tool calls)"
+ pure
+ <| Right
+ <| AgentResult
+ { resultFinalMessage = msgContent msg,
+ resultToolCallCount = totalCalls,
+ resultIterations = iteration + 1,
+ resultTotalCost = estimateTotalCost (llmModel llm) newTokens,
+ resultTotalTokens = newTokens
+ }
+ Just tcs -> do
+ toolResults <- executeToolCalls engineCfg toolMap tcs
+ let newMsgs = msgs <> [msg] <> toolResults
+ newCalls = totalCalls + length tcs
+ loop llm tools' toolMap newMsgs (iteration + 1) newCalls newTokens
+
+buildToolMap :: [Tool] -> Map.Map Text Tool
+buildToolMap = Map.fromList <. map (\t -> (toolName t, t))
+
+executeToolCalls :: EngineConfig -> Map.Map Text Tool -> [ToolCall] -> IO [Message]
+executeToolCalls engineCfg toolMap = traverse executeSingle
+ where
+ executeSingle tc = do
+ let name = fcName (tcFunction tc)
+ argsText = fcArguments (tcFunction tc)
+ callId = tcId tc
+ engineOnActivity engineCfg <| "Executing tool: " <> name
+ case Map.lookup name toolMap of
+ Nothing -> do
+ let errMsg = "Tool not found: " <> name
+ engineOnToolCall engineCfg name errMsg
+ pure <| Message ToolRole errMsg Nothing (Just callId)
+ Just tool -> do
+ case Aeson.decode (BL.fromStrict (TE.encodeUtf8 argsText)) of
+ Nothing -> do
+ let errMsg = "Invalid JSON arguments: " <> argsText
+ engineOnToolCall engineCfg name errMsg
+ pure <| Message ToolRole errMsg Nothing (Just callId)
+ Just args -> do
+ resultValue <- toolExecute tool args
+ let resultText = TE.decodeUtf8 (BL.toStrict (Aeson.encode resultValue))
+ summary = Text.take 100 resultText
+ engineOnToolCall engineCfg name summary
+ pure <| Message ToolRole resultText Nothing (Just callId)
+
+estimateCost :: Text -> Int -> Int
+estimateCost model tokens
+ | "gpt-4o-mini" `Text.isInfixOf` model = tokens * 15 `div` 1000000
+ | "gpt-4o" `Text.isInfixOf` model = tokens * 250 `div` 100000
+ | "gpt-4" `Text.isInfixOf` model = tokens * 3 `div` 100000
+ | "claude" `Text.isInfixOf` model = tokens * 3 `div` 100000
+ | otherwise = tokens `div` 100000
+
+estimateTotalCost :: Text -> Int -> Int
+estimateTotalCost = estimateCost