From 054d756fe1e0c5d38068df32382609baf96dc073 Mon Sep 17 00:00:00 2001 From: George Tsiolis Date: Tue, 9 Jun 2026 17:07:47 +0300 Subject: [PATCH] Add input validation against malformed inputs --- cmd/root.go | 10 ++ internal/config/config.go | 20 ++++ internal/config/env_validation_test.go | 52 ++++++++++ internal/snapshot/destination.go | 13 ++- internal/snapshot/destination_test.go | 25 +++++ internal/validate/validate.go | 128 +++++++++++++++++++++++++ internal/validate/validate_test.go | 108 +++++++++++++++++++++ 7 files changed, 349 insertions(+), 7 deletions(-) create mode 100644 internal/config/env_validation_test.go create mode 100644 internal/validate/validate.go create mode 100644 internal/validate/validate_test.go diff --git a/cmd/root.go b/cmd/root.go index aac3d754..e4df77d0 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -24,6 +24,7 @@ import ( "github.com/localstack/lstk/internal/tracing" "github.com/localstack/lstk/internal/ui" "github.com/localstack/lstk/internal/update" + "github.com/localstack/lstk/internal/validate" "github.com/localstack/lstk/internal/version" "github.com/spf13/cobra" "github.com/spf13/pflag" @@ -134,6 +135,15 @@ func Execute(ctx context.Context) error { resolvedToken = token } } + // Trim surrounding whitespace: env-injected tokens (e.g. CI secrets) commonly + // carry a trailing newline. Then reject clearly malformed tokens before they + // reach the platform API, telemetry, or the container environment. + resolvedToken = strings.TrimSpace(resolvedToken) + if err := validate.AuthToken(resolvedToken); err != nil { + err = fmt.Errorf("invalid auth token: %w", err) + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + return err + } cfg.AuthToken = resolvedToken tel.SetAuthToken(resolvedToken) diff --git a/internal/config/config.go b/internal/config/config.go index 8f17825b..cf39aad3 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,6 +9,7 @@ import ( "regexp" "strings" + "github.com/localstack/lstk/internal/validate" "github.com/pelletier/go-toml/v2" "github.com/spf13/viper" ) @@ -185,5 +186,24 @@ func Get() (*Config, error) { return nil, fmt.Errorf("invalid container config: %w", err) } } + if err := validateNamedEnvs(cfg.Env); err != nil { + return nil, err + } return &cfg, nil } + +// validateNamedEnvs rejects malformed variables defined in the top-level [env.*] +// config sections before they are injected into a container's environment. +func validateNamedEnvs(envs map[string]map[string]string) error { + for name, vars := range envs { + for key, value := range vars { + if err := validate.EnvVarName(key); err != nil { + return fmt.Errorf("invalid variable in [env.%s]: %w", name, err) + } + if err := validate.NoControlChars("value for "+key, value); err != nil { + return fmt.Errorf("invalid variable in [env.%s]: %w", name, err) + } + } + } + return nil +} diff --git a/internal/config/env_validation_test.go b/internal/config/env_validation_test.go new file mode 100644 index 00000000..b69f0e18 --- /dev/null +++ b/internal/config/env_validation_test.go @@ -0,0 +1,52 @@ +package config + +import "testing" + +func TestValidateNamedEnvs(t *testing.T) { + t.Parallel() + tests := []struct { + name string + envs map[string]map[string]string + wantErr bool + }{ + {"nil", nil, false}, + {"empty", map[string]map[string]string{}, false}, + { + name: "valid", + envs: map[string]map[string]string{ + "debug": {"ls_log": "trace", "debug": "1"}, + "ci": {"services": "s3,sqs"}, + }, + wantErr: false, + }, + { + name: "control char in value", + envs: map[string]map[string]string{"bad": {"debug": "1\x00"}}, + wantErr: true, + }, + { + name: "hyphen in key", + envs: map[string]map[string]string{"bad": {"my-key": "1"}}, + wantErr: true, + }, + { + name: "equals in key", + envs: map[string]map[string]string{"bad": {"a=b": "1"}}, + wantErr: true, + }, + { + name: "key starts with digit", + envs: map[string]map[string]string{"bad": {"1var": "1"}}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := validateNamedEnvs(tt.envs) + if (err != nil) != tt.wantErr { + t.Errorf("validateNamedEnvs() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/internal/snapshot/destination.go b/internal/snapshot/destination.go index 6e5534c5..50a9e66b 100644 --- a/internal/snapshot/destination.go +++ b/internal/snapshot/destination.go @@ -6,9 +6,10 @@ import ( "fmt" "os" "path/filepath" - "regexp" "strings" "time" + + "github.com/localstack/lstk/internal/validate" ) // ErrHomeNotSet is returned when a path needs "~" expansion but no home directory was provided. @@ -19,8 +20,6 @@ var ( ErrRemoteNotSupported = errors.New("remote destinations are not yet supported — coming soon") // ErrUnknownScheme is returned for unrecognized URL schemes. ErrUnknownScheme = errors.New("unrecognized destination scheme") - - validPodName = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9-]*$`) ) // DestinationKind distinguishes local file paths from remote pod destinations. @@ -71,8 +70,8 @@ func ParseSource(ref, home string) (Destination, error) { return Destination{}, fmt.Errorf("'%s' is not a valid reference. Aliases use a single colon. Did you mean:\npod:%s", ref, podName) case strings.HasPrefix(lower, "pod:"): podName := ref[len("pod:"):] - if !validPodName.MatchString(podName) { - return Destination{}, fmt.Errorf("invalid pod name %q: use letters, digits, and hyphens only, starting with a letter or digit", podName) + if err := validate.ResourceName("pod name", podName); err != nil { + return Destination{}, fmt.Errorf("invalid pod name %q: %w", podName, err) } return Destination{Kind: KindPod, Value: podName}, nil case strings.HasPrefix(lower, "s3://"), @@ -131,8 +130,8 @@ func ParseDestination(dest, home string, now time.Time) (Destination, error) { return Destination{}, fmt.Errorf("'%s' is not a valid reference. Aliases use a single colon. Did you mean:\npod:%s", dest, podName) case strings.HasPrefix(lower, "pod:"): podName := dest[len("pod:"):] - if !validPodName.MatchString(podName) { - return Destination{}, fmt.Errorf("invalid pod name %q: use letters, digits, and hyphens only, starting with a letter or digit", podName) + if err := validate.ResourceName("pod name", podName); err != nil { + return Destination{}, fmt.Errorf("invalid pod name %q: %w", podName, err) } return Destination{Kind: KindPod, Value: podName}, nil case strings.HasPrefix(lower, "s3://"), diff --git a/internal/snapshot/destination_test.go b/internal/snapshot/destination_test.go index bbcc111e..21482060 100644 --- a/internal/snapshot/destination_test.go +++ b/internal/snapshot/destination_test.go @@ -123,6 +123,16 @@ func TestParseSource(t *testing.T) { input: "pod:-bad", wantErr: "invalid pod name", }, + { + name: "pod: percent encoding rejected", + input: "pod:staging%2Fpod", + wantErr: "invalid pod name", + }, + { + name: "pod: shell metacharacters rejected", + input: "pod:a;rm", + wantErr: "invalid pod name", + }, // --- remote schemes --- { @@ -435,6 +445,21 @@ func TestParseDestination(t *testing.T) { input: "pod:my_pod", wantErr: "invalid pod name", }, + { + name: "pod: percent encoding rejected", + input: "pod:staging%2Fpod", + wantErr: "invalid pod name", + }, + { + name: "pod: embedded query rejected", + input: "pod:abc?fields=name", + wantErr: "invalid pod name", + }, + { + name: "pod: shell metacharacters rejected", + input: "pod:a;rm", + wantErr: "invalid pod name", + }, // --- unknown schemes --- { diff --git a/internal/validate/validate.go b/internal/validate/validate.go new file mode 100644 index 00000000..7772a2a0 --- /dev/null +++ b/internal/validate/validate.go @@ -0,0 +1,128 @@ +// Package validate provides reusable, deterministic validators for user-supplied +// CLI inputs. It exists to make the CLI a safe target for AI agents and scripts, +// which can produce malformed or hostile input — control characters, path +// traversal, percent-encoding, embedded query parameters, shell metacharacters — +// in ways humans rarely do. +// +// Validators return an *Error carrying a machine-classifiable Rule so callers can +// surface a precise, stable reason (and, in JSON output mode, a stable error +// code) instead of a generic "invalid input" message. Error() returns the bare +// reason so it composes cleanly when wrapped with caller context. +package validate + +import ( + "fmt" + "regexp" + "strings" + "unicode" +) + +// Rule classifies why a value was rejected. The values are stable and intended to +// be surfaced as machine-readable error codes. +const ( + RuleEmpty = "empty" + RuleControlChars = "control_chars" + RuleEncoding = "encoding" + RuleTraversal = "traversal" + RuleEmbedded = "embedded" + RuleMetachars = "metachars" + RuleFormat = "format" + RuleRange = "range" +) + +type Error struct { + Field string + Rule string + Msg string +} + +func (e *Error) Error() string { return e.Msg } + +func newError(field, rule, msg string) *Error { + return &Error{Field: field, Rule: rule, Msg: msg} +} + +// containsControlChars reports whether s contains any control character other +// than tab, newline, or carriage return. +func containsControlChars(s string) bool { + for _, r := range s { + if r == '\t' || r == '\n' || r == '\r' { + continue + } + if unicode.IsControl(r) { + return true + } + } + return false +} + +func NoControlChars(field, value string) error { + if containsControlChars(value) { + return newError(field, RuleControlChars, "contains control characters") + } + return nil +} + +var envVarKeyRegexp = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + +// EnvVarName validates an environment variable name (the key of a KEY=VALUE pair). +func EnvVarName(name string) error { + if !envVarKeyRegexp.MatchString(name) { + return newError("env", RuleFormat, fmt.Sprintf("env key %q contains invalid characters", name)) + } + return nil +} + +var resourceNameRegexp = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9-]*$`) + +// shellMetaChars are characters that enable command injection if an identifier is +// ever interpolated into a shell. The slash, question mark, and hash are handled +// separately as embedded path/query characters and are not repeated here. +const shellMetaChars = ";&|$\x60<>(){}[]!*'~\\\"" + +// ResourceName validates an opaque resource identifier such as a Cloud Pod name. +// It runs ordered deny-checks so the most specific reason wins, then a strict +// allow-list. The deny-checks exist to give precise, machine-classifiable +// feedback; the allow-list alone would reject every invalid value. +func ResourceName(field, value string) error { + switch { + case value == "": + return newError(field, RuleEmpty, "must not be empty") + case containsControlChars(value): + return newError(field, RuleControlChars, "contains control characters") + case strings.Contains(value, "%"): + return newError(field, RuleEncoding, "contains percent-encoding (pass the decoded value)") + case strings.Contains(value, ".."): + return newError(field, RuleTraversal, "contains a path traversal sequence (..)") + case strings.ContainsAny(value, "/?#"): + return newError(field, RuleEmbedded, "contains path or query characters (/, ?, #)") + case strings.ContainsAny(value, shellMetaChars): + return newError(field, RuleMetachars, "contains shell metacharacters") + case !resourceNameRegexp.MatchString(value): + return newError(field, RuleFormat, "use letters, digits, and hyphens only, starting with a letter or digit") + } + return nil +} + +// AuthToken validates a LocalStack auth token. The character set is intentionally +// not restricted — tokens are opaque — so only clearly malformed values are +// rejected: control characters, embedded whitespace, or an implausible length. An +// empty token is allowed (it means none is set). Callers should TrimSpace first, +// since environment injection (e.g. CI secrets) commonly appends a trailing newline. +func AuthToken(value string) error { + if value == "" { + return nil + } + for _, r := range value { + if unicode.IsControl(r) { + return newError("auth token", RuleControlChars, "contains control characters") + } + if unicode.IsSpace(r) { + return newError("auth token", RuleFormat, "contains whitespace") + } + } + if len(value) > 1024 { + return newError("auth token", RuleRange, "is implausibly long (over 1024 characters)") + } + return nil +} diff --git a/internal/validate/validate_test.go b/internal/validate/validate_test.go new file mode 100644 index 00000000..7b9bfb55 --- /dev/null +++ b/internal/validate/validate_test.go @@ -0,0 +1,108 @@ +package validate + +import ( + "errors" + "strings" + "testing" +) + +func TestNoControlChars(t *testing.T) { + t.Parallel() + tests := []struct { + name string + value string + wantErr bool + }{ + {"clean string", "hello world", false}, + {"with tab", "hello\tworld", false}, + {"with newline", "hello\nworld", false}, + {"with null byte", "hello\x00world", true}, + {"with bell", "hello\x07world", true}, + {"with escape", "hello\x1bworld", true}, + {"with delete", "hello\x7fworld", true}, + {"empty", "", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := NoControlChars("test", tt.value) + if (err != nil) != tt.wantErr { + t.Errorf("NoControlChars() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestResourceName(t *testing.T) { + t.Parallel() + tests := []struct { + name string + value string + wantErr bool + wantRule string + }{ + {"simple", "my-baseline", false, ""}, + {"alphanumeric", "abc123", false, ""}, + {"single char", "a", false, ""}, + {"long hyphenated", "my-long-pod-name-123", false, ""}, + {"empty", "", true, RuleEmpty}, + {"control char", "ba\x00d", true, RuleControlChars}, + {"percent encoding", "staging%2Fpod", true, RuleEncoding}, + {"path traversal", "../etc", true, RuleTraversal}, + {"embedded query", "abc?fields=name", true, RuleEmbedded}, + {"slash", "a/b", true, RuleEmbedded}, + {"fragment", "id#frag", true, RuleEmbedded}, + {"shell metachar semicolon", "a;rm", true, RuleMetachars}, + {"shell metachar subshell", "a$(id)", true, RuleMetachars}, + {"shell metachar backtick", "a`id`", true, RuleMetachars}, + {"underscore", "my_pod", true, RuleFormat}, + {"leading hyphen", "-bad", true, RuleFormat}, + {"leading dot", ".hidden", true, RuleFormat}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := ResourceName("name", tt.value) + if (err != nil) != tt.wantErr { + t.Fatalf("ResourceName(%q) error = %v, wantErr %v", tt.value, err, tt.wantErr) + } + if tt.wantRule != "" { + var ve *Error + if !errors.As(err, &ve) { + t.Fatalf("ResourceName(%q) error is not *validate.Error: %v", tt.value, err) + } + if ve.Rule != tt.wantRule { + t.Errorf("ResourceName(%q) Rule = %q, want %q", tt.value, ve.Rule, tt.wantRule) + } + } + }) + } +} + +func TestAuthToken(t *testing.T) { + t.Parallel() + tests := []struct { + name string + value string + wantErr bool + }{ + {"empty is allowed", "", false}, + {"typical token", "ls-example-token", false}, + {"alphanumeric", "exampletoken123", false}, + {"with null byte", "tok\x00en", true}, + {"with escape", "tok\x1ben", true}, + {"with newline", "token\n", true}, + {"with tab", "tok\ten", true}, + {"with space", "tok en", true}, + {"too long", strings.Repeat("a", 1025), true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := AuthToken(tt.value) + if (err != nil) != tt.wantErr { + t.Errorf("AuthToken(%q) error = %v, wantErr %v", tt.value, err, tt.wantErr) + } + }) + } +}