diff --git a/cmd/cli/commands/run.go b/cmd/cli/commands/run.go index 7ec76d887..d9d2d890e 100644 --- a/cmd/cli/commands/run.go +++ b/cmd/cli/commands/run.go @@ -829,10 +829,26 @@ func newRunCmd() *cobra.Command { } _, err := desktopClient.Inspect(model, false) - if err != nil { - if !errors.Is(err, desktop.ErrNotFound) { - return handleClientError(err, "Failed to inspect model") + modelFoundLocally := err == nil + if err != nil && !errors.Is(err, desktop.ErrNotFound) { + return handleClientError(err, "Failed to inspect model") + } + + backend := "" + if resolvedBackend, resolveErr := ResolveRequiredBackend(model); resolveErr == nil { + backend = resolvedBackend + } + + if backend != "" { + if err := EnsureBackendAvailable(backend, cmd); err != nil { + if errors.Is(err, errBackendInstallationCancelled) { + return nil + } + return err } + } + + if !modelFoundLocally { cmd.Println("Unable to find model '" + model + "' locally. Pulling from the server.") if err := pullModel(cmd, desktopClient, model); err != nil { return err diff --git a/cmd/cli/commands/utils.go b/cmd/cli/commands/utils.go index fbe1a932e..c96fc46c5 100644 --- a/cmd/cli/commands/utils.go +++ b/cmd/cli/commands/utils.go @@ -1,7 +1,9 @@ package commands import ( + "bufio" "bytes" + "encoding/json" "errors" "fmt" "io" @@ -43,6 +45,8 @@ func getDefaultRegistry() string { var errNotRunning = fmt.Errorf("Docker Model Runner is not running. Please start it and try again.\n") +var errBackendInstallationCancelled = errors.New("backend installation cancelled") + func handleClientError(err error, message string) error { if errors.Is(err, desktop.ErrServiceUnavailable) { err = errNotRunning @@ -278,6 +282,96 @@ func newTable(w io.Writer) *tablewriter.Table { ) } +func CheckBackendInstalled(backend string) (bool, error) { + status := desktopClient.Status() + if status.Error != nil { + return false, fmt.Errorf("failed to get backend status: %w", status.Error) + } + + var backendStatus map[string]string + if err := json.Unmarshal(status.Status, &backendStatus); err != nil { + return false, fmt.Errorf("failed to parse backend status: %w", err) + } + + backendState, exists := backendStatus[backend] + if !exists { + return false, nil + } + + state := strings.TrimSpace(strings.ToLower(backendState)) + if strings.HasPrefix(state, "not ") || strings.HasPrefix(state, "error") { + return false, nil + } + + return strings.HasPrefix(state, "installed") || strings.HasPrefix(state, "running"), nil +} + +func PromptInstallBackend(backend string, cmd *cobra.Command) (bool, error) { + fmt.Fprintf(cmd.OutOrStdout(), "Backend %q is not installed. Download and install it now? [Y/n]: ", backend) + + reader := bufio.NewReader(cmd.InOrStdin()) + input, err := reader.ReadString('\n') + if err != nil { + return false, fmt.Errorf("failed to read input: %w", err) + } + + input = strings.TrimSpace(strings.ToLower(input)) + return input == "" || input == "y" || input == "yes", nil +} + +func InstallBackend(backend string) error { + if err := desktopClient.InstallBackend(backend); err != nil { + return fmt.Errorf("failed to install backend %s: %w", backend, err) + } + + return nil +} + +func EnsureBackendAvailable(backend string, cmd *cobra.Command) error { + installed, err := CheckBackendInstalled(backend) + if err != nil { + return err + } + + if installed { + return nil + } + + confirm, err := PromptInstallBackend(backend, cmd) + if err != nil { + return err + } + + if !confirm { + cmd.Printf("Run 'docker model install-runner --backend %s' to install it manually.\n", backend) + return errBackendInstallationCancelled + } + + if err := InstallBackend(backend); err != nil { + return err + } + + installed, err = CheckBackendInstalled(backend) + if err != nil { + return err + } + if !installed { + return fmt.Errorf("backend %q is still not installed; run 'docker model install-runner --backend %s'", backend, backend) + } + + cmd.Printf("Backend %q installed successfully.\n", backend) + return nil +} + +func ResolveRequiredBackend(model string) (string, error) { + selection, err := desktopClient.ResolveModelBackend(model) + if err != nil { + return "", err + } + + return selection.Backend, nil +} + func printNextSteps(out io.Writer, messages []string) { if len(messages) == 0 { return diff --git a/cmd/cli/commands/utils_test.go b/cmd/cli/commands/utils_test.go index 0433fc06f..b9eaf24d7 100644 --- a/cmd/cli/commands/utils_test.go +++ b/cmd/cli/commands/utils_test.go @@ -1,9 +1,24 @@ package commands import ( + "bytes" + "encoding/json" "errors" "fmt" + "io" + "net/http" + "net/url" + "strings" "testing" + + "github.com/docker/model-runner/cmd/cli/desktop" + mockdesktop "github.com/docker/model-runner/cmd/cli/mocks" + "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/backends/vllm" + "github.com/docker/model-runner/pkg/inference/scheduling" + "github.com/spf13/cobra" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" ) func TestStripDefaultsFromModelName(t *testing.T) { @@ -112,3 +127,118 @@ func TestHandleClientErrorFormat(t *testing.T) { } }) } + +func setupDesktopClientStatusMock(t *testing.T, ctrl *gomock.Controller, backendStatus map[string]string) { + t.Helper() + + client := mockdesktop.NewMockDockerHttpClient(ctrl) + modelRunner = desktop.NewContextForMock(client) + desktopClient = desktop.New(modelRunner) + + statusJSON, err := json.Marshal(backendStatus) + require.NoError(t, err) + + expectedModelsURL := modelRunner.URL(inference.ModelsPrefix) + expectedStatusURL := modelRunner.URL(inference.InferencePrefix + "/status") + expectedUserAgent := "docker-model-cli/" + desktop.Version + + client.EXPECT().Do(gomock.Cond(func(req any) bool { + r, ok := req.(*http.Request) + return ok && r.URL.String() == expectedModelsURL && r.Header.Get("User-Agent") == expectedUserAgent + })).Return(&http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("{}"))}, nil) + + client.EXPECT().Do(gomock.Cond(func(req any) bool { + r, ok := req.(*http.Request) + return ok && r.URL.String() == expectedStatusURL && r.Header.Get("User-Agent") == expectedUserAgent + })).Return(&http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(statusJSON))}, nil) +} + +func TestCheckBackendInstalled(t *testing.T) { + t.Run("running status string is treated as installed", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + setupDesktopClientStatusMock(t, ctrl, map[string]string{"vllm": "running vllm latest-cuda"}) + + installed, err := CheckBackendInstalled(vllm.Name) + require.NoError(t, err) + require.True(t, installed) + }) + + t.Run("not running status is treated as missing", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + setupDesktopClientStatusMock(t, ctrl, map[string]string{"vllm": "not running"}) + + installed, err := CheckBackendInstalled(vllm.Name) + require.NoError(t, err) + require.False(t, installed) + }) + + t.Run("error status is treated as missing", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + setupDesktopClientStatusMock(t, ctrl, map[string]string{"vllm": "error failed to start"}) + + installed, err := CheckBackendInstalled(vllm.Name) + require.NoError(t, err) + require.False(t, installed) + }) +} + +func TestPromptInstallBackend(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + cmd.SetIn(strings.NewReader("yes\n")) + out := new(bytes.Buffer) + cmd.SetOut(out) + + confirmed, err := PromptInstallBackend(vllm.Name, cmd) + require.NoError(t, err) + require.True(t, confirmed) + require.Contains(t, out.String(), "Backend \"vllm\" is not installed") +} + +func TestEnsureBackendAvailableCancelled(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + setupDesktopClientStatusMock(t, ctrl, map[string]string{"vllm": "not running"}) + + cmd := &cobra.Command{Use: "test"} + cmd.SetIn(strings.NewReader("n\n")) + out := new(bytes.Buffer) + cmd.SetOut(out) + + err := EnsureBackendAvailable(vllm.Name, cmd) + require.Error(t, err) + require.ErrorIs(t, err, errBackendInstallationCancelled) + require.Contains(t, out.String(), "docker model install-runner --backend vllm") +} + +func TestResolveRequiredBackend(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + client := mockdesktop.NewMockDockerHttpClient(ctrl) + modelRunner = desktop.NewContextForMock(client) + desktopClient = desktop.New(modelRunner) + + model := "ai/functiongemma-vllm:270M" + selection := scheduling.ModelBackendSelection{Backend: vllm.Name, Installed: false} + body, err := json.Marshal(selection) + require.NoError(t, err) + + expectedResolveURL := modelRunner.URL(inference.ModelsPrefix + "/backend?model=" + url.QueryEscape(model)) + expectedUserAgent := "docker-model-cli/" + desktop.Version + + client.EXPECT().Do(gomock.Cond(func(req any) bool { + r, ok := req.(*http.Request) + return ok && r.URL.String() == expectedResolveURL && r.Header.Get("User-Agent") == expectedUserAgent + })).Return(&http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(body))}, nil) + + backend, err := ResolveRequiredBackend(model) + require.NoError(t, err) + require.Equal(t, vllm.Name, backend) +} diff --git a/cmd/cli/desktop/desktop.go b/cmd/cli/desktop/desktop.go index 95a069e05..4fa1a9905 100644 --- a/cmd/cli/desktop/desktop.go +++ b/cmd/cli/desktop/desktop.go @@ -362,6 +362,36 @@ func (c *Client) Inspect(model string, remote bool) (dmrm.Model, error) { return modelInspect, nil } +func (c *Client) ResolveModelBackend(model string) (scheduling.ModelBackendSelection, error) { + resolvePath := fmt.Sprintf("%s/backend?model=%s", inference.ModelsPrefix, url.QueryEscape(model)) + + resp, err := c.doRequest(http.MethodGet, resolvePath, nil) + if err != nil { + return scheduling.ModelBackendSelection{}, c.handleQueryError(err, resolvePath) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if resp.StatusCode == http.StatusNotFound { + return scheduling.ModelBackendSelection{}, errors.Wrap(ErrNotFound, model) + } + body, _ := io.ReadAll(resp.Body) + return scheduling.ModelBackendSelection{}, fmt.Errorf("failed to resolve model backend: %s: %s", resp.Status, string(body)) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return scheduling.ModelBackendSelection{}, fmt.Errorf("failed to read response body: %w", err) + } + + var selection scheduling.ModelBackendSelection + if err := json.Unmarshal(body, &selection); err != nil { + return scheduling.ModelBackendSelection{}, fmt.Errorf("failed to unmarshal response body: %w", err) + } + + return selection, nil +} + func (c *Client) InspectOpenAI(model string) (dmrm.OpenAIModel, error) { modelsRoute := c.modelRunner.OpenAIPathPrefix() + "/models" rawResponse, err := c.listRaw(fmt.Sprintf("%s/%s", modelsRoute, model), model) diff --git a/pkg/inference/scheduling/api.go b/pkg/inference/scheduling/api.go index 2dd11df46..f54454a34 100644 --- a/pkg/inference/scheduling/api.go +++ b/pkg/inference/scheduling/api.go @@ -119,3 +119,9 @@ type ModelConfigEntry struct { Mode inference.BackendMode Config inference.BackendConfiguration } + +// ModelBackendSelection describes the backend selected by the scheduler for a model. +type ModelBackendSelection struct { + Backend string `json:"backend"` + Installed bool `json:"installed"` +} diff --git a/pkg/inference/scheduling/http_handler.go b/pkg/inference/scheduling/http_handler.go index 769bdc7ec..29fe2a481 100644 --- a/pkg/inference/scheduling/http_handler.go +++ b/pkg/inference/scheduling/http_handler.go @@ -119,6 +119,7 @@ func (h *HTTPHandler) routeHandlers() map[string]http.HandlerFunc { m["GET "+inference.InferencePrefix+"/status"] = h.GetBackendStatus m["GET "+inference.InferencePrefix+"/ps"] = h.GetRunningBackends m["GET "+inference.InferencePrefix+"/df"] = h.GetDiskUsage + m["GET "+inference.ModelsPrefix+"/backend"] = h.GetModelBackend m["POST "+inference.InferencePrefix+"/unload"] = h.Unload m["POST "+inference.InferencePrefix+"/{backend}/_configure"] = h.Configure m["POST "+inference.InferencePrefix+"/_configure"] = h.Configure @@ -540,6 +541,30 @@ func (h *HTTPHandler) GetModelConfigs(w http.ResponseWriter, r *http.Request) { } } +// GetModelBackend resolves the backend selected by the scheduler for the provided model. +func (h *HTTPHandler) GetModelBackend(w http.ResponseWriter, r *http.Request) { + modelRef := r.URL.Query().Get("model") + if modelRef == "" { + http.Error(w, "model is required", http.StatusBadRequest) + return + } + + backend, err := h.scheduler.ResolveBackendForModel(r.Context(), modelRef) + if err != nil { + http.Error(w, fmt.Sprintf("failed to resolve backend: %v", err), http.StatusNotFound) + return + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(ModelBackendSelection{ + Backend: backend.Name(), + Installed: h.scheduler.installer.isInstalled(backend.Name()), + }); err != nil { + http.Error(w, fmt.Sprintf("Failed to encode response: %v", err), http.StatusInternalServerError) + return + } +} + // ServeHTTP implements net/http.Handler.ServeHTTP. func (h *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.lock.RLock() diff --git a/pkg/inference/scheduling/scheduler.go b/pkg/inference/scheduling/scheduler.go index 7e8087baf..93ac79d61 100644 --- a/pkg/inference/scheduling/scheduler.go +++ b/pkg/inference/scheduling/scheduler.go @@ -142,6 +142,10 @@ func (s *Scheduler) selectBackendForModel(model types.Model, backend inference.B format = inferFormatFromModel(model) } + return s.selectBackendForFormat(format, backend, modelRef) +} + +func (s *Scheduler) selectBackendForFormat(format types.Format, backend inference.Backend, modelRef string) inference.Backend { switch format { case types.FormatSafetensors: // Prefer vLLM for safetensors models (handles platform dispatch internally) @@ -211,6 +215,28 @@ func inferFormatFromModel(model types.Model) types.Format { return "" } +// ResolveBackendForModel resolves the backend that should be used for a model reference. +// It prefers local model metadata and falls back to remote metadata when needed. +func (s *Scheduler) ResolveBackendForModel(ctx context.Context, modelRef string) (inference.Backend, error) { + backend := s.defaultBackend + + if model, err := s.modelManager.GetLocal(modelRef); err == nil { + return s.selectBackendForModel(model, backend, modelRef), nil + } + + artifact, err := s.modelManager.GetRemote(ctx, modelRef) + if err != nil { + return nil, err + } + + config, err := artifact.Config() + if err != nil { + return nil, err + } + + return s.selectBackendForFormat(config.GetFormat(), backend, modelRef), nil +} + // ResetInstaller resets the backend installer with a new HTTP client. func (s *Scheduler) ResetInstaller(httpClient *http.Client) { s.installer = newInstaller(s.log, s.backends, httpClient, s.deferredBackends) diff --git a/pkg/routing/router.go b/pkg/routing/router.go index f5249237f..164663c15 100644 --- a/pkg/routing/router.go +++ b/pkg/routing/router.go @@ -49,6 +49,7 @@ func NewRouter(cfg RouterConfig) *NormalizedServeMux { if cfg.ModelHandlerMiddleware != nil { modelEndpoint = cfg.ModelHandlerMiddleware(cfg.ModelHandler) } + router.Handle(inference.ModelsPrefix+"/backend", cfg.SchedulerHTTP) router.Handle(inference.ModelsPrefix, modelEndpoint) router.Handle(inference.ModelsPrefix+"/", modelEndpoint)