summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Omni/Jr/Web.hs95
1 files changed, 95 insertions, 0 deletions
diff --git a/Omni/Jr/Web.hs b/Omni/Jr/Web.hs
index 86647d4..2be8ea1 100644
--- a/Omni/Jr/Web.hs
+++ b/Omni/Jr/Web.hs
@@ -17,6 +17,8 @@ module Omni.Jr.Web
where
import Alpha
+import qualified Control.Concurrent as Concurrent
+import qualified Data.Aeson as Aeson
import qualified Data.List as List
import qualified Data.Text as Text
import qualified Data.Text.Lazy as LazyText
@@ -33,6 +35,7 @@ import qualified Omni.Jr.Web.Style as Style
import qualified Omni.Task.Core as TaskCore
import Servant
import qualified Servant.HTML.Lucid as Lucid
+import qualified Servant.Types.SourceT as Source
import qualified System.Exit as Exit
import qualified System.Process as Process
import Web.FormUrlEncoded (FromForm (..), lookupUnique, parseUnique)
@@ -242,6 +245,7 @@ type API =
:> Get '[Lucid.HTML] TaskListPartial
:<|> "partials" :> "task" :> Capture "id" Text :> "metrics" :> Get '[Lucid.HTML] TaskMetricsPartial
:<|> "partials" :> "task" :> Capture "id" Text :> "events" :> QueryParam "since" Int :> Get '[Lucid.HTML] AgentEventsPartial
+ :<|> "tasks" :> Capture "id" Text :> "events" :> "stream" :> StreamGet NoFraming SSE (SourceIO ByteString)
data CSS
@@ -251,6 +255,14 @@ instance Accept CSS where
instance MimeRender CSS LazyText.Text where
mimeRender _ = LazyText.encodeUtf8
+data SSE
+
+instance Accept SSE where
+ contentType _ = "text/event-stream"
+
+instance MimeRender SSE ByteString where
+ mimeRender _ = identity
+
data HomePage = HomePage TaskCore.TaskStats [TaskCore.Task] [TaskCore.Task] Bool TaskCore.AggregatedMetrics TimeRange UTCTime
data ReadyQueuePage = ReadyQueuePage [TaskCore.Task] SortOrder UTCTime
@@ -2547,6 +2559,81 @@ instance Lucid.ToHtml AgentEventsPartial where
traverse_ (renderAgentEvent now) events
agentLogScrollScript
+-- | Stream agent events as SSE
+streamAgentEvents :: Text -> Text -> IO (SourceIO ByteString)
+streamAgentEvents tid sid = do
+ -- Get existing events first
+ existingEvents <- TaskCore.getEventsForSession sid
+ let lastId = if null existingEvents then 0 else maximum (map TaskCore.storedEventId existingEvents)
+
+ -- Convert existing events to SSE format
+ let existingSSE = map eventToSSE existingEvents
+
+ -- Create a streaming source that sends existing events, then polls for new ones
+ pure <| Source.fromStepT <| streamEventsStep tid sid lastId existingSSE True
+
+-- | Step function for streaming events
+streamEventsStep :: Text -> Text -> Int -> [ByteString] -> Bool -> Source.StepT IO ByteString
+streamEventsStep tid sid lastId buffer sendExisting = case (sendExisting, buffer) of
+ -- Send buffered existing events first
+ (True, b : bs) -> pure <| Source.Yield b (streamEventsStep tid sid lastId bs True)
+ (True, []) -> streamEventsStep tid sid lastId [] False
+ -- Poll for new events
+ (False, _) ->
+ Source.Effect <| do
+ -- Check if task is still in progress
+ tasks <- TaskCore.loadTasks
+ let isComplete = case TaskCore.findTask tid tasks of
+ Nothing -> True
+ Just task -> TaskCore.taskStatus task /= TaskCore.InProgress
+
+ if isComplete
+ then do
+ -- Send complete event and stop
+ let completeSSE = formatSSE "complete" "{}"
+ pure <| Source.Yield completeSSE Source.Stop
+ else do
+ -- Poll for new events
+ Concurrent.threadDelay 500000 -- 500ms
+ newEvents <- TaskCore.getEventsSince sid lastId
+ if null newEvents
+ then pure <| streamEventsStep tid sid lastId [] False
+ else do
+ let newLastId = maximum (map TaskCore.storedEventId newEvents)
+ let newSSE = map eventToSSE newEvents
+ case newSSE of
+ (e : es) -> pure <| Source.Yield e (streamEventsStep tid sid newLastId es False)
+ [] -> pure <| streamEventsStep tid sid newLastId [] False
+
+-- | Convert a StoredEvent to SSE format
+eventToSSE :: TaskCore.StoredEvent -> ByteString
+eventToSSE event =
+ let eventType = Text.toLower (TaskCore.storedEventType event)
+ content = TaskCore.storedEventContent event
+ jsonData = case eventType of
+ "assistant" -> Aeson.object ["content" Aeson..= content]
+ "toolcall" ->
+ let (tool, args) = parseToolCallContent content
+ in Aeson.object ["tool" Aeson..= tool, "args" Aeson..= Aeson.object ["data" Aeson..= args]]
+ "toolresult" ->
+ Aeson.object ["tool" Aeson..= ("unknown" :: Text), "success" Aeson..= True, "output" Aeson..= content]
+ "cost" -> Aeson.object ["cost" Aeson..= content]
+ "error" -> Aeson.object ["error" Aeson..= content]
+ "complete" -> Aeson.object []
+ _ -> Aeson.object ["content" Aeson..= content]
+ in formatSSE eventType (str (Aeson.encode jsonData))
+
+-- | Format an SSE message
+formatSSE :: Text -> ByteString -> ByteString
+formatSSE eventType jsonData =
+ str
+ <| "event: "
+ <> eventType
+ <> "\n"
+ <> "data: "
+ <> str jsonData
+ <> "\n\n"
+
api :: Proxy API
api = Proxy
@@ -2584,6 +2671,7 @@ server =
:<|> taskListPartialHandler
:<|> taskMetricsPartialHandler
:<|> agentEventsPartialHandler
+ :<|> taskEventsStreamHandler
where
styleHandler :: Servant.Handler LazyText.Text
styleHandler = pure Style.css
@@ -2903,6 +2991,13 @@ server =
Just task -> TaskCore.taskStatus task == TaskCore.InProgress
pure (AgentEventsPartial events isInProgress now)
+ taskEventsStreamHandler :: Text -> Servant.Handler (SourceIO ByteString)
+ taskEventsStreamHandler tid = do
+ maybeSession <- liftIO (TaskCore.getLatestSessionForTask tid)
+ case maybeSession of
+ Nothing -> pure (Source.source [])
+ Just sid -> liftIO (streamAgentEvents tid sid)
+
taskToUnixTs :: TaskCore.Task -> Int
taskToUnixTs t = round (utcTimeToPOSIXSeconds (TaskCore.taskUpdatedAt t))