diff options
| author | Ben Sima <ben@bensima.com> | 2025-12-13 08:21:23 -0500 |
|---|---|---|
| committer | Ben Sima <ben@bensima.com> | 2025-12-13 08:21:23 -0500 |
| commit | 1c7b30005af27dcc3345f7dee0fe0404c3bc8c49 (patch) | |
| tree | 25eb99ddfce74749264aa30dcf2992c207cc71a3 | |
| parent | f752330c9562b7a1bbdce15c05106a577daa2392 (diff) | |
fix: accumulate streaming tool call arguments across SSE chunks
OpenAI's SSE streaming sends tool calls incrementally - the first chunk
has the id and function name, subsequent chunks contain argument fragments.
Previously each chunk was treated as a complete tool call, causing invalid
JSON arguments.
- Add ToolCallDelta type with index for partial tool call data
- Add StreamToolCallDelta chunk type
- Track tool calls by index in IntMap accumulator
- Merge argument fragments across chunks via mergeToolCallDelta
- Build final ToolCall objects from accumulator when stream ends
- Handle new StreamToolCallDelta in Engine.hs pattern match
| -rw-r--r-- | Omni/Agent/Engine.hs | 184 | ||||
| -rw-r--r-- | Omni/Agent/Provider.hs | 257 | ||||
| -rw-r--r-- | Omni/Agent/Telegram.hs | 87 |
3 files changed, 523 insertions, 5 deletions
diff --git a/Omni/Agent/Engine.hs b/Omni/Agent/Engine.hs index f9b0355..f137ddb 100644 --- a/Omni/Agent/Engine.hs +++ b/Omni/Agent/Engine.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE NoImplicitPrelude #-} @@ -39,6 +40,7 @@ module Omni.Agent.Engine chat, runAgent, runAgentWithProvider, + runAgentWithProviderStreaming, main, test, ) @@ -50,6 +52,7 @@ import qualified Data.Aeson as Aeson import qualified Data.Aeson.KeyMap as KeyMap import qualified Data.ByteString.Lazy as BL import qualified Data.CaseInsensitive as CI +import Data.IORef (newIORef, writeIORef) import qualified Data.Map.Strict as Map import qualified Data.Text as Text import qualified Data.Text.Encoding as TE @@ -1003,3 +1006,184 @@ runAgentWithProvider engineCfg provider agentCfg userPrompt = do Just (Aeson.String s) -> "old_str not found" `Text.isInfixOf` s _ -> False isOldStrNotFoundProvider _ = False + +runAgentWithProviderStreaming :: + EngineConfig -> + Provider.Provider -> + AgentConfig -> + Text -> + (Text -> IO ()) -> + IO (Either Text AgentResult) +runAgentWithProviderStreaming engineCfg provider agentCfg userPrompt onStreamChunk = do + let tools = agentTools agentCfg + toolApis = map encodeToolForProvider tools + toolMap = buildToolMap tools + systemMsg = providerMessage Provider.System (agentSystemPrompt agentCfg) + userMsg = providerMessage Provider.User userPrompt + initialMessages = [systemMsg, userMsg] + + engineOnActivity engineCfg "Starting agent loop (Provider+Streaming)" + loopProviderStreaming provider toolApis toolMap initialMessages 0 0 0 0.0 Map.empty 0 0 + where + maxIter = agentMaxIterations agentCfg + guardrails' = agentGuardrails agentCfg + + providerMessage :: Provider.Role -> Text -> Provider.Message + providerMessage role content = Provider.Message role content Nothing Nothing + + loopProviderStreaming :: + Provider.Provider -> + [Provider.ToolApi] -> + Map.Map Text Tool -> + [Provider.Message] -> + Int -> + Int -> + Int -> + Double -> + Map.Map Text Int -> + Int -> + Int -> + IO (Either Text AgentResult) + loopProviderStreaming prov toolApis' toolMap msgs iteration totalCalls totalTokens totalCost toolCallCounts testFailures editFailures + | iteration >= maxIter = do + let errMsg = "Max iterations (" <> tshow maxIter <> ") reached" + engineOnError engineCfg errMsg + pure <| Left errMsg + | otherwise = do + let guardrailViolation = findGuardrailViolation guardrails' totalCost totalTokens toolCallCounts testFailures editFailures + case guardrailViolation of + Just (g, errMsg) -> do + engineOnGuardrail engineCfg g + pure <| Left errMsg + Nothing -> do + engineOnActivity engineCfg <| "Iteration " <> tshow (iteration + 1) + hasToolCalls <- newIORef False + result <- + Provider.chatStream prov toolApis' msgs <| \case + Provider.StreamContent txt -> onStreamChunk txt + Provider.StreamToolCall _ -> writeIORef hasToolCalls True + Provider.StreamToolCallDelta _ -> writeIORef hasToolCalls True + Provider.StreamError err -> engineOnError engineCfg err + Provider.StreamDone _ -> pure () + case result of + Left err -> do + engineOnError engineCfg err + pure (Left err) + Right chatRes -> do + let msg = Provider.chatMessage chatRes + tokens = maybe 0 Provider.usageTotalTokens (Provider.chatUsage chatRes) + cost = case Provider.chatUsage chatRes +> Provider.usageCost of + Just actualCost -> actualCost * 100 + Nothing -> estimateCost (getProviderModelStreaming prov) tokens + engineOnCost engineCfg tokens cost + let newTokens = totalTokens + tokens + newCost = totalCost + cost + let assistantText = Provider.msgContent msg + unless (Text.null assistantText) + <| engineOnAssistant engineCfg assistantText + case Provider.msgToolCalls msg of + Nothing + | Text.null (Provider.msgContent msg) && totalCalls > 0 -> do + engineOnActivity engineCfg "Empty response after tools, prompting for text" + let promptMsg = Provider.Message Provider.ToolRole "Please provide a response to the user." Nothing Nothing + newMsgs = msgs <> [msg, promptMsg] + loopProviderStreaming prov toolApis' toolMap newMsgs (iteration + 1) totalCalls newTokens newCost toolCallCounts testFailures editFailures + | otherwise -> do + engineOnActivity engineCfg "Agent completed" + engineOnComplete engineCfg + pure + <| Right + <| AgentResult + { resultFinalMessage = Provider.msgContent msg, + resultToolCallCount = totalCalls, + resultIterations = iteration + 1, + resultTotalCost = newCost, + resultTotalTokens = newTokens + } + Just [] -> do + engineOnActivity engineCfg "Agent completed (empty tool calls)" + engineOnComplete engineCfg + pure + <| Right + <| AgentResult + { resultFinalMessage = Provider.msgContent msg, + resultToolCallCount = totalCalls, + resultIterations = iteration + 1, + resultTotalCost = newCost, + resultTotalTokens = newTokens + } + Just tcs -> do + (toolResults, newTestFailures, newEditFailures) <- executeToolCallsStreaming engineCfg toolMap tcs testFailures editFailures + let newMsgs = msgs <> [msg] <> toolResults + newCalls = totalCalls + length tcs + newToolCallCounts = updateToolCallCountsStreaming toolCallCounts tcs + loopProviderStreaming prov toolApis' toolMap newMsgs (iteration + 1) newCalls newTokens newCost newToolCallCounts newTestFailures newEditFailures + + getProviderModelStreaming :: Provider.Provider -> Text + getProviderModelStreaming (Provider.OpenRouter cfg) = Provider.providerModel cfg + getProviderModelStreaming (Provider.Ollama cfg) = Provider.providerModel cfg + getProviderModelStreaming (Provider.AmpCLI _) = "amp" + + updateToolCallCountsStreaming :: Map.Map Text Int -> [Provider.ToolCall] -> Map.Map Text Int + updateToolCallCountsStreaming = + foldr (\tc m -> Map.insertWith (+) (Provider.fcName (Provider.tcFunction tc)) 1 m) + + executeToolCallsStreaming :: EngineConfig -> Map.Map Text Tool -> [Provider.ToolCall] -> Int -> Int -> IO ([Provider.Message], Int, Int) + executeToolCallsStreaming eCfg tMap tcs initialTestFailures initialEditFailures = do + results <- traverse (executeSingleStreaming eCfg tMap) tcs + let msgs = map (\(m, _, _) -> m) results + testDeltas = map (\(_, t, _) -> t) results + editDeltas = map (\(_, _, e) -> e) results + totalTestFail = initialTestFailures + sum testDeltas + totalEditFail = initialEditFailures + sum editDeltas + pure (msgs, totalTestFail, totalEditFail) + + executeSingleStreaming :: EngineConfig -> Map.Map Text Tool -> Provider.ToolCall -> IO (Provider.Message, Int, Int) + executeSingleStreaming eCfg tMap tc = do + let name = Provider.fcName (Provider.tcFunction tc) + argsText = Provider.fcArguments (Provider.tcFunction tc) + callId = Provider.tcId tc + engineOnActivity eCfg <| "Executing tool: " <> name + engineOnToolCall eCfg name argsText + case Map.lookup name tMap of + Nothing -> do + let errMsg = "Tool not found: " <> name + engineOnToolResult eCfg name False errMsg + pure (Provider.Message Provider.ToolRole errMsg Nothing (Just callId), 0, 0) + Just tool -> do + case Aeson.decode (BL.fromStrict (TE.encodeUtf8 argsText)) of + Nothing -> do + let errMsg = "Invalid JSON arguments: " <> argsText + engineOnToolResult eCfg name False errMsg + pure (Provider.Message Provider.ToolRole errMsg Nothing (Just callId), 0, 0) + Just args -> do + resultValue <- toolExecute tool args + let resultText = TE.decodeUtf8 (BL.toStrict (Aeson.encode resultValue)) + isTestCall = name == "bash" && ("bild --test" `Text.isInfixOf` argsText || "bild -t" `Text.isInfixOf` argsText) + isTestFailure = isTestCall && isFailureResultStreaming resultValue + testDelta = if isTestFailure then 1 else 0 + isEditFailure = name == "edit_file" && isOldStrNotFoundStreaming resultValue + editDelta = if isEditFailure then 1 else 0 + engineOnToolResult eCfg name True resultText + pure (Provider.Message Provider.ToolRole resultText Nothing (Just callId), testDelta, editDelta) + + isFailureResultStreaming :: Aeson.Value -> Bool + isFailureResultStreaming (Aeson.Object obj) = + case KeyMap.lookup "exit_code" obj of + Just (Aeson.Number n) -> n /= 0 + _ -> False + isFailureResultStreaming (Aeson.String s) = + "error" + `Text.isInfixOf` Text.toLower s + || "failed" + `Text.isInfixOf` Text.toLower s + || "FAILED" + `Text.isInfixOf` s + isFailureResultStreaming _ = False + + isOldStrNotFoundStreaming :: Aeson.Value -> Bool + isOldStrNotFoundStreaming (Aeson.Object obj) = + case KeyMap.lookup "error" obj of + Just (Aeson.String s) -> "old_str not found" `Text.isInfixOf` s + _ -> False + isOldStrNotFoundStreaming _ = False diff --git a/Omni/Agent/Provider.hs b/Omni/Agent/Provider.hs index 2ad6ea8..fd6920d 100644 --- a/Omni/Agent/Provider.hs +++ b/Omni/Agent/Provider.hs @@ -12,6 +12,8 @@ -- : out omni-agent-provider -- : dep aeson -- : dep http-conduit +-- : dep http-client-tls +-- : dep http-types -- : dep case-insensitive module Omni.Agent.Provider ( Provider (..), @@ -23,10 +25,12 @@ module Omni.Agent.Provider FunctionCall (..), Usage (..), ToolApi (..), + StreamChunk (..), defaultOpenRouter, defaultOllama, chat, chatWithUsage, + chatStream, main, test, ) @@ -36,11 +40,17 @@ import Alpha import Data.Aeson ((.=)) import qualified Data.Aeson as Aeson import qualified Data.Aeson.KeyMap as KeyMap +import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as BL import qualified Data.CaseInsensitive as CI +import Data.IORef (modifyIORef, newIORef, readIORef, writeIORef) +import qualified Data.IntMap.Strict as IntMap import qualified Data.Text as Text import qualified Data.Text.Encoding as TE +import qualified Network.HTTP.Client as HTTPClient +import qualified Network.HTTP.Client.TLS as HTTPClientTLS import qualified Network.HTTP.Simple as HTTP +import Network.HTTP.Types.Status (statusCode) import qualified Omni.Test as Test main :: IO () @@ -388,3 +398,250 @@ parseOllamaResponse val = Right msg -> pure (Right (ChatResult msg usageResult)) Left e -> pure (Left e) _ -> pure (Left "Expected object response from Ollama") + +data StreamChunk + = StreamContent Text + | StreamToolCall ToolCall + | StreamToolCallDelta ToolCallDelta + | StreamDone ChatResult + | StreamError Text + deriving (Show, Eq) + +data ToolCallDelta = ToolCallDelta + { tcdIndex :: Int, + tcdId :: Maybe Text, + tcdFunctionName :: Maybe Text, + tcdFunctionArgs :: Maybe Text + } + deriving (Show, Eq) + +chatStream :: Provider -> [ToolApi] -> [Message] -> (StreamChunk -> IO ()) -> IO (Either Text ChatResult) +chatStream (OpenRouter cfg) tools messages onChunk = chatStreamOpenAI cfg tools messages onChunk +chatStream (Ollama _cfg) _tools _messages _onChunk = pure (Left "Streaming not implemented for Ollama") +chatStream (AmpCLI _) _tools _messages _onChunk = pure (Left "Streaming not implemented for AmpCLI") + +chatStreamOpenAI :: ProviderConfig -> [ToolApi] -> [Message] -> (StreamChunk -> IO ()) -> IO (Either Text ChatResult) +chatStreamOpenAI cfg tools messages onChunk = do + let url = Text.unpack (providerBaseUrl cfg) <> "/chat/completions" + manager <- HTTPClient.newManager HTTPClientTLS.tlsManagerSettings + req0 <- HTTP.parseRequest url + let body = + Aeson.object + <| catMaybes + [ Just ("model" .= providerModel cfg), + Just ("messages" .= messages), + if null tools then Nothing else Just ("tools" .= tools), + Just ("stream" .= True), + Just ("usage" .= Aeson.object ["include" .= True]) + ] + baseReq = + HTTP.setRequestMethod "POST" + <| HTTP.setRequestHeader "Content-Type" ["application/json"] + <| HTTP.setRequestHeader "Authorization" ["Bearer " <> TE.encodeUtf8 (providerApiKey cfg)] + <| HTTP.setRequestBodyLBS (Aeson.encode body) + <| req0 + req = foldr addHeader baseReq (providerExtraHeaders cfg) + addHeader (name, value) = HTTP.addRequestHeader (CI.mk name) value + + HTTPClient.withResponse req manager <| \response -> do + let status = HTTPClient.responseStatus response + code = statusCode status + if code >= 200 && code < 300 + then processSSEStream (HTTPClient.responseBody response) onChunk + else do + bodyChunks <- readAllBody (HTTPClient.responseBody response) + let errBody = TE.decodeUtf8 (BS.concat bodyChunks) + pure (Left ("HTTP error: " <> tshow code <> " - " <> errBody)) + +readAllBody :: IO BS.ByteString -> IO [BS.ByteString] +readAllBody readBody = go [] + where + go acc = do + chunk <- readBody + if BS.null chunk + then pure (reverse acc) + else go (chunk : acc) + +data ToolCallAccum = ToolCallAccum + { tcaId :: Text, + tcaName :: Text, + tcaArgs :: Text + } + +processSSEStream :: IO BS.ByteString -> (StreamChunk -> IO ()) -> IO (Either Text ChatResult) +processSSEStream readBody onChunk = do + accumulatedContent <- newIORef ("" :: Text) + toolCallAccum <- newIORef (IntMap.empty :: IntMap.IntMap ToolCallAccum) + lastUsage <- newIORef (Nothing :: Maybe Usage) + buffer <- newIORef ("" :: Text) + + let loop = do + chunk <- readBody + if BS.null chunk + then do + content <- readIORef accumulatedContent + accum <- readIORef toolCallAccum + usage <- readIORef lastUsage + let toolCalls = map accumToToolCall (IntMap.elems accum) + finalMsg = + Message + { msgRole = Assistant, + msgContent = content, + msgToolCalls = if null toolCalls then Nothing else Just toolCalls, + msgToolCallId = Nothing + } + pure (Right (ChatResult finalMsg usage)) + else do + modifyIORef buffer (<> TE.decodeUtf8 chunk) + buf <- readIORef buffer + let (events, remaining) = parseSSEEvents buf + writeIORef buffer remaining + forM_ events <| \event -> do + case parseStreamEvent event of + Just (StreamContent txt) -> do + modifyIORef accumulatedContent (<> txt) + onChunk (StreamContent txt) + Just (StreamToolCallDelta delta) -> do + modifyIORef toolCallAccum (mergeToolCallDelta delta) + Just (StreamToolCall tc) -> do + modifyIORef toolCallAccum (mergeCompleteToolCall tc) + onChunk (StreamToolCall tc) + Just (StreamDone result) -> do + writeIORef lastUsage (chatUsage result) + Just (StreamError err) -> do + onChunk (StreamError err) + Nothing -> pure () + loop + + loop + +accumToToolCall :: ToolCallAccum -> ToolCall +accumToToolCall acc = + ToolCall + { tcId = tcaId acc, + tcType = "function", + tcFunction = FunctionCall (tcaName acc) (tcaArgs acc) + } + +mergeToolCallDelta :: ToolCallDelta -> IntMap.IntMap ToolCallAccum -> IntMap.IntMap ToolCallAccum +mergeToolCallDelta delta accum = + let idx = tcdIndex delta + existing = IntMap.lookup idx accum + updated = case existing of + Nothing -> + ToolCallAccum + { tcaId = fromMaybe "" (tcdId delta), + tcaName = fromMaybe "" (tcdFunctionName delta), + tcaArgs = fromMaybe "" (tcdFunctionArgs delta) + } + Just a -> + a + { tcaId = fromMaybe (tcaId a) (tcdId delta), + tcaName = fromMaybe (tcaName a) (tcdFunctionName delta), + tcaArgs = tcaArgs a <> fromMaybe "" (tcdFunctionArgs delta) + } + in IntMap.insert idx updated accum + +mergeCompleteToolCall :: ToolCall -> IntMap.IntMap ToolCallAccum -> IntMap.IntMap ToolCallAccum +mergeCompleteToolCall tc accum = + let nextIdx = if IntMap.null accum then 0 else fst (IntMap.findMax accum) + 1 + newAccum = + ToolCallAccum + { tcaId = tcId tc, + tcaName = fcName (tcFunction tc), + tcaArgs = fcArguments (tcFunction tc) + } + in IntMap.insert nextIdx newAccum accum + +parseSSEEvents :: Text -> ([Text], Text) +parseSSEEvents input = + let lines' = Text.splitOn "\n" input + (events, remaining) = go [] [] lines' + in (events, remaining) + where + go events current [] = (reverse events, Text.intercalate "\n" (reverse current)) + go events current (line : rest) + | Text.null line && not (null current) = + go (Text.intercalate "\n" (reverse current) : events) [] rest + | otherwise = + go events (line : current) rest + +parseStreamEvent :: Text -> Maybe StreamChunk +parseStreamEvent eventText = do + let dataLines = filter ("data:" `Text.isPrefixOf`) (Text.lines eventText) + case dataLines of + [] -> Nothing + (dataLine : _) -> do + let jsonStr = Text.strip (Text.drop 5 dataLine) + if jsonStr == "[DONE]" + then Nothing + else case Aeson.decode (BL.fromStrict (TE.encodeUtf8 jsonStr)) of + Nothing -> Nothing + Just (Aeson.Object obj) -> parseStreamChunk obj + _ -> Nothing + +parseStreamChunk :: Aeson.Object -> Maybe StreamChunk +parseStreamChunk obj = do + case KeyMap.lookup "error" obj of + Just (Aeson.Object errObj) -> do + let errMsg = case KeyMap.lookup "message" errObj of + Just (Aeson.String m) -> m + _ -> "Unknown error" + Just (StreamError errMsg) + _ -> do + case KeyMap.lookup "choices" obj of + Just (Aeson.Array choices) | not (null choices) -> do + case toList choices of + (Aeson.Object choice : _) -> do + case KeyMap.lookup "delta" choice of + Just (Aeson.Object delta) -> do + let contentChunk = case KeyMap.lookup "content" delta of + Just (Aeson.String c) | not (Text.null c) -> Just (StreamContent c) + _ -> Nothing + toolCallChunk = case KeyMap.lookup "tool_calls" delta of + Just (Aeson.Array tcs) + | not (null tcs) -> + parseToolCallDelta (toList tcs) + _ -> Nothing + contentChunk <|> toolCallChunk + _ -> Nothing + _ -> Nothing + _ -> do + case KeyMap.lookup "usage" obj of + Just usageVal -> case Aeson.fromJSON usageVal of + Aeson.Success usage -> Just (StreamDone (ChatResult (Message Assistant "" Nothing Nothing) (Just usage))) + _ -> Nothing + _ -> Nothing + +parseToolCallDelta :: [Aeson.Value] -> Maybe StreamChunk +parseToolCallDelta [] = Nothing +parseToolCallDelta (Aeson.Object tcObj : _) = do + idx <- case KeyMap.lookup "index" tcObj of + Just (Aeson.Number n) -> Just (round n) + _ -> Nothing + let tcId' = case KeyMap.lookup "id" tcObj of + Just (Aeson.String s) -> Just s + _ -> Nothing + funcObj = case KeyMap.lookup "function" tcObj of + Just (Aeson.Object f) -> Just f + _ -> Nothing + funcName = case funcObj of + Just f -> case KeyMap.lookup "name" f of + Just (Aeson.String s) -> Just s + _ -> Nothing + Nothing -> Nothing + funcArgs = case funcObj of + Just f -> case KeyMap.lookup "arguments" f of + Just (Aeson.String s) -> Just s + _ -> Nothing + Nothing -> Nothing + Just + ( StreamToolCallDelta + ToolCallDelta + { tcdIndex = idx, + tcdId = tcId', + tcdFunctionName = funcName, + tcdFunctionArgs = funcArgs + } + ) +parseToolCallDelta _ = Nothing diff --git a/Omni/Agent/Telegram.hs b/Omni/Agent/Telegram.hs index ee6784b..d6a8a30 100644 --- a/Omni/Agent/Telegram.hs +++ b/Omni/Agent/Telegram.hs @@ -30,6 +30,8 @@ module Omni.Agent.Telegram -- * Telegram API getUpdates, sendMessage, + sendMessageReturningId, + editMessage, sendTypingAction, -- * Media (re-exported from Media) @@ -67,8 +69,9 @@ import Data.Aeson ((.=)) import qualified Data.Aeson as Aeson import qualified Data.Aeson.KeyMap as KeyMap import qualified Data.ByteString.Lazy as BL +import Data.IORef (modifyIORef, newIORef, readIORef, writeIORef) import qualified Data.Text as Text -import Data.Time (getCurrentTime, utcToLocalTime) +import Data.Time (UTCTime (..), getCurrentTime, utcToLocalTime) import Data.Time.Format (defaultTimeLocale, formatTime) import Data.Time.LocalTime (getCurrentTimeZone) import qualified Network.HTTP.Client as HTTPClient @@ -221,6 +224,11 @@ getBotUsername cfg = do sendMessage :: Types.TelegramConfig -> Int -> Text -> IO () sendMessage cfg chatId text = do + _ <- sendMessageReturningId cfg chatId text + pure () + +sendMessageReturningId :: Types.TelegramConfig -> Int -> Text -> IO (Maybe Int) +sendMessageReturningId cfg chatId text = do let url = Text.unpack (Types.tgApiBaseUrl cfg) <> "/bot" @@ -237,6 +245,38 @@ sendMessage cfg chatId text = do <| HTTP.setRequestHeader "Content-Type" ["application/json"] <| HTTP.setRequestBodyLBS (Aeson.encode body) <| req0 + result <- try @SomeException (HTTP.httpLBS req) + case result of + Left _ -> pure Nothing + Right response -> do + let respBody = HTTP.getResponseBody response + case Aeson.decode respBody of + Just (Aeson.Object obj) -> case KeyMap.lookup "result" obj of + Just (Aeson.Object msgObj) -> case KeyMap.lookup "message_id" msgObj of + Just (Aeson.Number n) -> pure (Just (round n)) + _ -> pure Nothing + _ -> pure Nothing + _ -> pure Nothing + +editMessage :: Types.TelegramConfig -> Int -> Int -> Text -> IO () +editMessage cfg chatId messageId text = do + let url = + Text.unpack (Types.tgApiBaseUrl cfg) + <> "/bot" + <> Text.unpack (Types.tgBotToken cfg) + <> "/editMessageText" + body = + Aeson.object + [ "chat_id" .= chatId, + "message_id" .= messageId, + "text" .= text + ] + req0 <- HTTP.parseRequest url + let req = + HTTP.setRequestMethod "POST" + <| HTTP.setRequestHeader "Content-Type" ["application/json"] + <| HTTP.setRequestBodyLBS (Aeson.encode body) + <| req0 _ <- try @SomeException (HTTP.httpLBS req) pure () @@ -540,12 +580,40 @@ processEngagedMessage tgConfig provider engineCfg msg uid userName chatId userMe } } - result <- Engine.runAgentWithProvider engineCfg provider agentCfg userMessage + streamState <- newIORef StreamInit + lastUpdate <- newIORef (0 :: Int) + accumulatedText <- newIORef ("" :: Text) + + let onStreamChunk txt = do + modifyIORef accumulatedText (<> txt) + streamSt <- readIORef streamState + currentText <- readIORef accumulatedText + currentTime <- getCurrentTime + let nowMs = round (utctDayTime currentTime * 1000) :: Int + lastTime <- readIORef lastUpdate + + case streamSt of + StreamInit | Text.length currentText >= 20 -> do + maybeId <- sendMessageReturningId tgConfig chatId (currentText <> "...") + case maybeId of + Just msgId -> do + writeIORef streamState (StreamActive msgId) + writeIORef lastUpdate nowMs + Nothing -> pure () + StreamActive msgId | nowMs - lastTime > 400 -> do + editMessage tgConfig chatId msgId (currentText <> "...") + writeIORef lastUpdate nowMs + _ -> pure () + + result <- Engine.runAgentWithProviderStreaming engineCfg provider agentCfg userMessage onStreamChunk case result of Left err -> do putText <| "Agent error: " <> err - sendMessage tgConfig chatId "Sorry, I encountered an error. Please try again." + streamSt <- readIORef streamState + case streamSt of + StreamActive msgId -> editMessage tgConfig chatId msgId ("error: " <> err) + _ -> sendMessage tgConfig chatId "Sorry, I encountered an error. Please try again." Right agentResult -> do let response = Engine.resultFinalMessage agentResult putText <| "Response text: " <> Text.take 200 response @@ -558,9 +626,15 @@ processEngagedMessage tgConfig provider engineCfg msg uid userName chatId userMe then putText "Agent chose not to respond (group chat)" else do putText "Warning: empty response from agent" - sendMessage tgConfig chatId "hmm, i don't have a response for that" + streamSt <- readIORef streamState + case streamSt of + StreamActive msgId -> editMessage tgConfig chatId msgId "hmm, i don't have a response for that" + _ -> sendMessage tgConfig chatId "hmm, i don't have a response for that" else do - sendMessage tgConfig chatId response + streamSt <- readIORef streamState + case streamSt of + StreamActive msgId -> editMessage tgConfig chatId msgId response + _ -> sendMessage tgConfig chatId response checkAndSummarize (Types.tgOpenRouterApiKey tgConfig) uid chatId putText <| "Responded to " @@ -569,6 +643,9 @@ processEngagedMessage tgConfig provider engineCfg msg uid userName chatId userMe <> tshow (Engine.resultTotalCost agentResult) <> " cents)" +data StreamState = StreamInit | StreamActive Int + deriving (Show, Eq) + maxConversationTokens :: Int maxConversationTokens = 4000 |
