diff --git a/internal/api/handler/oscal/system_security_plans.go b/internal/api/handler/oscal/system_security_plans.go index f3eddb3d..4c9290e5 100644 --- a/internal/api/handler/oscal/system_security_plans.go +++ b/internal/api/handler/oscal/system_security_plans.go @@ -30,6 +30,7 @@ import ( // Defined here to avoid a circular import between the oscal handler and worker packages. type SSPJobEnqueuer interface { EnqueueOrphanedRiskCleanup(ctx context.Context, sspID uuid.UUID, oldProfileID, newProfileID *uuid.UUID) error + EnqueueDashboardSuggestionCells(ctx context.Context, runID uuid.UUID, cellCount int) error } // profileSummary is a lightweight DTO returned by the multi-profile list endpoint. diff --git a/internal/service/llm/fake.go b/internal/service/llm/fake.go index b5201d42..c2b6f7ee 100644 --- a/internal/service/llm/fake.go +++ b/internal/service/llm/fake.go @@ -3,19 +3,34 @@ package llm import ( "context" "encoding/json" + "sync" ) type FakeClient struct { + mu sync.Mutex + Raw json.RawMessage Model string InputTokens int OutputTokens int Err error Requests []StructuredRequest + Responses []*StructuredResponse + Errors []error } func (f *FakeClient) CompleteStructured(ctx context.Context, req StructuredRequest) (*StructuredResponse, error) { + f.mu.Lock() + defer f.mu.Unlock() + f.Requests = append(f.Requests, req) + index := len(f.Requests) - 1 + if index < len(f.Errors) && f.Errors[index] != nil { + return nil, f.Errors[index] + } + if index < len(f.Responses) && f.Responses[index] != nil { + return f.Responses[index], nil + } if f.Err != nil { return nil, f.Err } diff --git a/internal/service/relational/suggestions/service.go b/internal/service/relational/suggestions/service.go index 423cc266..e0ea8678 100644 --- a/internal/service/relational/suggestions/service.go +++ b/internal/service/relational/suggestions/service.go @@ -137,75 +137,82 @@ func (g GatheredInput) CellInput() CellInput { } func (s *SuggestionService) InsertValidatedMappings(runID uuid.UUID, sspID uuid.UUID, promptVersion string, mappings []ValidatedMapping, maxSuggestionsPerRun int) (InsertMappingsResult, error) { + result := InsertMappingsResult{} + err := s.db.Transaction(func(tx *gorm.DB) error { + var err error + result, err = s.InsertValidatedMappingsTx(tx, runID, sspID, promptVersion, mappings, maxSuggestionsPerRun) + return err + }) + return result, err +} + +func (s *SuggestionService) InsertValidatedMappingsTx(tx *gorm.DB, runID uuid.UUID, sspID uuid.UUID, promptVersion string, mappings []ValidatedMapping, maxSuggestionsPerRun int) (InsertMappingsResult, error) { if maxSuggestionsPerRun <= 0 { maxSuggestionsPerRun = DefaultMaxSuggestionsPerRun } result := InsertMappingsResult{} - err := s.db.Transaction(func(tx *gorm.DB) error { - var run DashboardSuggestionRun - if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("id = ? AND ssp_id = ?", runID, sspID).First(&run).Error; err != nil { - return err - } - capacity := maxSuggestionsPerRun - run.SuggestionCount + var run DashboardSuggestionRun + if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("id = ? AND ssp_id = ?", runID, sspID).First(&run).Error; err != nil { + return result, err + } + capacity := maxSuggestionsPerRun - run.SuggestionCount + if capacity <= 0 { + result.Capped = len(mappings) + return result, nil + } + + for _, mapping := range mappings { if capacity <= 0 { - result.Capped = len(mappings) - return nil + result.Capped++ + continue } - - for _, mapping := range mappings { - if capacity <= 0 { - result.Capped++ - continue - } - excluded, err := s.mappingExcluded(tx, sspID, promptVersion, mapping) - if err != nil { - return err - } - if excluded { - result.Excluded++ - continue - } - catalogID, controlID, err := ParseControlKey(mapping.ControlKey) - if err != nil { - return err - } - suggestion := DashboardSuggestion{ - RunID: runID, - SSPID: sspID, - ControlCatalogID: catalogID, - ControlID: controlID, - LabelSet: labelsToJSONMap(mapping.LabelSet), - LabelSetHash: mapping.LabelSetHash, - TargetFilterID: mapping.TargetFilterID, - ProposedFilterName: mapping.ProposedFilterName, - Reasoning: mapping.Reasoning, - Confidence: mapping.Confidence, - Status: DashboardSuggestionStatusPending, - } - create := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&suggestion) - if create.Error != nil { - return create.Error - } - if create.RowsAffected == 0 { - result.Excluded++ - continue - } - result.Inserted++ - capacity-- - if err := tx.Model(&DashboardSuggestionRun{}). - Where("id = ?", runID). - UpdateColumn("suggestion_count", gorm.Expr("suggestion_count + 1")).Error; err != nil { - return err - } - if err := createSuggestionEvent(tx, &suggestion, DashboardSuggestionEventTypeSuggestionCreated, nil, datatypes.JSONMap{ - "prompt_version": promptVersion, - }); err != nil { - return err - } + excluded, err := s.mappingExcluded(tx, sspID, promptVersion, mapping) + if err != nil { + return result, err } - return nil - }) - return result, err + if excluded { + result.Excluded++ + continue + } + catalogID, controlID, err := ParseControlKey(mapping.ControlKey) + if err != nil { + return result, err + } + suggestion := DashboardSuggestion{ + RunID: runID, + SSPID: sspID, + ControlCatalogID: catalogID, + ControlID: controlID, + LabelSet: labelsToJSONMap(mapping.LabelSet), + LabelSetHash: mapping.LabelSetHash, + TargetFilterID: mapping.TargetFilterID, + ProposedFilterName: mapping.ProposedFilterName, + Reasoning: mapping.Reasoning, + Confidence: mapping.Confidence, + Status: DashboardSuggestionStatusPending, + } + create := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&suggestion) + if create.Error != nil { + return result, create.Error + } + if create.RowsAffected == 0 { + result.Excluded++ + continue + } + result.Inserted++ + capacity-- + if err := tx.Model(&DashboardSuggestionRun{}). + Where("id = ?", runID). + UpdateColumn("suggestion_count", gorm.Expr("suggestion_count + 1")).Error; err != nil { + return result, err + } + if err := createSuggestionEvent(tx, &suggestion, DashboardSuggestionEventTypeSuggestionCreated, nil, datatypes.JSONMap{ + "prompt_version": promptVersion, + }); err != nil { + return result, err + } + } + return result, nil } func (s *SuggestionService) Accept(sspID uuid.UUID, suggestionIDs []uuid.UUID, actorID uuid.UUID) error { @@ -463,6 +470,21 @@ func createSuggestionEvent(tx *gorm.DB, suggestion *DashboardSuggestion, eventTy return tx.Create(&event).Error } +func CreateRunEventTx(tx *gorm.DB, run *DashboardSuggestionRun, eventType DashboardSuggestionEventType, payload datatypes.JSONMap) error { + snapshot, err := runSnapshot(run) + if err != nil { + return err + } + event := DashboardSuggestionEvent{ + RunID: run.ID, + EventType: string(eventType), + OccurredAt: time.Now().UTC(), + Payload: payload, + Snapshot: snapshot, + } + return tx.Create(&event).Error +} + func suggestionSnapshot(suggestion *DashboardSuggestion) (datatypes.JSONMap, error) { raw, err := json.Marshal(suggestion) if err != nil { @@ -475,6 +497,18 @@ func suggestionSnapshot(suggestion *DashboardSuggestion) (datatypes.JSONMap, err return snapshot, nil } +func runSnapshot(run *DashboardSuggestionRun) (datatypes.JSONMap, error) { + raw, err := json.Marshal(run) + if err != nil { + return nil, err + } + var snapshot datatypes.JSONMap + if err := json.Unmarshal(raw, &snapshot); err != nil { + return nil, err + } + return snapshot, nil +} + func selectDimension(requested []string, all []string, allSet map[string]struct{}) ([]string, []string) { if len(requested) == 0 { return append([]string(nil), all...), nil diff --git a/internal/service/worker/dashboard_suggestion_job_types.go b/internal/service/worker/dashboard_suggestion_job_types.go new file mode 100644 index 00000000..1350d411 --- /dev/null +++ b/internal/service/worker/dashboard_suggestion_job_types.go @@ -0,0 +1,36 @@ +package worker + +import ( + "errors" + "time" + + "github.com/google/uuid" + "github.com/riverqueue/river" +) + +const ( + JobTypeDashboardSuggestionCell = "dashboard_suggestion_cell" + DashboardSuggestionQueue = "suggestion" + DashboardSuggestionMaxAttempts = 3 +) + +var ErrDashboardSuggestionWorkerDisabled = errors.New("dashboard suggestion worker is disabled") + +type DashboardSuggestionCellArgs struct { + RunID uuid.UUID `json:"run_id" river:"unique"` + CellIndex int `json:"cell_index" river:"unique"` +} + +func (DashboardSuggestionCellArgs) Kind() string { return JobTypeDashboardSuggestionCell } + +func (DashboardSuggestionCellArgs) Timeout() time.Duration { return 5 * time.Minute } + +func JobInsertOptionsForDashboardSuggestionCell() *river.InsertOpts { + return &river.InsertOpts{ + Queue: DashboardSuggestionQueue, + MaxAttempts: DashboardSuggestionMaxAttempts, + UniqueOpts: river.UniqueOpts{ + ByArgs: true, + }, + } +} diff --git a/internal/service/worker/dashboard_suggestion_worker.go b/internal/service/worker/dashboard_suggestion_worker.go new file mode 100644 index 00000000..58d4d89f --- /dev/null +++ b/internal/service/worker/dashboard_suggestion_worker.go @@ -0,0 +1,427 @@ +package worker + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/compliance-framework/api/internal/config" + "github.com/compliance-framework/api/internal/service/llm" + suggestionrel "github.com/compliance-framework/api/internal/service/relational/suggestions" + "github.com/google/uuid" + "github.com/riverqueue/river" + "go.uber.org/zap" + "gorm.io/datatypes" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +const ( + dashboardSuggestionRunStatusPending = "pending" + dashboardSuggestionRunStatusRunning = "running" + dashboardSuggestionRunStatusCompleted = "completed" + dashboardSuggestionRunStatusFailed = "failed" + + dashboardSuggestionCellStatusPending = "pending" + dashboardSuggestionCellStatusCompleted = "completed" + dashboardSuggestionCellStatusFailed = "failed" +) + +type DashboardSuggestionWorker struct { + river.WorkerDefaults[DashboardSuggestionCellArgs] + + db *gorm.DB + suggestionService *suggestionrel.SuggestionService + llmClient llm.Client + aiCfg *config.AIConfig + logger *zap.SugaredLogger +} + +func NewDashboardSuggestionWorker(db *gorm.DB, llmClient llm.Client, aiCfg *config.AIConfig, logger *zap.SugaredLogger) *DashboardSuggestionWorker { + return &DashboardSuggestionWorker{ + db: db, + suggestionService: suggestionrel.NewSuggestionService(db), + llmClient: llmClient, + aiCfg: aiCfg, + logger: logger, + } +} + +func (w *DashboardSuggestionWorker) Timeout(job *river.Job[DashboardSuggestionCellArgs]) time.Duration { + requestTimeout := dashboardSuggestionRequestTimeout(w.aiCfg) + timeout := 2*requestTimeout + 30*time.Second + if timeout < 5*time.Minute { + return 5 * time.Minute + } + return timeout +} + +func (w *DashboardSuggestionWorker) Work(ctx context.Context, job *river.Job[DashboardSuggestionCellArgs]) error { + run, cell, ok, err := w.loadPendingCellAndStartRun(ctx, job.Args) + if err != nil { + return err + } + if !ok { + return nil + } + + gathered, err := w.suggestionService.GatherCellInput(run.SSPID, suggestionrel.GridCell{ + CellIndex: cell.CellIndex, + ControlKeys: []string(cell.ControlKeys), + LabelSetHashes: []string(cell.LabelSetHashes), + }, suggestionrel.GatherOptions{}) + if err != nil { + return w.handleAttemptFailure(ctx, job, err) + } + missingLabelSets := len(cell.LabelSetHashes) - len(gathered.LabelSets) + if missingLabelSets < 0 { + missingLabelSets = 0 + } + + prompt, err := suggestionrel.RenderPrompt(gathered) + if err != nil { + return w.handleAttemptFailure(ctx, job, err) + } + + response, err := w.completeWithOneRetry(ctx, prompt) + if err != nil { + if isNonRetryableLLMError(err) { + if markErr := w.failCellAndMaybeFinalize(ctx, job.Args, err); markErr != nil { + return markErr + } + return river.JobCancel(err) + } + return w.handleAttemptFailure(ctx, job, err) + } + + rawCount, err := rawMappingCount(response.Raw) + if err != nil { + err = fmt.Errorf("%w: %v", llm.ErrInvalidOutput, err) + if markErr := w.failCellAndMaybeFinalize(ctx, job.Args, err); markErr != nil { + return markErr + } + return river.JobCancel(err) + } + + validation, err := suggestionrel.ValidateMappings(gathered.CellInput(), response.Raw) + if err != nil { + err = fmt.Errorf("%w: %v", llm.ErrInvalidOutput, err) + if markErr := w.failCellAndMaybeFinalize(ctx, job.Args, err); markErr != nil { + return markErr + } + return river.JobCancel(err) + } + + if err := w.completeCell(ctx, run, cell, response, validation, rawCount, missingLabelSets); err != nil { + return w.handleAttemptFailure(ctx, job, err) + } + return nil +} + +func (w *DashboardSuggestionWorker) loadPendingCellAndStartRun(ctx context.Context, args DashboardSuggestionCellArgs) (suggestionrel.DashboardSuggestionRun, suggestionrel.DashboardSuggestionRunCell, bool, error) { + var run suggestionrel.DashboardSuggestionRun + var cell suggestionrel.DashboardSuggestionRunCell + err := w.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}). + Where("run_id = ? AND cell_index = ?", args.RunID, args.CellIndex). + First(&cell).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil + } + return err + } + if cell.Status != dashboardSuggestionCellStatusPending { + return nil + } + if err := tx.Where("id = ?", args.RunID).First(&run).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil + } + return err + } + now := time.Now().UTC() + update := tx.Model(&suggestionrel.DashboardSuggestionRun{}). + Where("id = ? AND status = ?", args.RunID, dashboardSuggestionRunStatusPending). + Updates(map[string]any{ + "status": dashboardSuggestionRunStatusRunning, + "started_at": now, + }) + if update.Error != nil { + return update.Error + } + if update.RowsAffected == 1 { + run.Status = dashboardSuggestionRunStatusRunning + run.StartedAt = &now + return suggestionrel.CreateRunEventTx(tx, &run, suggestionrel.DashboardSuggestionEventTypeRunStarted, datatypes.JSONMap{ + "model": run.Model, + "prompt_version": run.PromptVersion, + }) + } + return nil + }) + if err != nil { + return suggestionrel.DashboardSuggestionRun{}, suggestionrel.DashboardSuggestionRunCell{}, false, err + } + if cell.Status != dashboardSuggestionCellStatusPending || run.ID == nil { + return suggestionrel.DashboardSuggestionRun{}, suggestionrel.DashboardSuggestionRunCell{}, false, nil + } + return run, cell, true, nil +} + +func (w *DashboardSuggestionWorker) completeWithOneRetry(ctx context.Context, prompt string) (*llm.StructuredResponse, error) { + requestTimeout := dashboardSuggestionRequestTimeout(w.aiCfg) + + req := llm.StructuredRequest{ + System: suggestionrel.SystemPrompt, + Prompt: prompt, + Schema: suggestionrel.OutputSchema(), + MaxTokens: llm.DefaultAnthropicMaxTokens, + } + + var lastErr error + for attempt := 0; attempt < 2; attempt++ { + callCtx, cancel := context.WithTimeout(ctx, requestTimeout) + response, err := w.llmClient.CompleteStructured(callCtx, req) + cancel() + if err == nil { + return response, nil + } + lastErr = err + if !isRetryableLLMError(err) { + break + } + } + return nil, lastErr +} + +func dashboardSuggestionRequestTimeout(aiCfg *config.AIConfig) time.Duration { + if aiCfg != nil && aiCfg.RequestTimeout > 0 { + return aiCfg.RequestTimeout + } + return config.DefaultAIConfig().RequestTimeout +} + +func (w *DashboardSuggestionWorker) completeCell( + ctx context.Context, + run suggestionrel.DashboardSuggestionRun, + cell suggestionrel.DashboardSuggestionRunCell, + response *llm.StructuredResponse, + validation suggestionrel.ValidationResult, + rawCount int, + missingLabelSets int, +) error { + maxSuggestions := suggestionrel.DefaultMaxSuggestionsPerRun + if w.aiCfg != nil && w.aiCfg.MaxSuggestionsPerRun > 0 { + maxSuggestions = w.aiCfg.MaxSuggestionsPerRun + } + // mappings_rejected is an operational unserved count: invalid/excluded/capped + // mappings plus requested label sets that had no gathered evidence. + rejected := rawCount - len(validation.Mappings) + missingLabelSets + if rejected < 0 { + rejected = 0 + } + + return w.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + var lockedCell suggestionrel.DashboardSuggestionRunCell + if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}). + Where("run_id = ? AND cell_index = ?", *run.ID, cell.CellIndex). + First(&lockedCell).Error; err != nil { + return err + } + if lockedCell.Status != dashboardSuggestionCellStatusPending { + return nil + } + + inserted, err := w.suggestionService.InsertValidatedMappingsTx(tx, *run.ID, run.SSPID, run.PromptVersion, validation.Mappings, maxSuggestions) + if err != nil { + return err + } + rejected += inserted.Excluded + inserted.Capped + now := time.Now().UTC() + update := tx.Model(&suggestionrel.DashboardSuggestionRunCell{}). + Where("run_id = ? AND cell_index = ? AND status = ?", *run.ID, cell.CellIndex, dashboardSuggestionCellStatusPending). + Updates(map[string]any{ + "status": dashboardSuggestionCellStatusCompleted, + "error": nil, + "input_tokens": response.InputTokens, + "output_tokens": response.OutputTokens, + "mappings_returned": rawCount, + "mappings_rejected": rejected, + "completed_at": now, + }) + if update.Error != nil { + return update.Error + } + if update.RowsAffected == 0 { + return nil + } + return w.finalizeRunIfReady(tx, *run.ID) + }) +} + +func (w *DashboardSuggestionWorker) handleAttemptFailure(ctx context.Context, job *river.Job[DashboardSuggestionCellArgs], err error) error { + if isFinalAttempt(job) { + if markErr := w.failCellAndMaybeFinalize(ctx, job.Args, err); markErr != nil { + if w.logger != nil { + w.logger.Errorw("failed to mark dashboard suggestion cell failed on final attempt", + "run_id", job.Args.RunID, + "cell_index", job.Args.CellIndex, + "error", err, + "mark_error", markErr, + ) + } + return markErr + } + } + return err +} + +func (w *DashboardSuggestionWorker) failCellAndMaybeFinalize(ctx context.Context, args DashboardSuggestionCellArgs, cause error) error { + detachedCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 10*time.Second) + defer cancel() + + return w.db.WithContext(detachedCtx).Transaction(func(tx *gorm.DB) error { + now := time.Now().UTC() + message := cause.Error() + update := tx.Model(&suggestionrel.DashboardSuggestionRunCell{}). + Where("run_id = ? AND cell_index = ? AND status = ?", args.RunID, args.CellIndex, dashboardSuggestionCellStatusPending). + Updates(map[string]any{ + "status": dashboardSuggestionCellStatusFailed, + "error": message, + "completed_at": now, + }) + if update.Error != nil { + return update.Error + } + if update.RowsAffected == 0 { + return nil + } + return w.finalizeRunIfReady(tx, args.RunID) + }) +} + +func (w *DashboardSuggestionWorker) finalizeRunIfReady(tx *gorm.DB, runID uuid.UUID) error { + var run suggestionrel.DashboardSuggestionRun + if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("id = ?", runID).First(&run).Error; err != nil { + return err + } + if run.Status == dashboardSuggestionRunStatusCompleted || run.Status == dashboardSuggestionRunStatusFailed { + return nil + } + + var pending int64 + if err := tx.Model(&suggestionrel.DashboardSuggestionRunCell{}). + Where("run_id = ? AND status = ?", runID, dashboardSuggestionCellStatusPending). + Count(&pending).Error; err != nil { + return err + } + if pending > 0 { + return nil + } + + type aggregateRow struct { + Completed int + Failed int + InputTokens int + OutputTokens int + MappingsReturned int + MappingsRejected int + } + var aggregate aggregateRow + if err := tx.Model(&suggestionrel.DashboardSuggestionRunCell{}). + Select(` + COALESCE(SUM(CASE WHEN status = ? THEN 1 ELSE 0 END), 0) AS completed, + COALESCE(SUM(CASE WHEN status = ? THEN 1 ELSE 0 END), 0) AS failed, + COALESCE(SUM(input_tokens), 0) AS input_tokens, + COALESCE(SUM(output_tokens), 0) AS output_tokens, + COALESCE(SUM(mappings_returned), 0) AS mappings_returned, + COALESCE(SUM(mappings_rejected), 0) AS mappings_rejected + `, dashboardSuggestionCellStatusCompleted, dashboardSuggestionCellStatusFailed). + Where("run_id = ?", runID). + Scan(&aggregate).Error; err != nil { + return err + } + + var failedCells []suggestionrel.DashboardSuggestionRunCell + if err := tx.Where("run_id = ? AND status = ?", runID, dashboardSuggestionCellStatusFailed). + Order("cell_index ASC"). + Find(&failedCells).Error; err != nil { + return err + } + + failedSummary := make([]any, 0, len(failedCells)) + for _, cell := range failedCells { + item := datatypes.JSONMap{"cell_index": cell.CellIndex} + if cell.Error != nil { + item["error"] = *cell.Error + } + failedSummary = append(failedSummary, item) + } + + stats := datatypes.JSONMap{} + for key, value := range run.Stats { + stats[key] = value + } + stats["cells_completed"] = aggregate.Completed + stats["cells_failed"] = aggregate.Failed + stats["failed_cells"] = failedSummary + stats["mappings_returned"] = aggregate.MappingsReturned + stats["mappings_rejected"] = aggregate.MappingsRejected + + status := dashboardSuggestionRunStatusCompleted + eventType := suggestionrel.DashboardSuggestionEventTypeRunCompleted + if aggregate.Completed == 0 { + status = dashboardSuggestionRunStatusFailed + eventType = suggestionrel.DashboardSuggestionEventTypeRunFailed + } + now := time.Now().UTC() + if err := tx.Model(&suggestionrel.DashboardSuggestionRun{}). + Where("id = ?", runID). + Updates(map[string]any{ + "status": status, + "completed_at": now, + "input_tokens": aggregate.InputTokens, + "output_tokens": aggregate.OutputTokens, + "stats": stats, + }).Error; err != nil { + return err + } + run.Status = status + run.CompletedAt = &now + run.InputTokens = aggregate.InputTokens + run.OutputTokens = aggregate.OutputTokens + run.Stats = stats + return suggestionrel.CreateRunEventTx(tx, &run, eventType, datatypes.JSONMap{ + "cells_completed": aggregate.Completed, + "cells_failed": aggregate.Failed, + }) +} + +func rawMappingCount(raw json.RawMessage) (int, error) { + var decoded suggestionrel.RawMappings + if err := json.Unmarshal(raw, &decoded); err != nil { + return 0, err + } + return len(decoded.Mappings), nil +} + +func isRetryableLLMError(err error) bool { + return errors.Is(err, llm.ErrRateLimited) || errors.Is(err, llm.ErrOverloaded) +} + +func isNonRetryableLLMError(err error) bool { + return errors.Is(err, llm.ErrAuth) || errors.Is(err, llm.ErrInvalidOutput) +} + +func isFinalAttempt(job *river.Job[DashboardSuggestionCellArgs]) bool { + if job == nil || job.JobRow == nil { + return false + } + maxAttempts := job.MaxAttempts + if maxAttempts <= 0 { + maxAttempts = DashboardSuggestionMaxAttempts + } + return job.Attempt >= maxAttempts +} diff --git a/internal/service/worker/dashboard_suggestion_worker_integration_test.go b/internal/service/worker/dashboard_suggestion_worker_integration_test.go new file mode 100644 index 00000000..09cfd85e --- /dev/null +++ b/internal/service/worker/dashboard_suggestion_worker_integration_test.go @@ -0,0 +1,232 @@ +//go:build integration + +package worker + +import ( + "context" + "encoding/json" + "regexp" + "sync" + "testing" + "time" + + "github.com/compliance-framework/api/internal/config" + "github.com/compliance-framework/api/internal/service/llm" + "github.com/compliance-framework/api/internal/service/relational" + suggestionrel "github.com/compliance-framework/api/internal/service/relational/suggestions" + "github.com/compliance-framework/api/internal/tests" + "github.com/google/uuid" + "github.com/riverqueue/river" + "github.com/riverqueue/river/rivertype" + "github.com/stretchr/testify/suite" + "go.uber.org/zap" + "gorm.io/datatypes" +) + +type DashboardSuggestionWorkerIntegrationSuite struct { + tests.IntegrationTestSuite +} + +func TestDashboardSuggestionWorkerIntegrationSuite(t *testing.T) { + suite.Run(t, new(DashboardSuggestionWorkerIntegrationSuite)) +} + +func (suite *DashboardSuggestionWorkerIntegrationSuite) SetupTest() { + suite.Require().NoError(suite.Migrator.Refresh()) +} + +func (suite *DashboardSuggestionWorkerIntegrationSuite) TestTwoByTwoGridConcurrentShuffledFinalizesAndRerunDoesNotDuplicate() { + ctx := context.Background() + runID, cells := suite.seedTwoByTwoSuggestionRun() + client := &promptMappingClient{} + worker := NewDashboardSuggestionWorker(suite.DB, client, &config.AIConfig{RequestTimeout: 120 * time.Second, MaxSuggestionsPerRun: 10}, zap.NewNop().Sugar()) + + order := []int{2, 0, 3, 1} + var wg sync.WaitGroup + errs := make(chan error, len(order)) + for _, cellIndex := range order { + wg.Add(1) + go func(cellIndex int) { + defer wg.Done() + errs <- worker.Work(ctx, dashboardSuggestionIntegrationJob(runID, cellIndex)) + }(cellIndex) + } + wg.Wait() + close(errs) + for err := range errs { + suite.Require().NoError(err) + } + + var run suggestionrel.DashboardSuggestionRun + suite.Require().NoError(suite.DB.First(&run, "id = ?", runID).Error) + suite.Equal(dashboardSuggestionRunStatusCompleted, run.Status) + suite.Equal(4, run.SuggestionCount) + suite.Equal(40, run.InputTokens) + suite.Equal(20, run.OutputTokens) + suite.NotNil(run.StartedAt) + suite.NotNil(run.CompletedAt) + + for _, cell := range cells { + var stored suggestionrel.DashboardSuggestionRunCell + suite.Require().NoError(suite.DB.First(&stored, "run_id = ? AND cell_index = ?", runID, cell.CellIndex).Error) + suite.Equal(dashboardSuggestionCellStatusCompleted, stored.Status) + suite.Equal(10, stored.InputTokens) + suite.Equal(5, stored.OutputTokens) + suite.Equal(1, stored.MappingsReturned) + suite.Equal(0, stored.MappingsRejected) + } + + var suggestionCount int64 + suite.Require().NoError(suite.DB.Model(&suggestionrel.DashboardSuggestion{}).Where("run_id = ?", runID).Count(&suggestionCount).Error) + suite.Equal(int64(4), suggestionCount) + var completedEvents int64 + suite.Require().NoError(suite.DB.Model(&suggestionrel.DashboardSuggestionEvent{}). + Where("run_id = ? AND event_type = ?", runID, suggestionrel.DashboardSuggestionEventTypeRunCompleted). + Count(&completedEvents).Error) + suite.Equal(int64(1), completedEvents) + + for _, cell := range cells { + suite.Require().NoError(worker.Work(ctx, dashboardSuggestionIntegrationJob(runID, cell.CellIndex))) + } + var afterRerunCount int64 + suite.Require().NoError(suite.DB.Model(&suggestionrel.DashboardSuggestion{}).Where("run_id = ?", runID).Count(&afterRerunCount).Error) + suite.Equal(suggestionCount, afterRerunCount) +} + +func (suite *DashboardSuggestionWorkerIntegrationSuite) seedTwoByTwoSuggestionRun() (uuid.UUID, []suggestionrel.GridCell) { + sspID := uuid.New() + runID := uuid.New() + catalogID := uuid.New() + profileID := uuid.New() + suite.Require().NoError(suite.DB.Create(&relational.SystemSecurityPlan{UUIDModel: relational.UUIDModel{ID: &sspID}}).Error) + suite.Require().NoError(suite.DB.Create(&relational.Profile{UUIDModel: relational.UUIDModel{ID: &profileID}}).Error) + suite.Require().NoError(suite.DB.Exec(`INSERT INTO ssp_profiles (system_security_plan_id, profile_id) VALUES (?, ?)`, sspID, profileID).Error) + + controlIDs := []string{"AC-1", "AC-2"} + controlKeys := make([]string, 0, len(controlIDs)) + implementationID := uuid.New() + suite.Require().NoError(suite.DB.Create(&relational.ControlImplementation{ + UUIDModel: relational.UUIDModel{ID: &implementationID}, + SystemSecurityPlanId: sspID, + }).Error) + for _, controlID := range controlIDs { + controlKeys = append(controlKeys, suggestionrel.ControlKey(catalogID, controlID)) + suite.Require().NoError(suite.DB.Create(&relational.Control{ + CatalogID: catalogID, + ID: controlID, + Title: controlID + " title", + Parts: datatypes.NewJSONSlice([]relational.Part{}), + }).Error) + suite.Require().NoError(suite.DB.Exec(`INSERT INTO profile_controls (profile_id, control_catalog_id, control_id) VALUES (?, ?, ?)`, profileID, catalogID, controlID).Error) + implementedRequirementID := uuid.New() + suite.Require().NoError(suite.DB.Create(&relational.ImplementedRequirement{ + UUIDModel: relational.UUIDModel{ID: &implementedRequirementID}, + ControlImplementationId: implementationID, + ControlId: controlID, + Remarks: controlID + " implementation", + }).Error) + } + + labelSets := []map[string]string{ + {"env": "prod", "service": "api"}, + {"env": "stage", "service": "worker"}, + } + labelSetHashes := make([]string, 0, len(labelSets)) + for i, labels := range labelSets { + hash := suggestionrel.CanonicalLabelSetHash(labels) + labelSetHashes = append(labelSetHashes, hash) + evidenceID := uuid.New() + streamID := uuid.New() + suite.Require().NoError(suite.DB.Exec( + `INSERT INTO evidences (id, uuid, title, description, start, "end") VALUES (?, ?, ?, ?, ?, ?)`, + evidenceID, + streamID, + "evidence", + "evidence", + time.Now().UTC().Add(time.Duration(i)*time.Minute), + time.Now().UTC().Add(time.Duration(i)*time.Minute), + ).Error) + for key, value := range labels { + suite.Require().NoError(suite.DB.Exec(`INSERT INTO evidence_labels (evidence_id, labels_name, labels_value) VALUES (?, ?, ?)`, evidenceID, key, value).Error) + } + } + + suite.Require().NoError(suite.DB.Create(&suggestionrel.DashboardSuggestionRun{ + UUIDModel: relational.UUIDModel{ID: &runID}, + SSPID: sspID, + Status: dashboardSuggestionRunStatusPending, + Model: "fake-model", + PromptVersion: suggestionrel.PromptVersion, + Scope: datatypes.JSONMap{"controlKeys": controlKeys, "labelSetHashes": labelSetHashes}, + PlannedCalls: 4, + SuggestionCount: 0, + Stats: datatypes.JSONMap{}, + }).Error) + + cells := suggestionrel.BuildGrid(suggestionrel.Snapshot{ControlKeys: controlKeys, LabelSetHashes: labelSetHashes}, suggestionrel.ChunkConfig{ + MaxControlsPerChunk: 1, + MaxLabelSetsPerChunk: 1, + }) + for _, cell := range cells { + suite.Require().NoError(suite.DB.Create(&suggestionrel.DashboardSuggestionRunCell{ + RunID: runID, + CellIndex: cell.CellIndex, + ControlKeys: datatypes.NewJSONSlice(cell.ControlKeys), + LabelSetHashes: datatypes.NewJSONSlice(cell.LabelSetHashes), + Status: dashboardSuggestionCellStatusPending, + }).Error) + } + return runID, cells +} + +type promptMappingClient struct { + mu sync.Mutex + requests int +} + +func (c *promptMappingClient) CompleteStructured(ctx context.Context, req llm.StructuredRequest) (*llm.StructuredResponse, error) { + c.mu.Lock() + c.requests++ + c.mu.Unlock() + + controlKey := firstPromptValue(req.Prompt, `"control_key": "([^"]+)"`) + labelSetHash := firstPromptValue(req.Prompt, `"hash": "([^"]+)"`) + raw, err := json.Marshal(suggestionrel.RawMappings{Mappings: []suggestionrel.RawMapping{ + { + ControlKey: controlKey, + LabelSetHash: labelSetHash, + Action: suggestionrel.MappingActionNewFilter, + ProposedFilterName: "Dashboard " + labelSetHash[:8], + Confidence: 0.9, + Reasoning: "Evidence satisfies the control and belongs to this system.", + }, + }}) + if err != nil { + return nil, err + } + return &llm.StructuredResponse{ + Raw: raw, + Model: "fake-model", + InputTokens: 10, + OutputTokens: 5, + }, nil +} + +func firstPromptValue(prompt string, pattern string) string { + match := regexp.MustCompile(pattern).FindStringSubmatch(prompt) + if len(match) < 2 { + return "" + } + return match[1] +} + +func dashboardSuggestionIntegrationJob(runID uuid.UUID, cellIndex int) *river.Job[DashboardSuggestionCellArgs] { + return &river.Job[DashboardSuggestionCellArgs]{ + JobRow: &rivertype.JobRow{ + ID: int64(cellIndex + 1), + Attempt: 1, + MaxAttempts: 3, + }, + Args: DashboardSuggestionCellArgs{RunID: runID, CellIndex: cellIndex}, + } +} diff --git a/internal/service/worker/dashboard_suggestion_worker_test.go b/internal/service/worker/dashboard_suggestion_worker_test.go new file mode 100644 index 00000000..1b42ba6d --- /dev/null +++ b/internal/service/worker/dashboard_suggestion_worker_test.go @@ -0,0 +1,338 @@ +package worker + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "testing" + "time" + + "github.com/compliance-framework/api/internal/config" + "github.com/compliance-framework/api/internal/service/llm" + "github.com/compliance-framework/api/internal/service/relational" + suggestionrel "github.com/compliance-framework/api/internal/service/relational/suggestions" + "github.com/google/uuid" + "github.com/riverqueue/river" + "github.com/riverqueue/river/rivertype" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "gorm.io/datatypes" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +func TestDashboardSuggestionCellJobType(t *testing.T) { + t.Parallel() + + args := DashboardSuggestionCellArgs{RunID: uuid.New(), CellIndex: 4} + require.Equal(t, JobTypeDashboardSuggestionCell, args.Kind()) + require.Equal(t, 5*time.Minute, args.Timeout()) + + opts := JobInsertOptionsForDashboardSuggestionCell() + require.Equal(t, DashboardSuggestionQueue, opts.Queue) + require.Equal(t, DashboardSuggestionMaxAttempts, opts.MaxAttempts) + require.True(t, opts.UniqueOpts.ByArgs) +} + +func TestDashboardSuggestionWorkerTimeout(t *testing.T) { + t.Parallel() + + worker := NewDashboardSuggestionWorker(nil, nil, config.DefaultAIConfig(), zap.NewNop().Sugar()) + require.Equal(t, 5*time.Minute, worker.Timeout(nil)) + + worker = NewDashboardSuggestionWorker(nil, nil, &config.AIConfig{RequestTimeout: 3 * time.Minute}, zap.NewNop().Sugar()) + require.Equal(t, 6*time.Minute+30*time.Second, worker.Timeout(nil)) + + worker = NewDashboardSuggestionWorker(nil, nil, &config.AIConfig{}, zap.NewNop().Sugar()) + require.Equal(t, 5*time.Minute, worker.Timeout(nil)) +} + +func TestBuildRiverConfigSuggestionQueueGated(t *testing.T) { + t.Parallel() + + cfg := config.DefaultWorkerConfig() + withoutAI := buildRiverConfig(cfg, river.NewWorkers(), nil, nil) + _, ok := withoutAI.Queues[DashboardSuggestionQueue] + require.False(t, ok) + + withAI := buildRiverConfig(cfg, river.NewWorkers(), nil, &config.AIConfig{Enabled: true, QueueWorkers: 9}) + queue, ok := withAI.Queues[DashboardSuggestionQueue] + require.True(t, ok) + require.Equal(t, 9, queue.MaxWorkers) +} + +func TestServiceEnqueueDashboardSuggestionCellsDisabled(t *testing.T) { + t.Parallel() + + svc := &Service{config: config.DefaultWorkerConfig(), aiEnabled: false} + err := svc.EnqueueDashboardSuggestionCells(context.Background(), uuid.New(), 2) + require.ErrorIs(t, err, ErrDashboardSuggestionWorkerDisabled) +} + +func TestDashboardSuggestionWorkerNonPendingCellNoops(t *testing.T) { + db := newDashboardSuggestionWorkerTestDB(t) + runID, sspID := seedDashboardSuggestionRun(t, db, dashboardSuggestionRunStatusRunning, 1) + seedDashboardSuggestionCell(t, db, runID, 0, dashboardSuggestionCellStatusCompleted) + fake := &llm.FakeClient{Raw: json.RawMessage(`{"mappings":[]}`)} + worker := NewDashboardSuggestionWorker(db, fake, config.DefaultAIConfig(), zap.NewNop().Sugar()) + + err := worker.Work(context.Background(), dashboardSuggestionJob(runID, 0, 1, 3)) + require.NoError(t, err) + require.Empty(t, fake.Requests) + + var eventCount int64 + require.NoError(t, db.Model(&suggestionrel.DashboardSuggestionEvent{}).Where("run_id = ?", runID).Count(&eventCount).Error) + require.Zero(t, eventCount) + + var run suggestionrel.DashboardSuggestionRun + require.NoError(t, db.First(&run, "id = ? AND ssp_id = ?", runID, sspID).Error) + require.Equal(t, dashboardSuggestionRunStatusRunning, run.Status) +} + +func TestDashboardSuggestionWorkerRunStartedEmittedOnce(t *testing.T) { + db := newDashboardSuggestionWorkerTestDB(t) + runID, _ := seedDashboardSuggestionRun(t, db, dashboardSuggestionRunStatusPending, 3) + for i := range 3 { + seedDashboardSuggestionCell(t, db, runID, i, dashboardSuggestionCellStatusPending) + } + worker := NewDashboardSuggestionWorker(db, &llm.FakeClient{}, config.DefaultAIConfig(), zap.NewNop().Sugar()) + + for i := range 3 { + _, _, ok, err := worker.loadPendingCellAndStartRun(context.Background(), DashboardSuggestionCellArgs{RunID: runID, CellIndex: i}) + require.NoError(t, err) + require.True(t, ok) + } + + var eventCount int64 + require.NoError(t, db.Model(&suggestionrel.DashboardSuggestionEvent{}). + Where("run_id = ? AND event_type = ?", runID, suggestionrel.DashboardSuggestionEventTypeRunStarted). + Count(&eventCount).Error) + require.Equal(t, int64(1), eventCount) + + var run suggestionrel.DashboardSuggestionRun + require.NoError(t, db.First(&run, "id = ?", runID).Error) + require.Equal(t, dashboardSuggestionRunStatusRunning, run.Status) + require.NotNil(t, run.StartedAt) +} + +func TestDashboardSuggestionWorkerFinalizationMatrix(t *testing.T) { + t.Run("mixed success failure completes run", func(t *testing.T) { + db := newDashboardSuggestionWorkerTestDB(t) + runID, _ := seedDashboardSuggestionRun(t, db, dashboardSuggestionRunStatusRunning, 2) + seedDashboardSuggestionCellWithStats(t, db, runID, 0, dashboardSuggestionCellStatusCompleted, 11, 7, 2, 1, nil) + seedDashboardSuggestionCell(t, db, runID, 1, dashboardSuggestionCellStatusPending) + worker := NewDashboardSuggestionWorker(db, &llm.FakeClient{}, config.DefaultAIConfig(), zap.NewNop().Sugar()) + + require.NoError(t, worker.failCellAndMaybeFinalize(context.Background(), DashboardSuggestionCellArgs{RunID: runID, CellIndex: 1}, errors.New("provider failed"))) + + var run suggestionrel.DashboardSuggestionRun + require.NoError(t, db.First(&run, "id = ?", runID).Error) + require.Equal(t, dashboardSuggestionRunStatusCompleted, run.Status) + require.Equal(t, 11, run.InputTokens) + require.Equal(t, 7, run.OutputTokens) + require.Equal(t, "1", fmt.Sprint(run.Stats["cells_completed"])) + require.Equal(t, "1", fmt.Sprint(run.Stats["cells_failed"])) + require.NotNil(t, run.CompletedAt) + assertRunEventCount(t, db, runID, suggestionrel.DashboardSuggestionEventTypeRunCompleted, 1) + }) + + t.Run("all failed fails run", func(t *testing.T) { + db := newDashboardSuggestionWorkerTestDB(t) + runID, _ := seedDashboardSuggestionRun(t, db, dashboardSuggestionRunStatusRunning, 2) + seedDashboardSuggestionCellWithStats(t, db, runID, 0, dashboardSuggestionCellStatusFailed, 0, 0, 0, 0, ptrString("already failed")) + seedDashboardSuggestionCell(t, db, runID, 1, dashboardSuggestionCellStatusPending) + worker := NewDashboardSuggestionWorker(db, &llm.FakeClient{}, config.DefaultAIConfig(), zap.NewNop().Sugar()) + + require.NoError(t, worker.failCellAndMaybeFinalize(context.Background(), DashboardSuggestionCellArgs{RunID: runID, CellIndex: 1}, errors.New("auth failed"))) + + var run suggestionrel.DashboardSuggestionRun + require.NoError(t, db.First(&run, "id = ?", runID).Error) + require.Equal(t, dashboardSuggestionRunStatusFailed, run.Status) + require.NotNil(t, run.CompletedAt) + assertRunEventCount(t, db, runID, suggestionrel.DashboardSuggestionEventTypeRunFailed, 1) + }) +} + +func TestDashboardSuggestionWorkerFailCellDetachedFromCancelledContext(t *testing.T) { + db := newDashboardSuggestionWorkerTestDB(t) + runID, _ := seedDashboardSuggestionRun(t, db, dashboardSuggestionRunStatusRunning, 1) + seedDashboardSuggestionCell(t, db, runID, 0, dashboardSuggestionCellStatusPending) + worker := NewDashboardSuggestionWorker(db, &llm.FakeClient{}, config.DefaultAIConfig(), zap.NewNop().Sugar()) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + require.NoError(t, worker.failCellAndMaybeFinalize(ctx, DashboardSuggestionCellArgs{RunID: runID, CellIndex: 0}, errors.New("provider failed"))) + + var cell suggestionrel.DashboardSuggestionRunCell + require.NoError(t, db.First(&cell, "run_id = ? AND cell_index = ?", runID, 0).Error) + require.Equal(t, dashboardSuggestionCellStatusFailed, cell.Status) + require.NotNil(t, cell.CompletedAt) + require.NotNil(t, cell.Error) + require.Equal(t, "provider failed", *cell.Error) + + var run suggestionrel.DashboardSuggestionRun + require.NoError(t, db.First(&run, "id = ?", runID).Error) + require.Equal(t, dashboardSuggestionRunStatusFailed, run.Status) + require.NotNil(t, run.CompletedAt) + assertRunEventCount(t, db, runID, suggestionrel.DashboardSuggestionEventTypeRunFailed, 1) +} + +func TestDashboardSuggestionWorkerCompleteCellCountsMissingLabelSetsAsRejected(t *testing.T) { + db := newDashboardSuggestionWorkerTestDB(t) + runID, sspID := seedDashboardSuggestionRun(t, db, dashboardSuggestionRunStatusRunning, 1) + seedDashboardSuggestionCellWithLabelSets(t, db, runID, 0, dashboardSuggestionCellStatusPending, []string{"hash-with-evidence", "hash-without-evidence"}) + worker := NewDashboardSuggestionWorker(db, &llm.FakeClient{}, config.DefaultAIConfig(), zap.NewNop().Sugar()) + run := suggestionrel.DashboardSuggestionRun{ + UUIDModel: relational.UUIDModel{ID: &runID}, + SSPID: sspID, + PromptVersion: suggestionrel.PromptVersion, + Stats: datatypes.JSONMap{}, + } + cell := suggestionrel.DashboardSuggestionRunCell{RunID: runID, CellIndex: 0} + catalogID := uuid.New() + validation := suggestionrel.ValidationResult{ + Mappings: []suggestionrel.ValidatedMapping{{ + ControlKey: suggestionrel.ControlKey(catalogID, "AC-1"), + LabelSet: map[string]string{"component": "api"}, + LabelSetHash: "hash-with-evidence", + ProposedFilterName: "API scope", + Confidence: 0.8, + Reasoning: "evidence matches", + }}, + } + response := &llm.StructuredResponse{InputTokens: 2, OutputTokens: 3} + + require.NoError(t, worker.completeCell(context.Background(), run, cell, response, validation, 1, 1)) + + var stored suggestionrel.DashboardSuggestionRunCell + require.NoError(t, db.First(&stored, "run_id = ? AND cell_index = ?", runID, 0).Error) + require.Equal(t, dashboardSuggestionCellStatusCompleted, stored.Status) + require.Equal(t, 1, stored.MappingsReturned) + require.Equal(t, 1, stored.MappingsRejected) +} + +func TestDashboardSuggestionWorkerLLMRetryAndNonRetryableFailure(t *testing.T) { + t.Run("retryable error retries once in job", func(t *testing.T) { + fake := &llm.FakeClient{ + Errors: []error{llm.ErrRateLimited}, + Responses: []*llm.StructuredResponse{ + nil, + {Raw: json.RawMessage(`{"mappings":[]}`), Model: "fake", InputTokens: 3, OutputTokens: 5}, + }, + } + worker := NewDashboardSuggestionWorker(nil, fake, config.DefaultAIConfig(), zap.NewNop().Sugar()) + + response, err := worker.completeWithOneRetry(context.Background(), "prompt") + require.NoError(t, err) + require.Equal(t, 3, response.InputTokens) + require.Equal(t, 5, response.OutputTokens) + require.Len(t, fake.Requests, 2) + }) + + t.Run("non retryable error fails cell immediately", func(t *testing.T) { + fake := &llm.FakeClient{Err: llm.ErrAuth} + worker := NewDashboardSuggestionWorker(nil, fake, config.DefaultAIConfig(), zap.NewNop().Sugar()) + + _, err := worker.completeWithOneRetry(context.Background(), "prompt") + require.ErrorIs(t, err, llm.ErrAuth) + require.True(t, isNonRetryableLLMError(err)) + require.Len(t, fake.Requests, 1) + }) +} + +func newDashboardSuggestionWorkerTestDB(t *testing.T) *gorm.DB { + t.Helper() + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + require.NoError(t, err) + require.NoError(t, db.AutoMigrate( + &relational.SystemSecurityPlan{}, + &relational.SystemCharacteristics{}, + &relational.SystemImplementation{}, + &relational.SystemComponent{}, + &relational.Control{}, + &relational.Filter{}, + &suggestionrel.DashboardSuggestionRun{}, + &suggestionrel.DashboardSuggestionRunCell{}, + &suggestionrel.DashboardSuggestion{}, + &suggestionrel.DashboardSuggestionEvent{}, + )) + return db +} + +func seedDashboardSuggestionRun(t *testing.T, db *gorm.DB, status string, plannedCalls int) (uuid.UUID, uuid.UUID) { + t.Helper() + runID := uuid.New() + sspID := uuid.New() + require.NoError(t, db.Create(&relational.SystemSecurityPlan{UUIDModel: relational.UUIDModel{ID: &sspID}}).Error) + require.NoError(t, db.Create(&suggestionrel.DashboardSuggestionRun{ + UUIDModel: relational.UUIDModel{ID: &runID}, + SSPID: sspID, + Status: status, + Model: "fake-model", + PromptVersion: suggestionrel.PromptVersion, + Scope: datatypes.JSONMap{"controlKeys": []string{}, "labelSetHashes": []string{}}, + PlannedCalls: plannedCalls, + SuggestionCount: 0, + Stats: datatypes.JSONMap{}, + }).Error) + return runID, sspID +} + +func seedDashboardSuggestionCell(t *testing.T, db *gorm.DB, runID uuid.UUID, cellIndex int, status string) { + t.Helper() + seedDashboardSuggestionCellWithStats(t, db, runID, cellIndex, status, 0, 0, 0, 0, nil) +} + +func seedDashboardSuggestionCellWithLabelSets(t *testing.T, db *gorm.DB, runID uuid.UUID, cellIndex int, status string, labelSetHashes []string) { + t.Helper() + require.NoError(t, db.Create(&suggestionrel.DashboardSuggestionRunCell{ + RunID: runID, + CellIndex: cellIndex, + ControlKeys: datatypes.NewJSONSlice([]string{}), + LabelSetHashes: datatypes.NewJSONSlice(labelSetHashes), + Status: status, + }).Error) +} + +func seedDashboardSuggestionCellWithStats(t *testing.T, db *gorm.DB, runID uuid.UUID, cellIndex int, status string, inputTokens int, outputTokens int, mappingsReturned int, mappingsRejected int, message *string) { + t.Helper() + completedAt := time.Now().UTC() + require.NoError(t, db.Create(&suggestionrel.DashboardSuggestionRunCell{ + RunID: runID, + CellIndex: cellIndex, + ControlKeys: datatypes.NewJSONSlice([]string{}), + LabelSetHashes: datatypes.NewJSONSlice([]string{}), + Status: status, + Error: message, + InputTokens: inputTokens, + OutputTokens: outputTokens, + MappingsReturned: mappingsReturned, + MappingsRejected: mappingsRejected, + CompletedAt: &completedAt, + }).Error) +} + +func dashboardSuggestionJob(runID uuid.UUID, cellIndex int, attempt int, maxAttempts int) *river.Job[DashboardSuggestionCellArgs] { + return &river.Job[DashboardSuggestionCellArgs]{ + JobRow: &rivertype.JobRow{ + ID: 1, + Attempt: attempt, + MaxAttempts: maxAttempts, + }, + Args: DashboardSuggestionCellArgs{RunID: runID, CellIndex: cellIndex}, + } +} + +func assertRunEventCount(t *testing.T, db *gorm.DB, runID uuid.UUID, eventType suggestionrel.DashboardSuggestionEventType, expected int64) { + t.Helper() + var count int64 + require.NoError(t, db.Model(&suggestionrel.DashboardSuggestionEvent{}). + Where("run_id = ? AND event_type = ?", runID, eventType). + Count(&count).Error) + require.Equal(t, expected, count) +} + +func ptrString(value string) *string { + return &value +} diff --git a/internal/service/worker/service.go b/internal/service/worker/service.go index 7751fbc7..e1103d97 100644 --- a/internal/service/worker/service.go +++ b/internal/service/worker/service.go @@ -2,6 +2,7 @@ package worker import ( "context" + "errors" "fmt" "log" "os" @@ -10,6 +11,7 @@ import ( "github.com/compliance-framework/api/internal/config" "github.com/compliance-framework/api/internal/service/email" + "github.com/compliance-framework/api/internal/service/llm" "github.com/compliance-framework/api/internal/service/notification" emailprovider "github.com/compliance-framework/api/internal/service/notification/providers/email" slackprovider "github.com/compliance-framework/api/internal/service/notification/providers/slack" @@ -45,6 +47,7 @@ type Service struct { startedMu sync.RWMutex pgxPool *pgxpool.Pool digestCfg *config.Config + aiEnabled bool webBaseURL string // Workflow services @@ -100,6 +103,7 @@ func NewServiceWithDigest( emailSvc: emailSvc, digestSvc: digestSvc, digestCfg: digestCfg, + aiEnabled: false, logger: logger, started: false, }, nil @@ -314,11 +318,24 @@ func NewServiceWithDigest( poamOpenDigestSchedulerWorker := NewPoamOpenDigestSchedulerWorker(db, clientProxy, poamCfg.OpenDigestWindow, logger) river.AddWorker(workers, river.WorkFunc(poamOpenDigestSchedulerWorker.Work)) + aiEnabled := digestCfg != nil && digestCfg.AI != nil && digestCfg.AI.Enabled + if aiEnabled { + llmClient := llm.NewAnthropicClient(llm.AnthropicConfig{ + Enabled: true, + APIKey: digestCfg.AI.APIKey, + Model: digestCfg.AI.Model, + BaseURL: digestCfg.AI.BaseURL, + RequestTimeout: digestCfg.AI.RequestTimeout, + }) + dashboardSuggestionWorker := NewDashboardSuggestionWorker(db, llmClient, digestCfg.AI, logger) + river.AddWorker(workers, dashboardSuggestionWorker) + } + // Configure periodic jobs periodicJobs := periodicJobsFromConfig(digestCfg, logger) // Create River client with pgxv5 driver - riverConfig := buildRiverConfig(cfg, workers, periodicJobs) + riverConfig := buildRiverConfig(cfg, workers, periodicJobs, aiConfigFromConfig(digestCfg)) // Create the client client, err := river.NewClient(riverpgxv5.New(pgxPool), &riverConfig) @@ -338,6 +355,7 @@ func NewServiceWithDigest( slackSvc: slackService, userRepo: userRepo, digestCfg: digestCfg, + aiEnabled: aiEnabled, webBaseURL: webBaseURL, logger: logger, started: false, @@ -698,8 +716,8 @@ func periodicJobsFromConfig(cfg *config.Config, logger *zap.SugaredLogger) []*ri return periodicJobs } -func buildRiverConfig(cfg *config.WorkerConfig, workers *river.Workers, periodicJobs []*river.PeriodicJob) river.Config { - return river.Config{ +func buildRiverConfig(cfg *config.WorkerConfig, workers *river.Workers, periodicJobs []*river.PeriodicJob, aiCfg *config.AIConfig) river.Config { + riverConfig := river.Config{ PollOnly: cfg.UsePolling, Queues: map[string]river.QueueConfig{ "email": { @@ -730,6 +748,23 @@ func buildRiverConfig(cfg *config.WorkerConfig, workers *river.Workers, periodic Workers: workers, PeriodicJobs: periodicJobs, } + if aiCfg != nil && aiCfg.Enabled { + maxWorkers := aiCfg.QueueWorkers + if maxWorkers <= 0 { + maxWorkers = config.DefaultAIConfig().QueueWorkers + } + riverConfig.Queues[DashboardSuggestionQueue] = river.QueueConfig{ + MaxWorkers: maxWorkers, + } + } + return riverConfig +} + +func aiConfigFromConfig(cfg *config.Config) *config.AIConfig { + if cfg == nil { + return nil + } + return cfg.AI } func (s *Service) emailQueue() string { @@ -1069,3 +1104,31 @@ func (s *Service) EnqueueOrphanedRiskCleanup(ctx context.Context, sspID uuid.UUI } return nil } + +func (s *Service) EnqueueDashboardSuggestionCells(ctx context.Context, runID uuid.UUID, cellCount int) error { + if s == nil || s.config == nil || !s.config.Enabled || s.client == nil || !s.aiEnabled { + return ErrDashboardSuggestionWorkerDisabled + } + if cellCount <= 0 { + return nil + } + + params := make([]river.InsertManyParams, 0, cellCount) + for cellIndex := 0; cellIndex < cellCount; cellIndex++ { + params = append(params, river.InsertManyParams{ + Args: DashboardSuggestionCellArgs{ + RunID: runID, + CellIndex: cellIndex, + }, + InsertOpts: JobInsertOptionsForDashboardSuggestionCell(), + }) + } + + if _, err := s.client.InsertMany(ctx, params); err != nil { + if errors.Is(err, ErrDashboardSuggestionWorkerDisabled) { + return err + } + return fmt.Errorf("failed to enqueue dashboard suggestion cell jobs for run %s: %w", runID, err) + } + return nil +} diff --git a/internal/service/worker/service_test.go b/internal/service/worker/service_test.go index 7a92bda1..ba78e5ff 100644 --- a/internal/service/worker/service_test.go +++ b/internal/service/worker/service_test.go @@ -741,7 +741,7 @@ func TestBuildRiverConfig_IncludesSendSlackQueue(t *testing.T) { Workers: 7, } - riverConfig := buildRiverConfig(cfg, river.NewWorkers(), nil) + riverConfig := buildRiverConfig(cfg, river.NewWorkers(), nil, nil) sendSlackQueue, ok := riverConfig.Queues["slack"] assert.True(t, ok) @@ -751,7 +751,7 @@ func TestBuildRiverConfig_IncludesSendSlackQueue(t *testing.T) { func TestBuildRiverConfig_PollOnlyDisabledByDefault(t *testing.T) { cfg := config.DefaultWorkerConfig() - riverConfig := buildRiverConfig(cfg, river.NewWorkers(), nil) + riverConfig := buildRiverConfig(cfg, river.NewWorkers(), nil, nil) assert.False(t, riverConfig.PollOnly) } @@ -760,7 +760,7 @@ func TestBuildRiverConfig_PollOnlyCanBeEnabled(t *testing.T) { cfg := config.DefaultWorkerConfig() cfg.UsePolling = true - riverConfig := buildRiverConfig(cfg, river.NewWorkers(), nil) + riverConfig := buildRiverConfig(cfg, river.NewWorkers(), nil, nil) assert.True(t, riverConfig.PollOnly) }