diff options
| author | Ben Sima <ben@bensima.com> | 2025-12-11 19:50:20 -0500 |
|---|---|---|
| committer | Ben Sima <ben@bensima.com> | 2025-12-11 19:50:20 -0500 |
| commit | 276a27f27aeff7781a25e13fad0d568f5455ce05 (patch) | |
| tree | 6a7957986d14a9424f9e7f438dbd47a402b414fe /Omni | |
| parent | 225e5b7a24f0b30f6de1bd7418bf834ad345b0f3 (diff) | |
t-247: Add Provider abstraction for multi-backend LLM support
- Create Omni/Agent/Provider.hs with unified Provider interface
- Support OpenRouter (cloud), Ollama (local), Amp (subprocess stub)
- Add runAgentWithProvider to Engine.hs for Provider-based execution
- Add EngineType to Core.hs (EngineOpenRouter, EngineOllama, EngineAmp)
- Add --engine flag to 'jr work' command
- Worker.hs dispatches to appropriate provider based on engine type
Usage:
jr work <task-id> # OpenRouter (default)
jr work <task-id> --engine=ollama # Local Ollama
jr work <task-id> --engine=amp # Amp CLI (stub)
Diffstat (limited to 'Omni')
| -rw-r--r-- | Omni/Agent/Core.hs | 14 | ||||
| -rw-r--r-- | Omni/Agent/Engine.hs | 176 | ||||
| -rw-r--r-- | Omni/Agent/Provider.hs | 386 | ||||
| -rw-r--r-- | Omni/Agent/Worker.hs | 11 | ||||
| -rwxr-xr-x | Omni/Jr.hs | 15 |
5 files changed, 596 insertions, 6 deletions
diff --git a/Omni/Agent/Core.hs b/Omni/Agent/Core.hs index 88f7237..fb4a4b3 100644 --- a/Omni/Agent/Core.hs +++ b/Omni/Agent/Core.hs @@ -6,6 +6,17 @@ module Omni.Agent.Core where import Alpha import Data.Aeson (FromJSON, ToJSON) +-- | Engine/provider selection for agent +data EngineType + = EngineOpenRouter + | EngineOllama + | EngineAmp + deriving (Show, Eq, Generic) + +instance ToJSON EngineType + +instance FromJSON EngineType + -- | Status of a worker agent data WorkerStatus = Idle @@ -28,7 +39,8 @@ data Worker = Worker workerPid :: Maybe Int, workerStatus :: WorkerStatus, workerPath :: FilePath, - workerQuiet :: Bool -- Disable ANSI status bar (for loop mode) + workerQuiet :: Bool, -- Disable ANSI status bar (for loop mode) + workerEngine :: EngineType -- Which LLM backend to use } deriving (Show, Eq, Generic) diff --git a/Omni/Agent/Engine.hs b/Omni/Agent/Engine.hs index 4ee5e5d..fe3b3d5 100644 --- a/Omni/Agent/Engine.hs +++ b/Omni/Agent/Engine.hs @@ -30,12 +30,15 @@ module Omni.Agent.Engine ChatCompletionResponse (..), Choice (..), Usage (..), + ToolApi (..), + encodeToolForApi, defaultLLM, defaultEngineConfig, defaultAgentConfig, defaultGuardrails, chat, runAgent, + runAgentWithProvider, main, test, ) @@ -51,6 +54,7 @@ import qualified Data.Map.Strict as Map import qualified Data.Text as Text import qualified Data.Text.Encoding as TE import qualified Network.HTTP.Simple as HTTP +import qualified Omni.Agent.Provider as Provider import qualified Omni.Test as Test main :: IO () @@ -264,6 +268,14 @@ encodeToolForApi t = toolApiParameters = toolJsonSchema t } +encodeToolForProvider :: Tool -> Provider.ToolApi +encodeToolForProvider t = + Provider.ToolApi + { Provider.toolApiName = toolName t, + Provider.toolApiDescription = toolDescription t, + Provider.toolApiParameters = toolJsonSchema t + } + data LLM = LLM { llmBaseUrl :: Text, llmApiKey :: Text, @@ -809,3 +821,167 @@ estimateCost model tokens | "gpt-4" `Text.isInfixOf` model = fromIntegral tokens * 3 / 100000 | "claude" `Text.isInfixOf` model = fromIntegral tokens * 3 / 100000 | otherwise = fromIntegral tokens / 100000 + +-- | Run agent with a Provider instead of LLM. +-- This is the new preferred way to run agents with multiple backend support. +runAgentWithProvider :: EngineConfig -> Provider.Provider -> AgentConfig -> Text -> IO (Either Text AgentResult) +runAgentWithProvider engineCfg provider agentCfg userPrompt = do + let tools = agentTools agentCfg + toolApis = map encodeToolForProvider tools + toolMap = buildToolMap tools + systemMsg = providerMessage Provider.System (agentSystemPrompt agentCfg) + userMsg = providerMessage Provider.User userPrompt + initialMessages = [systemMsg, userMsg] + + engineOnActivity engineCfg "Starting agent loop (Provider)" + loopProvider provider toolApis toolMap initialMessages 0 0 0 0.0 Map.empty 0 0 + where + maxIter = agentMaxIterations agentCfg + guardrails' = agentGuardrails agentCfg + + providerMessage :: Provider.Role -> Text -> Provider.Message + providerMessage role content = Provider.Message role content Nothing Nothing + + loopProvider :: + Provider.Provider -> + [Provider.ToolApi] -> + Map.Map Text Tool -> + [Provider.Message] -> + Int -> + Int -> + Int -> + Double -> + Map.Map Text Int -> + Int -> + Int -> + IO (Either Text AgentResult) + loopProvider prov toolApis' toolMap msgs iteration totalCalls totalTokens totalCost toolCallCounts testFailures editFailures + | iteration >= maxIter = do + let errMsg = "Max iterations (" <> tshow maxIter <> ") reached" + engineOnError engineCfg errMsg + pure <| Left errMsg + | otherwise = do + let guardrailViolation = findGuardrailViolation guardrails' totalCost totalTokens toolCallCounts testFailures editFailures + case guardrailViolation of + Just (g, errMsg) -> do + engineOnGuardrail engineCfg g + pure <| Left errMsg + Nothing -> do + engineOnActivity engineCfg <| "Iteration " <> tshow (iteration + 1) + result <- Provider.chatWithUsage prov toolApis' msgs + case result of + Left err -> do + engineOnError engineCfg err + pure (Left err) + Right chatRes -> do + let msg = Provider.chatMessage chatRes + tokens = maybe 0 Provider.usageTotalTokens (Provider.chatUsage chatRes) + cost = case Provider.chatUsage chatRes +> Provider.usageCost of + Just actualCost -> actualCost * 100 + Nothing -> estimateCost (getProviderModel prov) tokens + engineOnCost engineCfg tokens cost + let newTokens = totalTokens + tokens + newCost = totalCost + cost + let assistantText = Provider.msgContent msg + unless (Text.null assistantText) + <| engineOnAssistant engineCfg assistantText + case Provider.msgToolCalls msg of + Nothing -> do + engineOnActivity engineCfg "Agent completed" + engineOnComplete engineCfg + pure + <| Right + <| AgentResult + { resultFinalMessage = Provider.msgContent msg, + resultToolCallCount = totalCalls, + resultIterations = iteration + 1, + resultTotalCost = newCost, + resultTotalTokens = newTokens + } + Just [] -> do + engineOnActivity engineCfg "Agent completed (empty tool calls)" + engineOnComplete engineCfg + pure + <| Right + <| AgentResult + { resultFinalMessage = Provider.msgContent msg, + resultToolCallCount = totalCalls, + resultIterations = iteration + 1, + resultTotalCost = newCost, + resultTotalTokens = newTokens + } + Just tcs -> do + (toolResults, newTestFailures, newEditFailures) <- executeProviderToolCalls engineCfg toolMap tcs testFailures editFailures + let newMsgs = msgs <> [msg] <> toolResults + newCalls = totalCalls + length tcs + newToolCallCounts = updateProviderToolCallCounts toolCallCounts tcs + loopProvider prov toolApis' toolMap newMsgs (iteration + 1) newCalls newTokens newCost newToolCallCounts newTestFailures newEditFailures + + getProviderModel :: Provider.Provider -> Text + getProviderModel (Provider.OpenRouter cfg) = Provider.providerModel cfg + getProviderModel (Provider.Ollama cfg) = Provider.providerModel cfg + getProviderModel (Provider.AmpCLI _) = "amp" + + updateProviderToolCallCounts :: Map.Map Text Int -> [Provider.ToolCall] -> Map.Map Text Int + updateProviderToolCallCounts = + foldr (\tc m -> Map.insertWith (+) (Provider.fcName (Provider.tcFunction tc)) 1 m) + + executeProviderToolCalls :: EngineConfig -> Map.Map Text Tool -> [Provider.ToolCall] -> Int -> Int -> IO ([Provider.Message], Int, Int) + executeProviderToolCalls eCfg tMap tcs initialTestFailures initialEditFailures = do + results <- traverse (executeSingleProvider eCfg tMap) tcs + let msgs = map (\(m, _, _) -> m) results + testDeltas = map (\(_, t, _) -> t) results + editDeltas = map (\(_, _, e) -> e) results + totalTestFail = initialTestFailures + sum testDeltas + totalEditFail = initialEditFailures + sum editDeltas + pure (msgs, totalTestFail, totalEditFail) + + executeSingleProvider :: EngineConfig -> Map.Map Text Tool -> Provider.ToolCall -> IO (Provider.Message, Int, Int) + executeSingleProvider eCfg tMap tc = do + let name = Provider.fcName (Provider.tcFunction tc) + argsText = Provider.fcArguments (Provider.tcFunction tc) + callId = Provider.tcId tc + engineOnActivity eCfg <| "Executing tool: " <> name + engineOnToolCall eCfg name argsText + case Map.lookup name tMap of + Nothing -> do + let errMsg = "Tool not found: " <> name + engineOnToolResult eCfg name False errMsg + pure (Provider.Message Provider.ToolRole errMsg Nothing (Just callId), 0, 0) + Just tool -> do + case Aeson.decode (BL.fromStrict (TE.encodeUtf8 argsText)) of + Nothing -> do + let errMsg = "Invalid JSON arguments: " <> argsText + engineOnToolResult eCfg name False errMsg + pure (Provider.Message Provider.ToolRole errMsg Nothing (Just callId), 0, 0) + Just args -> do + resultValue <- toolExecute tool args + let resultText = TE.decodeUtf8 (BL.toStrict (Aeson.encode resultValue)) + isTestCall = name == "bash" && ("bild --test" `Text.isInfixOf` argsText || "bild -t" `Text.isInfixOf` argsText) + isTestFailure = isTestCall && isFailureResultProvider resultValue + testDelta = if isTestFailure then 1 else 0 + isEditFailure = name == "edit_file" && isOldStrNotFoundProvider resultValue + editDelta = if isEditFailure then 1 else 0 + engineOnToolResult eCfg name True resultText + pure (Provider.Message Provider.ToolRole resultText Nothing (Just callId), testDelta, editDelta) + + isFailureResultProvider :: Aeson.Value -> Bool + isFailureResultProvider (Aeson.Object obj) = + case KeyMap.lookup "exit_code" obj of + Just (Aeson.Number n) -> n /= 0 + _ -> False + isFailureResultProvider (Aeson.String s) = + "error" + `Text.isInfixOf` Text.toLower s + || "failed" + `Text.isInfixOf` Text.toLower s + || "FAILED" + `Text.isInfixOf` s + isFailureResultProvider _ = False + + isOldStrNotFoundProvider :: Aeson.Value -> Bool + isOldStrNotFoundProvider (Aeson.Object obj) = + case KeyMap.lookup "error" obj of + Just (Aeson.String s) -> "old_str not found" `Text.isInfixOf` s + _ -> False + isOldStrNotFoundProvider _ = False diff --git a/Omni/Agent/Provider.hs b/Omni/Agent/Provider.hs new file mode 100644 index 0000000..a8a5381 --- /dev/null +++ b/Omni/Agent/Provider.hs @@ -0,0 +1,386 @@ +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE NoImplicitPrelude #-} + +-- | LLM Provider abstraction for multi-backend support. +-- +-- Supports multiple LLM backends: +-- - OpenRouter (cloud API, multiple models) +-- - Ollama (local models) +-- - Amp CLI (subprocess) +-- +-- : out omni-agent-provider +-- : dep aeson +-- : dep http-conduit +-- : dep case-insensitive +module Omni.Agent.Provider + ( Provider (..), + ProviderConfig (..), + ChatResult (..), + Message (..), + Role (..), + ToolCall (..), + FunctionCall (..), + Usage (..), + ToolApi (..), + defaultOpenRouter, + defaultOllama, + chat, + chatWithUsage, + main, + test, + ) +where + +import Alpha +import Data.Aeson ((.=)) +import qualified Data.Aeson as Aeson +import qualified Data.Aeson.KeyMap as KeyMap +import qualified Data.ByteString.Lazy as BL +import qualified Data.CaseInsensitive as CI +import qualified Data.Text as Text +import qualified Data.Text.Encoding as TE +import qualified Network.HTTP.Simple as HTTP +import qualified Omni.Test as Test + +main :: IO () +main = Test.run test + +test :: Test.Tree +test = + Test.group + "Omni.Agent.Provider" + [ Test.unit "defaultOpenRouter has correct endpoint" <| do + case defaultOpenRouter "" "test-model" of + OpenRouter cfg -> providerBaseUrl cfg Test.@=? "https://openrouter.ai/api/v1" + _ -> Test.assertFailure "Expected OpenRouter", + Test.unit "defaultOllama has correct endpoint" <| do + case defaultOllama "test-model" of + Ollama cfg -> providerBaseUrl cfg Test.@=? "http://localhost:11434" + _ -> Test.assertFailure "Expected Ollama", + Test.unit "ChatResult preserves message" <| do + let msg = Message User "test" Nothing Nothing + result = ChatResult msg Nothing + chatMessage result Test.@=? msg + ] + +data Provider + = OpenRouter ProviderConfig + | Ollama ProviderConfig + | AmpCLI FilePath + deriving (Show, Eq, Generic) + +data ProviderConfig = ProviderConfig + { providerBaseUrl :: Text, + providerApiKey :: Text, + providerModel :: Text, + providerExtraHeaders :: [(ByteString, ByteString)] + } + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON ProviderConfig where + toJSON c = + Aeson.object + [ "baseUrl" .= providerBaseUrl c, + "apiKey" .= providerApiKey c, + "model" .= providerModel c + ] + +instance Aeson.FromJSON ProviderConfig where + parseJSON = + Aeson.withObject "ProviderConfig" <| \v -> + (ProviderConfig </ (v Aeson..: "baseUrl")) + <*> (v Aeson..: "apiKey") + <*> (v Aeson..: "model") + <*> pure [] + +defaultOpenRouter :: Text -> Text -> Provider +defaultOpenRouter apiKey model = + OpenRouter + ProviderConfig + { providerBaseUrl = "https://openrouter.ai/api/v1", + providerApiKey = apiKey, + providerModel = model, + providerExtraHeaders = + [ ("HTTP-Referer", "https://omni.dev"), + ("X-Title", "Omni Agent") + ] + } + +defaultOllama :: Text -> Provider +defaultOllama model = + Ollama + ProviderConfig + { providerBaseUrl = "http://localhost:11434", + providerApiKey = "", + providerModel = model, + providerExtraHeaders = [] + } + +data Role = System | User | Assistant | ToolRole + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON Role where + toJSON System = Aeson.String "system" + toJSON User = Aeson.String "user" + toJSON Assistant = Aeson.String "assistant" + toJSON ToolRole = Aeson.String "tool" + +instance Aeson.FromJSON Role where + parseJSON = Aeson.withText "Role" parseRole + where + parseRole "system" = pure System + parseRole "user" = pure User + parseRole "assistant" = pure Assistant + parseRole "tool" = pure ToolRole + parseRole _ = empty + +data Message = Message + { msgRole :: Role, + msgContent :: Text, + msgToolCalls :: Maybe [ToolCall], + msgToolCallId :: Maybe Text + } + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON Message where + toJSON m = + Aeson.object + <| catMaybes + [ Just ("role" .= msgRole m), + Just ("content" .= msgContent m), + ("tool_calls" .=) </ msgToolCalls m, + ("tool_call_id" .=) </ msgToolCallId m + ] + +instance Aeson.FromJSON Message where + parseJSON = + Aeson.withObject "Message" <| \v -> + (Message </ (v Aeson..: "role")) + <*> (v Aeson..:? "content" Aeson..!= "") + <*> (v Aeson..:? "tool_calls") + <*> (v Aeson..:? "tool_call_id") + +data ToolCall = ToolCall + { tcId :: Text, + tcType :: Text, + tcFunction :: FunctionCall + } + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON ToolCall where + toJSON tc = + Aeson.object + [ "id" .= tcId tc, + "type" .= tcType tc, + "function" .= tcFunction tc + ] + +instance Aeson.FromJSON ToolCall where + parseJSON = + Aeson.withObject "ToolCall" <| \v -> + (ToolCall </ (v Aeson..: "id")) + <*> (v Aeson..:? "type" Aeson..!= "function") + <*> (v Aeson..: "function") + +data FunctionCall = FunctionCall + { fcName :: Text, + fcArguments :: Text + } + deriving (Show, Eq, Generic) + +instance Aeson.ToJSON FunctionCall where + toJSON fc = + Aeson.object + [ "name" .= fcName fc, + "arguments" .= fcArguments fc + ] + +instance Aeson.FromJSON FunctionCall where + parseJSON = + Aeson.withObject "FunctionCall" <| \v -> + (FunctionCall </ (v Aeson..: "name")) + <*> (v Aeson..: "arguments") + +data Usage = Usage + { usagePromptTokens :: Int, + usageCompletionTokens :: Int, + usageTotalTokens :: Int, + usageCost :: Maybe Double + } + deriving (Show, Eq, Generic) + +instance Aeson.FromJSON Usage where + parseJSON = + Aeson.withObject "Usage" <| \v -> + (Usage </ (v Aeson..: "prompt_tokens")) + <*> (v Aeson..: "completion_tokens") + <*> (v Aeson..: "total_tokens") + <*> (v Aeson..:? "cost") + +data ChatResult = ChatResult + { chatMessage :: Message, + chatUsage :: Maybe Usage + } + deriving (Show, Eq) + +data ToolApi = ToolApi + { toolApiName :: Text, + toolApiDescription :: Text, + toolApiParameters :: Aeson.Value + } + deriving (Generic) + +instance Aeson.ToJSON ToolApi where + toJSON t = + Aeson.object + [ "type" .= ("function" :: Text), + "function" + .= Aeson.object + [ "name" .= toolApiName t, + "description" .= toolApiDescription t, + "parameters" .= toolApiParameters t + ] + ] + +data ChatCompletionRequest = ChatCompletionRequest + { reqModel :: Text, + reqMessages :: [Message], + reqTools :: Maybe [ToolApi] + } + deriving (Generic) + +instance Aeson.ToJSON ChatCompletionRequest where + toJSON r = + Aeson.object + <| catMaybes + [ Just ("model" .= reqModel r), + Just ("messages" .= reqMessages r), + ("tools" .=) </ reqTools r, + Just ("usage" .= Aeson.object ["include" .= True]) + ] + +data Choice = Choice + { choiceIndex :: Int, + choiceMessage :: Message, + choiceFinishReason :: Maybe Text + } + deriving (Show, Eq, Generic) + +instance Aeson.FromJSON Choice where + parseJSON = + Aeson.withObject "Choice" <| \v -> + (Choice </ (v Aeson..: "index")) + <*> (v Aeson..: "message") + <*> (v Aeson..:? "finish_reason") + +data ChatCompletionResponse = ChatCompletionResponse + { respId :: Text, + respChoices :: [Choice], + respModel :: Text, + respUsage :: Maybe Usage + } + deriving (Show, Eq, Generic) + +instance Aeson.FromJSON ChatCompletionResponse where + parseJSON = + Aeson.withObject "ChatCompletionResponse" <| \v -> + (ChatCompletionResponse </ (v Aeson..: "id")) + <*> (v Aeson..: "choices") + <*> (v Aeson..: "model") + <*> (v Aeson..:? "usage") + +chat :: Provider -> [ToolApi] -> [Message] -> IO (Either Text Message) +chat provider tools messages = do + result <- chatWithUsage provider tools messages + pure (chatMessage </ result) + +chatWithUsage :: Provider -> [ToolApi] -> [Message] -> IO (Either Text ChatResult) +chatWithUsage (OpenRouter cfg) tools messages = chatOpenAI cfg tools messages +chatWithUsage (Ollama cfg) tools messages = chatOllama cfg tools messages +chatWithUsage (AmpCLI _promptFile) _tools _messages = do + pure (Left "Amp CLI provider not yet implemented") + +chatOpenAI :: ProviderConfig -> [ToolApi] -> [Message] -> IO (Either Text ChatResult) +chatOpenAI cfg tools messages = do + let url = Text.unpack (providerBaseUrl cfg) <> "/chat/completions" + req0 <- HTTP.parseRequest url + let body = + ChatCompletionRequest + { reqModel = providerModel cfg, + reqMessages = messages, + reqTools = if null tools then Nothing else Just tools + } + baseReq = + HTTP.setRequestMethod "POST" + <| HTTP.setRequestHeader "Content-Type" ["application/json"] + <| HTTP.setRequestHeader "Authorization" ["Bearer " <> TE.encodeUtf8 (providerApiKey cfg)] + <| HTTP.setRequestBodyLBS (Aeson.encode body) + <| req0 + req = foldr addHeader baseReq (providerExtraHeaders cfg) + addHeader (name, value) = HTTP.addRequestHeader (CI.mk name) value + + response <- HTTP.httpLBS req + let status = HTTP.getResponseStatusCode response + if status >= 200 && status < 300 + then case Aeson.decode (HTTP.getResponseBody response) of + Just resp -> + case respChoices resp of + (c : _) -> pure (Right (ChatResult (choiceMessage c) (respUsage resp))) + [] -> pure (Left "No choices in response") + Nothing -> pure (Left "Failed to parse response") + else pure (Left ("HTTP error: " <> tshow status <> " - " <> TE.decodeUtf8 (BL.toStrict (HTTP.getResponseBody response)))) + +chatOllama :: ProviderConfig -> [ToolApi] -> [Message] -> IO (Either Text ChatResult) +chatOllama cfg tools messages = do + let url = Text.unpack (providerBaseUrl cfg) <> "/api/chat" + req0 <- HTTP.parseRequest url + let body = + Aeson.object + [ "model" .= providerModel cfg, + "messages" .= messages, + "tools" .= if null tools then Aeson.Null else Aeson.toJSON tools, + "stream" .= False + ] + req = + HTTP.setRequestMethod "POST" + <| HTTP.setRequestHeader "Content-Type" ["application/json"] + <| HTTP.setRequestBodyLBS (Aeson.encode body) + <| req0 + + response <- HTTP.httpLBS req + let status = HTTP.getResponseStatusCode response + if status >= 200 && status < 300 + then case Aeson.decode (HTTP.getResponseBody response) of + Just resp -> parseOllamaResponse resp + Nothing -> pure (Left ("Failed to parse Ollama response: " <> TE.decodeUtf8 (BL.toStrict (HTTP.getResponseBody response)))) + else pure (Left ("HTTP error: " <> tshow status <> " - " <> TE.decodeUtf8 (BL.toStrict (HTTP.getResponseBody response)))) + +parseOllamaResponse :: Aeson.Value -> IO (Either Text ChatResult) +parseOllamaResponse val = + case val of + Aeson.Object obj -> do + let msgResult = do + msgObj <- case KeyMap.lookup "message" obj of + Just m -> Right m + Nothing -> Left "No message in response" + case Aeson.fromJSON msgObj of + Aeson.Success msg -> Right msg + Aeson.Error e -> Left (Text.pack e) + usageResult = case KeyMap.lookup "prompt_eval_count" obj of + Just (Aeson.Number promptTokens) -> + case KeyMap.lookup "eval_count" obj of + Just (Aeson.Number evalTokens) -> + Just + Usage + { usagePromptTokens = round promptTokens, + usageCompletionTokens = round evalTokens, + usageTotalTokens = round promptTokens + round evalTokens, + usageCost = Nothing + } + _ -> Nothing + _ -> Nothing + case msgResult of + Right msg -> pure (Right (ChatResult msg usageResult)) + Left e -> pure (Left e) + _ -> pure (Left "Expected object response from Ollama") diff --git a/Omni/Agent/Worker.hs b/Omni/Agent/Worker.hs index 66f894d..3b0c563 100644 --- a/Omni/Agent/Worker.hs +++ b/Omni/Agent/Worker.hs @@ -21,6 +21,7 @@ import qualified Data.Time import qualified Omni.Agent.Core as Core import qualified Omni.Agent.Engine as Engine import qualified Omni.Agent.Log as AgentLog +import qualified Omni.Agent.Provider as Provider import qualified Omni.Agent.Tools as Tools import qualified Omni.Fact as Fact import qualified Omni.Task.Core as TaskCore @@ -357,8 +358,14 @@ runWithEngine worker repo task = do Engine.agentGuardrails = guardrails } - -- Run the agent - result <- Engine.runAgent engineCfg agentCfg userPrompt + -- Run the agent with appropriate provider + result <- case Core.workerEngine worker of + Core.EngineOpenRouter -> Engine.runAgent engineCfg agentCfg userPrompt + Core.EngineOllama -> do + ollamaModel <- fromMaybe "llama3.1:8b" </ Env.lookupEnv "OLLAMA_MODEL" + let provider = Provider.defaultOllama (Text.pack ollamaModel) + Engine.runAgentWithProvider engineCfg provider agentCfg userPrompt + Core.EngineAmp -> pure (Left "Amp engine not yet implemented") totalCost <- readIORef totalCostRef case result of @@ -53,7 +53,7 @@ jr Usage: jr task [<args>...] - jr work [<task-id>] + jr work [<task-id>] [--engine=ENGINE] jr prompt <task-id> jr web [--port=PORT] jr review [<task-id>] [--auto] @@ -77,6 +77,7 @@ Commands: Options: -h --help Show this help --port=PORT Port for web server [default: 8080] + --engine=ENGINE LLM engine: openrouter, ollama, amp [default: openrouter] --auto Auto-review: accept if tests pass, reject if they fail --delay=SECONDS Delay between loop iterations [default: 5] --project=PROJECT Filter facts by project @@ -119,13 +120,20 @@ move args absPath <- Directory.getCurrentDirectory let name = Text.pack (takeFileName absPath) + -- Parse engine flag + let engineType = case Cli.getArg args (Cli.longOption "engine") of + Just "ollama" -> AgentCore.EngineOllama + Just "amp" -> AgentCore.EngineAmp + _ -> AgentCore.EngineOpenRouter + let worker = AgentCore.Worker { AgentCore.workerName = name, AgentCore.workerPid = Nothing, AgentCore.workerStatus = AgentCore.Idle, AgentCore.workerPath = path, - AgentCore.workerQuiet = False -- Show ANSI status bar for manual work + AgentCore.workerQuiet = False, -- Show ANSI status bar for manual work + AgentCore.workerEngine = engineType } let taskId = fmap Text.pack (Cli.getArg args (Cli.argument "task-id")) @@ -183,7 +191,8 @@ runLoop delaySec = do AgentCore.workerPid = Nothing, AgentCore.workerStatus = AgentCore.Idle, AgentCore.workerPath = ".", - AgentCore.workerQuiet = True -- No ANSI status bar in loop mode + AgentCore.workerQuiet = True, -- No ANSI status bar in loop mode + AgentCore.workerEngine = AgentCore.EngineOpenRouter -- Default for loop } putText "[loop] Starting worker..." AgentWorker.start worker (Just (TaskCore.taskId task)) |
