summaryrefslogtreecommitdiff
path: root/Omni
diff options
context:
space:
mode:
authorBen Sima <ben@bensima.com>2025-12-13 08:21:23 -0500
committerBen Sima <ben@bensima.com>2025-12-13 08:21:23 -0500
commit1c7b30005af27dcc3345f7dee0fe0404c3bc8c49 (patch)
tree25eb99ddfce74749264aa30dcf2992c207cc71a3 /Omni
parentf752330c9562b7a1bbdce15c05106a577daa2392 (diff)
fix: accumulate streaming tool call arguments across SSE chunks
OpenAI's SSE streaming sends tool calls incrementally - the first chunk has the id and function name, subsequent chunks contain argument fragments. Previously each chunk was treated as a complete tool call, causing invalid JSON arguments. - Add ToolCallDelta type with index for partial tool call data - Add StreamToolCallDelta chunk type - Track tool calls by index in IntMap accumulator - Merge argument fragments across chunks via mergeToolCallDelta - Build final ToolCall objects from accumulator when stream ends - Handle new StreamToolCallDelta in Engine.hs pattern match
Diffstat (limited to 'Omni')
-rw-r--r--Omni/Agent/Engine.hs184
-rw-r--r--Omni/Agent/Provider.hs257
-rw-r--r--Omni/Agent/Telegram.hs87
3 files changed, 523 insertions, 5 deletions
diff --git a/Omni/Agent/Engine.hs b/Omni/Agent/Engine.hs
index f9b0355..f137ddb 100644
--- a/Omni/Agent/Engine.hs
+++ b/Omni/Agent/Engine.hs
@@ -1,4 +1,5 @@
{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE NoImplicitPrelude #-}
@@ -39,6 +40,7 @@ module Omni.Agent.Engine
chat,
runAgent,
runAgentWithProvider,
+ runAgentWithProviderStreaming,
main,
test,
)
@@ -50,6 +52,7 @@ import qualified Data.Aeson as Aeson
import qualified Data.Aeson.KeyMap as KeyMap
import qualified Data.ByteString.Lazy as BL
import qualified Data.CaseInsensitive as CI
+import Data.IORef (newIORef, writeIORef)
import qualified Data.Map.Strict as Map
import qualified Data.Text as Text
import qualified Data.Text.Encoding as TE
@@ -1003,3 +1006,184 @@ runAgentWithProvider engineCfg provider agentCfg userPrompt = do
Just (Aeson.String s) -> "old_str not found" `Text.isInfixOf` s
_ -> False
isOldStrNotFoundProvider _ = False
+
+runAgentWithProviderStreaming ::
+ EngineConfig ->
+ Provider.Provider ->
+ AgentConfig ->
+ Text ->
+ (Text -> IO ()) ->
+ IO (Either Text AgentResult)
+runAgentWithProviderStreaming engineCfg provider agentCfg userPrompt onStreamChunk = do
+ let tools = agentTools agentCfg
+ toolApis = map encodeToolForProvider tools
+ toolMap = buildToolMap tools
+ systemMsg = providerMessage Provider.System (agentSystemPrompt agentCfg)
+ userMsg = providerMessage Provider.User userPrompt
+ initialMessages = [systemMsg, userMsg]
+
+ engineOnActivity engineCfg "Starting agent loop (Provider+Streaming)"
+ loopProviderStreaming provider toolApis toolMap initialMessages 0 0 0 0.0 Map.empty 0 0
+ where
+ maxIter = agentMaxIterations agentCfg
+ guardrails' = agentGuardrails agentCfg
+
+ providerMessage :: Provider.Role -> Text -> Provider.Message
+ providerMessage role content = Provider.Message role content Nothing Nothing
+
+ loopProviderStreaming ::
+ Provider.Provider ->
+ [Provider.ToolApi] ->
+ Map.Map Text Tool ->
+ [Provider.Message] ->
+ Int ->
+ Int ->
+ Int ->
+ Double ->
+ Map.Map Text Int ->
+ Int ->
+ Int ->
+ IO (Either Text AgentResult)
+ loopProviderStreaming prov toolApis' toolMap msgs iteration totalCalls totalTokens totalCost toolCallCounts testFailures editFailures
+ | iteration >= maxIter = do
+ let errMsg = "Max iterations (" <> tshow maxIter <> ") reached"
+ engineOnError engineCfg errMsg
+ pure <| Left errMsg
+ | otherwise = do
+ let guardrailViolation = findGuardrailViolation guardrails' totalCost totalTokens toolCallCounts testFailures editFailures
+ case guardrailViolation of
+ Just (g, errMsg) -> do
+ engineOnGuardrail engineCfg g
+ pure <| Left errMsg
+ Nothing -> do
+ engineOnActivity engineCfg <| "Iteration " <> tshow (iteration + 1)
+ hasToolCalls <- newIORef False
+ result <-
+ Provider.chatStream prov toolApis' msgs <| \case
+ Provider.StreamContent txt -> onStreamChunk txt
+ Provider.StreamToolCall _ -> writeIORef hasToolCalls True
+ Provider.StreamToolCallDelta _ -> writeIORef hasToolCalls True
+ Provider.StreamError err -> engineOnError engineCfg err
+ Provider.StreamDone _ -> pure ()
+ case result of
+ Left err -> do
+ engineOnError engineCfg err
+ pure (Left err)
+ Right chatRes -> do
+ let msg = Provider.chatMessage chatRes
+ tokens = maybe 0 Provider.usageTotalTokens (Provider.chatUsage chatRes)
+ cost = case Provider.chatUsage chatRes +> Provider.usageCost of
+ Just actualCost -> actualCost * 100
+ Nothing -> estimateCost (getProviderModelStreaming prov) tokens
+ engineOnCost engineCfg tokens cost
+ let newTokens = totalTokens + tokens
+ newCost = totalCost + cost
+ let assistantText = Provider.msgContent msg
+ unless (Text.null assistantText)
+ <| engineOnAssistant engineCfg assistantText
+ case Provider.msgToolCalls msg of
+ Nothing
+ | Text.null (Provider.msgContent msg) && totalCalls > 0 -> do
+ engineOnActivity engineCfg "Empty response after tools, prompting for text"
+ let promptMsg = Provider.Message Provider.ToolRole "Please provide a response to the user." Nothing Nothing
+ newMsgs = msgs <> [msg, promptMsg]
+ loopProviderStreaming prov toolApis' toolMap newMsgs (iteration + 1) totalCalls newTokens newCost toolCallCounts testFailures editFailures
+ | otherwise -> do
+ engineOnActivity engineCfg "Agent completed"
+ engineOnComplete engineCfg
+ pure
+ <| Right
+ <| AgentResult
+ { resultFinalMessage = Provider.msgContent msg,
+ resultToolCallCount = totalCalls,
+ resultIterations = iteration + 1,
+ resultTotalCost = newCost,
+ resultTotalTokens = newTokens
+ }
+ Just [] -> do
+ engineOnActivity engineCfg "Agent completed (empty tool calls)"
+ engineOnComplete engineCfg
+ pure
+ <| Right
+ <| AgentResult
+ { resultFinalMessage = Provider.msgContent msg,
+ resultToolCallCount = totalCalls,
+ resultIterations = iteration + 1,
+ resultTotalCost = newCost,
+ resultTotalTokens = newTokens
+ }
+ Just tcs -> do
+ (toolResults, newTestFailures, newEditFailures) <- executeToolCallsStreaming engineCfg toolMap tcs testFailures editFailures
+ let newMsgs = msgs <> [msg] <> toolResults
+ newCalls = totalCalls + length tcs
+ newToolCallCounts = updateToolCallCountsStreaming toolCallCounts tcs
+ loopProviderStreaming prov toolApis' toolMap newMsgs (iteration + 1) newCalls newTokens newCost newToolCallCounts newTestFailures newEditFailures
+
+ getProviderModelStreaming :: Provider.Provider -> Text
+ getProviderModelStreaming (Provider.OpenRouter cfg) = Provider.providerModel cfg
+ getProviderModelStreaming (Provider.Ollama cfg) = Provider.providerModel cfg
+ getProviderModelStreaming (Provider.AmpCLI _) = "amp"
+
+ updateToolCallCountsStreaming :: Map.Map Text Int -> [Provider.ToolCall] -> Map.Map Text Int
+ updateToolCallCountsStreaming =
+ foldr (\tc m -> Map.insertWith (+) (Provider.fcName (Provider.tcFunction tc)) 1 m)
+
+ executeToolCallsStreaming :: EngineConfig -> Map.Map Text Tool -> [Provider.ToolCall] -> Int -> Int -> IO ([Provider.Message], Int, Int)
+ executeToolCallsStreaming eCfg tMap tcs initialTestFailures initialEditFailures = do
+ results <- traverse (executeSingleStreaming eCfg tMap) tcs
+ let msgs = map (\(m, _, _) -> m) results
+ testDeltas = map (\(_, t, _) -> t) results
+ editDeltas = map (\(_, _, e) -> e) results
+ totalTestFail = initialTestFailures + sum testDeltas
+ totalEditFail = initialEditFailures + sum editDeltas
+ pure (msgs, totalTestFail, totalEditFail)
+
+ executeSingleStreaming :: EngineConfig -> Map.Map Text Tool -> Provider.ToolCall -> IO (Provider.Message, Int, Int)
+ executeSingleStreaming eCfg tMap tc = do
+ let name = Provider.fcName (Provider.tcFunction tc)
+ argsText = Provider.fcArguments (Provider.tcFunction tc)
+ callId = Provider.tcId tc
+ engineOnActivity eCfg <| "Executing tool: " <> name
+ engineOnToolCall eCfg name argsText
+ case Map.lookup name tMap of
+ Nothing -> do
+ let errMsg = "Tool not found: " <> name
+ engineOnToolResult eCfg name False errMsg
+ pure (Provider.Message Provider.ToolRole errMsg Nothing (Just callId), 0, 0)
+ Just tool -> do
+ case Aeson.decode (BL.fromStrict (TE.encodeUtf8 argsText)) of
+ Nothing -> do
+ let errMsg = "Invalid JSON arguments: " <> argsText
+ engineOnToolResult eCfg name False errMsg
+ pure (Provider.Message Provider.ToolRole errMsg Nothing (Just callId), 0, 0)
+ Just args -> do
+ resultValue <- toolExecute tool args
+ let resultText = TE.decodeUtf8 (BL.toStrict (Aeson.encode resultValue))
+ isTestCall = name == "bash" && ("bild --test" `Text.isInfixOf` argsText || "bild -t" `Text.isInfixOf` argsText)
+ isTestFailure = isTestCall && isFailureResultStreaming resultValue
+ testDelta = if isTestFailure then 1 else 0
+ isEditFailure = name == "edit_file" && isOldStrNotFoundStreaming resultValue
+ editDelta = if isEditFailure then 1 else 0
+ engineOnToolResult eCfg name True resultText
+ pure (Provider.Message Provider.ToolRole resultText Nothing (Just callId), testDelta, editDelta)
+
+ isFailureResultStreaming :: Aeson.Value -> Bool
+ isFailureResultStreaming (Aeson.Object obj) =
+ case KeyMap.lookup "exit_code" obj of
+ Just (Aeson.Number n) -> n /= 0
+ _ -> False
+ isFailureResultStreaming (Aeson.String s) =
+ "error"
+ `Text.isInfixOf` Text.toLower s
+ || "failed"
+ `Text.isInfixOf` Text.toLower s
+ || "FAILED"
+ `Text.isInfixOf` s
+ isFailureResultStreaming _ = False
+
+ isOldStrNotFoundStreaming :: Aeson.Value -> Bool
+ isOldStrNotFoundStreaming (Aeson.Object obj) =
+ case KeyMap.lookup "error" obj of
+ Just (Aeson.String s) -> "old_str not found" `Text.isInfixOf` s
+ _ -> False
+ isOldStrNotFoundStreaming _ = False
diff --git a/Omni/Agent/Provider.hs b/Omni/Agent/Provider.hs
index 2ad6ea8..fd6920d 100644
--- a/Omni/Agent/Provider.hs
+++ b/Omni/Agent/Provider.hs
@@ -12,6 +12,8 @@
-- : out omni-agent-provider
-- : dep aeson
-- : dep http-conduit
+-- : dep http-client-tls
+-- : dep http-types
-- : dep case-insensitive
module Omni.Agent.Provider
( Provider (..),
@@ -23,10 +25,12 @@ module Omni.Agent.Provider
FunctionCall (..),
Usage (..),
ToolApi (..),
+ StreamChunk (..),
defaultOpenRouter,
defaultOllama,
chat,
chatWithUsage,
+ chatStream,
main,
test,
)
@@ -36,11 +40,17 @@ import Alpha
import Data.Aeson ((.=))
import qualified Data.Aeson as Aeson
import qualified Data.Aeson.KeyMap as KeyMap
+import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
import qualified Data.CaseInsensitive as CI
+import Data.IORef (modifyIORef, newIORef, readIORef, writeIORef)
+import qualified Data.IntMap.Strict as IntMap
import qualified Data.Text as Text
import qualified Data.Text.Encoding as TE
+import qualified Network.HTTP.Client as HTTPClient
+import qualified Network.HTTP.Client.TLS as HTTPClientTLS
import qualified Network.HTTP.Simple as HTTP
+import Network.HTTP.Types.Status (statusCode)
import qualified Omni.Test as Test
main :: IO ()
@@ -388,3 +398,250 @@ parseOllamaResponse val =
Right msg -> pure (Right (ChatResult msg usageResult))
Left e -> pure (Left e)
_ -> pure (Left "Expected object response from Ollama")
+
+data StreamChunk
+ = StreamContent Text
+ | StreamToolCall ToolCall
+ | StreamToolCallDelta ToolCallDelta
+ | StreamDone ChatResult
+ | StreamError Text
+ deriving (Show, Eq)
+
+data ToolCallDelta = ToolCallDelta
+ { tcdIndex :: Int,
+ tcdId :: Maybe Text,
+ tcdFunctionName :: Maybe Text,
+ tcdFunctionArgs :: Maybe Text
+ }
+ deriving (Show, Eq)
+
+chatStream :: Provider -> [ToolApi] -> [Message] -> (StreamChunk -> IO ()) -> IO (Either Text ChatResult)
+chatStream (OpenRouter cfg) tools messages onChunk = chatStreamOpenAI cfg tools messages onChunk
+chatStream (Ollama _cfg) _tools _messages _onChunk = pure (Left "Streaming not implemented for Ollama")
+chatStream (AmpCLI _) _tools _messages _onChunk = pure (Left "Streaming not implemented for AmpCLI")
+
+chatStreamOpenAI :: ProviderConfig -> [ToolApi] -> [Message] -> (StreamChunk -> IO ()) -> IO (Either Text ChatResult)
+chatStreamOpenAI cfg tools messages onChunk = do
+ let url = Text.unpack (providerBaseUrl cfg) <> "/chat/completions"
+ manager <- HTTPClient.newManager HTTPClientTLS.tlsManagerSettings
+ req0 <- HTTP.parseRequest url
+ let body =
+ Aeson.object
+ <| catMaybes
+ [ Just ("model" .= providerModel cfg),
+ Just ("messages" .= messages),
+ if null tools then Nothing else Just ("tools" .= tools),
+ Just ("stream" .= True),
+ Just ("usage" .= Aeson.object ["include" .= True])
+ ]
+ baseReq =
+ HTTP.setRequestMethod "POST"
+ <| HTTP.setRequestHeader "Content-Type" ["application/json"]
+ <| HTTP.setRequestHeader "Authorization" ["Bearer " <> TE.encodeUtf8 (providerApiKey cfg)]
+ <| HTTP.setRequestBodyLBS (Aeson.encode body)
+ <| req0
+ req = foldr addHeader baseReq (providerExtraHeaders cfg)
+ addHeader (name, value) = HTTP.addRequestHeader (CI.mk name) value
+
+ HTTPClient.withResponse req manager <| \response -> do
+ let status = HTTPClient.responseStatus response
+ code = statusCode status
+ if code >= 200 && code < 300
+ then processSSEStream (HTTPClient.responseBody response) onChunk
+ else do
+ bodyChunks <- readAllBody (HTTPClient.responseBody response)
+ let errBody = TE.decodeUtf8 (BS.concat bodyChunks)
+ pure (Left ("HTTP error: " <> tshow code <> " - " <> errBody))
+
+readAllBody :: IO BS.ByteString -> IO [BS.ByteString]
+readAllBody readBody = go []
+ where
+ go acc = do
+ chunk <- readBody
+ if BS.null chunk
+ then pure (reverse acc)
+ else go (chunk : acc)
+
+data ToolCallAccum = ToolCallAccum
+ { tcaId :: Text,
+ tcaName :: Text,
+ tcaArgs :: Text
+ }
+
+processSSEStream :: IO BS.ByteString -> (StreamChunk -> IO ()) -> IO (Either Text ChatResult)
+processSSEStream readBody onChunk = do
+ accumulatedContent <- newIORef ("" :: Text)
+ toolCallAccum <- newIORef (IntMap.empty :: IntMap.IntMap ToolCallAccum)
+ lastUsage <- newIORef (Nothing :: Maybe Usage)
+ buffer <- newIORef ("" :: Text)
+
+ let loop = do
+ chunk <- readBody
+ if BS.null chunk
+ then do
+ content <- readIORef accumulatedContent
+ accum <- readIORef toolCallAccum
+ usage <- readIORef lastUsage
+ let toolCalls = map accumToToolCall (IntMap.elems accum)
+ finalMsg =
+ Message
+ { msgRole = Assistant,
+ msgContent = content,
+ msgToolCalls = if null toolCalls then Nothing else Just toolCalls,
+ msgToolCallId = Nothing
+ }
+ pure (Right (ChatResult finalMsg usage))
+ else do
+ modifyIORef buffer (<> TE.decodeUtf8 chunk)
+ buf <- readIORef buffer
+ let (events, remaining) = parseSSEEvents buf
+ writeIORef buffer remaining
+ forM_ events <| \event -> do
+ case parseStreamEvent event of
+ Just (StreamContent txt) -> do
+ modifyIORef accumulatedContent (<> txt)
+ onChunk (StreamContent txt)
+ Just (StreamToolCallDelta delta) -> do
+ modifyIORef toolCallAccum (mergeToolCallDelta delta)
+ Just (StreamToolCall tc) -> do
+ modifyIORef toolCallAccum (mergeCompleteToolCall tc)
+ onChunk (StreamToolCall tc)
+ Just (StreamDone result) -> do
+ writeIORef lastUsage (chatUsage result)
+ Just (StreamError err) -> do
+ onChunk (StreamError err)
+ Nothing -> pure ()
+ loop
+
+ loop
+
+accumToToolCall :: ToolCallAccum -> ToolCall
+accumToToolCall acc =
+ ToolCall
+ { tcId = tcaId acc,
+ tcType = "function",
+ tcFunction = FunctionCall (tcaName acc) (tcaArgs acc)
+ }
+
+mergeToolCallDelta :: ToolCallDelta -> IntMap.IntMap ToolCallAccum -> IntMap.IntMap ToolCallAccum
+mergeToolCallDelta delta accum =
+ let idx = tcdIndex delta
+ existing = IntMap.lookup idx accum
+ updated = case existing of
+ Nothing ->
+ ToolCallAccum
+ { tcaId = fromMaybe "" (tcdId delta),
+ tcaName = fromMaybe "" (tcdFunctionName delta),
+ tcaArgs = fromMaybe "" (tcdFunctionArgs delta)
+ }
+ Just a ->
+ a
+ { tcaId = fromMaybe (tcaId a) (tcdId delta),
+ tcaName = fromMaybe (tcaName a) (tcdFunctionName delta),
+ tcaArgs = tcaArgs a <> fromMaybe "" (tcdFunctionArgs delta)
+ }
+ in IntMap.insert idx updated accum
+
+mergeCompleteToolCall :: ToolCall -> IntMap.IntMap ToolCallAccum -> IntMap.IntMap ToolCallAccum
+mergeCompleteToolCall tc accum =
+ let nextIdx = if IntMap.null accum then 0 else fst (IntMap.findMax accum) + 1
+ newAccum =
+ ToolCallAccum
+ { tcaId = tcId tc,
+ tcaName = fcName (tcFunction tc),
+ tcaArgs = fcArguments (tcFunction tc)
+ }
+ in IntMap.insert nextIdx newAccum accum
+
+parseSSEEvents :: Text -> ([Text], Text)
+parseSSEEvents input =
+ let lines' = Text.splitOn "\n" input
+ (events, remaining) = go [] [] lines'
+ in (events, remaining)
+ where
+ go events current [] = (reverse events, Text.intercalate "\n" (reverse current))
+ go events current (line : rest)
+ | Text.null line && not (null current) =
+ go (Text.intercalate "\n" (reverse current) : events) [] rest
+ | otherwise =
+ go events (line : current) rest
+
+parseStreamEvent :: Text -> Maybe StreamChunk
+parseStreamEvent eventText = do
+ let dataLines = filter ("data:" `Text.isPrefixOf`) (Text.lines eventText)
+ case dataLines of
+ [] -> Nothing
+ (dataLine : _) -> do
+ let jsonStr = Text.strip (Text.drop 5 dataLine)
+ if jsonStr == "[DONE]"
+ then Nothing
+ else case Aeson.decode (BL.fromStrict (TE.encodeUtf8 jsonStr)) of
+ Nothing -> Nothing
+ Just (Aeson.Object obj) -> parseStreamChunk obj
+ _ -> Nothing
+
+parseStreamChunk :: Aeson.Object -> Maybe StreamChunk
+parseStreamChunk obj = do
+ case KeyMap.lookup "error" obj of
+ Just (Aeson.Object errObj) -> do
+ let errMsg = case KeyMap.lookup "message" errObj of
+ Just (Aeson.String m) -> m
+ _ -> "Unknown error"
+ Just (StreamError errMsg)
+ _ -> do
+ case KeyMap.lookup "choices" obj of
+ Just (Aeson.Array choices) | not (null choices) -> do
+ case toList choices of
+ (Aeson.Object choice : _) -> do
+ case KeyMap.lookup "delta" choice of
+ Just (Aeson.Object delta) -> do
+ let contentChunk = case KeyMap.lookup "content" delta of
+ Just (Aeson.String c) | not (Text.null c) -> Just (StreamContent c)
+ _ -> Nothing
+ toolCallChunk = case KeyMap.lookup "tool_calls" delta of
+ Just (Aeson.Array tcs)
+ | not (null tcs) ->
+ parseToolCallDelta (toList tcs)
+ _ -> Nothing
+ contentChunk <|> toolCallChunk
+ _ -> Nothing
+ _ -> Nothing
+ _ -> do
+ case KeyMap.lookup "usage" obj of
+ Just usageVal -> case Aeson.fromJSON usageVal of
+ Aeson.Success usage -> Just (StreamDone (ChatResult (Message Assistant "" Nothing Nothing) (Just usage)))
+ _ -> Nothing
+ _ -> Nothing
+
+parseToolCallDelta :: [Aeson.Value] -> Maybe StreamChunk
+parseToolCallDelta [] = Nothing
+parseToolCallDelta (Aeson.Object tcObj : _) = do
+ idx <- case KeyMap.lookup "index" tcObj of
+ Just (Aeson.Number n) -> Just (round n)
+ _ -> Nothing
+ let tcId' = case KeyMap.lookup "id" tcObj of
+ Just (Aeson.String s) -> Just s
+ _ -> Nothing
+ funcObj = case KeyMap.lookup "function" tcObj of
+ Just (Aeson.Object f) -> Just f
+ _ -> Nothing
+ funcName = case funcObj of
+ Just f -> case KeyMap.lookup "name" f of
+ Just (Aeson.String s) -> Just s
+ _ -> Nothing
+ Nothing -> Nothing
+ funcArgs = case funcObj of
+ Just f -> case KeyMap.lookup "arguments" f of
+ Just (Aeson.String s) -> Just s
+ _ -> Nothing
+ Nothing -> Nothing
+ Just
+ ( StreamToolCallDelta
+ ToolCallDelta
+ { tcdIndex = idx,
+ tcdId = tcId',
+ tcdFunctionName = funcName,
+ tcdFunctionArgs = funcArgs
+ }
+ )
+parseToolCallDelta _ = Nothing
diff --git a/Omni/Agent/Telegram.hs b/Omni/Agent/Telegram.hs
index ee6784b..d6a8a30 100644
--- a/Omni/Agent/Telegram.hs
+++ b/Omni/Agent/Telegram.hs
@@ -30,6 +30,8 @@ module Omni.Agent.Telegram
-- * Telegram API
getUpdates,
sendMessage,
+ sendMessageReturningId,
+ editMessage,
sendTypingAction,
-- * Media (re-exported from Media)
@@ -67,8 +69,9 @@ import Data.Aeson ((.=))
import qualified Data.Aeson as Aeson
import qualified Data.Aeson.KeyMap as KeyMap
import qualified Data.ByteString.Lazy as BL
+import Data.IORef (modifyIORef, newIORef, readIORef, writeIORef)
import qualified Data.Text as Text
-import Data.Time (getCurrentTime, utcToLocalTime)
+import Data.Time (UTCTime (..), getCurrentTime, utcToLocalTime)
import Data.Time.Format (defaultTimeLocale, formatTime)
import Data.Time.LocalTime (getCurrentTimeZone)
import qualified Network.HTTP.Client as HTTPClient
@@ -221,6 +224,11 @@ getBotUsername cfg = do
sendMessage :: Types.TelegramConfig -> Int -> Text -> IO ()
sendMessage cfg chatId text = do
+ _ <- sendMessageReturningId cfg chatId text
+ pure ()
+
+sendMessageReturningId :: Types.TelegramConfig -> Int -> Text -> IO (Maybe Int)
+sendMessageReturningId cfg chatId text = do
let url =
Text.unpack (Types.tgApiBaseUrl cfg)
<> "/bot"
@@ -237,6 +245,38 @@ sendMessage cfg chatId text = do
<| HTTP.setRequestHeader "Content-Type" ["application/json"]
<| HTTP.setRequestBodyLBS (Aeson.encode body)
<| req0
+ result <- try @SomeException (HTTP.httpLBS req)
+ case result of
+ Left _ -> pure Nothing
+ Right response -> do
+ let respBody = HTTP.getResponseBody response
+ case Aeson.decode respBody of
+ Just (Aeson.Object obj) -> case KeyMap.lookup "result" obj of
+ Just (Aeson.Object msgObj) -> case KeyMap.lookup "message_id" msgObj of
+ Just (Aeson.Number n) -> pure (Just (round n))
+ _ -> pure Nothing
+ _ -> pure Nothing
+ _ -> pure Nothing
+
+editMessage :: Types.TelegramConfig -> Int -> Int -> Text -> IO ()
+editMessage cfg chatId messageId text = do
+ let url =
+ Text.unpack (Types.tgApiBaseUrl cfg)
+ <> "/bot"
+ <> Text.unpack (Types.tgBotToken cfg)
+ <> "/editMessageText"
+ body =
+ Aeson.object
+ [ "chat_id" .= chatId,
+ "message_id" .= messageId,
+ "text" .= text
+ ]
+ req0 <- HTTP.parseRequest url
+ let req =
+ HTTP.setRequestMethod "POST"
+ <| HTTP.setRequestHeader "Content-Type" ["application/json"]
+ <| HTTP.setRequestBodyLBS (Aeson.encode body)
+ <| req0
_ <- try @SomeException (HTTP.httpLBS req)
pure ()
@@ -540,12 +580,40 @@ processEngagedMessage tgConfig provider engineCfg msg uid userName chatId userMe
}
}
- result <- Engine.runAgentWithProvider engineCfg provider agentCfg userMessage
+ streamState <- newIORef StreamInit
+ lastUpdate <- newIORef (0 :: Int)
+ accumulatedText <- newIORef ("" :: Text)
+
+ let onStreamChunk txt = do
+ modifyIORef accumulatedText (<> txt)
+ streamSt <- readIORef streamState
+ currentText <- readIORef accumulatedText
+ currentTime <- getCurrentTime
+ let nowMs = round (utctDayTime currentTime * 1000) :: Int
+ lastTime <- readIORef lastUpdate
+
+ case streamSt of
+ StreamInit | Text.length currentText >= 20 -> do
+ maybeId <- sendMessageReturningId tgConfig chatId (currentText <> "...")
+ case maybeId of
+ Just msgId -> do
+ writeIORef streamState (StreamActive msgId)
+ writeIORef lastUpdate nowMs
+ Nothing -> pure ()
+ StreamActive msgId | nowMs - lastTime > 400 -> do
+ editMessage tgConfig chatId msgId (currentText <> "...")
+ writeIORef lastUpdate nowMs
+ _ -> pure ()
+
+ result <- Engine.runAgentWithProviderStreaming engineCfg provider agentCfg userMessage onStreamChunk
case result of
Left err -> do
putText <| "Agent error: " <> err
- sendMessage tgConfig chatId "Sorry, I encountered an error. Please try again."
+ streamSt <- readIORef streamState
+ case streamSt of
+ StreamActive msgId -> editMessage tgConfig chatId msgId ("error: " <> err)
+ _ -> sendMessage tgConfig chatId "Sorry, I encountered an error. Please try again."
Right agentResult -> do
let response = Engine.resultFinalMessage agentResult
putText <| "Response text: " <> Text.take 200 response
@@ -558,9 +626,15 @@ processEngagedMessage tgConfig provider engineCfg msg uid userName chatId userMe
then putText "Agent chose not to respond (group chat)"
else do
putText "Warning: empty response from agent"
- sendMessage tgConfig chatId "hmm, i don't have a response for that"
+ streamSt <- readIORef streamState
+ case streamSt of
+ StreamActive msgId -> editMessage tgConfig chatId msgId "hmm, i don't have a response for that"
+ _ -> sendMessage tgConfig chatId "hmm, i don't have a response for that"
else do
- sendMessage tgConfig chatId response
+ streamSt <- readIORef streamState
+ case streamSt of
+ StreamActive msgId -> editMessage tgConfig chatId msgId response
+ _ -> sendMessage tgConfig chatId response
checkAndSummarize (Types.tgOpenRouterApiKey tgConfig) uid chatId
putText
<| "Responded to "
@@ -569,6 +643,9 @@ processEngagedMessage tgConfig provider engineCfg msg uid userName chatId userMe
<> tshow (Engine.resultTotalCost agentResult)
<> " cents)"
+data StreamState = StreamInit | StreamActive Int
+ deriving (Show, Eq)
+
maxConversationTokens :: Int
maxConversationTokens = 4000