From 95dc5d3f1dfd05a872a437064531fea1b8bc30b4 Mon Sep 17 00:00:00 2001 From: Swarit Pandey Date: Mon, 18 May 2026 10:27:05 +0530 Subject: [PATCH 1/2] chore(node-scan): optimize scan times Signed-off-by: Swarit Pandey --- internal/detector/nodescan.go | 121 +++++++++++++++---- internal/detector/nodescan_cache.go | 143 +++++++++++++++++++++++ internal/detector/nodescan_cache_test.go | 106 +++++++++++++++++ 3 files changed, 350 insertions(+), 20 deletions(-) create mode 100644 internal/detector/nodescan_cache.go create mode 100644 internal/detector/nodescan_cache_test.go diff --git a/internal/detector/nodescan.go b/internal/detector/nodescan.go index 3304be2..01895fe 100644 --- a/internal/detector/nodescan.go +++ b/internal/detector/nodescan.go @@ -9,6 +9,7 @@ import ( "sort" "strconv" "strings" + "sync" "time" "github.com/step-security/dev-machine-guard/internal/executor" @@ -228,7 +229,11 @@ type projectEntry struct { modTime int64 } -// ScanProjects finds package.json files, sorts by most recently modified, then scans. +// ScanProjects finds package.json files, sorts by most recently modified, then +// scans projects concurrently. Per-project results are cached locally; on the +// next run we skip `npm ls` for any project whose package.json and lockfile +// haven't been modified since the cached scan timestamp. +// // Respects the size limit (default 500MB, override via STEPSEC_MAX_NODE_SCAN_BYTES). func (s *NodeScanner) ScanProjects(ctx context.Context, searchDirs []string) []model.NodeScanResult { // Phase 1: Discover all package.json files @@ -254,7 +259,6 @@ func (s *NodeScanner) ScanProjects(ctx context.Context, searchDirs []string) []m if isInsideNodeModules(projectDir) { return nil } - // Get modification time for sorting modTime := int64(0) if info, err := entry.Info(); err == nil { modTime = info.ModTime().Unix() @@ -269,40 +273,117 @@ func (s *NodeScanner) ScanProjects(ctx context.Context, searchDirs []string) []m return projects[i].modTime > projects[j].modTime }) - // Phase 3: Scan in order, respecting limits - maxBytes := getMaxProjectScanBytes() - var results []model.NodeScanResult - totalSize := int64(0) - + // Phase 3: Build the work plan. For each project decide whether the + // previous cached result is still valid (skip) or we need to re-scan. + cachePath := scanCachePath(s.exec) + cache := loadScanCache(cachePath) + nowUnix := time.Now().Unix() + + type plan struct { + dir string + pm string + skip bool + cached model.NodeScanResult + } + plans := make([]plan, 0, len(projects)) for i, p := range projects { if i >= maxNodeProjects { s.log.Progress(" Reached maximum of %d projects, stopping search", maxNodeProjects) break } - if totalSize > maxBytes { - s.log.Progress(" Reached data size limit (%d bytes collected, limit: %d bytes)", totalSize, maxBytes) - s.log.Progress(" Skipping remaining projects (prioritized by most recently modified)") - break + pm := DetectProjectPM(s.exec, p.dir) + pl := plan{dir: p.dir, pm: pm} + if entry, ok := cache.Projects[p.dir]; ok && entry.PackageManager == pm { + lockPath := lockfileFor(s.exec, p.dir, pm) + // No lockfile means we can't trust mtime — always re-scan. + if lockPath != "" { + pkgMt := mtimeOr0(s.exec, filepath.Join(p.dir, "package.json")) + lockMt := mtimeOr0(s.exec, lockPath) + if pkgMt <= entry.LastScanUnix && lockMt <= entry.LastScanUnix { + pl.skip = true + pl.cached = entry.CachedResult + } + } } + plans = append(plans, pl) + } - s.log.Progress(" Found project: %s", p.dir) - pm := DetectProjectPM(s.exec, p.dir) - s.log.Progress(" Package manager: %s", pm) + // Phase 4: Dispatch fresh scans concurrently. Skipped projects already + // have a result; only cache-miss/invalid entries hit the worker pool. + results := make([]model.NodeScanResult, len(plans)) + for i, pl := range plans { + if pl.skip { + results[i] = pl.cached + s.log.Progress(" Skipping (unchanged): %s (%s)", pl.dir, pl.pm) + } + } - r := s.scanProject(ctx, p.dir) - resultSize := int64(len(r.RawStdoutBase64)) + int64(len(r.RawStderrBase64)) + workers := scanWorkerCount(s.exec) + jobs := make(chan int, len(plans)) + var wg sync.WaitGroup + for range workers { + wg.Add(1) + go func() { + defer wg.Done() + for idx := range jobs { + pl := plans[idx] + s.log.Progress(" Scanning project: %s (%s)", pl.dir, pl.pm) + results[idx] = s.scanProject(ctx, pl.dir) + } + }() + } + scanned := 0 + for i, pl := range plans { + if !pl.skip { + jobs <- i + scanned++ + } + } + close(jobs) + wg.Wait() + s.log.Progress(" Scanned %d projects (%d skipped via cache)", scanned, len(plans)-scanned) - if totalSize+resultSize > maxBytes { + // Phase 5: Apply the size cap in mtime-desc order (matches prior behavior) + // and update cache with freshly-scanned successful results. + maxBytes := getMaxProjectScanBytes() + final := make([]model.NodeScanResult, 0, len(plans)) + totalSize := int64(0) + for i := range plans { + r := results[i] + size := int64(len(r.RawStdoutBase64)) + int64(len(r.RawStderrBase64)) + if totalSize+size > maxBytes { s.log.Progress(" Reached data size limit (%d bytes collected, limit: %d bytes)", totalSize, maxBytes) s.log.Progress(" Skipping remaining projects (prioritized by most recently modified)") break } + totalSize += size + final = append(final, r) + // Only cache successful fresh scans. Failed scans should be retried. + if !plans[i].skip && r.ExitCode == 0 { + cache.Projects[plans[i].dir] = cacheEntry{ + PackageManager: plans[i].pm, + LastScanUnix: nowUnix, + CachedResult: r, + } + } + } - totalSize += resultSize - results = append(results, r) + // Drop cache entries for projects no longer on disk so the cache file + // doesn't grow unboundedly across runs. + seen := make(map[string]struct{}, len(plans)) + for _, pl := range plans { + seen[pl.dir] = struct{}{} + } + for dir := range cache.Projects { + if _, ok := seen[dir]; !ok { + delete(cache.Projects, dir) + } + } + if err := cache.save(cachePath); err != nil { + s.log.Progress(" Warning: failed to write scan cache: %v", err) } - return results + return final } func (s *NodeScanner) scanProject(ctx context.Context, projectDir string) model.NodeScanResult { diff --git a/internal/detector/nodescan_cache.go b/internal/detector/nodescan_cache.go new file mode 100644 index 0000000..4c72cf1 --- /dev/null +++ b/internal/detector/nodescan_cache.go @@ -0,0 +1,143 @@ +package detector + +import ( + "encoding/json" + "os" + "path/filepath" + "runtime" + "strconv" + + "github.com/step-security/dev-machine-guard/internal/executor" + "github.com/step-security/dev-machine-guard/internal/model" +) + +const scanCacheVersion = 1 + +// cacheEntry is one project's cached scan result, used to skip re-running +// `npm/yarn/pnpm/bun ls` when neither package.json nor the lockfile has been +// modified since LastScanUnix. +type cacheEntry struct { + PackageManager string `json:"package_manager"` + LastScanUnix int64 `json:"last_scan_unix"` + CachedResult model.NodeScanResult `json:"cached_result"` +} + +type scanCache struct { + Version int `json:"version"` + Projects map[string]cacheEntry `json:"projects"` +} + +func newScanCache() *scanCache { + return &scanCache{Version: scanCacheVersion, Projects: map[string]cacheEntry{}} +} + +// scanCachePath returns the on-disk path for the per-project scan cache. +// Override with STEPSEC_NODE_SCAN_CACHE for tests / non-root runs. +func scanCachePath(exec executor.Executor) string { + if override := exec.Getenv("STEPSEC_NODE_SCAN_CACHE"); override != "" { + return override + } + if exec.GOOS() == "windows" { + return filepath.Join(`C:\ProgramData\StepSecurity\dev-machine-guard`, "scan-cache.json") + } + return "/var/lib/stepsecurity/dev-machine-guard/scan-cache.json" +} + +// loadScanCache reads the cache file. Returns an empty cache on miss or any +// parse error — a corrupt cache must never break a scan, only force a full one. +func loadScanCache(path string) *scanCache { + data, err := os.ReadFile(path) + if err != nil { + return newScanCache() + } + var c scanCache + if err := json.Unmarshal(data, &c); err != nil || c.Version != scanCacheVersion { + return newScanCache() + } + if c.Projects == nil { + c.Projects = map[string]cacheEntry{} + } + return &c +} + +// save writes the cache atomically (write to tmp, rename). +func (c *scanCache) save(path string) error { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + data, err := json.Marshal(c) + if err != nil { + return err + } + tmp, err := os.CreateTemp(dir, ".scan-cache-*.tmp") + if err != nil { + return err + } + tmpPath := tmp.Name() + if _, err := tmp.Write(data); err != nil { + _ = tmp.Close() + _ = os.Remove(tmpPath) + return err + } + if err := tmp.Close(); err != nil { + _ = os.Remove(tmpPath) + return err + } + return os.Rename(tmpPath, path) +} + +// lockfileFor returns the path of the lockfile for the given package manager +// in projectDir, or "" if no expected lockfile is present. +func lockfileFor(exec executor.Executor, projectDir, pm string) string { + var names []string + switch pm { + case "npm": + names = []string{"package-lock.json"} + case "yarn", "yarn-berry": + names = []string{"yarn.lock"} + case "pnpm": + names = []string{"pnpm-lock.yaml"} + case "bun": + names = []string{"bun.lock", "bun.lockb"} + default: + return "" + } + for _, n := range names { + p := filepath.Join(projectDir, n) + if exec.FileExists(p) { + return p + } + } + return "" +} + +// mtimeOr0 returns the file's mtime in unix seconds, or 0 if it can't be stat'd. +func mtimeOr0(exec executor.Executor, path string) int64 { + if path == "" { + return 0 + } + info, err := exec.Stat(path) + if err != nil { + return 0 + } + return info.ModTime().Unix() +} + +// scanWorkerCount returns the number of concurrent project scans to dispatch. +// Defaults to min(NumCPU, 8). Override with STEPSEC_NODE_SCAN_WORKERS. +func scanWorkerCount(exec executor.Executor) int { + if v := exec.Getenv("STEPSEC_NODE_SCAN_WORKERS"); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + return n + } + } + n := runtime.NumCPU() + if n > 8 { + n = 8 + } + if n < 1 { + n = 1 + } + return n +} diff --git a/internal/detector/nodescan_cache_test.go b/internal/detector/nodescan_cache_test.go new file mode 100644 index 0000000..96abc0f --- /dev/null +++ b/internal/detector/nodescan_cache_test.go @@ -0,0 +1,106 @@ +package detector + +import ( + "os" + "path/filepath" + "testing" + + "github.com/step-security/dev-machine-guard/internal/executor" + "github.com/step-security/dev-machine-guard/internal/model" +) + +func TestScanCache_RoundTrip(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "scan-cache.json") + + c := newScanCache() + c.Projects["/app"] = cacheEntry{ + PackageManager: "npm", + LastScanUnix: 1700000000, + CachedResult: model.NodeScanResult{ + ProjectPath: "/app", + PackageManager: "npm", + RawStdoutBase64: "eyJkZXBzIjpbXX0=", + ExitCode: 0, + }, + } + if err := c.save(path); err != nil { + t.Fatalf("save: %v", err) + } + + loaded := loadScanCache(path) + if loaded.Version != scanCacheVersion { + t.Errorf("version: got %d, want %d", loaded.Version, scanCacheVersion) + } + entry, ok := loaded.Projects["/app"] + if !ok { + t.Fatal("missing /app entry after reload") + } + if entry.LastScanUnix != 1700000000 || entry.PackageManager != "npm" { + t.Errorf("entry mismatch: %+v", entry) + } + if entry.CachedResult.RawStdoutBase64 != "eyJkZXBzIjpbXX0=" { + t.Errorf("cached result lost: %+v", entry.CachedResult) + } +} + +func TestScanCache_MissReturnsEmpty(t *testing.T) { + c := loadScanCache(filepath.Join(t.TempDir(), "does-not-exist.json")) + if c == nil || c.Projects == nil { + t.Fatal("expected non-nil empty cache on miss") + } + if len(c.Projects) != 0 { + t.Errorf("expected empty projects map, got %d entries", len(c.Projects)) + } +} + +func TestScanCache_CorruptReturnsEmpty(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "scan-cache.json") + if err := os.WriteFile(path, []byte("not json"), 0o644); err != nil { + t.Fatal(err) + } + c := loadScanCache(path) + if len(c.Projects) != 0 { + t.Errorf("expected empty cache after corrupt read, got %d entries", len(c.Projects)) + } +} + +func TestScanCache_WrongVersionReturnsEmpty(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "scan-cache.json") + if err := os.WriteFile(path, []byte(`{"version":999,"projects":{"/a":{"package_manager":"npm","last_scan_unix":1}}}`), 0o644); err != nil { + t.Fatal(err) + } + c := loadScanCache(path) + if len(c.Projects) != 0 { + t.Errorf("expected empty cache on version mismatch, got %d entries", len(c.Projects)) + } +} + +func TestLockfileFor(t *testing.T) { + mock := executor.NewMock() + mock.SetFile(filepath.Join("/proj-npm", "package-lock.json"), []byte{}) + mock.SetFile(filepath.Join("/proj-yarn", "yarn.lock"), []byte{}) + mock.SetFile(filepath.Join("/proj-pnpm", "pnpm-lock.yaml"), []byte{}) + mock.SetFile(filepath.Join("/proj-bun", "bun.lockb"), []byte{}) + + cases := []struct { + dir, pm, want string + }{ + {"/proj-npm", "npm", filepath.Join("/proj-npm", "package-lock.json")}, + {"/proj-yarn", "yarn", filepath.Join("/proj-yarn", "yarn.lock")}, + {"/proj-yarn", "yarn-berry", filepath.Join("/proj-yarn", "yarn.lock")}, + {"/proj-pnpm", "pnpm", filepath.Join("/proj-pnpm", "pnpm-lock.yaml")}, + {"/proj-bun", "bun", filepath.Join("/proj-bun", "bun.lockb")}, + {"/missing", "npm", ""}, + {"/proj-npm", "unknown", ""}, + } + for _, c := range cases { + got := lockfileFor(mock, c.dir, c.pm) + if got != c.want { + t.Errorf("lockfileFor(%q,%q): got %q, want %q", c.dir, c.pm, got, c.want) + } + } +} + From a59519057c43041112b4223da8a9cdf32f90371d Mon Sep 17 00:00:00 2001 From: Swarit Pandey Date: Mon, 18 May 2026 17:38:49 +0530 Subject: [PATCH 2/2] fix(node-scan): bypass cmd /c on Windows for project scans MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Project and global yarn scans constructed `cd "" && ls ...` shell strings and dispatched them via `cmd /c`. Go's os/exec quotes the third arg (MSVC rules: backslash-escaped quotes), and cmd.exe applies its own quote-stripping rules to /c args. The two disagree when the arg contains both an embedded quoted path and special characters like &&, so cmd.exe emits "The filename, directory name, or volume label syntax is incorrect." and the scan returns exit code 1 with empty stdout. Confirmed on Windows Server 2025: every npm/yarn/pnpm scan failed, telemetry shipped a 92-byte error message instead of the package tree. Add Executor.RunInDir which uses os/exec's cmd.Dir on Windows (no cmd.exe at all). Unix path still uses the shell so RunAsUser/sudo delegation keeps working for root cron runs. Existing TestNodeScanner_ScanProject_Windows and ScanYarnGlobal_Windows tests only exercised the mock — they never actually invoked cmd.exe, so the bug never surfaced. Updated to mock the new dispatch path. Signed-off-by: Swarit Pandey --- internal/detector/nodescan.go | 23 ++++++++++++----------- internal/detector/nodescan_test.go | 9 +++++---- internal/detector/shellcmd.go | 18 ++++++++++++++++++ internal/executor/executor.go | 28 ++++++++++++++++++++++++++++ internal/executor/mock.go | 14 ++++++++++++++ 5 files changed, 77 insertions(+), 15 deletions(-) diff --git a/internal/detector/nodescan.go b/internal/detector/nodescan.go index 01895fe..86c8f7c 100644 --- a/internal/detector/nodescan.go +++ b/internal/detector/nodescan.go @@ -69,10 +69,16 @@ func (s *NodeScanner) runCmd(ctx context.Context, timeout time.Duration, name st return s.exec.RunWithTimeout(ctx, timeout, name, args...) } -// runShellCmd runs a shell command string, delegating to the logged-in user when running as root. -// Falls through to the platform-aware free function for the normal (non-delegation) path. -func (s *NodeScanner) runShellCmd(ctx context.Context, timeout time.Duration, shellCmd string) (string, string, int, error) { +// runCmdInDir runs a command from `dir`, delegating to the logged-in user when +// running as root. On Windows this bypasses cmd /c entirely (see runCmdInDir +// in shellcmd.go); RunAsUser delegation is Unix-only, so the sudo path always +// constructs a shell string. +func (s *NodeScanner) runCmdInDir(ctx context.Context, timeout time.Duration, dir, name string, args ...string) (string, string, int, error) { if s.shouldRunAsUser() { + shellCmd := "cd " + platformShellQuote(s.exec, dir) + " && " + name + for _, a := range args { + shellCmd += " " + a + } ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() stdout, err := s.exec.RunAsUser(ctx, s.loggedInUser, shellCmd) @@ -84,7 +90,7 @@ func (s *NodeScanner) runShellCmd(ctx context.Context, timeout time.Duration, sh } return stdout, "", 0, nil } - return runShellCmd(ctx, s.exec, timeout, shellCmd) + return runCmdInDir(ctx, s.exec, timeout, dir, name, args...) } // checkPath checks if a binary is available, using the logged-in user's PATH when running as root. @@ -167,8 +173,7 @@ func (s *NodeScanner) scanYarnGlobal(ctx context.Context) (model.NodeScanResult, } start := time.Now() - shellCmd := "cd " + platformShellQuote(s.exec, globalDir) + " && yarn list --json --depth=0" - stdout, stderr, exitCode, _ := s.runShellCmd(ctx, 60*time.Second, shellCmd) + stdout, stderr, exitCode, _ := s.runCmdInDir(ctx, 60*time.Second, globalDir, "yarn", "list", "--json", "--depth=0") duration := time.Since(start).Milliseconds() errMsg := "" @@ -424,11 +429,7 @@ func (s *NodeScanner) scanProject(ctx context.Context, projectDir string) model. } start := time.Now() - cmdStr := "cd " + platformShellQuote(s.exec, projectDir) + " && " + cmd - for _, a := range args { - cmdStr += " " + a - } - stdout, stderr, exitCode, _ := s.runShellCmd(ctx, 30*time.Second, cmdStr) + stdout, stderr, exitCode, _ := s.runCmdInDir(ctx, 30*time.Second, projectDir, cmd, args...) duration := time.Since(start).Milliseconds() errMsg := "" diff --git a/internal/detector/nodescan_test.go b/internal/detector/nodescan_test.go index 0d047b8..2be33a1 100644 --- a/internal/detector/nodescan_test.go +++ b/internal/detector/nodescan_test.go @@ -89,9 +89,9 @@ func TestNodeScanner_ScanYarnGlobal_Windows(t *testing.T) { mock.SetPath("yarn", `C:\Program Files\nodejs\yarn.cmd`) mock.SetCommand("1.22.19\n", "", 0, "yarn", "--version") mock.SetCommand(`C:\Users\dev\AppData\Local\Yarn\Data\global`+"\n", "", 0, "yarn", "global", "dir") - // runShellCmd dispatches to cmd /c on Windows; platformShellQuote uses double quotes + // On Windows the scanner uses RunInDir (cmd.Dir-based) to avoid cmd /c quoting issues. mock.SetCommand(`{"type":"tree","data":{"trees":[]}}`, "", 0, - "cmd", "/c", `cd "C:\Users\dev\AppData\Local\Yarn\Data\global" && yarn list --json --depth=0`) + `C:\Users\dev\AppData\Local\Yarn\Data\global`+"|yarn", "list", "--json", "--depth=0") scanner := newTestScanner(mock) results := scanner.ScanGlobalPackages(context.Background()) @@ -154,8 +154,9 @@ func TestNodeScanner_ScanProject_Windows(t *testing.T) { // DetectProjectPM uses filepath.Join which is host-dependent; // construct the mock file path the same way the code will. mock.SetFile(filepath.Join(`C:\Users\dev\myapp`, "package-lock.json"), []byte{}) + // On Windows the scanner uses RunInDir (cmd.Dir-based) — mock keys on dir|name + args. mock.SetCommand(`{"dependencies":{"lodash":{"version":"4.17.21"}}}`, "", 0, - "cmd", "/c", `cd "C:\Users\dev\myapp" && npm ls --json --depth=3`) + `C:\Users\dev\myapp|npm`, "ls", "--json", "--depth=3") scanner := newTestScanner(mock) result := scanner.scanProject(context.Background(), `C:\Users\dev\myapp`) @@ -188,7 +189,7 @@ func TestNodeScanner_ScanProject_YarnBerry_Windows(t *testing.T) { mock.SetFile(filepath.Join(projectDir, "yarn.lock"), []byte{}) mock.SetFile(filepath.Join(projectDir, ".yarnrc.yml"), []byte{}) mock.SetCommand(`{"name":"myapp","children":[]}`, "", 0, - "cmd", "/c", `cd "C:\Users\dev\myapp" && yarn info --all --json`) + `C:\Users\dev\myapp|yarn`, "info", "--all", "--json") scanner := newTestScanner(mock) result := scanner.scanProject(context.Background(), projectDir) diff --git a/internal/detector/shellcmd.go b/internal/detector/shellcmd.go index a47c971..69ebffa 100644 --- a/internal/detector/shellcmd.go +++ b/internal/detector/shellcmd.go @@ -18,6 +18,24 @@ func runShellCmd(ctx context.Context, exec executor.Executor, timeout time.Durat return exec.RunWithTimeout(ctx, timeout, "bash", "-c", command) } +// runCmdInDir runs a command from a specific working directory. On Windows it +// dispatches directly via the executor's RunInDir (using os/exec's cmd.Dir), +// avoiding the `cmd /c "cd && "` pattern — Go's os/exec quoting and +// cmd.exe's quote-stripping rules conflict when paths or arguments need +// escaping, producing "The filename, directory name, or volume label syntax +// is incorrect." On Unix we keep the shell-command-string approach so root +// runs can still delegate via RunAsUser/sudo. +func runCmdInDir(ctx context.Context, exec executor.Executor, timeout time.Duration, dir, name string, args ...string) (string, string, int, error) { + if exec.GOOS() == "windows" { + return exec.RunInDir(ctx, timeout, dir, name, args...) + } + shellCmd := "cd " + platformShellQuote(exec, dir) + " && " + name + for _, a := range args { + shellCmd += " " + a + } + return runShellCmd(ctx, exec, timeout, shellCmd) +} + // platformShellQuote quotes a string for use in a shell command. // On Unix: single quotes with escaping. // On Windows: double quotes with escaping. diff --git a/internal/executor/executor.go b/internal/executor/executor.go index 8b15cbb..759ba41 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -20,6 +20,11 @@ type Executor interface { Run(ctx context.Context, name string, args ...string) (stdout, stderr string, exitCode int, err error) // RunWithTimeout executes a command with a timeout. RunWithTimeout(ctx context.Context, timeout time.Duration, name string, args ...string) (stdout, stderr string, exitCode int, err error) + // RunInDir executes a command with a timeout, in the given working directory. + // Used to avoid the `cmd /c "cd && "` pattern on Windows, where + // Go's os/exec quoting and cmd.exe's quote-stripping rules conflict and + // mangle paths. + RunInDir(ctx context.Context, timeout time.Duration, dir, name string, args ...string) (stdout, stderr string, exitCode int, err error) // RunAsUser runs a shell command as a specific user (for root -> user delegation). RunAsUser(ctx context.Context, username, command string) (string, error) // LookPath searches for an executable in PATH. @@ -87,6 +92,29 @@ func (r *Real) RunWithTimeout(ctx context.Context, timeout time.Duration, name s return stdout, stderr, code, err } +func (r *Real) RunInDir(ctx context.Context, timeout time.Duration, dir, name string, args ...string) (string, string, int, error) { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + cmd := exec.CommandContext(ctx, name, args...) + cmd.Dir = dir + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + err := cmd.Run() + if ctx.Err() == context.DeadlineExceeded { + return stdout.String(), stderr.String(), 124, fmt.Errorf("command timed out after %s", timeout) + } + exitCode := 0 + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + exitCode = exitErr.ExitCode() + } else { + return stdout.String(), stderr.String(), -1, err + } + } + return stdout.String(), stderr.String(), exitCode, nil +} + func (r *Real) LookPath(name string) (string, error) { return exec.LookPath(name) } diff --git a/internal/executor/mock.go b/internal/executor/mock.go index cfb79a7..af8151f 100644 --- a/internal/executor/mock.go +++ b/internal/executor/mock.go @@ -167,6 +167,20 @@ func (m *Mock) RunWithTimeout(ctx context.Context, _ time.Duration, name string, return m.Run(ctx, name, args...) } +// RunInDir matches first on the (dir, name, args) tuple via SetCommandInDir, +// then falls back to plain (name, args) — so existing tests that don't care +// about the working directory don't need to be rewritten. +func (m *Mock) RunInDir(ctx context.Context, _ time.Duration, dir, name string, args ...string) (string, string, int, error) { + m.mu.RLock() + key := cmdKey(dir+"|"+name, args...) + if r, ok := m.commands[key]; ok { + m.mu.RUnlock() + return r.Stdout, r.Stderr, r.ExitCode, r.Err + } + m.mu.RUnlock() + return m.Run(ctx, name, args...) +} + func (m *Mock) RunAsUser(ctx context.Context, _ string, command string) (string, error) { stdout, _, _, err := m.Run(ctx, "bash", "-c", command) return stdout, err