Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
- `certificates`: [v1.2.0](services/certificates/CHANGELOG.md#v120)
- **Feature:** Switch from `v2beta` API version to `v2` version.
- **Breaking change:** Rename `CreateCertificateResponse` to `GetCertificateResponse`
- `core`:
- [v0.21.0](core/CHANGELOG.md#v0210)
- **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token`
- **Feature:** Support Workload Identity Federation flow
- `sfs`:
- [v0.2.0](services/sfs/CHANGELOG.md)
- **Breaking change:** Remove region configuration in `APIClient`
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,4 +234,4 @@ See the [release documentation](./RELEASE.md) for further information.

## License

Apache 2.0
Apache 2.0
4 changes: 4 additions & 0 deletions core/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## v0.21.0
- **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token`
- **Feature:** Support Workload Identity Federation flow

## v0.20.1
- **Improvement:** Improve error message when passing a PEM encoded file to as service account key

Expand Down
2 changes: 1 addition & 1 deletion core/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v0.20.1
v0.21.0
45 changes: 39 additions & 6 deletions core/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ func SetupAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) {
return nil, fmt.Errorf("configuring no auth client: %w", err)
}
return noAuthRoundTripper, nil
} else if cfg.WorkloadIdentityFederation {
wifRoundTripper, err := WorkloadIdentityFederationAuth(cfg)
if err != nil {
return nil, fmt.Errorf("configuring no auth client: %w", err)
}
return wifRoundTripper, nil
} else if cfg.ServiceAccountKey != "" || cfg.ServiceAccountKeyPath != "" {
keyRoundTripper, err := KeyAuth(cfg)
if err != nil {
Expand Down Expand Up @@ -84,14 +90,18 @@ func DefaultAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) {
cfg = &config.Configuration{}
}

// Key flow
rt, err = KeyAuth(cfg)
// WIF flow
rt, err = WorkloadIdentityFederationAuth(cfg)
if err != nil {
keyFlowErr := err
// Token flow
rt, err = TokenAuth(cfg)
// Key flow
rt, err = KeyAuth(cfg)
if err != nil {
return nil, fmt.Errorf("no valid credentials were found: trying key flow: %s, trying token flow: %w", keyFlowErr.Error(), err)
keyFlowErr := err
// Token flow
rt, err = TokenAuth(cfg)
if err != nil {
return nil, fmt.Errorf("no valid credentials were found: trying key flow: %s, trying token flow: %w", keyFlowErr.Error(), err)
}
}
}
return rt, nil
Expand Down Expand Up @@ -221,6 +231,29 @@ func KeyAuth(cfg *config.Configuration) (http.RoundTripper, error) {
return client, nil
}

// WorkloadIdentityFederationAuth configures the wif flow and returns an http.RoundTripper
// that can be used to make authenticated requests using an access token
func WorkloadIdentityFederationAuth(cfg *config.Configuration) (http.RoundTripper, error) {
wifConfig := clients.WorkloadIdentityFederationFlowConfig{
TokenUrl: cfg.TokenCustomUrl,
BackgroundTokenRefreshContext: cfg.BackgroundTokenRefreshContext,
ClientID: cfg.ServiceAccountEmail,
TokenExpiration: cfg.ServiceAccountFederatedTokenExpiration,
FederatedTokenFunction: cfg.ServiceAccountFederatedTokenFunc,
}

if cfg.HTTPClient != nil && cfg.HTTPClient.Transport != nil {
wifConfig.HTTPTransport = cfg.HTTPClient.Transport
}

client := &clients.WorkloadIdentityFederationFlow{}
if err := client.Init(&wifConfig); err != nil {
return nil, fmt.Errorf("error initializing client: %w", err)
}

return client, nil
}

// readCredentialsFile reads the credentials file from the specified path and returns Credentials
func readCredentialsFile(path string) (*Credentials, error) {
if path == "" {
Expand Down
88 changes: 87 additions & 1 deletion core/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"testing"
"time"

"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/stackitcloud/stackit-sdk-go/core/clients"
"github.com/stackitcloud/stackit-sdk-go/core/config"
Expand Down Expand Up @@ -121,6 +122,32 @@ func TestSetupAuth(t *testing.T) {
}
}()

// create a wif assertion file
wifAssertionFile, errs := os.CreateTemp("", "temp-*.txt")
if errs != nil {
t.Fatalf("Creating temporary file: %s", err)
}
defer func() {
_ = wifAssertionFile.Close()
err := os.Remove(wifAssertionFile.Name())
if err != nil {
t.Fatalf("Removing temporary file: %s", err)
}
}()

token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
Subject: "sub",
}).SignedString([]byte("test"))
if err != nil {
t.Fatalf("Removing temporary file: %s", err)
}

_, errs = wifAssertionFile.WriteString(string(token))
if errs != nil {
t.Fatalf("Writing wif assertion to temporary file: %s", err)
}

// create a credentials file with saKey and private key
credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt")
if errs != nil {
Expand All @@ -147,12 +174,19 @@ func TestSetupAuth(t *testing.T) {
desc string
config *config.Configuration
setToken bool
setWorkloadIdentity bool
setKeys bool
setKeyPaths bool
setCredentialsFilePathToken bool
setCredentialsFilePathKey bool
isValid bool
}{
{
desc: "wif_config",
config: nil,
setWorkloadIdentity: true,
isValid: true,
},
{
desc: "token_config",
config: nil,
Expand Down Expand Up @@ -241,6 +275,12 @@ func TestSetupAuth(t *testing.T) {
t.Setenv("STACKIT_CREDENTIALS_PATH", "")
}

if test.setWorkloadIdentity {
t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", wifAssertionFile.Name())
} else {
t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "")
}

t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "test-email")

authRoundTripper, err := SetupAuth(test.config)
Expand All @@ -253,7 +293,7 @@ func TestSetupAuth(t *testing.T) {
t.Fatalf("Test didn't return error on invalid test case")
}

if test.isValid && authRoundTripper == nil {
if authRoundTripper == nil && test.isValid {
t.Fatalf("Roundtripper returned is nil for valid test case")
}
})
Expand Down Expand Up @@ -381,6 +421,32 @@ func TestDefaultAuth(t *testing.T) {
t.Fatalf("Writing private key to temporary file: %s", err)
}

// create a wif assertion file
wifAssertionFile, errs := os.CreateTemp("", "temp-*.txt")
if errs != nil {
t.Fatalf("Creating temporary file: %s", err)
}
defer func() {
_ = wifAssertionFile.Close()
err := os.Remove(wifAssertionFile.Name())
if err != nil {
t.Fatalf("Removing temporary file: %s", err)
}
}()

token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
Subject: "sub",
}).SignedString([]byte("test"))
if err != nil {
t.Fatalf("Removing temporary file: %s", err)
}

_, errs = wifAssertionFile.WriteString(string(token))
if errs != nil {
t.Fatalf("Writing wif assertion to temporary file: %s", err)
}

// create a credentials file with saKey and private key
credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt")
if errs != nil {
Expand Down Expand Up @@ -409,6 +475,7 @@ func TestDefaultAuth(t *testing.T) {
setKeyPaths bool
setKeys bool
setCredentialsFilePathKey bool
setWorkloadIdentity bool
isValid bool
expectedFlow string
}{
Expand All @@ -418,6 +485,14 @@ func TestDefaultAuth(t *testing.T) {
isValid: true,
expectedFlow: "token",
},
{
desc: "wif_precedes_key_precedes_token",
setToken: true,
setKeyPaths: true,
setWorkloadIdentity: true,
isValid: true,
expectedFlow: "wif",
},
{
desc: "key_precedes_token",
setToken: true,
Expand Down Expand Up @@ -475,6 +550,13 @@ func TestDefaultAuth(t *testing.T) {
} else {
t.Setenv("STACKIT_SERVICE_ACCOUNT_TOKEN", "")
}

if test.setWorkloadIdentity {
t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", wifAssertionFile.Name())
} else {
t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "")
}

t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "test-email")

// Get the default authentication client and ensure that it's not nil
Expand All @@ -501,6 +583,10 @@ func TestDefaultAuth(t *testing.T) {
if _, ok := authClient.(*clients.KeyFlow); !ok {
t.Fatalf("Expected key flow, got %s", reflect.TypeOf(authClient))
}
case "wif":
if _, ok := authClient.(*clients.WorkloadIdentityFederationFlow); !ok {
t.Fatalf("Expected key flow, got %s", reflect.TypeOf(authClient))
}
}
}
})
Expand Down
84 changes: 84 additions & 0 deletions core/clients/auth_flow.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package clients

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"

"github.com/golang-jwt/jwt/v5"
"github.com/stackitcloud/stackit-sdk-go/core/oapierror"
)

const (
defaultTokenExpirationLeeway = time.Second * 5
)

type AuthFlow interface {
RoundTrip(req *http.Request) (*http.Response, error)
GetAccessToken() (string, error)
GetBackgroundTokenRefreshContext() context.Context
}

// TokenResponseBody is the API response
// when requesting a new token
type TokenResponseBody struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
}

func parseTokenResponse(res *http.Response) (*TokenResponseBody, error) {
if res == nil {
return nil, fmt.Errorf("received bad response from API")
}
if res.StatusCode != http.StatusOK {
body, err := io.ReadAll(res.Body)
if err != nil {
// Fail silently, omit body from error
// We're trying to show error details, so it's unnecessary to fail because of this err
body = []byte{}
}
return nil, &oapierror.GenericOpenAPIError{
StatusCode: res.StatusCode,
Body: body,
}
}
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}

token := &TokenResponseBody{}
err = json.Unmarshal(body, token)
if err != nil {
return nil, fmt.Errorf("unmarshal token response: %w", err)
}
return token, nil
}

func tokenExpired(token string, tokenExpirationLeeway time.Duration) (bool, error) {
if token == "" {
return true, nil
}

// We can safely use ParseUnverified because we are not authenticating the user at this point.
// We're just checking the expiration time
tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{})
if err != nil {
return false, fmt.Errorf("parse token: %w", err)
}

expirationTimestampNumeric, err := tokenParsed.Claims.GetExpirationTime()
if err != nil {
return false, fmt.Errorf("get expiration timestamp: %w", err)
}

// Pretend to be `tokenExpirationLeeway` into the future to avoid token expiring
// between retrieving the token and upstream systems validating it.
now := time.Now().Add(tokenExpirationLeeway)
return now.After(expirationTimestampNumeric.Time), nil
}
Loading
Loading