Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions file_handling.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,33 @@ func (f *File) Write(content string) error {
return err
}

tempName := filepath.Join(f.Dir(), RandomString(20))
if err := os.WriteFile(tempName, []byte(content), mode); err != nil {
tempFile, err := os.CreateTemp(f.Dir(), ".find-replace-*")
if err != nil {
return fmt.Errorf("create tempfile in %v: %w", f.Dir(), err)
}
// Make sure the temp file is removed if the rename below fails. On
// success, the rename has already moved the file to f.Path so this is
// a no-op (we deliberately ignore the not-exist error).
defer os.Remove(tempName)
tempName := tempFile.Name()
removeTemp := true
defer func() {
if removeTemp {
_ = os.Remove(tempName)
}
}()
if err := tempFile.Chmod(mode); err != nil {
_ = tempFile.Close()
return fmt.Errorf("chmod temp file %v: %w", tempName, err)
}
if _, err := tempFile.WriteString(content); err != nil {
_ = tempFile.Close()
return fmt.Errorf("write temp file %v: %w", tempName, err)
}
if err := tempFile.Close(); err != nil {
return fmt.Errorf("close temp file %v: %w", tempName, err)
}

log.Printf("Rewriting %v", f.Path)
if err := os.Rename(tempName, f.Path); err != nil {
return fmt.Errorf("atomically move temp file %v to %v: %w", tempName, f.Path, err)
}
removeTemp = false
return nil
}
26 changes: 26 additions & 0 deletions file_handling_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"os"
"path/filepath"
"testing"
)
Expand Down Expand Up @@ -77,3 +78,28 @@ func TestNewFile(t *testing.T) {
})
}
}

func TestWriteUsesCreateTempPrefix(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "target.txt")
if err := os.WriteFile(path, []byte("old"), 0o644); err != nil {
t.Fatal(err)
}
f, err := NewFile(path)
if err != nil {
t.Fatal(err)
}
if err := f.Write("new"); err != nil {
t.Fatal(err)
}
entries, err := os.ReadDir(dir)
if err != nil {
t.Fatal(err)
}
for _, e := range entries {
if e.Name() != "target.txt" {
t.Fatalf("unexpected leftover file %q", e.Name())
}
}
}