From 5aa010364350d1a7d9f7ad2633d77969394b76ef Mon Sep 17 00:00:00 2001 From: "ccf-lisa[bot]" <286799724+ccf-lisa[bot]@users.noreply.github.com> Date: Tue, 16 Jun 2026 08:53:33 -0300 Subject: [PATCH 1/5] implement: fix-prompt-to-use-fewer-labels --- internal/service/migrator.go | 16 ++ .../relational/suggestions/core_test.go | 87 +++++++++ .../service/relational/suggestions/gather.go | 1 + .../service/relational/suggestions/models.go | 31 ++-- .../suggestions/models_integration_test.go | 118 +++++++++++++ .../service/relational/suggestions/prompt.go | 14 +- .../service/relational/suggestions/service.go | 77 +++++--- .../suggestions/testdata/prompt_v2.golden | 84 +++++++++ .../relational/suggestions/validation.go | 166 +++++++++++++++--- internal/tests/migrate.go | 16 ++ 10 files changed, 541 insertions(+), 69 deletions(-) create mode 100644 internal/service/relational/suggestions/testdata/prompt_v2.golden diff --git a/internal/service/migrator.go b/internal/service/migrator.go index 39b89f29..7fb78bf7 100644 --- a/internal/service/migrator.go +++ b/internal/service/migrator.go @@ -359,6 +359,22 @@ func MigrateUpWithConfig(db *gorm.DB, cfg *config.Config) error { return err } + if err := db.Exec(` + DROP INDEX IF EXISTS idx_dashboard_suggestions_unique_pending_filter_labels + `).Error; err != nil { + return err + } + + if err := db.Exec(` + CREATE UNIQUE INDEX idx_dashboard_suggestions_unique_pending_filter_labels + ON dashboard_suggestions (ssp_id, control_catalog_id, control_id, proposed_filter_label_set) + WHERE status = 'pending' + AND proposed_filter_label_set IS NOT NULL + AND proposed_filter_label_set <> 'null'::jsonb + `).Error; err != nil { + return err + } + if err := db.Exec(` CREATE UNIQUE INDEX IF NOT EXISTS idx_dashboard_suggestion_runs_unique_active ON dashboard_suggestion_runs (ssp_id) diff --git a/internal/service/relational/suggestions/core_test.go b/internal/service/relational/suggestions/core_test.go index 295d3921..30f06d31 100644 --- a/internal/service/relational/suggestions/core_test.go +++ b/internal/service/relational/suggestions/core_test.go @@ -180,6 +180,93 @@ func TestValidateMappingsRules(t *testing.T) { require.Equal(t, "env=stage", result.Mappings[1].ProposedFilterName) } +func TestValidateMappingsFilterLabelSubsetRules(t *testing.T) { + controlKey := ControlKey(uuid.New(), "AC-1") + labels := map[string]string{ + "_agent": "agent-1", + "_plugin": "github_repos", + "_policy": "secret_scanning_push_protection", + "organization": "compliance-framework", + "provider": "github", + "repository": "todo-app", + "type": "repository", + } + hash := CanonicalLabelSetHash(labels) + input := CellInput{ + Controls: []ControlInput{{ControlKey: controlKey}}, + LabelSets: []LabelSetInput{{Hash: hash, Labels: labels}}, + } + + result := ValidateRawMappings(input, []RawMapping{ + { + ControlKey: controlKey, + LabelSetHash: hash, + Action: MappingActionNewFilter, + Confidence: 0.9, + Reasoning: "matches", + ProposedFilterLabels: map[string]string{"provider": "github", "type": "repository", "_agent": "agent-1"}, + }, + { + ControlKey: controlKey, + LabelSetHash: hash, + Action: MappingActionNewFilter, + Confidence: 0.8, + Reasoning: "hallucinated", + ProposedFilterLabels: map[string]string{"provider": "gitlab"}, + }, + }) + + require.Equal(t, 1, result.Counts["rejected_invalid_filter_labels"]) + require.Equal(t, 1, result.Counts["dropped_identity_filter_labels"]) + require.Equal(t, 1, result.Counts["added_policy_filter_label"]) + require.Len(t, result.Mappings, 1) + require.Equal(t, map[string]string{ + "_policy": "secret_scanning_push_protection", + "provider": "github", + "type": "repository", + }, result.Mappings[0].ProposedFilterLabelSet) +} + +func TestValidateMappingsDedupesByProposedFilterSubset(t *testing.T) { + controlKey := ControlKey(uuid.New(), "AC-1") + subset := map[string]string{ + "_policy": "secret_scanning_push_protection", + "provider": "github", + "type": "repository", + } + firstLabels := map[string]string{ + "_policy": "secret_scanning_push_protection", + "provider": "github", + "type": "repository", + "repository": "todo-app", + } + secondLabels := map[string]string{ + "_policy": "secret_scanning_push_protection", + "provider": "github", + "type": "repository", + "repository": "payments-api", + } + firstHash := CanonicalLabelSetHash(firstLabels) + secondHash := CanonicalLabelSetHash(secondLabels) + input := CellInput{ + Controls: []ControlInput{{ControlKey: controlKey}}, + LabelSets: []LabelSetInput{ + {Hash: firstHash, Labels: firstLabels}, + {Hash: secondHash, Labels: secondLabels}, + }, + } + + result := ValidateRawMappings(input, []RawMapping{ + {ControlKey: controlKey, LabelSetHash: firstHash, Action: MappingActionNewFilter, Confidence: 0.7, Reasoning: "matches", ProposedFilterLabels: subset}, + {ControlKey: controlKey, LabelSetHash: secondHash, Action: MappingActionNewFilter, Confidence: 0.9, Reasoning: "better match", ProposedFilterLabels: subset}, + }) + + require.Equal(t, 1, result.Counts["deduped_within_cell"]) + require.Len(t, result.Mappings, 1) + require.Equal(t, secondHash, result.Mappings[0].LabelSetHash) + require.Equal(t, subset, result.Mappings[0].ProposedFilterLabelSet) +} + func TestValidateMappingsControlCap(t *testing.T) { controlKey := ControlKey(uuid.New(), "AC-1") input := CellInput{Controls: []ControlInput{{ControlKey: controlKey}}} diff --git a/internal/service/relational/suggestions/gather.go b/internal/service/relational/suggestions/gather.go index 94613fa9..95acaaf2 100644 --- a/internal/service/relational/suggestions/gather.go +++ b/internal/service/relational/suggestions/gather.go @@ -340,6 +340,7 @@ func (s *SuggestionService) gatherVisibleFilters(sspID uuid.UUID) ([]VisibleFilt if labels, ok := CanonicalizeFilter(filter.Filter.Data()); ok { hash := CanonicalLabelSetHash(labels) input.LabelSetHash = &hash + input.Labels = labels } visible = append(visible, input) if filter.SSPID != nil && *filter.SSPID == sspID { diff --git a/internal/service/relational/suggestions/models.go b/internal/service/relational/suggestions/models.go index b7f50e7a..82cecabe 100644 --- a/internal/service/relational/suggestions/models.go +++ b/internal/service/relational/suggestions/models.go @@ -58,21 +58,22 @@ func (DashboardSuggestionRunCell) TableName() string { type DashboardSuggestion struct { relational.UUIDModel - RunID uuid.UUID `json:"runId" gorm:"type:uuid;not null;index"` - SSPID uuid.UUID `json:"sspId" gorm:"column:ssp_id;type:uuid;not null;index"` - ControlCatalogID uuid.UUID `json:"controlCatalogId" gorm:"type:uuid;not null"` - ControlID string `json:"controlId" gorm:"type:text;not null"` - LabelSet datatypes.JSONMap `json:"labelSet" gorm:"type:jsonb;not null"` - LabelSetHash string `json:"labelSetHash" gorm:"type:char(64);not null;index"` - TargetFilterID *uuid.UUID `json:"targetFilterId" gorm:"type:uuid;index"` - ProposedFilterName string `json:"proposedFilterName" gorm:"type:text;not null"` - Reasoning string `json:"reasoning" gorm:"type:text;not null"` - Confidence float64 `json:"confidence" gorm:"type:double precision;not null"` - Status string `json:"status" gorm:"type:varchar(16);not null;index"` - AcceptedFilterID *uuid.UUID `json:"acceptedFilterId" gorm:"type:uuid;index"` - DecidedByUserID *uuid.UUID `json:"decidedByUserId" gorm:"type:uuid;index"` - DecidedAt *time.Time `json:"decidedAt"` - RejectReason *string `json:"rejectReason" gorm:"type:text"` + RunID uuid.UUID `json:"runId" gorm:"type:uuid;not null;index"` + SSPID uuid.UUID `json:"sspId" gorm:"column:ssp_id;type:uuid;not null;index"` + ControlCatalogID uuid.UUID `json:"controlCatalogId" gorm:"type:uuid;not null"` + ControlID string `json:"controlId" gorm:"type:text;not null"` + LabelSet datatypes.JSONMap `json:"labelSet" gorm:"type:jsonb;not null"` + LabelSetHash string `json:"labelSetHash" gorm:"type:char(64);not null;index"` + ProposedFilterLabelSet datatypes.JSONMap `json:"proposedFilterLabelSet" gorm:"column:proposed_filter_label_set;type:jsonb"` + TargetFilterID *uuid.UUID `json:"targetFilterId" gorm:"type:uuid;index"` + ProposedFilterName string `json:"proposedFilterName" gorm:"type:text;not null"` + Reasoning string `json:"reasoning" gorm:"type:text;not null"` + Confidence float64 `json:"confidence" gorm:"type:double precision;not null"` + Status string `json:"status" gorm:"type:varchar(16);not null;index"` + AcceptedFilterID *uuid.UUID `json:"acceptedFilterId" gorm:"type:uuid;index"` + DecidedByUserID *uuid.UUID `json:"decidedByUserId" gorm:"type:uuid;index"` + DecidedAt *time.Time `json:"decidedAt"` + RejectReason *string `json:"rejectReason" gorm:"type:text"` Run *DashboardSuggestionRun `json:"-" gorm:"foreignKey:RunID;references:ID;constraint:OnDelete:CASCADE"` SystemSecurityPlan *relational.SystemSecurityPlan `json:"-" gorm:"foreignKey:SSPID;references:ID"` diff --git a/internal/service/relational/suggestions/models_integration_test.go b/internal/service/relational/suggestions/models_integration_test.go index c373911f..da5c67f1 100644 --- a/internal/service/relational/suggestions/models_integration_test.go +++ b/internal/service/relational/suggestions/models_integration_test.go @@ -362,6 +362,114 @@ func (suite *DashboardSuggestionsIntegrationSuite) TestAcceptExtendsSameSSPMatch suite.Equal(int64(1), linkCount) } +func (suite *DashboardSuggestionsIntegrationSuite) TestMinimalProposedFilterDedupesAndMatchesNewRepositories() { + sspID := uuid.New() + runID := uuid.New() + catalogID := uuid.New() + actorID := uuid.New() + suite.seedSuggestionSSPAndRun(sspID, runID) + + subset := map[string]string{ + "_policy": "secret_scanning_push_protection", + "provider": "github", + "type": "repository", + } + firstLabels := map[string]string{ + "_agent": "agent-1", + "_plugin": "github_repos", + "_policy": "secret_scanning_push_protection", + "organization": "compliance-framework", + "provider": "github", + "repository": "todo-app", + "team": "ccf", + "type": "repository", + } + secondLabels := map[string]string{ + "_agent": "agent-2", + "_plugin": "github_repos", + "_policy": "secret_scanning_push_protection", + "organization": "compliance-framework", + "provider": "github", + "repository": "payments-api", + "team": "ccf", + "type": "repository", + } + firstHash := suggestionrel.CanonicalLabelSetHash(firstLabels) + secondHash := suggestionrel.CanonicalLabelSetHash(secondLabels) + + svc := suggestionrel.NewSuggestionService(suite.DB) + result, err := svc.InsertValidatedMappings(runID, sspID, suggestionrel.PromptVersion, []suggestionrel.ValidatedMapping{ + { + ControlKey: suggestionrel.ControlKey(catalogID, "AC-1"), + LabelSetHash: firstHash, + LabelSet: firstLabels, + ProposedFilterLabelSet: subset, + Action: suggestionrel.MappingActionNewFilter, + ProposedFilterName: "GitHub push protection", + Confidence: 0.8, + Reasoning: "matches", + }, + { + ControlKey: suggestionrel.ControlKey(catalogID, "AC-1"), + LabelSetHash: secondHash, + LabelSet: secondLabels, + ProposedFilterLabelSet: subset, + Action: suggestionrel.MappingActionNewFilter, + ProposedFilterName: "GitHub push protection", + Confidence: 0.9, + Reasoning: "matches another repo", + }, + }, 10) + suite.Require().NoError(err) + suite.Equal(1, result.Inserted) + suite.Equal(1, result.Excluded) + + var suggestions []suggestionrel.DashboardSuggestion + suite.Require().NoError(suite.DB.Where("run_id = ?", runID).Find(&suggestions).Error) + suite.Require().Len(suggestions, 1) + suite.Equal(subset, jsonMapToStringMap(suggestions[0].ProposedFilterLabelSet)) + suite.NotEqual(suggestions[0].LabelSetHash, suggestionrel.CanonicalLabelSetHash(subset)) + suite.Contains(jsonMapToStringMap(suggestions[0].LabelSet), "repository") + + suite.Require().NoError(svc.Accept(sspID, []uuid.UUID{*suggestions[0].ID}, actorID)) + + var filter relational.Filter + suite.Require().NoError(suite.DB.First(&filter, "ssp_id = ?", sspID).Error) + filterLabels, ok := suggestionrel.CanonicalizeFilter(filter.Filter.Data()) + suite.True(ok) + suite.Equal(subset, filterLabels) + + var linkCount int64 + suite.Require().NoError(suite.DB.Table("filter_controls").Where("filter_id = ? AND control_id = ?", filter.ID, "AC-1").Count(&linkCount).Error) + suite.Equal(int64(1), linkCount) + var eventCount int64 + suite.Require().NoError(suite.DB.Model(&suggestionrel.DashboardSuggestionEvent{}). + Where("suggestion_id = ? AND event_type = ?", suggestions[0].ID, suggestionrel.DashboardSuggestionEventTypeAccepted). + Count(&eventCount).Error) + suite.Equal(int64(1), eventCount) + + now := time.Now().UTC() + suite.insertEvidenceLabels(uuid.New(), uuid.New(), "todo-app", now, firstLabels) + suite.insertEvidenceLabels(uuid.New(), uuid.New(), "payments-api", now, secondLabels) + thirdLabels := map[string]string{ + "_agent": "agent-3", + "_plugin": "github_repos", + "_policy": "secret_scanning_push_protection", + "organization": "new-org", + "provider": "github", + "repository": "future-repo", + "team": "new-team", + "type": "repository", + } + suite.insertEvidenceLabels(uuid.New(), uuid.New(), "future-repo", now, thirdLabels) + + query, err := relational.GetEvidenceSearchByFilterQuery(relational.GetLatestEvidenceStreamsQuery(suite.DB), suite.DB, filter.Filter.Data()) + suite.Require().NoError(err) + var matched []relational.Evidence + suite.Require().NoError(query.Find(&matched).Error) + suite.Require().Len(matched, 3) +} + func (suite *DashboardSuggestionsIntegrationSuite) TestInsertExcludesMatchingGlobalFilterAndDoesNotModifyIt() { sspID := uuid.New() runID := uuid.New() @@ -618,6 +726,16 @@ func (suite *DashboardSuggestionsIntegrationSuite) insertEvidenceLabels(id uuid. } } +func jsonMapToStringMap(values datatypes.JSONMap) map[string]string { + out := make(map[string]string, len(values)) + for key, value := range values { + if stringValue, ok := value.(string); ok { + out[key] = stringValue + } + } + return out +} + func ptrUUID(id uuid.UUID) *uuid.UUID { return &id } diff --git a/internal/service/relational/suggestions/prompt.go b/internal/service/relational/suggestions/prompt.go index 5daf8973..ca71fad4 100644 --- a/internal/service/relational/suggestions/prompt.go +++ b/internal/service/relational/suggestions/prompt.go @@ -6,7 +6,7 @@ import ( "text/template" ) -const PromptVersion = "v1" +const PromptVersion = "v2" const SystemPrompt = `You map compliance evidence streams to security controls for a specific SSP. @@ -23,9 +23,11 @@ For each shown control that a shown label-set provides evidence for, emit a mapp Respect qualifiers in the control text. A control scoped to a provider, technology, component, environment, or other qualifier only matches evidence whose labels satisfy that qualifier. -Use extend_filter with target_filter_id only when one of this plan's own dashboards has exactly that label-set. Otherwise use new_filter with a short descriptive proposed_filter_name. Global dashboards are listed only to avoid duplicate names; never extend them. +For every mapping, include proposed_filter_labels: the smallest key/value subset that the dashboard filter should use. Choose labels that capture the control's evidence intent and generalize to future components. Always include _policy when present. Do not include _agent or _plugin because they describe collection mechanics, not the evidence. Avoid instance identity labels such as repository, organization, account, host, namespace, project, or environment unless the control or system context is clearly scoped to that specific instance. -Only reference control_key and label_set_hash values present in the input. Reasoning must state both why the evidence satisfies the control and why it belongs to this system. Provide confidence from 0 to 1.` +Use extend_filter with target_filter_id only when one of this plan's own dashboards has exactly the same proposed_filter_labels. Otherwise use new_filter with a short descriptive proposed_filter_name. Global dashboards are listed only to avoid duplicate names; never extend them. + +Only reference control_key and label_set_hash values present in the input, and only choose proposed_filter_labels from labels present on that evidence label-set. Reasoning must state both why the evidence satisfies the control and why it belongs to this system. Provide confidence from 0 to 1.` func OutputSchema() map[string]any { return map[string]any{ @@ -38,7 +40,7 @@ func OutputSchema() map[string]any { "items": map[string]any{ "type": "object", "additionalProperties": false, - "required": []any{"control_key", "label_set_hash", "action", "confidence", "reasoning"}, + "required": []any{"control_key", "label_set_hash", "action", "proposed_filter_labels", "confidence", "reasoning"}, "properties": map[string]any{ "control_key": map[string]any{ "type": "string", @@ -56,6 +58,10 @@ func OutputSchema() map[string]any { "proposed_filter_name": map[string]any{ "type": "string", }, + "proposed_filter_labels": map[string]any{ + "type": "object", + "additionalProperties": map[string]any{"type": "string"}, + }, "confidence": map[string]any{ "type": "number", }, diff --git a/internal/service/relational/suggestions/service.go b/internal/service/relational/suggestions/service.go index 231b30c6..28ad2c02 100644 --- a/internal/service/relational/suggestions/service.go +++ b/internal/service/relational/suggestions/service.go @@ -173,6 +173,12 @@ func (s *SuggestionService) InsertValidatedMappingsTx(tx *gorm.DB, runID uuid.UU result.Capped++ continue } + filterLabels, _, ok := validateProposedFilterLabels(mapping.ProposedFilterLabelSet, mapping.LabelSet) + if !ok { + result.Excluded++ + continue + } + mapping.ProposedFilterLabelSet = filterLabels excluded, err := s.mappingExcluded(tx, sspID, promptVersion, mapping) if err != nil { return result, err @@ -186,17 +192,18 @@ func (s *SuggestionService) InsertValidatedMappingsTx(tx *gorm.DB, runID uuid.UU return result, err } suggestion := DashboardSuggestion{ - RunID: runID, - SSPID: sspID, - ControlCatalogID: catalogID, - ControlID: controlID, - LabelSet: labelsToJSONMap(mapping.LabelSet), - LabelSetHash: mapping.LabelSetHash, - TargetFilterID: mapping.TargetFilterID, - ProposedFilterName: mapping.ProposedFilterName, - Reasoning: mapping.Reasoning, - Confidence: mapping.Confidence, - Status: DashboardSuggestionStatusPending, + RunID: runID, + SSPID: sspID, + ControlCatalogID: catalogID, + ControlID: controlID, + LabelSet: labelsToJSONMap(mapping.LabelSet), + LabelSetHash: mapping.LabelSetHash, + ProposedFilterLabelSet: labelsToJSONMap(mapping.ProposedFilterLabelSet), + TargetFilterID: mapping.TargetFilterID, + ProposedFilterName: mapping.ProposedFilterName, + Reasoning: mapping.Reasoning, + Confidence: mapping.Confidence, + Status: DashboardSuggestionStatusPending, } create := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&suggestion) if create.Error != nil { @@ -229,12 +236,16 @@ func (s *SuggestionService) Accept(sspID uuid.UUID, suggestionIDs []uuid.UUID, a return err } now := time.Now().UTC() - byHash := map[string][]DashboardSuggestion{} + byFilterHash := map[string][]DashboardSuggestion{} for _, suggestion := range suggestions { - byHash[suggestion.LabelSetHash] = append(byHash[suggestion.LabelSetHash], suggestion) + filterLabels := suggestionFilterLabels(suggestion) + if len(filterLabels) == 0 { + return fmt.Errorf("dashboard suggestion %s has empty proposed filter labels", suggestionIDString(suggestion)) + } + byFilterHash[CanonicalLabelSetHash(filterLabels)] = append(byFilterHash[CanonicalLabelSetHash(filterLabels)], suggestion) } - for hash, group := range byHash { + for filterHash, group := range byFilterHash { sort.Slice(group, func(i, j int) bool { if group[i].Confidence != group[j].Confidence { return group[i].Confidence > group[j].Confidence @@ -244,8 +255,8 @@ func (s *SuggestionService) Accept(sspID uuid.UUID, suggestionIDs []uuid.UUID, a } return suggestionIDString(group[i]) < suggestionIDString(group[j]) }) - labels := jsonMapToLabels(group[0].LabelSet) - filterID, created, err := s.acceptFilterForHash(tx, sspID, hash, labels, group) + filterLabels := suggestionFilterLabels(group[0]) + filterID, created, err := s.acceptFilterForLabels(tx, sspID, filterHash, filterLabels, group) if err != nil { return err } @@ -331,8 +342,11 @@ func (s *SuggestionService) Reject(sspID uuid.UUID, suggestionIDs []uuid.UUID, r }) } -func (s *SuggestionService) acceptFilterForHash(tx *gorm.DB, sspID uuid.UUID, hash string, labels map[string]string, group []DashboardSuggestion) (uuid.UUID, bool, error) { - if err := lockAcceptFilterHash(tx, sspID, hash); err != nil { +func (s *SuggestionService) acceptFilterForLabels(tx *gorm.DB, sspID uuid.UUID, filterHash string, labels map[string]string, group []DashboardSuggestion) (uuid.UUID, bool, error) { + if len(labels) == 0 { + return uuid.Nil, false, errors.New("cannot accept dashboard suggestion with empty proposed filter labels") + } + if err := lockAcceptFilterHash(tx, sspID, filterHash); err != nil { return uuid.Nil, false, err } @@ -340,7 +354,7 @@ func (s *SuggestionService) acceptFilterForHash(tx *gorm.DB, sspID uuid.UUID, ha if suggestion.TargetFilterID == nil { continue } - filter, ok, err := loadSameSSPFilterWithHash(tx, sspID, *suggestion.TargetFilterID, hash) + filter, ok, err := loadSameSSPFilterWithHash(tx, sspID, *suggestion.TargetFilterID, filterHash) if err != nil { return uuid.Nil, false, err } @@ -355,7 +369,7 @@ func (s *SuggestionService) acceptFilterForHash(tx *gorm.DB, sspID uuid.UUID, ha } for _, filter := range filters { filterLabels, ok := CanonicalizeFilter(filter.Filter.Data()) - if ok && CanonicalLabelSetHash(filterLabels) == hash { + if ok && CanonicalLabelSetHash(filterLabels) == filterHash { return *filter.ID, false, nil } } @@ -426,6 +440,14 @@ func loadSameSSPFilterWithHash(tx *gorm.DB, sspID uuid.UUID, filterID uuid.UUID, return filter, true, nil } +func suggestionFilterLabels(suggestion DashboardSuggestion) map[string]string { + labels := jsonMapToLabels(suggestion.ProposedFilterLabelSet) + if len(labels) > 0 { + return labels + } + return jsonMapToLabels(suggestion.LabelSet) +} + func (s *SuggestionService) mappingExcluded(tx *gorm.DB, sspID uuid.UUID, promptVersion string, mapping ValidatedMapping) (bool, error) { catalogID, controlID, err := ParseControlKey(mapping.ControlKey) if err != nil { @@ -443,6 +465,19 @@ func (s *SuggestionService) mappingExcluded(tx *gorm.DB, sspID uuid.UUID, prompt return true, nil } + var pending []DashboardSuggestion + if err := tx. + Where("ssp_id = ? AND control_catalog_id = ? AND control_id = ? AND status = ?", sspID, catalogID, controlID, DashboardSuggestionStatusPending). + Find(&pending).Error; err != nil { + return false, err + } + mappingFilterHash := CanonicalLabelSetHash(mapping.ProposedFilterLabelSet) + for _, suggestion := range pending { + if CanonicalLabelSetHash(suggestionFilterLabels(suggestion)) == mappingFilterHash { + return true, nil + } + } + var filters []relational.Filter if err := tx. Joins("JOIN filter_controls ON filter_controls.filter_id = filters.id"). @@ -453,7 +488,7 @@ func (s *SuggestionService) mappingExcluded(tx *gorm.DB, sspID uuid.UUID, prompt } for _, filter := range filters { labels, ok := CanonicalizeFilter(filter.Filter.Data()) - if ok && CanonicalLabelSetHash(labels) == mapping.LabelSetHash { + if ok && CanonicalLabelSetHash(labels) == mappingFilterHash { return true, nil } } diff --git a/internal/service/relational/suggestions/testdata/prompt_v2.golden b/internal/service/relational/suggestions/testdata/prompt_v2.golden new file mode 100644 index 00000000..6f470518 --- /dev/null +++ b/internal/service/relational/suggestions/testdata/prompt_v2.golden @@ -0,0 +1,84 @@ +You map compliance evidence streams to security controls for a specific SSP. + +You are given: +1. The system context: name, description, components it uses, plus each control's implementation text from this plan. +2. A subset of the controls in this SSP. +3. A subset of evidence label-sets with sample titles. +4. Label-key documentation. +5. The dashboards visible to this plan. + +The controls and label-sets shown are one slice of a larger set. Only map what you see; absence here means nothing. + +For each shown control that a shown label-set provides evidence for, emit a mapping, but only when the evidence pertains to this system. When a label identifies an asset such as a repository, cluster, account, host, service, image, namespace, project, or environment, that asset must correspond to a component or description in the system context. Exclude evidence for assets this system does not use. + +Respect qualifiers in the control text. A control scoped to a provider, technology, component, environment, or other qualifier only matches evidence whose labels satisfy that qualifier. + +For every mapping, include proposed_filter_labels: the smallest key/value subset that the dashboard filter should use. Choose labels that capture the control's evidence intent and generalize to future components. Always include _policy when present. Do not include _agent or _plugin because they describe collection mechanics, not the evidence. Avoid instance identity labels such as repository, organization, account, host, namespace, project, or environment unless the control or system context is clearly scoped to that specific instance. + +Use extend_filter with target_filter_id only when one of this plan's own dashboards has exactly the same proposed_filter_labels. Otherwise use new_filter with a short descriptive proposed_filter_name. Global dashboards are listed only to avoid duplicate names; never extend them. + +Only reference control_key and label_set_hash values present in the input, and only choose proposed_filter_labels from labels present on that evidence label-set. Reasoning must state both why the evidence satisfies the control and why it belongs to this system. Provide confidence from 0 to 1. +---USER--- +Prompt version: v2 + +System context: +{ + "system_name": "Payments API", + "description": "Processes card payments.", + "components": [ + { + "title": "payments-api", + "type": "service", + "purpose": "payment processing", + "description": "Go API" + } + ] +} + +Controls: +[ + { + "control_key": "11111111-1111-1111-1111-111111111111:AC-1", + "catalog_id": "", + "control_id": "", + "catalog_title": "", + "title": "Policy", + "statement": "", + "implementation_text": "Uses payments-api." + } +] + +Evidence label-sets: +[ + { + "hash": "hash1", + "labels": { + "repo": "payments-api" + }, + "evidence_count": 2, + "sample_titles": [ + "scan" + ] + } +] + +Label-key documentation: +[ + { + "key": "repo", + "description": "Repository name" + } +] + +This SSP's extendable dashboards: +[ + { + "id": "22222222-2222-2222-2222-222222222222", + "name": "payments-api" + } +] + +Global dashboard names: +null + +Return JSON matching the provided schema. diff --git a/internal/service/relational/suggestions/validation.go b/internal/service/relational/suggestions/validation.go index b7d913f3..e1dd8dc2 100644 --- a/internal/service/relational/suggestions/validation.go +++ b/internal/service/relational/suggestions/validation.go @@ -40,10 +40,11 @@ type LabelSetInput struct { } type VisibleFilterInput struct { - ID uuid.UUID `json:"id"` - Name string `json:"name"` - SSPID *uuid.UUID `json:"ssp_id,omitempty"` - LabelSetHash *string `json:"label_set_hash,omitempty"` + ID uuid.UUID `json:"id"` + Name string `json:"name"` + SSPID *uuid.UUID `json:"ssp_id,omitempty"` + LabelSetHash *string `json:"label_set_hash,omitempty"` + Labels map[string]string `json:"labels,omitempty"` } type RawMappings struct { @@ -51,24 +52,26 @@ type RawMappings struct { } type RawMapping struct { - ControlKey string `json:"control_key"` - LabelSetHash string `json:"label_set_hash"` - Action string `json:"action"` - TargetFilterID string `json:"target_filter_id,omitempty"` - ProposedFilterName string `json:"proposed_filter_name,omitempty"` - Confidence float64 `json:"confidence"` - Reasoning string `json:"reasoning"` + ControlKey string `json:"control_key"` + LabelSetHash string `json:"label_set_hash"` + Action string `json:"action"` + TargetFilterID string `json:"target_filter_id,omitempty"` + ProposedFilterName string `json:"proposed_filter_name,omitempty"` + ProposedFilterLabels map[string]string `json:"proposed_filter_labels,omitempty"` + Confidence float64 `json:"confidence"` + Reasoning string `json:"reasoning"` } type ValidatedMapping struct { - ControlKey string - LabelSetHash string - LabelSet map[string]string - Action string - TargetFilterID *uuid.UUID - ProposedFilterName string - Confidence float64 - Reasoning string + ControlKey string + LabelSetHash string + LabelSet map[string]string + ProposedFilterLabelSet map[string]string + Action string + TargetFilterID *uuid.UUID + ProposedFilterName string + Confidence float64 + Reasoning string } type ValidationCounts map[string]int @@ -128,12 +131,20 @@ func ValidateRawMappings(input CellInput, rawMappings []RawMapping) ValidationRe counts["reasoning_truncated"]++ } + filterLabels, filterLabelCounts, ok := validateProposedFilterLabels(raw.ProposedFilterLabels, labelSet.Labels) + mergeValidationCounts(counts, filterLabelCounts) + if !ok { + counts["rejected_invalid_filter_labels"]++ + continue + } + action := raw.Action var targetFilterID *uuid.UUID if action == MappingActionExtendFilter { parsed, err := uuid.Parse(strings.TrimSpace(raw.TargetFilterID)) filter, found := sameSSPFilters[parsed] - if err != nil || !found || filter.LabelSetHash == nil || *filter.LabelSetHash != labelSetHash { + filterLabelHash := CanonicalLabelSetHash(filterLabels) + if err != nil || !found || filter.LabelSetHash == nil || *filter.LabelSetHash != filterLabelHash { action = MappingActionNewFilter counts["downgraded_extend_to_new"]++ } else { @@ -158,16 +169,17 @@ func ValidateRawMappings(input CellInput, rawMappings []RawMapping) ValidationRe } mapping := ValidatedMapping{ - ControlKey: controlKey, - LabelSetHash: labelSetHash, - LabelSet: labelSet.Labels, - Action: action, - TargetFilterID: targetFilterID, - ProposedFilterName: name, - Confidence: raw.Confidence, - Reasoning: reasoning, + ControlKey: controlKey, + LabelSetHash: labelSetHash, + LabelSet: labelSet.Labels, + ProposedFilterLabelSet: filterLabels, + Action: action, + TargetFilterID: targetFilterID, + ProposedFilterName: name, + Confidence: raw.Confidence, + Reasoning: reasoning, } - dedupeKey := controlKey + "\x00" + labelSetHash + dedupeKey := mappingDedupeKey(mapping) if existing, found := kept[dedupeKey]; !found || mapping.Confidence > existing.Confidence { if found { counts["deduped_within_cell"]++ @@ -196,6 +208,102 @@ func ValidateRawMappings(input CellInput, rawMappings []RawMapping) ValidationRe return ValidationResult{Mappings: mappings, Counts: counts} } +func mappingDedupeKey(mapping ValidatedMapping) string { + return mapping.ControlKey + "\x00" + CanonicalLabelSetHash(mapping.ProposedFilterLabelSet) +} + +func validateProposedFilterLabels(raw map[string]string, evidenceLabels map[string]string) (map[string]string, ValidationCounts, bool) { + counts := ValidationCounts{} + evidence, ok := NormalizeLabelSet(evidenceLabels) + if !ok { + return nil, counts, false + } + + var normalized map[string]string + if len(raw) > 0 { + normalized, ok = NormalizeLabelSet(raw) + if !ok { + return nil, counts, false + } + } else { + normalized = map[string]string{} + counts["fallback_filter_labels"]++ + } + + filterLabels := make(map[string]string, len(normalized)+1) + for key, value := range normalized { + if isGatheringIdentityLabel(key) { + counts["dropped_identity_filter_labels"]++ + continue + } + evidenceValue, found := evidence[key] + if !found || evidenceValue != value { + return nil, counts, false + } + filterLabels[key] = value + } + + if policy, found := evidence["_policy"]; found { + if existing, included := filterLabels["_policy"]; included && existing != policy { + return nil, counts, false + } + if _, included := filterLabels["_policy"]; !included { + counts["added_policy_filter_label"]++ + } + filterLabels["_policy"] = policy + } + + if len(filterLabels) == 0 { + filterLabels = defaultFilterLabelSubset(evidence) + if len(filterLabels) > 0 { + counts["fallback_filter_labels"]++ + } + } + if len(filterLabels) == 0 { + return nil, counts, false + } + return filterLabels, counts, true +} + +func defaultFilterLabelSubset(labels map[string]string) map[string]string { + out := map[string]string{} + for _, key := range []string{"_policy", "provider", "type"} { + if value, found := labels[key]; found && !isGatheringIdentityLabel(key) { + out[key] = value + } + } + if len(out) > 0 { + return out + } + + keys := make([]string, 0, len(labels)) + for key := range labels { + if !isGatheringIdentityLabel(key) { + keys = append(keys, key) + } + } + sort.Strings(keys) + for _, key := range keys { + out[key] = labels[key] + } + return out +} + +func isGatheringIdentityLabel(key string) bool { + switch strings.ToLower(key) { + case "_agent", "_plugin": + return true + default: + return false + } +} + +func mergeValidationCounts(dst ValidationCounts, src ValidationCounts) { + for key, value := range src { + dst[key] += value + } +} + func capMappingsPerControl(mappings []ValidatedMapping, counts ValidationCounts) []ValidatedMapping { byControl := map[string][]ValidatedMapping{} for _, mapping := range mappings { diff --git a/internal/tests/migrate.go b/internal/tests/migrate.go index e02aaacf..61df1f71 100644 --- a/internal/tests/migrate.go +++ b/internal/tests/migrate.go @@ -320,6 +320,22 @@ func (t *TestMigrator) Up() error { return err } + if err := t.db.Exec(` + DROP INDEX IF EXISTS idx_dashboard_suggestions_unique_pending_filter_labels + `).Error; err != nil { + return err + } + + if err := t.db.Exec(` + CREATE UNIQUE INDEX idx_dashboard_suggestions_unique_pending_filter_labels + ON dashboard_suggestions (ssp_id, control_catalog_id, control_id, proposed_filter_label_set) + WHERE status = 'pending' + AND proposed_filter_label_set IS NOT NULL + AND proposed_filter_label_set <> 'null'::jsonb + `).Error; err != nil { + return err + } + if err := t.db.Exec(` CREATE UNIQUE INDEX IF NOT EXISTS idx_dashboard_suggestion_runs_unique_active ON dashboard_suggestion_runs (ssp_id) From 4c9485ec8ecb59c35504521f7c4abf4b31e29129 Mon Sep 17 00:00:00 2001 From: "ccf-lisa[bot]" <286799724+ccf-lisa[bot]@users.noreply.github.com> Date: Tue, 16 Jun 2026 08:59:49 -0300 Subject: [PATCH 2/5] self-review: address pass 1 findings --- docs/docs.go | 6 ++++++ docs/swagger.json | 6 ++++++ docs/swagger.yaml | 4 ++++ 3 files changed, 16 insertions(+) diff --git a/docs/docs.go b/docs/docs.go index 12a3ccb9..7f6e23e3 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -32495,6 +32495,9 @@ const docTemplate = `{ "labelSetHash": { "type": "string" }, + "proposedFilterLabelSet": { + "$ref": "#/definitions/datatypes.JSONMap" + }, "proposedFilterName": { "type": "string" }, @@ -40481,6 +40484,9 @@ const docTemplate = `{ "labelSetHash": { "type": "string" }, + "proposedFilterLabelSet": { + "$ref": "#/definitions/datatypes.JSONMap" + }, "proposedFilterName": { "type": "string" }, diff --git a/docs/swagger.json b/docs/swagger.json index 4d5d5193..2158d73b 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -32489,6 +32489,9 @@ "labelSetHash": { "type": "string" }, + "proposedFilterLabelSet": { + "$ref": "#/definitions/datatypes.JSONMap" + }, "proposedFilterName": { "type": "string" }, @@ -40475,6 +40478,9 @@ "labelSetHash": { "type": "string" }, + "proposedFilterLabelSet": { + "$ref": "#/definitions/datatypes.JSONMap" + }, "proposedFilterName": { "type": "string" }, diff --git a/docs/swagger.yaml b/docs/swagger.yaml index a40f6be0..225e115a 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -3093,6 +3093,8 @@ definitions: $ref: '#/definitions/datatypes.JSONMap' labelSetHash: type: string + proposedFilterLabelSet: + $ref: '#/definitions/datatypes.JSONMap' proposedFilterName: type: string reasoning: @@ -8372,6 +8374,8 @@ definitions: $ref: '#/definitions/datatypes.JSONMap' labelSetHash: type: string + proposedFilterLabelSet: + $ref: '#/definitions/datatypes.JSONMap' proposedFilterName: type: string reasoning: From b0ffab59825facdcaf73aedb9429db01f78891ae Mon Sep 17 00:00:00 2001 From: "ccf-lisa[bot]" <286799724+ccf-lisa[bot]@users.noreply.github.com> Date: Tue, 16 Jun 2026 09:26:20 -0300 Subject: [PATCH 3/5] fix: address review feedback --- internal/service/migrator.go | 8 +---- .../relational/suggestions/core_test.go | 32 +++++++++++++++++++ .../relational/suggestions/validation.go | 22 +++++++++---- internal/tests/migrate.go | 8 +---- 4 files changed, 49 insertions(+), 21 deletions(-) diff --git a/internal/service/migrator.go b/internal/service/migrator.go index 7fb78bf7..9a53b18f 100644 --- a/internal/service/migrator.go +++ b/internal/service/migrator.go @@ -360,13 +360,7 @@ func MigrateUpWithConfig(db *gorm.DB, cfg *config.Config) error { } if err := db.Exec(` - DROP INDEX IF EXISTS idx_dashboard_suggestions_unique_pending_filter_labels - `).Error; err != nil { - return err - } - - if err := db.Exec(` - CREATE UNIQUE INDEX idx_dashboard_suggestions_unique_pending_filter_labels + CREATE UNIQUE INDEX IF NOT EXISTS idx_dashboard_suggestions_unique_pending_filter_labels ON dashboard_suggestions (ssp_id, control_catalog_id, control_id, proposed_filter_label_set) WHERE status = 'pending' AND proposed_filter_label_set IS NOT NULL diff --git a/internal/service/relational/suggestions/core_test.go b/internal/service/relational/suggestions/core_test.go index 30f06d31..b5bd6647 100644 --- a/internal/service/relational/suggestions/core_test.go +++ b/internal/service/relational/suggestions/core_test.go @@ -227,6 +227,38 @@ func TestValidateMappingsFilterLabelSubsetRules(t *testing.T) { }, result.Mappings[0].ProposedFilterLabelSet) } +func TestValidateMappingsEmptyFilterLabelsUseDefaultSubset(t *testing.T) { + controlKey := ControlKey(uuid.New(), "AC-1") + labels := map[string]string{ + "_policy": "secret_scanning_push_protection", + "organization": "compliance-framework", + "provider": "github", + "repository": "todo-app", + "type": "repository", + } + hash := CanonicalLabelSetHash(labels) + input := CellInput{ + Controls: []ControlInput{{ControlKey: controlKey}}, + LabelSets: []LabelSetInput{{Hash: hash, Labels: labels}}, + } + + result := ValidateRawMappings(input, []RawMapping{{ + ControlKey: controlKey, + LabelSetHash: hash, + Action: MappingActionNewFilter, + Confidence: 0.9, + Reasoning: "matches", + }}) + + require.Equal(t, 1, result.Counts["fallback_filter_labels"]) + require.Len(t, result.Mappings, 1) + require.Equal(t, map[string]string{ + "_policy": "secret_scanning_push_protection", + "provider": "github", + "type": "repository", + }, result.Mappings[0].ProposedFilterLabelSet) +} + func TestValidateMappingsDedupesByProposedFilterSubset(t *testing.T) { controlKey := ControlKey(uuid.New(), "AC-1") subset := map[string]string{ diff --git a/internal/service/relational/suggestions/validation.go b/internal/service/relational/suggestions/validation.go index e1dd8dc2..d2479ecd 100644 --- a/internal/service/relational/suggestions/validation.go +++ b/internal/service/relational/suggestions/validation.go @@ -219,15 +219,23 @@ func validateProposedFilterLabels(raw map[string]string, evidenceLabels map[stri return nil, counts, false } - var normalized map[string]string - if len(raw) > 0 { - normalized, ok = NormalizeLabelSet(raw) - if !ok { + if len(raw) == 0 { + filterLabels := defaultFilterLabelSubset(evidence) + if len(filterLabels) > 0 { + counts["fallback_filter_labels"]++ + } + if policy, found := evidence["_policy"]; found { + filterLabels["_policy"] = policy + } + if len(filterLabels) == 0 { return nil, counts, false } - } else { - normalized = map[string]string{} - counts["fallback_filter_labels"]++ + return filterLabels, counts, true + } + + normalized, ok := NormalizeLabelSet(raw) + if !ok { + return nil, counts, false } filterLabels := make(map[string]string, len(normalized)+1) diff --git a/internal/tests/migrate.go b/internal/tests/migrate.go index 61df1f71..821dcd7e 100644 --- a/internal/tests/migrate.go +++ b/internal/tests/migrate.go @@ -321,13 +321,7 @@ func (t *TestMigrator) Up() error { } if err := t.db.Exec(` - DROP INDEX IF EXISTS idx_dashboard_suggestions_unique_pending_filter_labels - `).Error; err != nil { - return err - } - - if err := t.db.Exec(` - CREATE UNIQUE INDEX idx_dashboard_suggestions_unique_pending_filter_labels + CREATE UNIQUE INDEX IF NOT EXISTS idx_dashboard_suggestions_unique_pending_filter_labels ON dashboard_suggestions (ssp_id, control_catalog_id, control_id, proposed_filter_label_set) WHERE status = 'pending' AND proposed_filter_label_set IS NOT NULL From cf3f8db2222833f9bdda763005df57ec55de1676 Mon Sep 17 00:00:00 2001 From: "ccf-lisa[bot]" <286799724+ccf-lisa[bot]@users.noreply.github.com> Date: Tue, 16 Jun 2026 10:00:01 -0300 Subject: [PATCH 4/5] fix: address review feedback --- .../relational/suggestions/core_test.go | 99 +++++++++++++++++++ .../service/relational/suggestions/prompt.go | 18 +++- .../suggestions/testdata/prompt_v2.golden | 2 +- .../relational/suggestions/validation.go | 40 ++++++-- 4 files changed, 147 insertions(+), 12 deletions(-) diff --git a/internal/service/relational/suggestions/core_test.go b/internal/service/relational/suggestions/core_test.go index b5bd6647..395ace2c 100644 --- a/internal/service/relational/suggestions/core_test.go +++ b/internal/service/relational/suggestions/core_test.go @@ -1,7 +1,9 @@ package suggestions import ( + "encoding/json" "errors" + "fmt" "os" "path/filepath" "strings" @@ -227,6 +229,103 @@ func TestValidateMappingsFilterLabelSubsetRules(t *testing.T) { }, result.Mappings[0].ProposedFilterLabelSet) } +func TestOutputSchemaProposedFilterLabelsUsesStrictPairArray(t *testing.T) { + schema := OutputSchema() + properties := schema["properties"].(map[string]any) + mappings := properties["mappings"].(map[string]any) + mappingItem := mappings["items"].(map[string]any) + mappingProperties := mappingItem["properties"].(map[string]any) + proposedLabels := mappingProperties["proposed_filter_labels"].(map[string]any) + + require.Equal(t, "array", proposedLabels["type"]) + item := proposedLabels["items"].(map[string]any) + require.Equal(t, "object", item["type"]) + require.Equal(t, false, item["additionalProperties"]) + require.ElementsMatch(t, []any{"key", "value"}, item["required"]) + itemProperties := item["properties"].(map[string]any) + require.Equal(t, map[string]any{"type": "string"}, itemProperties["key"]) + require.Equal(t, map[string]any{"type": "string"}, itemProperties["value"]) + + requireEverySchemaObjectIsStrict(t, schema) +} + +func requireEverySchemaObjectIsStrict(t *testing.T, node any) { + t.Helper() + switch typed := node.(type) { + case map[string]any: + if typed["type"] == "object" { + require.Equal(t, false, typed["additionalProperties"]) + } + for _, value := range typed { + requireEverySchemaObjectIsStrict(t, value) + } + case []any: + for _, value := range typed { + requireEverySchemaObjectIsStrict(t, value) + } + } +} + +func TestValidateMappingsDecodesProposedFilterLabelsPairArrayLikeLegacyMap(t *testing.T) { + controlKey := ControlKey(uuid.New(), "AC-1") + labels := map[string]string{ + "_policy": "secret_scanning_push_protection", + "provider": "github", + "type": "repository", + } + hash := CanonicalLabelSetHash(labels) + input := CellInput{ + Controls: []ControlInput{{ControlKey: controlKey}}, + LabelSets: []LabelSetInput{{Hash: hash, Labels: labels}}, + } + rawTemplate := `{ + "mappings": [{ + "control_key": %q, + "label_set_hash": %q, + "action": "new_filter", + "confidence": 0.9, + "reasoning": "matches", + "proposed_filter_labels": %s + }] + }` + arrayResult, err := ValidateMappings(input, []byte(fmt.Sprintf( + rawTemplate, + controlKey, + hash, + `[{"key":"provider","value":"github"},{"key":"type","value":"repository"}]`, + ))) + require.NoError(t, err) + legacyResult, err := ValidateMappings(input, []byte(fmt.Sprintf( + rawTemplate, + controlKey, + hash, + `{"provider":"github","type":"repository"}`, + ))) + require.NoError(t, err) + + require.Len(t, arrayResult.Mappings, 1) + require.Equal(t, legacyResult.Mappings[0].ProposedFilterLabelSet, arrayResult.Mappings[0].ProposedFilterLabelSet) +} + +func TestProposedFilterLabelsPairArrayDuplicateKeyLastWins(t *testing.T) { + var decoded RawMappings + err := json.Unmarshal([]byte(`{ + "mappings": [{ + "control_key": "control", + "label_set_hash": "hash", + "action": "new_filter", + "confidence": 0.9, + "reasoning": "matches", + "proposed_filter_labels": [ + {"key":"provider","value":"aws"}, + {"key":"provider","value":"github"} + ] + }] + }`), &decoded) + require.NoError(t, err) + require.Equal(t, ProposedFilterLabels{"provider": "github"}, decoded.Mappings[0].ProposedFilterLabels) +} + func TestValidateMappingsEmptyFilterLabelsUseDefaultSubset(t *testing.T) { controlKey := ControlKey(uuid.New(), "AC-1") labels := map[string]string{ diff --git a/internal/service/relational/suggestions/prompt.go b/internal/service/relational/suggestions/prompt.go index ca71fad4..fb29c14b 100644 --- a/internal/service/relational/suggestions/prompt.go +++ b/internal/service/relational/suggestions/prompt.go @@ -23,7 +23,7 @@ For each shown control that a shown label-set provides evidence for, emit a mapp Respect qualifiers in the control text. A control scoped to a provider, technology, component, environment, or other qualifier only matches evidence whose labels satisfy that qualifier. -For every mapping, include proposed_filter_labels: the smallest key/value subset that the dashboard filter should use. Choose labels that capture the control's evidence intent and generalize to future components. Always include _policy when present. Do not include _agent or _plugin because they describe collection mechanics, not the evidence. Avoid instance identity labels such as repository, organization, account, host, namespace, project, or environment unless the control or system context is clearly scoped to that specific instance. +For every mapping, include proposed_filter_labels as a list of {"key","value"} pairs: the smallest label subset that the dashboard filter should use. Choose labels that capture the control's evidence intent and generalize to future components. Always include _policy when present. Do not include _agent or _plugin because they describe collection mechanics, not the evidence. Avoid instance identity labels such as repository, organization, account, host, namespace, project, or environment unless the control or system context is clearly scoped to that specific instance. Use extend_filter with target_filter_id only when one of this plan's own dashboards has exactly the same proposed_filter_labels. Otherwise use new_filter with a short descriptive proposed_filter_name. Global dashboards are listed only to avoid duplicate names; never extend them. @@ -59,8 +59,20 @@ func OutputSchema() map[string]any { "type": "string", }, "proposed_filter_labels": map[string]any{ - "type": "object", - "additionalProperties": map[string]any{"type": "string"}, + "type": "array", + "items": map[string]any{ + "type": "object", + "additionalProperties": false, + "required": []any{"key", "value"}, + "properties": map[string]any{ + "key": map[string]any{ + "type": "string", + }, + "value": map[string]any{ + "type": "string", + }, + }, + }, }, "confidence": map[string]any{ "type": "number", diff --git a/internal/service/relational/suggestions/testdata/prompt_v2.golden b/internal/service/relational/suggestions/testdata/prompt_v2.golden index 6f470518..cf0f9f2c 100644 --- a/internal/service/relational/suggestions/testdata/prompt_v2.golden +++ b/internal/service/relational/suggestions/testdata/prompt_v2.golden @@ -13,7 +13,7 @@ For each shown control that a shown label-set provides evidence for, emit a mapp Respect qualifiers in the control text. A control scoped to a provider, technology, component, environment, or other qualifier only matches evidence whose labels satisfy that qualifier. -For every mapping, include proposed_filter_labels: the smallest key/value subset that the dashboard filter should use. Choose labels that capture the control's evidence intent and generalize to future components. Always include _policy when present. Do not include _agent or _plugin because they describe collection mechanics, not the evidence. Avoid instance identity labels such as repository, organization, account, host, namespace, project, or environment unless the control or system context is clearly scoped to that specific instance. +For every mapping, include proposed_filter_labels as a list of {"key","value"} pairs: the smallest label subset that the dashboard filter should use. Choose labels that capture the control's evidence intent and generalize to future components. Always include _policy when present. Do not include _agent or _plugin because they describe collection mechanics, not the evidence. Avoid instance identity labels such as repository, organization, account, host, namespace, project, or environment unless the control or system context is clearly scoped to that specific instance. Use extend_filter with target_filter_id only when one of this plan's own dashboards has exactly the same proposed_filter_labels. Otherwise use new_filter with a short descriptive proposed_filter_name. Global dashboards are listed only to avoid duplicate names; never extend them. diff --git a/internal/service/relational/suggestions/validation.go b/internal/service/relational/suggestions/validation.go index d2479ecd..61dad6a1 100644 --- a/internal/service/relational/suggestions/validation.go +++ b/internal/service/relational/suggestions/validation.go @@ -52,14 +52,38 @@ type RawMappings struct { } type RawMapping struct { - ControlKey string `json:"control_key"` - LabelSetHash string `json:"label_set_hash"` - Action string `json:"action"` - TargetFilterID string `json:"target_filter_id,omitempty"` - ProposedFilterName string `json:"proposed_filter_name,omitempty"` - ProposedFilterLabels map[string]string `json:"proposed_filter_labels,omitempty"` - Confidence float64 `json:"confidence"` - Reasoning string `json:"reasoning"` + ControlKey string `json:"control_key"` + LabelSetHash string `json:"label_set_hash"` + Action string `json:"action"` + TargetFilterID string `json:"target_filter_id,omitempty"` + ProposedFilterName string `json:"proposed_filter_name,omitempty"` + ProposedFilterLabels ProposedFilterLabels `json:"proposed_filter_labels,omitempty"` + Confidence float64 `json:"confidence"` + Reasoning string `json:"reasoning"` +} + +type ProposedFilterLabels map[string]string + +func (labels *ProposedFilterLabels) UnmarshalJSON(raw []byte) error { + var pairs []struct { + Key string `json:"key"` + Value string `json:"value"` + } + if err := json.Unmarshal(raw, &pairs); err == nil { + decoded := make(map[string]string, len(pairs)) + for _, pair := range pairs { + decoded[pair.Key] = pair.Value + } + *labels = decoded + return nil + } + + var legacy map[string]string + if err := json.Unmarshal(raw, &legacy); err != nil { + return err + } + *labels = legacy + return nil } type ValidatedMapping struct { From f76859798693bbfe7eddca6164e614760ed716f8 Mon Sep 17 00:00:00 2001 From: Gustavo Carvalho Date: Wed, 17 Jun 2026 09:55:21 -0300 Subject: [PATCH 5/5] fix: several fixes Signed-off-by: Gustavo Carvalho --- docs/docs.go | 104 ++++++++++-------- docs/swagger.json | 104 ++++++++++-------- docs/swagger.yaml | 69 +++++++----- .../handler/oscal/dashboard_suggestions.go | 72 +++++++++++- .../dashboard_suggestions_integration_test.go | 38 ++++++- internal/api/handler/users.go | 6 +- .../api/handler/users_integration_test.go | 2 +- internal/service/llm/anthropic.go | 70 ++++++++++-- internal/service/llm/anthropic_test.go | 91 +++++++++++++++ internal/service/llm/client.go | 42 +++++++ .../relational/suggestions/core_test.go | 7 +- .../service/relational/suggestions/prompt.go | 60 ++++++++-- .../suggestions/testdata/prompt_v2.golden | 42 +++---- .../worker/dashboard_suggestion_worker.go | 94 ++++++++++++++-- ...oard_suggestion_worker_integration_test.go | 29 ++++- .../dashboard_suggestion_worker_test.go | 86 ++++++++++++++- 16 files changed, 731 insertions(+), 185 deletions(-) diff --git a/docs/docs.go b/docs/docs.go index 7f6e23e3..42ade9db 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -16572,7 +16572,7 @@ const docTemplate = `{ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/handler.GenericDataListResponse-suggestions_DashboardSuggestionEvent" + "$ref": "#/definitions/handler.GenericDataListResponse-oscal_dashboardSuggestionEventResponse" } }, "400": { @@ -28947,6 +28947,18 @@ const docTemplate = `{ } } }, + "handler.GenericDataListResponse-oscal_dashboardSuggestionEventResponse": { + "type": "object", + "properties": { + "data": { + "description": "Items from the list response", + "type": "array", + "items": { + "$ref": "#/definitions/oscal.dashboardSuggestionEventResponse" + } + } + } + }, "handler.GenericDataListResponse-oscal_dashboardSuggestionResponse": { "type": "object", "properties": { @@ -29055,18 +29067,6 @@ const docTemplate = `{ } } }, - "handler.GenericDataListResponse-suggestions_DashboardSuggestionEvent": { - "type": "object", - "properties": { - "data": { - "description": "Items from the list response", - "type": "array", - "items": { - "$ref": "#/definitions/suggestions.DashboardSuggestionEvent" - } - } - } - }, "handler.GenericDataListResponse-suggestions_LabelSetInput": { "type": "object", "properties": { @@ -32442,6 +32442,41 @@ const docTemplate = `{ } } }, + "oscal.dashboardSuggestionEventResponse": { + "type": "object", + "properties": { + "actor": { + "$ref": "#/definitions/oscal.suggestionEventActor" + }, + "actorUserId": { + "type": "string" + }, + "details": { + "type": "string" + }, + "eventType": { + "type": "string" + }, + "id": { + "type": "string" + }, + "occurredAt": { + "type": "string" + }, + "payload": { + "$ref": "#/definitions/datatypes.JSONMap" + }, + "runId": { + "type": "string" + }, + "snapshot": { + "$ref": "#/definitions/datatypes.JSONMap" + }, + "suggestionId": { + "type": "string" + } + } + }, "oscal.dashboardSuggestionPreviewResponse": { "type": "object", "properties": { @@ -32666,6 +32701,17 @@ const docTemplate = `{ } } }, + "oscal.suggestionEventActor": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "name": { + "type": "string" + } + } + }, "oscalTypes_1_1_3.Action": { "type": "object", "properties": { @@ -40510,38 +40556,6 @@ const docTemplate = `{ } } }, - "suggestions.DashboardSuggestionEvent": { - "type": "object", - "properties": { - "actorUserId": { - "type": "string" - }, - "details": { - "type": "string" - }, - "eventType": { - "type": "string" - }, - "id": { - "type": "string" - }, - "occurredAt": { - "type": "string" - }, - "payload": { - "$ref": "#/definitions/datatypes.JSONMap" - }, - "runId": { - "type": "string" - }, - "snapshot": { - "$ref": "#/definitions/datatypes.JSONMap" - }, - "suggestionId": { - "type": "string" - } - } - }, "suggestions.DashboardSuggestionRunCell": { "type": "object", "properties": { diff --git a/docs/swagger.json b/docs/swagger.json index 2158d73b..023ab624 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -16566,7 +16566,7 @@ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/handler.GenericDataListResponse-suggestions_DashboardSuggestionEvent" + "$ref": "#/definitions/handler.GenericDataListResponse-oscal_dashboardSuggestionEventResponse" } }, "400": { @@ -28941,6 +28941,18 @@ } } }, + "handler.GenericDataListResponse-oscal_dashboardSuggestionEventResponse": { + "type": "object", + "properties": { + "data": { + "description": "Items from the list response", + "type": "array", + "items": { + "$ref": "#/definitions/oscal.dashboardSuggestionEventResponse" + } + } + } + }, "handler.GenericDataListResponse-oscal_dashboardSuggestionResponse": { "type": "object", "properties": { @@ -29049,18 +29061,6 @@ } } }, - "handler.GenericDataListResponse-suggestions_DashboardSuggestionEvent": { - "type": "object", - "properties": { - "data": { - "description": "Items from the list response", - "type": "array", - "items": { - "$ref": "#/definitions/suggestions.DashboardSuggestionEvent" - } - } - } - }, "handler.GenericDataListResponse-suggestions_LabelSetInput": { "type": "object", "properties": { @@ -32436,6 +32436,41 @@ } } }, + "oscal.dashboardSuggestionEventResponse": { + "type": "object", + "properties": { + "actor": { + "$ref": "#/definitions/oscal.suggestionEventActor" + }, + "actorUserId": { + "type": "string" + }, + "details": { + "type": "string" + }, + "eventType": { + "type": "string" + }, + "id": { + "type": "string" + }, + "occurredAt": { + "type": "string" + }, + "payload": { + "$ref": "#/definitions/datatypes.JSONMap" + }, + "runId": { + "type": "string" + }, + "snapshot": { + "$ref": "#/definitions/datatypes.JSONMap" + }, + "suggestionId": { + "type": "string" + } + } + }, "oscal.dashboardSuggestionPreviewResponse": { "type": "object", "properties": { @@ -32660,6 +32695,17 @@ } } }, + "oscal.suggestionEventActor": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "name": { + "type": "string" + } + } + }, "oscalTypes_1_1_3.Action": { "type": "object", "properties": { @@ -40504,38 +40550,6 @@ } } }, - "suggestions.DashboardSuggestionEvent": { - "type": "object", - "properties": { - "actorUserId": { - "type": "string" - }, - "details": { - "type": "string" - }, - "eventType": { - "type": "string" - }, - "id": { - "type": "string" - }, - "occurredAt": { - "type": "string" - }, - "payload": { - "$ref": "#/definitions/datatypes.JSONMap" - }, - "runId": { - "type": "string" - }, - "snapshot": { - "$ref": "#/definitions/datatypes.JSONMap" - }, - "suggestionId": { - "type": "string" - } - } - }, "suggestions.DashboardSuggestionRunCell": { "type": "object", "properties": { diff --git a/docs/swagger.yaml b/docs/swagger.yaml index 225e115a..dfcf3a85 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -629,6 +629,14 @@ definitions: $ref: '#/definitions/oscal.ProfileHandler' type: array type: object + handler.GenericDataListResponse-oscal_dashboardSuggestionEventResponse: + properties: + data: + description: Items from the list response + items: + $ref: '#/definitions/oscal.dashboardSuggestionEventResponse' + type: array + type: object handler.GenericDataListResponse-oscal_dashboardSuggestionResponse: properties: data: @@ -949,14 +957,6 @@ definitions: $ref: '#/definitions/relational.User' type: array type: object - handler.GenericDataListResponse-suggestions_DashboardSuggestionEvent: - properties: - data: - description: Items from the list response - items: - $ref: '#/definitions/suggestions.DashboardSuggestionEvent' - type: array - type: object handler.GenericDataListResponse-suggestions_LabelSetInput: properties: data: @@ -3058,6 +3058,29 @@ definitions: required: - ids type: object + oscal.dashboardSuggestionEventResponse: + properties: + actor: + $ref: '#/definitions/oscal.suggestionEventActor' + actorUserId: + type: string + details: + type: string + eventType: + type: string + id: + type: string + occurredAt: + type: string + payload: + $ref: '#/definitions/datatypes.JSONMap' + runId: + type: string + snapshot: + $ref: '#/definitions/datatypes.JSONMap' + suggestionId: + type: string + type: object oscal.dashboardSuggestionPreviewResponse: properties: controlCount: @@ -3205,6 +3228,13 @@ definitions: value: type: string type: object + oscal.suggestionEventActor: + properties: + id: + type: string + name: + type: string + type: object oscalTypes_1_1_3.Action: properties: date: @@ -8391,27 +8421,6 @@ definitions: targetFilterId: type: string type: object - suggestions.DashboardSuggestionEvent: - properties: - actorUserId: - type: string - details: - type: string - eventType: - type: string - id: - type: string - occurredAt: - type: string - payload: - $ref: '#/definitions/datatypes.JSONMap' - runId: - type: string - snapshot: - $ref: '#/definitions/datatypes.JSONMap' - suggestionId: - type: string - type: object suggestions.DashboardSuggestionRunCell: properties: cellIndex: @@ -20472,7 +20481,7 @@ paths: "200": description: OK schema: - $ref: '#/definitions/handler.GenericDataListResponse-suggestions_DashboardSuggestionEvent' + $ref: '#/definitions/handler.GenericDataListResponse-oscal_dashboardSuggestionEventResponse' "400": description: Bad Request schema: diff --git a/internal/api/handler/oscal/dashboard_suggestions.go b/internal/api/handler/oscal/dashboard_suggestions.go index f767dbab..90a3eea5 100644 --- a/internal/api/handler/oscal/dashboard_suggestions.go +++ b/internal/api/handler/oscal/dashboard_suggestions.go @@ -61,6 +61,19 @@ type dashboardSuggestionResponse struct { TargetFilterName string `json:"targetFilterName,omitempty"` } +// suggestionEventActor carries the resolved display details for the user that +// triggered a dashboard suggestion event, so the UI can render who acted rather +// than just an opaque user ID. +type suggestionEventActor struct { + ID string `json:"id"` + Name string `json:"name"` +} + +type dashboardSuggestionEventResponse struct { + suggestionrel.DashboardSuggestionEvent + Actor *suggestionEventActor `json:"actor,omitempty"` +} + type acceptDashboardSuggestionsResponse struct { AcceptedFilterIDs []uuid.UUID `json:"acceptedFilterIds"` Suggestions []dashboardSuggestionResponse `json:"suggestions"` @@ -464,7 +477,7 @@ func (h *DashboardSuggestionHandler) Reject(ctx echo.Context) error { // @Produce json // @Param id path string true "System Security Plan ID" // @Param suggestionId path string true "Dashboard suggestion ID" -// @Success 200 {object} handler.GenericDataListResponse[suggestions.DashboardSuggestionEvent] +// @Success 200 {object} handler.GenericDataListResponse[oscal.dashboardSuggestionEventResponse] // @Failure 400 {object} api.Error // @Failure 401 {object} api.Error // @Failure 404 {object} api.Error @@ -491,7 +504,62 @@ func (h *DashboardSuggestionHandler) Events(ctx echo.Context) error { if err := h.db.Where("suggestion_id = ?", suggestionID).Order("occurred_at ASC").Find(&events).Error; err != nil { return ctx.JSON(http.StatusInternalServerError, api.NewError(err)) } - return ctx.JSON(http.StatusOK, handler.GenericDataListResponse[suggestionrel.DashboardSuggestionEvent]{Data: events}) + + actors, err := h.resolveEventActors(events) + if err != nil { + return ctx.JSON(http.StatusInternalServerError, api.NewError(err)) + } + + responses := make([]dashboardSuggestionEventResponse, 0, len(events)) + for _, event := range events { + resp := dashboardSuggestionEventResponse{DashboardSuggestionEvent: event} + if event.ActorUserID != nil { + if actor, ok := actors[*event.ActorUserID]; ok { + resp.Actor = &actor + } + } + responses = append(responses, resp) + } + + return ctx.JSON(http.StatusOK, handler.GenericDataListResponse[dashboardSuggestionEventResponse]{Data: responses}) +} + +// resolveEventActors loads the distinct actor users referenced by the given +// events and returns a map keyed by user ID. Users that no longer exist are +// simply omitted from the map. +func (h *DashboardSuggestionHandler) resolveEventActors(events []suggestionrel.DashboardSuggestionEvent) (map[uuid.UUID]suggestionEventActor, error) { + ids := make([]uuid.UUID, 0, len(events)) + seen := make(map[uuid.UUID]struct{}, len(events)) + for _, event := range events { + if event.ActorUserID == nil { + continue + } + if _, ok := seen[*event.ActorUserID]; ok { + continue + } + seen[*event.ActorUserID] = struct{}{} + ids = append(ids, *event.ActorUserID) + } + + actors := make(map[uuid.UUID]suggestionEventActor, len(ids)) + if len(ids) == 0 { + return actors, nil + } + + var users []relational.User + if err := h.db.Where("id IN ?", ids).Find(&users).Error; err != nil { + return nil, err + } + for _, user := range users { + if user.ID == nil { + continue + } + actors[*user.ID] = suggestionEventActor{ + ID: user.ID.String(), + Name: handler.UserDisplayName(user), + } + } + return actors, nil } func (h *DashboardSuggestionHandler) aiEnabled() bool { diff --git a/internal/api/handler/oscal/dashboard_suggestions_integration_test.go b/internal/api/handler/oscal/dashboard_suggestions_integration_test.go index 7ac2d6bd..7a64f75a 100644 --- a/internal/api/handler/oscal/dashboard_suggestions_integration_test.go +++ b/internal/api/handler/oscal/dashboard_suggestions_integration_test.go @@ -420,6 +420,40 @@ func (suite *DashboardSuggestionsHTTPSuite) TestAcceptCreatesSSPFilterAndWritesE suite.Equal(int64(2), eventCount) } +func (suite *DashboardSuggestionsHTTPSuite) TestEventsResolveActorDisplayDetails() { + sspID, controlKeys, _ := suite.seedScope([]string{"AC-1"}, []map[string]string{{"env": "prod"}}) + runID := suite.seedSuggestionRun(sspID) + catalogID, _ := suite.parseControlKey(controlKeys[0]) + labels := map[string]string{"env": "prod"} + hash := suggestionrel.CanonicalLabelSetHash(labels) + suggestion := suite.seedDashboardSuggestion(runID, sspID, catalogID, "AC-1", labels, hash, "prod evidence", 0.9) + + body := dashboardSuggestionDecisionRequest{IDs: []uuid.UUID{*suggestion.ID}} + rec, req := suite.req(http.MethodPost, fmt.Sprintf("/api/oscal/system-security-plans/%s/dashboard-suggestions/accept", sspID), body) + suite.server.E().ServeHTTP(rec, req) + suite.Require().Equal(http.StatusOK, rec.Code, rec.Body.String()) + + rec, req = suite.req(http.MethodGet, fmt.Sprintf("/api/oscal/system-security-plans/%s/dashboard-suggestions/%s/events", sspID, suggestion.ID), nil) + suite.server.E().ServeHTTP(rec, req) + suite.Require().Equal(http.StatusOK, rec.Code, rec.Body.String()) + + var eventsResponse apihandler.GenericDataListResponse[dashboardSuggestionEventResponse] + suite.Require().NoError(json.Unmarshal(rec.Body.Bytes(), &eventsResponse)) + + var acceptedEvent *dashboardSuggestionEventResponse + for i := range eventsResponse.Data { + if eventsResponse.Data[i].EventType == string(suggestionrel.DashboardSuggestionEventTypeAccepted) { + acceptedEvent = &eventsResponse.Data[i] + break + } + } + suite.Require().NotNil(acceptedEvent, "expected an accepted event") + suite.Require().NotNil(acceptedEvent.ActorUserID, "accepted event should record the actor") + suite.Require().NotNil(acceptedEvent.Actor, "accepted event should resolve actor details") + suite.Equal(acceptedEvent.ActorUserID.String(), acceptedEvent.Actor.ID) + suite.Equal("Dummy User", acceptedEvent.Actor.Name) +} + func (suite *DashboardSuggestionsHTTPSuite) TestRejectPersistsDecisionAndWritesEvents() { sspID, controlKeys, _ := suite.seedScope([]string{"AC-1"}, []map[string]string{{"env": "prod"}}) runID := suite.seedSuggestionRun(sspID) @@ -475,10 +509,12 @@ func (suite *DashboardSuggestionsHTTPSuite) TestListSuggestionsAndEventsScopeByS rec, req = suite.req(http.MethodGet, fmt.Sprintf("/api/oscal/system-security-plans/%s/dashboard-suggestions/%s/events", sspID, suggestion.ID), nil) suite.server.E().ServeHTTP(rec, req) suite.Equal(http.StatusOK, rec.Code, rec.Body.String()) - var eventsResponse apihandler.GenericDataListResponse[suggestionrel.DashboardSuggestionEvent] + var eventsResponse apihandler.GenericDataListResponse[dashboardSuggestionEventResponse] suite.Require().NoError(json.Unmarshal(rec.Body.Bytes(), &eventsResponse)) suite.Require().Len(eventsResponse.Data, 1) suite.Equal(string(suggestionrel.DashboardSuggestionEventTypeSuggestionCreated), eventsResponse.Data[0].EventType) + // This event has no actor, so no actor details should be resolved. + suite.Nil(eventsResponse.Data[0].Actor) rec, req = suite.req(http.MethodGet, fmt.Sprintf("/api/oscal/system-security-plans/%s/dashboard-suggestions/%s/events", otherSSPID, suggestion.ID), nil) suite.server.E().ServeHTTP(rec, req) diff --git a/internal/api/handler/users.go b/internal/api/handler/users.go index 5acd25ef..e6c0a659 100644 --- a/internal/api/handler/users.go +++ b/internal/api/handler/users.go @@ -184,7 +184,7 @@ func (h *UserHandler) ListSelectableUsers(ctx echo.Context) error { responses = append(responses, selectableUserResponse{ ID: user.ID.String(), - DisplayName: userDisplayName(user), + DisplayName: UserDisplayName(user), }) } @@ -271,7 +271,7 @@ func (h *UserHandler) GetPublicUser(ctx echo.Context) error { return ctx.JSON(200, GenericDataResponse[publicUserResponse]{ Data: publicUserResponse{ ID: user.ID.String(), - Name: userDisplayName(user), + Name: UserDisplayName(user), }, }) } @@ -299,7 +299,7 @@ func (h *UserHandler) attachAuthProvider(resp *userResponse) { resp.AuthProvider = &link.Provider } -func userDisplayName(user relational.User) string { +func UserDisplayName(user relational.User) string { if user.ID == nil { return "" } diff --git a/internal/api/handler/users_integration_test.go b/internal/api/handler/users_integration_test.go index b6b4a770..39a83085 100644 --- a/internal/api/handler/users_integration_test.go +++ b/internal/api/handler/users_integration_test.go @@ -107,7 +107,7 @@ func (suite *UserApiIntegrationSuite) TestGetPublicUser() { err = json.Unmarshal(rec.Body.Bytes(), &response) suite.Require().NoError(err, "Expected valid JSON response for GetPublicUser") suite.Require().Equal(existingUser.UUIDModel.ID.String(), response.Data.ID, "Expected matching user ID in response for GetPublicUser") - suite.Require().Equal(userDisplayName(existingUser), response.Data.Name, "Expected public user name to match the user's display name") + suite.Require().Equal(UserDisplayName(existingUser), response.Data.Name, "Expected public user name to match the user's display name") blankNameUser := relational.User{ Email: "blank-name-user@example.com", diff --git a/internal/service/llm/anthropic.go b/internal/service/llm/anthropic.go index a88ec699..47c03084 100644 --- a/internal/service/llm/anthropic.go +++ b/internal/service/llm/anthropic.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net/http" + "strconv" "strings" "time" @@ -98,18 +99,32 @@ func (c *AnthropicClient) CompleteStructured(ctx context.Context, req Structured maxTokens = DefaultAnthropicMaxTokens } + // Build the user message as one or two text blocks. A non-empty cached + // prefix is sent first with its own cache breakpoint so the volatile tail in + // req.Prompt stays uncached. + userBlocks := make([]anthropic.ContentBlockParamUnion, 0, 2) + if req.CachedUserPrefix != "" { + prefix := anthropic.NewTextBlock(req.CachedUserPrefix) + prefix.OfText.CacheControl = cacheControlParam(req.CachedUserPrefixTTL) + userBlocks = append(userBlocks, prefix) + } + userBlocks = append(userBlocks, anthropic.NewTextBlock(req.Prompt)) + params := anthropic.MessageNewParams{ MaxTokens: int64(maxTokens), Model: anthropic.Model(c.model), Messages: []anthropic.MessageParam{ - anthropic.NewUserMessage(anthropic.NewTextBlock(req.Prompt)), + anthropic.NewUserMessage(userBlocks...), }, OutputConfig: anthropic.OutputConfigParam{ Format: anthropic.JSONOutputFormatParam{Schema: sanitizeAnthropicSchema(req.Schema)}, }, } if req.System != "" { - params.System = []anthropic.TextBlockParam{{Text: req.System}} + params.System = []anthropic.TextBlockParam{{ + Text: req.System, + CacheControl: cacheControlParam(req.SystemCacheTTL), + }} } msg, err := c.client.Messages.New(callCtx, params, option.WithRequestTimeout(c.requestTimeout), option.WithMaxRetries(0)) @@ -123,13 +138,32 @@ func (c *AnthropicClient) CompleteStructured(ctx context.Context, req Structured } return &StructuredResponse{ - Raw: raw, - Model: string(msg.Model), - InputTokens: int(msg.Usage.InputTokens), - OutputTokens: int(msg.Usage.OutputTokens), + Raw: raw, + Model: string(msg.Model), + InputTokens: int(msg.Usage.InputTokens), + OutputTokens: int(msg.Usage.OutputTokens), + CacheReadInputTokens: int(msg.Usage.CacheReadInputTokens), + CacheCreationInputTokens: int(msg.Usage.CacheCreationInputTokens), }, nil } +// cacheControlParam maps a CacheTTL to the SDK cache-control param. The empty +// TTL returns the zero value, which the SDK omits (no cache breakpoint). +func cacheControlParam(ttl CacheTTL) anthropic.CacheControlEphemeralParam { + switch ttl { + case CacheTTL5m: + cc := anthropic.NewCacheControlEphemeralParam() + cc.TTL = anthropic.CacheControlEphemeralTTLTTL5m + return cc + case CacheTTL1h: + cc := anthropic.NewCacheControlEphemeralParam() + cc.TTL = anthropic.CacheControlEphemeralTTLTTL1h + return cc + default: + return anthropic.CacheControlEphemeralParam{} + } +} + func structuredRawJSON(msg *anthropic.Message) (json.RawMessage, error) { if msg == nil { return nil, fmt.Errorf("%w: empty provider response", ErrInvalidOutput) @@ -184,7 +218,7 @@ func mapAnthropicError(ctx context.Context, callCtx context.Context, err error) // sending; malformed or otherwise rejected schemas are still invalid input. return fmt.Errorf("%w: %v", ErrInvalidOutput, err) case http.StatusTooManyRequests: - return fmt.Errorf("%w: %v", ErrRateLimited, err) + return &RateLimitError{RetryAfter: retryAfterFromResponse(apiErr.Response), Err: err} case 529: return fmt.Errorf("%w: %v", ErrOverloaded, err) case http.StatusRequestTimeout, http.StatusConflict: @@ -210,3 +244,25 @@ func mapAnthropicError(ctx context.Context, callCtx context.Context, err error) } return fmt.Errorf("%w: %v", ErrOverloaded, err) } + +// retryAfterFromResponse reads the Retry-After header (integer seconds or an +// HTTP-date) from a 429 response. It returns 0 when the header is absent or +// unparseable so callers can fall back to a default backoff. +func retryAfterFromResponse(resp *http.Response) time.Duration { + if resp == nil { + return 0 + } + value := strings.TrimSpace(resp.Header.Get("Retry-After")) + if value == "" { + return 0 + } + if secs, err := strconv.Atoi(value); err == nil && secs >= 0 { + return time.Duration(secs) * time.Second + } + if t, err := http.ParseTime(value); err == nil { + if d := time.Until(t); d > 0 { + return d + } + } + return 0 +} diff --git a/internal/service/llm/anthropic_test.go b/internal/service/llm/anthropic_test.go index aff4d24b..f5492cd4 100644 --- a/internal/service/llm/anthropic_test.go +++ b/internal/service/llm/anthropic_test.go @@ -203,6 +203,97 @@ func TestAnthropicClientStructuredOutputSchemaPassthrough(t *testing.T) { require.Equal(t, float64(2), schema["properties"].(map[string]any)["answer"].(map[string]any)["minLength"]) } +func TestAnthropicClientRateLimitCarriesRetryAfter(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Retry-After", "12") + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"type":"error","error":{"type":"rate_limit_error","message":"slow down"}}`)) + })) + t.Cleanup(server.Close) + + client := NewAnthropicClient(AnthropicConfig{ + Enabled: true, + APIKey: "test-key", + Model: "claude-test", + BaseURL: server.URL, + RequestTimeout: time.Second, + }) + + _, err := client.CompleteStructured(context.Background(), StructuredRequest{ + Prompt: "prompt", + Schema: map[string]any{"type": "object"}, + MaxTokens: 64, + }) + + require.ErrorIs(t, err, ErrRateLimited) + var rateLimit *RateLimitError + require.ErrorAs(t, err, &rateLimit) + require.Equal(t, 12*time.Second, rateLimit.RetryAfter) +} + +func TestAnthropicClientAppliesCacheControl(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var payload map[string]any + require.NoError(t, json.NewDecoder(r.Body).Decode(&payload)) + + system := payload["system"].([]any) + require.Len(t, system, 1) + systemCache := system[0].(map[string]any)["cache_control"].(map[string]any) + require.Equal(t, "ephemeral", systemCache["type"]) + require.Equal(t, "1h", systemCache["ttl"]) + + messages := payload["messages"].([]any) + require.Len(t, messages, 1) + content := messages[0].(map[string]any)["content"].([]any) + require.Len(t, content, 2) + + prefixCache := content[0].(map[string]any)["cache_control"].(map[string]any) + require.Equal(t, "1h", prefixCache["ttl"]) + require.Equal(t, "controls", content[0].(map[string]any)["text"]) + + // The volatile tail must stay uncached. + _, volatileHasCache := content[1].(map[string]any)["cache_control"] + require.False(t, volatileHasCache) + require.Equal(t, "labels", content[1].(map[string]any)["text"]) + + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "id":"msg_test", + "type":"message", + "role":"assistant", + "model":"claude-test", + "content":[{"type":"text","text":"{}"}], + "stop_reason":"end_turn", + "stop_sequence":"", + "usage":{"input_tokens":1,"output_tokens":1,"cache_read_input_tokens":40,"cache_creation_input_tokens":20} + }`)) + })) + t.Cleanup(server.Close) + + client := NewAnthropicClient(AnthropicConfig{ + Enabled: true, + APIKey: "test-key", + Model: "claude-test", + BaseURL: server.URL, + RequestTimeout: time.Second, + }) + + resp, err := client.CompleteStructured(context.Background(), StructuredRequest{ + System: "sys", + SystemCacheTTL: CacheTTL1h, + CachedUserPrefix: "controls", + CachedUserPrefixTTL: CacheTTL1h, + Prompt: "labels", + Schema: map[string]any{"type": "object"}, + MaxTokens: 64, + }) + + require.NoError(t, err) + require.Equal(t, 40, resp.CacheReadInputTokens) + require.Equal(t, 20, resp.CacheCreationInputTokens) +} + func TestSanitizeAnthropicSchema(t *testing.T) { schema := map[string]any{ "type": "object", diff --git a/internal/service/llm/client.go b/internal/service/llm/client.go index 8d51841a..f1058f28 100644 --- a/internal/service/llm/client.go +++ b/internal/service/llm/client.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "errors" + "fmt" + "time" ) type Client interface { @@ -12,11 +14,30 @@ type Client interface { CompleteStructured(ctx context.Context, req StructuredRequest) (*StructuredResponse, error) } +// CacheTTL controls prompt-cache breakpoints. The empty value disables caching +// for that block; "5m" and "1h" map to Anthropic's ephemeral cache lifetimes. +type CacheTTL string + +const ( + CacheTTLNone CacheTTL = "" + CacheTTL5m CacheTTL = "5m" + CacheTTL1h CacheTTL = "1h" +) + type StructuredRequest struct { System string Prompt string Schema map[string]any // JSON Schema for the output object MaxTokens int + + // Prompt caching (additive; zero values preserve the uncached behaviour). + // SystemCacheTTL sets a cache breakpoint at the end of the system block. + SystemCacheTTL CacheTTL + // CachedUserPrefix, when non-empty, is emitted as a user text block before + // Prompt with its own cache breakpoint. Use it for the large, stable portion + // of the user content so the volatile tail in Prompt stays uncached. + CachedUserPrefix string + CachedUserPrefixTTL CacheTTL } type StructuredResponse struct { @@ -24,6 +45,10 @@ type StructuredResponse struct { Model string InputTokens int OutputTokens int + // Prompt-cache accounting; cache reads do not count against ITPM rate limits + // for most models, so these are useful for verifying caching is effective. + CacheReadInputTokens int + CacheCreationInputTokens int } var ( @@ -33,3 +58,20 @@ var ( ErrOverloaded = errors.New("llm: provider overloaded") ErrInvalidOutput = errors.New("llm: output did not match schema") ) + +// RateLimitError wraps a provider 429 and carries the Retry-After hint (zero +// when the provider did not supply one). It unwraps to ErrRateLimited so +// existing errors.Is(err, ErrRateLimited) checks keep working. +type RateLimitError struct { + RetryAfter time.Duration + Err error +} + +func (e *RateLimitError) Error() string { + if e.Err != nil { + return fmt.Sprintf("%s: %v", ErrRateLimited.Error(), e.Err) + } + return ErrRateLimited.Error() +} + +func (e *RateLimitError) Unwrap() error { return ErrRateLimited } diff --git a/internal/service/relational/suggestions/core_test.go b/internal/service/relational/suggestions/core_test.go index 395ace2c..29a6f1e5 100644 --- a/internal/service/relational/suggestions/core_test.go +++ b/internal/service/relational/suggestions/core_test.go @@ -460,10 +460,13 @@ func TestPromptGolden(t *testing.T) { LabelKeyDocs: []LabelKeyDocInput{{Key: "repo", Description: "Repository name"}}, SameSSPFilters: []VisibleFilterInput{{ID: uuid.MustParse("22222222-2222-2222-2222-222222222222"), Name: "payments-api"}}, } - gotPrompt, err := RenderPrompt(input) + rendered, err := RenderPrompt(input) require.NoError(t, err) - got := SystemPrompt + "\n---USER---\n" + gotPrompt + got := rendered.System + "\n---CONTROLS---\n" + rendered.Controls + "\n---LABELSETS---\n" + rendered.Volatile path := filepath.Join("testdata", "prompt_"+PromptVersion+".golden") + if os.Getenv("UPDATE_GOLDEN") != "" { + require.NoError(t, os.WriteFile(path, []byte(got+"\n"), 0o644)) + } expected, err := os.ReadFile(path) require.NoError(t, err) require.Equal(t, strings.TrimSuffix(string(expected), "\n"), got) diff --git a/internal/service/relational/suggestions/prompt.go b/internal/service/relational/suggestions/prompt.go index fb29c14b..d4dc1ca2 100644 --- a/internal/service/relational/suggestions/prompt.go +++ b/internal/service/relational/suggestions/prompt.go @@ -87,17 +87,24 @@ func OutputSchema() map[string]any { } } -const userPromptTemplate = `Prompt version: {{.PromptVersion}} +// The prompt is split into three segments so the request can place prompt-cache +// breakpoints at the boundaries: +// +// System — the instructions plus run-stable context (system context, +// label-key docs, dashboards). Identical for every cell of an SSP. +// Controls — the control chunk for this cell. Identical across the label-set +// cells that share a control-row, and the token-heavy dimension. +// Volatile — the per-cell evidence label-sets and the closing instruction. +// +// Keeping the volatile content last means the System and Controls prefixes are +// byte-stable and therefore cacheable. +const systemBlockTemplate = `{{.SystemPrompt}} + +Prompt version: {{.PromptVersion}} System context: {{json .Input.SystemContext}} -Controls: -{{json .Input.Controls}} - -Evidence label-sets: -{{json .Input.LabelSets}} - Label-key documentation: {{json .Input.LabelKeyDocs}} @@ -105,12 +112,26 @@ This SSP's extendable dashboards: {{json .Input.SameSSPFilters}} Global dashboard names: -{{json .Input.GlobalFilterNames}} +{{json .Input.GlobalFilterNames}}` + +const controlsBlockTemplate = `Controls: +{{json .Input.Controls}}` + +const volatileBlockTemplate = `Evidence label-sets: +{{json .Input.LabelSets}} Return JSON matching the provided schema.` -func RenderPrompt(input GatheredInput) (string, error) { - tmpl, err := template.New("dashboard-suggestion-prompt").Funcs(template.FuncMap{ +// RenderedPrompt holds the cacheable system/controls prefixes and the volatile +// tail for a single cell. +type RenderedPrompt struct { + System string + Controls string + Volatile string +} + +func renderPromptTemplate(name, text string, input GatheredInput) (string, error) { + tmpl, err := template.New(name).Funcs(template.FuncMap{ "json": func(value any) (string, error) { raw, err := json.MarshalIndent(value, "", " ") if err != nil { @@ -118,12 +139,13 @@ func RenderPrompt(input GatheredInput) (string, error) { } return string(raw), nil }, - }).Parse(userPromptTemplate) + }).Parse(text) if err != nil { return "", err } var buf bytes.Buffer err = tmpl.Execute(&buf, map[string]any{ + "SystemPrompt": SystemPrompt, "PromptVersion": PromptVersion, "Input": input, }) @@ -132,3 +154,19 @@ func RenderPrompt(input GatheredInput) (string, error) { } return buf.String(), nil } + +func RenderPrompt(input GatheredInput) (RenderedPrompt, error) { + system, err := renderPromptTemplate("dashboard-suggestion-system", systemBlockTemplate, input) + if err != nil { + return RenderedPrompt{}, err + } + controls, err := renderPromptTemplate("dashboard-suggestion-controls", controlsBlockTemplate, input) + if err != nil { + return RenderedPrompt{}, err + } + volatile, err := renderPromptTemplate("dashboard-suggestion-volatile", volatileBlockTemplate, input) + if err != nil { + return RenderedPrompt{}, err + } + return RenderedPrompt{System: system, Controls: controls, Volatile: volatile}, nil +} diff --git a/internal/service/relational/suggestions/testdata/prompt_v2.golden b/internal/service/relational/suggestions/testdata/prompt_v2.golden index cf0f9f2c..bbeade7e 100644 --- a/internal/service/relational/suggestions/testdata/prompt_v2.golden +++ b/internal/service/relational/suggestions/testdata/prompt_v2.golden @@ -18,7 +18,7 @@ For every mapping, include proposed_filter_labels as a list of {"key","value"} p Use extend_filter with target_filter_id only when one of this plan's own dashboards has exactly the same proposed_filter_labels. Otherwise use new_filter with a short descriptive proposed_filter_name. Global dashboards are listed only to avoid duplicate names; never extend them. Only reference control_key and label_set_hash values present in the input, and only choose proposed_filter_labels from labels present on that evidence label-set. Reasoning must state both why the evidence satisfies the control and why it belongs to this system. Provide confidence from 0 to 1. ----USER--- + Prompt version: v2 System context: @@ -35,6 +35,25 @@ System context: ] } +Label-key documentation: +[ + { + "key": "repo", + "description": "Repository name" + } +] + +This SSP's extendable dashboards: +[ + { + "id": "22222222-2222-2222-2222-222222222222", + "name": "payments-api" + } +] + +Global dashboard names: +null +---CONTROLS--- Controls: [ { @@ -47,7 +66,7 @@ Controls: "implementation_text": "Uses payments-api." } ] - +---LABELSETS--- Evidence label-sets: [ { @@ -62,23 +81,4 @@ Evidence label-sets: } ] -Label-key documentation: -[ - { - "key": "repo", - "description": "Repository name" - } -] - -This SSP's extendable dashboards: -[ - { - "id": "22222222-2222-2222-2222-222222222222", - "name": "payments-api" - } -] - -Global dashboard names: -null - Return JSON matching the provided schema. diff --git a/internal/service/worker/dashboard_suggestion_worker.go b/internal/service/worker/dashboard_suggestion_worker.go index 58d4d89f..77815798 100644 --- a/internal/service/worker/dashboard_suggestion_worker.go +++ b/internal/service/worker/dashboard_suggestion_worker.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "math/rand/v2" "time" "github.com/compliance-framework/api/internal/config" @@ -29,6 +30,18 @@ const ( dashboardSuggestionCellStatusFailed = "failed" ) +const ( + // defaultRateLimitSnooze is used when a 429 carries no Retry-After hint. + // The ITPM limit replenishes per minute, so a sub-minute base is reasonable. + defaultRateLimitSnooze = 30 * time.Second + // rateLimitSnoozeJitter is the upper bound of random delay added to each + // snooze to de-synchronise concurrently throttled workers. + rateLimitSnoozeJitter = 15 * time.Second + // maxRateLimitSnoozes caps how many times a cell may be snoozed before it is + // failed, so a persistently throttled run cannot snooze forever. + maxRateLimitSnoozes = 20 +) + type DashboardSuggestionWorker struct { river.WorkerDefaults[DashboardSuggestionCellArgs] @@ -80,14 +93,22 @@ func (w *DashboardSuggestionWorker) Work(ctx context.Context, job *river.Job[Das missingLabelSets = 0 } - prompt, err := suggestionrel.RenderPrompt(gathered) + rendered, err := suggestionrel.RenderPrompt(gathered) if err != nil { return w.handleAttemptFailure(ctx, job, err) } - response, err := w.completeWithOneRetry(ctx, prompt) + response, err := w.completeWithRetry(ctx, rendered) if err != nil { - if isNonRetryableLLMError(err) { + // A rate limit skips the cell and requeues it after the provider's + // Retry-After (or a default) plus jitter, without consuming a regular + // attempt or marking the cell failed. + if delay, ok := rateLimitSnooze(err, job.Attempt); ok { + return river.JobSnooze(delay) + } + // Non-retryable errors, and rate limits that have exhausted their snooze + // budget, fail the cell and cancel the job. + if isNonRetryableLLMError(err) || errors.Is(err, llm.ErrRateLimited) { if markErr := w.failCellAndMaybeFinalize(ctx, job.Args, err); markErr != nil { return markErr } @@ -96,6 +117,20 @@ func (w *DashboardSuggestionWorker) Work(ctx context.Context, job *river.Job[Das return w.handleAttemptFailure(ctx, job, err) } + if w.logger != nil { + // Surface prompt-cache accounting so cache effectiveness is observable + // without relying on the provider dashboard. cache_read_input_tokens are + // the "cache hits"; cache_creation_input_tokens are first-time writes. + w.logger.Infow("dashboard suggestion cell llm usage", + "run_id", job.Args.RunID, + "cell_index", cell.CellIndex, + "input_tokens", response.InputTokens, + "cache_creation_input_tokens", response.CacheCreationInputTokens, + "cache_read_input_tokens", response.CacheReadInputTokens, + "output_tokens", response.OutputTokens, + ) + } + rawCount, err := rawMappingCount(response.Raw) if err != nil { err = fmt.Errorf("%w: %v", llm.ErrInvalidOutput, err) @@ -170,14 +205,19 @@ func (w *DashboardSuggestionWorker) loadPendingCellAndStartRun(ctx context.Conte return run, cell, true, nil } -func (w *DashboardSuggestionWorker) completeWithOneRetry(ctx context.Context, prompt string) (*llm.StructuredResponse, error) { +func (w *DashboardSuggestionWorker) completeWithRetry(ctx context.Context, rendered suggestionrel.RenderedPrompt) (*llm.StructuredResponse, error) { requestTimeout := dashboardSuggestionRequestTimeout(w.aiCfg) + // Two prompt-cache breakpoints (1h TTL): the system block (run-stable) and + // the controls prefix (row-stable). The volatile label-sets stay uncached. req := llm.StructuredRequest{ - System: suggestionrel.SystemPrompt, - Prompt: prompt, - Schema: suggestionrel.OutputSchema(), - MaxTokens: llm.DefaultAnthropicMaxTokens, + System: rendered.System, + SystemCacheTTL: llm.CacheTTL1h, + CachedUserPrefix: rendered.Controls, + CachedUserPrefixTTL: llm.CacheTTL1h, + Prompt: rendered.Volatile, + Schema: suggestionrel.OutputSchema(), + MaxTokens: llm.DefaultAnthropicMaxTokens, } var lastErr error @@ -189,13 +229,38 @@ func (w *DashboardSuggestionWorker) completeWithOneRetry(ctx context.Context, pr return response, nil } lastErr = err - if !isRetryableLLMError(err) { + if !isInlineRetryableLLMError(err) { break } } return nil, lastErr } +// rateLimitSnooze decides whether a failed completion should be snoozed. ok is +// false when err is not a rate limit, or when the snooze budget for this cell is +// exhausted (so the caller can fail it instead of snoozing forever). +func rateLimitSnooze(err error, attempt int) (time.Duration, bool) { + var rateLimit *llm.RateLimitError + if !errors.As(err, &rateLimit) { + return 0, false + } + if attempt >= maxRateLimitSnoozes { + return 0, false + } + return snoozeDelay(rateLimit.RetryAfter), true +} + +// snoozeDelay returns how long to defer a rate-limited cell: the provider's +// Retry-After when available (otherwise a default), plus random jitter so the +// concurrent workers that all hit the limit at once do not wake in lockstep. +func snoozeDelay(retryAfter time.Duration) time.Duration { + base := retryAfter + if base <= 0 { + base = defaultRateLimitSnooze + } + return base + rand.N(rateLimitSnoozeJitter) +} + func dashboardSuggestionRequestTimeout(aiCfg *config.AIConfig) time.Duration { if aiCfg != nil && aiCfg.RequestTimeout > 0 { return aiCfg.RequestTimeout @@ -407,8 +472,15 @@ func rawMappingCount(raw json.RawMessage) (int, error) { return len(decoded.Mappings), nil } -func isRetryableLLMError(err error) bool { - return errors.Is(err, llm.ErrRateLimited) || errors.Is(err, llm.ErrOverloaded) +// isInlineRetryableLLMError reports whether an immediate in-process retry is +// worthwhile. Overloaded errors are transient and worth one quick retry; rate +// limits are deliberately excluded so they bubble up to be snoozed instead of +// burning another call into the same throttled window. +func isInlineRetryableLLMError(err error) bool { + if errors.Is(err, llm.ErrRateLimited) { + return false + } + return errors.Is(err, llm.ErrOverloaded) } func isNonRetryableLLMError(err error) bool { diff --git a/internal/service/worker/dashboard_suggestion_worker_integration_test.go b/internal/service/worker/dashboard_suggestion_worker_integration_test.go index 09cfd85e..4056bcf2 100644 --- a/internal/service/worker/dashboard_suggestion_worker_integration_test.go +++ b/internal/service/worker/dashboard_suggestion_worker_integration_test.go @@ -93,6 +93,28 @@ func (suite *DashboardSuggestionWorkerIntegrationSuite) TestTwoByTwoGridConcurre suite.Equal(suggestionCount, afterRerunCount) } +func (suite *DashboardSuggestionWorkerIntegrationSuite) TestRateLimitSnoozesCellWithoutFailing() { + ctx := context.Background() + runID, cells := suite.seedTwoByTwoSuggestionRun() + client := &llm.FakeClient{Err: &llm.RateLimitError{RetryAfter: 5 * time.Second}} + worker := NewDashboardSuggestionWorker(suite.DB, client, &config.AIConfig{RequestTimeout: 120 * time.Second, MaxSuggestionsPerRun: 10}, zap.NewNop().Sugar()) + + err := worker.Work(ctx, dashboardSuggestionIntegrationJob(runID, cells[0].CellIndex)) + + // The cell is snoozed (not failed) and the run is not finalized as failed. + var snooze *rivertype.JobSnoozeError + suite.Require().ErrorAs(err, &snooze) + suite.Require().GreaterOrEqual(snooze.Duration, 5*time.Second) + + var cell suggestionrel.DashboardSuggestionRunCell + suite.Require().NoError(suite.DB.First(&cell, "run_id = ? AND cell_index = ?", runID, cells[0].CellIndex).Error) + suite.Require().Equal(dashboardSuggestionCellStatusPending, cell.Status) + + var run suggestionrel.DashboardSuggestionRun + suite.Require().NoError(suite.DB.First(&run, "id = ?", runID).Error) + suite.Require().NotEqual(dashboardSuggestionRunStatusFailed, run.Status) +} + func (suite *DashboardSuggestionWorkerIntegrationSuite) seedTwoByTwoSuggestionRun() (uuid.UUID, []suggestionrel.GridCell) { sspID := uuid.New() runID := uuid.New() @@ -189,8 +211,11 @@ func (c *promptMappingClient) CompleteStructured(ctx context.Context, req llm.St c.requests++ c.mu.Unlock() - controlKey := firstPromptValue(req.Prompt, `"control_key": "([^"]+)"`) - labelSetHash := firstPromptValue(req.Prompt, `"hash": "([^"]+)"`) + // Controls and label-sets are split across cache segments now, so search the + // whole rendered request rather than just the volatile tail. + combined := req.System + "\n" + req.CachedUserPrefix + "\n" + req.Prompt + controlKey := firstPromptValue(combined, `"control_key": "([^"]+)"`) + labelSetHash := firstPromptValue(combined, `"hash": "([^"]+)"`) raw, err := json.Marshal(suggestionrel.RawMappings{Mappings: []suggestionrel.RawMapping{ { ControlKey: controlKey, diff --git a/internal/service/worker/dashboard_suggestion_worker_test.go b/internal/service/worker/dashboard_suggestion_worker_test.go index 1b42ba6d..11510f36 100644 --- a/internal/service/worker/dashboard_suggestion_worker_test.go +++ b/internal/service/worker/dashboard_suggestion_worker_test.go @@ -213,9 +213,11 @@ func TestDashboardSuggestionWorkerCompleteCellCountsMissingLabelSetsAsRejected(t } func TestDashboardSuggestionWorkerLLMRetryAndNonRetryableFailure(t *testing.T) { - t.Run("retryable error retries once in job", func(t *testing.T) { + rendered := suggestionrel.RenderedPrompt{System: "system", Controls: "controls", Volatile: "labels"} + + t.Run("overloaded error retries once in job", func(t *testing.T) { fake := &llm.FakeClient{ - Errors: []error{llm.ErrRateLimited}, + Errors: []error{llm.ErrOverloaded}, Responses: []*llm.StructuredResponse{ nil, {Raw: json.RawMessage(`{"mappings":[]}`), Model: "fake", InputTokens: 3, OutputTokens: 5}, @@ -223,22 +225,98 @@ func TestDashboardSuggestionWorkerLLMRetryAndNonRetryableFailure(t *testing.T) { } worker := NewDashboardSuggestionWorker(nil, fake, config.DefaultAIConfig(), zap.NewNop().Sugar()) - response, err := worker.completeWithOneRetry(context.Background(), "prompt") + response, err := worker.completeWithRetry(context.Background(), rendered) require.NoError(t, err) require.Equal(t, 3, response.InputTokens) require.Equal(t, 5, response.OutputTokens) require.Len(t, fake.Requests, 2) }) + t.Run("rate limit is not retried inline", func(t *testing.T) { + fake := &llm.FakeClient{Err: &llm.RateLimitError{RetryAfter: 5 * time.Second}} + worker := NewDashboardSuggestionWorker(nil, fake, config.DefaultAIConfig(), zap.NewNop().Sugar()) + + _, err := worker.completeWithRetry(context.Background(), rendered) + require.ErrorIs(t, err, llm.ErrRateLimited) + require.Len(t, fake.Requests, 1) + }) + t.Run("non retryable error fails cell immediately", func(t *testing.T) { fake := &llm.FakeClient{Err: llm.ErrAuth} worker := NewDashboardSuggestionWorker(nil, fake, config.DefaultAIConfig(), zap.NewNop().Sugar()) - _, err := worker.completeWithOneRetry(context.Background(), "prompt") + _, err := worker.completeWithRetry(context.Background(), rendered) require.ErrorIs(t, err, llm.ErrAuth) require.True(t, isNonRetryableLLMError(err)) require.Len(t, fake.Requests, 1) }) + + t.Run("cached request carries both cache breakpoints", func(t *testing.T) { + fake := &llm.FakeClient{Raw: json.RawMessage(`{"mappings":[]}`)} + worker := NewDashboardSuggestionWorker(nil, fake, config.DefaultAIConfig(), zap.NewNop().Sugar()) + + _, err := worker.completeWithRetry(context.Background(), rendered) + require.NoError(t, err) + require.Len(t, fake.Requests, 1) + req := fake.Requests[0] + require.Equal(t, "system", req.System) + require.Equal(t, llm.CacheTTL1h, req.SystemCacheTTL) + require.Equal(t, "controls", req.CachedUserPrefix) + require.Equal(t, llm.CacheTTL1h, req.CachedUserPrefixTTL) + require.Equal(t, "labels", req.Prompt) + }) +} + +func TestSnoozeDelay(t *testing.T) { + t.Parallel() + + // With a Retry-After hint, the delay is at least the hint and within the + // jitter band above it. + for range 50 { + d := snoozeDelay(12 * time.Second) + require.GreaterOrEqual(t, d, 12*time.Second) + require.Less(t, d, 12*time.Second+rateLimitSnoozeJitter) + } + + // Without a hint, it falls back to the default base plus jitter. + for range 50 { + d := snoozeDelay(0) + require.GreaterOrEqual(t, d, defaultRateLimitSnooze) + require.Less(t, d, defaultRateLimitSnooze+rateLimitSnoozeJitter) + } +} + +func TestRateLimitSnooze(t *testing.T) { + t.Parallel() + + t.Run("rate limit within budget snoozes", func(t *testing.T) { + delay, ok := rateLimitSnooze(&llm.RateLimitError{RetryAfter: 7 * time.Second}, 1) + require.True(t, ok) + require.GreaterOrEqual(t, delay, 7*time.Second) + }) + + t.Run("wrapped rate limit is detected", func(t *testing.T) { + wrapped := river.JobCancel(&llm.RateLimitError{RetryAfter: 2 * time.Second}) + _, ok := rateLimitSnooze(wrapped, 1) + require.True(t, ok) + }) + + t.Run("exhausted budget does not snooze", func(t *testing.T) { + _, ok := rateLimitSnooze(&llm.RateLimitError{RetryAfter: time.Second}, maxRateLimitSnoozes) + require.False(t, ok) + }) + + t.Run("non rate limit does not snooze", func(t *testing.T) { + _, ok := rateLimitSnooze(llm.ErrAuth, 1) + require.False(t, ok) + }) + + // The bare ErrRateLimited sentinel is not a *RateLimitError, so it is not + // snoozed (it carries no Retry-After) and falls through to normal handling. + t.Run("bare sentinel does not snooze", func(t *testing.T) { + _, ok := rateLimitSnooze(llm.ErrRateLimited, 1) + require.False(t, ok) + }) } func newDashboardSuggestionWorkerTestDB(t *testing.T) *gorm.DB {