diff --git a/file_handling.go b/file_handling.go index 7c0b1a7..a3538b4 100644 --- a/file_handling.go +++ b/file_handling.go @@ -7,8 +7,8 @@ import ( "os" "path/filepath" "strings" - - "golang.org/x/tools/godoc/util" + "unicode/utf8" + "bytes" ) type File struct { @@ -60,8 +60,8 @@ func (f *File) Mode() (os.FileMode, error) { } // Read reads the file into a string, or returns the empty string for binary -// files. An error indicates the file could not be opened or fully read; the -// caller should log-and-skip rather than abort. +// files (NUL bytes or invalid UTF-8). An error indicates the file could not be +// opened or fully read; the caller should log-and-skip rather than abort. func (f *File) Read() (string, error) { handle, err := os.Open(f.Path) if err != nil { @@ -69,21 +69,46 @@ func (f *File) Read() (string, error) { } defer handle.Close() - // Check if the file looks like text before reading the entire file. var buf [1024]byte - n, err := handle.Read(buf[0:]) - if err != nil || !util.IsText(buf[0:n]) { + n, err := handle.Read(buf[:]) + if err != nil && err != io.EOF { + return "", fmt.Errorf("read %v: %w", f.Path, err) + } + if n == 0 { return "", nil } - - // Reset file handle so we can read the entire file. - if _, err := handle.Seek(0, io.SeekStart); err != nil { - return "", fmt.Errorf("seek to start of %v: %w", f.Path, err) + if !isTextBytes(buf[:n]) { + return "", nil + } + if err == io.EOF { + return string(buf[:n]), nil } builder := new(strings.Builder) - if _, err := io.Copy(builder, handle); err != nil { - return "", fmt.Errorf("read %v: %w", f.Path, err) + if _, wErr := builder.Write(buf[:n]); wErr != nil { + return "", fmt.Errorf("read %v: %w", f.Path, wErr) + } + + chunk := make([]byte, 32*1024) + for { + readN, readErr := handle.Read(chunk) + if readN > 0 { + if bytes.IndexByte(chunk[:readN], 0) >= 0 { + return "", nil + } + if !utf8.Valid(chunk[:readN]) { + return "", nil + } + if _, wErr := builder.Write(chunk[:readN]); wErr != nil { + return "", fmt.Errorf("read %v: %w", f.Path, wErr) + } + } + if readErr == io.EOF { + break + } + if readErr != nil { + return "", fmt.Errorf("read %v: %w", f.Path, readErr) + } } return builder.String(), nil } diff --git a/file_handling_test.go b/file_handling_test.go index 91ee538..17f2e26 100644 --- a/file_handling_test.go +++ b/file_handling_test.go @@ -1,6 +1,7 @@ package main import ( + "os" "path/filepath" "testing" ) @@ -77,3 +78,44 @@ func TestNewFile(t *testing.T) { }) } } + +func TestReadSkipsBinaryWithNUL(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "mixed.txt") + content := []byte("text prefix\x00binary suffix") + if err := os.WriteFile(path, content, 0o644); err != nil { + t.Fatal(err) + } + + f, err := NewFile(path) + if err != nil { + t.Fatal(err) + } + got, err := f.Read() + if err != nil { + t.Fatal(err) + } + if got != "" { + t.Fatalf("Read() = %q; want empty for NUL-containing file", got) + } +} + +func TestReadReturnsShortTextFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "short.txt") + if err := os.WriteFile(path, []byte("hello"), 0o644); err != nil { + t.Fatal(err) + } + f, err := NewFile(path) + if err != nil { + t.Fatal(err) + } + got, err := f.Read() + if err != nil { + t.Fatal(err) + } + if got != "hello" { + t.Fatalf("Read() = %q; want %q", got, "hello") + } +} + diff --git a/go.mod b/go.mod index 66fd9af..d0c7f3d 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,3 @@ module github.com/dolph/find-replace go 1.20 - -require golang.org/x/tools v0.7.0 diff --git a/go.sum b/go.sum index b522ba0..e69de29 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +0,0 @@ -golang.org/x/tools v0.1.9 h1:j9KsMiaP1c3B0OTQGth0/k+miLGTgLsAFUCrF2vLcF8= -golang.org/x/tools v0.1.9/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= -golang.org/x/tools v0.7.0 h1:W4OVu8VVOaIO0yzWMNdepAulS7YfoS3Zabrm8DOXXU4= -golang.org/x/tools v0.7.0/go.mod h1:4pg6aUX35JBAogB10C9AtvVL+qowtN4pT3CGSQex14s= diff --git a/text_detect.go b/text_detect.go new file mode 100644 index 0000000..ae50590 --- /dev/null +++ b/text_detect.go @@ -0,0 +1,17 @@ +package main + +import ( + "bytes" + "unicode/utf8" +) + +// isTextBytes reports whether b is valid UTF-8 and contains no NUL bytes. +func isTextBytes(b []byte) bool { + if len(b) == 0 { + return true + } + if bytes.IndexByte(b, 0) >= 0 { + return false + } + return utf8.Valid(b) +} diff --git a/text_detect_test.go b/text_detect_test.go new file mode 100644 index 0000000..05810a9 --- /dev/null +++ b/text_detect_test.go @@ -0,0 +1,24 @@ +package main + +import "testing" + +func TestIsTextBytes(t *testing.T) { + cases := []struct { + name string + in []byte + want bool + }{ + {name: "empty", in: nil, want: true}, + {name: "ascii", in: []byte("hello"), want: true}, + {name: "utf8", in: []byte("héllo"), want: true}, + {name: "nul", in: []byte("a\x00b"), want: false}, + {name: "invalid utf8", in: []byte{0xff, 0xfe}, want: false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := isTextBytes(tc.in); got != tc.want { + t.Fatalf("isTextBytes(%q) = %v; want %v", tc.in, got, tc.want) + } + }) + } +}