diff options
Diffstat (limited to 'Omni/Agent/Provider.hs')
| -rw-r--r-- | Omni/Agent/Provider.hs | 257 |
1 files changed, 257 insertions, 0 deletions
diff --git a/Omni/Agent/Provider.hs b/Omni/Agent/Provider.hs index 2ad6ea8..fd6920d 100644 --- a/Omni/Agent/Provider.hs +++ b/Omni/Agent/Provider.hs @@ -12,6 +12,8 @@ -- : out omni-agent-provider -- : dep aeson -- : dep http-conduit +-- : dep http-client-tls +-- : dep http-types -- : dep case-insensitive module Omni.Agent.Provider ( Provider (..), @@ -23,10 +25,12 @@ module Omni.Agent.Provider FunctionCall (..), Usage (..), ToolApi (..), + StreamChunk (..), defaultOpenRouter, defaultOllama, chat, chatWithUsage, + chatStream, main, test, ) @@ -36,11 +40,17 @@ import Alpha import Data.Aeson ((.=)) import qualified Data.Aeson as Aeson import qualified Data.Aeson.KeyMap as KeyMap +import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as BL import qualified Data.CaseInsensitive as CI +import Data.IORef (modifyIORef, newIORef, readIORef, writeIORef) +import qualified Data.IntMap.Strict as IntMap import qualified Data.Text as Text import qualified Data.Text.Encoding as TE +import qualified Network.HTTP.Client as HTTPClient +import qualified Network.HTTP.Client.TLS as HTTPClientTLS import qualified Network.HTTP.Simple as HTTP +import Network.HTTP.Types.Status (statusCode) import qualified Omni.Test as Test main :: IO () @@ -388,3 +398,250 @@ parseOllamaResponse val = Right msg -> pure (Right (ChatResult msg usageResult)) Left e -> pure (Left e) _ -> pure (Left "Expected object response from Ollama") + +data StreamChunk + = StreamContent Text + | StreamToolCall ToolCall + | StreamToolCallDelta ToolCallDelta + | StreamDone ChatResult + | StreamError Text + deriving (Show, Eq) + +data ToolCallDelta = ToolCallDelta + { tcdIndex :: Int, + tcdId :: Maybe Text, + tcdFunctionName :: Maybe Text, + tcdFunctionArgs :: Maybe Text + } + deriving (Show, Eq) + +chatStream :: Provider -> [ToolApi] -> [Message] -> (StreamChunk -> IO ()) -> IO (Either Text ChatResult) +chatStream (OpenRouter cfg) tools messages onChunk = chatStreamOpenAI cfg tools messages onChunk +chatStream (Ollama _cfg) _tools _messages _onChunk = pure (Left "Streaming not implemented for Ollama") +chatStream (AmpCLI _) _tools _messages _onChunk = pure (Left "Streaming not implemented for AmpCLI") + +chatStreamOpenAI :: ProviderConfig -> [ToolApi] -> [Message] -> (StreamChunk -> IO ()) -> IO (Either Text ChatResult) +chatStreamOpenAI cfg tools messages onChunk = do + let url = Text.unpack (providerBaseUrl cfg) <> "/chat/completions" + manager <- HTTPClient.newManager HTTPClientTLS.tlsManagerSettings + req0 <- HTTP.parseRequest url + let body = + Aeson.object + <| catMaybes + [ Just ("model" .= providerModel cfg), + Just ("messages" .= messages), + if null tools then Nothing else Just ("tools" .= tools), + Just ("stream" .= True), + Just ("usage" .= Aeson.object ["include" .= True]) + ] + baseReq = + HTTP.setRequestMethod "POST" + <| HTTP.setRequestHeader "Content-Type" ["application/json"] + <| HTTP.setRequestHeader "Authorization" ["Bearer " <> TE.encodeUtf8 (providerApiKey cfg)] + <| HTTP.setRequestBodyLBS (Aeson.encode body) + <| req0 + req = foldr addHeader baseReq (providerExtraHeaders cfg) + addHeader (name, value) = HTTP.addRequestHeader (CI.mk name) value + + HTTPClient.withResponse req manager <| \response -> do + let status = HTTPClient.responseStatus response + code = statusCode status + if code >= 200 && code < 300 + then processSSEStream (HTTPClient.responseBody response) onChunk + else do + bodyChunks <- readAllBody (HTTPClient.responseBody response) + let errBody = TE.decodeUtf8 (BS.concat bodyChunks) + pure (Left ("HTTP error: " <> tshow code <> " - " <> errBody)) + +readAllBody :: IO BS.ByteString -> IO [BS.ByteString] +readAllBody readBody = go [] + where + go acc = do + chunk <- readBody + if BS.null chunk + then pure (reverse acc) + else go (chunk : acc) + +data ToolCallAccum = ToolCallAccum + { tcaId :: Text, + tcaName :: Text, + tcaArgs :: Text + } + +processSSEStream :: IO BS.ByteString -> (StreamChunk -> IO ()) -> IO (Either Text ChatResult) +processSSEStream readBody onChunk = do + accumulatedContent <- newIORef ("" :: Text) + toolCallAccum <- newIORef (IntMap.empty :: IntMap.IntMap ToolCallAccum) + lastUsage <- newIORef (Nothing :: Maybe Usage) + buffer <- newIORef ("" :: Text) + + let loop = do + chunk <- readBody + if BS.null chunk + then do + content <- readIORef accumulatedContent + accum <- readIORef toolCallAccum + usage <- readIORef lastUsage + let toolCalls = map accumToToolCall (IntMap.elems accum) + finalMsg = + Message + { msgRole = Assistant, + msgContent = content, + msgToolCalls = if null toolCalls then Nothing else Just toolCalls, + msgToolCallId = Nothing + } + pure (Right (ChatResult finalMsg usage)) + else do + modifyIORef buffer (<> TE.decodeUtf8 chunk) + buf <- readIORef buffer + let (events, remaining) = parseSSEEvents buf + writeIORef buffer remaining + forM_ events <| \event -> do + case parseStreamEvent event of + Just (StreamContent txt) -> do + modifyIORef accumulatedContent (<> txt) + onChunk (StreamContent txt) + Just (StreamToolCallDelta delta) -> do + modifyIORef toolCallAccum (mergeToolCallDelta delta) + Just (StreamToolCall tc) -> do + modifyIORef toolCallAccum (mergeCompleteToolCall tc) + onChunk (StreamToolCall tc) + Just (StreamDone result) -> do + writeIORef lastUsage (chatUsage result) + Just (StreamError err) -> do + onChunk (StreamError err) + Nothing -> pure () + loop + + loop + +accumToToolCall :: ToolCallAccum -> ToolCall +accumToToolCall acc = + ToolCall + { tcId = tcaId acc, + tcType = "function", + tcFunction = FunctionCall (tcaName acc) (tcaArgs acc) + } + +mergeToolCallDelta :: ToolCallDelta -> IntMap.IntMap ToolCallAccum -> IntMap.IntMap ToolCallAccum +mergeToolCallDelta delta accum = + let idx = tcdIndex delta + existing = IntMap.lookup idx accum + updated = case existing of + Nothing -> + ToolCallAccum + { tcaId = fromMaybe "" (tcdId delta), + tcaName = fromMaybe "" (tcdFunctionName delta), + tcaArgs = fromMaybe "" (tcdFunctionArgs delta) + } + Just a -> + a + { tcaId = fromMaybe (tcaId a) (tcdId delta), + tcaName = fromMaybe (tcaName a) (tcdFunctionName delta), + tcaArgs = tcaArgs a <> fromMaybe "" (tcdFunctionArgs delta) + } + in IntMap.insert idx updated accum + +mergeCompleteToolCall :: ToolCall -> IntMap.IntMap ToolCallAccum -> IntMap.IntMap ToolCallAccum +mergeCompleteToolCall tc accum = + let nextIdx = if IntMap.null accum then 0 else fst (IntMap.findMax accum) + 1 + newAccum = + ToolCallAccum + { tcaId = tcId tc, + tcaName = fcName (tcFunction tc), + tcaArgs = fcArguments (tcFunction tc) + } + in IntMap.insert nextIdx newAccum accum + +parseSSEEvents :: Text -> ([Text], Text) +parseSSEEvents input = + let lines' = Text.splitOn "\n" input + (events, remaining) = go [] [] lines' + in (events, remaining) + where + go events current [] = (reverse events, Text.intercalate "\n" (reverse current)) + go events current (line : rest) + | Text.null line && not (null current) = + go (Text.intercalate "\n" (reverse current) : events) [] rest + | otherwise = + go events (line : current) rest + +parseStreamEvent :: Text -> Maybe StreamChunk +parseStreamEvent eventText = do + let dataLines = filter ("data:" `Text.isPrefixOf`) (Text.lines eventText) + case dataLines of + [] -> Nothing + (dataLine : _) -> do + let jsonStr = Text.strip (Text.drop 5 dataLine) + if jsonStr == "[DONE]" + then Nothing + else case Aeson.decode (BL.fromStrict (TE.encodeUtf8 jsonStr)) of + Nothing -> Nothing + Just (Aeson.Object obj) -> parseStreamChunk obj + _ -> Nothing + +parseStreamChunk :: Aeson.Object -> Maybe StreamChunk +parseStreamChunk obj = do + case KeyMap.lookup "error" obj of + Just (Aeson.Object errObj) -> do + let errMsg = case KeyMap.lookup "message" errObj of + Just (Aeson.String m) -> m + _ -> "Unknown error" + Just (StreamError errMsg) + _ -> do + case KeyMap.lookup "choices" obj of + Just (Aeson.Array choices) | not (null choices) -> do + case toList choices of + (Aeson.Object choice : _) -> do + case KeyMap.lookup "delta" choice of + Just (Aeson.Object delta) -> do + let contentChunk = case KeyMap.lookup "content" delta of + Just (Aeson.String c) | not (Text.null c) -> Just (StreamContent c) + _ -> Nothing + toolCallChunk = case KeyMap.lookup "tool_calls" delta of + Just (Aeson.Array tcs) + | not (null tcs) -> + parseToolCallDelta (toList tcs) + _ -> Nothing + contentChunk <|> toolCallChunk + _ -> Nothing + _ -> Nothing + _ -> do + case KeyMap.lookup "usage" obj of + Just usageVal -> case Aeson.fromJSON usageVal of + Aeson.Success usage -> Just (StreamDone (ChatResult (Message Assistant "" Nothing Nothing) (Just usage))) + _ -> Nothing + _ -> Nothing + +parseToolCallDelta :: [Aeson.Value] -> Maybe StreamChunk +parseToolCallDelta [] = Nothing +parseToolCallDelta (Aeson.Object tcObj : _) = do + idx <- case KeyMap.lookup "index" tcObj of + Just (Aeson.Number n) -> Just (round n) + _ -> Nothing + let tcId' = case KeyMap.lookup "id" tcObj of + Just (Aeson.String s) -> Just s + _ -> Nothing + funcObj = case KeyMap.lookup "function" tcObj of + Just (Aeson.Object f) -> Just f + _ -> Nothing + funcName = case funcObj of + Just f -> case KeyMap.lookup "name" f of + Just (Aeson.String s) -> Just s + _ -> Nothing + Nothing -> Nothing + funcArgs = case funcObj of + Just f -> case KeyMap.lookup "arguments" f of + Just (Aeson.String s) -> Just s + _ -> Nothing + Nothing -> Nothing + Just + ( StreamToolCallDelta + ToolCallDelta + { tcdIndex = idx, + tcdId = tcId', + tcdFunctionName = funcName, + tcdFunctionArgs = funcArgs + } + ) +parseToolCallDelta _ = Nothing |
