summaryrefslogtreecommitdiff
path: root/Omni/Ava/Trace.hs
blob: 6dbdf513f5a5c7727965f638b91cd1237a45deac (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE NoImplicitPrelude #-}

-- | Tool trace storage for Ava.
--
-- Records tool execution traces for debugging and analytics.
--
-- : out omni-ava-trace
-- : dep aeson
-- : dep sqlite-simple
-- : dep uuid
module Omni.Ava.Trace
  ( TraceRecord (..),
    insertTrace,
    getTrace,
    cleanupOldTraces,
    main,
    test,
  )
where

import Alpha
import Data.Aeson ((.=))
import qualified Data.Aeson as Aeson
import qualified Data.Text as Text
import qualified Data.UUID as UUID
import qualified Data.UUID.V4 as UUID
import qualified Database.SQLite.Simple as SQL
import qualified Database.SQLite.Simple.ToField as SQL
import qualified Omni.Test as Test

main :: IO ()
main = Test.run test

test :: Test.Tree
test =
  Test.group
    "Omni.Ava.Trace"
    [ Test.unit "TraceRecord JSON roundtrip" <| do
        let tr =
              TraceRecord
                { trcId = "trace-123",
                  trcCreatedAt = "2024-01-15T10:30:00Z",
                  trcToolName = "web_search",
                  trcInput = "{\"query\":\"test\"}",
                  trcOutput = "{\"results\":[]}",
                  trcDurationMs = 150,
                  trcUserId = Just "user-456",
                  trcChatId = Just "chat-789"
                }
        case Aeson.decode (Aeson.encode tr) of
          Nothing -> Test.assertFailure "Failed to decode TraceRecord"
          Just decoded -> do
            trcToolName decoded Test.@=? "web_search"
            trcDurationMs decoded Test.@=? 150
    ]

data TraceRecord = TraceRecord
  { trcId :: Text,
    trcCreatedAt :: Text,
    trcToolName :: Text,
    trcInput :: Text,
    trcOutput :: Text,
    trcDurationMs :: Int,
    trcUserId :: Maybe Text,
    trcChatId :: Maybe Text
  }
  deriving (Show, Eq, Generic)

instance Aeson.ToJSON TraceRecord where
  toJSON tr =
    Aeson.object
      [ "id" .= trcId tr,
        "created_at" .= trcCreatedAt tr,
        "tool_name" .= trcToolName tr,
        "input" .= trcInput tr,
        "output" .= trcOutput tr,
        "duration_ms" .= trcDurationMs tr,
        "user_id" .= trcUserId tr,
        "chat_id" .= trcChatId tr
      ]

instance Aeson.FromJSON TraceRecord where
  parseJSON =
    Aeson.withObject "TraceRecord" <| \v ->
      (TraceRecord </ (v Aeson..: "id"))
        <*> (v Aeson..: "created_at")
        <*> (v Aeson..: "tool_name")
        <*> (v Aeson..: "input")
        <*> (v Aeson..: "output")
        <*> (v Aeson..: "duration_ms")
        <*> (v Aeson..:? "user_id")
        <*> (v Aeson..:? "chat_id")

instance SQL.FromRow TraceRecord where
  fromRow =
    (TraceRecord </ SQL.field)
      <*> SQL.field
      <*> SQL.field
      <*> SQL.field
      <*> SQL.field
      <*> SQL.field
      <*> SQL.field
      <*> SQL.field

instance SQL.ToRow TraceRecord where
  toRow tr =
    [ SQL.toField (trcId tr),
      SQL.toField (trcCreatedAt tr),
      SQL.toField (trcToolName tr),
      SQL.toField (trcInput tr),
      SQL.toField (trcOutput tr),
      SQL.toField (trcDurationMs tr),
      SQL.toField (trcUserId tr),
      SQL.toField (trcChatId tr)
    ]

insertTrace :: SQL.Connection -> TraceRecord -> IO Text
insertTrace conn tr = do
  tid <-
    if Text.null (trcId tr)
      then (Text.pack <. UUID.toString) </ UUID.nextRandom
      else pure (trcId tr)
  let trWithId = tr {trcId = tid}
  SQL.execute
    conn
    "INSERT INTO tool_traces (id, created_at, tool_name, input, output, duration_ms, user_id, chat_id) \
    \VALUES (?, ?, ?, ?, ?, ?, ?, ?)"
    trWithId
  pure tid

getTrace :: SQL.Connection -> Text -> IO (Maybe TraceRecord)
getTrace conn tid = do
  results <-
    SQL.query
      conn
      "SELECT id, created_at, tool_name, input, output, duration_ms, user_id, chat_id \
      \FROM tool_traces WHERE id = ?"
      (SQL.Only tid)
  pure (listToMaybe results)

cleanupOldTraces :: SQL.Connection -> IO Int
cleanupOldTraces conn = do
  SQL.execute_
    conn
    "DELETE FROM tool_traces WHERE created_at < datetime('now', '-7 days')"
  SQL.changes conn