summaryrefslogtreecommitdiff
path: root/Omni/Agent/Provider.hs
diff options
context:
space:
mode:
Diffstat (limited to 'Omni/Agent/Provider.hs')
-rw-r--r--Omni/Agent/Provider.hs110
1 files changed, 79 insertions, 31 deletions
diff --git a/Omni/Agent/Provider.hs b/Omni/Agent/Provider.hs
index 1bb4f04..db30e5f 100644
--- a/Omni/Agent/Provider.hs
+++ b/Omni/Agent/Provider.hs
@@ -52,6 +52,7 @@ import qualified Network.HTTP.Client.TLS as HTTPClientTLS
import qualified Network.HTTP.Simple as HTTP
import Network.HTTP.Types.Status (statusCode)
import qualified Omni.Test as Test
+import qualified System.Timeout as Timeout
main :: IO ()
main = Test.run test
@@ -74,6 +75,43 @@ test =
chatMessage result Test.@=? msg
]
+-- | HTTP request timeout in microseconds (60 seconds)
+httpTimeoutMicros :: Int
+httpTimeoutMicros = 60 * 1000000
+
+-- | Maximum number of retries for transient failures
+maxRetries :: Int
+maxRetries = 3
+
+-- | Initial backoff delay in microseconds (1 second)
+initialBackoffMicros :: Int
+initialBackoffMicros = 1000000
+
+-- | Retry an IO action with exponential backoff
+-- Retries on timeout, connection errors, and 5xx status codes
+retryWithBackoff :: Int -> Int -> IO (Either Text a) -> IO (Either Text a)
+retryWithBackoff retriesLeft backoff action
+ | retriesLeft <= 0 = action
+ | otherwise = do
+ result <- Timeout.timeout httpTimeoutMicros action
+ case result of
+ Nothing -> do
+ threadDelay backoff
+ retryWithBackoff (retriesLeft - 1) (backoff * 2) action
+ Just (Left err)
+ | isRetryable err -> do
+ threadDelay backoff
+ retryWithBackoff (retriesLeft - 1) (backoff * 2) action
+ Just r -> pure r
+ where
+ isRetryable err =
+ "HTTP error: 5"
+ `Text.isPrefixOf` err
+ || "connection"
+ `Text.isInfixOf` Text.toLower err
+ || "timeout"
+ `Text.isInfixOf` Text.toLower err
+
data Provider
= OpenRouter ProviderConfig
| Ollama ProviderConfig
@@ -330,20 +368,21 @@ chatOpenAI cfg tools messages = do
req = foldr addHeader baseReq (providerExtraHeaders cfg)
addHeader (name, value) = HTTP.addRequestHeader (CI.mk name) value
- response <- HTTP.httpLBS req
- let status = HTTP.getResponseStatusCode response
- respBody = HTTP.getResponseBody response
- cleanedBody = BL.dropWhile (\b -> b `elem` [0x0a, 0x0d, 0x20]) respBody
- if status >= 200 && status < 300
- then case Aeson.eitherDecode cleanedBody of
- Right resp ->
- case respChoices resp of
- (c : _) -> pure (Right (ChatResult (choiceMessage c) (respUsage resp)))
- [] -> pure (Left "No choices in response")
- Left err -> do
- let bodyPreview = TE.decodeUtf8 (BL.toStrict (BL.take 500 cleanedBody))
- pure (Left ("Failed to parse response: " <> Text.pack err <> " | Body: " <> bodyPreview))
- else pure (Left ("HTTP error: " <> tshow status <> " - " <> TE.decodeUtf8 (BL.toStrict respBody)))
+ retryWithBackoff maxRetries initialBackoffMicros <| do
+ response <- HTTP.httpLBS req
+ let status = HTTP.getResponseStatusCode response
+ respBody = HTTP.getResponseBody response
+ cleanedBody = BL.dropWhile (\b -> b `elem` [0x0a, 0x0d, 0x20]) respBody
+ if status >= 200 && status < 300
+ then case Aeson.eitherDecode cleanedBody of
+ Right resp ->
+ case respChoices resp of
+ (c : _) -> pure (Right (ChatResult (choiceMessage c) (respUsage resp)))
+ [] -> pure (Left "No choices in response")
+ Left err -> do
+ let bodyPreview = TE.decodeUtf8 (BL.toStrict (BL.take 500 cleanedBody))
+ pure (Left ("Failed to parse response: " <> Text.pack err <> " | Body: " <> bodyPreview))
+ else pure (Left ("HTTP error: " <> tshow status <> " - " <> TE.decodeUtf8 (BL.toStrict respBody)))
chatOllama :: ProviderConfig -> [ToolApi] -> [Message] -> IO (Either Text ChatResult)
chatOllama cfg tools messages = do
@@ -362,13 +401,14 @@ chatOllama cfg tools messages = do
<| HTTP.setRequestBodyLBS (Aeson.encode body)
<| req0
- response <- HTTP.httpLBS req
- let status = HTTP.getResponseStatusCode response
- if status >= 200 && status < 300
- then case Aeson.decode (HTTP.getResponseBody response) of
- Just resp -> parseOllamaResponse resp
- Nothing -> pure (Left ("Failed to parse Ollama response: " <> TE.decodeUtf8 (BL.toStrict (HTTP.getResponseBody response))))
- else pure (Left ("HTTP error: " <> tshow status <> " - " <> TE.decodeUtf8 (BL.toStrict (HTTP.getResponseBody response))))
+ retryWithBackoff maxRetries initialBackoffMicros <| do
+ response <- HTTP.httpLBS req
+ let status = HTTP.getResponseStatusCode response
+ if status >= 200 && status < 300
+ then case Aeson.decode (HTTP.getResponseBody response) of
+ Just resp -> parseOllamaResponse resp
+ Nothing -> pure (Left ("Failed to parse Ollama response: " <> TE.decodeUtf8 (BL.toStrict (HTTP.getResponseBody response))))
+ else pure (Left ("HTTP error: " <> tshow status <> " - " <> TE.decodeUtf8 (BL.toStrict (HTTP.getResponseBody response))))
parseOllamaResponse :: Aeson.Value -> IO (Either Text ChatResult)
parseOllamaResponse val =
@@ -423,7 +463,11 @@ chatStream (AmpCLI _) _tools _messages _onChunk = pure (Left "Streaming not impl
chatStreamOpenAI :: ProviderConfig -> [ToolApi] -> [Message] -> (StreamChunk -> IO ()) -> IO (Either Text ChatResult)
chatStreamOpenAI cfg tools messages onChunk = do
let url = Text.unpack (providerBaseUrl cfg) <> "/chat/completions"
- manager <- HTTPClient.newManager HTTPClientTLS.tlsManagerSettings
+ managerSettings =
+ HTTPClientTLS.tlsManagerSettings
+ { HTTPClient.managerResponseTimeout = HTTPClient.responseTimeoutMicro httpTimeoutMicros
+ }
+ manager <- HTTPClient.newManager managerSettings
req0 <- HTTP.parseRequest url
let body =
Aeson.object
@@ -443,15 +487,19 @@ chatStreamOpenAI cfg tools messages onChunk = do
req = foldr addHeader baseReq (providerExtraHeaders cfg)
addHeader (name, value) = HTTP.addRequestHeader (CI.mk name) value
- HTTPClient.withResponse req manager <| \response -> do
- let status = HTTPClient.responseStatus response
- code = statusCode status
- if code >= 200 && code < 300
- then processSSEStream (HTTPClient.responseBody response) onChunk
- else do
- bodyChunks <- readAllBody (HTTPClient.responseBody response)
- let errBody = TE.decodeUtf8 (BS.concat bodyChunks)
- pure (Left ("HTTP error: " <> tshow code <> " - " <> errBody))
+ result <-
+ try <| HTTPClient.withResponse req manager <| \response -> do
+ let status = HTTPClient.responseStatus response
+ code = statusCode status
+ if code >= 200 && code < 300
+ then processSSEStream (HTTPClient.responseBody response) onChunk
+ else do
+ bodyChunks <- readAllBody (HTTPClient.responseBody response)
+ let errBody = TE.decodeUtf8 (BS.concat bodyChunks)
+ pure (Left ("HTTP error: " <> tshow code <> " - " <> errBody))
+ case result of
+ Left (e :: SomeException) -> pure (Left ("Stream request failed: " <> tshow e))
+ Right r -> pure r
readAllBody :: IO BS.ByteString -> IO [BS.ByteString]
readAllBody readBody = go []