Skip to content
22 changes: 19 additions & 3 deletions cmd/cli/commands/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -829,10 +829,26 @@ func newRunCmd() *cobra.Command {
}

_, err := desktopClient.Inspect(model, false)
if err != nil {
if !errors.Is(err, desktop.ErrNotFound) {
return handleClientError(err, "Failed to inspect model")
modelFoundLocally := err == nil
if err != nil && !errors.Is(err, desktop.ErrNotFound) {
return handleClientError(err, "Failed to inspect model")
}

backend := ""
if resolvedBackend, resolveErr := ResolveRequiredBackend(model); resolveErr == nil {
backend = resolvedBackend
}

if backend != "" {
if err := EnsureBackendAvailable(backend, cmd); err != nil {
if errors.Is(err, errBackendInstallationCancelled) {
return nil
}
return err
}
}

if !modelFoundLocally {
cmd.Println("Unable to find model '" + model + "' locally. Pulling from the server.")
if err := pullModel(cmd, desktopClient, model); err != nil {
return err
Expand Down
94 changes: 94 additions & 0 deletions cmd/cli/commands/utils.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package commands

import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -43,6 +45,8 @@ func getDefaultRegistry() string {

var errNotRunning = fmt.Errorf("Docker Model Runner is not running. Please start it and try again.\n")

var errBackendInstallationCancelled = errors.New("backend installation cancelled")

func handleClientError(err error, message string) error {
if errors.Is(err, desktop.ErrServiceUnavailable) {
err = errNotRunning
Expand Down Expand Up @@ -278,6 +282,96 @@ func newTable(w io.Writer) *tablewriter.Table {
)
}

func CheckBackendInstalled(backend string) (bool, error) {
status := desktopClient.Status()
if status.Error != nil {
return false, fmt.Errorf("failed to get backend status: %w", status.Error)
}

var backendStatus map[string]string
if err := json.Unmarshal(status.Status, &backendStatus); err != nil {
return false, fmt.Errorf("failed to parse backend status: %w", err)
}

backendState, exists := backendStatus[backend]
if !exists {
return false, nil
}

state := strings.TrimSpace(strings.ToLower(backendState))
if strings.HasPrefix(state, "not ") || strings.HasPrefix(state, "error") {
return false, nil
}

return strings.HasPrefix(state, "installed") || strings.HasPrefix(state, "running"), nil
}

func PromptInstallBackend(backend string, cmd *cobra.Command) (bool, error) {
fmt.Fprintf(cmd.OutOrStdout(), "Backend %q is not installed. Download and install it now? [Y/n]: ", backend)

reader := bufio.NewReader(cmd.InOrStdin())
input, err := reader.ReadString('\n')
if err != nil {
return false, fmt.Errorf("failed to read input: %w", err)
}

input = strings.TrimSpace(strings.ToLower(input))
return input == "" || input == "y" || input == "yes", nil
}

func InstallBackend(backend string) error {
if err := desktopClient.InstallBackend(backend); err != nil {
return fmt.Errorf("failed to install backend %s: %w", backend, err)
}

return nil
}

func EnsureBackendAvailable(backend string, cmd *cobra.Command) error {
installed, err := CheckBackendInstalled(backend)
if err != nil {
return err
}

if installed {
return nil
}

confirm, err := PromptInstallBackend(backend, cmd)
if err != nil {
return err
}

if !confirm {
cmd.Printf("Run 'docker model install-runner --backend %s' to install it manually.\n", backend)
return errBackendInstallationCancelled
}

if err := InstallBackend(backend); err != nil {
return err
}

installed, err = CheckBackendInstalled(backend)
if err != nil {
return err
}
if !installed {
return fmt.Errorf("backend %q is still not installed; run 'docker model install-runner --backend %s'", backend, backend)
}

cmd.Printf("Backend %q installed successfully.\n", backend)
return nil
}

func ResolveRequiredBackend(model string) (string, error) {
selection, err := desktopClient.ResolveModelBackend(model)
if err != nil {
return "", err
}

return selection.Backend, nil
}

func printNextSteps(out io.Writer, messages []string) {
if len(messages) == 0 {
return
Expand Down
130 changes: 130 additions & 0 deletions cmd/cli/commands/utils_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
package commands

import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"testing"

"github.com/docker/model-runner/cmd/cli/desktop"
mockdesktop "github.com/docker/model-runner/cmd/cli/mocks"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/backends/vllm"
"github.com/docker/model-runner/pkg/inference/scheduling"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)

func TestStripDefaultsFromModelName(t *testing.T) {
Expand Down Expand Up @@ -112,3 +127,118 @@ func TestHandleClientErrorFormat(t *testing.T) {
}
})
}

func setupDesktopClientStatusMock(t *testing.T, ctrl *gomock.Controller, backendStatus map[string]string) {
t.Helper()

client := mockdesktop.NewMockDockerHttpClient(ctrl)
modelRunner = desktop.NewContextForMock(client)
desktopClient = desktop.New(modelRunner)

statusJSON, err := json.Marshal(backendStatus)
require.NoError(t, err)

expectedModelsURL := modelRunner.URL(inference.ModelsPrefix)
expectedStatusURL := modelRunner.URL(inference.InferencePrefix + "/status")
expectedUserAgent := "docker-model-cli/" + desktop.Version

client.EXPECT().Do(gomock.Cond(func(req any) bool {
r, ok := req.(*http.Request)
return ok && r.URL.String() == expectedModelsURL && r.Header.Get("User-Agent") == expectedUserAgent
})).Return(&http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("{}"))}, nil)

client.EXPECT().Do(gomock.Cond(func(req any) bool {
r, ok := req.(*http.Request)
return ok && r.URL.String() == expectedStatusURL && r.Header.Get("User-Agent") == expectedUserAgent
})).Return(&http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(statusJSON))}, nil)
}

func TestCheckBackendInstalled(t *testing.T) {
t.Run("running status string is treated as installed", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

setupDesktopClientStatusMock(t, ctrl, map[string]string{"vllm": "running vllm latest-cuda"})

installed, err := CheckBackendInstalled(vllm.Name)
require.NoError(t, err)
require.True(t, installed)
})

t.Run("not running status is treated as missing", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

setupDesktopClientStatusMock(t, ctrl, map[string]string{"vllm": "not running"})

installed, err := CheckBackendInstalled(vllm.Name)
require.NoError(t, err)
require.False(t, installed)
})

t.Run("error status is treated as missing", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

setupDesktopClientStatusMock(t, ctrl, map[string]string{"vllm": "error failed to start"})

installed, err := CheckBackendInstalled(vllm.Name)
require.NoError(t, err)
require.False(t, installed)
})
}

func TestPromptInstallBackend(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
cmd.SetIn(strings.NewReader("yes\n"))
out := new(bytes.Buffer)
cmd.SetOut(out)

confirmed, err := PromptInstallBackend(vllm.Name, cmd)
require.NoError(t, err)
require.True(t, confirmed)
require.Contains(t, out.String(), "Backend \"vllm\" is not installed")
}

func TestEnsureBackendAvailableCancelled(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

setupDesktopClientStatusMock(t, ctrl, map[string]string{"vllm": "not running"})

cmd := &cobra.Command{Use: "test"}
cmd.SetIn(strings.NewReader("n\n"))
out := new(bytes.Buffer)
cmd.SetOut(out)

err := EnsureBackendAvailable(vllm.Name, cmd)
require.Error(t, err)
require.ErrorIs(t, err, errBackendInstallationCancelled)
require.Contains(t, out.String(), "docker model install-runner --backend vllm")
}

func TestResolveRequiredBackend(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

client := mockdesktop.NewMockDockerHttpClient(ctrl)
modelRunner = desktop.NewContextForMock(client)
desktopClient = desktop.New(modelRunner)

model := "ai/functiongemma-vllm:270M"
selection := scheduling.ModelBackendSelection{Backend: vllm.Name, Installed: false}
body, err := json.Marshal(selection)
require.NoError(t, err)

expectedResolveURL := modelRunner.URL(inference.ModelsPrefix + "/backend?model=" + url.QueryEscape(model))
expectedUserAgent := "docker-model-cli/" + desktop.Version

client.EXPECT().Do(gomock.Cond(func(req any) bool {
r, ok := req.(*http.Request)
return ok && r.URL.String() == expectedResolveURL && r.Header.Get("User-Agent") == expectedUserAgent
})).Return(&http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(body))}, nil)

backend, err := ResolveRequiredBackend(model)
require.NoError(t, err)
require.Equal(t, vllm.Name, backend)
}
30 changes: 30 additions & 0 deletions cmd/cli/desktop/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,36 @@ func (c *Client) Inspect(model string, remote bool) (dmrm.Model, error) {
return modelInspect, nil
}

func (c *Client) ResolveModelBackend(model string) (scheduling.ModelBackendSelection, error) {
resolvePath := fmt.Sprintf("%s/backend?model=%s", inference.ModelsPrefix, url.QueryEscape(model))

resp, err := c.doRequest(http.MethodGet, resolvePath, nil)
if err != nil {
return scheduling.ModelBackendSelection{}, c.handleQueryError(err, resolvePath)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusNotFound {
return scheduling.ModelBackendSelection{}, errors.Wrap(ErrNotFound, model)
}
body, _ := io.ReadAll(resp.Body)
return scheduling.ModelBackendSelection{}, fmt.Errorf("failed to resolve model backend: %s: %s", resp.Status, string(body))
}

body, err := io.ReadAll(resp.Body)
if err != nil {
return scheduling.ModelBackendSelection{}, fmt.Errorf("failed to read response body: %w", err)
}

var selection scheduling.ModelBackendSelection
if err := json.Unmarshal(body, &selection); err != nil {
return scheduling.ModelBackendSelection{}, fmt.Errorf("failed to unmarshal response body: %w", err)
}

return selection, nil
}

func (c *Client) InspectOpenAI(model string) (dmrm.OpenAIModel, error) {
modelsRoute := c.modelRunner.OpenAIPathPrefix() + "/models"
rawResponse, err := c.listRaw(fmt.Sprintf("%s/%s", modelsRoute, model), model)
Expand Down
6 changes: 6 additions & 0 deletions pkg/inference/scheduling/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,9 @@ type ModelConfigEntry struct {
Mode inference.BackendMode
Config inference.BackendConfiguration
}

// ModelBackendSelection describes the backend selected by the scheduler for a model.
type ModelBackendSelection struct {
Backend string `json:"backend"`
Installed bool `json:"installed"`
}
25 changes: 25 additions & 0 deletions pkg/inference/scheduling/http_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ func (h *HTTPHandler) routeHandlers() map[string]http.HandlerFunc {
m["GET "+inference.InferencePrefix+"/status"] = h.GetBackendStatus
m["GET "+inference.InferencePrefix+"/ps"] = h.GetRunningBackends
m["GET "+inference.InferencePrefix+"/df"] = h.GetDiskUsage
m["GET "+inference.ModelsPrefix+"/backend"] = h.GetModelBackend
m["POST "+inference.InferencePrefix+"/unload"] = h.Unload
m["POST "+inference.InferencePrefix+"/{backend}/_configure"] = h.Configure
m["POST "+inference.InferencePrefix+"/_configure"] = h.Configure
Expand Down Expand Up @@ -540,6 +541,30 @@ func (h *HTTPHandler) GetModelConfigs(w http.ResponseWriter, r *http.Request) {
}
}

// GetModelBackend resolves the backend selected by the scheduler for the provided model.
func (h *HTTPHandler) GetModelBackend(w http.ResponseWriter, r *http.Request) {
modelRef := r.URL.Query().Get("model")
if modelRef == "" {
http.Error(w, "model is required", http.StatusBadRequest)
return
}

backend, err := h.scheduler.ResolveBackendForModel(r.Context(), modelRef)
if err != nil {
http.Error(w, fmt.Sprintf("failed to resolve backend: %v", err), http.StatusNotFound)
return
}

w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(ModelBackendSelection{
Backend: backend.Name(),
Installed: h.scheduler.installer.isInstalled(backend.Name()),
}); err != nil {
http.Error(w, fmt.Sprintf("Failed to encode response: %v", err), http.StatusInternalServerError)
return
}
}

// ServeHTTP implements net/http.Handler.ServeHTTP.
func (h *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.lock.RLock()
Expand Down
Loading