Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 9 additions & 37 deletions ociauth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,8 @@ func (a *stdTransport) RoundTrip(req *http.Request) (*http.Response, error) {

ctx := req.Context()
requiredScope := RequestInfoFromContext(ctx).RequiredScope
wantScope := ScopeFromContext(ctx)

if err := r.setAuthorization(ctx, req, requiredScope, wantScope); err != nil {
if err := r.setAuthorization(ctx, req, requiredScope); err != nil {
return nil, err
}
resp, err := r.transport.RoundTrip(req)
Expand All @@ -171,7 +170,7 @@ func (a *stdTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if challenge == nil {
return resp, nil
}
authAdded, tokenAcquired, err := r.setAuthorizationFromChallenge(ctx, req, challenge, requiredScope, wantScope)
authAdded, tokenAcquired, err := r.setAuthorizationFromChallenge(ctx, req, challenge)
if err != nil {
resp.Body.Close()
return nil, err
Expand Down Expand Up @@ -218,7 +217,7 @@ func (a *stdTransport) RoundTrip(req *http.Request) (*http.Response, error) {

// setAuthorization sets up authorization on the given request using any
// auth information currently available.
func (r *registry) setAuthorization(ctx context.Context, req *http.Request, requiredScope, wantScope Scope) error {
func (r *registry) setAuthorization(ctx context.Context, req *http.Request, requiredScope Scope) error {
r.mu.Lock()
defer r.mu.Unlock()
// Remove tokens that have expired or will expire soon so that
Expand Down Expand Up @@ -247,7 +246,7 @@ func (r *registry) setAuthorization(ctx context.Context, req *http.Request, requ
// acquiring several tokens concurrently. We should relax the lock
// to allow that.

accessToken, err := r.acquireAccessToken(ctx, requiredScope, wantScope)
accessToken, err := r.acquireAccessToken(ctx, requiredScope)
if err != nil {
// Avoid using %w to wrap the error because we don't want the
// caller of RoundTrip (usually ociclient) to assume that the
Expand All @@ -264,15 +263,15 @@ func (r *registry) setAuthorization(ctx context.Context, req *http.Request, requ
return nil
}

func (r *registry) setAuthorizationFromChallenge(ctx context.Context, req *http.Request, challenge *authHeader, requiredScope, wantScope Scope) (authAdded, tokenAcquired bool, _ error) {
func (r *registry) setAuthorizationFromChallenge(ctx context.Context, req *http.Request, challenge *authHeader) (authAdded, tokenAcquired bool, _ error) {
r.mu.Lock()
defer r.mu.Unlock()
r.wwwAuthenticate = challenge

switch {
case r.wwwAuthenticate.scheme == "bearer":
scope := ParseScope(r.wwwAuthenticate.params["scope"])
accessToken, err := r.acquireAccessToken(ctx, scope, wantScope.Union(requiredScope))
accessToken, err := r.acquireAccessToken(ctx, scope)
if err != nil {
return false, false, err
}
Expand Down Expand Up @@ -320,41 +319,14 @@ func (r *registry) init() error {
}

// acquireAccessToken tries to acquire an access token for authorizing a request.
// The requiredScopeStr parameter indicates the scope that's definitely
// required. This is a string because apparently some servers are picky
// about getting exactly the same scope in the auth request that was
// returned in the challenge. The wantScope parameter indicates
// what scope might be required in the future.
// The scope comes from the registry's Www-Authenticate challenge.
//
// This method assumes that there has been a previous 401 response with
// a Www-Authenticate: Bearer... header.
func (r *registry) acquireAccessToken(ctx context.Context, requiredScope, wantScope Scope) (string, error) {
scope := requiredScope.Union(wantScope)
func (r *registry) acquireAccessToken(ctx context.Context, scope Scope) (string, error) {
tok, err := r.acquireToken(ctx, scope)
if err != nil {
var herr oci.HTTPError
if !errors.As(err, &herr) || herr.StatusCode() != http.StatusUnauthorized {
return "", err
}
// The documentation says this:
//
// If the client only has a subset of the requested
// access it _must not be considered an error_ as it is
// not the responsibility of the token server to
// indicate authorization errors as part of this
// workflow.
//
// However it's apparently not uncommon for servers to reject
// such requests anyway, so if we've got an unauthorized error
// and wantScope goes beyond requiredScope, it may be because
// the server is rejecting the request.
scope = requiredScope
tok, err = r.acquireToken(ctx, scope)
if err != nil {
return "", err
}
// TODO mark the registry as picky about tokens so we don't
// attempt twice every time?
return "", err
}
if tok.RefreshToken != "" {
r.refreshToken = tok.RefreshToken
Expand Down
83 changes: 71 additions & 12 deletions ociauth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ func TestBearerAuth(t *testing.T) {
assertRequest(context.Background(), t, ts, "/test", client, Scope{})
}

func TestBearerAuthAdditionalScope(t *testing.T) {
// This tests the scenario where there's a larger scope in the context
// than the required scope.
func TestBearerAuthAdditionalScopeDoesNotOverrideChallenge(t *testing.T) {
// This tests that additional context scope is not unioned into the
// registry-provided challenge scope.
requiredScope := ParseScope("repository:foo:push,pull")
additionalScope := ParseScope("repository:bar:pull somethingElse")
authSrv := newAuthServer(t, func(req *http.Request) (any, *httpError) {
Expand All @@ -114,8 +114,7 @@ func TestBearerAuthAdditionalScope(t *testing.T) {
}
requestedScope := ParseScope(strings.Join(req.Form["scope"], " "))
runNonFatal(t, func(t testing.TB) {
wantScope := requiredScope.Union(additionalScope)
require.True(t, wantScope.Equal(requestedScope), "scope mismatch: got %v, want %v", requestedScope, wantScope)
require.True(t, requiredScope.Equal(requestedScope), "scope mismatch: got %v, want %v", requestedScope, requiredScope)
require.Equal(t, []string{"someService"}, req.Form["service"])
})
return &wireToken{
Expand All @@ -132,8 +131,7 @@ func TestBearerAuthAdditionalScope(t *testing.T) {
}
}
runNonFatal(t, func(t testing.TB) {
wantScope := requiredScope.Union(additionalScope)
require.True(t, wantScope.Equal(authScopeFromRequest(t, req)), "scope mismatch")
require.True(t, requiredScope.Equal(authScopeFromRequest(t, req)), "scope mismatch")
})
return nil
})
Expand Down Expand Up @@ -418,10 +416,14 @@ func TestLaterRequestCanUseEarlierTokenWithLargerScope(t *testing.T) {
Action: ActionPull,
})
if req.Header.Get("Authorization") == "" {
challengeScope := requiredScope
if resource == "foo1" {
challengeScope = ParseScope("repository:foo1:pull repository:foo2:pull")
}
return &httpError{
statusCode: http.StatusUnauthorized,
header: http.Header{
"Www-Authenticate": []string{fmt.Sprintf("Bearer realm=%q,service=someService,scope=%q", authSrv, requiredScope)},
"Www-Authenticate": []string{fmt.Sprintf("Bearer realm=%q,service=someService,scope=%q", authSrv, challengeScope)},
},
}
}
Expand All @@ -438,15 +440,72 @@ func TestLaterRequestCanUseEarlierTokenWithLargerScope(t *testing.T) {
}),
}),
}
ctx := ContextWithScope(context.Background(), ParseScope("repository:foo1:pull repository:foo2:pull"))
assertRequest(ctx, t, ts, "/test/foo1", client, Scope{})
assertRequest(ctx, t, ts, "/test/foo2", client, Scope{})
assertRequest(context.Background(), t, ts, "/test/foo1", client, Scope{})
assertRequest(context.Background(), t, ts, "/test/foo2", client, Scope{})
// One token fetch should have been sufficient for both requests.
require.Equal(t, 1, authCount)
}

func TestLaterRequestCanAcquireTokenProactively(t *testing.T) {
authCount := 0
authSrv := newAuthServer(t, func(req *http.Request) (any, *httpError) {
authCount++
requestedScope := ParseScope(strings.Join(req.Form["scope"], " "))
return &wireToken{
Token: token{requestedScope}.String(),
}, nil
})
targetCount := 0
ts := newTargetServer(t, func(req *http.Request) *httpError {
targetCount++
resource := strings.TrimPrefix(req.URL.Path, "/test/")
requiredScope := NewScope(ResourceScope{
ResourceType: TypeRepository,
Resource: resource,
Action: ActionPull,
})
if req.Header.Get("Authorization") == "" {
return &httpError{
statusCode: http.StatusUnauthorized,
header: http.Header{
"Www-Authenticate": []string{fmt.Sprintf("Bearer realm=%q,service=someService,scope=%q", authSrv, requiredScope)},
},
}
}
runNonFatal(t, func(t testing.TB) {
requestScope := authScopeFromRequest(t, req)
require.True(t, requestScope.Contains(requiredScope), "request scope: %q; required scope: %q", requestScope, requiredScope)
})
return nil
})
client := &http.Client{
Transport: NewStdTransport(StdTransportParams{
Config: configFunc(func(host string) (ConfigEntry, error) {
if host == ts.Host {
return ConfigEntry{
RefreshToken: "someRefreshToken",
}, nil
}
return ConfigEntry{}, nil
}),
}),
}
assertRequest1(ContextWithRequestInfo(context.Background(), RequestInfo{
RequiredScope: ParseScope("repository:foo1:pull"),
}), t, ts, "/test/foo1", client)
require.Equal(t, 2, targetCount)
require.Equal(t, 1, authCount)

assertRequest1(ContextWithRequestInfo(context.Background(), RequestInfo{
RequiredScope: ParseScope("repository:foo2:pull"),
}), t, ts, "/test/foo2", client)
require.Equal(t, 3, targetCount)
require.Equal(t, 2, authCount)
}

func TestAuthServerRejectsRequestsWithTooMuchScope(t *testing.T) {
// This tests the scenario described in the comment in registry.acquireAccessToken.
// This verifies that caller-provided desired scope is not added to the
// registry-provided challenge scope.
userHasScope := ParseScope("repository:foo:pull")

authSrv := newAuthServer(t, func(req *http.Request) (any, *httpError) {
Expand Down
18 changes: 8 additions & 10 deletions ociauth/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ import (
type scopeKey struct{}

// ContextWithScope returns ctx annotated with the given
// scope. When the ociauth transport receives a request with a scope in the context,
// it will treat it as "desired authorization scope"; new authorization tokens
// will be acquired with that scope as well as any scope required by
// the operation.
// scope. The ociauth transport does not add this scope to a registry
// challenge; challenges remain the source of truth for new token acquisition.
func ContextWithScope(ctx context.Context, s Scope) context.Context {
return context.WithValue(ctx, scopeKey{}, s)
}
Expand All @@ -29,18 +27,18 @@ type requestInfoKey struct{}
// request context. The [ociclient] package will add this to all
// requests that is makes.
type RequestInfo struct {
// RequiredScope holds the authorization scope that's required
// by the request. The ociauth logic will reuse any available
// auth token that has this scope. When acquiring a new token,
// it will add any scope found in [ScopeFromContext] too.
// RequiredScope holds the authorization scope that can satisfy
// the request for cached-token reuse. When the transport already
// knows the registry's bearer token realm and has a refresh token,
// it may use this scope to acquire a token proactively. A
// Www-Authenticate challenge remains authoritative when present.
RequiredScope Scope
}

// ContextWithRequestInfo returns ctx annotated with the given
// request informaton. When ociclient receives a request with
// this attached, it will respect info.RequiredScope to determine
// what auth tokens to reuse. When it acquires a new token,
// it will ask for the union of info.RequiredScope [ScopeFromContext].
// what auth tokens to reuse or proactively acquire.
func ContextWithRequestInfo(ctx context.Context, info RequestInfo) context.Context {
return context.WithValue(ctx, requestInfoKey{}, info)
}
Expand Down
9 changes: 8 additions & 1 deletion ociauth/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type knownAction byte
const (
unknownAction knownAction = iota
// Note: ordered by lexical string representation.
deleteAction
pullAction
pushAction
numActions
Expand All @@ -24,6 +25,8 @@ const (
// TypeRegistry is the resource type for registry-wide operations.
TypeRegistry = "registry"

// ActionDelete is the action for deleting content from a repository.
ActionDelete = "delete"
// ActionPull is the action for pulling content from a repository.
ActionPull = "pull"
// ActionPush is the action for pushing content to a repository.
Expand All @@ -32,6 +35,8 @@ const (

func (a knownAction) String() string {
switch a {
case deleteAction:
return ActionDelete
case pullAction:
return ActionPull
case pushAction:
Expand Down Expand Up @@ -65,7 +70,7 @@ type ResourceScope struct {
Resource string

// Action names an action that can be performed on the resource.
// This is usually ActionPush or ActionPull.
// This is usually ActionPull, ActionPush or ActionDelete.
Action string
}

Expand Down Expand Up @@ -487,6 +492,8 @@ func (s Scope) String() string {

func parseKnownAction(s string) knownAction {
switch s {
case ActionDelete:
return deleteAction
case ActionPull:
return pullAction
case ActionPush:
Expand Down
14 changes: 7 additions & 7 deletions ociclient/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ func TestAuthScopes(t *testing.T) {
assertScope("repository:foo/bar:pull", func(ctx context.Context, r oci.Interface) {
r.ResolveTag(ctx, "foo/bar", "sometag")
})
assertScope("repository:foo/bar:push", func(ctx context.Context, r oci.Interface) {
assertScope("repository:foo/bar:pull,push", func(ctx context.Context, r oci.Interface) {
r.PushBlob(ctx, "foo/bar", oci.Descriptor{
MediaType: "application/json",
Digest: testDigest,
Size: 3,
}, strings.NewReader("foo"))
})
assertScope("repository:foo/bar:push", func(ctx context.Context, r oci.Interface) {
assertScope("repository:foo/bar:pull,push", func(ctx context.Context, r oci.Interface) {
w, err := r.PushBlobChunked(ctx, "foo/bar", 0)
require.NoError(t, err)
w.Write([]byte("foo"))
Expand All @@ -74,21 +74,21 @@ func TestAuthScopes(t *testing.T) {
_, err = w.Commit(ocidigest.FromBytes([]byte("foobar")))
require.NoError(t, err)
})
assertScope("repository:x/y:pull repository:z/w:push", func(ctx context.Context, r oci.Interface) {
assertScope("repository:x/y:pull repository:z/w:pull,push", func(ctx context.Context, r oci.Interface) {
r.MountBlob(ctx, "x/y", "z/w", testDigest)
})
assertScope("repository:foo/bar:push", func(ctx context.Context, r oci.Interface) {
assertScope("repository:foo/bar:pull,push", func(ctx context.Context, r oci.Interface) {
r.PushManifest(ctx, "foo/bar", []byte("something"), "application/json", &oci.PushManifestParameters{
Tags: []string{"sometag"},
})
})
assertScope("repository:foo/bar:push", func(ctx context.Context, r oci.Interface) {
assertScope("repository:foo/bar:delete", func(ctx context.Context, r oci.Interface) {
r.DeleteBlob(ctx, "foo/bar", testDigest)
})
assertScope("repository:foo/bar:push", func(ctx context.Context, r oci.Interface) {
assertScope("repository:foo/bar:delete", func(ctx context.Context, r oci.Interface) {
r.DeleteManifest(ctx, "foo/bar", testDigest)
})
assertScope("repository:foo/bar:push", func(ctx context.Context, r oci.Interface) {
assertScope("repository:foo/bar:delete", func(ctx context.Context, r oci.Interface) {
r.DeleteTag(ctx, "foo/bar", "sometag")
})
assertScope("registry:catalog:*", func(ctx context.Context, r oci.Interface) {
Expand Down
4 changes: 2 additions & 2 deletions ociclient/badname_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ func TestBadRepoName(t *testing.T) {
})
require.NoError(t, err)
_, err = r.GetBlob(ctx, "Invalid--Repo", ocidigest.FromBytes(nil))
assert.Regexp(t, "invalid OCI request: name invalid: invalid repository name", err.Error())
assert.Regexp(t, "no can do", err.Error())
_, err = r.ResolveTag(ctx, "okrepo", "bad-Tag!")
assert.Regexp(t, "invalid OCI request: 404 Not Found: page not found", err.Error())
assert.Regexp(t, "no can do", err.Error())
}

type noTransport struct{}
Expand Down
Loading
Loading