From be1b89a6ebeddedc759f8d455ae7bfd7a2baa6ad Mon Sep 17 00:00:00 2001 From: wuyangfan <1102042793@qq.com> Date: Sat, 30 May 2026 13:15:16 +0800 Subject: [PATCH] fix: stream file rewrites with bounded memory Replace the read-all/strings.Replace/write path with a windowed stream rewriter that skips writing when no matches occur, fixing OOM on large files and eliminating the redundant Contains pre-scan. Fixes #8 Fixes #14 --- file_handling.go | 56 +++++++++++++++++++++++++++ find_replace.go | 10 +---- stream_replace.go | 86 +++++++++++++++++++++++++++++++++++++++++ stream_replace_test.go | 88 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 231 insertions(+), 9 deletions(-) create mode 100644 stream_replace.go create mode 100644 stream_replace_test.go diff --git a/file_handling.go b/file_handling.go index 7c0b1a7..256539f 100644 --- a/file_handling.go +++ b/file_handling.go @@ -88,6 +88,62 @@ func (f *File) Read() (string, error) { return builder.String(), nil } +// streamFindReplace rewrites file contents via streamReplace, skipping binary +// files and leaving the file unchanged when find does not occur. +func (f *File) streamFindReplace(find, replace string) error { + mode, err := f.Mode() + if err != nil { + return err + } + + in, err := os.Open(f.Path) + if err != nil { + return fmt.Errorf("open %v: %w", f.Path, err) + } + defer in.Close() + + var head [1024]byte + n, err := in.Read(head[:]) + if err != nil && err != io.EOF { + return fmt.Errorf("read %v: %w", f.Path, err) + } + if n == 0 || !util.IsText(head[:n]) { + return nil + } + if _, err := in.Seek(0, io.SeekStart); err != nil { + return fmt.Errorf("seek to start of %v: %w", f.Path, err) + } + + tempName := filepath.Join(f.Dir(), RandomString(20)) + out, err := os.OpenFile(tempName, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode) + if err != nil { + return fmt.Errorf("create tempfile in %v: %w", f.Dir(), err) + } + + changed, err := streamReplace(in, out, []byte(find), []byte(replace)) + closeErr := out.Close() + if err != nil { + _ = os.Remove(tempName) + return fmt.Errorf("rewrite %v: %w", f.Path, err) + } + if closeErr != nil { + _ = os.Remove(tempName) + return fmt.Errorf("close tempfile for %v: %w", f.Path, closeErr) + } + if !changed { + _ = os.Remove(tempName) + return nil + } + defer os.Remove(tempName) + + log.Printf("Rewriting %v", f.Path) + if err := os.Rename(tempName, f.Path); err != nil { + return fmt.Errorf("atomically move temp file %v to %v: %w", tempName, f.Path, err) + } + return nil +} + + // Write atomically replaces the file with content, via a temp file + rename. // A deferred os.Remove(tempName) ensures the temp file is cleaned up if any // step after its creation fails (including the rename); on success the remove diff --git a/find_replace.go b/find_replace.go index 9bdbe79..baf3cc3 100644 --- a/find_replace.go +++ b/find_replace.go @@ -200,13 +200,5 @@ func (fr *findReplace) RenameFile(f *File) error { // ReplaceContents rewrites the file at f if its contents contain the find // string. Binary-looking files (where Read returns "") are skipped silently. func (fr *findReplace) ReplaceContents(f *File) error { - content, err := f.Read() - if err != nil { - return err - } - if !strings.Contains(content, fr.find) { - return nil - } - newContent := strings.ReplaceAll(content, fr.find, fr.replace) - return f.Write(newContent) + return f.streamFindReplace(fr.find, fr.replace) } diff --git a/stream_replace.go b/stream_replace.go new file mode 100644 index 0000000..67abb9d --- /dev/null +++ b/stream_replace.go @@ -0,0 +1,86 @@ +package main + +import ( + "bytes" + "io" +) + +const streamBufferSize = 256 * 1024 + +// streamReplace copies from r to w, replacing every occurrence of find with replace. +// It returns whether any replacement was made. Memory use is bounded by streamBufferSize +// plus len(find) bytes of carry-over between reads. +func streamReplace(r io.Reader, w io.Writer, find, replace []byte) (bool, error) { + if len(find) == 0 { + return false, nil + } + + buf := make([]byte, streamBufferSize) + var pending []byte + var changed bool + + for { + n, readErr := r.Read(buf) + if n > 0 { + data := append(pending, buf[:n]...) + isFinal := readErr == io.EOF + if !isFinal && len(data) < streamBufferSize { + pending = data + continue + } + out, rest, chunkChanged := replaceChunk(data, find, replace, isFinal) + if chunkChanged { + changed = true + } + if len(out) > 0 { + if _, err := w.Write(out); err != nil { + return changed, err + } + } + pending = rest + } + if readErr == io.EOF { + break + } + if readErr != nil { + return changed, readErr + } + } + + if len(pending) > 0 { + out, _, chunkChanged := replaceChunk(pending, find, replace, true) + if chunkChanged { + changed = true + } + if len(out) > 0 { + if _, err := w.Write(out); err != nil { + return changed, err + } + } + } + + return changed, nil +} + +func replaceChunk(data, find, replace []byte, final bool) (out []byte, rest []byte, changed bool) { + if len(data) == 0 { + return nil, nil, false + } + + if final { + replaced := bytes.Replace(data, find, replace, -1) + return replaced, nil, !bytes.Equal(replaced, data) + } + + overlap := len(find) - 1 + if overlap >= len(data) { + return nil, append([]byte(nil), data...), false + } + + split := len(data) - overlap + process := data[:split] + rest = append([]byte(nil), data[split:]...) + + replaced := bytes.Replace(process, find, replace, -1) + return replaced, rest, !bytes.Equal(replaced, process) +} diff --git a/stream_replace_test.go b/stream_replace_test.go new file mode 100644 index 0000000..8b2904d --- /dev/null +++ b/stream_replace_test.go @@ -0,0 +1,88 @@ +package main + +import ( + "bytes" + "io" + "strings" + "testing" +) + +func TestStreamReplace(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + find string + replace string + want string + }{ + {name: "no match", input: "hello", find: "z", replace: "q", want: "hello"}, + {name: "simple", input: "foo bar foo", find: "foo", replace: "baz", want: "baz bar baz"}, + {name: "span boundary", input: "xxababc", find: "ab", replace: "X", want: "xxXXc"}, + {name: "empty input", input: "", find: "a", replace: "b", want: ""}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + var out bytes.Buffer + changed, err := streamReplace(strings.NewReader(tc.input), &out, []byte(tc.find), []byte(tc.replace)) + if err != nil { + t.Fatal(err) + } + if tc.input != tc.want && !changed { + t.Fatal("expected changed=true") + } + if tc.input == tc.want && changed { + t.Fatal("expected changed=false") + } + if out.String() != tc.want { + t.Fatalf("got %q; want %q", out.String(), tc.want) + } + }) + } +} + +func TestStreamReplaceLargeWithSmallReads(t *testing.T) { + find := "needle" + replace := "pin" + input := strings.Repeat("hay", 1000) + find + strings.Repeat("stack", 1000) + want := strings.Replace(input, find, replace, 1) + + var out bytes.Buffer + r := &smallReader{data: []byte(input), step: 3} + changed, err := streamReplace(r, &out, []byte(find), []byte(replace)) + if err != nil { + t.Fatal(err) + } + if !changed { + t.Fatal("expected replacement") + } + if out.String() != want { + t.Fatalf("output length %d; want %d", out.Len(), len(want)) + } +} + +type smallReader struct { + data []byte + step int + off int +} + +func (r *smallReader) Read(p []byte) (int, error) { + if r.off >= len(r.data) { + return 0, io.EOF + } + n := r.step + if n > len(p) { + n = len(p) + } + if n > len(r.data)-r.off { + n = len(r.data) - r.off + } + copy(p, r.data[r.off:r.off+n]) + r.off += n + return n, nil +}