Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 11 additions & 31 deletions pkg/inference/backends/llamacpp/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@ package llamacpp

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
Expand Down Expand Up @@ -50,7 +47,7 @@ func SetDesiredServerVersion(version string) {
}

//nolint:unused // Used in platform-specific files (download_darwin.go, download_windows.go)
func (l *llamaCpp) downloadLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
func (l *llamaCpp) downloadLatestLlamaCpp(ctx context.Context, log logging.Logger,
llamaCppPath, vendoredServerStoragePath, desiredVersion, desiredVariant string,
) error {
ShouldUpdateServerLock.Lock()
Expand All @@ -63,35 +60,18 @@ func (l *llamaCpp) downloadLatestLlamaCpp(ctx context.Context, log logging.Logge

log.Info("downloadLatestLlamaCpp", "desiredVersion", desiredVersion, "desiredVariant", desiredVariant, "vendoredServerStoragePath", vendoredServerStoragePath, "llamaCppPath", llamaCppPath)
desiredTag := desiredVersion + "-" + desiredVariant
url := fmt.Sprintf("https://hub.docker.com/v2/namespaces/%s/repositories/%s/tags/%s", hubNamespace, hubRepo, desiredTag)
resp, err := httpClient.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
// Resolve the desired tag to a digest via the Registry HTTP API v2. This
// honors l.registryMirrors (typically a corporate Artifactory / Nexus /
// Harbor mirror configured for docker.io) and credentials populated by
// `docker login`, so customers behind a private mirror with no direct
// egress to registry-1.docker.io can still resolve and pull the backend
// image. See docker/model-runner#TBD.
tagRef := fmt.Sprintf("registry-1.docker.io/%s/%s:%s", hubNamespace, hubRepo, desiredTag)
latest, err := dockerhub.ResolveDigest(ctx, tagRef, l.registryMirrors)
if err != nil {
return fmt.Errorf("failed to read response body: %w", err)
}

// https://docs.docker.com/reference/api/hub/latest/#tag/repositories/paths/~1v2~1namespaces~1%7Bnamespace%7D~1repositories~1%7Brepository%7D~1tags~1%7Btag%7D/get
var response struct {
Name string `json:"name"`
Digest string `json:"digest"`
}

if unmarshalErr := json.Unmarshal(body, &response); unmarshalErr != nil {
return fmt.Errorf("failed to unmarshal response body: %w", unmarshalErr)
}

var latest string
if response.Name == desiredTag {
latest = response.Digest
}
if latest == "" {
log.Warn("could not find the tag", "tag", desiredTag, "response", body)
return fmt.Errorf("could not find the %s tag", desiredTag)
log.Warn("could not resolve llama.cpp tag", "tag", desiredTag, "mirrors", l.registryMirrors, "error", err)
return fmt.Errorf("could not resolve the %s tag: %w", desiredTag, err)
}

bundledVersionFile := filepath.Join(vendoredServerStoragePath, "com.docker.llama-server.digest")
Expand Down
4 changes: 2 additions & 2 deletions pkg/inference/backends/llamacpp/download_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ import (
"github.com/docker/model-runner/pkg/logging"
)

func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, _ *http.Client,
llamaCppPath, vendoredServerStoragePath string,
) error {
desiredVersion := GetDesiredServerVersion()
desiredVariant := "metal"
return l.downloadLatestLlamaCpp(ctx, log, httpClient, llamaCppPath, vendoredServerStoragePath, desiredVersion,
return l.downloadLatestLlamaCpp(ctx, log, llamaCppPath, vendoredServerStoragePath, desiredVersion,
desiredVariant)
}
4 changes: 2 additions & 2 deletions pkg/inference/backends/llamacpp/download_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"github.com/docker/model-runner/pkg/logging"
)

func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, _ *http.Client,
llamaCppPath, vendoredServerStoragePath string,
) error {
nvGPUInfoBin := filepath.Join(vendoredServerStoragePath, "com.docker.nv-gpu-info.exe")
Expand Down Expand Up @@ -43,6 +43,6 @@ func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger,
desiredVariant = "opencl"
}
l.status = inference.FormatInstalling(fmt.Sprintf("%s llama.cpp %s", inference.DetailCheckingForUpdates, desiredVariant))
return l.downloadLatestLlamaCpp(ctx, log, httpClient, llamaCppPath, vendoredServerStoragePath, desiredVersion,
return l.downloadLatestLlamaCpp(ctx, log, llamaCppPath, vendoredServerStoragePath, desiredVersion,
desiredVariant)
}
76 changes: 70 additions & 6 deletions pkg/internal/dockerhub/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"log/slog"
"net/http"
"os"
"path/filepath"
"strings"
Expand All @@ -16,7 +17,9 @@ import (
"github.com/containerd/containerd/v2/core/images/archive"
"github.com/containerd/containerd/v2/core/remotes"
"github.com/containerd/containerd/v2/core/remotes/docker"
remoteerrors "github.com/containerd/containerd/v2/core/remotes/errors"
"github.com/containerd/containerd/v2/plugins/content/local"
"github.com/containerd/errdefs"
"github.com/containerd/platforms"
"github.com/docker/model-runner/pkg/internal/jsonutil"
"github.com/docker/model-runner/pkg/internal/registryutil"
Expand All @@ -40,13 +43,49 @@ func PullPlatform(ctx context.Context, image, destination, requiredOs, requiredA
if err != nil {
return fmt.Errorf("creating new content store: %w", err)
}
desc, err := retry(ctx, 10, 1*time.Second, func() (*v1.Descriptor, error) { return fetch(ctx, store, image, requiredOs, requiredArch, mirrors) })
resolver := newResolver(mirrors)
desc, err := retry(ctx, 10, 1*time.Second, func() (*v1.Descriptor, error) {
return fetch(ctx, resolver, store, image, requiredOs, requiredArch)
})
if err != nil {
return fmt.Errorf("fetching image: %w", err)
}
return archive.Export(ctx, store, output, archive.WithManifest(*desc, image), archive.WithSkipMissing(store))
}

// ResolveDigest resolves the given image reference (e.g. "registry-1.docker.io/docker/foo:tag")
// against the registry (with optional mirrors tried first for Docker Hub references) and
// returns the resolved digest. It does not download any blobs; it issues only the manifest
// HEAD/GET that the registry resolver needs.
//
// Authentication uses the same credentials lookup as PullPlatform (env vars
// DOCKER_HUB_USER/DOCKER_HUB_PASSWORD or ~/.docker/config.json), so a prior
// `docker login <mirror-host>` is honored.
func ResolveDigest(ctx context.Context, ref string, mirrors []string) (string, error) {
resolver := newResolver(mirrors)
desc, err := retry(ctx, 10, 1*time.Second, func() (*v1.Descriptor, error) {
name, d, err := resolver.Resolve(ctx, ref)
if err != nil {
return nil, err
}
slog.Debug("resolved image tag", "ref", ref, "resolved", name, "digest", d.Digest.String())
return &d, nil
})
if err != nil {
return "", fmt.Errorf("resolving image %q: %w", ref, err)
}
return desc.Digest.String(), nil
}

// newResolver builds a containerd docker resolver that authenticates via
// dockerCredentials and tries the given mirrors before the upstream registry.
func newResolver(mirrors []string) remotes.Resolver {
authorizer := docker.NewDockerAuthorizer(docker.WithAuthCreds(dockerCredentials))
return docker.NewResolver(docker.ResolverOptions{
Hosts: registryutil.RegistryHosts(mirrors, authorizer, nil),
})
}

func retry(ctx context.Context, attempts int, sleep time.Duration, f func() (*v1.Descriptor, error)) (*v1.Descriptor, error) {
var err error
var result *v1.Descriptor
Expand All @@ -63,15 +102,40 @@ func retry(ctx context.Context, attempts int, sleep time.Duration, f func() (*v1
if err == nil {
return result, nil
}
if isTerminal(err) {
return nil, err
}
}
return nil, fmt.Errorf("after %d attempts, last error: %w", attempts, err)
}

func fetch(ctx context.Context, store content.Store, ref, requiredOs, requiredArch string, mirrors []string) (*v1.Descriptor, error) {
authorizer := docker.NewDockerAuthorizer(docker.WithAuthCreds(dockerCredentials))
resolver := docker.NewResolver(docker.ResolverOptions{
Hosts: registryutil.RegistryHosts(mirrors, authorizer, nil),
})
// isTerminal reports whether err is non-retryable: a missing tag/manifest, an
// authentication/authorization failure, or a canceled/expired context. Retrying
// these only wastes time, so the caller should fail fast instead of looping.
//
// The containerd resolver only maps 404 to errdefs.ErrNotFound; other 4xx
// statuses (including 401 and 403) surface as a remoteerrors.ErrUnexpectedStatus
// carrying the raw status code, so we inspect that explicitly. 429 is
// deliberately left retryable — the resolver already retries it internally and a
// later attempt can succeed once a rate limit clears.
func isTerminal(err error) bool {
if errdefs.IsNotFound(err) ||
errdefs.IsUnauthorized(err) ||
errors.Is(err, context.Canceled) ||
errors.Is(err, context.DeadlineExceeded) {
return true
}
var unexpected remoteerrors.ErrUnexpectedStatus
if errors.As(err, &unexpected) {
switch unexpected.StatusCode {
case http.StatusUnauthorized, http.StatusForbidden:
return true
}
}
return false
}
Comment thread
ilopezluna marked this conversation as resolved.

func fetch(ctx context.Context, resolver remotes.Resolver, store content.Store, ref, requiredOs, requiredArch string) (*v1.Descriptor, error) {
name, desc, err := resolver.Resolve(ctx, ref)
if err != nil {
return nil, err
Expand Down
177 changes: 177 additions & 0 deletions pkg/internal/dockerhub/download_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package dockerhub

import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"

remoteerrors "github.com/containerd/containerd/v2/core/remotes/errors"
"github.com/containerd/errdefs"
v1 "github.com/opencontainers/image-spec/specs-go/v1"
)

// registryHandler is a minimal Docker Registry v2 HTTP handler that supports
// the manifest HEAD / GET requests issued by containerd's docker resolver.
type registryHandler struct {
// tag is the tag to recognize; for any other tag the handler returns 404.
tag string
// digest returned in the Docker-Content-Digest header.
digest string
// requests counts how many requests this handler received (for assertions).
requests atomic.Int64
}

func (h *registryHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.requests.Add(1)
switch {
case r.URL.Path == "/v2/" || r.URL.Path == "/v2":
// API version probe.
w.Header().Set("Docker-Distribution-API-Version", "registry/2.0")
w.WriteHeader(http.StatusOK)
case strings.HasSuffix(r.URL.Path, "/manifests/"+h.tag):
// Manifest HEAD/GET for the recognized tag.
w.Header().Set("Docker-Content-Digest", h.digest)
w.Header().Set("Content-Type", "application/vnd.oci.image.index.v1+json")
body := []byte(`{"schemaVersion":2,"mediaType":"application/vnd.oci.image.index.v1+json","manifests":[]}`)
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(body)))
if r.Method == http.MethodHead {
w.WriteHeader(http.StatusOK)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write(body)
default:
http.Error(w, "not found", http.StatusNotFound)
}
}

// TestResolveDigest_UsesMirror verifies that when a mirror is configured for
// Docker Hub references, the resolver issues its manifest lookup against the
// mirror rather than registry-1.docker.io. This is the path enterprise
// customers behind an Artifactory / Nexus / Harbor mirror need.
func TestResolveDigest_UsesMirror(t *testing.T) {
const wantDigest = "sha256:48883a67000000000000000000000000000000000000000000000000deadbeef"

mirror := &registryHandler{tag: "latest-cuda", digest: wantDigest}
srv := httptest.NewServer(mirror)
defer srv.Close()

ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
defer cancel()

// Reference points at registry-1.docker.io; the mirror should intercept it.
ref := "registry-1.docker.io/docker/docker-model-backend-llamacpp:latest-cuda"
got, err := ResolveDigest(ctx, ref, []string{srv.URL})
if err != nil {
t.Fatalf("ResolveDigest returned error: %v", err)
}
if got != wantDigest {
t.Fatalf("digest mismatch: got %q want %q", got, wantDigest)
}
if mirror.requests.Load() == 0 {
t.Fatalf("expected mirror to be called at least once, got 0 requests")
}
}

// TestResolveDigest_CanceledContext verifies the resolver does not block when
// the context is already canceled. This protects against silent stalls when
// the network path to the upstream/mirror is blackholed (a frequent symptom
// in enterprise networks).
func TestResolveDigest_CanceledContext(t *testing.T) {
ctx, cancel := context.WithCancel(t.Context())
cancel()

// No mirror, no real network call should complete. We bound the test
// with a wall-clock deadline so a regression cannot hang CI. A canceled
// context is classified as terminal, so retry must not loop.
done := make(chan struct{})
var resolveErr error
go func() {
_, resolveErr = ResolveDigest(ctx, "registry-1.docker.io/docker/docker-model-backend-llamacpp:latest-cuda", nil)
close(done)
}()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatalf("ResolveDigest did not return on canceled context within 5s")
}
if resolveErr == nil {
t.Fatalf("expected error on canceled context, got nil")
}
}

// TestRetry_FailsFastOnTerminalError verifies retry does not loop on a
// non-retryable error (e.g. a missing tag / 404). Before this, every error was
// retried 10 times with 1s sleeps (~9s), blocking the install/startup path.
func TestRetry_FailsFastOnTerminalError(t *testing.T) {
var calls int
_, err := retry(t.Context(), 10, time.Second, func() (*v1.Descriptor, error) {
calls++
return nil, errdefs.ErrNotFound
})
if err == nil {
t.Fatalf("expected error on terminal failure, got nil")
}
if calls != 1 {
t.Fatalf("expected exactly 1 attempt on a terminal error, got %d", calls)
}
}

// TestRetry_RetriesTransientError verifies retry still loops the full budget on
// an unclassified (transient) error, preserving the original behavior.
func TestRetry_RetriesTransientError(t *testing.T) {
var calls int
_, err := retry(t.Context(), 3, time.Millisecond, func() (*v1.Descriptor, error) {
calls++
return nil, errors.New("transient network blip")
})
if err == nil {
t.Fatalf("expected error after exhausting attempts, got nil")
}
if calls != 3 {
t.Fatalf("expected 3 attempts on a transient error, got %d", calls)
}
}

// TestRetry_FailsFastOnForbidden verifies that a 403 Forbidden — which the
// containerd resolver surfaces as remoteerrors.ErrUnexpectedStatus, not an
// errdefs sentinel — is treated as terminal and not retried. Same applies to
// 401 Unauthorized via the same path.
func TestRetry_FailsFastOnForbidden(t *testing.T) {
for _, code := range []int{http.StatusUnauthorized, http.StatusForbidden} {
var calls int
_, err := retry(t.Context(), 10, time.Second, func() (*v1.Descriptor, error) {
calls++
return nil, remoteerrors.ErrUnexpectedStatus{StatusCode: code}
})
if err == nil {
t.Fatalf("status %d: expected error, got nil", code)
}
if calls != 1 {
t.Fatalf("status %d: expected exactly 1 attempt, got %d", code, calls)
}
}
}

// TestRetry_RetriesRateLimited verifies that a 429 Too Many Requests is NOT
// terminal: a later attempt can succeed once the rate limit clears, so retry
// must keep looping.
func TestRetry_RetriesRateLimited(t *testing.T) {
var calls int
_, err := retry(t.Context(), 3, time.Millisecond, func() (*v1.Descriptor, error) {
calls++
return nil, remoteerrors.ErrUnexpectedStatus{StatusCode: http.StatusTooManyRequests}
})
if err == nil {
t.Fatalf("expected error after exhausting attempts, got nil")
}
if calls != 3 {
t.Fatalf("expected 3 attempts on a rate-limit error, got %d", calls)
}
}
Loading