From b86c6649c0558c174ed6565e94ab76926c5d7670 Mon Sep 17 00:00:00 2001 From: "ccf-lisa[bot]" <286799724+ccf-lisa[bot]@users.noreply.github.com> Date: Mon, 15 Jun 2026 11:02:40 -0300 Subject: [PATCH 1/6] implement: bch-1309-core-suggestion-engine --- .../service/relational/suggestions/core.go | 208 ++++++++ .../relational/suggestions/core_test.go | 171 ++++++ .../service/relational/suggestions/gather.go | 385 ++++++++++++++ .../suggestions/models_integration_test.go | 177 +++++++ .../service/relational/suggestions/prompt.go | 116 ++++ .../service/relational/suggestions/service.go | 497 ++++++++++++++++++ .../suggestions/testdata/prompt_v1.golden | 82 +++ .../relational/suggestions/validation.go | 254 +++++++++ 8 files changed, 1890 insertions(+) create mode 100644 internal/service/relational/suggestions/core.go create mode 100644 internal/service/relational/suggestions/core_test.go create mode 100644 internal/service/relational/suggestions/gather.go create mode 100644 internal/service/relational/suggestions/prompt.go create mode 100644 internal/service/relational/suggestions/service.go create mode 100644 internal/service/relational/suggestions/testdata/prompt_v1.golden create mode 100644 internal/service/relational/suggestions/validation.go diff --git a/internal/service/relational/suggestions/core.go b/internal/service/relational/suggestions/core.go new file mode 100644 index 00000000..3dc95d05 --- /dev/null +++ b/internal/service/relational/suggestions/core.go @@ -0,0 +1,208 @@ +package suggestions + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "math" + "sort" + "strings" + + "github.com/compliance-framework/api/internal/converters/labelfilter" +) + +const ( + DefaultMaxControlsPerChunk = 40 + DefaultMaxLabelSetsPerChunk = 200 + DefaultMaxSuggestionsPerRun = 500 + MaxMappingsPerControlPerCell = 10 + MaxReasoningLength = 2000 + ReasoningTruncatedMarker = "\n[truncated]" + DashboardSuggestionStatusPending = "pending" + DashboardSuggestionStatusAccepted = "accepted" + DashboardSuggestionStatusRejected = "rejected" +) + +type ChunkConfig struct { + MaxControlsPerChunk int + MaxLabelSetsPerChunk int +} + +type Scope struct { + ControlKeys []string + LabelSetHashes []string +} + +type Snapshot struct { + ControlKeys []string `json:"controlKeys"` + LabelSetHashes []string `json:"labelSetHashes"` +} + +type GridCell struct { + CellIndex int + ControlKeys []string + LabelSetHashes []string +} + +type ScopeError struct { + UnknownControlKeys []string + UnknownLabelSetHashes []string +} + +func (e *ScopeError) Error() string { + parts := make([]string, 0, 2) + if len(e.UnknownControlKeys) > 0 { + parts = append(parts, "unknown control keys: "+strings.Join(e.UnknownControlKeys, ", ")) + } + if len(e.UnknownLabelSetHashes) > 0 { + parts = append(parts, "unknown label-set hashes: "+strings.Join(e.UnknownLabelSetHashes, ", ")) + } + if len(parts) == 0 { + return "invalid suggestion scope" + } + return strings.Join(parts, "; ") +} + +func ResolveSnapshot(scope Scope, allControlKeys []string, allLabelSetHashes []string) (Snapshot, error) { + controls := append([]string(nil), allControlKeys...) + labelSets := append([]string(nil), allLabelSetHashes...) + sort.Strings(controls) + sort.Strings(labelSets) + controlSet := stringSet(controls) + labelSet := stringSet(labelSets) + selectedControls, unknownControls := selectDimension(scope.ControlKeys, controls, controlSet) + selectedLabelSets, unknownLabelSets := selectDimension(scope.LabelSetHashes, labelSets, labelSet) + if len(unknownControls) > 0 || len(unknownLabelSets) > 0 { + return Snapshot{}, &ScopeError{UnknownControlKeys: unknownControls, UnknownLabelSetHashes: unknownLabelSets} + } + return Snapshot{ControlKeys: selectedControls, LabelSetHashes: selectedLabelSets}, nil +} + +func CanonicalLabelSetHash(labels map[string]string) string { + lines := canonicalLabelLines(labels) + sum := sha256.Sum256([]byte(strings.Join(lines, "\n"))) + return hex.EncodeToString(sum[:]) +} + +func BuildLabelFilter(labels map[string]string) labelfilter.Filter { + keys := make([]string, 0, len(labels)) + for key := range labels { + keys = append(keys, key) + } + sort.Strings(keys) + + scopes := make([]labelfilter.Scope, 0, len(keys)) + for _, key := range keys { + scopes = append(scopes, labelfilter.Scope{ + Condition: &labelfilter.Condition{ + Label: key, + Operator: "=", + Value: labels[key], + }, + }) + } + + if len(scopes) == 1 { + return labelfilter.Filter{Scope: &scopes[0]} + } + + return labelfilter.Filter{ + Scope: &labelfilter.Scope{ + Query: &labelfilter.Query{ + Operator: "AND", + Scopes: scopes, + }, + }, + } +} + +func CanonicalizeFilter(filter labelfilter.Filter) (map[string]string, bool) { + if filter.Scope == nil { + return map[string]string{}, true + } + labels := map[string]string{} + if !canonicalizeScope(*filter.Scope, labels) { + return nil, false + } + return labels, true +} + +func canonicalizeScope(scope labelfilter.Scope, labels map[string]string) bool { + if scope.Condition != nil { + condition := scope.Condition + if condition.Operator != "=" { + return false + } + labels[strings.ToLower(condition.Label)] = condition.Value + return true + } + query := scope.Query + if query == nil || !strings.EqualFold(query.Operator, "AND") { + return false + } + for _, child := range query.Scopes { + if !canonicalizeScope(child, labels) { + return false + } + } + return true +} + +func canonicalLabelLines(labels map[string]string) []string { + keys := make([]string, 0, len(labels)) + for key := range labels { + keys = append(keys, strings.ToLower(key)) + } + sort.Strings(keys) + + lines := make([]string, 0, len(keys)) + for _, lowerKey := range keys { + value := "" + for key, candidate := range labels { + if strings.ToLower(key) == lowerKey { + value = candidate + break + } + } + lines = append(lines, fmt.Sprintf("%s=%s", lowerKey, value)) + } + return lines +} + +func PlannedCalls(controlCount, labelSetCount int, cfg ChunkConfig) int { + cfg = normalizeChunkConfig(cfg) + if controlCount == 0 || labelSetCount == 0 { + return 0 + } + return int(math.Ceil(float64(controlCount)/float64(cfg.MaxControlsPerChunk))) * + int(math.Ceil(float64(labelSetCount)/float64(cfg.MaxLabelSetsPerChunk))) +} + +func BuildGrid(snapshot Snapshot, cfg ChunkConfig) []GridCell { + cfg = normalizeChunkConfig(cfg) + cells := make([]GridCell, 0, PlannedCalls(len(snapshot.ControlKeys), len(snapshot.LabelSetHashes), cfg)) + cellIndex := 0 + for cStart := 0; cStart < len(snapshot.ControlKeys); cStart += cfg.MaxControlsPerChunk { + cEnd := min(cStart+cfg.MaxControlsPerChunk, len(snapshot.ControlKeys)) + for lStart := 0; lStart < len(snapshot.LabelSetHashes); lStart += cfg.MaxLabelSetsPerChunk { + lEnd := min(lStart+cfg.MaxLabelSetsPerChunk, len(snapshot.LabelSetHashes)) + cells = append(cells, GridCell{ + CellIndex: cellIndex, + ControlKeys: append([]string(nil), snapshot.ControlKeys[cStart:cEnd]...), + LabelSetHashes: append([]string(nil), snapshot.LabelSetHashes[lStart:lEnd]...), + }) + cellIndex++ + } + } + return cells +} + +func normalizeChunkConfig(cfg ChunkConfig) ChunkConfig { + if cfg.MaxControlsPerChunk <= 0 { + cfg.MaxControlsPerChunk = DefaultMaxControlsPerChunk + } + if cfg.MaxLabelSetsPerChunk <= 0 { + cfg.MaxLabelSetsPerChunk = DefaultMaxLabelSetsPerChunk + } + return cfg +} diff --git a/internal/service/relational/suggestions/core_test.go b/internal/service/relational/suggestions/core_test.go new file mode 100644 index 00000000..7720313e --- /dev/null +++ b/internal/service/relational/suggestions/core_test.go @@ -0,0 +1,171 @@ +package suggestions + +import ( + "errors" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/compliance-framework/api/internal/converters/labelfilter" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +func TestCanonicalLabelSetHashStable(t *testing.T) { + got := CanonicalLabelSetHash(map[string]string{"repo": "API", "Env": "Prod"}) + require.Equal(t, "3839e61d314736aafb6a1e944edf2b75e8ed4d0bb8ff3d293194a5cc7bb70917", got) + require.Equal(t, got, CanonicalLabelSetHash(map[string]string{"env": "Prod", "Repo": "API"})) + require.NotEqual(t, got, CanonicalLabelSetHash(map[string]string{"env": "prod", "repo": "API"})) +} + +func TestBuildLabelFilterCanonicalizeRoundTrip(t *testing.T) { + labels := map[string]string{"env": "prod", "repo": "api"} + filter := BuildLabelFilter(labels) + got, ok := CanonicalizeFilter(filter) + require.True(t, ok) + require.Equal(t, labels, got) +} + +func TestCanonicalizeFilterRefusesOrAndNotEquals(t *testing.T) { + orFilter := labelfilter.Filter{Scope: &labelfilter.Scope{Query: &labelfilter.Query{ + Operator: "OR", + Scopes: []labelfilter.Scope{ + {Condition: &labelfilter.Condition{Label: "env", Operator: "=", Value: "prod"}}, + {Condition: &labelfilter.Condition{Label: "env", Operator: "=", Value: "staging"}}, + }, + }}} + _, ok := CanonicalizeFilter(orFilter) + require.False(t, ok) + + notEqualFilter := labelfilter.Filter{Scope: &labelfilter.Scope{Condition: &labelfilter.Condition{ + Label: "env", Operator: "!=", Value: "prod", + }}} + _, ok = CanonicalizeFilter(notEqualFilter) + require.False(t, ok) +} + +func TestResolveSnapshot(t *testing.T) { + controls := []string{"catalog-b:AC-2", "catalog-a:AC-1"} + hashes := []string{"h2", "h1"} + + snapshot, err := ResolveSnapshot(Scope{}, controls, hashes) + require.NoError(t, err) + require.Equal(t, Snapshot{ + ControlKeys: []string{"catalog-a:AC-1", "catalog-b:AC-2"}, + LabelSetHashes: []string{"h1", "h2"}, + }, snapshot) + + snapshot, err = ResolveSnapshot(Scope{ControlKeys: []string{"catalog-b:AC-2"}, LabelSetHashes: []string{"h1"}}, controls, hashes) + require.NoError(t, err) + require.Equal(t, []string{"catalog-b:AC-2"}, snapshot.ControlKeys) + require.Equal(t, []string{"h1"}, snapshot.LabelSetHashes) + + _, err = ResolveSnapshot(Scope{ControlKeys: []string{"missing"}, LabelSetHashes: []string{"h3"}}, controls, hashes) + var scopeErr *ScopeError + require.True(t, errors.As(err, &scopeErr)) + require.Equal(t, []string{"missing"}, scopeErr.UnknownControlKeys) + require.Equal(t, []string{"h3"}, scopeErr.UnknownLabelSetHashes) +} + +func TestBuildGridAndPlannedCalls(t *testing.T) { + snapshot := Snapshot{ + ControlKeys: []string{"c1", "c2", "c3"}, + LabelSetHashes: []string{"l1", "l2", "l3", "l4", "l5"}, + } + cfg := ChunkConfig{MaxControlsPerChunk: 2, MaxLabelSetsPerChunk: 2} + require.Equal(t, 6, PlannedCalls(len(snapshot.ControlKeys), len(snapshot.LabelSetHashes), cfg)) + + cells := BuildGrid(snapshot, cfg) + require.Len(t, cells, 6) + seen := map[string]int{} + for _, cell := range cells { + for _, control := range cell.ControlKeys { + for _, labelSet := range cell.LabelSetHashes { + seen[control+"|"+labelSet]++ + } + } + } + require.Len(t, seen, len(snapshot.ControlKeys)*len(snapshot.LabelSetHashes)) + for _, count := range seen { + require.Equal(t, 1, count) + } + require.Equal(t, 1, PlannedCalls(1, 1, ChunkConfig{})) +} + +func TestValidateMappingsRules(t *testing.T) { + controlKey := ControlKey(uuid.New(), "AC-1") + labelHash := CanonicalLabelSetHash(map[string]string{"env": "prod"}) + otherHash := CanonicalLabelSetHash(map[string]string{"env": "stage"}) + sameSSPFilterID := uuid.New() + globalFilterID := uuid.New() + + longReason := strings.Repeat("a", MaxReasoningLength+10) + input := CellInput{ + Controls: []ControlInput{{ControlKey: controlKey}}, + LabelSets: []LabelSetInput{ + {Hash: labelHash, Labels: map[string]string{"env": "prod"}}, + {Hash: otherHash, Labels: map[string]string{"env": "stage"}}, + }, + SameSSPFilters: []VisibleFilterInput{{ID: sameSSPFilterID, Name: "prod", LabelSetHash: &labelHash}}, + VisibleFilters: []VisibleFilterInput{{ID: globalFilterID, Name: "global", LabelSetHash: &labelHash}}, + } + + result := ValidateRawMappings(input, []RawMapping{ + {ControlKey: "missing", LabelSetHash: labelHash, Action: MappingActionNewFilter, ProposedFilterName: "bad", Confidence: 0.5, Reasoning: "x"}, + {ControlKey: controlKey, LabelSetHash: "missing", Action: MappingActionNewFilter, ProposedFilterName: "bad", Confidence: 0.5, Reasoning: "x"}, + {ControlKey: controlKey, LabelSetHash: labelHash, Action: MappingActionNewFilter, ProposedFilterName: "bad", Confidence: 1.5, Reasoning: "x"}, + {ControlKey: controlKey, LabelSetHash: labelHash, Action: MappingActionNewFilter, ProposedFilterName: "bad", Confidence: 0.5, Reasoning: ""}, + {ControlKey: controlKey, LabelSetHash: labelHash, Action: MappingActionExtendFilter, TargetFilterID: globalFilterID.String(), Confidence: 0.4, Reasoning: longReason}, + {ControlKey: controlKey, LabelSetHash: labelHash, Action: MappingActionExtendFilter, TargetFilterID: sameSSPFilterID.String(), Confidence: 0.9, Reasoning: "better"}, + {ControlKey: controlKey, LabelSetHash: labelHash, Action: MappingActionNewFilter, ProposedFilterName: "lower", Confidence: 0.1, Reasoning: "dedupe"}, + {ControlKey: controlKey, LabelSetHash: otherHash, Action: MappingActionNewFilter, Confidence: 0.8, Reasoning: "fallback"}, + }) + + require.Equal(t, 1, result.Counts["rejected_unknown_control"]) + require.Equal(t, 1, result.Counts["rejected_unknown_label_set"]) + require.Equal(t, 1, result.Counts["rejected_confidence_out_of_range"]) + require.Equal(t, 1, result.Counts["rejected_empty_reasoning"]) + require.Equal(t, 1, result.Counts["downgraded_extend_to_new"]) + require.Equal(t, 2, result.Counts["deduped_within_cell"]) + require.Equal(t, 2, result.Counts["fallback_name"]) + require.Len(t, result.Mappings, 2) + require.Equal(t, MappingActionExtendFilter, result.Mappings[0].Action) + require.Equal(t, sameSSPFilterID, *result.Mappings[0].TargetFilterID) + require.Equal(t, "env=stage", result.Mappings[1].ProposedFilterName) +} + +func TestValidateMappingsControlCap(t *testing.T) { + controlKey := ControlKey(uuid.New(), "AC-1") + input := CellInput{Controls: []ControlInput{{ControlKey: controlKey}}} + raw := make([]RawMapping, 0, MaxMappingsPerControlPerCell+1) + for i := 0; i < MaxMappingsPerControlPerCell+1; i++ { + hash := CanonicalLabelSetHash(map[string]string{"n": string(rune('a' + i))}) + input.LabelSets = append(input.LabelSets, LabelSetInput{Hash: hash, Labels: map[string]string{"n": string(rune('a' + i))}}) + raw = append(raw, RawMapping{ControlKey: controlKey, LabelSetHash: hash, Action: MappingActionNewFilter, ProposedFilterName: "x", Confidence: float64(i) / 100, Reasoning: "x"}) + } + result := ValidateRawMappings(input, raw) + require.Len(t, result.Mappings, MaxMappingsPerControlPerCell) + require.Equal(t, 1, result.Counts["dropped_control_cap"]) +} + +func TestPromptGolden(t *testing.T) { + input := GatheredInput{ + SystemContext: SystemContextInput{ + SystemName: "Payments API", + Description: "Processes card payments.", + Components: []SystemComponentInput{{Title: "payments-api", Type: "service", Purpose: "payment processing", Description: "Go API"}}, + }, + Controls: []ControlInput{{ControlKey: "11111111-1111-1111-1111-111111111111:AC-1", Title: "Policy", ImplementationText: "Uses payments-api."}}, + LabelSets: []LabelSetInput{{Hash: "hash1", Labels: map[string]string{"repo": "payments-api"}, EvidenceCount: 2, SampleTitles: []string{"scan"}}}, + LabelKeyDocs: []LabelKeyDocInput{{Key: "repo", Description: "Repository name"}}, + SameSSPFilters: []VisibleFilterInput{{ID: uuid.MustParse("22222222-2222-2222-2222-222222222222"), Name: "payments-api"}}, + } + gotPrompt, err := RenderPrompt(input) + require.NoError(t, err) + got := SystemPrompt + "\n---USER---\n" + gotPrompt + path := filepath.Join("testdata", "prompt_"+PromptVersion+".golden") + expected, err := os.ReadFile(path) + require.NoError(t, err) + require.Equal(t, strings.TrimSuffix(string(expected), "\n"), got) +} diff --git a/internal/service/relational/suggestions/gather.go b/internal/service/relational/suggestions/gather.go new file mode 100644 index 00000000..ce9fe604 --- /dev/null +++ b/internal/service/relational/suggestions/gather.go @@ -0,0 +1,385 @@ +package suggestions + +import ( + "encoding/json" + "errors" + "fmt" + "sort" + "strings" + + "github.com/compliance-framework/api/internal/service/relational" + "github.com/google/uuid" + "gorm.io/gorm" +) + +const ( + defaultMaxComponents = 25 + defaultMaxComponentTextLen = 800 + defaultMaxControlTextLen = 2000 +) + +func (s *SuggestionService) resolvedControlKeys(sspID uuid.UUID) ([]string, error) { + var rows []struct { + CatalogID uuid.UUID `gorm:"column:control_catalog_id"` + ControlID string `gorm:"column:control_id"` + } + if err := s.db. + Table("profile_controls"). + Select("DISTINCT profile_controls.control_catalog_id, profile_controls.control_id"). + Joins("JOIN ssp_profiles ON ssp_profiles.profile_id = profile_controls.profile_id"). + Where("ssp_profiles.system_security_plan_id = ?", sspID). + Order("profile_controls.control_catalog_id ASC, profile_controls.control_id ASC"). + Scan(&rows).Error; err != nil { + return nil, err + } + keys := make([]string, 0, len(rows)) + for _, row := range rows { + keys = append(keys, ControlKey(row.CatalogID, row.ControlID)) + } + sort.Strings(keys) + return keys, nil +} + +func (s *SuggestionService) currentLabelSetHashes() ([]string, error) { + labelSets, err := s.gatherAllLabelSets() + if err != nil { + return nil, err + } + hashes := make([]string, 0, len(labelSets)) + for _, labelSet := range labelSets { + hashes = append(hashes, labelSet.Hash) + } + sort.Strings(hashes) + return hashes, nil +} + +func (s *SuggestionService) gatherControls(sspID uuid.UUID, controlKeys []string, opts GatherOptions) ([]ControlInput, error) { + opts = normalizeGatherOptions(opts) + if len(controlKeys) == 0 { + return []ControlInput{}, nil + } + type controlRow struct { + CatalogID uuid.UUID `gorm:"column:catalog_id"` + ControlID string `gorm:"column:control_id"` + Title string `gorm:"column:title"` + Parts string `gorm:"column:parts"` + CatalogTitle string `gorm:"column:catalog_title"` + ImplementationText string `gorm:"column:implementation_text"` + } + keys := make(map[string]struct{}, len(controlKeys)) + for _, key := range controlKeys { + keys[key] = struct{}{} + } + catalogIDs := make([]uuid.UUID, 0, len(controlKeys)) + controlIDs := make([]string, 0, len(controlKeys)) + for _, key := range controlKeys { + catalogID, controlID, err := ParseControlKey(key) + if err != nil { + return nil, err + } + catalogIDs = append(catalogIDs, catalogID) + controlIDs = append(controlIDs, controlID) + } + + var rows []controlRow + if err := s.db.Raw(` + SELECT + c.catalog_id, + c.id AS control_id, + c.title, + c.parts::text AS parts, + COALESCE(m.title, '') AS catalog_title, + TRIM(CONCAT_WS(E'\n', + NULLIF(ir.remarks, ''), + string_agg(NULLIF(st.remarks, ''), E'\n' ORDER BY st.statement_id) + )) AS implementation_text + FROM controls c + JOIN profile_controls pc ON pc.control_catalog_id = c.catalog_id AND pc.control_id = c.id + JOIN ssp_profiles sp ON sp.profile_id = pc.profile_id AND sp.system_security_plan_id = @ssp_id + LEFT JOIN metadata m ON m.parent_type = 'catalogs' AND m.parent_id = c.catalog_id::text + LEFT JOIN control_implementations ci ON ci.system_security_plan_id = @ssp_id + LEFT JOIN implemented_requirements ir ON ir.control_implementation_id = ci.id AND UPPER(ir.control_id) = UPPER(c.id) + LEFT JOIN statements st ON st.implemented_requirement_id = ir.id + WHERE c.catalog_id IN @catalog_ids AND c.id IN @control_ids + GROUP BY c.catalog_id, c.id, c.title, c.parts, m.title, ir.remarks + ORDER BY c.catalog_id ASC, c.id ASC + `, map[string]any{ + "ssp_id": sspID, + "catalog_ids": catalogIDs, + "control_ids": controlIDs, + }).Scan(&rows).Error; err != nil { + return nil, err + } + + out := make([]ControlInput, 0, len(rows)) + seen := map[string]struct{}{} + for _, row := range rows { + key := ControlKey(row.CatalogID, row.ControlID) + if _, wanted := keys[key]; !wanted { + continue + } + if _, duplicate := seen[key]; duplicate { + continue + } + seen[key] = struct{}{} + out = append(out, ControlInput{ + ControlKey: key, + CatalogID: row.CatalogID.String(), + ControlID: row.ControlID, + CatalogTitle: row.CatalogTitle, + Title: row.Title, + Statement: truncate(extractPartText(row.Parts), opts.MaxControlTextLen), + ImplementationText: truncate(row.ImplementationText, opts.MaxControlTextLen), + }) + } + sort.Slice(out, func(i, j int) bool { return out[i].ControlKey < out[j].ControlKey }) + if len(out) != len(controlKeys) { + return nil, fmt.Errorf("failed to gather all scoped controls") + } + return out, nil +} + +func (s *SuggestionService) gatherLabelSets(hashes []string) ([]LabelSetInput, error) { + all, err := s.gatherAllLabelSets() + if err != nil { + return nil, err + } + wanted := stringSet(hashes) + out := make([]LabelSetInput, 0, len(hashes)) + for _, labelSet := range all { + if _, ok := wanted[labelSet.Hash]; ok { + out = append(out, labelSet) + } + } + sort.Slice(out, func(i, j int) bool { return out[i].Hash < out[j].Hash }) + return out, nil +} + +func (s *SuggestionService) gatherAllLabelSets() ([]LabelSetInput, error) { + type row struct { + EvidenceID uuid.UUID `gorm:"column:evidence_id"` + EvidenceUUID uuid.UUID `gorm:"column:evidence_uuid"` + Title string `gorm:"column:title"` + LabelName string `gorm:"column:labels_name"` + LabelValue string `gorm:"column:labels_value"` + } + var rows []row + if err := s.db.Raw(` + WITH latest AS ( + SELECT DISTINCT ON (uuid) id, uuid, title, "end" + FROM evidences + ORDER BY uuid, "end" DESC + ) + SELECT latest.id AS evidence_id, latest.uuid AS evidence_uuid, latest.title, el.labels_name, el.labels_value + FROM latest + JOIN evidence_labels el ON el.evidence_id = latest.id + ORDER BY latest.uuid ASC, el.labels_name ASC, el.labels_value ASC + `).Scan(&rows).Error; err != nil { + if strings.Contains(err.Error(), "no such table") { + return []LabelSetInput{}, nil + } + return nil, err + } + + type evidenceGroup struct { + title string + labels map[string]string + } + byEvidence := map[uuid.UUID]*evidenceGroup{} + for _, row := range rows { + group := byEvidence[row.EvidenceID] + if group == nil { + group = &evidenceGroup{title: row.Title, labels: map[string]string{}} + byEvidence[row.EvidenceID] = group + } + group.labels[row.LabelName] = row.LabelValue + } + + byHash := map[string]*LabelSetInput{} + for _, group := range byEvidence { + hash := CanonicalLabelSetHash(group.labels) + labelSet := byHash[hash] + if labelSet == nil { + copied := make(map[string]string, len(group.labels)) + for key, value := range group.labels { + copied[key] = value + } + labelSet = &LabelSetInput{Hash: hash, Labels: copied} + byHash[hash] = labelSet + } + labelSet.EvidenceCount++ + if group.title != "" && len(labelSet.SampleTitles) < 3 { + labelSet.SampleTitles = append(labelSet.SampleTitles, group.title) + sort.Strings(labelSet.SampleTitles) + } + } + + out := make([]LabelSetInput, 0, len(byHash)) + for _, labelSet := range byHash { + out = append(out, *labelSet) + } + sort.Slice(out, func(i, j int) bool { return out[i].Hash < out[j].Hash }) + return out, nil +} + +func (s *SuggestionService) gatherSystemContext(sspID uuid.UUID, opts GatherOptions, stats map[string]int) (SystemContextInput, error) { + opts = normalizeGatherOptions(opts) + var characteristics relational.SystemCharacteristics + err := s.db.Where("system_security_plan_id = ?", sspID).First(&characteristics).Error + if err != nil && !errorsIsRecordNotFound(err) { + return SystemContextInput{}, err + } + + type componentRow struct { + Title string + Type string + Purpose string + Description string + } + var rows []componentRow + if err := s.db. + Table("system_components"). + Select("title, type, purpose, description"). + Joins("JOIN system_implementations ON system_implementations.id = system_components.system_implementation_id"). + Where("system_implementations.system_security_plan_id = ?", sspID). + Order("system_components.title ASC, system_components.id ASC"). + Find(&rows).Error; err != nil { + return SystemContextInput{}, err + } + if len(rows) > opts.MaxComponents { + stats["system_components_overflow"] = len(rows) - opts.MaxComponents + rows = rows[:opts.MaxComponents] + } + components := make([]SystemComponentInput, 0, len(rows)) + for _, row := range rows { + component := SystemComponentInput{ + Title: truncate(row.Title, opts.MaxComponentTextLen), + Type: truncate(row.Type, opts.MaxComponentTextLen), + Purpose: truncate(row.Purpose, opts.MaxComponentTextLen), + Description: truncate(row.Description, opts.MaxComponentTextLen), + } + components = append(components, component) + } + return SystemContextInput{ + SystemName: characteristics.SystemName, + Description: truncate(characteristics.Description, opts.MaxControlTextLen), + Components: components, + }, nil +} + +func (s *SuggestionService) gatherLabelKeyDocs() ([]LabelKeyDocInput, error) { + type row struct { + Key string + Description *string + } + var rows []row + if err := s.db.Raw(` + SELECT key, description FROM subject_template_label_schema_fields + UNION + SELECT key, description FROM risk_template_label_schema_fields + ORDER BY key ASC + `).Scan(&rows).Error; err != nil { + if strings.Contains(err.Error(), "no such table") { + return []LabelKeyDocInput{}, nil + } + return nil, err + } + merged := map[string]string{} + for _, row := range rows { + if _, exists := merged[row.Key]; !exists && row.Description != nil { + merged[row.Key] = *row.Description + } + } + keys := make([]string, 0, len(merged)) + for key := range merged { + keys = append(keys, key) + } + sort.Strings(keys) + out := make([]LabelKeyDocInput, 0, len(keys)) + for _, key := range keys { + out = append(out, LabelKeyDocInput{Key: key, Description: merged[key]}) + } + return out, nil +} + +func (s *SuggestionService) gatherVisibleFilters(sspID uuid.UUID) ([]VisibleFilterInput, []VisibleFilterInput, []string, error) { + var filters []relational.Filter + if err := s.db. + Where("ssp_id IS NULL OR ssp_id = ?", sspID). + Order("name ASC, id ASC"). + Find(&filters).Error; err != nil { + return nil, nil, nil, err + } + visible := make([]VisibleFilterInput, 0, len(filters)) + sameSSP := make([]VisibleFilterInput, 0) + globalNames := make([]string, 0) + for _, filter := range filters { + input := VisibleFilterInput{ + ID: *filter.ID, + Name: filter.Name, + SSPID: filter.SSPID, + } + if labels, ok := CanonicalizeFilter(filter.Filter.Data()); ok { + hash := CanonicalLabelSetHash(labels) + input.LabelSetHash = &hash + } + visible = append(visible, input) + if filter.SSPID != nil && *filter.SSPID == sspID { + sameSSP = append(sameSSP, input) + } + if filter.SSPID == nil { + globalNames = append(globalNames, filter.Name) + } + } + sort.Strings(globalNames) + return visible, sameSSP, globalNames, nil +} + +func normalizeGatherOptions(opts GatherOptions) GatherOptions { + if opts.MaxComponents <= 0 { + opts.MaxComponents = defaultMaxComponents + } + if opts.MaxComponentTextLen <= 0 { + opts.MaxComponentTextLen = defaultMaxComponentTextLen + } + if opts.MaxControlTextLen <= 0 { + opts.MaxControlTextLen = defaultMaxControlTextLen + } + return opts +} + +func truncate(value string, limit int) string { + value = strings.TrimSpace(value) + if limit <= 0 || len(value) <= limit { + return value + } + return value[:limit] + ReasoningTruncatedMarker +} + +func extractPartText(partsJSON string) string { + if strings.TrimSpace(partsJSON) == "" { + return "" + } + var parts []relational.Part + if err := json.Unmarshal([]byte(partsJSON), &parts); err != nil { + return "" + } + var lines []string + var walk func([]relational.Part) + walk = func(items []relational.Part) { + for _, part := range items { + if part.Title != "" || part.Prose != "" { + lines = append(lines, strings.TrimSpace(strings.Join([]string{part.Title, part.Prose}, " "))) + } + if len(part.Parts) > 0 { + walk(part.Parts) + } + } + } + walk(parts) + return strings.Join(lines, "\n") +} + +func errorsIsRecordNotFound(err error) bool { + return err == nil || errors.Is(err, gorm.ErrRecordNotFound) +} diff --git a/internal/service/relational/suggestions/models_integration_test.go b/internal/service/relational/suggestions/models_integration_test.go index 4538f0ad..16c54ee0 100644 --- a/internal/service/relational/suggestions/models_integration_test.go +++ b/internal/service/relational/suggestions/models_integration_test.go @@ -228,3 +228,180 @@ func (suite *DashboardSuggestionsIntegrationSuite) TestDashboardSuggestionReason ).Error suite.Error(err) } + +func (suite *DashboardSuggestionsIntegrationSuite) TestAcceptCreatesOneSSPBoundFilterAndLinksControls() { + sspID := uuid.New() + runID := uuid.New() + catalogID := uuid.New() + actorID := uuid.New() + labels := map[string]string{"env": "prod", "repo": "payments-api"} + hash := suggestionrel.CanonicalLabelSetHash(labels) + suite.seedSuggestionSSPAndRun(sspID, runID) + + low := suite.seedDashboardSuggestion(runID, sspID, catalogID, "AC-1", labels, hash, "low name", 0.4, nil) + high := suite.seedDashboardSuggestion(runID, sspID, catalogID, "AC-2", labels, hash, "high name", 0.9, nil) + + svc := suggestionrel.NewSuggestionService(suite.DB) + suite.Require().NoError(svc.Accept(sspID, []uuid.UUID{*low.ID, *high.ID}, actorID)) + + var filters []relational.Filter + suite.Require().NoError(suite.DB.Where("ssp_id = ?", sspID).Find(&filters).Error) + suite.Require().Len(filters, 1) + suite.Equal("high name", filters[0].Name) + filterLabels, ok := suggestionrel.CanonicalizeFilter(filters[0].Filter.Data()) + suite.True(ok) + suite.Equal(labels, filterLabels) + + var linkCount int64 + suite.Require().NoError(suite.DB.Table("filter_controls").Where("filter_id = ?", filters[0].ID).Count(&linkCount).Error) + suite.Equal(int64(2), linkCount) + + var accepted []suggestionrel.DashboardSuggestion + suite.Require().NoError(suite.DB.Where("id IN ?", []uuid.UUID{*low.ID, *high.ID}).Find(&accepted).Error) + for _, suggestion := range accepted { + suite.Equal(suggestionrel.DashboardSuggestionStatusAccepted, suggestion.Status) + suite.Require().NotNil(suggestion.AcceptedFilterID) + suite.Equal(*filters[0].ID, *suggestion.AcceptedFilterID) + } + + var eventCount int64 + suite.Require().NoError(suite.DB.Model(&suggestionrel.DashboardSuggestionEvent{}). + Where("event_type = ?", string(suggestionrel.DashboardSuggestionEventTypeAccepted)). + Count(&eventCount).Error) + suite.Equal(int64(2), eventCount) +} + +func (suite *DashboardSuggestionsIntegrationSuite) TestAcceptExtendsSameSSPMatchingFilter() { + sspID := uuid.New() + runID := uuid.New() + catalogID := uuid.New() + actorID := uuid.New() + labels := map[string]string{"env": "prod"} + hash := suggestionrel.CanonicalLabelSetHash(labels) + suite.seedSuggestionSSPAndRun(sspID, runID) + + filter := relational.Filter{Name: "existing", SSPID: &sspID, Filter: datatypes.NewJSONType(suggestionrel.BuildLabelFilter(labels))} + suite.Require().NoError(suite.DB.Create(&filter).Error) + suggestion := suite.seedDashboardSuggestion(runID, sspID, catalogID, "AC-1", labels, hash, "ignored", 0.8, filter.ID) + + svc := suggestionrel.NewSuggestionService(suite.DB) + suite.Require().NoError(svc.Accept(sspID, []uuid.UUID{*suggestion.ID}, actorID)) + + var filterCount int64 + suite.Require().NoError(suite.DB.Model(&relational.Filter{}).Where("ssp_id = ?", sspID).Count(&filterCount).Error) + suite.Equal(int64(1), filterCount) + + var linkCount int64 + suite.Require().NoError(suite.DB.Table("filter_controls").Where("filter_id = ? AND control_id = ?", filter.ID, "AC-1").Count(&linkCount).Error) + suite.Equal(int64(1), linkCount) +} + +func (suite *DashboardSuggestionsIntegrationSuite) TestInsertExcludesMatchingGlobalFilterAndDoesNotModifyIt() { + sspID := uuid.New() + runID := uuid.New() + catalogID := uuid.New() + labels := map[string]string{"env": "prod"} + hash := suggestionrel.CanonicalLabelSetHash(labels) + suite.seedSuggestionSSPAndRun(sspID, runID) + + global := relational.Filter{Name: "global", Filter: datatypes.NewJSONType(suggestionrel.BuildLabelFilter(labels))} + suite.Require().NoError(suite.DB.Create(&global).Error) + suite.Require().NoError(suite.DB.Exec( + `INSERT INTO filter_controls (filter_id, control_catalog_id, control_id) VALUES (?, ?, ?)`, + global.ID, catalogID, "AC-1", + ).Error) + + svc := suggestionrel.NewSuggestionService(suite.DB) + result, err := svc.InsertValidatedMappings(runID, sspID, suggestionrel.PromptVersion, []suggestionrel.ValidatedMapping{{ + ControlKey: suggestionrel.ControlKey(catalogID, "AC-1"), + LabelSetHash: hash, + LabelSet: labels, + Action: suggestionrel.MappingActionNewFilter, + ProposedFilterName: "new", + Confidence: 0.8, + Reasoning: "matches", + }}, 10) + suite.Require().NoError(err) + suite.Equal(0, result.Inserted) + suite.Equal(1, result.Excluded) + + var suggestionCount int64 + suite.Require().NoError(suite.DB.Model(&suggestionrel.DashboardSuggestion{}).Where("run_id = ?", runID).Count(&suggestionCount).Error) + suite.Zero(suggestionCount) + + var reloaded relational.Filter + suite.Require().NoError(suite.DB.First(&reloaded, "id = ?", global.ID).Error) + suite.Nil(reloaded.SSPID) +} + +func (suite *DashboardSuggestionsIntegrationSuite) TestAcceptSSPIsolationAndGlobalFiltersStayVisible() { + sspA := uuid.New() + sspB := uuid.New() + runID := uuid.New() + catalogID := uuid.New() + actorID := uuid.New() + labels := map[string]string{"env": "prod"} + hash := suggestionrel.CanonicalLabelSetHash(labels) + suite.seedSuggestionSSPAndRun(sspA, runID) + suite.Require().NoError(suite.DB.Create(&relational.SystemSecurityPlan{UUIDModel: relational.UUIDModel{ID: &sspB}}).Error) + global := relational.Filter{Name: "global", Filter: datatypes.NewJSONType(suggestionrel.BuildLabelFilter(map[string]string{"env": "stage"}))} + suite.Require().NoError(suite.DB.Create(&global).Error) + + suggestion := suite.seedDashboardSuggestion(runID, sspA, catalogID, "AC-1", labels, hash, "ssp-a only", 0.8, nil) + svc := suggestionrel.NewSuggestionService(suite.DB) + suite.Require().NoError(svc.Accept(sspA, []uuid.UUID{*suggestion.ID}, actorID)) + + var sspBFilterCount int64 + suite.Require().NoError(suite.DB.Model(&relational.Filter{}).Where("ssp_id = ?", sspB).Count(&sspBFilterCount).Error) + suite.Zero(sspBFilterCount) + + var globalCount int64 + suite.Require().NoError(suite.DB.Model(&relational.Filter{}).Where("id = ? AND ssp_id IS NULL", global.ID).Count(&globalCount).Error) + suite.Equal(int64(1), globalCount) +} + +func (suite *DashboardSuggestionsIntegrationSuite) seedSuggestionSSPAndRun(sspID uuid.UUID, runID uuid.UUID) { + suite.Require().NoError(suite.DB.Create(&relational.SystemSecurityPlan{UUIDModel: relational.UUIDModel{ID: &sspID}}).Error) + suite.Require().NoError(suite.DB.Create(&suggestionrel.DashboardSuggestionRun{ + UUIDModel: relational.UUIDModel{ID: &runID}, + SSPID: sspID, + Status: "completed", + Model: "test-model", + PromptVersion: suggestionrel.PromptVersion, + Scope: datatypes.JSONMap{"controlKeys": []string{}, "labelSetHashes": []string{}}, + PlannedCalls: 1, + SuggestionCount: 0, + Stats: datatypes.JSONMap{}, + }).Error) +} + +func (suite *DashboardSuggestionsIntegrationSuite) seedDashboardSuggestion( + runID uuid.UUID, + sspID uuid.UUID, + catalogID uuid.UUID, + controlID string, + labels map[string]string, + hash string, + name string, + confidence float64, + targetFilterID *uuid.UUID, +) suggestionrel.DashboardSuggestion { + suggestion := suggestionrel.DashboardSuggestion{ + RunID: runID, + SSPID: sspID, + ControlCatalogID: catalogID, + ControlID: controlID, + LabelSet: datatypes.JSONMap{}, + LabelSetHash: hash, + TargetFilterID: targetFilterID, + ProposedFilterName: name, + Reasoning: "Evidence satisfies the control and belongs to the system.", + Confidence: confidence, + Status: suggestionrel.DashboardSuggestionStatusPending, + } + for key, value := range labels { + suggestion.LabelSet[key] = value + } + suite.Require().NoError(suite.DB.Create(&suggestion).Error) + return suggestion +} diff --git a/internal/service/relational/suggestions/prompt.go b/internal/service/relational/suggestions/prompt.go new file mode 100644 index 00000000..5daf8973 --- /dev/null +++ b/internal/service/relational/suggestions/prompt.go @@ -0,0 +1,116 @@ +package suggestions + +import ( + "bytes" + "encoding/json" + "text/template" +) + +const PromptVersion = "v1" + +const SystemPrompt = `You map compliance evidence streams to security controls for a specific SSP. + +You are given: +1. The system context: name, description, components it uses, plus each control's implementation text from this plan. +2. A subset of the controls in this SSP. +3. A subset of evidence label-sets with sample titles. +4. Label-key documentation. +5. The dashboards visible to this plan. + +The controls and label-sets shown are one slice of a larger set. Only map what you see; absence here means nothing. + +For each shown control that a shown label-set provides evidence for, emit a mapping, but only when the evidence pertains to this system. When a label identifies an asset such as a repository, cluster, account, host, service, image, namespace, project, or environment, that asset must correspond to a component or description in the system context. Exclude evidence for assets this system does not use. + +Respect qualifiers in the control text. A control scoped to a provider, technology, component, environment, or other qualifier only matches evidence whose labels satisfy that qualifier. + +Use extend_filter with target_filter_id only when one of this plan's own dashboards has exactly that label-set. Otherwise use new_filter with a short descriptive proposed_filter_name. Global dashboards are listed only to avoid duplicate names; never extend them. + +Only reference control_key and label_set_hash values present in the input. Reasoning must state both why the evidence satisfies the control and why it belongs to this system. Provide confidence from 0 to 1.` + +func OutputSchema() map[string]any { + return map[string]any{ + "type": "object", + "additionalProperties": false, + "required": []any{"mappings"}, + "properties": map[string]any{ + "mappings": map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "object", + "additionalProperties": false, + "required": []any{"control_key", "label_set_hash", "action", "confidence", "reasoning"}, + "properties": map[string]any{ + "control_key": map[string]any{ + "type": "string", + }, + "label_set_hash": map[string]any{ + "type": "string", + }, + "action": map[string]any{ + "type": "string", + "enum": []any{MappingActionNewFilter, MappingActionExtendFilter}, + }, + "target_filter_id": map[string]any{ + "type": "string", + }, + "proposed_filter_name": map[string]any{ + "type": "string", + }, + "confidence": map[string]any{ + "type": "number", + }, + "reasoning": map[string]any{ + "type": "string", + }, + }, + }, + }, + }, + } +} + +const userPromptTemplate = `Prompt version: {{.PromptVersion}} + +System context: +{{json .Input.SystemContext}} + +Controls: +{{json .Input.Controls}} + +Evidence label-sets: +{{json .Input.LabelSets}} + +Label-key documentation: +{{json .Input.LabelKeyDocs}} + +This SSP's extendable dashboards: +{{json .Input.SameSSPFilters}} + +Global dashboard names: +{{json .Input.GlobalFilterNames}} + +Return JSON matching the provided schema.` + +func RenderPrompt(input GatheredInput) (string, error) { + tmpl, err := template.New("dashboard-suggestion-prompt").Funcs(template.FuncMap{ + "json": func(value any) (string, error) { + raw, err := json.MarshalIndent(value, "", " ") + if err != nil { + return "", err + } + return string(raw), nil + }, + }).Parse(userPromptTemplate) + if err != nil { + return "", err + } + var buf bytes.Buffer + err = tmpl.Execute(&buf, map[string]any{ + "PromptVersion": PromptVersion, + "Input": input, + }) + if err != nil { + return "", err + } + return buf.String(), nil +} diff --git a/internal/service/relational/suggestions/service.go b/internal/service/relational/suggestions/service.go new file mode 100644 index 00000000..95575e25 --- /dev/null +++ b/internal/service/relational/suggestions/service.go @@ -0,0 +1,497 @@ +package suggestions + +import ( + "encoding/json" + "errors" + "fmt" + "sort" + "strings" + "time" + + "github.com/compliance-framework/api/internal/service/relational" + "github.com/google/uuid" + "gorm.io/datatypes" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +type SuggestionService struct { + db *gorm.DB +} + +func NewSuggestionService(db *gorm.DB) *SuggestionService { + return &SuggestionService{db: db} +} + +type GatherOptions struct { + MaxComponents int + MaxComponentTextLen int + MaxControlTextLen int +} + +type GatheredInput struct { + Cell GridCell `json:"cell"` + Controls []ControlInput `json:"controls"` + LabelSets []LabelSetInput `json:"label_sets"` + SystemContext SystemContextInput `json:"system_context"` + LabelKeyDocs []LabelKeyDocInput `json:"label_key_docs"` + Filters []VisibleFilterInput `json:"filters"` + SameSSPFilters []VisibleFilterInput `json:"same_ssp_filters"` + GlobalFilterNames []string `json:"global_filter_names"` + Stats map[string]int `json:"stats"` +} + +type SystemContextInput struct { + SystemName string `json:"system_name"` + Description string `json:"description"` + Components []SystemComponentInput `json:"components"` +} + +type SystemComponentInput struct { + Title string `json:"title"` + Type string `json:"type"` + Purpose string `json:"purpose"` + Description string `json:"description"` +} + +type LabelKeyDocInput struct { + Key string `json:"key"` + Description string `json:"description"` +} + +type InsertMappingsResult struct { + Inserted int + Excluded int + Capped int +} + +type ConflictError struct { + IDs []uuid.UUID +} + +func (e *ConflictError) Error() string { + ids := make([]string, 0, len(e.IDs)) + for _, id := range e.IDs { + ids = append(ids, id.String()) + } + return "dashboard suggestions are not pending or do not belong to SSP: " + strings.Join(ids, ", ") +} + +func (s *SuggestionService) ResolveScope(sspID uuid.UUID, scope Scope) (Snapshot, error) { + controls, err := s.resolvedControlKeys(sspID) + if err != nil { + return Snapshot{}, err + } + labelSets, err := s.currentLabelSetHashes() + if err != nil { + return Snapshot{}, err + } + return ResolveSnapshot(scope, controls, labelSets) +} + +func (s *SuggestionService) GatherCellInput(sspID uuid.UUID, cell GridCell, opts GatherOptions) (GatheredInput, error) { + stats := map[string]int{} + controls, err := s.gatherControls(sspID, cell.ControlKeys, opts) + if err != nil { + return GatheredInput{}, err + } + labelSets, err := s.gatherLabelSets(cell.LabelSetHashes) + if err != nil { + return GatheredInput{}, err + } + systemContext, err := s.gatherSystemContext(sspID, opts, stats) + if err != nil { + return GatheredInput{}, err + } + labelDocs, err := s.gatherLabelKeyDocs() + if err != nil { + return GatheredInput{}, err + } + filters, sameSSPFilters, globalNames, err := s.gatherVisibleFilters(sspID) + if err != nil { + return GatheredInput{}, err + } + return GatheredInput{ + Cell: cell, + Controls: controls, + LabelSets: labelSets, + SystemContext: systemContext, + LabelKeyDocs: labelDocs, + Filters: filters, + SameSSPFilters: sameSSPFilters, + GlobalFilterNames: globalNames, + Stats: stats, + }, nil +} + +func (g GatheredInput) CellInput() CellInput { + return CellInput{ + Controls: g.Controls, + LabelSets: g.LabelSets, + VisibleFilters: g.Filters, + SameSSPFilters: g.SameSSPFilters, + GlobalFilterNames: g.GlobalFilterNames, + } +} + +func (s *SuggestionService) InsertValidatedMappings(runID uuid.UUID, sspID uuid.UUID, promptVersion string, mappings []ValidatedMapping, maxSuggestionsPerRun int) (InsertMappingsResult, error) { + if maxSuggestionsPerRun <= 0 { + maxSuggestionsPerRun = DefaultMaxSuggestionsPerRun + } + result := InsertMappingsResult{} + err := s.db.Transaction(func(tx *gorm.DB) error { + var run DashboardSuggestionRun + if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("id = ?", runID).First(&run).Error; err != nil { + return err + } + capacity := maxSuggestionsPerRun - run.SuggestionCount + if capacity <= 0 { + result.Capped = len(mappings) + return nil + } + + for _, mapping := range mappings { + if capacity <= 0 { + result.Capped++ + continue + } + excluded, err := s.mappingExcluded(tx, sspID, promptVersion, mapping) + if err != nil { + return err + } + if excluded { + result.Excluded++ + continue + } + catalogID, controlID, err := ParseControlKey(mapping.ControlKey) + if err != nil { + return err + } + suggestion := DashboardSuggestion{ + RunID: runID, + SSPID: sspID, + ControlCatalogID: catalogID, + ControlID: controlID, + LabelSet: labelsToJSONMap(mapping.LabelSet), + LabelSetHash: mapping.LabelSetHash, + TargetFilterID: mapping.TargetFilterID, + ProposedFilterName: mapping.ProposedFilterName, + Reasoning: mapping.Reasoning, + Confidence: mapping.Confidence, + Status: DashboardSuggestionStatusPending, + } + create := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&suggestion) + if create.Error != nil { + return create.Error + } + if create.RowsAffected == 0 { + result.Excluded++ + continue + } + result.Inserted++ + capacity-- + if err := tx.Model(&DashboardSuggestionRun{}). + Where("id = ?", runID). + UpdateColumn("suggestion_count", gorm.Expr("suggestion_count + 1")).Error; err != nil { + return err + } + if err := createSuggestionEvent(tx, &suggestion, DashboardSuggestionEventTypeSuggestionCreated, nil, datatypes.JSONMap{ + "prompt_version": promptVersion, + }); err != nil { + return err + } + } + return nil + }) + return result, err +} + +func (s *SuggestionService) Accept(sspID uuid.UUID, suggestionIDs []uuid.UUID, actorID uuid.UUID) error { + return s.db.Transaction(func(tx *gorm.DB) error { + suggestions, err := loadPendingSuggestions(tx, sspID, suggestionIDs) + if err != nil { + return err + } + now := time.Now().UTC() + byHash := map[string][]DashboardSuggestion{} + for _, suggestion := range suggestions { + byHash[suggestion.LabelSetHash] = append(byHash[suggestion.LabelSetHash], suggestion) + } + + for hash, group := range byHash { + sort.Slice(group, func(i, j int) bool { + if group[i].Confidence != group[j].Confidence { + return group[i].Confidence > group[j].Confidence + } + return group[i].LabelSetHash < group[j].LabelSetHash + }) + labels := jsonMapToLabels(group[0].LabelSet) + filterID, created, err := s.acceptFilterForHash(tx, sspID, hash, labels, group) + if err != nil { + return err + } + for _, suggestion := range group { + if err := tx.Exec(` + INSERT INTO filter_controls (filter_id, control_catalog_id, control_id) + VALUES (?, ?, ?) + ON CONFLICT DO NOTHING + `, filterID, suggestion.ControlCatalogID, suggestion.ControlID).Error; err != nil { + return err + } + if err := tx.Model(&DashboardSuggestion{}). + Where("id = ?", suggestion.ID). + Updates(map[string]any{ + "status": DashboardSuggestionStatusAccepted, + "accepted_filter_id": filterID, + "decided_by_user_id": actorID, + "decided_at": now, + }).Error; err != nil { + return err + } + suggestion.Status = DashboardSuggestionStatusAccepted + suggestion.AcceptedFilterID = &filterID + suggestion.DecidedByUserID = &actorID + suggestion.DecidedAt = &now + payload := datatypes.JSONMap{ + "filter_id": filterID.String(), + "created": created, + "reasoning": suggestion.Reasoning, + "confidence": suggestion.Confidence, + } + var run DashboardSuggestionRun + if err := tx.Select("model", "prompt_version").Where("id = ?", suggestion.RunID).First(&run).Error; err == nil { + payload["model"] = run.Model + payload["prompt_version"] = run.PromptVersion + } + if err := createSuggestionEvent(tx, &suggestion, DashboardSuggestionEventTypeAccepted, &actorID, payload); err != nil { + return err + } + } + } + return nil + }) +} + +func (s *SuggestionService) Reject(sspID uuid.UUID, suggestionIDs []uuid.UUID, reason string, actorID uuid.UUID) error { + return s.db.Transaction(func(tx *gorm.DB) error { + suggestions, err := loadPendingSuggestions(tx, sspID, suggestionIDs) + if err != nil { + return err + } + now := time.Now().UTC() + reason = strings.TrimSpace(reason) + for _, suggestion := range suggestions { + if err := tx.Model(&DashboardSuggestion{}). + Where("id = ?", suggestion.ID). + Updates(map[string]any{ + "status": DashboardSuggestionStatusRejected, + "reject_reason": reason, + "decided_by_user_id": actorID, + "decided_at": now, + }).Error; err != nil { + return err + } + suggestion.Status = DashboardSuggestionStatusRejected + suggestion.RejectReason = &reason + suggestion.DecidedByUserID = &actorID + suggestion.DecidedAt = &now + if err := createSuggestionEvent(tx, &suggestion, DashboardSuggestionEventTypeRejected, &actorID, datatypes.JSONMap{ + "reason": reason, + }); err != nil { + return err + } + } + return nil + }) +} + +func (s *SuggestionService) acceptFilterForHash(tx *gorm.DB, sspID uuid.UUID, hash string, labels map[string]string, group []DashboardSuggestion) (uuid.UUID, bool, error) { + for _, suggestion := range group { + if suggestion.TargetFilterID == nil { + continue + } + filter, ok, err := loadSameSSPFilterWithHash(tx, sspID, *suggestion.TargetFilterID, hash) + if err != nil { + return uuid.Nil, false, err + } + if ok { + return *filter.ID, false, nil + } + } + + var filters []relational.Filter + if err := tx.Where("ssp_id = ?", sspID).Order("name ASC, id ASC").Find(&filters).Error; err != nil { + return uuid.Nil, false, err + } + for _, filter := range filters { + filterLabels, ok := CanonicalizeFilter(filter.Filter.Data()) + if ok && CanonicalLabelSetHash(filterLabels) == hash { + return *filter.ID, false, nil + } + } + + name := group[0].ProposedFilterName + if strings.TrimSpace(name) == "" { + name = fallbackFilterName(labels) + } + filter := relational.Filter{ + Name: name, + SSPID: &sspID, + Filter: datatypes.NewJSONType(BuildLabelFilter(labels)), + } + if err := tx.Create(&filter).Error; err != nil { + return uuid.Nil, false, err + } + return *filter.ID, true, nil +} + +func loadPendingSuggestions(tx *gorm.DB, sspID uuid.UUID, ids []uuid.UUID) ([]DashboardSuggestion, error) { + var suggestions []DashboardSuggestion + if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}). + Where("id IN ?", ids). + Find(&suggestions).Error; err != nil { + return nil, err + } + found := map[uuid.UUID]DashboardSuggestion{} + for _, suggestion := range suggestions { + found[*suggestion.ID] = suggestion + } + offending := make([]uuid.UUID, 0) + for _, id := range ids { + suggestion, ok := found[id] + if !ok || suggestion.SSPID != sspID || suggestion.Status != DashboardSuggestionStatusPending { + offending = append(offending, id) + } + } + if len(offending) > 0 { + return nil, &ConflictError{IDs: offending} + } + return suggestions, nil +} + +func loadSameSSPFilterWithHash(tx *gorm.DB, sspID uuid.UUID, filterID uuid.UUID, hash string) (relational.Filter, bool, error) { + var filter relational.Filter + err := tx.Where("id = ? AND ssp_id = ?", filterID, sspID).First(&filter).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + return relational.Filter{}, false, nil + } + if err != nil { + return relational.Filter{}, false, err + } + labels, ok := CanonicalizeFilter(filter.Filter.Data()) + if !ok || CanonicalLabelSetHash(labels) != hash { + return relational.Filter{}, false, nil + } + return filter, true, nil +} + +func (s *SuggestionService) mappingExcluded(tx *gorm.DB, sspID uuid.UUID, promptVersion string, mapping ValidatedMapping) (bool, error) { + catalogID, controlID, err := ParseControlKey(mapping.ControlKey) + if err != nil { + return false, err + } + var existingCount int64 + if err := tx.Model(&DashboardSuggestion{}). + Joins("JOIN dashboard_suggestion_runs ON dashboard_suggestion_runs.id = dashboard_suggestions.run_id"). + Where("dashboard_suggestions.ssp_id = ? AND control_catalog_id = ? AND control_id = ? AND label_set_hash = ?", sspID, catalogID, controlID, mapping.LabelSetHash). + Where("(dashboard_suggestions.status = ? OR (dashboard_suggestions.status = ? AND dashboard_suggestion_runs.prompt_version = ?))", DashboardSuggestionStatusAccepted, DashboardSuggestionStatusRejected, promptVersion). + Count(&existingCount).Error; err != nil { + return false, err + } + if existingCount > 0 { + return true, nil + } + + var filters []relational.Filter + if err := tx. + Joins("JOIN filter_controls ON filter_controls.filter_id = filters.id"). + Where("(filters.ssp_id IS NULL OR filters.ssp_id = ?) AND filter_controls.control_catalog_id = ? AND filter_controls.control_id = ?", sspID, catalogID, controlID). + Group("filters.id"). + Find(&filters).Error; err != nil { + return false, err + } + for _, filter := range filters { + labels, ok := CanonicalizeFilter(filter.Filter.Data()) + if ok && CanonicalLabelSetHash(labels) == mapping.LabelSetHash { + return true, nil + } + } + return false, nil +} + +func createSuggestionEvent(tx *gorm.DB, suggestion *DashboardSuggestion, eventType DashboardSuggestionEventType, actorID *uuid.UUID, payload datatypes.JSONMap) error { + snapshot, err := suggestionSnapshot(suggestion) + if err != nil { + return err + } + event := DashboardSuggestionEvent{ + RunID: &suggestion.RunID, + SuggestionID: suggestion.ID, + EventType: string(eventType), + ActorUserID: actorID, + OccurredAt: time.Now().UTC(), + Payload: payload, + Snapshot: snapshot, + } + return tx.Create(&event).Error +} + +func suggestionSnapshot(suggestion *DashboardSuggestion) (datatypes.JSONMap, error) { + raw, err := json.Marshal(suggestion) + if err != nil { + return nil, err + } + var snapshot datatypes.JSONMap + if err := json.Unmarshal(raw, &snapshot); err != nil { + return nil, err + } + return snapshot, nil +} + +func selectDimension(requested []string, all []string, allSet map[string]struct{}) ([]string, []string) { + if len(requested) == 0 { + return append([]string(nil), all...), nil + } + selectedSet := map[string]struct{}{} + unknown := make([]string, 0) + for _, value := range requested { + value = strings.TrimSpace(value) + if _, ok := allSet[value]; !ok { + unknown = append(unknown, value) + continue + } + selectedSet[value] = struct{}{} + } + sort.Strings(unknown) + selected := make([]string, 0, len(selectedSet)) + for _, value := range all { + if _, ok := selectedSet[value]; ok { + selected = append(selected, value) + } + } + return selected, unknown +} + +func stringSet(values []string) map[string]struct{} { + out := make(map[string]struct{}, len(values)) + for _, value := range values { + out[value] = struct{}{} + } + return out +} + +func labelsToJSONMap(labels map[string]string) datatypes.JSONMap { + out := datatypes.JSONMap{} + for key, value := range labels { + out[key] = value + } + return out +} + +func jsonMapToLabels(labels datatypes.JSONMap) map[string]string { + out := make(map[string]string, len(labels)) + for key, value := range labels { + out[key] = fmt.Sprint(value) + } + return out +} diff --git a/internal/service/relational/suggestions/testdata/prompt_v1.golden b/internal/service/relational/suggestions/testdata/prompt_v1.golden new file mode 100644 index 00000000..f3db3fb5 --- /dev/null +++ b/internal/service/relational/suggestions/testdata/prompt_v1.golden @@ -0,0 +1,82 @@ +You map compliance evidence streams to security controls for a specific SSP. + +You are given: +1. The system context: name, description, components it uses, plus each control's implementation text from this plan. +2. A subset of the controls in this SSP. +3. A subset of evidence label-sets with sample titles. +4. Label-key documentation. +5. The dashboards visible to this plan. + +The controls and label-sets shown are one slice of a larger set. Only map what you see; absence here means nothing. + +For each shown control that a shown label-set provides evidence for, emit a mapping, but only when the evidence pertains to this system. When a label identifies an asset such as a repository, cluster, account, host, service, image, namespace, project, or environment, that asset must correspond to a component or description in the system context. Exclude evidence for assets this system does not use. + +Respect qualifiers in the control text. A control scoped to a provider, technology, component, environment, or other qualifier only matches evidence whose labels satisfy that qualifier. + +Use extend_filter with target_filter_id only when one of this plan's own dashboards has exactly that label-set. Otherwise use new_filter with a short descriptive proposed_filter_name. Global dashboards are listed only to avoid duplicate names; never extend them. + +Only reference control_key and label_set_hash values present in the input. Reasoning must state both why the evidence satisfies the control and why it belongs to this system. Provide confidence from 0 to 1. +---USER--- +Prompt version: v1 + +System context: +{ + "system_name": "Payments API", + "description": "Processes card payments.", + "components": [ + { + "title": "payments-api", + "type": "service", + "purpose": "payment processing", + "description": "Go API" + } + ] +} + +Controls: +[ + { + "control_key": "11111111-1111-1111-1111-111111111111:AC-1", + "catalog_id": "", + "control_id": "", + "catalog_title": "", + "title": "Policy", + "statement": "", + "implementation_text": "Uses payments-api." + } +] + +Evidence label-sets: +[ + { + "hash": "hash1", + "labels": { + "repo": "payments-api" + }, + "evidence_count": 2, + "sample_titles": [ + "scan" + ] + } +] + +Label-key documentation: +[ + { + "key": "repo", + "description": "Repository name" + } +] + +This SSP's extendable dashboards: +[ + { + "id": "22222222-2222-2222-2222-222222222222", + "name": "payments-api" + } +] + +Global dashboard names: +null + +Return JSON matching the provided schema. diff --git a/internal/service/relational/suggestions/validation.go b/internal/service/relational/suggestions/validation.go new file mode 100644 index 00000000..a53ef4f7 --- /dev/null +++ b/internal/service/relational/suggestions/validation.go @@ -0,0 +1,254 @@ +package suggestions + +import ( + "encoding/json" + "fmt" + "sort" + "strings" + + "github.com/google/uuid" +) + +const ( + MappingActionNewFilter = "new_filter" + MappingActionExtendFilter = "extend_filter" +) + +type CellInput struct { + Controls []ControlInput + LabelSets []LabelSetInput + VisibleFilters []VisibleFilterInput + SameSSPFilters []VisibleFilterInput + GlobalFilterNames []string +} + +type ControlInput struct { + ControlKey string `json:"control_key"` + CatalogID string `json:"catalog_id"` + ControlID string `json:"control_id"` + CatalogTitle string `json:"catalog_title"` + Title string `json:"title"` + Statement string `json:"statement"` + ImplementationText string `json:"implementation_text"` +} + +type LabelSetInput struct { + Hash string `json:"hash"` + Labels map[string]string `json:"labels"` + EvidenceCount int `json:"evidence_count"` + SampleTitles []string `json:"sample_titles"` +} + +type VisibleFilterInput struct { + ID uuid.UUID `json:"id"` + Name string `json:"name"` + SSPID *uuid.UUID `json:"ssp_id,omitempty"` + LabelSetHash *string `json:"label_set_hash,omitempty"` +} + +type RawMappings struct { + Mappings []RawMapping `json:"mappings"` +} + +type RawMapping struct { + ControlKey string `json:"control_key"` + LabelSetHash string `json:"label_set_hash"` + Action string `json:"action"` + TargetFilterID string `json:"target_filter_id,omitempty"` + ProposedFilterName string `json:"proposed_filter_name,omitempty"` + Confidence float64 `json:"confidence"` + Reasoning string `json:"reasoning"` +} + +type ValidatedMapping struct { + ControlKey string + LabelSetHash string + LabelSet map[string]string + Action string + TargetFilterID *uuid.UUID + ProposedFilterName string + Confidence float64 + Reasoning string +} + +type ValidationCounts map[string]int + +type ValidationResult struct { + Mappings []ValidatedMapping + Counts ValidationCounts +} + +func ValidateMappings(input CellInput, raw []byte) (ValidationResult, error) { + var decoded RawMappings + if err := json.Unmarshal(raw, &decoded); err != nil { + return ValidationResult{}, err + } + return ValidateRawMappings(input, decoded.Mappings), nil +} + +func ValidateRawMappings(input CellInput, rawMappings []RawMapping) ValidationResult { + counts := ValidationCounts{} + controlSet := map[string]struct{}{} + for _, control := range input.Controls { + controlSet[control.ControlKey] = struct{}{} + } + labelSets := map[string]LabelSetInput{} + for _, labelSet := range input.LabelSets { + labelSets[labelSet.Hash] = labelSet + } + sameSSPFilters := map[uuid.UUID]VisibleFilterInput{} + for _, filter := range input.SameSSPFilters { + sameSSPFilters[filter.ID] = filter + } + + kept := map[string]ValidatedMapping{} + for _, raw := range rawMappings { + controlKey := strings.TrimSpace(raw.ControlKey) + labelSetHash := strings.TrimSpace(raw.LabelSetHash) + if _, ok := controlSet[controlKey]; !ok { + counts["rejected_unknown_control"]++ + continue + } + labelSet, ok := labelSets[labelSetHash] + if !ok { + counts["rejected_unknown_label_set"]++ + continue + } + if raw.Confidence < 0 || raw.Confidence > 1 { + counts["rejected_confidence_out_of_range"]++ + continue + } + reasoning := strings.TrimSpace(raw.Reasoning) + if reasoning == "" { + counts["rejected_empty_reasoning"]++ + continue + } + if len(reasoning) > MaxReasoningLength { + reasoning = reasoning[:MaxReasoningLength] + ReasoningTruncatedMarker + counts["reasoning_truncated"]++ + } + + action := raw.Action + var targetFilterID *uuid.UUID + if action == MappingActionExtendFilter { + parsed, err := uuid.Parse(strings.TrimSpace(raw.TargetFilterID)) + filter, found := sameSSPFilters[parsed] + if err != nil || !found || filter.LabelSetHash == nil || *filter.LabelSetHash != labelSetHash { + action = MappingActionNewFilter + counts["downgraded_extend_to_new"]++ + } else { + targetFilterID = &parsed + } + } + if action != MappingActionExtendFilter { + action = MappingActionNewFilter + targetFilterID = nil + } + + name := strings.TrimSpace(raw.ProposedFilterName) + if action == MappingActionNewFilter { + if name == "" { + name = fallbackFilterName(labelSet.Labels) + counts["fallback_name"]++ + } + if len(name) > 120 { + name = name[:120] + counts["name_truncated"]++ + } + } + + mapping := ValidatedMapping{ + ControlKey: controlKey, + LabelSetHash: labelSetHash, + LabelSet: labelSet.Labels, + Action: action, + TargetFilterID: targetFilterID, + ProposedFilterName: name, + Confidence: raw.Confidence, + Reasoning: reasoning, + } + dedupeKey := controlKey + "\x00" + labelSetHash + if existing, found := kept[dedupeKey]; !found || mapping.Confidence > existing.Confidence { + if found { + counts["deduped_within_cell"]++ + } + kept[dedupeKey] = mapping + } else { + counts["deduped_within_cell"]++ + } + } + + mappings := make([]ValidatedMapping, 0, len(kept)) + for _, mapping := range kept { + mappings = append(mappings, mapping) + } + sort.Slice(mappings, func(i, j int) bool { + if mappings[i].ControlKey != mappings[j].ControlKey { + return mappings[i].ControlKey < mappings[j].ControlKey + } + if mappings[i].Confidence != mappings[j].Confidence { + return mappings[i].Confidence > mappings[j].Confidence + } + return mappings[i].LabelSetHash < mappings[j].LabelSetHash + }) + + mappings = capMappingsPerControl(mappings, counts) + return ValidationResult{Mappings: mappings, Counts: counts} +} + +func capMappingsPerControl(mappings []ValidatedMapping, counts ValidationCounts) []ValidatedMapping { + byControl := map[string][]ValidatedMapping{} + for _, mapping := range mappings { + byControl[mapping.ControlKey] = append(byControl[mapping.ControlKey], mapping) + } + controls := make([]string, 0, len(byControl)) + for controlKey := range byControl { + controls = append(controls, controlKey) + } + sort.Strings(controls) + + out := make([]ValidatedMapping, 0, len(mappings)) + for _, controlKey := range controls { + group := byControl[controlKey] + sort.Slice(group, func(i, j int) bool { + if group[i].Confidence != group[j].Confidence { + return group[i].Confidence > group[j].Confidence + } + return group[i].LabelSetHash < group[j].LabelSetHash + }) + if len(group) > MaxMappingsPerControlPerCell { + counts["dropped_control_cap"] += len(group) - MaxMappingsPerControlPerCell + group = group[:MaxMappingsPerControlPerCell] + } + out = append(out, group...) + } + return out +} + +func fallbackFilterName(labels map[string]string) string { + lines := canonicalLabelLines(labels) + name := strings.Join(lines, ", ") + if name == "" { + return "Evidence label set" + } + if len(name) > 120 { + return name[:120] + } + return name +} + +func ParseControlKey(controlKey string) (uuid.UUID, string, error) { + catalogIDRaw, controlID, ok := strings.Cut(controlKey, ":") + if !ok || strings.TrimSpace(controlID) == "" { + return uuid.Nil, "", fmt.Errorf("invalid control key %q", controlKey) + } + catalogID, err := uuid.Parse(catalogIDRaw) + if err != nil { + return uuid.Nil, "", fmt.Errorf("invalid control key %q: %w", controlKey, err) + } + return catalogID, controlID, nil +} + +func ControlKey(catalogID uuid.UUID, controlID string) string { + return catalogID.String() + ":" + controlID +} From d6334151c6083e570662cf78261bbe4fc64dc3fa Mon Sep 17 00:00:00 2001 From: "ccf-lisa[bot]" <286799724+ccf-lisa[bot]@users.noreply.github.com> Date: Mon, 15 Jun 2026 11:08:17 -0300 Subject: [PATCH 2/6] self-review: address pass 1 findings --- .../suggestions/models_integration_test.go | 46 +++++++++++++++++++ .../service/relational/suggestions/service.go | 17 +++++++ 2 files changed, 63 insertions(+) diff --git a/internal/service/relational/suggestions/models_integration_test.go b/internal/service/relational/suggestions/models_integration_test.go index 16c54ee0..89377a3b 100644 --- a/internal/service/relational/suggestions/models_integration_test.go +++ b/internal/service/relational/suggestions/models_integration_test.go @@ -3,6 +3,7 @@ package suggestions_test import ( + "sync" "testing" "time" @@ -271,6 +272,51 @@ func (suite *DashboardSuggestionsIntegrationSuite) TestAcceptCreatesOneSSPBoundF suite.Equal(int64(2), eventCount) } +func (suite *DashboardSuggestionsIntegrationSuite) TestConcurrentAcceptsCreateOneSSPBoundFilterForSameHash() { + sspID := uuid.New() + runID := uuid.New() + catalogID := uuid.New() + actorID := uuid.New() + labels := map[string]string{"env": "prod", "repo": "payments-api"} + hash := suggestionrel.CanonicalLabelSetHash(labels) + suite.seedSuggestionSSPAndRun(sspID, runID) + + first := suite.seedDashboardSuggestion(runID, sspID, catalogID, "AC-1", labels, hash, "first", 0.8, nil) + second := suite.seedDashboardSuggestion(runID, sspID, catalogID, "AC-2", labels, hash, "second", 0.7, nil) + + svc := suggestionrel.NewSuggestionService(suite.DB) + errs := make(chan error, 2) + var wg sync.WaitGroup + wg.Add(2) + for _, suggestionID := range []uuid.UUID{*first.ID, *second.ID} { + go func(id uuid.UUID) { + defer wg.Done() + errs <- svc.Accept(sspID, []uuid.UUID{id}, actorID) + }(suggestionID) + } + wg.Wait() + close(errs) + for err := range errs { + suite.Require().NoError(err) + } + + var filters []relational.Filter + suite.Require().NoError(suite.DB.Where("ssp_id = ?", sspID).Find(&filters).Error) + suite.Require().Len(filters, 1) + filterLabels, ok := suggestionrel.CanonicalizeFilter(filters[0].Filter.Data()) + suite.True(ok) + suite.Equal(labels, filterLabels) + + var accepted []suggestionrel.DashboardSuggestion + suite.Require().NoError(suite.DB.Where("id IN ?", []uuid.UUID{*first.ID, *second.ID}).Find(&accepted).Error) + suite.Require().Len(accepted, 2) + for _, suggestion := range accepted { + suite.Equal(suggestionrel.DashboardSuggestionStatusAccepted, suggestion.Status) + suite.Require().NotNil(suggestion.AcceptedFilterID) + suite.Equal(*filters[0].ID, *suggestion.AcceptedFilterID) + } +} + func (suite *DashboardSuggestionsIntegrationSuite) TestAcceptExtendsSameSSPMatchingFilter() { sspID := uuid.New() runID := uuid.New() diff --git a/internal/service/relational/suggestions/service.go b/internal/service/relational/suggestions/service.go index 95575e25..e7f66e3d 100644 --- a/internal/service/relational/suggestions/service.go +++ b/internal/service/relational/suggestions/service.go @@ -1,6 +1,8 @@ package suggestions import ( + "crypto/sha256" + "encoding/binary" "encoding/json" "errors" "fmt" @@ -306,6 +308,10 @@ func (s *SuggestionService) Reject(sspID uuid.UUID, suggestionIDs []uuid.UUID, r } func (s *SuggestionService) acceptFilterForHash(tx *gorm.DB, sspID uuid.UUID, hash string, labels map[string]string, group []DashboardSuggestion) (uuid.UUID, bool, error) { + if err := lockAcceptFilterHash(tx, sspID, hash); err != nil { + return uuid.Nil, false, err + } + for _, suggestion := range group { if suggestion.TargetFilterID == nil { continue @@ -345,6 +351,17 @@ func (s *SuggestionService) acceptFilterForHash(tx *gorm.DB, sspID uuid.UUID, ha return *filter.ID, true, nil } +func lockAcceptFilterHash(tx *gorm.DB, sspID uuid.UUID, hash string) error { + if tx.Dialector.Name() != "postgres" { + return nil + } + + sum := sha256.Sum256([]byte(sspID.String() + ":" + hash)) + key1 := int32(binary.BigEndian.Uint32(sum[0:4])) + key2 := int32(binary.BigEndian.Uint32(sum[4:8])) + return tx.Exec("SELECT pg_advisory_xact_lock(?, ?)", key1, key2).Error +} + func loadPendingSuggestions(tx *gorm.DB, sspID uuid.UUID, ids []uuid.UUID) ([]DashboardSuggestion, error) { var suggestions []DashboardSuggestion if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}). From 9632e8f05b80dc87e748162270fc2d90f70f0285 Mon Sep 17 00:00:00 2001 From: "ccf-lisa[bot]" <286799724+ccf-lisa[bot]@users.noreply.github.com> Date: Mon, 15 Jun 2026 11:13:29 -0300 Subject: [PATCH 3/6] self-review: address pass 2 findings --- .../service/relational/suggestions/core.go | 10 ++++++- .../relational/suggestions/core_test.go | 27 +++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/internal/service/relational/suggestions/core.go b/internal/service/relational/suggestions/core.go index 3dc95d05..65dd919a 100644 --- a/internal/service/relational/suggestions/core.go +++ b/internal/service/relational/suggestions/core.go @@ -133,7 +133,15 @@ func canonicalizeScope(scope labelfilter.Scope, labels map[string]string) bool { if condition.Operator != "=" { return false } - labels[strings.ToLower(condition.Label)] = condition.Value + key := strings.ToLower(condition.Label) + existing, seen := labels[key] + if seen && existing != condition.Value { + return false + } + if seen { + return true + } + labels[key] = condition.Value return true } query := scope.Query diff --git a/internal/service/relational/suggestions/core_test.go b/internal/service/relational/suggestions/core_test.go index 7720313e..ff80e3ba 100644 --- a/internal/service/relational/suggestions/core_test.go +++ b/internal/service/relational/suggestions/core_test.go @@ -45,6 +45,33 @@ func TestCanonicalizeFilterRefusesOrAndNotEquals(t *testing.T) { require.False(t, ok) } +func TestCanonicalizeFilterRefusesConflictingDuplicateLabelConditions(t *testing.T) { + filter := labelfilter.Filter{Scope: &labelfilter.Scope{Query: &labelfilter.Query{ + Operator: "AND", + Scopes: []labelfilter.Scope{ + {Condition: &labelfilter.Condition{Label: "env", Operator: "=", Value: "prod"}}, + {Condition: &labelfilter.Condition{Label: "Env", Operator: "=", Value: "stage"}}, + }, + }}} + + _, ok := CanonicalizeFilter(filter) + require.False(t, ok) +} + +func TestCanonicalizeFilterAcceptsDuplicateSameLabelConditions(t *testing.T) { + filter := labelfilter.Filter{Scope: &labelfilter.Scope{Query: &labelfilter.Query{ + Operator: "AND", + Scopes: []labelfilter.Scope{ + {Condition: &labelfilter.Condition{Label: "env", Operator: "=", Value: "prod"}}, + {Condition: &labelfilter.Condition{Label: "Env", Operator: "=", Value: "prod"}}, + }, + }}} + + labels, ok := CanonicalizeFilter(filter) + require.True(t, ok) + require.Equal(t, map[string]string{"env": "prod"}, labels) +} + func TestResolveSnapshot(t *testing.T) { controls := []string{"catalog-b:AC-2", "catalog-a:AC-1"} hashes := []string{"h2", "h1"} From 8fb9dd702857e922b87cec919141138c2002074b Mon Sep 17 00:00:00 2001 From: "ccf-lisa[bot]" <286799724+ccf-lisa[bot]@users.noreply.github.com> Date: Mon, 15 Jun 2026 11:21:02 -0300 Subject: [PATCH 4/6] self-review: address pass 3 findings --- .../service/relational/suggestions/core.go | 41 ++++++--- .../relational/suggestions/core_test.go | 17 ++++ .../service/relational/suggestions/gather.go | 10 +- .../suggestions/models_integration_test.go | 92 +++++++++++++++++++ .../service/relational/suggestions/service.go | 2 +- 5 files changed, 147 insertions(+), 15 deletions(-) diff --git a/internal/service/relational/suggestions/core.go b/internal/service/relational/suggestions/core.go index 65dd919a..c7588324 100644 --- a/internal/service/relational/suggestions/core.go +++ b/internal/service/relational/suggestions/core.go @@ -79,11 +79,28 @@ func ResolveSnapshot(scope Scope, allControlKeys []string, allLabelSetHashes []s } func CanonicalLabelSetHash(labels map[string]string) string { - lines := canonicalLabelLines(labels) + normalized, ok := NormalizeLabelSet(labels) + lines := canonicalLabelLines(normalized) + if !ok { + lines = conflictingLabelLines(labels) + } sum := sha256.Sum256([]byte(strings.Join(lines, "\n"))) return hex.EncodeToString(sum[:]) } +func NormalizeLabelSet(labels map[string]string) (map[string]string, bool) { + normalized := make(map[string]string, len(labels)) + for key, value := range labels { + lowerKey := strings.ToLower(key) + existing, seen := normalized[lowerKey] + if seen && existing != value { + return nil, false + } + normalized[lowerKey] = value + } + return normalized, true +} + func BuildLabelFilter(labels map[string]string) labelfilter.Filter { keys := make([]string, 0, len(labels)) for key := range labels { @@ -159,21 +176,23 @@ func canonicalizeScope(scope labelfilter.Scope, labels map[string]string) bool { func canonicalLabelLines(labels map[string]string) []string { keys := make([]string, 0, len(labels)) for key := range labels { - keys = append(keys, strings.ToLower(key)) + keys = append(keys, key) } sort.Strings(keys) lines := make([]string, 0, len(keys)) - for _, lowerKey := range keys { - value := "" - for key, candidate := range labels { - if strings.ToLower(key) == lowerKey { - value = candidate - break - } - } - lines = append(lines, fmt.Sprintf("%s=%s", lowerKey, value)) + for _, key := range keys { + lines = append(lines, fmt.Sprintf("%s=%s", key, labels[key])) + } + return lines +} + +func conflictingLabelLines(labels map[string]string) []string { + lines := make([]string, 0, len(labels)) + for key, value := range labels { + lines = append(lines, fmt.Sprintf("%s=%s", strings.ToLower(key), value)) } + sort.Strings(lines) return lines } diff --git a/internal/service/relational/suggestions/core_test.go b/internal/service/relational/suggestions/core_test.go index ff80e3ba..8d52164d 100644 --- a/internal/service/relational/suggestions/core_test.go +++ b/internal/service/relational/suggestions/core_test.go @@ -19,6 +19,23 @@ func TestCanonicalLabelSetHashStable(t *testing.T) { require.NotEqual(t, got, CanonicalLabelSetHash(map[string]string{"env": "prod", "repo": "API"})) } +func TestNormalizeLabelSetCollapsesSameValueCaseDuplicates(t *testing.T) { + labels, ok := NormalizeLabelSet(map[string]string{"Env": "prod", "env": "prod", "Repo": "api"}) + require.True(t, ok) + require.Equal(t, map[string]string{"env": "prod", "repo": "api"}, labels) + require.Equal(t, CanonicalLabelSetHash(map[string]string{"env": "prod", "repo": "api"}), CanonicalLabelSetHash(map[string]string{"Env": "prod", "env": "prod", "Repo": "api"})) +} + +func TestCanonicalLabelSetHashConflictingCaseDuplicatesDeterministic(t *testing.T) { + _, ok := NormalizeLabelSet(map[string]string{"Env": "prod", "env": "stage"}) + require.False(t, ok) + + got := CanonicalLabelSetHash(map[string]string{"Env": "prod", "env": "stage"}) + require.Equal(t, got, CanonicalLabelSetHash(map[string]string{"env": "stage", "Env": "prod"})) + require.NotEqual(t, got, CanonicalLabelSetHash(map[string]string{"env": "stage"})) + require.NotEqual(t, got, CanonicalLabelSetHash(map[string]string{"env": "prod"})) +} + func TestBuildLabelFilterCanonicalizeRoundTrip(t *testing.T) { labels := map[string]string{"env": "prod", "repo": "api"} filter := BuildLabelFilter(labels) diff --git a/internal/service/relational/suggestions/gather.go b/internal/service/relational/suggestions/gather.go index ce9fe604..004baa90 100644 --- a/internal/service/relational/suggestions/gather.go +++ b/internal/service/relational/suggestions/gather.go @@ -197,11 +197,15 @@ func (s *SuggestionService) gatherAllLabelSets() ([]LabelSetInput, error) { byHash := map[string]*LabelSetInput{} for _, group := range byEvidence { - hash := CanonicalLabelSetHash(group.labels) + labels, ok := NormalizeLabelSet(group.labels) + if !ok { + continue + } + hash := CanonicalLabelSetHash(labels) labelSet := byHash[hash] if labelSet == nil { - copied := make(map[string]string, len(group.labels)) - for key, value := range group.labels { + copied := make(map[string]string, len(labels)) + for key, value := range labels { copied[key] = value } labelSet = &LabelSetInput{Hash: hash, Labels: copied} diff --git a/internal/service/relational/suggestions/models_integration_test.go b/internal/service/relational/suggestions/models_integration_test.go index 89377a3b..27e7f98f 100644 --- a/internal/service/relational/suggestions/models_integration_test.go +++ b/internal/service/relational/suggestions/models_integration_test.go @@ -380,6 +380,78 @@ func (suite *DashboardSuggestionsIntegrationSuite) TestInsertExcludesMatchingGlo suite.Nil(reloaded.SSPID) } +func (suite *DashboardSuggestionsIntegrationSuite) TestInsertValidatedMappingsRejectsRunSSPMismatch() { + sspA := uuid.New() + sspB := uuid.New() + runID := uuid.New() + catalogID := uuid.New() + labels := map[string]string{"env": "prod"} + hash := suggestionrel.CanonicalLabelSetHash(labels) + suite.seedSuggestionSSPAndRun(sspA, runID) + suite.Require().NoError(suite.DB.Create(&relational.SystemSecurityPlan{UUIDModel: relational.UUIDModel{ID: &sspB}}).Error) + + svc := suggestionrel.NewSuggestionService(suite.DB) + result, err := svc.InsertValidatedMappings(runID, sspB, suggestionrel.PromptVersion, []suggestionrel.ValidatedMapping{{ + ControlKey: suggestionrel.ControlKey(catalogID, "AC-1"), + LabelSetHash: hash, + LabelSet: labels, + Action: suggestionrel.MappingActionNewFilter, + ProposedFilterName: "prod", + Confidence: 0.8, + Reasoning: "matches", + }}, 10) + suite.Error(err) + suite.Equal(suggestionrel.InsertMappingsResult{}, result) + + var suggestionCount int64 + suite.Require().NoError(suite.DB.Model(&suggestionrel.DashboardSuggestion{}). + Where("ssp_id IN ?", []uuid.UUID{sspA, sspB}). + Count(&suggestionCount).Error) + suite.Zero(suggestionCount) + + var run suggestionrel.DashboardSuggestionRun + suite.Require().NoError(suite.DB.First(&run, "id = ?", runID).Error) + suite.Zero(run.SuggestionCount) +} + +func (suite *DashboardSuggestionsIntegrationSuite) TestGatherLabelSetsNormalizesAndSkipsCaseVariantDuplicates() { + sspID := uuid.New() + runID := uuid.New() + suite.seedSuggestionSSPAndRun(sspID, runID) + + sameValueEvidenceID := uuid.New() + sameValueEvidenceUUID := uuid.New() + conflictingEvidenceID := uuid.New() + conflictingEvidenceUUID := uuid.New() + now := time.Now().UTC() + suite.insertEvidenceLabels(sameValueEvidenceID, sameValueEvidenceUUID, "same value", now, map[string]string{ + "Env": "prod", + "env": "prod", + "Repo": "api", + }) + suite.insertEvidenceLabels(conflictingEvidenceID, conflictingEvidenceUUID, "conflicting", now, map[string]string{ + "Env": "prod", + "env": "stage", + }) + + normalized := map[string]string{"env": "prod", "repo": "api"} + hash := suggestionrel.CanonicalLabelSetHash(normalized) + conflictingHash := suggestionrel.CanonicalLabelSetHash(map[string]string{"Env": "prod", "env": "stage"}) + svc := suggestionrel.NewSuggestionService(suite.DB) + + snapshot, err := svc.ResolveScope(sspID, suggestionrel.Scope{}) + suite.Require().NoError(err) + suite.Contains(snapshot.LabelSetHashes, hash) + suite.NotContains(snapshot.LabelSetHashes, conflictingHash) + + input, err := svc.GatherCellInput(sspID, suggestionrel.GridCell{LabelSetHashes: []string{hash}}, suggestionrel.GatherOptions{}) + suite.Require().NoError(err) + suite.Require().Len(input.LabelSets, 1) + suite.Equal(hash, input.LabelSets[0].Hash) + suite.Equal(normalized, input.LabelSets[0].Labels) + suite.Equal(1, input.LabelSets[0].EvidenceCount) +} + func (suite *DashboardSuggestionsIntegrationSuite) TestAcceptSSPIsolationAndGlobalFiltersStayVisible() { sspA := uuid.New() sspB := uuid.New() @@ -451,3 +523,23 @@ func (suite *DashboardSuggestionsIntegrationSuite) seedDashboardSuggestion( suite.Require().NoError(suite.DB.Create(&suggestion).Error) return suggestion } + +func (suite *DashboardSuggestionsIntegrationSuite) insertEvidenceLabels(id uuid.UUID, streamUUID uuid.UUID, title string, collectedAt time.Time, labels map[string]string) { + suite.Require().NoError(suite.DB.Exec( + `INSERT INTO evidences (id, uuid, title, description, start, "end") VALUES (?, ?, ?, ?, ?, ?)`, + id, + streamUUID, + title, + title, + collectedAt, + collectedAt, + ).Error) + for key, value := range labels { + suite.Require().NoError(suite.DB.Exec( + `INSERT INTO evidence_labels (evidence_id, labels_name, labels_value) VALUES (?, ?, ?)`, + id, + key, + value, + ).Error) + } +} diff --git a/internal/service/relational/suggestions/service.go b/internal/service/relational/suggestions/service.go index e7f66e3d..2f82920a 100644 --- a/internal/service/relational/suggestions/service.go +++ b/internal/service/relational/suggestions/service.go @@ -143,7 +143,7 @@ func (s *SuggestionService) InsertValidatedMappings(runID uuid.UUID, sspID uuid. result := InsertMappingsResult{} err := s.db.Transaction(func(tx *gorm.DB) error { var run DashboardSuggestionRun - if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("id = ?", runID).First(&run).Error; err != nil { + if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("id = ? AND ssp_id = ?", runID, sspID).First(&run).Error; err != nil { return err } capacity := maxSuggestionsPerRun - run.SuggestionCount From fc928e7ce3837c29b0ba1b9c5a68554e86c0f517 Mon Sep 17 00:00:00 2001 From: "ccf-lisa[bot]" <286799724+ccf-lisa[bot]@users.noreply.github.com> Date: Mon, 15 Jun 2026 11:31:29 -0300 Subject: [PATCH 5/6] fix: CI failures --- internal/service/relational/suggestions/service.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/service/relational/suggestions/service.go b/internal/service/relational/suggestions/service.go index 2f82920a..5efe00b4 100644 --- a/internal/service/relational/suggestions/service.go +++ b/internal/service/relational/suggestions/service.go @@ -352,7 +352,7 @@ func (s *SuggestionService) acceptFilterForHash(tx *gorm.DB, sspID uuid.UUID, ha } func lockAcceptFilterHash(tx *gorm.DB, sspID uuid.UUID, hash string) error { - if tx.Dialector.Name() != "postgres" { + if tx.Name() != "postgres" { return nil } From 206848a354cf9713f803bfb37bd33af42ed77609 Mon Sep 17 00:00:00 2001 From: "ccf-lisa[bot]" <286799724+ccf-lisa[bot]@users.noreply.github.com> Date: Mon, 15 Jun 2026 11:45:54 -0300 Subject: [PATCH 6/6] fix: address review feedback --- .../relational/suggestions/core_test.go | 37 ++++++++ .../service/relational/suggestions/gather.go | 54 ++++++----- .../suggestions/models_integration_test.go | 94 +++++++++++++++++-- .../service/relational/suggestions/service.go | 12 ++- .../relational/suggestions/validation.go | 39 ++++++-- 5 files changed, 199 insertions(+), 37 deletions(-) diff --git a/internal/service/relational/suggestions/core_test.go b/internal/service/relational/suggestions/core_test.go index 8d52164d..295d3921 100644 --- a/internal/service/relational/suggestions/core_test.go +++ b/internal/service/relational/suggestions/core_test.go @@ -6,6 +6,7 @@ import ( "path/filepath" "strings" "testing" + "unicode/utf8" "github.com/compliance-framework/api/internal/converters/labelfilter" "github.com/google/uuid" @@ -193,6 +194,42 @@ func TestValidateMappingsControlCap(t *testing.T) { require.Equal(t, 1, result.Counts["dropped_control_cap"]) } +func TestValidateMappingsTruncatesMultibyteTextAtRuneBoundary(t *testing.T) { + controlKey := ControlKey(uuid.New(), "AC-1") + labelHash := CanonicalLabelSetHash(map[string]string{"env": "prod"}) + input := CellInput{ + Controls: []ControlInput{{ControlKey: controlKey}}, + LabelSets: []LabelSetInput{{Hash: labelHash, Labels: map[string]string{"env": "prod"}}}, + } + + result := ValidateRawMappings(input, []RawMapping{{ + ControlKey: controlKey, + LabelSetHash: labelHash, + Action: MappingActionNewFilter, + ProposedFilterName: strings.Repeat("á", 121), + Confidence: 0.8, + Reasoning: strings.Repeat("🙂", MaxReasoningLength+1), + }}) + + require.Len(t, result.Mappings, 1) + require.Equal(t, 1, result.Counts["reasoning_truncated"]) + require.Equal(t, 1, result.Counts["name_truncated"]) + require.True(t, utf8.ValidString(result.Mappings[0].Reasoning)) + require.True(t, strings.HasSuffix(result.Mappings[0].Reasoning, ReasoningTruncatedMarker)) + require.True(t, utf8.ValidString(result.Mappings[0].ProposedFilterName)) + require.Equal(t, 120, utf8.RuneCountInString(result.Mappings[0].ProposedFilterName)) +} + +func TestFallbackAndGatherTruncatePreserveUTF8(t *testing.T) { + name := fallbackFilterName(map[string]string{"emoji": strings.Repeat("🙂", 121)}) + require.True(t, utf8.ValidString(name)) + require.Equal(t, 120, utf8.RuneCountInString(name)) + + value := truncate(" "+strings.Repeat("á", 6)+" ", 3) + require.True(t, utf8.ValidString(value)) + require.Equal(t, strings.Repeat("á", 3)+ReasoningTruncatedMarker, value) +} + func TestPromptGolden(t *testing.T) { input := GatheredInput{ SystemContext: SystemContextInput{ diff --git a/internal/service/relational/suggestions/gather.go b/internal/service/relational/suggestions/gather.go index 004baa90..94613fa9 100644 --- a/internal/service/relational/suggestions/gather.go +++ b/internal/service/relational/suggestions/gather.go @@ -70,14 +70,14 @@ func (s *SuggestionService) gatherControls(sspID uuid.UUID, controlKeys []string for _, key := range controlKeys { keys[key] = struct{}{} } - catalogIDs := make([]uuid.UUID, 0, len(controlKeys)) + catalogIDStrings := make([]string, 0, len(controlKeys)) controlIDs := make([]string, 0, len(controlKeys)) for _, key := range controlKeys { catalogID, controlID, err := ParseControlKey(key) if err != nil { return nil, err } - catalogIDs = append(catalogIDs, catalogID) + catalogIDStrings = append(catalogIDStrings, catalogID.String()) controlIDs = append(controlIDs, controlID) } @@ -89,23 +89,33 @@ func (s *SuggestionService) gatherControls(sspID uuid.UUID, controlKeys []string c.title, c.parts::text AS parts, COALESCE(m.title, '') AS catalog_title, - TRIM(CONCAT_WS(E'\n', - NULLIF(ir.remarks, ''), - string_agg(NULLIF(st.remarks, ''), E'\n' ORDER BY st.statement_id) - )) AS implementation_text + COALESCE(impl.implementation_text, '') AS implementation_text FROM controls c - JOIN profile_controls pc ON pc.control_catalog_id = c.catalog_id AND pc.control_id = c.id - JOIN ssp_profiles sp ON sp.profile_id = pc.profile_id AND sp.system_security_plan_id = @ssp_id - LEFT JOIN metadata m ON m.parent_type = 'catalogs' AND m.parent_id = c.catalog_id::text - LEFT JOIN control_implementations ci ON ci.system_security_plan_id = @ssp_id - LEFT JOIN implemented_requirements ir ON ir.control_implementation_id = ci.id AND UPPER(ir.control_id) = UPPER(c.id) - LEFT JOIN statements st ON st.implemented_requirement_id = ir.id - WHERE c.catalog_id IN @catalog_ids AND c.id IN @control_ids - GROUP BY c.catalog_id, c.id, c.title, c.parts, m.title, ir.remarks + JOIN profile_controls pc ON pc.control_catalog_id::text = c.catalog_id::text AND pc.control_id::text = c.id::text + JOIN ssp_profiles sp ON sp.profile_id::text = pc.profile_id::text AND sp.system_security_plan_id::text = CAST(@ssp_id AS text) + LEFT JOIN metadata m ON m.parent_type::text = 'catalogs' AND m.parent_id::text = c.catalog_id::text + LEFT JOIN LATERAL ( + SELECT TRIM(string_agg(piece, E'\n' ORDER BY sort_key)) AS implementation_text + FROM ( + SELECT NULLIF(ir.remarks, '') AS piece, '0' AS sort_key + FROM control_implementations ci + JOIN implemented_requirements ir ON ir.control_implementation_id::text = ci.id::text AND UPPER(ir.control_id) = UPPER(c.id) + WHERE ci.system_security_plan_id::text = CAST(@ssp_id AS text) + UNION ALL + SELECT NULLIF(st.remarks, '') AS piece, '1:' || st.statement_id::text AS sort_key + FROM control_implementations ci + JOIN implemented_requirements ir ON ir.control_implementation_id::text = ci.id::text AND UPPER(ir.control_id) = UPPER(c.id) + JOIN statements st ON st.implemented_requirement_id::text = ir.id::text + WHERE ci.system_security_plan_id::text = CAST(@ssp_id AS text) + ) pieces + WHERE piece IS NOT NULL + ) impl ON true + WHERE c.catalog_id::text IN @catalog_ids AND c.id::text IN @control_ids + GROUP BY c.catalog_id, c.id, c.title, c.parts, m.title, impl.implementation_text ORDER BY c.catalog_id ASC, c.id ASC `, map[string]any{ "ssp_id": sspID, - "catalog_ids": catalogIDs, + "catalog_ids": catalogIDStrings, "control_ids": controlIDs, }).Scan(&rows).Error; err != nil { return nil, err @@ -212,14 +222,18 @@ func (s *SuggestionService) gatherAllLabelSets() ([]LabelSetInput, error) { byHash[hash] = labelSet } labelSet.EvidenceCount++ - if group.title != "" && len(labelSet.SampleTitles) < 3 { + if group.title != "" { labelSet.SampleTitles = append(labelSet.SampleTitles, group.title) - sort.Strings(labelSet.SampleTitles) } } out := make([]LabelSetInput, 0, len(byHash)) for _, labelSet := range byHash { + sort.Strings(labelSet.SampleTitles) + labelSet.SampleTitles = dedupeStrings(labelSet.SampleTitles) + if len(labelSet.SampleTitles) > 3 { + labelSet.SampleTitles = labelSet.SampleTitles[:3] + } out = append(out, *labelSet) } sort.Slice(out, func(i, j int) bool { return out[i].Hash < out[j].Hash }) @@ -354,10 +368,8 @@ func normalizeGatherOptions(opts GatherOptions) GatherOptions { func truncate(value string, limit int) string { value = strings.TrimSpace(value) - if limit <= 0 || len(value) <= limit { - return value - } - return value[:limit] + ReasoningTruncatedMarker + truncated, _ := truncateRunes(value, limit, ReasoningTruncatedMarker) + return truncated } func extractPartText(partsJSON string) string { diff --git a/internal/service/relational/suggestions/models_integration_test.go b/internal/service/relational/suggestions/models_integration_test.go index 27e7f98f..c373911f 100644 --- a/internal/service/relational/suggestions/models_integration_test.go +++ b/internal/service/relational/suggestions/models_integration_test.go @@ -272,6 +272,26 @@ func (suite *DashboardSuggestionsIntegrationSuite) TestAcceptCreatesOneSSPBoundF suite.Equal(int64(2), eventCount) } +func (suite *DashboardSuggestionsIntegrationSuite) TestAcceptUsesDeterministicNameTieBreak() { + sspID := uuid.New() + runID := uuid.New() + catalogID := uuid.New() + actorID := uuid.New() + labels := map[string]string{"env": "prod"} + hash := suggestionrel.CanonicalLabelSetHash(labels) + suite.seedSuggestionSSPAndRun(sspID, runID) + + first := suite.seedDashboardSuggestion(runID, sspID, catalogID, "AC-1", labels, hash, "z-name", 0.8, nil) + second := suite.seedDashboardSuggestion(runID, sspID, catalogID, "AC-2", labels, hash, "a-name", 0.8, nil) + + svc := suggestionrel.NewSuggestionService(suite.DB) + suite.Require().NoError(svc.Accept(sspID, []uuid.UUID{*first.ID, *second.ID}, actorID)) + + var filter relational.Filter + suite.Require().NoError(suite.DB.First(&filter, "ssp_id = ?", sspID).Error) + suite.Equal("a-name", filter.Name) +} + func (suite *DashboardSuggestionsIntegrationSuite) TestConcurrentAcceptsCreateOneSSPBoundFilterForSameHash() { sspID := uuid.New() runID := uuid.New() @@ -419,16 +439,16 @@ func (suite *DashboardSuggestionsIntegrationSuite) TestGatherLabelSetsNormalizes runID := uuid.New() suite.seedSuggestionSSPAndRun(sspID, runID) - sameValueEvidenceID := uuid.New() - sameValueEvidenceUUID := uuid.New() conflictingEvidenceID := uuid.New() conflictingEvidenceUUID := uuid.New() now := time.Now().UTC() - suite.insertEvidenceLabels(sameValueEvidenceID, sameValueEvidenceUUID, "same value", now, map[string]string{ - "Env": "prod", - "env": "prod", - "Repo": "api", - }) + for _, title := range []string{"zeta", "alpha", "gamma", "beta"} { + suite.insertEvidenceLabels(uuid.New(), uuid.New(), title, now, map[string]string{ + "Env": "prod", + "env": "prod", + "Repo": "api", + }) + } suite.insertEvidenceLabels(conflictingEvidenceID, conflictingEvidenceUUID, "conflicting", now, map[string]string{ "Env": "prod", "env": "stage", @@ -449,7 +469,61 @@ func (suite *DashboardSuggestionsIntegrationSuite) TestGatherLabelSetsNormalizes suite.Require().Len(input.LabelSets, 1) suite.Equal(hash, input.LabelSets[0].Hash) suite.Equal(normalized, input.LabelSets[0].Labels) - suite.Equal(1, input.LabelSets[0].EvidenceCount) + suite.Equal(4, input.LabelSets[0].EvidenceCount) + suite.Equal([]string{"alpha", "beta", "gamma"}, input.LabelSets[0].SampleTitles) +} + +func (suite *DashboardSuggestionsIntegrationSuite) TestGatherControlsUsesMatchingImplementationWhenSSPHasDuplicateImplementations() { + sspID := uuid.New() + runID := uuid.New() + catalogID := uuid.New() + profileID := uuid.New() + controlID := "AC-1" + suite.seedSuggestionSSPAndRun(sspID, runID) + + suite.Require().NoError(suite.DB.Create(&relational.Profile{UUIDModel: relational.UUIDModel{ID: &profileID}}).Error) + suite.Require().NoError(suite.DB.Create(&relational.Control{ + CatalogID: catalogID, + ID: controlID, + Title: "Access Control Policy", + Parts: datatypes.NewJSONSlice([]relational.Part{}), + }).Error) + suite.Require().NoError(suite.DB.Exec( + `INSERT INTO ssp_profiles (system_security_plan_id, profile_id) VALUES (?, ?)`, + sspID, + profileID, + ).Error) + suite.Require().NoError(suite.DB.Exec( + `INSERT INTO profile_controls (profile_id, control_catalog_id, control_id) VALUES (?, ?, ?)`, + profileID, + catalogID, + controlID, + ).Error) + + emptyImplementationID := uuid.New() + matchingImplementationID := uuid.New() + suite.Require().NoError(suite.DB.Create(&relational.ControlImplementation{ + UUIDModel: relational.UUIDModel{ID: &emptyImplementationID}, + SystemSecurityPlanId: sspID, + }).Error) + suite.Require().NoError(suite.DB.Create(&relational.ControlImplementation{ + UUIDModel: relational.UUIDModel{ID: &matchingImplementationID}, + SystemSecurityPlanId: sspID, + }).Error) + suite.Require().NoError(suite.DB.Create(&relational.ImplementedRequirement{ + UUIDModel: relational.UUIDModel{ID: ptrUUID(uuid.New())}, + ControlImplementationId: matchingImplementationID, + ControlId: controlID, + Remarks: "implemented requirement remarks", + }).Error) + + svc := suggestionrel.NewSuggestionService(suite.DB) + input, err := svc.GatherCellInput(sspID, suggestionrel.GridCell{ + ControlKeys: []string{suggestionrel.ControlKey(catalogID, controlID)}, + }, suggestionrel.GatherOptions{}) + suite.Require().NoError(err) + suite.Require().Len(input.Controls, 1) + suite.Equal("implemented requirement remarks", input.Controls[0].ImplementationText) } func (suite *DashboardSuggestionsIntegrationSuite) TestAcceptSSPIsolationAndGlobalFiltersStayVisible() { @@ -543,3 +617,7 @@ func (suite *DashboardSuggestionsIntegrationSuite) insertEvidenceLabels(id uuid. ).Error) } } + +func ptrUUID(id uuid.UUID) *uuid.UUID { + return &id +} diff --git a/internal/service/relational/suggestions/service.go b/internal/service/relational/suggestions/service.go index 5efe00b4..423cc266 100644 --- a/internal/service/relational/suggestions/service.go +++ b/internal/service/relational/suggestions/service.go @@ -225,7 +225,10 @@ func (s *SuggestionService) Accept(sspID uuid.UUID, suggestionIDs []uuid.UUID, a if group[i].Confidence != group[j].Confidence { return group[i].Confidence > group[j].Confidence } - return group[i].LabelSetHash < group[j].LabelSetHash + if group[i].ProposedFilterName != group[j].ProposedFilterName { + return group[i].ProposedFilterName < group[j].ProposedFilterName + } + return suggestionIDString(group[i]) < suggestionIDString(group[j]) }) labels := jsonMapToLabels(group[0].LabelSet) filterID, created, err := s.acceptFilterForHash(tx, sspID, hash, labels, group) @@ -274,6 +277,13 @@ func (s *SuggestionService) Accept(sspID uuid.UUID, suggestionIDs []uuid.UUID, a }) } +func suggestionIDString(suggestion DashboardSuggestion) string { + if suggestion.ID == nil { + return "" + } + return suggestion.ID.String() +} + func (s *SuggestionService) Reject(sspID uuid.UUID, suggestionIDs []uuid.UUID, reason string, actorID uuid.UUID) error { return s.db.Transaction(func(tx *gorm.DB) error { suggestions, err := loadPendingSuggestions(tx, sspID, suggestionIDs) diff --git a/internal/service/relational/suggestions/validation.go b/internal/service/relational/suggestions/validation.go index a53ef4f7..b7d913f3 100644 --- a/internal/service/relational/suggestions/validation.go +++ b/internal/service/relational/suggestions/validation.go @@ -123,8 +123,8 @@ func ValidateRawMappings(input CellInput, rawMappings []RawMapping) ValidationRe counts["rejected_empty_reasoning"]++ continue } - if len(reasoning) > MaxReasoningLength { - reasoning = reasoning[:MaxReasoningLength] + ReasoningTruncatedMarker + if truncated, ok := truncateRunes(reasoning, MaxReasoningLength, ReasoningTruncatedMarker); ok { + reasoning = truncated counts["reasoning_truncated"]++ } @@ -151,8 +151,8 @@ func ValidateRawMappings(input CellInput, rawMappings []RawMapping) ValidationRe name = fallbackFilterName(labelSet.Labels) counts["fallback_name"]++ } - if len(name) > 120 { - name = name[:120] + if truncated, ok := truncateRunes(name, 120, ""); ok { + name = truncated counts["name_truncated"]++ } } @@ -231,10 +231,35 @@ func fallbackFilterName(labels map[string]string) string { if name == "" { return "Evidence label set" } - if len(name) > 120 { - return name[:120] + truncated, _ := truncateRunes(name, 120, "") + return truncated +} + +func truncateRunes(value string, limit int, marker string) (string, bool) { + if limit <= 0 { + return value, false + } + count := 0 + for index := range value { + if count == limit { + return value[:index] + marker, true + } + count++ + } + return value, false +} + +func dedupeStrings(values []string) []string { + if len(values) < 2 { + return values } - return name + out := values[:0] + for _, value := range values { + if len(out) == 0 || out[len(out)-1] != value { + out = append(out, value) + } + } + return out } func ParseControlKey(controlKey string) (uuid.UUID, string, error) {