diff options
Diffstat (limited to 'Omni/Agent/Provider.hs')
| -rw-r--r-- | Omni/Agent/Provider.hs | 110 |
1 files changed, 79 insertions, 31 deletions
diff --git a/Omni/Agent/Provider.hs b/Omni/Agent/Provider.hs index 1bb4f04..db30e5f 100644 --- a/Omni/Agent/Provider.hs +++ b/Omni/Agent/Provider.hs @@ -52,6 +52,7 @@ import qualified Network.HTTP.Client.TLS as HTTPClientTLS import qualified Network.HTTP.Simple as HTTP import Network.HTTP.Types.Status (statusCode) import qualified Omni.Test as Test +import qualified System.Timeout as Timeout main :: IO () main = Test.run test @@ -74,6 +75,43 @@ test = chatMessage result Test.@=? msg ] +-- | HTTP request timeout in microseconds (60 seconds) +httpTimeoutMicros :: Int +httpTimeoutMicros = 60 * 1000000 + +-- | Maximum number of retries for transient failures +maxRetries :: Int +maxRetries = 3 + +-- | Initial backoff delay in microseconds (1 second) +initialBackoffMicros :: Int +initialBackoffMicros = 1000000 + +-- | Retry an IO action with exponential backoff +-- Retries on timeout, connection errors, and 5xx status codes +retryWithBackoff :: Int -> Int -> IO (Either Text a) -> IO (Either Text a) +retryWithBackoff retriesLeft backoff action + | retriesLeft <= 0 = action + | otherwise = do + result <- Timeout.timeout httpTimeoutMicros action + case result of + Nothing -> do + threadDelay backoff + retryWithBackoff (retriesLeft - 1) (backoff * 2) action + Just (Left err) + | isRetryable err -> do + threadDelay backoff + retryWithBackoff (retriesLeft - 1) (backoff * 2) action + Just r -> pure r + where + isRetryable err = + "HTTP error: 5" + `Text.isPrefixOf` err + || "connection" + `Text.isInfixOf` Text.toLower err + || "timeout" + `Text.isInfixOf` Text.toLower err + data Provider = OpenRouter ProviderConfig | Ollama ProviderConfig @@ -330,20 +368,21 @@ chatOpenAI cfg tools messages = do req = foldr addHeader baseReq (providerExtraHeaders cfg) addHeader (name, value) = HTTP.addRequestHeader (CI.mk name) value - response <- HTTP.httpLBS req - let status = HTTP.getResponseStatusCode response - respBody = HTTP.getResponseBody response - cleanedBody = BL.dropWhile (\b -> b `elem` [0x0a, 0x0d, 0x20]) respBody - if status >= 200 && status < 300 - then case Aeson.eitherDecode cleanedBody of - Right resp -> - case respChoices resp of - (c : _) -> pure (Right (ChatResult (choiceMessage c) (respUsage resp))) - [] -> pure (Left "No choices in response") - Left err -> do - let bodyPreview = TE.decodeUtf8 (BL.toStrict (BL.take 500 cleanedBody)) - pure (Left ("Failed to parse response: " <> Text.pack err <> " | Body: " <> bodyPreview)) - else pure (Left ("HTTP error: " <> tshow status <> " - " <> TE.decodeUtf8 (BL.toStrict respBody))) + retryWithBackoff maxRetries initialBackoffMicros <| do + response <- HTTP.httpLBS req + let status = HTTP.getResponseStatusCode response + respBody = HTTP.getResponseBody response + cleanedBody = BL.dropWhile (\b -> b `elem` [0x0a, 0x0d, 0x20]) respBody + if status >= 200 && status < 300 + then case Aeson.eitherDecode cleanedBody of + Right resp -> + case respChoices resp of + (c : _) -> pure (Right (ChatResult (choiceMessage c) (respUsage resp))) + [] -> pure (Left "No choices in response") + Left err -> do + let bodyPreview = TE.decodeUtf8 (BL.toStrict (BL.take 500 cleanedBody)) + pure (Left ("Failed to parse response: " <> Text.pack err <> " | Body: " <> bodyPreview)) + else pure (Left ("HTTP error: " <> tshow status <> " - " <> TE.decodeUtf8 (BL.toStrict respBody))) chatOllama :: ProviderConfig -> [ToolApi] -> [Message] -> IO (Either Text ChatResult) chatOllama cfg tools messages = do @@ -362,13 +401,14 @@ chatOllama cfg tools messages = do <| HTTP.setRequestBodyLBS (Aeson.encode body) <| req0 - response <- HTTP.httpLBS req - let status = HTTP.getResponseStatusCode response - if status >= 200 && status < 300 - then case Aeson.decode (HTTP.getResponseBody response) of - Just resp -> parseOllamaResponse resp - Nothing -> pure (Left ("Failed to parse Ollama response: " <> TE.decodeUtf8 (BL.toStrict (HTTP.getResponseBody response)))) - else pure (Left ("HTTP error: " <> tshow status <> " - " <> TE.decodeUtf8 (BL.toStrict (HTTP.getResponseBody response)))) + retryWithBackoff maxRetries initialBackoffMicros <| do + response <- HTTP.httpLBS req + let status = HTTP.getResponseStatusCode response + if status >= 200 && status < 300 + then case Aeson.decode (HTTP.getResponseBody response) of + Just resp -> parseOllamaResponse resp + Nothing -> pure (Left ("Failed to parse Ollama response: " <> TE.decodeUtf8 (BL.toStrict (HTTP.getResponseBody response)))) + else pure (Left ("HTTP error: " <> tshow status <> " - " <> TE.decodeUtf8 (BL.toStrict (HTTP.getResponseBody response)))) parseOllamaResponse :: Aeson.Value -> IO (Either Text ChatResult) parseOllamaResponse val = @@ -423,7 +463,11 @@ chatStream (AmpCLI _) _tools _messages _onChunk = pure (Left "Streaming not impl chatStreamOpenAI :: ProviderConfig -> [ToolApi] -> [Message] -> (StreamChunk -> IO ()) -> IO (Either Text ChatResult) chatStreamOpenAI cfg tools messages onChunk = do let url = Text.unpack (providerBaseUrl cfg) <> "/chat/completions" - manager <- HTTPClient.newManager HTTPClientTLS.tlsManagerSettings + managerSettings = + HTTPClientTLS.tlsManagerSettings + { HTTPClient.managerResponseTimeout = HTTPClient.responseTimeoutMicro httpTimeoutMicros + } + manager <- HTTPClient.newManager managerSettings req0 <- HTTP.parseRequest url let body = Aeson.object @@ -443,15 +487,19 @@ chatStreamOpenAI cfg tools messages onChunk = do req = foldr addHeader baseReq (providerExtraHeaders cfg) addHeader (name, value) = HTTP.addRequestHeader (CI.mk name) value - HTTPClient.withResponse req manager <| \response -> do - let status = HTTPClient.responseStatus response - code = statusCode status - if code >= 200 && code < 300 - then processSSEStream (HTTPClient.responseBody response) onChunk - else do - bodyChunks <- readAllBody (HTTPClient.responseBody response) - let errBody = TE.decodeUtf8 (BS.concat bodyChunks) - pure (Left ("HTTP error: " <> tshow code <> " - " <> errBody)) + result <- + try <| HTTPClient.withResponse req manager <| \response -> do + let status = HTTPClient.responseStatus response + code = statusCode status + if code >= 200 && code < 300 + then processSSEStream (HTTPClient.responseBody response) onChunk + else do + bodyChunks <- readAllBody (HTTPClient.responseBody response) + let errBody = TE.decodeUtf8 (BS.concat bodyChunks) + pure (Left ("HTTP error: " <> tshow code <> " - " <> errBody)) + case result of + Left (e :: SomeException) -> pure (Left ("Stream request failed: " <> tshow e)) + Right r -> pure r readAllBody :: IO BS.ByteString -> IO [BS.ByteString] readAllBody readBody = go [] |
