diff options
Diffstat (limited to 'Omni')
| -rw-r--r-- | Omni/Jr/Web.hs | 95 |
1 files changed, 95 insertions, 0 deletions
diff --git a/Omni/Jr/Web.hs b/Omni/Jr/Web.hs index 86647d4..2be8ea1 100644 --- a/Omni/Jr/Web.hs +++ b/Omni/Jr/Web.hs @@ -17,6 +17,8 @@ module Omni.Jr.Web where import Alpha +import qualified Control.Concurrent as Concurrent +import qualified Data.Aeson as Aeson import qualified Data.List as List import qualified Data.Text as Text import qualified Data.Text.Lazy as LazyText @@ -33,6 +35,7 @@ import qualified Omni.Jr.Web.Style as Style import qualified Omni.Task.Core as TaskCore import Servant import qualified Servant.HTML.Lucid as Lucid +import qualified Servant.Types.SourceT as Source import qualified System.Exit as Exit import qualified System.Process as Process import Web.FormUrlEncoded (FromForm (..), lookupUnique, parseUnique) @@ -242,6 +245,7 @@ type API = :> Get '[Lucid.HTML] TaskListPartial :<|> "partials" :> "task" :> Capture "id" Text :> "metrics" :> Get '[Lucid.HTML] TaskMetricsPartial :<|> "partials" :> "task" :> Capture "id" Text :> "events" :> QueryParam "since" Int :> Get '[Lucid.HTML] AgentEventsPartial + :<|> "tasks" :> Capture "id" Text :> "events" :> "stream" :> StreamGet NoFraming SSE (SourceIO ByteString) data CSS @@ -251,6 +255,14 @@ instance Accept CSS where instance MimeRender CSS LazyText.Text where mimeRender _ = LazyText.encodeUtf8 +data SSE + +instance Accept SSE where + contentType _ = "text/event-stream" + +instance MimeRender SSE ByteString where + mimeRender _ = identity + data HomePage = HomePage TaskCore.TaskStats [TaskCore.Task] [TaskCore.Task] Bool TaskCore.AggregatedMetrics TimeRange UTCTime data ReadyQueuePage = ReadyQueuePage [TaskCore.Task] SortOrder UTCTime @@ -2547,6 +2559,81 @@ instance Lucid.ToHtml AgentEventsPartial where traverse_ (renderAgentEvent now) events agentLogScrollScript +-- | Stream agent events as SSE +streamAgentEvents :: Text -> Text -> IO (SourceIO ByteString) +streamAgentEvents tid sid = do + -- Get existing events first + existingEvents <- TaskCore.getEventsForSession sid + let lastId = if null existingEvents then 0 else maximum (map TaskCore.storedEventId existingEvents) + + -- Convert existing events to SSE format + let existingSSE = map eventToSSE existingEvents + + -- Create a streaming source that sends existing events, then polls for new ones + pure <| Source.fromStepT <| streamEventsStep tid sid lastId existingSSE True + +-- | Step function for streaming events +streamEventsStep :: Text -> Text -> Int -> [ByteString] -> Bool -> Source.StepT IO ByteString +streamEventsStep tid sid lastId buffer sendExisting = case (sendExisting, buffer) of + -- Send buffered existing events first + (True, b : bs) -> pure <| Source.Yield b (streamEventsStep tid sid lastId bs True) + (True, []) -> streamEventsStep tid sid lastId [] False + -- Poll for new events + (False, _) -> + Source.Effect <| do + -- Check if task is still in progress + tasks <- TaskCore.loadTasks + let isComplete = case TaskCore.findTask tid tasks of + Nothing -> True + Just task -> TaskCore.taskStatus task /= TaskCore.InProgress + + if isComplete + then do + -- Send complete event and stop + let completeSSE = formatSSE "complete" "{}" + pure <| Source.Yield completeSSE Source.Stop + else do + -- Poll for new events + Concurrent.threadDelay 500000 -- 500ms + newEvents <- TaskCore.getEventsSince sid lastId + if null newEvents + then pure <| streamEventsStep tid sid lastId [] False + else do + let newLastId = maximum (map TaskCore.storedEventId newEvents) + let newSSE = map eventToSSE newEvents + case newSSE of + (e : es) -> pure <| Source.Yield e (streamEventsStep tid sid newLastId es False) + [] -> pure <| streamEventsStep tid sid newLastId [] False + +-- | Convert a StoredEvent to SSE format +eventToSSE :: TaskCore.StoredEvent -> ByteString +eventToSSE event = + let eventType = Text.toLower (TaskCore.storedEventType event) + content = TaskCore.storedEventContent event + jsonData = case eventType of + "assistant" -> Aeson.object ["content" Aeson..= content] + "toolcall" -> + let (tool, args) = parseToolCallContent content + in Aeson.object ["tool" Aeson..= tool, "args" Aeson..= Aeson.object ["data" Aeson..= args]] + "toolresult" -> + Aeson.object ["tool" Aeson..= ("unknown" :: Text), "success" Aeson..= True, "output" Aeson..= content] + "cost" -> Aeson.object ["cost" Aeson..= content] + "error" -> Aeson.object ["error" Aeson..= content] + "complete" -> Aeson.object [] + _ -> Aeson.object ["content" Aeson..= content] + in formatSSE eventType (str (Aeson.encode jsonData)) + +-- | Format an SSE message +formatSSE :: Text -> ByteString -> ByteString +formatSSE eventType jsonData = + str + <| "event: " + <> eventType + <> "\n" + <> "data: " + <> str jsonData + <> "\n\n" + api :: Proxy API api = Proxy @@ -2584,6 +2671,7 @@ server = :<|> taskListPartialHandler :<|> taskMetricsPartialHandler :<|> agentEventsPartialHandler + :<|> taskEventsStreamHandler where styleHandler :: Servant.Handler LazyText.Text styleHandler = pure Style.css @@ -2903,6 +2991,13 @@ server = Just task -> TaskCore.taskStatus task == TaskCore.InProgress pure (AgentEventsPartial events isInProgress now) + taskEventsStreamHandler :: Text -> Servant.Handler (SourceIO ByteString) + taskEventsStreamHandler tid = do + maybeSession <- liftIO (TaskCore.getLatestSessionForTask tid) + case maybeSession of + Nothing -> pure (Source.source []) + Just sid -> liftIO (streamAgentEvents tid sid) + taskToUnixTs :: TaskCore.Task -> Int taskToUnixTs t = round (utcTimeToPOSIXSeconds (TaskCore.taskUpdatedAt t)) |
