diff --git a/file_handling.go b/file_handling.go index 7c0b1a7..5331f25 100644 --- a/file_handling.go +++ b/file_handling.go @@ -98,18 +98,33 @@ func (f *File) Write(content string) error { return err } - tempName := filepath.Join(f.Dir(), RandomString(20)) - if err := os.WriteFile(tempName, []byte(content), mode); err != nil { + tempFile, err := os.CreateTemp(f.Dir(), ".find-replace-*") + if err != nil { return fmt.Errorf("create tempfile in %v: %w", f.Dir(), err) } - // Make sure the temp file is removed if the rename below fails. On - // success, the rename has already moved the file to f.Path so this is - // a no-op (we deliberately ignore the not-exist error). - defer os.Remove(tempName) + tempName := tempFile.Name() + removeTemp := true + defer func() { + if removeTemp { + _ = os.Remove(tempName) + } + }() + if err := tempFile.Chmod(mode); err != nil { + _ = tempFile.Close() + return fmt.Errorf("chmod temp file %v: %w", tempName, err) + } + if _, err := tempFile.WriteString(content); err != nil { + _ = tempFile.Close() + return fmt.Errorf("write temp file %v: %w", tempName, err) + } + if err := tempFile.Close(); err != nil { + return fmt.Errorf("close temp file %v: %w", tempName, err) + } log.Printf("Rewriting %v", f.Path) if err := os.Rename(tempName, f.Path); err != nil { return fmt.Errorf("atomically move temp file %v to %v: %w", tempName, f.Path, err) } + removeTemp = false return nil } diff --git a/file_handling_test.go b/file_handling_test.go index 91ee538..b5b8f64 100644 --- a/file_handling_test.go +++ b/file_handling_test.go @@ -1,6 +1,7 @@ package main import ( + "os" "path/filepath" "testing" ) @@ -77,3 +78,28 @@ func TestNewFile(t *testing.T) { }) } } + +func TestWriteUsesCreateTempPrefix(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "target.txt") + if err := os.WriteFile(path, []byte("old"), 0o644); err != nil { + t.Fatal(err) + } + f, err := NewFile(path) + if err != nil { + t.Fatal(err) + } + if err := f.Write("new"); err != nil { + t.Fatal(err) + } + entries, err := os.ReadDir(dir) + if err != nil { + t.Fatal(err) + } + for _, e := range entries { + if e.Name() != "target.txt" { + t.Fatalf("unexpected leftover file %q", e.Name()) + } + } +} +