diff --git a/file_handling.go b/file_handling.go index 7c0b1a7..86764b3 100644 --- a/file_handling.go +++ b/file_handling.go @@ -93,15 +93,19 @@ func (f *File) Read() (string, error) { // step after its creation fails (including the rename); on success the remove // is a no-op because the file has already been renamed away. func (f *File) Write(content string) error { - mode, err := f.Mode() + info, err := f.Info() if err != nil { return err } + mode := info.Mode() tempName := filepath.Join(f.Dir(), RandomString(20)) if err := os.WriteFile(tempName, []byte(content), mode); err != nil { return fmt.Errorf("create tempfile in %v: %w", f.Dir(), err) } + if err := chownTempFromInfo(tempName, info); err != nil { + return fmt.Errorf("preserve ownership on temp file %v: %w", tempName, err) + } // Make sure the temp file is removed if the rename below fails. On // success, the rename has already moved the file to f.Path so this is // a no-op (we deliberately ignore the not-exist error). diff --git a/file_handling_test.go b/file_handling_test.go index 91ee538..eccf9dd 100644 --- a/file_handling_test.go +++ b/file_handling_test.go @@ -1,7 +1,10 @@ package main import ( + "os" "path/filepath" + "runtime" + "syscall" "testing" ) @@ -77,3 +80,53 @@ func TestNewFile(t *testing.T) { }) } } + +func TestWritePreservesOwnership(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("ownership preservation not implemented on Windows") + } + + dir := t.TempDir() + path := filepath.Join(dir, "f.txt") + if err := os.WriteFile(path, []byte("old"), 0o644); err != nil { + t.Fatal(err) + } + + before, err := os.Stat(path) + if err != nil { + t.Fatal(err) + } + beforeStat, ok := before.Sys().(*syscall.Stat_t) + if !ok { + t.Skip("syscall.Stat_t not available") + } + + f, err := NewFile(path) + if err != nil { + t.Fatal(err) + } + if err := f.Write("new"); err != nil { + t.Fatal(err) + } + + after, err := os.Stat(path) + if err != nil { + t.Fatal(err) + } + afterStat, ok := after.Sys().(*syscall.Stat_t) + if !ok { + t.Fatal("expected syscall.Stat_t after rewrite") + } + if afterStat.Uid != beforeStat.Uid || afterStat.Gid != beforeStat.Gid { + t.Fatalf("ownership changed: uid %d->%d gid %d->%d", + beforeStat.Uid, afterStat.Uid, beforeStat.Gid, afterStat.Gid) + } + got, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + if string(got) != "new" { + t.Fatalf("content = %q; want %q", got, "new") + } +} + diff --git a/ownership_unix.go b/ownership_unix.go new file mode 100644 index 0000000..1c3a8d5 --- /dev/null +++ b/ownership_unix.go @@ -0,0 +1,16 @@ +//go:build !windows + +package main + +import ( + "os" + "syscall" +) + +func chownTempFromInfo(tempPath string, info os.FileInfo) error { + sys, ok := info.Sys().(*syscall.Stat_t) + if !ok { + return nil + } + return os.Chown(tempPath, int(sys.Uid), int(sys.Gid)) +} diff --git a/ownership_windows.go b/ownership_windows.go new file mode 100644 index 0000000..f0abaca --- /dev/null +++ b/ownership_windows.go @@ -0,0 +1,9 @@ +//go:build windows + +package main + +import "os" + +func chownTempFromInfo(tempPath string, info os.FileInfo) error { + return nil +}