diff --git a/internal/app/azldev/core/sources/overlays.go b/internal/app/azldev/core/sources/overlays.go index 95705f8..d89e476 100644 --- a/internal/app/azldev/core/sources/overlays.go +++ b/internal/app/azldev/core/sources/overlays.go @@ -173,7 +173,7 @@ func ApplySpecOverlay(overlay projectconfig.ComponentOverlay, openedSpec *spec.S return fmt.Errorf("failed to add patch entry to spec:\n%w", err) } case projectconfig.ComponentOverlayRemovePatch: - err := openedSpec.RemovePatchEntry(overlay.PackageName, overlay.Filename) + err := openedSpec.RemovePatchEntry(overlay.Filename) if err != nil { return fmt.Errorf("failed to remove patch entry from spec:\n%w", err) } diff --git a/internal/rpm/spec/edit.go b/internal/rpm/spec/edit.go index a4e3a75..a979067 100644 --- a/internal/rpm/spec/edit.go +++ b/internal/rpm/spec/edit.go @@ -50,7 +50,7 @@ func (s *Spec) UpdateExistingTag(packageName string, tag string, value string) ( var updated bool - err = s.VisitTags(packageName, func(tagLine *TagLine, ctx *Context) error { + err = s.VisitTagsPackage(packageName, func(tagLine *TagLine, ctx *Context) error { if strings.ToLower(tagLine.Tag) != tagToCompareAgainst { return nil } @@ -100,20 +100,14 @@ func (s *Spec) RemoveTag(packageName string, tag string, value string) (err erro return nil } -// VisitTags iterates over all tag lines in the given package, calling the visitor function -// for each one. The visitor receives the parsed [TagLine] and the mutation [Context]. This -// extracts the common target-type / package / tag-type filtering that many tag-oriented -// methods need. -func (s *Spec) VisitTags(packageName string, visitor func(tagLine *TagLine, ctx *Context) error) error { +// VisitTags iterates over all tag lines across all packages, calling the visitor function +// for each one. The visitor receives the parsed [TagLine] and the mutation [Context]. +func (s *Spec) VisitTags(visitor func(tagLine *TagLine, ctx *Context) error) error { return s.Visit(func(ctx *Context) error { if ctx.Target.TargetType != SectionLineTarget { return nil } - if ctx.CurrentSection.Package != packageName { - return nil - } - if ctx.Target.Line.Parsed.GetType() != Tag { return nil } @@ -127,13 +121,27 @@ func (s *Spec) VisitTags(packageName string, visitor func(tagLine *TagLine, ctx }) } +// VisitTagsPackage iterates over all tag lines in the given package, calling the visitor +// function for each one. The visitor receives the parsed [TagLine] and the mutation [Context]. +// This extracts the common target-type / package / tag-type filtering that many tag-oriented +// methods need. +func (s *Spec) VisitTagsPackage(packageName string, visitor func(tagLine *TagLine, ctx *Context) error) error { + return s.VisitTags(func(tagLine *TagLine, ctx *Context) error { + if ctx.CurrentSection.Package != packageName { + return nil + } + + return visitor(tagLine, ctx) + }) +} + // RemoveTagsMatching removes all tags in the given package for which the provided matcher // function returns true. The matcher receives the tag name and value as arguments. Returns // the number of tags removed. If no matching tags were found, returns 0 and no error. func (s *Spec) RemoveTagsMatching(packageName string, matcher func(tag, value string) bool) (int, error) { removed := 0 - err := s.VisitTags(packageName, func(tagLine *TagLine, ctx *Context) error { + err := s.VisitTagsPackage(packageName, func(tagLine *TagLine, ctx *Context) error { if !matcher(tagLine.Tag, tagLine.Value) { return nil } @@ -600,7 +608,7 @@ func (s *Spec) AddPatchEntry(packageName, filename string) error { return s.AppendLinesToSection("%patchlist", "", []string{filename}) } - highest, err := s.GetHighestPatchTagNumber(packageName) + highest, err := s.GetHighestPatchTagNumber() if err != nil { return fmt.Errorf("failed to scan for existing patch tags:\n%w", err) } @@ -610,24 +618,13 @@ func (s *Spec) AddPatchEntry(packageName, filename string) error { // RemovePatchEntry removes all references to patches matching the given pattern from the spec. // The pattern is a glob pattern (supporting doublestar syntax) matched against PatchN tag values -// and %patchlist entries. Returns an error if no references matched the pattern. -func (s *Spec) RemovePatchEntry(packageName, pattern string) error { - slog.Debug("Removing patch entry from spec", "package", packageName, "pattern", pattern) +// and %patchlist entries across all packages. Returns an error if no references matched the pattern. +func (s *Spec) RemovePatchEntry(pattern string) error { + slog.Debug("Removing patch entry from spec", "pattern", pattern) totalRemoved := 0 - tagsRemoved, err := s.RemoveTagsMatching(packageName, func(tag, value string) bool { - if _, ok := ParsePatchTagNumber(tag); !ok { - return false - } - - matched, matchErr := doublestar.Match(pattern, value) - if matchErr != nil { - return false - } - - return matched - }) + tagsRemoved, err := s.removePatchTagsMatching(pattern) if err != nil { return fmt.Errorf("failed to remove matching patch tags:\n%w", err) } @@ -655,6 +652,33 @@ func (s *Spec) RemovePatchEntry(packageName, pattern string) error { return nil } +// removePatchTagsMatching removes all PatchN tags across all packages whose values match the +// given glob pattern. Returns the number of tags removed. +func (s *Spec) removePatchTagsMatching(pattern string) (int, error) { + removed := 0 + + err := s.VisitTags(func(tagLine *TagLine, ctx *Context) error { + if _, ok := ParsePatchTagNumber(tagLine.Tag); !ok { + return nil + } + + matched, matchErr := doublestar.Match(pattern, tagLine.Value) + if matchErr != nil { + return fmt.Errorf("failed to match glob pattern %#q against %#q:\n%w", pattern, tagLine.Value, matchErr) + } + + if matched { + ctx.RemoveLine() + + removed++ + } + + return nil + }) + + return removed, err +} + // removePatchlistEntriesMatching removes lines from the %patchlist section whose trimmed content // matches the given glob pattern. Returns the number of entries removed. func (s *Spec) removePatchlistEntriesMatching(pattern string) (int, error) { @@ -692,13 +716,13 @@ func (s *Spec) removePatchlistEntriesMatching(pattern string) (int, error) { } // GetHighestPatchTagNumber scans the spec for all PatchN tags (where N is a decimal number) -// in the given package and returns the highest N found. Returns -1 if no numeric patch tags +// across all packages and returns the highest N found. Returns -1 if no numeric patch tags // exist. Tags with non-numeric suffixes (e.g., macro-based names like Patch%{n}) are silently // skipped. -func (s *Spec) GetHighestPatchTagNumber(packageName string) (int, error) { +func (s *Spec) GetHighestPatchTagNumber() (int, error) { highest := -1 - err := s.VisitTags(packageName, func(tagLine *TagLine, _ *Context) error { + err := s.VisitTags(func(tagLine *TagLine, _ *Context) error { num, isPatchTag := ParsePatchTagNumber(tagLine.Tag) if isPatchTag && num > highest { highest = num diff --git a/internal/rpm/spec/edit_test.go b/internal/rpm/spec/edit_test.go index 2db001f..3ad3b80 100644 --- a/internal/rpm/spec/edit_test.go +++ b/internal/rpm/spec/edit_test.go @@ -1150,10 +1150,9 @@ func TestHasSection(t *testing.T) { func TestGetHighestPatchTagNumber(t *testing.T) { tests := []struct { - name string - input string - packageName string - expected int + name string + input string + expected int }{ { name: "no patch tags", @@ -1181,16 +1180,9 @@ func TestGetHighestPatchTagNumber(t *testing.T) { expected: 1, }, { - name: "patches in sub-package are isolated", - input: "Name: test\nPatch0: main.patch\n\n%package devel\nPatch1: devel.patch\n", - packageName: "devel", - expected: 1, - }, - { - name: "main package with sub-package present", - input: "Name: test\nPatch0: main.patch\n\n%package devel\nPatch5: devel.patch\n", - packageName: "", - expected: 0, + name: "scans across all packages", + input: "Name: test\nPatch0: main.patch\n\n%package devel\nPatch5: devel.patch\n", + expected: 5, }, } @@ -1199,7 +1191,7 @@ func TestGetHighestPatchTagNumber(t *testing.T) { specFile, err := spec.OpenSpec(strings.NewReader(testCase.input)) require.NoError(t, err) - result, err := specFile.GetHighestPatchTagNumber(testCase.packageName) + result, err := specFile.GetHighestPatchTagNumber() require.NoError(t, err) assert.Equal(t, testCase.expected, result) }) @@ -1281,7 +1273,6 @@ func TestRemovePatchEntry(t *testing.T) { tests := []struct { name string input string - packageName string pattern string expectedOutput string expectedFailure bool @@ -1337,6 +1328,20 @@ func TestRemovePatchEntry(t *testing.T) { pattern: "*.patch", expectedOutput: "Name: test\n", }, + { + name: "removes matching patches across all packages", + input: "Name: test\nPatch0: CVE-001.patch\nPatch1: keep.patch\n\n%package devel\n" + + "Summary: Dev\nPatch2: CVE-002.patch\nPatch3: also-keep.patch\n", + pattern: "CVE-*.patch", + expectedOutput: "Name: test\nPatch1: keep.patch\n\n%package devel\nSummary: Dev\nPatch3: also-keep.patch\n", + }, + { + name: "no match across multiple packages returns error", + input: "Name: test\nPatch0: keep.patch\n\n%package devel\nSummary: Dev\nPatch1: also-keep.patch\n", + pattern: "nonexistent.patch", + expectedFailure: true, + errorContains: "no patches matching", + }, } for _, testCase := range tests { @@ -1344,7 +1349,7 @@ func TestRemovePatchEntry(t *testing.T) { specFile, err := spec.OpenSpec(strings.NewReader(testCase.input)) require.NoError(t, err) - err = specFile.RemovePatchEntry(testCase.packageName, testCase.pattern) + err = specFile.RemovePatchEntry(testCase.pattern) if testCase.expectedFailure { require.Error(t, err) @@ -1390,3 +1395,104 @@ func TestParsePatchTagNumber(t *testing.T) { }) } } + +func TestVisitTags(t *testing.T) { + input := `Name: main-pkg +Version: 1.0 +Patch0: main.patch + +%package devel +Summary: Development files +Patch1: devel.patch + +%package -n other +Summary: Other package +Patch2: other.patch +` + + tests := []struct { + name string + expectedTags []string + }{ + { + name: "visits tags across all packages", + expectedTags: []string{"Name", "Version", "Patch0", "Summary", "Patch1", "Summary", "Patch2"}, + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + sf, err := spec.OpenSpec(strings.NewReader(input)) + require.NoError(t, err) + + var tags []string + + err = sf.VisitTags(func(tagLine *spec.TagLine, _ *spec.Context) error { + tags = append(tags, tagLine.Tag) + + return nil + }) + require.NoError(t, err) + assert.Equal(t, testCase.expectedTags, tags) + }) + } +} + +func TestVisitTagsPackage(t *testing.T) { + input := `Name: main-pkg +Version: 1.0 +Patch0: main.patch + +%package devel +Summary: Development files +Patch1: devel.patch + +%package -n other +Summary: Other package +Patch2: other.patch +` + + tests := []struct { + name string + packageName string + expectedTags []string + }{ + { + name: "global package only", + packageName: "", + expectedTags: []string{"Name", "Version", "Patch0"}, + }, + { + name: "devel sub-package only", + packageName: "devel", + expectedTags: []string{"Summary", "Patch1"}, + }, + { + name: "other sub-package only", + packageName: "other", + expectedTags: []string{"Summary", "Patch2"}, + }, + { + name: "non-existing package returns no tags", + packageName: "nonexistent", + expectedTags: nil, + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + sf, err := spec.OpenSpec(strings.NewReader(input)) + require.NoError(t, err) + + var tags []string + + err = sf.VisitTagsPackage(testCase.packageName, func(tagLine *spec.TagLine, _ *spec.Context) error { + tags = append(tags, tagLine.Tag) + + return nil + }) + require.NoError(t, err) + assert.Equal(t, testCase.expectedTags, tags) + }) + } +}