diff options
| author | Ben Sima <ben@bensima.com> | 2025-12-11 19:50:20 -0500 |
|---|---|---|
| committer | Ben Sima <ben@bensima.com> | 2025-12-11 19:50:20 -0500 |
| commit | 276a27f27aeff7781a25e13fad0d568f5455ce05 (patch) | |
| tree | 6a7957986d14a9424f9e7f438dbd47a402b414fe /Omni/Agent/Provider.hs | |
| parent | 225e5b7a24f0b30f6de1bd7418bf834ad345b0f3 (diff) | |
t-247: Add Provider abstraction for multi-backend LLM support
- Create Omni/Agent/Provider.hs with unified Provider interface
- Support OpenRouter (cloud), Ollama (local), Amp (subprocess stub)
- Add runAgentWithProvider to Engine.hs for Provider-based execution
- Add EngineType to Core.hs (EngineOpenRouter, EngineOllama, EngineAmp)
- Add --engine flag to 'jr work' command
- Worker.hs dispatches to appropriate provider based on engine type
Usage:
jr work <task-id> # OpenRouter (default)
jr work <task-id> --engine=ollama # Local Ollama
jr work <task-id> --engine=amp # Amp CLI (stub)
Diffstat (limited to 'Omni/Agent/Provider.hs')
| -rw-r--r-- | Omni/Agent/Provider.hs | 386 |
1 files changed, 386 insertions, 0 deletions
diff --git a/Omni/Agent/Provider.hs b/Omni/Agent/Provider.hs new file mode 100644 index 0000000..a8a5381 --- /dev/null +++ b/Omni/Agent/Provider.hs @@ -0,0 +1,386 @@ +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE NoImplicitPrelude #-} + +-- | LLM Provider abstraction for multi-backend support. +-- +-- Supports multiple LLM backends: +-- - OpenRouter (cloud API, multiple models) +-- - Ollama (local models) +-- - Amp CLI (subprocess) +-- +-- : out omni-agent-provider +-- : dep aeson +-- : dep http-conduit +-- : dep case-insensitive +module Omni.Agent.Provider + ( Provider (..), + ProviderConfig (..), + ChatResult (..), + Message (..), + Role (..), + ToolCall (..), + FunctionCall (..), + Usage (..), + ToolApi (..), + defaultOpenRouter, + defaultOllama, + chat, + chatWithUsage, + main, + test, + ) +where + +import Alpha +import Data.Aeson ((.=)) +import qualified Data.Aeson as Aeson +import qualified Data.Aeson.KeyMap as KeyMap +import qualified Data.ByteString.Lazy as BL +import qualified Data.CaseInsensitive as CI +import qualified Data.Text as Text +import qualified Data.Text.Encoding as TE +import qualified Network.HTTP.Simple as HTTP +import qualified Omni.Test as Test + +main :: IO () +main = Test.run test + +test :: Test.Tree +test = + Test.group + "Omni.Agent.Provider" + [ Test.unit "defaultOpenRouter has correct endpoint" <| do + case defaultOpenRouter "" "test-model" of + OpenRouter cfg -> providerBaseUrl cfg Test.@=? "https://openrouter.ai/api/v1" + _ -> Test.assertFailure "Expected OpenRouter", + Test.unit "defaultOllama has correct endpoint" <| do + case defaultOllama "test-model" of + Ollama cfg -> providerBaseUrl cfg Test.@=? "http://localhost:11434" + _ -> Test.assertFailure "Expected Ollama", + Test.unit "ChatResult preserves message" <| do + let msg = Message User "test" Nothing Nothing + result = ChatResult msg Nothing + chatMessage result Test.@=? msg + ] + +data Provider + = OpenRouter ProviderConfig + | Ollama ProviderConfig + | AmpCLI FilePath + deriving (Show, Eq, Generic) + +data ProviderConfig = ProviderConfig + { providerBaseUrl :: Text, + providerApiKey :: Text, + providerModel :: Text, + providerExtraHeaders :: [(ByteString, ByteString)] + } + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON ProviderConfig where + toJSON c = + Aeson.object + [ "baseUrl" .= providerBaseUrl c, + "apiKey" .= providerApiKey c, + "model" .= providerModel c + ] + +instance Aeson.FromJSON ProviderConfig where + parseJSON = + Aeson.withObject "ProviderConfig" <| \v -> + (ProviderConfig </ (v Aeson..: "baseUrl")) + <*> (v Aeson..: "apiKey") + <*> (v Aeson..: "model") + <*> pure [] + +defaultOpenRouter :: Text -> Text -> Provider +defaultOpenRouter apiKey model = + OpenRouter + ProviderConfig + { providerBaseUrl = "https://openrouter.ai/api/v1", + providerApiKey = apiKey, + providerModel = model, + providerExtraHeaders = + [ ("HTTP-Referer", "https://omni.dev"), + ("X-Title", "Omni Agent") + ] + } + +defaultOllama :: Text -> Provider +defaultOllama model = + Ollama + ProviderConfig + { providerBaseUrl = "http://localhost:11434", + providerApiKey = "", + providerModel = model, + providerExtraHeaders = [] + } + +data Role = System | User | Assistant | ToolRole + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON Role where + toJSON System = Aeson.String "system" + toJSON User = Aeson.String "user" + toJSON Assistant = Aeson.String "assistant" + toJSON ToolRole = Aeson.String "tool" + +instance Aeson.FromJSON Role where + parseJSON = Aeson.withText "Role" parseRole + where + parseRole "system" = pure System + parseRole "user" = pure User + parseRole "assistant" = pure Assistant + parseRole "tool" = pure ToolRole + parseRole _ = empty + +data Message = Message + { msgRole :: Role, + msgContent :: Text, + msgToolCalls :: Maybe [ToolCall], + msgToolCallId :: Maybe Text + } + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON Message where + toJSON m = + Aeson.object + <| catMaybes + [ Just ("role" .= msgRole m), + Just ("content" .= msgContent m), + ("tool_calls" .=) </ msgToolCalls m, + ("tool_call_id" .=) </ msgToolCallId m + ] + +instance Aeson.FromJSON Message where + parseJSON = + Aeson.withObject "Message" <| \v -> + (Message </ (v Aeson..: "role")) + <*> (v Aeson..:? "content" Aeson..!= "") + <*> (v Aeson..:? "tool_calls") + <*> (v Aeson..:? "tool_call_id") + +data ToolCall = ToolCall + { tcId :: Text, + tcType :: Text, + tcFunction :: FunctionCall + } + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON ToolCall where + toJSON tc = + Aeson.object + [ "id" .= tcId tc, + "type" .= tcType tc, + "function" .= tcFunction tc + ] + +instance Aeson.FromJSON ToolCall where + parseJSON = + Aeson.withObject "ToolCall" <| \v -> + (ToolCall </ (v Aeson..: "id")) + <*> (v Aeson..:? "type" Aeson..!= "function") + <*> (v Aeson..: "function") + +data FunctionCall = FunctionCall + { fcName :: Text, + fcArguments :: Text + } + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON FunctionCall where + toJSON fc = + Aeson.object + [ "name" .= fcName fc, + "arguments" .= fcArguments fc + ] + +instance Aeson.FromJSON FunctionCall where + parseJSON = + Aeson.withObject "FunctionCall" <| \v -> + (FunctionCall </ (v Aeson..: "name")) + <*> (v Aeson..: "arguments") + +data Usage = Usage + { usagePromptTokens :: Int, + usageCompletionTokens :: Int, + usageTotalTokens :: Int, + usageCost :: Maybe Double + } + deriving (Show, Eq, Generic) + +instance Aeson.FromJSON Usage where + parseJSON = + Aeson.withObject "Usage" <| \v -> + (Usage </ (v Aeson..: "prompt_tokens")) + <*> (v Aeson..: "completion_tokens") + <*> (v Aeson..: "total_tokens") + <*> (v Aeson..:? "cost") + +data ChatResult = ChatResult + { chatMessage :: Message, + chatUsage :: Maybe Usage + } + deriving (Show, Eq) + +data ToolApi = ToolApi + { toolApiName :: Text, + toolApiDescription :: Text, + toolApiParameters :: Aeson.Value + } + deriving (Generic) + +instance Aeson.ToJSON ToolApi where + toJSON t = + Aeson.object + [ "type" .= ("function" :: Text), + "function" + .= Aeson.object + [ "name" .= toolApiName t, + "description" .= toolApiDescription t, + "parameters" .= toolApiParameters t + ] + ] + +data ChatCompletionRequest = ChatCompletionRequest + { reqModel :: Text, + reqMessages :: [Message], + reqTools :: Maybe [ToolApi] + } + deriving (Generic) + +instance Aeson.ToJSON ChatCompletionRequest where + toJSON r = + Aeson.object + <| catMaybes + [ Just ("model" .= reqModel r), + Just ("messages" .= reqMessages r), + ("tools" .=) </ reqTools r, + Just ("usage" .= Aeson.object ["include" .= True]) + ] + +data Choice = Choice + { choiceIndex :: Int, + choiceMessage :: Message, + choiceFinishReason :: Maybe Text + } + deriving (Show, Eq, Generic) + +instance Aeson.FromJSON Choice where + parseJSON = + Aeson.withObject "Choice" <| \v -> + (Choice </ (v Aeson..: "index")) + <*> (v Aeson..: "message") + <*> (v Aeson..:? "finish_reason") + +data ChatCompletionResponse = ChatCompletionResponse + { respId :: Text, + respChoices :: [Choice], + respModel :: Text, + respUsage :: Maybe Usage + } + deriving (Show, Eq, Generic) + +instance Aeson.FromJSON ChatCompletionResponse where + parseJSON = + Aeson.withObject "ChatCompletionResponse" <| \v -> + (ChatCompletionResponse </ (v Aeson..: "id")) + <*> (v Aeson..: "choices") + <*> (v Aeson..: "model") + <*> (v Aeson..:? "usage") + +chat :: Provider -> [ToolApi] -> [Message] -> IO (Either Text Message) +chat provider tools messages = do + result <- chatWithUsage provider tools messages + pure (chatMessage </ result) + +chatWithUsage :: Provider -> [ToolApi] -> [Message] -> IO (Either Text ChatResult) +chatWithUsage (OpenRouter cfg) tools messages = chatOpenAI cfg tools messages +chatWithUsage (Ollama cfg) tools messages = chatOllama cfg tools messages +chatWithUsage (AmpCLI _promptFile) _tools _messages = do + pure (Left "Amp CLI provider not yet implemented") + +chatOpenAI :: ProviderConfig -> [ToolApi] -> [Message] -> IO (Either Text ChatResult) +chatOpenAI cfg tools messages = do + let url = Text.unpack (providerBaseUrl cfg) <> "/chat/completions" + req0 <- HTTP.parseRequest url + let body = + ChatCompletionRequest + { reqModel = providerModel cfg, + reqMessages = messages, + reqTools = if null tools then Nothing else Just tools + } + baseReq = + HTTP.setRequestMethod "POST" + <| HTTP.setRequestHeader "Content-Type" ["application/json"] + <| HTTP.setRequestHeader "Authorization" ["Bearer " <> TE.encodeUtf8 (providerApiKey cfg)] + <| HTTP.setRequestBodyLBS (Aeson.encode body) + <| req0 + req = foldr addHeader baseReq (providerExtraHeaders cfg) + addHeader (name, value) = HTTP.addRequestHeader (CI.mk name) value + + response <- HTTP.httpLBS req + let status = HTTP.getResponseStatusCode response + if status >= 200 && status < 300 + then case Aeson.decode (HTTP.getResponseBody response) of + Just resp -> + case respChoices resp of + (c : _) -> pure (Right (ChatResult (choiceMessage c) (respUsage resp))) + [] -> pure (Left "No choices in response") + Nothing -> pure (Left "Failed to parse response") + else pure (Left ("HTTP error: " <> tshow status <> " - " <> TE.decodeUtf8 (BL.toStrict (HTTP.getResponseBody response)))) + +chatOllama :: ProviderConfig -> [ToolApi] -> [Message] -> IO (Either Text ChatResult) +chatOllama cfg tools messages = do + let url = Text.unpack (providerBaseUrl cfg) <> "/api/chat" + req0 <- HTTP.parseRequest url + let body = + Aeson.object + [ "model" .= providerModel cfg, + "messages" .= messages, + "tools" .= if null tools then Aeson.Null else Aeson.toJSON tools, + "stream" .= False + ] + req = + HTTP.setRequestMethod "POST" + <| HTTP.setRequestHeader "Content-Type" ["application/json"] + <| HTTP.setRequestBodyLBS (Aeson.encode body) + <| req0 + + response <- HTTP.httpLBS req + let status = HTTP.getResponseStatusCode response + if status >= 200 && status < 300 + then case Aeson.decode (HTTP.getResponseBody response) of + Just resp -> parseOllamaResponse resp + Nothing -> pure (Left ("Failed to parse Ollama response: " <> TE.decodeUtf8 (BL.toStrict (HTTP.getResponseBody response)))) + else pure (Left ("HTTP error: " <> tshow status <> " - " <> TE.decodeUtf8 (BL.toStrict (HTTP.getResponseBody response)))) + +parseOllamaResponse :: Aeson.Value -> IO (Either Text ChatResult) +parseOllamaResponse val = + case val of + Aeson.Object obj -> do + let msgResult = do + msgObj <- case KeyMap.lookup "message" obj of + Just m -> Right m + Nothing -> Left "No message in response" + case Aeson.fromJSON msgObj of + Aeson.Success msg -> Right msg + Aeson.Error e -> Left (Text.pack e) + usageResult = case KeyMap.lookup "prompt_eval_count" obj of + Just (Aeson.Number promptTokens) -> + case KeyMap.lookup "eval_count" obj of + Just (Aeson.Number evalTokens) -> + Just + Usage + { usagePromptTokens = round promptTokens, + usageCompletionTokens = round evalTokens, + usageTotalTokens = round promptTokens + round evalTokens, + usageCost = Nothing + } + _ -> Nothing + _ -> Nothing + case msgResult of + Right msg -> pure (Right (ChatResult msg usageResult)) + Left e -> pure (Left e) + _ -> pure (Left "Expected object response from Ollama") |
