From ea45ea68d16c05934a11b24db0aeb68a8e3e4f22 Mon Sep 17 00:00:00 2001 From: Manish Biswal Date: Wed, 4 Mar 2026 14:32:44 +0530 Subject: [PATCH 1/4] feat: auto download backends Signed-off-by: Manish Biswal --- cmd/cli/commands/run.go | 15 ++++++ cmd/cli/commands/utils.go | 103 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+) diff --git a/cmd/cli/commands/run.go b/cmd/cli/commands/run.go index 58e929131..4c3a1aa4a 100644 --- a/cmd/cli/commands/run.go +++ b/cmd/cli/commands/run.go @@ -820,6 +820,21 @@ func newRunCmd() *cobra.Command { } } + modelInfo, err := desktopClient.Inspect(model, true) + backend := "" + if err == nil { + backend, _ = GetRequiredBackendFromModelInfo(&modelInfo) + } + + if backend != "" { + if err := EnsureBackendAvailable(backend, cmd); err != nil { + if err.Error() == "backend installation cancelled" { + return nil + } + return err + } + } + // Handle --detach flag: just load the model without interaction if detach { if err := desktopClient.Preload(cmd.Context(), model); err != nil { diff --git a/cmd/cli/commands/utils.go b/cmd/cli/commands/utils.go index c77f472b7..76bed532a 100644 --- a/cmd/cli/commands/utils.go +++ b/cmd/cli/commands/utils.go @@ -1,7 +1,9 @@ package commands import ( + "bufio" "bytes" + "encoding/json" "errors" "fmt" "io" @@ -12,7 +14,10 @@ import ( "github.com/docker/model-runner/cmd/cli/desktop" "github.com/docker/model-runner/cmd/cli/pkg/standalone" "github.com/docker/model-runner/pkg/distribution/oci/reference" + "github.com/docker/model-runner/pkg/distribution/types" + "github.com/docker/model-runner/pkg/inference/backends/llamacpp" "github.com/docker/model-runner/pkg/inference/backends/vllm" + dmrm "github.com/docker/model-runner/pkg/inference/models" "github.com/moby/term" "github.com/olekukonko/tablewriter" "github.com/olekukonko/tablewriter/renderer" @@ -270,3 +275,101 @@ func newTable(w io.Writer) *tablewriter.Table { }), ) } + +func CheckBackendInstalled(backend string) (bool, error) { + status := desktopClient.Status() + if status.Error != nil { + return false, fmt.Errorf("failed to get backend status: %w", status.Error) + } + + var backendStatus map[string]string + if err := json.Unmarshal(status.Status, &backendStatus); err != nil { + return false, fmt.Errorf("failed to parse backend status: %w", err) + } + + backendState, exists := backendStatus[backend] + if !exists { + return false, nil + } + + return backendState == "installed" || backendState == "running", nil +} + +func PromptInstallBackend(backend string, cmd *cobra.Command) (bool, error) { + fmt.Fprintf(cmd.OutOrStdout(), "Backend %q is not installed. Download and install it now? [Y/n]: ", backend) + + reader := bufio.NewReader(os.Stdin) + input, err := reader.ReadString('\n') + if err != nil { + return false, fmt.Errorf("failed to read input: %w", err) + } + + input = strings.TrimSpace(strings.ToLower(input)) + return input == "" || input == "y" || input == "yes", nil +} + +func InstallBackend(backend string, cmd *cobra.Command) error { + installCmd := newInstallRunner() + installCmd.SetArgs([]string{"--backend", backend}) + + if err := installCmd.Execute(); err != nil { + return fmt.Errorf("failed to install backend %s: %w", backend, err) + } + + return nil +} + +func EnsureBackendAvailable(backend string, cmd *cobra.Command) error { + installed, err := CheckBackendInstalled(backend) + if err != nil { + return err + } + + if installed { + return nil + } + + confirm, err := PromptInstallBackend(backend, cmd) + if err != nil { + return err + } + + if !confirm { + cmd.Printf("Run 'docker model install-runner --backend %s' to install it manually.\n", backend) + return fmt.Errorf("backend installation cancelled") + } + + if err := InstallBackend(backend, cmd); err != nil { + return err + } + + cmd.Printf("Backend %q installed successfully.\n", backend) + return nil +} + +func GetRequiredBackend(model string) (string, error) { + modelInfo, err := desktopClient.Inspect(model, false) + if err != nil { + return "", err + } + + return GetRequiredBackendFromModelInfo(&modelInfo) +} + +func GetRequiredBackendFromModelInfo(modelInfo *dmrm.Model) (string, error) { + config, ok := modelInfo.Config.(*types.Config) + if !ok { + return llamacpp.Name, nil + } + + switch config.Format { + case types.FormatSafetensors: + return vllm.Name, nil + case types.FormatGGUF: + return llamacpp.Name, nil + case types.FormatDiffusers: + return "diffusers", nil + default: + return llamacpp.Name, nil + } +} From fcca922754b456c2f1e62a54a9cfb60e30205a8c Mon Sep 17 00:00:00 2001 From: Manish Biswal Date: Sat, 7 Mar 2026 21:45:33 +0530 Subject: [PATCH 2/4] feat(cli): prompt and install required backend during run --- cmd/cli/commands/run.go | 28 ++++++++++++++++++---------- cmd/cli/commands/utils.go | 15 ++++++++++++++- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/cmd/cli/commands/run.go b/cmd/cli/commands/run.go index 4c3a1aa4a..f3d7beb4e 100644 --- a/cmd/cli/commands/run.go +++ b/cmd/cli/commands/run.go @@ -809,20 +809,21 @@ func newRunCmd() *cobra.Command { return nil } - _, err := desktopClient.Inspect(model, false) - if err != nil { - if !errors.Is(err, desktop.ErrNotFound) { - return handleClientError(err, "Failed to inspect model") - } - cmd.Println("Unable to find model '" + model + "' locally. Pulling from the server.") - if err := pullModel(cmd, desktopClient, model); err != nil { - return err + modelInfo, err := desktopClient.Inspect(model, false) + modelFoundLocally := err == nil + if err != nil && !errors.Is(err, desktop.ErrNotFound) { + return handleClientError(err, "Failed to inspect model") + } + + if !modelFoundLocally { + remoteInfo, remoteErr := desktopClient.Inspect(model, true) + if remoteErr == nil { + modelInfo = remoteInfo } } - modelInfo, err := desktopClient.Inspect(model, true) backend := "" - if err == nil { + if modelInfo.ID != "" { backend, _ = GetRequiredBackendFromModelInfo(&modelInfo) } @@ -835,6 +836,13 @@ func newRunCmd() *cobra.Command { } } + if !modelFoundLocally { + cmd.Println("Unable to find model '" + model + "' locally. Pulling from the server.") + if err := pullModel(cmd, desktopClient, model); err != nil { + return err + } + } + // Handle --detach flag: just load the model without interaction if detach { if err := desktopClient.Preload(cmd.Context(), model); err != nil { diff --git a/cmd/cli/commands/utils.go b/cmd/cli/commands/utils.go index 76bed532a..27d504823 100644 --- a/cmd/cli/commands/utils.go +++ b/cmd/cli/commands/utils.go @@ -292,7 +292,12 @@ func CheckBackendInstalled(backend string) (bool, error) { return false, nil } - return backendState == "installed" || backendState == "running", nil + state := strings.TrimSpace(strings.ToLower(backendState)) + if strings.HasPrefix(state, "not ") || strings.HasPrefix(state, "error") { + return false, nil + } + + return strings.HasPrefix(state, "installed") || strings.HasPrefix(state, "running"), nil } func PromptInstallBackend(backend string, cmd *cobra.Command) (bool, error) { @@ -343,6 +348,14 @@ func EnsureBackendAvailable(backend string, cmd *cobra.Command) error { return err } + installed, err = CheckBackendInstalled(backend) + if err != nil { + return err + } + if !installed { + return fmt.Errorf("backend %q is still not installed; run 'docker model install-runner --backend %s'", backend, backend) + } + cmd.Printf("Backend %q installed successfully.\n", backend) return nil } From cd54d0b57b4871c9322e266ae3056199eabb62d2 Mon Sep 17 00:00:00 2001 From: Manish Biswal Date: Thu, 12 Mar 2026 00:41:14 +0530 Subject: [PATCH 3/4] refactor(cli): use desktop client backend install API --- cmd/cli/commands/utils.go | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/cmd/cli/commands/utils.go b/cmd/cli/commands/utils.go index 27d504823..967ef1363 100644 --- a/cmd/cli/commands/utils.go +++ b/cmd/cli/commands/utils.go @@ -313,11 +313,8 @@ func PromptInstallBackend(backend string, cmd *cobra.Command) (bool, error) { return input == "" || input == "y" || input == "yes", nil } -func InstallBackend(backend string, cmd *cobra.Command) error { - installCmd := newInstallRunner() - installCmd.SetArgs([]string{"--backend", backend}) - - if err := installCmd.Execute(); err != nil { +func InstallBackend(backend string) error { + if err := desktopClient.InstallBackend(backend); err != nil { return fmt.Errorf("failed to install backend %s: %w", backend, err) } @@ -344,7 +341,7 @@ func EnsureBackendAvailable(backend string, cmd *cobra.Command) error { return fmt.Errorf("backend installation cancelled") } - if err := InstallBackend(backend, cmd); err != nil { + if err := InstallBackend(backend); err != nil { return err } From 82ea264f933fc046d4e5749d31e660f02853eeda Mon Sep 17 00:00:00 2001 From: Manish Biswal Date: Sat, 11 Apr 2026 18:51:31 +0530 Subject: [PATCH 4/4] use server side backend selection instead of duplicating logic in CLI Remove the CLI side GetRequiredBackendFromModelInfo mapping and delegate backend resolution to the server via a new GET /models/backend?model= endpoint. The endpoint reuses Scheduler.selectBackendForModel so the platform aware fallback chain (vLLM, MLX, SGLang, diffusers) stays in one place. Also address earlier review feedback: sentinel error for cancelled install, cmd.InOrStdin for prompt input, and unit tests for the new helpers. --- cmd/cli/commands/run.go | 13 ++----- cmd/cli/commands/utils.go | 23 +++---------- cmd/cli/commands/utils_test.go | 44 +++++++++++++----------- cmd/cli/desktop/desktop.go | 30 ++++++++++++++++ pkg/inference/scheduling/api.go | 6 ++++ pkg/inference/scheduling/http_handler.go | 25 ++++++++++++++ pkg/inference/scheduling/scheduler.go | 26 ++++++++++++++ pkg/routing/router.go | 1 + 8 files changed, 120 insertions(+), 48 deletions(-) diff --git a/cmd/cli/commands/run.go b/cmd/cli/commands/run.go index e12830d95..d9d2d890e 100644 --- a/cmd/cli/commands/run.go +++ b/cmd/cli/commands/run.go @@ -828,22 +828,15 @@ func newRunCmd() *cobra.Command { return nil } - modelInfo, err := desktopClient.Inspect(model, false) + _, err := desktopClient.Inspect(model, false) modelFoundLocally := err == nil if err != nil && !errors.Is(err, desktop.ErrNotFound) { return handleClientError(err, "Failed to inspect model") } - if !modelFoundLocally { - remoteInfo, remoteErr := desktopClient.Inspect(model, true) - if remoteErr == nil { - modelInfo = remoteInfo - } - } - backend := "" - if modelInfo.ID != "" { - backend, _ = GetRequiredBackendFromModelInfo(&modelInfo) + if resolvedBackend, resolveErr := ResolveRequiredBackend(model); resolveErr == nil { + backend = resolvedBackend } if backend != "" { diff --git a/cmd/cli/commands/utils.go b/cmd/cli/commands/utils.go index 431bf7129..c96fc46c5 100644 --- a/cmd/cli/commands/utils.go +++ b/cmd/cli/commands/utils.go @@ -14,11 +14,7 @@ import ( "github.com/docker/model-runner/cmd/cli/pkg/standalone" "github.com/docker/model-runner/pkg/distribution/distribution" "github.com/docker/model-runner/pkg/distribution/oci/reference" - "github.com/docker/model-runner/pkg/distribution/types" - "github.com/docker/model-runner/pkg/inference/backends/diffusers" - "github.com/docker/model-runner/pkg/inference/backends/llamacpp" "github.com/docker/model-runner/pkg/inference/backends/vllm" - dmrm "github.com/docker/model-runner/pkg/inference/models" "github.com/moby/term" "github.com/olekukonko/tablewriter" "github.com/olekukonko/tablewriter/renderer" @@ -367,22 +363,13 @@ func EnsureBackendAvailable(backend string, cmd *cobra.Command) error { return nil } -func GetRequiredBackendFromModelInfo(modelInfo *dmrm.Model) (string, error) { - config, ok := modelInfo.Config.(*types.Config) - if !ok { - return llamacpp.Name, nil +func ResolveRequiredBackend(model string) (string, error) { + selection, err := desktopClient.ResolveModelBackend(model) + if err != nil { + return "", err } - switch config.Format { - case types.FormatSafetensors: - return vllm.Name, nil - case types.FormatGGUF: - return llamacpp.Name, nil - case types.FormatDiffusers: - return diffusers.Name, nil - default: - return llamacpp.Name, nil - } + return selection.Backend, nil } func printNextSteps(out io.Writer, messages []string) { diff --git a/cmd/cli/commands/utils_test.go b/cmd/cli/commands/utils_test.go index cb62ff4ef..b9eaf24d7 100644 --- a/cmd/cli/commands/utils_test.go +++ b/cmd/cli/commands/utils_test.go @@ -7,17 +7,15 @@ import ( "fmt" "io" "net/http" + "net/url" "strings" "testing" "github.com/docker/model-runner/cmd/cli/desktop" mockdesktop "github.com/docker/model-runner/cmd/cli/mocks" - "github.com/docker/model-runner/pkg/distribution/types" "github.com/docker/model-runner/pkg/inference" - "github.com/docker/model-runner/pkg/inference/backends/diffusers" - "github.com/docker/model-runner/pkg/inference/backends/llamacpp" "github.com/docker/model-runner/pkg/inference/backends/vllm" - dmrm "github.com/docker/model-runner/pkg/inference/models" + "github.com/docker/model-runner/pkg/inference/scheduling" "github.com/spf13/cobra" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -219,22 +217,28 @@ func TestEnsureBackendAvailableCancelled(t *testing.T) { require.Contains(t, out.String(), "docker model install-runner --backend vllm") } -func TestGetRequiredBackendFromModelInfo(t *testing.T) { - t.Run("safetensors chooses vllm", func(t *testing.T) { - backend, err := GetRequiredBackendFromModelInfo(&dmrm.Model{Config: &types.Config{Format: types.FormatSafetensors}}) - require.NoError(t, err) - require.Equal(t, vllm.Name, backend) - }) +func TestResolveRequiredBackend(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() - t.Run("gguf chooses llamacpp", func(t *testing.T) { - backend, err := GetRequiredBackendFromModelInfo(&dmrm.Model{Config: &types.Config{Format: types.FormatGGUF}}) - require.NoError(t, err) - require.Equal(t, llamacpp.Name, backend) - }) + client := mockdesktop.NewMockDockerHttpClient(ctrl) + modelRunner = desktop.NewContextForMock(client) + desktopClient = desktop.New(modelRunner) - t.Run("diffusers chooses diffusers backend", func(t *testing.T) { - backend, err := GetRequiredBackendFromModelInfo(&dmrm.Model{Config: &types.Config{Format: types.FormatDiffusers}}) - require.NoError(t, err) - require.Equal(t, diffusers.Name, backend) - }) + model := "ai/functiongemma-vllm:270M" + selection := scheduling.ModelBackendSelection{Backend: vllm.Name, Installed: false} + body, err := json.Marshal(selection) + require.NoError(t, err) + + expectedResolveURL := modelRunner.URL(inference.ModelsPrefix + "/backend?model=" + url.QueryEscape(model)) + expectedUserAgent := "docker-model-cli/" + desktop.Version + + client.EXPECT().Do(gomock.Cond(func(req any) bool { + r, ok := req.(*http.Request) + return ok && r.URL.String() == expectedResolveURL && r.Header.Get("User-Agent") == expectedUserAgent + })).Return(&http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(body))}, nil) + + backend, err := ResolveRequiredBackend(model) + require.NoError(t, err) + require.Equal(t, vllm.Name, backend) } diff --git a/cmd/cli/desktop/desktop.go b/cmd/cli/desktop/desktop.go index 95a069e05..4fa1a9905 100644 --- a/cmd/cli/desktop/desktop.go +++ b/cmd/cli/desktop/desktop.go @@ -362,6 +362,36 @@ func (c *Client) Inspect(model string, remote bool) (dmrm.Model, error) { return modelInspect, nil } +func (c *Client) ResolveModelBackend(model string) (scheduling.ModelBackendSelection, error) { + resolvePath := fmt.Sprintf("%s/backend?model=%s", inference.ModelsPrefix, url.QueryEscape(model)) + + resp, err := c.doRequest(http.MethodGet, resolvePath, nil) + if err != nil { + return scheduling.ModelBackendSelection{}, c.handleQueryError(err, resolvePath) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if resp.StatusCode == http.StatusNotFound { + return scheduling.ModelBackendSelection{}, errors.Wrap(ErrNotFound, model) + } + body, _ := io.ReadAll(resp.Body) + return scheduling.ModelBackendSelection{}, fmt.Errorf("failed to resolve model backend: %s: %s", resp.Status, string(body)) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return scheduling.ModelBackendSelection{}, fmt.Errorf("failed to read response body: %w", err) + } + + var selection scheduling.ModelBackendSelection + if err := json.Unmarshal(body, &selection); err != nil { + return scheduling.ModelBackendSelection{}, fmt.Errorf("failed to unmarshal response body: %w", err) + } + + return selection, nil +} + func (c *Client) InspectOpenAI(model string) (dmrm.OpenAIModel, error) { modelsRoute := c.modelRunner.OpenAIPathPrefix() + "/models" rawResponse, err := c.listRaw(fmt.Sprintf("%s/%s", modelsRoute, model), model) diff --git a/pkg/inference/scheduling/api.go b/pkg/inference/scheduling/api.go index 2dd11df46..f54454a34 100644 --- a/pkg/inference/scheduling/api.go +++ b/pkg/inference/scheduling/api.go @@ -119,3 +119,9 @@ type ModelConfigEntry struct { Mode inference.BackendMode Config inference.BackendConfiguration } + +// ModelBackendSelection describes the backend selected by the scheduler for a model. +type ModelBackendSelection struct { + Backend string `json:"backend"` + Installed bool `json:"installed"` +} diff --git a/pkg/inference/scheduling/http_handler.go b/pkg/inference/scheduling/http_handler.go index 769bdc7ec..29fe2a481 100644 --- a/pkg/inference/scheduling/http_handler.go +++ b/pkg/inference/scheduling/http_handler.go @@ -119,6 +119,7 @@ func (h *HTTPHandler) routeHandlers() map[string]http.HandlerFunc { m["GET "+inference.InferencePrefix+"/status"] = h.GetBackendStatus m["GET "+inference.InferencePrefix+"/ps"] = h.GetRunningBackends m["GET "+inference.InferencePrefix+"/df"] = h.GetDiskUsage + m["GET "+inference.ModelsPrefix+"/backend"] = h.GetModelBackend m["POST "+inference.InferencePrefix+"/unload"] = h.Unload m["POST "+inference.InferencePrefix+"/{backend}/_configure"] = h.Configure m["POST "+inference.InferencePrefix+"/_configure"] = h.Configure @@ -540,6 +541,30 @@ func (h *HTTPHandler) GetModelConfigs(w http.ResponseWriter, r *http.Request) { } } +// GetModelBackend resolves the backend selected by the scheduler for the provided model. +func (h *HTTPHandler) GetModelBackend(w http.ResponseWriter, r *http.Request) { + modelRef := r.URL.Query().Get("model") + if modelRef == "" { + http.Error(w, "model is required", http.StatusBadRequest) + return + } + + backend, err := h.scheduler.ResolveBackendForModel(r.Context(), modelRef) + if err != nil { + http.Error(w, fmt.Sprintf("failed to resolve backend: %v", err), http.StatusNotFound) + return + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(ModelBackendSelection{ + Backend: backend.Name(), + Installed: h.scheduler.installer.isInstalled(backend.Name()), + }); err != nil { + http.Error(w, fmt.Sprintf("Failed to encode response: %v", err), http.StatusInternalServerError) + return + } +} + // ServeHTTP implements net/http.Handler.ServeHTTP. func (h *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.lock.RLock() diff --git a/pkg/inference/scheduling/scheduler.go b/pkg/inference/scheduling/scheduler.go index 7e8087baf..93ac79d61 100644 --- a/pkg/inference/scheduling/scheduler.go +++ b/pkg/inference/scheduling/scheduler.go @@ -142,6 +142,10 @@ func (s *Scheduler) selectBackendForModel(model types.Model, backend inference.B format = inferFormatFromModel(model) } + return s.selectBackendForFormat(format, backend, modelRef) +} + +func (s *Scheduler) selectBackendForFormat(format types.Format, backend inference.Backend, modelRef string) inference.Backend { switch format { case types.FormatSafetensors: // Prefer vLLM for safetensors models (handles platform dispatch internally) @@ -211,6 +215,28 @@ func inferFormatFromModel(model types.Model) types.Format { return "" } +// ResolveBackendForModel resolves the backend that should be used for a model reference. +// It prefers local model metadata and falls back to remote metadata when needed. +func (s *Scheduler) ResolveBackendForModel(ctx context.Context, modelRef string) (inference.Backend, error) { + backend := s.defaultBackend + + if model, err := s.modelManager.GetLocal(modelRef); err == nil { + return s.selectBackendForModel(model, backend, modelRef), nil + } + + artifact, err := s.modelManager.GetRemote(ctx, modelRef) + if err != nil { + return nil, err + } + + config, err := artifact.Config() + if err != nil { + return nil, err + } + + return s.selectBackendForFormat(config.GetFormat(), backend, modelRef), nil +} + // ResetInstaller resets the backend installer with a new HTTP client. func (s *Scheduler) ResetInstaller(httpClient *http.Client) { s.installer = newInstaller(s.log, s.backends, httpClient, s.deferredBackends) diff --git a/pkg/routing/router.go b/pkg/routing/router.go index f5249237f..164663c15 100644 --- a/pkg/routing/router.go +++ b/pkg/routing/router.go @@ -49,6 +49,7 @@ func NewRouter(cfg RouterConfig) *NormalizedServeMux { if cfg.ModelHandlerMiddleware != nil { modelEndpoint = cfg.ModelHandlerMiddleware(cfg.ModelHandler) } + router.Handle(inference.ModelsPrefix+"/backend", cfg.SchedulerHTTP) router.Handle(inference.ModelsPrefix, modelEndpoint) router.Handle(inference.ModelsPrefix+"/", modelEndpoint)