Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/api/handler/oscal/system_security_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
// Defined here to avoid a circular import between the oscal handler and worker packages.
type SSPJobEnqueuer interface {
EnqueueOrphanedRiskCleanup(ctx context.Context, sspID uuid.UUID, oldProfileID, newProfileID *uuid.UUID) error
EnqueueDashboardSuggestionCells(ctx context.Context, runID uuid.UUID, cellCount int) error
}

// profileSummary is a lightweight DTO returned by the multi-profile list endpoint.
Expand Down
15 changes: 15 additions & 0 deletions internal/service/llm/fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,34 @@ package llm
import (
"context"
"encoding/json"
"sync"
)

type FakeClient struct {
mu sync.Mutex

Raw json.RawMessage
Model string
InputTokens int
OutputTokens int
Err error
Requests []StructuredRequest
Responses []*StructuredResponse
Errors []error
}

func (f *FakeClient) CompleteStructured(ctx context.Context, req StructuredRequest) (*StructuredResponse, error) {
f.mu.Lock()
defer f.mu.Unlock()

f.Requests = append(f.Requests, req)
index := len(f.Requests) - 1
if index < len(f.Errors) && f.Errors[index] != nil {
return nil, f.Errors[index]
}
if index < len(f.Responses) && f.Responses[index] != nil {
return f.Responses[index], nil
}
if f.Err != nil {
return nil, f.Err
}
Expand Down
158 changes: 96 additions & 62 deletions internal/service/relational/suggestions/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,75 +137,82 @@ func (g GatheredInput) CellInput() CellInput {
}

func (s *SuggestionService) InsertValidatedMappings(runID uuid.UUID, sspID uuid.UUID, promptVersion string, mappings []ValidatedMapping, maxSuggestionsPerRun int) (InsertMappingsResult, error) {
result := InsertMappingsResult{}
err := s.db.Transaction(func(tx *gorm.DB) error {
var err error
result, err = s.InsertValidatedMappingsTx(tx, runID, sspID, promptVersion, mappings, maxSuggestionsPerRun)
return err
})
return result, err
}

func (s *SuggestionService) InsertValidatedMappingsTx(tx *gorm.DB, runID uuid.UUID, sspID uuid.UUID, promptVersion string, mappings []ValidatedMapping, maxSuggestionsPerRun int) (InsertMappingsResult, error) {
if maxSuggestionsPerRun <= 0 {
maxSuggestionsPerRun = DefaultMaxSuggestionsPerRun
}
result := InsertMappingsResult{}
err := s.db.Transaction(func(tx *gorm.DB) error {
var run DashboardSuggestionRun
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("id = ? AND ssp_id = ?", runID, sspID).First(&run).Error; err != nil {
return err
}
capacity := maxSuggestionsPerRun - run.SuggestionCount
var run DashboardSuggestionRun
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("id = ? AND ssp_id = ?", runID, sspID).First(&run).Error; err != nil {
return result, err
}
capacity := maxSuggestionsPerRun - run.SuggestionCount
if capacity <= 0 {
result.Capped = len(mappings)
return result, nil
}

for _, mapping := range mappings {
if capacity <= 0 {
result.Capped = len(mappings)
return nil
result.Capped++
continue
}

for _, mapping := range mappings {
if capacity <= 0 {
result.Capped++
continue
}
excluded, err := s.mappingExcluded(tx, sspID, promptVersion, mapping)
if err != nil {
return err
}
if excluded {
result.Excluded++
continue
}
catalogID, controlID, err := ParseControlKey(mapping.ControlKey)
if err != nil {
return err
}
suggestion := DashboardSuggestion{
RunID: runID,
SSPID: sspID,
ControlCatalogID: catalogID,
ControlID: controlID,
LabelSet: labelsToJSONMap(mapping.LabelSet),
LabelSetHash: mapping.LabelSetHash,
TargetFilterID: mapping.TargetFilterID,
ProposedFilterName: mapping.ProposedFilterName,
Reasoning: mapping.Reasoning,
Confidence: mapping.Confidence,
Status: DashboardSuggestionStatusPending,
}
create := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&suggestion)
if create.Error != nil {
return create.Error
}
if create.RowsAffected == 0 {
result.Excluded++
continue
}
result.Inserted++
capacity--
if err := tx.Model(&DashboardSuggestionRun{}).
Where("id = ?", runID).
UpdateColumn("suggestion_count", gorm.Expr("suggestion_count + 1")).Error; err != nil {
return err
}
if err := createSuggestionEvent(tx, &suggestion, DashboardSuggestionEventTypeSuggestionCreated, nil, datatypes.JSONMap{
"prompt_version": promptVersion,
}); err != nil {
return err
}
excluded, err := s.mappingExcluded(tx, sspID, promptVersion, mapping)
if err != nil {
return result, err
}
return nil
})
return result, err
if excluded {
result.Excluded++
continue
}
catalogID, controlID, err := ParseControlKey(mapping.ControlKey)
if err != nil {
return result, err
}
suggestion := DashboardSuggestion{
RunID: runID,
SSPID: sspID,
ControlCatalogID: catalogID,
ControlID: controlID,
LabelSet: labelsToJSONMap(mapping.LabelSet),
LabelSetHash: mapping.LabelSetHash,
TargetFilterID: mapping.TargetFilterID,
ProposedFilterName: mapping.ProposedFilterName,
Reasoning: mapping.Reasoning,
Confidence: mapping.Confidence,
Status: DashboardSuggestionStatusPending,
}
create := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&suggestion)
if create.Error != nil {
return result, create.Error
}
if create.RowsAffected == 0 {
result.Excluded++
continue
}
result.Inserted++
capacity--
if err := tx.Model(&DashboardSuggestionRun{}).
Where("id = ?", runID).
UpdateColumn("suggestion_count", gorm.Expr("suggestion_count + 1")).Error; err != nil {
return result, err
}
if err := createSuggestionEvent(tx, &suggestion, DashboardSuggestionEventTypeSuggestionCreated, nil, datatypes.JSONMap{
"prompt_version": promptVersion,
}); err != nil {
return result, err
}
}
return result, nil
}

func (s *SuggestionService) Accept(sspID uuid.UUID, suggestionIDs []uuid.UUID, actorID uuid.UUID) error {
Expand Down Expand Up @@ -463,6 +470,21 @@ func createSuggestionEvent(tx *gorm.DB, suggestion *DashboardSuggestion, eventTy
return tx.Create(&event).Error
}

func CreateRunEventTx(tx *gorm.DB, run *DashboardSuggestionRun, eventType DashboardSuggestionEventType, payload datatypes.JSONMap) error {
snapshot, err := runSnapshot(run)
if err != nil {
return err
}
event := DashboardSuggestionEvent{
RunID: run.ID,
EventType: string(eventType),
OccurredAt: time.Now().UTC(),
Payload: payload,
Snapshot: snapshot,
}
return tx.Create(&event).Error
}

func suggestionSnapshot(suggestion *DashboardSuggestion) (datatypes.JSONMap, error) {
raw, err := json.Marshal(suggestion)
if err != nil {
Expand All @@ -475,6 +497,18 @@ func suggestionSnapshot(suggestion *DashboardSuggestion) (datatypes.JSONMap, err
return snapshot, nil
}

func runSnapshot(run *DashboardSuggestionRun) (datatypes.JSONMap, error) {
raw, err := json.Marshal(run)
if err != nil {
return nil, err
}
var snapshot datatypes.JSONMap
if err := json.Unmarshal(raw, &snapshot); err != nil {
return nil, err
}
return snapshot, nil
}

func selectDimension(requested []string, all []string, allSet map[string]struct{}) ([]string, []string) {
if len(requested) == 0 {
return append([]string(nil), all...), nil
Expand Down
36 changes: 36 additions & 0 deletions internal/service/worker/dashboard_suggestion_job_types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package worker

import (
"errors"
"time"

"github.com/google/uuid"
"github.com/riverqueue/river"
)

const (
JobTypeDashboardSuggestionCell = "dashboard_suggestion_cell"
DashboardSuggestionQueue = "suggestion"
DashboardSuggestionMaxAttempts = 3
)

var ErrDashboardSuggestionWorkerDisabled = errors.New("dashboard suggestion worker is disabled")

type DashboardSuggestionCellArgs struct {
RunID uuid.UUID `json:"run_id" river:"unique"`
CellIndex int `json:"cell_index" river:"unique"`
}

func (DashboardSuggestionCellArgs) Kind() string { return JobTypeDashboardSuggestionCell }

func (DashboardSuggestionCellArgs) Timeout() time.Duration { return 5 * time.Minute }

func JobInsertOptionsForDashboardSuggestionCell() *river.InsertOpts {
return &river.InsertOpts{
Queue: DashboardSuggestionQueue,
MaxAttempts: DashboardSuggestionMaxAttempts,
UniqueOpts: river.UniqueOpts{
ByArgs: true,
},
}
}
Loading
Loading