summaryrefslogtreecommitdiff
path: root/Omni/Agent/Provider.hs
diff options
context:
space:
mode:
authorBen Sima <ben@bensima.com>2025-12-17 13:29:40 -0500
committerBen Sima <ben@bensima.com>2025-12-17 13:29:40 -0500
commitab01b34bf563990e0f491ada646472aaade97610 (patch)
tree5e46a1a157bb846b0c3a090a83153c788da2b977 /Omni/Agent/Provider.hs
parente112d3ce07fa24f31a281e521a554cc881a76c7b (diff)
parent337648981cc5a55935116141341521f4fce83214 (diff)
Merge Ava deployment changes
Diffstat (limited to 'Omni/Agent/Provider.hs')
-rw-r--r--Omni/Agent/Provider.hs695
1 files changed, 695 insertions, 0 deletions
diff --git a/Omni/Agent/Provider.hs b/Omni/Agent/Provider.hs
new file mode 100644
index 0000000..db30e5f
--- /dev/null
+++ b/Omni/Agent/Provider.hs
@@ -0,0 +1,695 @@
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE NoImplicitPrelude #-}
+
+-- | LLM Provider abstraction for multi-backend support.
+--
+-- Supports multiple LLM backends:
+-- - OpenRouter (cloud API, multiple models)
+-- - Ollama (local models)
+-- - Amp CLI (subprocess)
+--
+-- : out omni-agent-provider
+-- : dep aeson
+-- : dep http-conduit
+-- : dep http-client-tls
+-- : dep http-types
+-- : dep case-insensitive
+module Omni.Agent.Provider
+ ( Provider (..),
+ ProviderConfig (..),
+ ChatResult (..),
+ Message (..),
+ Role (..),
+ ToolCall (..),
+ FunctionCall (..),
+ Usage (..),
+ ToolApi (..),
+ StreamChunk (..),
+ defaultOpenRouter,
+ defaultOllama,
+ chat,
+ chatWithUsage,
+ chatStream,
+ main,
+ test,
+ )
+where
+
+import Alpha
+import Data.Aeson ((.=))
+import qualified Data.Aeson as Aeson
+import qualified Data.Aeson.KeyMap as KeyMap
+import qualified Data.ByteString as BS
+import qualified Data.ByteString.Lazy as BL
+import qualified Data.CaseInsensitive as CI
+import Data.IORef (modifyIORef, newIORef, readIORef, writeIORef)
+import qualified Data.IntMap.Strict as IntMap
+import qualified Data.Text as Text
+import qualified Data.Text.Encoding as TE
+import qualified Network.HTTP.Client as HTTPClient
+import qualified Network.HTTP.Client.TLS as HTTPClientTLS
+import qualified Network.HTTP.Simple as HTTP
+import Network.HTTP.Types.Status (statusCode)
+import qualified Omni.Test as Test
+import qualified System.Timeout as Timeout
+
+main :: IO ()
+main = Test.run test
+
+test :: Test.Tree
+test =
+ Test.group
+ "Omni.Agent.Provider"
+ [ Test.unit "defaultOpenRouter has correct endpoint" <| do
+ case defaultOpenRouter "" "test-model" of
+ OpenRouter cfg -> providerBaseUrl cfg Test.@=? "https://openrouter.ai/api/v1"
+ _ -> Test.assertFailure "Expected OpenRouter",
+ Test.unit "defaultOllama has correct endpoint" <| do
+ case defaultOllama "test-model" of
+ Ollama cfg -> providerBaseUrl cfg Test.@=? "http://localhost:11434"
+ _ -> Test.assertFailure "Expected Ollama",
+ Test.unit "ChatResult preserves message" <| do
+ let msg = Message User "test" Nothing Nothing
+ result = ChatResult msg Nothing
+ chatMessage result Test.@=? msg
+ ]
+
+-- | HTTP request timeout in microseconds (60 seconds)
+httpTimeoutMicros :: Int
+httpTimeoutMicros = 60 * 1000000
+
+-- | Maximum number of retries for transient failures
+maxRetries :: Int
+maxRetries = 3
+
+-- | Initial backoff delay in microseconds (1 second)
+initialBackoffMicros :: Int
+initialBackoffMicros = 1000000
+
+-- | Retry an IO action with exponential backoff
+-- Retries on timeout, connection errors, and 5xx status codes
+retryWithBackoff :: Int -> Int -> IO (Either Text a) -> IO (Either Text a)
+retryWithBackoff retriesLeft backoff action
+ | retriesLeft <= 0 = action
+ | otherwise = do
+ result <- Timeout.timeout httpTimeoutMicros action
+ case result of
+ Nothing -> do
+ threadDelay backoff
+ retryWithBackoff (retriesLeft - 1) (backoff * 2) action
+ Just (Left err)
+ | isRetryable err -> do
+ threadDelay backoff
+ retryWithBackoff (retriesLeft - 1) (backoff * 2) action
+ Just r -> pure r
+ where
+ isRetryable err =
+ "HTTP error: 5"
+ `Text.isPrefixOf` err
+ || "connection"
+ `Text.isInfixOf` Text.toLower err
+ || "timeout"
+ `Text.isInfixOf` Text.toLower err
+
+data Provider
+ = OpenRouter ProviderConfig
+ | Ollama ProviderConfig
+ | AmpCLI FilePath
+ deriving (Show, Eq, Generic)
+
+data ProviderConfig = ProviderConfig
+ { providerBaseUrl :: Text,
+ providerApiKey :: Text,
+ providerModel :: Text,
+ providerExtraHeaders :: [(ByteString, ByteString)]
+ }
+ deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON ProviderConfig where
+ toJSON c =
+ Aeson.object
+ [ "baseUrl" .= providerBaseUrl c,
+ "apiKey" .= providerApiKey c,
+ "model" .= providerModel c
+ ]
+
+instance Aeson.FromJSON ProviderConfig where
+ parseJSON =
+ Aeson.withObject "ProviderConfig" <| \v ->
+ (ProviderConfig </ (v Aeson..: "baseUrl"))
+ <*> (v Aeson..: "apiKey")
+ <*> (v Aeson..: "model")
+ <*> pure []
+
+defaultOpenRouter :: Text -> Text -> Provider
+defaultOpenRouter apiKey model =
+ OpenRouter
+ ProviderConfig
+ { providerBaseUrl = "https://openrouter.ai/api/v1",
+ providerApiKey = apiKey,
+ providerModel = model,
+ providerExtraHeaders =
+ [ ("HTTP-Referer", "https://omni.dev"),
+ ("X-Title", "Omni Agent")
+ ]
+ }
+
+defaultOllama :: Text -> Provider
+defaultOllama model =
+ Ollama
+ ProviderConfig
+ { providerBaseUrl = "http://localhost:11434",
+ providerApiKey = "",
+ providerModel = model,
+ providerExtraHeaders = []
+ }
+
+data Role = System | User | Assistant | ToolRole
+ deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON Role where
+ toJSON System = Aeson.String "system"
+ toJSON User = Aeson.String "user"
+ toJSON Assistant = Aeson.String "assistant"
+ toJSON ToolRole = Aeson.String "tool"
+
+instance Aeson.FromJSON Role where
+ parseJSON = Aeson.withText "Role" parseRole
+ where
+ parseRole "system" = pure System
+ parseRole "user" = pure User
+ parseRole "assistant" = pure Assistant
+ parseRole "tool" = pure ToolRole
+ parseRole _ = empty
+
+data Message = Message
+ { msgRole :: Role,
+ msgContent :: Text,
+ msgToolCalls :: Maybe [ToolCall],
+ msgToolCallId :: Maybe Text
+ }
+ deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON Message where
+ toJSON m =
+ Aeson.object
+ <| catMaybes
+ [ Just ("role" .= msgRole m),
+ Just ("content" .= msgContent m),
+ ("tool_calls" .=) </ msgToolCalls m,
+ ("tool_call_id" .=) </ msgToolCallId m
+ ]
+
+instance Aeson.FromJSON Message where
+ parseJSON =
+ Aeson.withObject "Message" <| \v ->
+ (Message </ (v Aeson..: "role"))
+ <*> (v Aeson..:? "content" Aeson..!= "")
+ <*> (v Aeson..:? "tool_calls")
+ <*> (v Aeson..:? "tool_call_id")
+
+data ToolCall = ToolCall
+ { tcId :: Text,
+ tcType :: Text,
+ tcFunction :: FunctionCall
+ }
+ deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON ToolCall where
+ toJSON tc =
+ Aeson.object
+ [ "id" .= tcId tc,
+ "type" .= tcType tc,
+ "function" .= tcFunction tc
+ ]
+
+instance Aeson.FromJSON ToolCall where
+ parseJSON =
+ Aeson.withObject "ToolCall" <| \v ->
+ (ToolCall </ (v Aeson..: "id"))
+ <*> (v Aeson..:? "type" Aeson..!= "function")
+ <*> (v Aeson..: "function")
+
+data FunctionCall = FunctionCall
+ { fcName :: Text,
+ fcArguments :: Text
+ }
+ deriving (Show, Eq, Generic)
+
+instance Aeson.ToJSON FunctionCall where
+ toJSON fc =
+ Aeson.object
+ [ "name" .= fcName fc,
+ "arguments" .= fcArguments fc
+ ]
+
+instance Aeson.FromJSON FunctionCall where
+ parseJSON =
+ Aeson.withObject "FunctionCall" <| \v ->
+ (FunctionCall </ (v Aeson..: "name"))
+ <*> (v Aeson..:? "arguments" Aeson..!= "{}")
+
+data Usage = Usage
+ { usagePromptTokens :: Int,
+ usageCompletionTokens :: Int,
+ usageTotalTokens :: Int,
+ usageCost :: Maybe Double
+ }
+ deriving (Show, Eq, Generic)
+
+instance Aeson.FromJSON Usage where
+ parseJSON =
+ Aeson.withObject "Usage" <| \v ->
+ (Usage </ (v Aeson..: "prompt_tokens"))
+ <*> (v Aeson..: "completion_tokens")
+ <*> (v Aeson..: "total_tokens")
+ <*> (v Aeson..:? "cost")
+
+data ChatResult = ChatResult
+ { chatMessage :: Message,
+ chatUsage :: Maybe Usage
+ }
+ deriving (Show, Eq)
+
+data ToolApi = ToolApi
+ { toolApiName :: Text,
+ toolApiDescription :: Text,
+ toolApiParameters :: Aeson.Value
+ }
+ deriving (Generic)
+
+instance Aeson.ToJSON ToolApi where
+ toJSON t =
+ Aeson.object
+ [ "type" .= ("function" :: Text),
+ "function"
+ .= Aeson.object
+ [ "name" .= toolApiName t,
+ "description" .= toolApiDescription t,
+ "parameters" .= toolApiParameters t
+ ]
+ ]
+
+data ChatCompletionRequest = ChatCompletionRequest
+ { reqModel :: Text,
+ reqMessages :: [Message],
+ reqTools :: Maybe [ToolApi]
+ }
+ deriving (Generic)
+
+instance Aeson.ToJSON ChatCompletionRequest where
+ toJSON r =
+ Aeson.object
+ <| catMaybes
+ [ Just ("model" .= reqModel r),
+ Just ("messages" .= reqMessages r),
+ ("tools" .=) </ reqTools r,
+ Just ("usage" .= Aeson.object ["include" .= True])
+ ]
+
+data Choice = Choice
+ { choiceIndex :: Int,
+ choiceMessage :: Message,
+ choiceFinishReason :: Maybe Text
+ }
+ deriving (Show, Eq, Generic)
+
+instance Aeson.FromJSON Choice where
+ parseJSON =
+ Aeson.withObject "Choice" <| \v ->
+ (Choice </ (v Aeson..: "index"))
+ <*> (v Aeson..: "message")
+ <*> (v Aeson..:? "finish_reason")
+
+data ChatCompletionResponse = ChatCompletionResponse
+ { respId :: Text,
+ respChoices :: [Choice],
+ respModel :: Text,
+ respUsage :: Maybe Usage
+ }
+ deriving (Show, Eq, Generic)
+
+instance Aeson.FromJSON ChatCompletionResponse where
+ parseJSON =
+ Aeson.withObject "ChatCompletionResponse" <| \v ->
+ (ChatCompletionResponse </ (v Aeson..: "id"))
+ <*> (v Aeson..: "choices")
+ <*> (v Aeson..: "model")
+ <*> (v Aeson..:? "usage")
+
+chat :: Provider -> [ToolApi] -> [Message] -> IO (Either Text Message)
+chat provider tools messages = do
+ result <- chatWithUsage provider tools messages
+ pure (chatMessage </ result)
+
+chatWithUsage :: Provider -> [ToolApi] -> [Message] -> IO (Either Text ChatResult)
+chatWithUsage (OpenRouter cfg) tools messages = chatOpenAI cfg tools messages
+chatWithUsage (Ollama cfg) tools messages = chatOllama cfg tools messages
+chatWithUsage (AmpCLI _promptFile) _tools _messages = do
+ pure (Left "Amp CLI provider not yet implemented")
+
+chatOpenAI :: ProviderConfig -> [ToolApi] -> [Message] -> IO (Either Text ChatResult)
+chatOpenAI cfg tools messages = do
+ let url = Text.unpack (providerBaseUrl cfg) <> "/chat/completions"
+ req0 <- HTTP.parseRequest url
+ let body =
+ ChatCompletionRequest
+ { reqModel = providerModel cfg,
+ reqMessages = messages,
+ reqTools = if null tools then Nothing else Just tools
+ }
+ baseReq =
+ HTTP.setRequestMethod "POST"
+ <| HTTP.setRequestHeader "Content-Type" ["application/json"]
+ <| HTTP.setRequestHeader "Authorization" ["Bearer " <> TE.encodeUtf8 (providerApiKey cfg)]
+ <| HTTP.setRequestBodyLBS (Aeson.encode body)
+ <| req0
+ req = foldr addHeader baseReq (providerExtraHeaders cfg)
+ addHeader (name, value) = HTTP.addRequestHeader (CI.mk name) value
+
+ retryWithBackoff maxRetries initialBackoffMicros <| do
+ response <- HTTP.httpLBS req
+ let status = HTTP.getResponseStatusCode response
+ respBody = HTTP.getResponseBody response
+ cleanedBody = BL.dropWhile (\b -> b `elem` [0x0a, 0x0d, 0x20]) respBody
+ if status >= 200 && status < 300
+ then case Aeson.eitherDecode cleanedBody of
+ Right resp ->
+ case respChoices resp of
+ (c : _) -> pure (Right (ChatResult (choiceMessage c) (respUsage resp)))
+ [] -> pure (Left "No choices in response")
+ Left err -> do
+ let bodyPreview = TE.decodeUtf8 (BL.toStrict (BL.take 500 cleanedBody))
+ pure (Left ("Failed to parse response: " <> Text.pack err <> " | Body: " <> bodyPreview))
+ else pure (Left ("HTTP error: " <> tshow status <> " - " <> TE.decodeUtf8 (BL.toStrict respBody)))
+
+chatOllama :: ProviderConfig -> [ToolApi] -> [Message] -> IO (Either Text ChatResult)
+chatOllama cfg tools messages = do
+ let url = Text.unpack (providerBaseUrl cfg) <> "/api/chat"
+ req0 <- HTTP.parseRequest url
+ let body =
+ Aeson.object
+ [ "model" .= providerModel cfg,
+ "messages" .= messages,
+ "tools" .= if null tools then Aeson.Null else Aeson.toJSON tools,
+ "stream" .= False
+ ]
+ req =
+ HTTP.setRequestMethod "POST"
+ <| HTTP.setRequestHeader "Content-Type" ["application/json"]
+ <| HTTP.setRequestBodyLBS (Aeson.encode body)
+ <| req0
+
+ retryWithBackoff maxRetries initialBackoffMicros <| do
+ response <- HTTP.httpLBS req
+ let status = HTTP.getResponseStatusCode response
+ if status >= 200 && status < 300
+ then case Aeson.decode (HTTP.getResponseBody response) of
+ Just resp -> parseOllamaResponse resp
+ Nothing -> pure (Left ("Failed to parse Ollama response: " <> TE.decodeUtf8 (BL.toStrict (HTTP.getResponseBody response))))
+ else pure (Left ("HTTP error: " <> tshow status <> " - " <> TE.decodeUtf8 (BL.toStrict (HTTP.getResponseBody response))))
+
+parseOllamaResponse :: Aeson.Value -> IO (Either Text ChatResult)
+parseOllamaResponse val =
+ case val of
+ Aeson.Object obj -> do
+ let msgResult = do
+ msgObj <- case KeyMap.lookup "message" obj of
+ Just m -> Right m
+ Nothing -> Left "No message in response"
+ case Aeson.fromJSON msgObj of
+ Aeson.Success msg -> Right msg
+ Aeson.Error e -> Left (Text.pack e)
+ usageResult = case KeyMap.lookup "prompt_eval_count" obj of
+ Just (Aeson.Number promptTokens) ->
+ case KeyMap.lookup "eval_count" obj of
+ Just (Aeson.Number evalTokens) ->
+ Just
+ Usage
+ { usagePromptTokens = round promptTokens,
+ usageCompletionTokens = round evalTokens,
+ usageTotalTokens = round promptTokens + round evalTokens,
+ usageCost = Nothing
+ }
+ _ -> Nothing
+ _ -> Nothing
+ case msgResult of
+ Right msg -> pure (Right (ChatResult msg usageResult))
+ Left e -> pure (Left e)
+ _ -> pure (Left "Expected object response from Ollama")
+
+data StreamChunk
+ = StreamContent Text
+ | StreamToolCall ToolCall
+ | StreamToolCallDelta ToolCallDelta
+ | StreamDone ChatResult
+ | StreamError Text
+ deriving (Show, Eq)
+
+data ToolCallDelta = ToolCallDelta
+ { tcdIndex :: Int,
+ tcdId :: Maybe Text,
+ tcdFunctionName :: Maybe Text,
+ tcdFunctionArgs :: Maybe Text
+ }
+ deriving (Show, Eq)
+
+chatStream :: Provider -> [ToolApi] -> [Message] -> (StreamChunk -> IO ()) -> IO (Either Text ChatResult)
+chatStream (OpenRouter cfg) tools messages onChunk = chatStreamOpenAI cfg tools messages onChunk
+chatStream (Ollama _cfg) _tools _messages _onChunk = pure (Left "Streaming not implemented for Ollama")
+chatStream (AmpCLI _) _tools _messages _onChunk = pure (Left "Streaming not implemented for AmpCLI")
+
+chatStreamOpenAI :: ProviderConfig -> [ToolApi] -> [Message] -> (StreamChunk -> IO ()) -> IO (Either Text ChatResult)
+chatStreamOpenAI cfg tools messages onChunk = do
+ let url = Text.unpack (providerBaseUrl cfg) <> "/chat/completions"
+ managerSettings =
+ HTTPClientTLS.tlsManagerSettings
+ { HTTPClient.managerResponseTimeout = HTTPClient.responseTimeoutMicro httpTimeoutMicros
+ }
+ manager <- HTTPClient.newManager managerSettings
+ req0 <- HTTP.parseRequest url
+ let body =
+ Aeson.object
+ <| catMaybes
+ [ Just ("model" .= providerModel cfg),
+ Just ("messages" .= messages),
+ if null tools then Nothing else Just ("tools" .= tools),
+ Just ("stream" .= True),
+ Just ("usage" .= Aeson.object ["include" .= True])
+ ]
+ baseReq =
+ HTTP.setRequestMethod "POST"
+ <| HTTP.setRequestHeader "Content-Type" ["application/json"]
+ <| HTTP.setRequestHeader "Authorization" ["Bearer " <> TE.encodeUtf8 (providerApiKey cfg)]
+ <| HTTP.setRequestBodyLBS (Aeson.encode body)
+ <| req0
+ req = foldr addHeader baseReq (providerExtraHeaders cfg)
+ addHeader (name, value) = HTTP.addRequestHeader (CI.mk name) value
+
+ result <-
+ try <| HTTPClient.withResponse req manager <| \response -> do
+ let status = HTTPClient.responseStatus response
+ code = statusCode status
+ if code >= 200 && code < 300
+ then processSSEStream (HTTPClient.responseBody response) onChunk
+ else do
+ bodyChunks <- readAllBody (HTTPClient.responseBody response)
+ let errBody = TE.decodeUtf8 (BS.concat bodyChunks)
+ pure (Left ("HTTP error: " <> tshow code <> " - " <> errBody))
+ case result of
+ Left (e :: SomeException) -> pure (Left ("Stream request failed: " <> tshow e))
+ Right r -> pure r
+
+readAllBody :: IO BS.ByteString -> IO [BS.ByteString]
+readAllBody readBody = go []
+ where
+ go acc = do
+ chunk <- readBody
+ if BS.null chunk
+ then pure (reverse acc)
+ else go (chunk : acc)
+
+data ToolCallAccum = ToolCallAccum
+ { tcaId :: Text,
+ tcaName :: Text,
+ tcaArgs :: Text
+ }
+
+processSSEStream :: IO BS.ByteString -> (StreamChunk -> IO ()) -> IO (Either Text ChatResult)
+processSSEStream readBody onChunk = do
+ accumulatedContent <- newIORef ("" :: Text)
+ toolCallAccum <- newIORef (IntMap.empty :: IntMap.IntMap ToolCallAccum)
+ lastUsage <- newIORef (Nothing :: Maybe Usage)
+ buffer <- newIORef ("" :: Text)
+
+ let loop = do
+ chunk <- readBody
+ if BS.null chunk
+ then do
+ content <- readIORef accumulatedContent
+ accum <- readIORef toolCallAccum
+ usage <- readIORef lastUsage
+ let toolCalls = map accumToToolCall (IntMap.elems accum)
+ finalMsg =
+ Message
+ { msgRole = Assistant,
+ msgContent = content,
+ msgToolCalls = if null toolCalls then Nothing else Just toolCalls,
+ msgToolCallId = Nothing
+ }
+ pure (Right (ChatResult finalMsg usage))
+ else do
+ modifyIORef buffer (<> TE.decodeUtf8 chunk)
+ buf <- readIORef buffer
+ let (events, remaining) = parseSSEEvents buf
+ writeIORef buffer remaining
+ forM_ events <| \event -> do
+ case parseStreamEvent event of
+ Just (StreamContent txt) -> do
+ modifyIORef accumulatedContent (<> txt)
+ onChunk (StreamContent txt)
+ Just (StreamToolCallDelta delta) -> do
+ modifyIORef toolCallAccum (mergeToolCallDelta delta)
+ Just (StreamToolCall tc) -> do
+ modifyIORef toolCallAccum (mergeCompleteToolCall tc)
+ onChunk (StreamToolCall tc)
+ Just (StreamDone result) -> do
+ writeIORef lastUsage (chatUsage result)
+ Just (StreamError err) -> do
+ onChunk (StreamError err)
+ Nothing -> pure ()
+ loop
+
+ loop
+
+accumToToolCall :: ToolCallAccum -> ToolCall
+accumToToolCall acc =
+ ToolCall
+ { tcId = tcaId acc,
+ tcType = "function",
+ tcFunction = FunctionCall (tcaName acc) (tcaArgs acc)
+ }
+
+mergeToolCallDelta :: ToolCallDelta -> IntMap.IntMap ToolCallAccum -> IntMap.IntMap ToolCallAccum
+mergeToolCallDelta delta accum =
+ let idx = tcdIndex delta
+ existing = IntMap.lookup idx accum
+ updated = case existing of
+ Nothing ->
+ ToolCallAccum
+ { tcaId = fromMaybe "" (tcdId delta),
+ tcaName = fromMaybe "" (tcdFunctionName delta),
+ tcaArgs = fromMaybe "" (tcdFunctionArgs delta)
+ }
+ Just a ->
+ a
+ { tcaId = fromMaybe (tcaId a) (tcdId delta),
+ tcaName = fromMaybe (tcaName a) (tcdFunctionName delta),
+ tcaArgs = tcaArgs a <> fromMaybe "" (tcdFunctionArgs delta)
+ }
+ in IntMap.insert idx updated accum
+
+mergeCompleteToolCall :: ToolCall -> IntMap.IntMap ToolCallAccum -> IntMap.IntMap ToolCallAccum
+mergeCompleteToolCall tc accum =
+ let nextIdx = if IntMap.null accum then 0 else fst (IntMap.findMax accum) + 1
+ newAccum =
+ ToolCallAccum
+ { tcaId = tcId tc,
+ tcaName = fcName (tcFunction tc),
+ tcaArgs = fcArguments (tcFunction tc)
+ }
+ in IntMap.insert nextIdx newAccum accum
+
+parseSSEEvents :: Text -> ([Text], Text)
+parseSSEEvents input =
+ let lines' = Text.splitOn "\n" input
+ (events, remaining) = go [] [] lines'
+ in (events, remaining)
+ where
+ go events current [] = (reverse events, Text.intercalate "\n" (reverse current))
+ go events current (line : rest)
+ | Text.null line && not (null current) =
+ go (Text.intercalate "\n" (reverse current) : events) [] rest
+ | otherwise =
+ go events (line : current) rest
+
+parseStreamEvent :: Text -> Maybe StreamChunk
+parseStreamEvent eventText = do
+ let dataLines = filter ("data:" `Text.isPrefixOf`) (Text.lines eventText)
+ case dataLines of
+ [] -> Nothing
+ (dataLine : _) -> do
+ let jsonStr = Text.strip (Text.drop 5 dataLine)
+ if jsonStr == "[DONE]"
+ then Nothing
+ else case Aeson.decode (BL.fromStrict (TE.encodeUtf8 jsonStr)) of
+ Nothing -> Nothing
+ Just (Aeson.Object obj) -> parseStreamChunk obj
+ _ -> Nothing
+
+parseStreamChunk :: Aeson.Object -> Maybe StreamChunk
+parseStreamChunk obj = do
+ case KeyMap.lookup "error" obj of
+ Just (Aeson.Object errObj) -> do
+ let errMsg = case KeyMap.lookup "message" errObj of
+ Just (Aeson.String m) -> m
+ _ -> "Unknown error"
+ Just (StreamError errMsg)
+ _ -> do
+ let usageChunk = case KeyMap.lookup "usage" obj of
+ Just usageVal -> case Aeson.fromJSON usageVal of
+ Aeson.Success usage -> Just (StreamDone (ChatResult (Message Assistant "" Nothing Nothing) (Just usage)))
+ _ -> Nothing
+ _ -> Nothing
+ case KeyMap.lookup "choices" obj of
+ Just (Aeson.Array choices) | not (null choices) -> do
+ case toList choices of
+ (Aeson.Object choice : _) -> do
+ case KeyMap.lookup "delta" choice of
+ Just (Aeson.Object delta) -> do
+ let contentChunk = case KeyMap.lookup "content" delta of
+ Just (Aeson.String c) | not (Text.null c) -> Just (StreamContent c)
+ _ -> Nothing
+ toolCallChunk = case KeyMap.lookup "tool_calls" delta of
+ Just (Aeson.Array tcs)
+ | not (null tcs) ->
+ parseToolCallDelta (toList tcs)
+ _ -> Nothing
+ contentChunk <|> toolCallChunk <|> usageChunk
+ _ -> usageChunk
+ _ -> usageChunk
+ _ -> usageChunk
+
+parseToolCallDelta :: [Aeson.Value] -> Maybe StreamChunk
+parseToolCallDelta [] = Nothing
+parseToolCallDelta (Aeson.Object tcObj : _) = do
+ idx <- case KeyMap.lookup "index" tcObj of
+ Just (Aeson.Number n) -> Just (round n)
+ _ -> Nothing
+ let tcId' = case KeyMap.lookup "id" tcObj of
+ Just (Aeson.String s) -> Just s
+ _ -> Nothing
+ funcObj = case KeyMap.lookup "function" tcObj of
+ Just (Aeson.Object f) -> Just f
+ _ -> Nothing
+ funcName = case funcObj of
+ Just f -> case KeyMap.lookup "name" f of
+ Just (Aeson.String s) -> Just s
+ _ -> Nothing
+ Nothing -> Nothing
+ funcArgs = case funcObj of
+ Just f -> case KeyMap.lookup "arguments" f of
+ Just (Aeson.String s) -> Just s
+ _ -> Nothing
+ Nothing -> Nothing
+ Just
+ ( StreamToolCallDelta
+ ToolCallDelta
+ { tcdIndex = idx,
+ tcdId = tcId',
+ tcdFunctionName = funcName,
+ tcdFunctionArgs = funcArgs
+ }
+ )
+parseToolCallDelta _ = Nothing