diff --git a/cli.go b/cli.go index 3f73977a8..cf58a66c8 100644 --- a/cli.go +++ b/cli.go @@ -77,6 +77,7 @@ import ( backupbackend "github.com/blackbirdworks/gopherstack/services/backup" batchbackend "github.com/blackbirdworks/gopherstack/services/batch" bedrockbackend "github.com/blackbirdworks/gopherstack/services/bedrock" + bedrockagentbackend "github.com/blackbirdworks/gopherstack/services/bedrockagent" bedrockruntimebackend "github.com/blackbirdworks/gopherstack/services/bedrockruntime" cebackend "github.com/blackbirdworks/gopherstack/services/ce" cloudcontrolbackend "github.com/blackbirdworks/gopherstack/services/cloudcontrol" @@ -2767,6 +2768,7 @@ func getMostRecentServiceProviders() []service.Provider { &fsxbackend.Provider{}, &vpclatticebackend.Provider{}, &omicsbackend.Provider{}, + &bedrockagentbackend.Provider{}, } } diff --git a/services/bedrockagent/backend.go b/services/bedrockagent/backend.go new file mode 100644 index 000000000..f310d0c34 --- /dev/null +++ b/services/bedrockagent/backend.go @@ -0,0 +1,2837 @@ +package bedrockagent + +import ( + "context" + "fmt" + "maps" + "sort" + "strconv" + "sync" + "time" + + "github.com/blackbirdworks/gopherstack/pkgs/arn" + "github.com/blackbirdworks/gopherstack/pkgs/awserr" +) + +// --------------------------------------------------------------------------- +// Sentinel errors +// --------------------------------------------------------------------------- + +var ( + // ErrNotFound is returned when a requested resource does not exist. + ErrNotFound = awserr.New("ResourceNotFoundException", awserr.ErrNotFound) + // ErrAlreadyExists is returned when a resource with the given name already exists. + ErrAlreadyExists = awserr.New("ConflictException", awserr.ErrAlreadyExists) + // ErrValidation is returned for invalid request parameters. + ErrValidation = awserr.New("ValidationException", awserr.ErrInvalidParameter) +) + +// --------------------------------------------------------------------------- +// Context key +// --------------------------------------------------------------------------- + +type regionKey struct{} + +func ctxRegion(ctx context.Context, dflt string) string { + if r, ok := ctx.Value(regionKey{}).(string); ok && r != "" { + return r + } + + return dflt +} + +// --------------------------------------------------------------------------- +// Status constants +// --------------------------------------------------------------------------- + +const ( + agentStatusNotPrepared = "NOT_PREPARED" + agentStatusPreparing = "PREPARING" + agentStatusPrepared = "PREPARED" + kbStatusActive = "ACTIVE" + dsStatusAvailable = "AVAILABLE" + aliasStatusPrepared = "PREPARED" + flowStatusPrepared = "PREPARED" + flowStatusNotPrepared = "NOT_PREPARED" + ingestionJobRunning = "IN_PROGRESS" + ingestionJobComplete = "COMPLETE" + actionGroupEnabled = "ENABLED" + collabEnabled = "ENABLED" + docStatusIndexed = "INDEXED" + defaultAgentVersion = "DRAFT" + + bedrockAgentService = "bedrock" +) + +// --------------------------------------------------------------------------- +// Config structs +// --------------------------------------------------------------------------- + +// AgentConfig holds fields for creating or updating an Agent. +type AgentConfig struct { + Tags map[string]string + Guardrail map[string]any + Memory map[string]any + AgentName string + Collaboration string + Description string + FoundationModel string + Instruction string + RoleARN string +} + +// ActionGroupConfig holds fields for creating or updating an AgentActionGroup. +type ActionGroupConfig struct { + ActionGroupExecutor map[string]any + APISchema map[string]any + FunctionSchema map[string]any + ActionGroupName string + Description string + ActionGroupState string +} + +// AliasConfig holds fields for creating or updating an AgentAlias. +type AliasConfig struct { + Tags map[string]string + AliasName string + Description string + RoutingConfiguration []AliasRouting +} + +// CollaboratorConfig holds fields for an AgentCollaborator. +type CollaboratorConfig struct { + AgentDescriptor map[string]any + CollaboratorName string + CollaborationInstruction string + RelayConversationHistory string +} + +// KnowledgeBaseConfig holds fields for creating or updating a KnowledgeBase. +type KnowledgeBaseConfig struct { + Tags map[string]string + KBConfiguration map[string]any + StorageConfiguration map[string]any + Name string + Description string + RoleARN string +} + +// DataSourceConfig holds fields for creating or updating a DataSource. +type DataSourceConfig struct { + DataSourceConfiguration map[string]any + VectorIngestionConfig map[string]any + Name string + Description string + DataDeletionPolicy string +} + +// FlowConfig holds fields for creating or updating a Flow. +type FlowConfig struct { + Tags map[string]string + Definition map[string]any + Name string + Description string + RoleARN string +} + +// FlowAliasConfig holds fields for creating or updating a FlowAlias. +type FlowAliasConfig struct { + Tags map[string]string + Name string + Description string + RoutingConfiguration []FlowAliasRouting +} + +// PromptConfig holds fields for creating or updating a Prompt. +type PromptConfig struct { + Tags map[string]string + Name string + Description string + DefaultVariant string + Variants []map[string]any +} + +// KBDocument is a knowledge base document for ingestion. +type KBDocument struct { + Metadata map[string]any + Content map[string]any + DocID string +} + +// --------------------------------------------------------------------------- +// Model types +// --------------------------------------------------------------------------- + +// Agent represents a Bedrock Agent. +type Agent struct { + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + Tags map[string]string `json:"tags,omitempty"` + Guardrail map[string]any `json:"guardrailConfiguration,omitempty"` + Memory map[string]any `json:"memoryConfiguration,omitempty"` + AgentID string `json:"agentId"` + AgentARN string `json:"agentArn"` + AgentName string `json:"agentName"` + AgentVersion string `json:"agentVersion"` + AgentStatus string `json:"agentStatus"` + Collaboration string `json:"agentCollaboration,omitempty"` + Description string `json:"description,omitempty"` + FoundationModel string `json:"foundationModel,omitempty"` + Instruction string `json:"instruction,omitempty"` + RoleARN string `json:"agentResourceRoleArn,omitempty"` +} + +// AgentSummary is the condensed agent representation used in list responses. +type AgentSummary struct { + UpdatedAt time.Time `json:"updatedAt"` + AgentID string `json:"agentId"` + AgentName string `json:"agentName"` + AgentStatus string `json:"agentStatus"` + Description string `json:"description,omitempty"` +} + +// AgentVersion holds a snapshot version of an agent. +type AgentVersion struct { + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + AgentID string `json:"agentId"` + AgentARN string `json:"agentArn"` + AgentName string `json:"agentName"` + AgentStatus string `json:"agentStatus"` + AgentVersion string `json:"agentVersion"` + Description string `json:"description,omitempty"` + FoundationModel string `json:"foundationModel,omitempty"` + Instruction string `json:"instruction,omitempty"` + RoleARN string `json:"agentResourceRoleArn,omitempty"` +} + +// AgentVersionSummary is used in list-agent-versions responses. +type AgentVersionSummary struct { + UpdatedAt time.Time `json:"updatedAt"` + AgentName string `json:"agentName"` + AgentStatus string `json:"agentStatus"` + AgentVersion string `json:"agentVersion"` + Description string `json:"description,omitempty"` +} + +// AgentActionGroup is an action group attached to an agent version. +type AgentActionGroup struct { + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + ActionGroupExecutor map[string]any `json:"actionGroupExecutor,omitempty"` + APISchema map[string]any `json:"apiSchema,omitempty"` + FunctionSchema map[string]any `json:"functionSchema,omitempty"` + ActionGroupID string `json:"actionGroupId"` + ActionGroupName string `json:"actionGroupName"` + AgentID string `json:"agentId"` + AgentVersion string `json:"agentVersion"` + ActionGroupState string `json:"actionGroupState"` + Description string `json:"description,omitempty"` +} + +// ActionGroupSummary is used in list responses. +type ActionGroupSummary struct { + ActionGroupID string `json:"actionGroupId"` + ActionGroupName string `json:"actionGroupName"` + ActionGroupState string `json:"actionGroupState"` + Description string `json:"description,omitempty"` +} + +// AliasRouting maps an alias to an agent version. +type AliasRouting struct { + AgentVersion string `json:"agentVersion"` +} + +// AgentAlias routes traffic to a specific agent version. +type AgentAlias struct { + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + Tags map[string]string `json:"tags,omitempty"` + AgentAliasID string `json:"agentAliasId"` + AgentAliasARN string `json:"agentAliasArn"` + AgentAliasName string `json:"agentAliasName"` + AgentAliasStatus string `json:"agentAliasStatus"` + AgentID string `json:"agentId"` + Description string `json:"description,omitempty"` + RoutingConfiguration []AliasRouting `json:"routingConfiguration"` +} + +// AgentAliasSummary is used in list responses. +type AgentAliasSummary struct { + AgentAliasID string `json:"agentAliasId"` + AgentAliasName string `json:"agentAliasName"` + AgentAliasStatus string `json:"agentAliasStatus"` + Description string `json:"description,omitempty"` +} + +// AgentCollaborator links two agents for multi-agent collaboration. +type AgentCollaborator struct { + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + AgentDescriptor map[string]any `json:"agentDescriptor,omitempty"` + AgentID string `json:"agentId"` + AgentVersion string `json:"agentVersion"` + CollaboratorID string `json:"collaboratorId"` + CollaboratorName string `json:"collaboratorName"` + CollaborationInstruction string `json:"collaborationInstruction,omitempty"` + RelayConversationHistory string `json:"relayConversationHistory,omitempty"` + CollaboratorStatus string `json:"collaboratorStatus"` +} + +// KnowledgeBase is a Bedrock Knowledge Base. +type KnowledgeBase struct { + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + Tags map[string]string `json:"tags,omitempty"` + KBConfiguration map[string]any `json:"knowledgeBaseConfiguration,omitempty"` + StorageConfiguration map[string]any `json:"storageConfiguration,omitempty"` + KnowledgeBaseID string `json:"knowledgeBaseId"` + KnowledgeBaseARN string `json:"knowledgeBaseArn"` + Name string `json:"name"` + Status string `json:"status"` + Description string `json:"description,omitempty"` + RoleARN string `json:"roleArn,omitempty"` +} + +// KnowledgeBaseSummary is used in list responses. +type KnowledgeBaseSummary struct { + UpdatedAt time.Time `json:"updatedAt"` + KnowledgeBaseID string `json:"knowledgeBaseId"` + Name string `json:"name"` + Status string `json:"status"` + Description string `json:"description,omitempty"` +} + +// AgentKnowledgeBase is the association between an agent and a knowledge base. +type AgentKnowledgeBase struct { + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + AgentID string `json:"agentId"` + AgentVersion string `json:"agentVersion"` + KnowledgeBaseID string `json:"knowledgeBaseId"` + KBState string `json:"knowledgeBaseState"` + Description string `json:"description,omitempty"` +} + +// DataSource is a knowledge base data source. +type DataSource struct { + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + DataSourceConfiguration map[string]any `json:"dataSourceConfiguration,omitempty"` + VectorIngestionConfig map[string]any `json:"vectorIngestionConfiguration,omitempty"` + DataSourceID string `json:"dataSourceId"` + KnowledgeBaseID string `json:"knowledgeBaseId"` + Name string `json:"name"` + DataSourceStatus string `json:"dataSourceStatus"` + Description string `json:"description,omitempty"` + DataDeletionPolicy string `json:"dataDeletionPolicy,omitempty"` +} + +// DataSourceSummary is used in list responses. +type DataSourceSummary struct { + UpdatedAt time.Time `json:"updatedAt"` + DataSourceID string `json:"dataSourceId"` + KnowledgeBaseID string `json:"knowledgeBaseId"` + Name string `json:"name"` + DataSourceStatus string `json:"dataSourceStatus"` + Description string `json:"description,omitempty"` +} + +// IngestionJob is a knowledge base data ingestion job. +type IngestionJob struct { + StartedAt time.Time `json:"startedAt"` + UpdatedAt time.Time `json:"updatedAt"` + IngestionJobID string `json:"ingestionJobId"` + KnowledgeBaseID string `json:"knowledgeBaseId"` + DataSourceID string `json:"dataSourceId"` + Status string `json:"status"` + Description string `json:"description,omitempty"` +} + +// Flow is a Bedrock prompt flow. +type Flow struct { + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + Tags map[string]string `json:"tags,omitempty"` + Definition map[string]any `json:"definition,omitempty"` + FlowID string `json:"id"` + FlowARN string `json:"arn"` + Name string `json:"name"` + Status string `json:"status"` + Description string `json:"description,omitempty"` + RoleARN string `json:"executionRoleArn,omitempty"` + Version string `json:"version"` +} + +// FlowSummary is used in list responses. +type FlowSummary struct { + UpdatedAt time.Time `json:"updatedAt"` + FlowID string `json:"id"` + Name string `json:"name"` + Status string `json:"status"` + Description string `json:"description,omitempty"` + Version string `json:"version"` +} + +// FlowVersion is a snapshot of a flow. +type FlowVersion struct { + CreatedAt time.Time `json:"createdAt"` + Definition map[string]any `json:"definition,omitempty"` + FlowARN string `json:"arn"` + FlowID string `json:"id"` + Name string `json:"name"` + Status string `json:"status"` + Version string `json:"version"` + Description string `json:"description,omitempty"` +} + +// FlowVersionSummary is used in list responses. +type FlowVersionSummary struct { + CreatedAt time.Time `json:"createdAt"` + Arn string `json:"arn"` + FlowID string `json:"id"` + Name string `json:"name"` + Status string `json:"status"` + Version string `json:"version"` + Description string `json:"description,omitempty"` +} + +// FlowAliasRouting maps a flow alias to a specific flow version. +type FlowAliasRouting struct { + FlowVersion string `json:"flowVersion"` +} + +// FlowAlias routes traffic to a specific flow version. +type FlowAlias struct { + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + Tags map[string]string `json:"tags,omitempty"` + AliasID string `json:"id"` + AliasARN string `json:"arn"` + FlowID string `json:"flowId"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + RoutingConfiguration []FlowAliasRouting `json:"routingConfiguration,omitempty"` +} + +// FlowAliasSummary is used in list responses. +type FlowAliasSummary struct { + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + AliasID string `json:"id"` + AliasARN string `json:"arn"` + FlowID string `json:"flowId"` + Name string `json:"name"` + Description string `json:"description,omitempty"` +} + +// FlowValidationError is a flow definition validation error. +type FlowValidationError struct { + Message string `json:"message"` + Severity string `json:"severity"` +} + +// Prompt is a Bedrock Prompt resource. +type Prompt struct { + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + Tags map[string]string `json:"tags,omitempty"` + PromptID string `json:"id"` + PromptARN string `json:"arn"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + DefaultVariant string `json:"defaultVariant,omitempty"` + Version string `json:"version"` + Variants []map[string]any `json:"variants,omitempty"` +} + +// PromptSummary is used in list responses. +type PromptSummary struct { + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + PromptID string `json:"id"` + PromptARN string `json:"arn"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + Version string `json:"version"` +} + +// PromptVersion is an immutable snapshot of a prompt. +type PromptVersion struct { + CreatedAt time.Time `json:"createdAt"` + PromptARN string `json:"arn"` + PromptID string `json:"id"` + Name string `json:"name"` + Version string `json:"version"` + Description string `json:"description,omitempty"` + Variants []map[string]any `json:"variants,omitempty"` +} + +// KBDocumentDetail is the status of a knowledge base document operation. +type KBDocumentDetail struct { + DocumentID string `json:"documentId"` + KnowledgeBaseID string `json:"knowledgeBaseId"` + DataSourceID string `json:"dataSourceId"` + Status string `json:"status"` +} + +// --------------------------------------------------------------------------- +// InMemoryBackend +// --------------------------------------------------------------------------- + +// InMemoryBackend implements StorageBackend with in-memory maps, isolated by region. +type InMemoryBackend struct { + kbDocuments map[string]*KBDocumentDetail + agentsByName map[string]string + agentVersions map[string]map[string]*AgentVersion + actionGroups map[string]*AgentActionGroup + agentAliases map[string]*AgentAlias + agentCollaborators map[string]map[string]*AgentCollaborator + agentKBAssocs map[string]*AgentKnowledgeBase + knowledgeBases map[string]*KnowledgeBase + kbsByName map[string]string + dataSources map[string]*DataSource + ingestionJobs map[string]*IngestionJob + flows map[string]*Flow + flowsByName map[string]string + flowVersions map[string]map[string]*FlowVersion + flowAliases map[string]*FlowAlias + prompts map[string]*Prompt + promptVersions map[string]map[string]*PromptVersion + promptsByName map[string]string + promptVersionCtrs map[string]int + tags map[string]map[string]string + flowVersionCtrs map[string]int + agents map[string]*Agent + agentVersionCtrs map[string]int + accountID string + defaultRegion string + dsCounter int + collabCounter int + kbCounter int + flowCounter int + aliasCounter int + agentCounter int + actionGroupCounter int + flowAliasCounter int + promptCounter int + jobCounter int + mu sync.RWMutex +} + +var _ StorageBackend = (*InMemoryBackend)(nil) + +// NewInMemoryBackend creates and initialises an InMemoryBackend. +func NewInMemoryBackend(region, accountID string) *InMemoryBackend { + return &InMemoryBackend{ + agents: make(map[string]*Agent), + agentsByName: make(map[string]string), + agentVersions: make(map[string]map[string]*AgentVersion), + actionGroups: make(map[string]*AgentActionGroup), + agentAliases: make(map[string]*AgentAlias), + agentCollaborators: make(map[string]map[string]*AgentCollaborator), + agentKBAssocs: make(map[string]*AgentKnowledgeBase), + knowledgeBases: make(map[string]*KnowledgeBase), + kbsByName: make(map[string]string), + dataSources: make(map[string]*DataSource), + ingestionJobs: make(map[string]*IngestionJob), + flows: make(map[string]*Flow), + flowsByName: make(map[string]string), + flowVersions: make(map[string]map[string]*FlowVersion), + flowAliases: make(map[string]*FlowAlias), + prompts: make(map[string]*Prompt), + promptsByName: make(map[string]string), + promptVersions: make(map[string]map[string]*PromptVersion), + kbDocuments: make(map[string]*KBDocumentDetail), + tags: make(map[string]map[string]string), + agentVersionCtrs: make(map[string]int), + flowVersionCtrs: make(map[string]int), + promptVersionCtrs: make(map[string]int), + defaultRegion: region, + accountID: accountID, + } +} + +// Reset clears all backend state (used in tests). +func (b *InMemoryBackend) Reset() { + b.mu.Lock() + defer b.mu.Unlock() + + b.agents = make(map[string]*Agent) + b.agentsByName = make(map[string]string) + b.agentVersions = make(map[string]map[string]*AgentVersion) + b.actionGroups = make(map[string]*AgentActionGroup) + b.agentAliases = make(map[string]*AgentAlias) + b.agentCollaborators = make(map[string]map[string]*AgentCollaborator) + b.agentKBAssocs = make(map[string]*AgentKnowledgeBase) + b.knowledgeBases = make(map[string]*KnowledgeBase) + b.kbsByName = make(map[string]string) + b.dataSources = make(map[string]*DataSource) + b.ingestionJobs = make(map[string]*IngestionJob) + b.flows = make(map[string]*Flow) + b.flowsByName = make(map[string]string) + b.flowVersions = make(map[string]map[string]*FlowVersion) + b.flowAliases = make(map[string]*FlowAlias) + b.prompts = make(map[string]*Prompt) + b.promptsByName = make(map[string]string) + b.promptVersions = make(map[string]map[string]*PromptVersion) + b.kbDocuments = make(map[string]*KBDocumentDetail) + b.tags = make(map[string]map[string]string) + b.agentVersionCtrs = make(map[string]int) + b.flowVersionCtrs = make(map[string]int) + b.promptVersionCtrs = make(map[string]int) + b.agentCounter = 0 + b.actionGroupCounter = 0 + b.aliasCounter = 0 + b.collabCounter = 0 + b.kbCounter = 0 + b.dsCounter = 0 + b.jobCounter = 0 + b.flowCounter = 0 + b.flowAliasCounter = 0 + b.promptCounter = 0 +} + +// --------------------------------------------------------------------------- +// ID/ARN helpers +// --------------------------------------------------------------------------- + +func (b *InMemoryBackend) nextID(prefix string, counter *int) string { + *counter++ + + return fmt.Sprintf("%s-%08d", prefix, *counter) +} + +func (b *InMemoryBackend) buildAgentARN(region, agentID string) string { + return arn.Build(bedrockAgentService, region, b.accountID, "agent/"+agentID) +} + +func (b *InMemoryBackend) buildKBARN(region, kbID string) string { + return arn.Build(bedrockAgentService, region, b.accountID, "knowledge-base/"+kbID) +} + +func (b *InMemoryBackend) buildFlowARN(region, flowID string) string { + return arn.Build(bedrockAgentService, region, b.accountID, "flow/"+flowID) +} + +func (b *InMemoryBackend) buildPromptARN(region, promptID string) string { + return arn.Build(bedrockAgentService, region, b.accountID, "prompt/"+promptID) +} + +func (b *InMemoryBackend) buildAliasARN(region, agentID, aliasID string) string { + return arn.Build( + bedrockAgentService, + region, + b.accountID, + fmt.Sprintf("agent-alias/%s/%s", agentID, aliasID), + ) +} + +func (b *InMemoryBackend) buildFlowAliasARN(region, flowID, aliasID string) string { + return arn.Build( + bedrockAgentService, + region, + b.accountID, + fmt.Sprintf("flow-alias/%s/%s", flowID, aliasID), + ) +} + +// --------------------------------------------------------------------------- +// Agent CRUD +// --------------------------------------------------------------------------- + +// CreateAgent creates a new agent. +func (b *InMemoryBackend) CreateAgent(ctx context.Context, cfg AgentConfig) (*Agent, error) { + if cfg.AgentName == "" { + return nil, fmt.Errorf("%w: agentName is required", ErrValidation) + } + + region := ctxRegion(ctx, b.defaultRegion) + + b.mu.Lock() + defer b.mu.Unlock() + + if _, exists := b.agentsByName[cfg.AgentName]; exists { + return nil, fmt.Errorf("%w: agent %q already exists", ErrAlreadyExists, cfg.AgentName) + } + + id := b.nextID("agent", &b.agentCounter) + now := time.Now().UTC() + + a := &Agent{ + AgentID: id, + AgentARN: b.buildAgentARN(region, id), + AgentName: cfg.AgentName, + AgentVersion: defaultAgentVersion, + AgentStatus: agentStatusNotPrepared, + Collaboration: cfg.Collaboration, + Description: cfg.Description, + FoundationModel: cfg.FoundationModel, + Instruction: cfg.Instruction, + RoleARN: cfg.RoleARN, + Tags: maps.Clone(cfg.Tags), + Guardrail: cfg.Guardrail, + Memory: cfg.Memory, + CreatedAt: now, + UpdatedAt: now, + } + + b.agents[id] = a + b.agentsByName[cfg.AgentName] = id + b.tags[a.AgentARN] = maps.Clone(cfg.Tags) + + return agentCopy(a), nil +} + +// GetAgent returns an agent by ID. +func (b *InMemoryBackend) GetAgent(_ context.Context, agentID string) (*Agent, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + a, ok := b.agents[agentID] + if !ok { + return nil, fmt.Errorf("%w: agent %q not found", ErrNotFound, agentID) + } + + return agentCopy(a), nil +} + +// UpdateAgent updates an existing agent. +func (b *InMemoryBackend) UpdateAgent(_ context.Context, agentID string, cfg AgentConfig) (*Agent, error) { + b.mu.Lock() + defer b.mu.Unlock() + + a, ok := b.agents[agentID] + if !ok { + return nil, fmt.Errorf("%w: agent %q not found", ErrNotFound, agentID) + } + + if cfg.AgentName != "" && cfg.AgentName != a.AgentName { + if _, exists := b.agentsByName[cfg.AgentName]; exists { + return nil, fmt.Errorf("%w: agent name %q already in use", ErrAlreadyExists, cfg.AgentName) + } + + delete(b.agentsByName, a.AgentName) + b.agentsByName[cfg.AgentName] = agentID + a.AgentName = cfg.AgentName + } + + applyAgentConfig(a, cfg) + a.UpdatedAt = time.Now().UTC() + + return agentCopy(a), nil +} + +func applyAgentConfig(a *Agent, cfg AgentConfig) { + if cfg.Collaboration != "" { + a.Collaboration = cfg.Collaboration + } + + if cfg.Description != "" { + a.Description = cfg.Description + } + + if cfg.FoundationModel != "" { + a.FoundationModel = cfg.FoundationModel + } + + if cfg.Instruction != "" { + a.Instruction = cfg.Instruction + } + + if cfg.RoleARN != "" { + a.RoleARN = cfg.RoleARN + } + + if cfg.Guardrail != nil { + a.Guardrail = cfg.Guardrail + } + + if cfg.Memory != nil { + a.Memory = cfg.Memory + } +} + +// DeleteAgent deletes an agent. +func (b *InMemoryBackend) DeleteAgent(_ context.Context, agentID string) error { + b.mu.Lock() + defer b.mu.Unlock() + + a, ok := b.agents[agentID] + if !ok { + return fmt.Errorf("%w: agent %q not found", ErrNotFound, agentID) + } + + delete(b.agentsByName, a.AgentName) + delete(b.agents, agentID) + delete(b.agentVersions, agentID) + delete(b.agentVersionCtrs, agentID) + delete(b.agentCollaborators, agentID) + + return nil +} + +// ListAgents returns a paginated list of agent summaries. +func (b *InMemoryBackend) ListAgents( + _ context.Context, maxResults int, nextToken string, +) ([]*AgentSummary, string, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + ids := sortedKeys(b.agents) + ids, outToken := paginate(ids, nextToken, maxResults) + + out := make([]*AgentSummary, 0, len(ids)) + + for _, id := range ids { + a := b.agents[id] + out = append(out, &AgentSummary{ + AgentID: a.AgentID, + AgentName: a.AgentName, + AgentStatus: a.AgentStatus, + Description: a.Description, + UpdatedAt: a.UpdatedAt, + }) + } + + return out, outToken, nil +} + +// PrepareAgent transitions agent to PREPARED status. +func (b *InMemoryBackend) PrepareAgent(_ context.Context, agentID string) (*Agent, error) { + b.mu.Lock() + defer b.mu.Unlock() + + a, ok := b.agents[agentID] + if !ok { + return nil, fmt.Errorf("%w: agent %q not found", ErrNotFound, agentID) + } + + a.AgentStatus = agentStatusPrepared + a.UpdatedAt = time.Now().UTC() + + return agentCopy(a), nil +} + +// --------------------------------------------------------------------------- +// Agent version CRUD +// --------------------------------------------------------------------------- + +// CreateAgentVersion creates a numbered snapshot of an agent. +func (b *InMemoryBackend) CreateAgentVersion( + _ context.Context, agentID, description string, +) (*AgentVersion, error) { + b.mu.Lock() + defer b.mu.Unlock() + + a, ok := b.agents[agentID] + if !ok { + return nil, fmt.Errorf("%w: agent %q not found", ErrNotFound, agentID) + } + + b.agentVersionCtrs[agentID]++ + versionNum := b.agentVersionCtrs[agentID] + version := strconv.Itoa(versionNum) + + if b.agentVersions[agentID] == nil { + b.agentVersions[agentID] = make(map[string]*AgentVersion) + } + + now := time.Now().UTC() + av := &AgentVersion{ + AgentID: agentID, + AgentARN: a.AgentARN, + AgentName: a.AgentName, + AgentVersion: version, + AgentStatus: agentStatusPrepared, + FoundationModel: a.FoundationModel, + Instruction: a.Instruction, + RoleARN: a.RoleARN, + Description: description, + CreatedAt: now, + UpdatedAt: now, + } + + b.agentVersions[agentID][version] = av + + return agentVersionCopy(av), nil +} + +// GetAgentVersion returns a specific agent version. +func (b *InMemoryBackend) GetAgentVersion( + _ context.Context, agentID, agentVersion string, +) (*AgentVersion, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + versions, ok := b.agentVersions[agentID] + if !ok { + return nil, fmt.Errorf("%w: agent %q not found", ErrNotFound, agentID) + } + + av, ok := versions[agentVersion] + if !ok { + return nil, fmt.Errorf("%w: agent version %q not found", ErrNotFound, agentVersion) + } + + return agentVersionCopy(av), nil +} + +// DeleteAgentVersion deletes an agent version. +func (b *InMemoryBackend) DeleteAgentVersion( + _ context.Context, agentID, agentVersion string, +) error { + b.mu.Lock() + defer b.mu.Unlock() + + versions, ok := b.agentVersions[agentID] + if !ok { + return fmt.Errorf("%w: agent %q not found", ErrNotFound, agentID) + } + + if _, exists := versions[agentVersion]; !exists { + return fmt.Errorf("%w: agent version %q not found", ErrNotFound, agentVersion) + } + + delete(versions, agentVersion) + + return nil +} + +// ListAgentVersions returns paginated agent version summaries. +func (b *InMemoryBackend) ListAgentVersions( + _ context.Context, agentID string, maxResults int, nextToken string, +) ([]*AgentVersionSummary, string, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + if _, ok := b.agents[agentID]; !ok { + return nil, "", fmt.Errorf("%w: agent %q not found", ErrNotFound, agentID) + } + + versions := b.agentVersions[agentID] + keys := sortedKeys(versions) + keys, outToken := paginate(keys, nextToken, maxResults) + + out := make([]*AgentVersionSummary, 0, len(keys)) + + for _, k := range keys { + av := versions[k] + out = append(out, &AgentVersionSummary{ + AgentName: av.AgentName, + AgentVersion: av.AgentVersion, + AgentStatus: av.AgentStatus, + Description: av.Description, + UpdatedAt: av.UpdatedAt, + }) + } + + return out, outToken, nil +} + +// --------------------------------------------------------------------------- +// Agent action group CRUD +// --------------------------------------------------------------------------- + +func agActionGroupKey(agentID, agentVersion, actionGroupID string) string { + return agentID + "/" + agentVersion + "/" + actionGroupID +} + +// CreateAgentActionGroup creates an action group for an agent version. +func (b *InMemoryBackend) CreateAgentActionGroup( + _ context.Context, agentID string, cfg ActionGroupConfig, +) (*AgentActionGroup, error) { + if cfg.ActionGroupName == "" { + return nil, fmt.Errorf("%w: actionGroupName is required", ErrValidation) + } + + b.mu.Lock() + defer b.mu.Unlock() + + if _, ok := b.agents[agentID]; !ok { + return nil, fmt.Errorf("%w: agent %q not found", ErrNotFound, agentID) + } + + id := b.nextID("ag", &b.actionGroupCounter) + agentVersion := defaultAgentVersion + now := time.Now().UTC() + + ag := &AgentActionGroup{ + ActionGroupID: id, + ActionGroupName: cfg.ActionGroupName, + AgentID: agentID, + AgentVersion: agentVersion, + ActionGroupState: actionGroupEnabled, + Description: cfg.Description, + ActionGroupExecutor: cfg.ActionGroupExecutor, + APISchema: cfg.APISchema, + FunctionSchema: cfg.FunctionSchema, + CreatedAt: now, + UpdatedAt: now, + } + + if cfg.ActionGroupState != "" { + ag.ActionGroupState = cfg.ActionGroupState + } + + b.actionGroups[agActionGroupKey(agentID, agentVersion, id)] = ag + + return actionGroupCopy(ag), nil +} + +// GetAgentActionGroup returns an action group. +func (b *InMemoryBackend) GetAgentActionGroup( + _ context.Context, agentID, agentVersion, actionGroupID string, +) (*AgentActionGroup, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + ag, ok := b.actionGroups[agActionGroupKey(agentID, agentVersion, actionGroupID)] + if !ok { + return nil, fmt.Errorf("%w: action group %q not found", ErrNotFound, actionGroupID) + } + + return actionGroupCopy(ag), nil +} + +// UpdateAgentActionGroup updates an action group. +func (b *InMemoryBackend) UpdateAgentActionGroup( + _ context.Context, agentID, agentVersion, actionGroupID string, cfg ActionGroupConfig, +) (*AgentActionGroup, error) { + b.mu.Lock() + defer b.mu.Unlock() + + key := agActionGroupKey(agentID, agentVersion, actionGroupID) + + ag, ok := b.actionGroups[key] + if !ok { + return nil, fmt.Errorf("%w: action group %q not found", ErrNotFound, actionGroupID) + } + + applyActionGroupConfig(ag, cfg) + ag.UpdatedAt = time.Now().UTC() + + return actionGroupCopy(ag), nil +} + +func applyActionGroupConfig(ag *AgentActionGroup, cfg ActionGroupConfig) { + if cfg.ActionGroupName != "" { + ag.ActionGroupName = cfg.ActionGroupName + } + + if cfg.Description != "" { + ag.Description = cfg.Description + } + + if cfg.ActionGroupState != "" { + ag.ActionGroupState = cfg.ActionGroupState + } + + if cfg.ActionGroupExecutor != nil { + ag.ActionGroupExecutor = cfg.ActionGroupExecutor + } + + if cfg.APISchema != nil { + ag.APISchema = cfg.APISchema + } + + if cfg.FunctionSchema != nil { + ag.FunctionSchema = cfg.FunctionSchema + } +} + +// DeleteAgentActionGroup deletes an action group. +func (b *InMemoryBackend) DeleteAgentActionGroup( + _ context.Context, agentID, agentVersion, actionGroupID string, +) error { + b.mu.Lock() + defer b.mu.Unlock() + + key := agActionGroupKey(agentID, agentVersion, actionGroupID) + + if _, ok := b.actionGroups[key]; !ok { + return fmt.Errorf("%w: action group %q not found", ErrNotFound, actionGroupID) + } + + delete(b.actionGroups, key) + + return nil +} + +// ListAgentActionGroups returns all action groups for an agent version. +func (b *InMemoryBackend) ListAgentActionGroups( + _ context.Context, agentID, agentVersion string, maxResults int, nextToken string, +) ([]*ActionGroupSummary, string, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + prefix := agentID + "/" + agentVersion + "/" + + var ids []string + + for k := range b.actionGroups { + if k[:len(prefix)] == prefix { + ids = append(ids, k[len(prefix):]) + } + } + + sort.Strings(ids) + ids, outToken := paginate(ids, nextToken, maxResults) + + out := make([]*ActionGroupSummary, 0, len(ids)) + + for _, id := range ids { + ag := b.actionGroups[agActionGroupKey(agentID, agentVersion, id)] + out = append(out, &ActionGroupSummary{ + ActionGroupID: ag.ActionGroupID, + ActionGroupName: ag.ActionGroupName, + ActionGroupState: ag.ActionGroupState, + Description: ag.Description, + }) + } + + return out, outToken, nil +} + +// --------------------------------------------------------------------------- +// Agent alias CRUD +// --------------------------------------------------------------------------- + +func aliasKey(agentID, aliasID string) string { return agentID + "/" + aliasID } + +// CreateAgentAlias creates an alias for an agent. +func (b *InMemoryBackend) CreateAgentAlias( + ctx context.Context, agentID string, cfg AliasConfig, +) (*AgentAlias, error) { + if cfg.AliasName == "" { + return nil, fmt.Errorf("%w: agentAliasName is required", ErrValidation) + } + + region := ctxRegion(ctx, b.defaultRegion) + + b.mu.Lock() + defer b.mu.Unlock() + + if _, ok := b.agents[agentID]; !ok { + return nil, fmt.Errorf("%w: agent %q not found", ErrNotFound, agentID) + } + + id := b.nextID("alias", &b.aliasCounter) + now := time.Now().UTC() + + al := &AgentAlias{ + AgentAliasID: id, + AgentAliasARN: b.buildAliasARN(region, agentID, id), + AgentAliasName: cfg.AliasName, + AgentAliasStatus: aliasStatusPrepared, + AgentID: agentID, + Description: cfg.Description, + Tags: maps.Clone(cfg.Tags), + RoutingConfiguration: cfg.RoutingConfiguration, + CreatedAt: now, + UpdatedAt: now, + } + + b.agentAliases[aliasKey(agentID, id)] = al + + return aliasCopy(al), nil +} + +// GetAgentAlias returns an agent alias. +func (b *InMemoryBackend) GetAgentAlias(_ context.Context, agentID, aliasID string) (*AgentAlias, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + al, ok := b.agentAliases[aliasKey(agentID, aliasID)] + if !ok { + return nil, fmt.Errorf("%w: alias %q not found", ErrNotFound, aliasID) + } + + return aliasCopy(al), nil +} + +// UpdateAgentAlias updates an agent alias. +func (b *InMemoryBackend) UpdateAgentAlias( + _ context.Context, agentID, aliasID string, cfg AliasConfig, +) (*AgentAlias, error) { + b.mu.Lock() + defer b.mu.Unlock() + + al, ok := b.agentAliases[aliasKey(agentID, aliasID)] + if !ok { + return nil, fmt.Errorf("%w: alias %q not found", ErrNotFound, aliasID) + } + + if cfg.AliasName != "" { + al.AgentAliasName = cfg.AliasName + } + + if cfg.Description != "" { + al.Description = cfg.Description + } + + if cfg.RoutingConfiguration != nil { + al.RoutingConfiguration = cfg.RoutingConfiguration + } + + if cfg.Tags != nil { + al.Tags = maps.Clone(cfg.Tags) + } + + al.UpdatedAt = time.Now().UTC() + + return aliasCopy(al), nil +} + +// DeleteAgentAlias deletes an agent alias. +func (b *InMemoryBackend) DeleteAgentAlias(_ context.Context, agentID, aliasID string) error { + b.mu.Lock() + defer b.mu.Unlock() + + if _, ok := b.agentAliases[aliasKey(agentID, aliasID)]; !ok { + return fmt.Errorf("%w: alias %q not found", ErrNotFound, aliasID) + } + + delete(b.agentAliases, aliasKey(agentID, aliasID)) + + return nil +} + +// ListAgentAliases returns paginated alias summaries for an agent. +func (b *InMemoryBackend) ListAgentAliases( + _ context.Context, agentID string, maxResults int, nextToken string, +) ([]*AgentAliasSummary, string, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + prefix := agentID + "/" + + var ids []string + + for k := range b.agentAliases { + if len(k) > len(prefix) && k[:len(prefix)] == prefix { + ids = append(ids, k[len(prefix):]) + } + } + + sort.Strings(ids) + ids, outToken := paginate(ids, nextToken, maxResults) + + out := make([]*AgentAliasSummary, 0, len(ids)) + + for _, id := range ids { + al := b.agentAliases[aliasKey(agentID, id)] + out = append(out, &AgentAliasSummary{ + AgentAliasID: al.AgentAliasID, + AgentAliasName: al.AgentAliasName, + AgentAliasStatus: al.AgentAliasStatus, + Description: al.Description, + }) + } + + return out, outToken, nil +} + +// --------------------------------------------------------------------------- +// Agent collaborator CRUD +// --------------------------------------------------------------------------- + +// AssociateAgentCollaborator creates a collaborator association. +func (b *InMemoryBackend) AssociateAgentCollaborator( + _ context.Context, agentID, agentVersion string, cfg CollaboratorConfig, +) (*AgentCollaborator, error) { + b.mu.Lock() + defer b.mu.Unlock() + + if _, ok := b.agents[agentID]; !ok { + return nil, fmt.Errorf("%w: agent %q not found", ErrNotFound, agentID) + } + + id := b.nextID("collab", &b.collabCounter) + + if b.agentCollaborators[agentID+"/"+agentVersion] == nil { + b.agentCollaborators[agentID+"/"+agentVersion] = make(map[string]*AgentCollaborator) + } + + now := time.Now().UTC() + c := &AgentCollaborator{ + AgentID: agentID, + AgentVersion: agentVersion, + CollaboratorID: id, + CollaboratorName: cfg.CollaboratorName, + CollaborationInstruction: cfg.CollaborationInstruction, + RelayConversationHistory: cfg.RelayConversationHistory, + AgentDescriptor: cfg.AgentDescriptor, + CollaboratorStatus: collabEnabled, + CreatedAt: now, + UpdatedAt: now, + } + + b.agentCollaborators[agentID+"/"+agentVersion][id] = c + + return collabCopy(c), nil +} + +// GetAgentCollaborator returns a collaborator by ID. +func (b *InMemoryBackend) GetAgentCollaborator( + _ context.Context, agentID, agentVersion, collaboratorID string, +) (*AgentCollaborator, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + group, ok := b.agentCollaborators[agentID+"/"+agentVersion] + if !ok { + return nil, fmt.Errorf("%w: collaborator %q not found", ErrNotFound, collaboratorID) + } + + c, ok := group[collaboratorID] + if !ok { + return nil, fmt.Errorf("%w: collaborator %q not found", ErrNotFound, collaboratorID) + } + + return collabCopy(c), nil +} + +// UpdateAgentCollaborator updates a collaborator. +func (b *InMemoryBackend) UpdateAgentCollaborator( + _ context.Context, agentID, agentVersion, collaboratorID string, cfg CollaboratorConfig, +) (*AgentCollaborator, error) { + b.mu.Lock() + defer b.mu.Unlock() + + group, ok := b.agentCollaborators[agentID+"/"+agentVersion] + if !ok { + return nil, fmt.Errorf("%w: collaborator %q not found", ErrNotFound, collaboratorID) + } + + c, ok := group[collaboratorID] + if !ok { + return nil, fmt.Errorf("%w: collaborator %q not found", ErrNotFound, collaboratorID) + } + + if cfg.CollaboratorName != "" { + c.CollaboratorName = cfg.CollaboratorName + } + + if cfg.CollaborationInstruction != "" { + c.CollaborationInstruction = cfg.CollaborationInstruction + } + + if cfg.RelayConversationHistory != "" { + c.RelayConversationHistory = cfg.RelayConversationHistory + } + + if cfg.AgentDescriptor != nil { + c.AgentDescriptor = cfg.AgentDescriptor + } + + c.UpdatedAt = time.Now().UTC() + + return collabCopy(c), nil +} + +// DisassociateAgentCollaborator removes a collaborator. +func (b *InMemoryBackend) DisassociateAgentCollaborator( + _ context.Context, agentID, agentVersion, collaboratorID string, +) error { + b.mu.Lock() + defer b.mu.Unlock() + + group, ok := b.agentCollaborators[agentID+"/"+agentVersion] + if !ok { + return fmt.Errorf("%w: collaborator %q not found", ErrNotFound, collaboratorID) + } + + if _, exists := group[collaboratorID]; !exists { + return fmt.Errorf("%w: collaborator %q not found", ErrNotFound, collaboratorID) + } + + delete(group, collaboratorID) + + return nil +} + +// ListAgentCollaborators returns paginated collaborators. +func (b *InMemoryBackend) ListAgentCollaborators( + _ context.Context, agentID, agentVersion string, maxResults int, nextToken string, +) ([]*AgentCollaborator, string, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + group := b.agentCollaborators[agentID+"/"+agentVersion] + + ids := sortedKeys(group) + ids, outToken := paginate(ids, nextToken, maxResults) + + out := make([]*AgentCollaborator, 0, len(ids)) + + for _, id := range ids { + out = append(out, collabCopy(group[id])) + } + + return out, outToken, nil +} + +// --------------------------------------------------------------------------- +// Knowledge base CRUD +// --------------------------------------------------------------------------- + +// CreateKnowledgeBase creates a new knowledge base. +func (b *InMemoryBackend) CreateKnowledgeBase( + ctx context.Context, cfg KnowledgeBaseConfig, +) (*KnowledgeBase, error) { + if cfg.Name == "" { + return nil, fmt.Errorf("%w: name is required", ErrValidation) + } + + region := ctxRegion(ctx, b.defaultRegion) + + b.mu.Lock() + defer b.mu.Unlock() + + if _, exists := b.kbsByName[cfg.Name]; exists { + return nil, fmt.Errorf("%w: knowledge base %q already exists", ErrAlreadyExists, cfg.Name) + } + + id := b.nextID("kb", &b.kbCounter) + now := time.Now().UTC() + + kb := &KnowledgeBase{ + KnowledgeBaseID: id, + KnowledgeBaseARN: b.buildKBARN(region, id), + Name: cfg.Name, + Status: kbStatusActive, + Description: cfg.Description, + RoleARN: cfg.RoleARN, + KBConfiguration: cfg.KBConfiguration, + StorageConfiguration: cfg.StorageConfiguration, + Tags: maps.Clone(cfg.Tags), + CreatedAt: now, + UpdatedAt: now, + } + + b.knowledgeBases[id] = kb + b.kbsByName[cfg.Name] = id + b.tags[kb.KnowledgeBaseARN] = maps.Clone(cfg.Tags) + + return kbCopy(kb), nil +} + +// GetKnowledgeBase returns a knowledge base. +func (b *InMemoryBackend) GetKnowledgeBase(_ context.Context, kbID string) (*KnowledgeBase, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + kb, ok := b.knowledgeBases[kbID] + if !ok { + return nil, fmt.Errorf("%w: knowledge base %q not found", ErrNotFound, kbID) + } + + return kbCopy(kb), nil +} + +// UpdateKnowledgeBase updates a knowledge base. +func (b *InMemoryBackend) UpdateKnowledgeBase( + _ context.Context, kbID string, cfg KnowledgeBaseConfig, +) (*KnowledgeBase, error) { + b.mu.Lock() + defer b.mu.Unlock() + + kb, ok := b.knowledgeBases[kbID] + if !ok { + return nil, fmt.Errorf("%w: knowledge base %q not found", ErrNotFound, kbID) + } + + if cfg.Name != "" { + kb.Name = cfg.Name + } + + if cfg.Description != "" { + kb.Description = cfg.Description + } + + if cfg.RoleARN != "" { + kb.RoleARN = cfg.RoleARN + } + + if cfg.KBConfiguration != nil { + kb.KBConfiguration = cfg.KBConfiguration + } + + if cfg.StorageConfiguration != nil { + kb.StorageConfiguration = cfg.StorageConfiguration + } + + kb.UpdatedAt = time.Now().UTC() + + return kbCopy(kb), nil +} + +// DeleteKnowledgeBase deletes a knowledge base. +func (b *InMemoryBackend) DeleteKnowledgeBase(_ context.Context, kbID string) error { + b.mu.Lock() + defer b.mu.Unlock() + + kb, ok := b.knowledgeBases[kbID] + if !ok { + return fmt.Errorf("%w: knowledge base %q not found", ErrNotFound, kbID) + } + + delete(b.kbsByName, kb.Name) + delete(b.knowledgeBases, kbID) + + return nil +} + +// ListKnowledgeBases returns paginated knowledge base summaries. +func (b *InMemoryBackend) ListKnowledgeBases( + _ context.Context, maxResults int, nextToken string, +) ([]*KnowledgeBaseSummary, string, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + ids := sortedKeys(b.knowledgeBases) + ids, outToken := paginate(ids, nextToken, maxResults) + + out := make([]*KnowledgeBaseSummary, 0, len(ids)) + + for _, id := range ids { + kb := b.knowledgeBases[id] + out = append(out, &KnowledgeBaseSummary{ + KnowledgeBaseID: kb.KnowledgeBaseID, + Name: kb.Name, + Status: kb.Status, + Description: kb.Description, + UpdatedAt: kb.UpdatedAt, + }) + } + + return out, outToken, nil +} + +// --------------------------------------------------------------------------- +// Agent–knowledge base association CRUD +// --------------------------------------------------------------------------- + +func agKBKey(agentID, agentVersion, kbID string) string { + return agentID + "/" + agentVersion + "/" + kbID +} + +// AssociateAgentKnowledgeBase creates an agent–KB association. +func (b *InMemoryBackend) AssociateAgentKnowledgeBase( + _ context.Context, agentID, agentVersion, kbID, description, kbState string, +) (*AgentKnowledgeBase, error) { + b.mu.Lock() + defer b.mu.Unlock() + + if _, ok := b.agents[agentID]; !ok { + return nil, fmt.Errorf("%w: agent %q not found", ErrNotFound, agentID) + } + + if _, ok := b.knowledgeBases[kbID]; !ok { + return nil, fmt.Errorf("%w: knowledge base %q not found", ErrNotFound, kbID) + } + + now := time.Now().UTC() + state := "ENABLED" + + if kbState != "" { + state = kbState + } + + assoc := &AgentKnowledgeBase{ + AgentID: agentID, + AgentVersion: agentVersion, + KnowledgeBaseID: kbID, + KBState: state, + Description: description, + CreatedAt: now, + UpdatedAt: now, + } + + b.agentKBAssocs[agKBKey(agentID, agentVersion, kbID)] = assoc + + return agKBCopy(assoc), nil +} + +// GetAgentKnowledgeBase returns an agent–KB association. +func (b *InMemoryBackend) GetAgentKnowledgeBase( + _ context.Context, agentID, agentVersion, kbID string, +) (*AgentKnowledgeBase, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + assoc, ok := b.agentKBAssocs[agKBKey(agentID, agentVersion, kbID)] + if !ok { + return nil, fmt.Errorf("%w: association for kb %q not found", ErrNotFound, kbID) + } + + return agKBCopy(assoc), nil +} + +// UpdateAgentKnowledgeBase updates an agent–KB association. +func (b *InMemoryBackend) UpdateAgentKnowledgeBase( + _ context.Context, agentID, agentVersion, kbID, description, kbState string, +) (*AgentKnowledgeBase, error) { + b.mu.Lock() + defer b.mu.Unlock() + + key := agKBKey(agentID, agentVersion, kbID) + + assoc, ok := b.agentKBAssocs[key] + if !ok { + return nil, fmt.Errorf("%w: association for kb %q not found", ErrNotFound, kbID) + } + + if description != "" { + assoc.Description = description + } + + if kbState != "" { + assoc.KBState = kbState + } + + assoc.UpdatedAt = time.Now().UTC() + + return agKBCopy(assoc), nil +} + +// DisassociateAgentKnowledgeBase removes an agent–KB association. +func (b *InMemoryBackend) DisassociateAgentKnowledgeBase( + _ context.Context, agentID, agentVersion, kbID string, +) error { + b.mu.Lock() + defer b.mu.Unlock() + + key := agKBKey(agentID, agentVersion, kbID) + + if _, ok := b.agentKBAssocs[key]; !ok { + return fmt.Errorf("%w: association for kb %q not found", ErrNotFound, kbID) + } + + delete(b.agentKBAssocs, key) + + return nil +} + +// ListAgentKnowledgeBases returns paginated agent–KB associations. +func (b *InMemoryBackend) ListAgentKnowledgeBases( + _ context.Context, agentID, agentVersion string, maxResults int, nextToken string, +) ([]*AgentKnowledgeBase, string, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + prefix := agentID + "/" + agentVersion + "/" + + var ids []string + + for k := range b.agentKBAssocs { + if len(k) > len(prefix) && k[:len(prefix)] == prefix { + ids = append(ids, k[len(prefix):]) + } + } + + sort.Strings(ids) + ids, outToken := paginate(ids, nextToken, maxResults) + + out := make([]*AgentKnowledgeBase, 0, len(ids)) + + for _, id := range ids { + out = append(out, agKBCopy(b.agentKBAssocs[agKBKey(agentID, agentVersion, id)])) + } + + return out, outToken, nil +} + +// --------------------------------------------------------------------------- +// Data source CRUD +// --------------------------------------------------------------------------- + +func dsKey(kbID, dsID string) string { return kbID + "/" + dsID } + +// CreateDataSource creates a data source in a knowledge base. +func (b *InMemoryBackend) CreateDataSource( + _ context.Context, kbID string, cfg DataSourceConfig, +) (*DataSource, error) { + if cfg.Name == "" { + return nil, fmt.Errorf("%w: name is required", ErrValidation) + } + + b.mu.Lock() + defer b.mu.Unlock() + + if _, ok := b.knowledgeBases[kbID]; !ok { + return nil, fmt.Errorf("%w: knowledge base %q not found", ErrNotFound, kbID) + } + + id := b.nextID("ds", &b.dsCounter) + now := time.Now().UTC() + + ds := &DataSource{ + DataSourceID: id, + KnowledgeBaseID: kbID, + Name: cfg.Name, + DataSourceStatus: dsStatusAvailable, + Description: cfg.Description, + DataDeletionPolicy: cfg.DataDeletionPolicy, + DataSourceConfiguration: cfg.DataSourceConfiguration, + VectorIngestionConfig: cfg.VectorIngestionConfig, + CreatedAt: now, + UpdatedAt: now, + } + + b.dataSources[dsKey(kbID, id)] = ds + + return dsCopy(ds), nil +} + +// GetDataSource returns a data source. +func (b *InMemoryBackend) GetDataSource(_ context.Context, kbID, dsID string) (*DataSource, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + ds, ok := b.dataSources[dsKey(kbID, dsID)] + if !ok { + return nil, fmt.Errorf("%w: data source %q not found", ErrNotFound, dsID) + } + + return dsCopy(ds), nil +} + +// UpdateDataSource updates a data source. +func (b *InMemoryBackend) UpdateDataSource( + _ context.Context, kbID, dsID string, cfg DataSourceConfig, +) (*DataSource, error) { + b.mu.Lock() + defer b.mu.Unlock() + + ds, ok := b.dataSources[dsKey(kbID, dsID)] + if !ok { + return nil, fmt.Errorf("%w: data source %q not found", ErrNotFound, dsID) + } + + if cfg.Name != "" { + ds.Name = cfg.Name + } + + if cfg.Description != "" { + ds.Description = cfg.Description + } + + if cfg.DataDeletionPolicy != "" { + ds.DataDeletionPolicy = cfg.DataDeletionPolicy + } + + if cfg.DataSourceConfiguration != nil { + ds.DataSourceConfiguration = cfg.DataSourceConfiguration + } + + if cfg.VectorIngestionConfig != nil { + ds.VectorIngestionConfig = cfg.VectorIngestionConfig + } + + ds.UpdatedAt = time.Now().UTC() + + return dsCopy(ds), nil +} + +// DeleteDataSource deletes a data source. +func (b *InMemoryBackend) DeleteDataSource(_ context.Context, kbID, dsID string) error { + b.mu.Lock() + defer b.mu.Unlock() + + if _, ok := b.dataSources[dsKey(kbID, dsID)]; !ok { + return fmt.Errorf("%w: data source %q not found", ErrNotFound, dsID) + } + + delete(b.dataSources, dsKey(kbID, dsID)) + + return nil +} + +// ListDataSources returns paginated data source summaries. +func (b *InMemoryBackend) ListDataSources( + _ context.Context, kbID string, maxResults int, nextToken string, +) ([]*DataSourceSummary, string, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + prefix := kbID + "/" + + var ids []string + + for k := range b.dataSources { + if len(k) > len(prefix) && k[:len(prefix)] == prefix { + ids = append(ids, k[len(prefix):]) + } + } + + sort.Strings(ids) + ids, outToken := paginate(ids, nextToken, maxResults) + + out := make([]*DataSourceSummary, 0, len(ids)) + + for _, id := range ids { + ds := b.dataSources[dsKey(kbID, id)] + out = append(out, &DataSourceSummary{ + DataSourceID: ds.DataSourceID, + KnowledgeBaseID: ds.KnowledgeBaseID, + Name: ds.Name, + DataSourceStatus: ds.DataSourceStatus, + Description: ds.Description, + UpdatedAt: ds.UpdatedAt, + }) + } + + return out, outToken, nil +} + +// --------------------------------------------------------------------------- +// Ingestion job CRUD +// --------------------------------------------------------------------------- + +func jobKey(kbID, dsID, jobID string) string { return kbID + "/" + dsID + "/" + jobID } + +// StartIngestionJob creates and starts a new ingestion job. +func (b *InMemoryBackend) StartIngestionJob( + _ context.Context, kbID, dsID, description string, +) (*IngestionJob, error) { + b.mu.Lock() + defer b.mu.Unlock() + + if _, ok := b.dataSources[dsKey(kbID, dsID)]; !ok { + return nil, fmt.Errorf("%w: data source %q not found", ErrNotFound, dsID) + } + + id := b.nextID("job", &b.jobCounter) + now := time.Now().UTC() + + job := &IngestionJob{ + IngestionJobID: id, + KnowledgeBaseID: kbID, + DataSourceID: dsID, + Status: ingestionJobComplete, + Description: description, + StartedAt: now, + UpdatedAt: now, + } + + b.ingestionJobs[jobKey(kbID, dsID, id)] = job + + return jobCopy(job), nil +} + +// GetIngestionJob returns an ingestion job. +func (b *InMemoryBackend) GetIngestionJob( + _ context.Context, kbID, dsID, jobID string, +) (*IngestionJob, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + job, ok := b.ingestionJobs[jobKey(kbID, dsID, jobID)] + if !ok { + return nil, fmt.Errorf("%w: ingestion job %q not found", ErrNotFound, jobID) + } + + return jobCopy(job), nil +} + +// StopIngestionJob stops an ingestion job. +func (b *InMemoryBackend) StopIngestionJob( + _ context.Context, kbID, dsID, jobID string, +) (*IngestionJob, error) { + b.mu.Lock() + defer b.mu.Unlock() + + job, ok := b.ingestionJobs[jobKey(kbID, dsID, jobID)] + if !ok { + return nil, fmt.Errorf("%w: ingestion job %q not found", ErrNotFound, jobID) + } + + job.Status = "STOPPED" + job.UpdatedAt = time.Now().UTC() + + return jobCopy(job), nil +} + +// ListIngestionJobs returns paginated ingestion job summaries. +func (b *InMemoryBackend) ListIngestionJobs( + _ context.Context, kbID, dsID string, maxResults int, nextToken string, +) ([]*IngestionJob, string, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + prefix := kbID + "/" + dsID + "/" + + var ids []string + + for k := range b.ingestionJobs { + if len(k) > len(prefix) && k[:len(prefix)] == prefix { + ids = append(ids, k[len(prefix):]) + } + } + + sort.Strings(ids) + ids, outToken := paginate(ids, nextToken, maxResults) + + out := make([]*IngestionJob, 0, len(ids)) + + for _, id := range ids { + out = append(out, jobCopy(b.ingestionJobs[jobKey(kbID, dsID, id)])) + } + + return out, outToken, nil +} + +// --------------------------------------------------------------------------- +// Flow CRUD +// --------------------------------------------------------------------------- + +// CreateFlow creates a new flow. +func (b *InMemoryBackend) CreateFlow(ctx context.Context, cfg FlowConfig) (*Flow, error) { + if cfg.Name == "" { + return nil, fmt.Errorf("%w: name is required", ErrValidation) + } + + region := ctxRegion(ctx, b.defaultRegion) + + b.mu.Lock() + defer b.mu.Unlock() + + if _, exists := b.flowsByName[cfg.Name]; exists { + return nil, fmt.Errorf("%w: flow %q already exists", ErrAlreadyExists, cfg.Name) + } + + id := b.nextID("flow", &b.flowCounter) + now := time.Now().UTC() + + f := &Flow{ + FlowID: id, + FlowARN: b.buildFlowARN(region, id), + Name: cfg.Name, + Status: flowStatusNotPrepared, + Description: cfg.Description, + RoleARN: cfg.RoleARN, + Definition: cfg.Definition, + Tags: maps.Clone(cfg.Tags), + Version: "DRAFT", + CreatedAt: now, + UpdatedAt: now, + } + + b.flows[id] = f + b.flowsByName[cfg.Name] = id + b.tags[f.FlowARN] = maps.Clone(cfg.Tags) + + return flowCopy(f), nil +} + +// GetFlow returns a flow. +func (b *InMemoryBackend) GetFlow(_ context.Context, flowID string) (*Flow, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + f, ok := b.flows[flowID] + if !ok { + return nil, fmt.Errorf("%w: flow %q not found", ErrNotFound, flowID) + } + + return flowCopy(f), nil +} + +// UpdateFlow updates a flow. +func (b *InMemoryBackend) UpdateFlow(_ context.Context, flowID string, cfg FlowConfig) (*Flow, error) { + b.mu.Lock() + defer b.mu.Unlock() + + f, ok := b.flows[flowID] + if !ok { + return nil, fmt.Errorf("%w: flow %q not found", ErrNotFound, flowID) + } + + applyFlowConfig(f, cfg) + f.UpdatedAt = time.Now().UTC() + + return flowCopy(f), nil +} + +func applyFlowConfig(f *Flow, cfg FlowConfig) { + if cfg.Name != "" { + f.Name = cfg.Name + } + + if cfg.Description != "" { + f.Description = cfg.Description + } + + if cfg.RoleARN != "" { + f.RoleARN = cfg.RoleARN + } + + if cfg.Definition != nil { + f.Definition = cfg.Definition + } + + if cfg.Tags != nil { + f.Tags = maps.Clone(cfg.Tags) + } +} + +// DeleteFlow deletes a flow. +func (b *InMemoryBackend) DeleteFlow(_ context.Context, flowID string) error { + b.mu.Lock() + defer b.mu.Unlock() + + f, ok := b.flows[flowID] + if !ok { + return fmt.Errorf("%w: flow %q not found", ErrNotFound, flowID) + } + + delete(b.flowsByName, f.Name) + delete(b.flows, flowID) + delete(b.flowVersions, flowID) + delete(b.flowVersionCtrs, flowID) + + return nil +} + +// ListFlows returns paginated flow summaries. +func (b *InMemoryBackend) ListFlows( + _ context.Context, maxResults int, nextToken string, +) ([]*FlowSummary, string, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + ids := sortedKeys(b.flows) + ids, outToken := paginate(ids, nextToken, maxResults) + + out := make([]*FlowSummary, 0, len(ids)) + + for _, id := range ids { + f := b.flows[id] + out = append(out, &FlowSummary{ + FlowID: f.FlowID, + Name: f.Name, + Status: f.Status, + Description: f.Description, + Version: f.Version, + UpdatedAt: f.UpdatedAt, + }) + } + + return out, outToken, nil +} + +// PrepareFlow transitions a flow to prepared status. +func (b *InMemoryBackend) PrepareFlow(_ context.Context, flowID string) (*Flow, error) { + b.mu.Lock() + defer b.mu.Unlock() + + f, ok := b.flows[flowID] + if !ok { + return nil, fmt.Errorf("%w: flow %q not found", ErrNotFound, flowID) + } + + f.Status = flowStatusPrepared + f.UpdatedAt = time.Now().UTC() + + return flowCopy(f), nil +} + +// ValidateFlowDefinition validates a flow definition (stub - always passes). +func (b *InMemoryBackend) ValidateFlowDefinition( + _ context.Context, _ map[string]any, +) ([]FlowValidationError, error) { + return []FlowValidationError{}, nil +} + +// --------------------------------------------------------------------------- +// Flow version CRUD +// --------------------------------------------------------------------------- + +// CreateFlowVersion creates a numbered snapshot of a flow. +func (b *InMemoryBackend) CreateFlowVersion( + _ context.Context, flowID, description string, +) (*FlowVersion, error) { + b.mu.Lock() + defer b.mu.Unlock() + + f, ok := b.flows[flowID] + if !ok { + return nil, fmt.Errorf("%w: flow %q not found", ErrNotFound, flowID) + } + + b.flowVersionCtrs[flowID]++ + vNum := b.flowVersionCtrs[flowID] + version := strconv.Itoa(vNum) + + if b.flowVersions[flowID] == nil { + b.flowVersions[flowID] = make(map[string]*FlowVersion) + } + + fv := &FlowVersion{ + FlowID: flowID, + FlowARN: f.FlowARN, + Name: f.Name, + Version: version, + Status: flowStatusPrepared, + Definition: f.Definition, + Description: description, + CreatedAt: time.Now().UTC(), + } + + b.flowVersions[flowID][version] = fv + + return flowVersionCopy(fv), nil +} + +// GetFlowVersion returns a flow version. +func (b *InMemoryBackend) GetFlowVersion( + _ context.Context, flowID, flowVersion string, +) (*FlowVersion, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + versions, ok := b.flowVersions[flowID] + if !ok { + return nil, fmt.Errorf("%w: flow %q not found", ErrNotFound, flowID) + } + + fv, ok := versions[flowVersion] + if !ok { + return nil, fmt.Errorf("%w: flow version %q not found", ErrNotFound, flowVersion) + } + + return flowVersionCopy(fv), nil +} + +// DeleteFlowVersion deletes a flow version. +func (b *InMemoryBackend) DeleteFlowVersion(_ context.Context, flowID, flowVersion string) error { + b.mu.Lock() + defer b.mu.Unlock() + + versions, ok := b.flowVersions[flowID] + if !ok { + return fmt.Errorf("%w: flow %q not found", ErrNotFound, flowID) + } + + if _, exists := versions[flowVersion]; !exists { + return fmt.Errorf("%w: flow version %q not found", ErrNotFound, flowVersion) + } + + delete(versions, flowVersion) + + return nil +} + +// ListFlowVersions returns paginated flow version summaries. +func (b *InMemoryBackend) ListFlowVersions( + _ context.Context, flowID string, maxResults int, nextToken string, +) ([]*FlowVersionSummary, string, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + if _, ok := b.flows[flowID]; !ok { + return nil, "", fmt.Errorf("%w: flow %q not found", ErrNotFound, flowID) + } + + versions := b.flowVersions[flowID] + keys := sortedKeys(versions) + keys, outToken := paginate(keys, nextToken, maxResults) + + out := make([]*FlowVersionSummary, 0, len(keys)) + + for _, k := range keys { + fv := versions[k] + out = append(out, &FlowVersionSummary{ + FlowID: fv.FlowID, + Arn: fv.FlowARN, + Name: fv.Name, + Version: fv.Version, + Status: fv.Status, + Description: fv.Description, + CreatedAt: fv.CreatedAt, + }) + } + + return out, outToken, nil +} + +// --------------------------------------------------------------------------- +// Flow alias CRUD +// --------------------------------------------------------------------------- + +func flowAliasKey(flowID, aliasID string) string { return flowID + "/" + aliasID } + +// CreateFlowAlias creates a flow alias. +func (b *InMemoryBackend) CreateFlowAlias( + ctx context.Context, flowID string, cfg FlowAliasConfig, +) (*FlowAlias, error) { + if cfg.Name == "" { + return nil, fmt.Errorf("%w: name is required", ErrValidation) + } + + region := ctxRegion(ctx, b.defaultRegion) + + b.mu.Lock() + defer b.mu.Unlock() + + if _, ok := b.flows[flowID]; !ok { + return nil, fmt.Errorf("%w: flow %q not found", ErrNotFound, flowID) + } + + id := b.nextID("falias", &b.flowAliasCounter) + now := time.Now().UTC() + + al := &FlowAlias{ + AliasID: id, + AliasARN: b.buildFlowAliasARN(region, flowID, id), + FlowID: flowID, + Name: cfg.Name, + Description: cfg.Description, + RoutingConfiguration: cfg.RoutingConfiguration, + Tags: maps.Clone(cfg.Tags), + CreatedAt: now, + UpdatedAt: now, + } + + b.flowAliases[flowAliasKey(flowID, id)] = al + + return flowAliasCopy(al), nil +} + +// GetFlowAlias returns a flow alias. +func (b *InMemoryBackend) GetFlowAlias(_ context.Context, flowID, aliasID string) (*FlowAlias, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + al, ok := b.flowAliases[flowAliasKey(flowID, aliasID)] + if !ok { + return nil, fmt.Errorf("%w: flow alias %q not found", ErrNotFound, aliasID) + } + + return flowAliasCopy(al), nil +} + +// UpdateFlowAlias updates a flow alias. +func (b *InMemoryBackend) UpdateFlowAlias( + _ context.Context, flowID, aliasID string, cfg FlowAliasConfig, +) (*FlowAlias, error) { + b.mu.Lock() + defer b.mu.Unlock() + + al, ok := b.flowAliases[flowAliasKey(flowID, aliasID)] + if !ok { + return nil, fmt.Errorf("%w: flow alias %q not found", ErrNotFound, aliasID) + } + + if cfg.Name != "" { + al.Name = cfg.Name + } + + if cfg.Description != "" { + al.Description = cfg.Description + } + + if cfg.RoutingConfiguration != nil { + al.RoutingConfiguration = cfg.RoutingConfiguration + } + + if cfg.Tags != nil { + al.Tags = maps.Clone(cfg.Tags) + } + + al.UpdatedAt = time.Now().UTC() + + return flowAliasCopy(al), nil +} + +// DeleteFlowAlias deletes a flow alias. +func (b *InMemoryBackend) DeleteFlowAlias(_ context.Context, flowID, aliasID string) error { + b.mu.Lock() + defer b.mu.Unlock() + + if _, ok := b.flowAliases[flowAliasKey(flowID, aliasID)]; !ok { + return fmt.Errorf("%w: flow alias %q not found", ErrNotFound, aliasID) + } + + delete(b.flowAliases, flowAliasKey(flowID, aliasID)) + + return nil +} + +// ListFlowAliases returns paginated flow alias summaries. +func (b *InMemoryBackend) ListFlowAliases( + _ context.Context, flowID string, maxResults int, nextToken string, +) ([]*FlowAliasSummary, string, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + prefix := flowID + "/" + + var ids []string + + for k := range b.flowAliases { + if len(k) > len(prefix) && k[:len(prefix)] == prefix { + ids = append(ids, k[len(prefix):]) + } + } + + sort.Strings(ids) + ids, outToken := paginate(ids, nextToken, maxResults) + + out := make([]*FlowAliasSummary, 0, len(ids)) + + for _, id := range ids { + al := b.flowAliases[flowAliasKey(flowID, id)] + out = append(out, &FlowAliasSummary{ + AliasID: al.AliasID, + AliasARN: al.AliasARN, + FlowID: al.FlowID, + Name: al.Name, + Description: al.Description, + CreatedAt: al.CreatedAt, + UpdatedAt: al.UpdatedAt, + }) + } + + return out, outToken, nil +} + +// --------------------------------------------------------------------------- +// Prompt CRUD +// --------------------------------------------------------------------------- + +// CreatePrompt creates a new prompt. +func (b *InMemoryBackend) CreatePrompt(ctx context.Context, cfg PromptConfig) (*Prompt, error) { + if cfg.Name == "" { + return nil, fmt.Errorf("%w: name is required", ErrValidation) + } + + region := ctxRegion(ctx, b.defaultRegion) + + b.mu.Lock() + defer b.mu.Unlock() + + if _, exists := b.promptsByName[cfg.Name]; exists { + return nil, fmt.Errorf("%w: prompt %q already exists", ErrAlreadyExists, cfg.Name) + } + + id := b.nextID("prompt", &b.promptCounter) + now := time.Now().UTC() + + p := &Prompt{ + PromptID: id, + PromptARN: b.buildPromptARN(region, id), + Name: cfg.Name, + Description: cfg.Description, + DefaultVariant: cfg.DefaultVariant, + Variants: cfg.Variants, + Tags: maps.Clone(cfg.Tags), + Version: "DRAFT", + CreatedAt: now, + UpdatedAt: now, + } + + b.prompts[id] = p + b.promptsByName[cfg.Name] = id + b.tags[p.PromptARN] = maps.Clone(cfg.Tags) + + return promptCopy(p), nil +} + +// GetPrompt returns a prompt. +func (b *InMemoryBackend) GetPrompt(_ context.Context, promptID string) (*Prompt, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + p, ok := b.prompts[promptID] + if !ok { + return nil, fmt.Errorf("%w: prompt %q not found", ErrNotFound, promptID) + } + + return promptCopy(p), nil +} + +// UpdatePrompt updates a prompt. +func (b *InMemoryBackend) UpdatePrompt( + _ context.Context, promptID string, cfg PromptConfig, +) (*Prompt, error) { + b.mu.Lock() + defer b.mu.Unlock() + + p, ok := b.prompts[promptID] + if !ok { + return nil, fmt.Errorf("%w: prompt %q not found", ErrNotFound, promptID) + } + + if cfg.Name != "" { + p.Name = cfg.Name + } + + if cfg.Description != "" { + p.Description = cfg.Description + } + + if cfg.DefaultVariant != "" { + p.DefaultVariant = cfg.DefaultVariant + } + + if cfg.Variants != nil { + p.Variants = cfg.Variants + } + + if cfg.Tags != nil { + p.Tags = maps.Clone(cfg.Tags) + } + + p.UpdatedAt = time.Now().UTC() + + return promptCopy(p), nil +} + +// DeletePrompt deletes a prompt. +func (b *InMemoryBackend) DeletePrompt(_ context.Context, promptID string) error { + b.mu.Lock() + defer b.mu.Unlock() + + p, ok := b.prompts[promptID] + if !ok { + return fmt.Errorf("%w: prompt %q not found", ErrNotFound, promptID) + } + + delete(b.promptsByName, p.Name) + delete(b.prompts, promptID) + delete(b.promptVersions, promptID) + delete(b.promptVersionCtrs, promptID) + + return nil +} + +// ListPrompts returns paginated prompt summaries. +func (b *InMemoryBackend) ListPrompts( + _ context.Context, maxResults int, nextToken string, +) ([]*PromptSummary, string, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + ids := sortedKeys(b.prompts) + ids, outToken := paginate(ids, nextToken, maxResults) + + out := make([]*PromptSummary, 0, len(ids)) + + for _, id := range ids { + p := b.prompts[id] + out = append(out, &PromptSummary{ + PromptID: p.PromptID, + PromptARN: p.PromptARN, + Name: p.Name, + Description: p.Description, + Version: p.Version, + CreatedAt: p.CreatedAt, + UpdatedAt: p.UpdatedAt, + }) + } + + return out, outToken, nil +} + +// --------------------------------------------------------------------------- +// Prompt version CRUD +// --------------------------------------------------------------------------- + +// CreatePromptVersion creates a versioned snapshot of a prompt. +func (b *InMemoryBackend) CreatePromptVersion( + _ context.Context, promptID, description string, +) (*PromptVersion, error) { + b.mu.Lock() + defer b.mu.Unlock() + + p, ok := b.prompts[promptID] + if !ok { + return nil, fmt.Errorf("%w: prompt %q not found", ErrNotFound, promptID) + } + + b.promptVersionCtrs[promptID]++ + vNum := b.promptVersionCtrs[promptID] + version := strconv.Itoa(vNum) + + if b.promptVersions[promptID] == nil { + b.promptVersions[promptID] = make(map[string]*PromptVersion) + } + + pv := &PromptVersion{ + PromptID: promptID, + PromptARN: p.PromptARN, + Name: p.Name, + Version: version, + Variants: p.Variants, + Description: description, + CreatedAt: time.Now().UTC(), + } + + b.promptVersions[promptID][version] = pv + + return promptVersionCopy(pv), nil +} + +// GetPromptVersion returns a specific prompt version. +func (b *InMemoryBackend) GetPromptVersion( + _ context.Context, promptID, version string, +) (*PromptVersion, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + versions, ok := b.promptVersions[promptID] + if !ok { + return nil, fmt.Errorf("%w: prompt %q not found", ErrNotFound, promptID) + } + + pv, ok := versions[version] + if !ok { + return nil, fmt.Errorf("%w: prompt version %q not found", ErrNotFound, version) + } + + return promptVersionCopy(pv), nil +} + +// DeletePromptVersion deletes a prompt version. +func (b *InMemoryBackend) DeletePromptVersion( + _ context.Context, promptID, version string, +) error { + b.mu.Lock() + defer b.mu.Unlock() + + versions, ok := b.promptVersions[promptID] + if !ok { + return fmt.Errorf("%w: prompt %q not found", ErrNotFound, promptID) + } + + if _, exists := versions[version]; !exists { + return fmt.Errorf("%w: prompt version %q not found", ErrNotFound, version) + } + + delete(versions, version) + + return nil +} + +// --------------------------------------------------------------------------- +// Knowledge base document operations +// --------------------------------------------------------------------------- + +func kbDocKey(kbID, dsID, docID string) string { return kbID + "/" + dsID + "/" + docID } + +// IngestKnowledgeBaseDocuments ingests documents into a knowledge base data source. +func (b *InMemoryBackend) IngestKnowledgeBaseDocuments( + _ context.Context, kbID, dsID string, docs []KBDocument, +) ([]KBDocumentDetail, error) { + b.mu.Lock() + defer b.mu.Unlock() + + if _, ok := b.dataSources[dsKey(kbID, dsID)]; !ok { + return nil, fmt.Errorf("%w: data source %q not found", ErrNotFound, dsID) + } + + out := make([]KBDocumentDetail, 0, len(docs)) + + for _, doc := range docs { + detail := KBDocumentDetail{ + DocumentID: doc.DocID, + KnowledgeBaseID: kbID, + DataSourceID: dsID, + Status: docStatusIndexed, + } + b.kbDocuments[kbDocKey(kbID, dsID, doc.DocID)] = &detail + out = append(out, detail) + } + + return out, nil +} + +// GetKnowledgeBaseDocuments retrieves document details. +func (b *InMemoryBackend) GetKnowledgeBaseDocuments( + _ context.Context, kbID, dsID string, docIDs []string, +) ([]KBDocumentDetail, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + out := make([]KBDocumentDetail, 0, len(docIDs)) + + for _, id := range docIDs { + detail, ok := b.kbDocuments[kbDocKey(kbID, dsID, id)] + if !ok { + return nil, fmt.Errorf("%w: document %q not found", ErrNotFound, id) + } + + out = append(out, *detail) + } + + return out, nil +} + +// DeleteKnowledgeBaseDocuments deletes documents from a knowledge base data source. +func (b *InMemoryBackend) DeleteKnowledgeBaseDocuments( + _ context.Context, kbID, dsID string, docIDs []string, +) ([]KBDocumentDetail, error) { + b.mu.Lock() + defer b.mu.Unlock() + + out := make([]KBDocumentDetail, 0, len(docIDs)) + + for _, id := range docIDs { + key := kbDocKey(kbID, dsID, id) + + detail, ok := b.kbDocuments[key] + if !ok { + out = append(out, KBDocumentDetail{ + DocumentID: id, + KnowledgeBaseID: kbID, + DataSourceID: dsID, + Status: "NOT_FOUND", + }) + + continue + } + + delete(b.kbDocuments, key) + + d := *detail + d.Status = "DELETED" + out = append(out, d) + } + + return out, nil +} + +// ListKnowledgeBaseDocuments returns paginated document details. +func (b *InMemoryBackend) ListKnowledgeBaseDocuments( + _ context.Context, kbID, dsID string, maxResults int, nextToken string, +) ([]KBDocumentDetail, string, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + prefix := kbID + "/" + dsID + "/" + + var keys []string + + for k := range b.kbDocuments { + if len(k) > len(prefix) && k[:len(prefix)] == prefix { + keys = append(keys, k) + } + } + + sort.Strings(keys) + keys, outToken := paginate(keys, nextToken, maxResults) + + out := make([]KBDocumentDetail, 0, len(keys)) + + for _, k := range keys { + out = append(out, *b.kbDocuments[k]) + } + + return out, outToken, nil +} + +// --------------------------------------------------------------------------- +// Tagging operations +// --------------------------------------------------------------------------- + +// ListTagsForResource returns tags for a resource ARN. +func (b *InMemoryBackend) ListTagsForResource( + _ context.Context, resourceARN string, +) (map[string]string, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + t, ok := b.tags[resourceARN] + if !ok { + return map[string]string{}, nil + } + + return maps.Clone(t), nil +} + +// TagResource adds or updates tags on a resource. +func (b *InMemoryBackend) TagResource( + _ context.Context, resourceARN string, tags map[string]string, +) error { + b.mu.Lock() + defer b.mu.Unlock() + + if b.tags[resourceARN] == nil { + b.tags[resourceARN] = make(map[string]string) + } + + maps.Copy(b.tags[resourceARN], tags) + + return nil +} + +// UntagResource removes tags from a resource. +func (b *InMemoryBackend) UntagResource( + _ context.Context, resourceARN string, tagKeys []string, +) error { + b.mu.Lock() + defer b.mu.Unlock() + + t := b.tags[resourceARN] + + for _, k := range tagKeys { + delete(t, k) + } + + return nil +} + +// --------------------------------------------------------------------------- +// Pagination helper +// --------------------------------------------------------------------------- + +const defaultPageSize = 100 + +func paginate(ids []string, nextToken string, maxResults int) ([]string, string) { + start := 0 + + if nextToken != "" { + for i, id := range ids { + if id == nextToken { + start = i + + break + } + } + } + + size := defaultPageSize + + if maxResults > 0 && maxResults < defaultPageSize { + size = maxResults + } + + end := min(start+size, len(ids)) + + page := ids[start:end] + + var outToken string + + if end < len(ids) { + outToken = ids[end] + } + + return page, outToken +} + +func sortedKeys[V any](m map[string]V) []string { + keys := make([]string, 0, len(m)) + + for k := range m { + keys = append(keys, k) + } + + sort.Strings(keys) + + return keys +} + +// --------------------------------------------------------------------------- +// Deep-copy helpers +// --------------------------------------------------------------------------- + +func agentCopy(a *Agent) *Agent { + cp := *a + cp.Tags = maps.Clone(a.Tags) + + return &cp +} + +func agentVersionCopy(av *AgentVersion) *AgentVersion { + cp := *av + + return &cp +} + +func actionGroupCopy(ag *AgentActionGroup) *AgentActionGroup { + cp := *ag + + return &cp +} + +func aliasCopy(al *AgentAlias) *AgentAlias { + cp := *al + cp.Tags = maps.Clone(al.Tags) + + if al.RoutingConfiguration != nil { + cp.RoutingConfiguration = append([]AliasRouting{}, al.RoutingConfiguration...) + } + + return &cp +} + +func collabCopy(c *AgentCollaborator) *AgentCollaborator { + cp := *c + + return &cp +} + +func kbCopy(kb *KnowledgeBase) *KnowledgeBase { + cp := *kb + cp.Tags = maps.Clone(kb.Tags) + + return &cp +} + +func agKBCopy(a *AgentKnowledgeBase) *AgentKnowledgeBase { + cp := *a + + return &cp +} + +func dsCopy(ds *DataSource) *DataSource { + cp := *ds + + return &cp +} + +func jobCopy(j *IngestionJob) *IngestionJob { + cp := *j + + return &cp +} + +func flowCopy(f *Flow) *Flow { + cp := *f + cp.Tags = maps.Clone(f.Tags) + + return &cp +} + +func flowVersionCopy(fv *FlowVersion) *FlowVersion { + cp := *fv + + return &cp +} + +func flowAliasCopy(al *FlowAlias) *FlowAlias { + cp := *al + cp.Tags = maps.Clone(al.Tags) + + if al.RoutingConfiguration != nil { + cp.RoutingConfiguration = append([]FlowAliasRouting{}, al.RoutingConfiguration...) + } + + return &cp +} + +func promptCopy(p *Prompt) *Prompt { + cp := *p + cp.Tags = maps.Clone(p.Tags) + + return &cp +} + +func promptVersionCopy(pv *PromptVersion) *PromptVersion { + cp := *pv + + return &cp +} diff --git a/services/bedrockagent/export_test.go b/services/bedrockagent/export_test.go new file mode 100644 index 000000000..c0b202645 --- /dev/null +++ b/services/bedrockagent/export_test.go @@ -0,0 +1,11 @@ +package bedrockagent + +// Exported for testing. + +func NewTestBackend(region, accountID string) *InMemoryBackend { + return NewInMemoryBackend(region, accountID) +} + +func NewTestHandler(b StorageBackend) *Handler { + return NewHandler(b) +} diff --git a/services/bedrockagent/handler.go b/services/bedrockagent/handler.go new file mode 100644 index 000000000..526ed5c16 --- /dev/null +++ b/services/bedrockagent/handler.go @@ -0,0 +1,2666 @@ +package bedrockagent + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "slices" + "strings" + + "github.com/labstack/echo/v5" + + "github.com/blackbirdworks/gopherstack/pkgs/awserr" + "github.com/blackbirdworks/gopherstack/pkgs/httputils" + "github.com/blackbirdworks/gopherstack/pkgs/logger" + "github.com/blackbirdworks/gopherstack/pkgs/service" +) + +// --------------------------------------------------------------------------- +// Operation name constants +// --------------------------------------------------------------------------- + +const ( + opCreateAgent = "CreateAgent" + opGetAgent = "GetAgent" + opUpdateAgent = "UpdateAgent" + opDeleteAgent = "DeleteAgent" + opListAgents = "ListAgents" + opPrepareAgent = "PrepareAgent" + opCreateAgentVersion = "CreateAgentVersion" + opGetAgentVersion = "GetAgentVersion" + opDeleteAgentVersion = "DeleteAgentVersion" + opListAgentVersions = "ListAgentVersions" + opCreateAgentActionGroup = "CreateAgentActionGroup" + opGetAgentActionGroup = "GetAgentActionGroup" + opUpdateAgentActionGroup = "UpdateAgentActionGroup" + opDeleteAgentActionGroup = "DeleteAgentActionGroup" + opListAgentActionGroups = "ListAgentActionGroups" + opCreateAgentAlias = "CreateAgentAlias" + opGetAgentAlias = "GetAgentAlias" + opUpdateAgentAlias = "UpdateAgentAlias" + opDeleteAgentAlias = "DeleteAgentAlias" + opListAgentAliases = "ListAgentAliases" + opAssociateAgentCollaborator = "AssociateAgentCollaborator" + opGetAgentCollaborator = "GetAgentCollaborator" + opUpdateAgentCollaborator = "UpdateAgentCollaborator" + opDisassociateAgentCollaborator = "DisassociateAgentCollaborator" + opListAgentCollaborators = "ListAgentCollaborators" + opCreateKnowledgeBase = "CreateKnowledgeBase" + opGetKnowledgeBase = "GetKnowledgeBase" + opUpdateKnowledgeBase = "UpdateKnowledgeBase" + opDeleteKnowledgeBase = "DeleteKnowledgeBase" + opListKnowledgeBases = "ListKnowledgeBases" + opAssociateAgentKnowledgeBase = "AssociateAgentKnowledgeBase" + opGetAgentKnowledgeBase = "GetAgentKnowledgeBase" + opUpdateAgentKnowledgeBase = "UpdateAgentKnowledgeBase" + opDisassociateAgentKnowledgeBase = "DisassociateAgentKnowledgeBase" + opListAgentKnowledgeBases = "ListAgentKnowledgeBases" + opCreateDataSource = "CreateDataSource" + opGetDataSource = "GetDataSource" + opUpdateDataSource = "UpdateDataSource" + opDeleteDataSource = "DeleteDataSource" + opListDataSources = "ListDataSources" + opStartIngestionJob = "StartIngestionJob" + opGetIngestionJob = "GetIngestionJob" + opStopIngestionJob = "StopIngestionJob" + opListIngestionJobs = "ListIngestionJobs" + opCreateFlow = "CreateFlow" + opGetFlow = "GetFlow" + opUpdateFlow = "UpdateFlow" + opDeleteFlow = "DeleteFlow" + opListFlows = "ListFlows" + opPrepareFlow = "PrepareFlow" + opValidateFlowDefinition = "ValidateFlowDefinition" + opCreateFlowVersion = "CreateFlowVersion" + opGetFlowVersion = "GetFlowVersion" + opDeleteFlowVersion = "DeleteFlowVersion" + opListFlowVersions = "ListFlowVersions" + opCreateFlowAlias = "CreateFlowAlias" + opGetFlowAlias = "GetFlowAlias" + opUpdateFlowAlias = "UpdateFlowAlias" + opDeleteFlowAlias = "DeleteFlowAlias" + opListFlowAliases = "ListFlowAliases" + opCreatePrompt = "CreatePrompt" + opGetPrompt = "GetPrompt" + opUpdatePrompt = "UpdatePrompt" + opDeletePrompt = "DeletePrompt" + opListPrompts = "ListPrompts" + opCreatePromptVersion = "CreatePromptVersion" + opGetPromptVersion = "GetPromptVersion" + opDeletePromptVersion = "DeletePromptVersion" + opIngestKnowledgeBaseDocuments = "IngestKnowledgeBaseDocuments" + opGetKnowledgeBaseDocuments = "GetKnowledgeBaseDocuments" + opDeleteKnowledgeBaseDocuments = "DeleteKnowledgeBaseDocuments" + opListKnowledgeBaseDocuments = "ListKnowledgeBaseDocuments" + opListTagsForResource = "ListTagsForResource" + opTagResource = "TagResource" + opUntagResource = "UntagResource" +) + +// --------------------------------------------------------------------------- +// Path constants +// --------------------------------------------------------------------------- + +const ( + agentsBase = "/agents" + kbBase = "/knowledgebases" + flowsBase = "/flows" + promptsBase = "/prompts" + tagsBase = "/tags/" + baService = "bedrock-agent" + baPriority = 87 + splitTwo = 2 + splitThree = 3 + splitFour = 4 + maxPageDefault = 100 +) + +// --------------------------------------------------------------------------- +// Goconst string constants +// --------------------------------------------------------------------------- + +const ( + keyAgent = "agent" + keyAgentID = "agentId" + keyAgentStatus = "agentStatus" + keyAgentVersion = "agentVersion" + keyAgentActionGroup = "agentActionGroup" + keyAgentAlias = "agentAlias" + keyAgentCollaborator = "agentCollaborator" + keyKnowledgeBase = "knowledgeBase" + keyAgentKB = "agentKnowledgeBase" + keyDataSource = "dataSource" + keyIngestionJob = "ingestionJob" + keyDocumentDetails = "documentDetails" + keyNextToken = "nextToken" + keyStatus = "status" + statusDeleting = "DELETING" + opUnknown = "Unknown" +) + +// --------------------------------------------------------------------------- +// Handler +// --------------------------------------------------------------------------- + +// Handler is the HTTP handler for the Bedrock Agent REST API. +type Handler struct { + Backend StorageBackend + AccountID string + DefaultRegion string +} + +// NewHandler creates a new Bedrock Agent handler. +func NewHandler(backend StorageBackend) *Handler { + return &Handler{Backend: backend} +} + +// Reset clears handler state (delegates to backend). +func (h *Handler) Reset() { + if r, ok := h.Backend.(interface{ Reset() }); ok { + r.Reset() + } +} + +// Name returns the service name. +func (h *Handler) Name() string { return "BedrockAgent" } + +// GetSupportedOperations returns the list of supported operations. +func (h *Handler) GetSupportedOperations() []string { + return []string{ + opCreateAgent, opGetAgent, opUpdateAgent, opDeleteAgent, opListAgents, opPrepareAgent, + opCreateAgentVersion, opGetAgentVersion, opDeleteAgentVersion, opListAgentVersions, + opCreateAgentActionGroup, opGetAgentActionGroup, opUpdateAgentActionGroup, + opDeleteAgentActionGroup, opListAgentActionGroups, + opCreateAgentAlias, opGetAgentAlias, opUpdateAgentAlias, opDeleteAgentAlias, opListAgentAliases, + opAssociateAgentCollaborator, opGetAgentCollaborator, opUpdateAgentCollaborator, + opDisassociateAgentCollaborator, opListAgentCollaborators, + opCreateKnowledgeBase, opGetKnowledgeBase, opUpdateKnowledgeBase, + opDeleteKnowledgeBase, opListKnowledgeBases, + opAssociateAgentKnowledgeBase, opGetAgentKnowledgeBase, opUpdateAgentKnowledgeBase, + opDisassociateAgentKnowledgeBase, opListAgentKnowledgeBases, + opCreateDataSource, opGetDataSource, opUpdateDataSource, opDeleteDataSource, opListDataSources, + opStartIngestionJob, opGetIngestionJob, opStopIngestionJob, opListIngestionJobs, + opCreateFlow, opGetFlow, opUpdateFlow, opDeleteFlow, opListFlows, opPrepareFlow, + opValidateFlowDefinition, + opCreateFlowVersion, opGetFlowVersion, opDeleteFlowVersion, opListFlowVersions, + opCreateFlowAlias, opGetFlowAlias, opUpdateFlowAlias, opDeleteFlowAlias, opListFlowAliases, + opCreatePrompt, opGetPrompt, opUpdatePrompt, opDeletePrompt, opListPrompts, + opCreatePromptVersion, opGetPromptVersion, opDeletePromptVersion, + opIngestKnowledgeBaseDocuments, opGetKnowledgeBaseDocuments, + opDeleteKnowledgeBaseDocuments, opListKnowledgeBaseDocuments, + opListTagsForResource, opTagResource, opUntagResource, + } +} + +// ChaosServiceName returns the chaos service name. +func (h *Handler) ChaosServiceName() string { return baService } + +// ChaosOperations returns all operations. +func (h *Handler) ChaosOperations() []string { return h.GetSupportedOperations() } + +// ChaosRegions returns the supported regions. +func (h *Handler) ChaosRegions() []string { return []string{h.DefaultRegion} } + +// RouteMatcher returns a function matching Bedrock Agent requests. +func (h *Handler) RouteMatcher() service.Matcher { + return func(c *echo.Context) bool { + svc := httputils.ExtractServiceFromRequest(c.Request()) + if svc == baService { + return true + } + + path := c.Request().URL.Path + + return strings.HasPrefix(path, agentsBase) || + strings.HasPrefix(path, kbBase) || + strings.HasPrefix(path, flowsBase) || + strings.HasPrefix(path, promptsBase) || + strings.HasPrefix(path, tagsBase) + } +} + +// MatchPriority returns routing priority. +func (h *Handler) MatchPriority() int { return baPriority } + +// ExtractOperation determines the operation name from the request. +func (h *Handler) ExtractOperation(c *echo.Context) string { + return classifyPath(c.Request().Method, c.Request().URL.Path) +} + +// ExtractResource extracts an agent or flow ID from the request path. +func (h *Handler) ExtractResource(c *echo.Context) string { + path := c.Request().URL.Path + + for _, prefix := range []string{"/agents/", "/flows/", "/knowledgebases/", "/prompts/"} { + if rest, ok := strings.CutPrefix(path, prefix); ok { + parts := strings.SplitN(rest, "/", splitTwo) + + return parts[0] + } + } + + return "" +} + +// Handler returns the Echo handler function. +func (h *Handler) Handler() echo.HandlerFunc { + return func(c *echo.Context) error { + region := httputils.ExtractRegionFromRequest(c.Request(), h.DefaultRegion) + ctx := context.WithValue(c.Request().Context(), regionKey{}, region) + log := logger.Load(ctx) + path := strings.TrimSuffix(c.Request().URL.Path, "/") + method := c.Request().Method + query := c.Request().URL.Query() + + body, err := httputils.ReadBody(c.Request()) + if err != nil { + log.ErrorContext(ctx, "bedrockagent: failed to read body", "error", err) + + return c.JSON(http.StatusInternalServerError, errResp("InternalFailure", "internal server error")) + } + + return h.dispatch(ctx, c, path, method, query, body) + } +} + +// --------------------------------------------------------------------------- +// Dispatch +// --------------------------------------------------------------------------- + +func (h *Handler) dispatch( + ctx context.Context, c *echo.Context, path, method string, query url.Values, body []byte, +) error { + switch { + case strings.HasPrefix(path, agentsBase): + return h.dispatchAgents(ctx, c, path, method, body) + case strings.HasPrefix(path, kbBase): + return h.dispatchKB(ctx, c, path, method, body) + case strings.HasPrefix(path, flowsBase): + return h.dispatchFlows(ctx, c, path, method, body) + case strings.HasPrefix(path, promptsBase): + return h.dispatchPrompts(ctx, c, path, method, body) + case strings.HasPrefix(path, tagsBase): + return h.dispatchTags(ctx, c, path, method, query, body) + } + + return c.JSON(http.StatusNotFound, errResp("UnknownOperationException", "unknown: "+path)) +} + +// --------------------------------------------------------------------------- +// Agent dispatch +// --------------------------------------------------------------------------- + +func (h *Handler) dispatchAgents( + ctx context.Context, c *echo.Context, path, method string, body []byte, +) error { + if path == agentsBase { + return h.dispatchAgentRoot(ctx, c, method, body) + } + + rest, _ := strings.CutPrefix(path, agentsBase+"/") + parts := strings.SplitN(rest, "/", splitTwo) + agentID := parts[0] + suffix := "" + + if len(parts) == splitTwo { + suffix = "/" + parts[1] + } + + return h.dispatchAgentID(ctx, c, agentID, suffix, method, body) +} + +func (h *Handler) dispatchAgentRoot( + ctx context.Context, c *echo.Context, method string, body []byte, +) error { + switch method { + case http.MethodPut, http.MethodPost: + return h.handleCreateAgent(ctx, c, body) + case http.MethodGet: + return h.handleListAgents(ctx, c) + } + + return c.JSON(http.StatusMethodNotAllowed, errResp("MethodNotAllowedException", method)) +} + +func (h *Handler) dispatchAgentID( + ctx context.Context, c *echo.Context, agentID, suffix, method string, body []byte, +) error { + switch { + case suffix == "" && method == http.MethodGet: + return h.handleGetAgent(ctx, c, agentID) + case suffix == "" && method == http.MethodPut: + return h.handleUpdateAgent(ctx, c, agentID, body) + case suffix == "" && method == http.MethodDelete: + return h.handleDeleteAgent(ctx, c, agentID) + case suffix == "/prepare" && method == http.MethodPost: + return h.handlePrepareAgent(ctx, c, agentID) + case strings.HasPrefix(suffix, "/agentversions"): + return h.dispatchAgentVersions(ctx, c, agentID, suffix, method, body) + case strings.HasPrefix(suffix, "/agentaliases"): + return h.dispatchAgentAliases(ctx, c, agentID, suffix, method, body) + } + + return c.JSON(http.StatusNotFound, errResp("UnknownOperationException", "unknown agent op")) +} + +func (h *Handler) dispatchAgentVersions( + ctx context.Context, c *echo.Context, agentID, suffix, method string, body []byte, +) error { + rest, _ := strings.CutPrefix(suffix, "/agentversions") + + if rest == "" { + switch method { + case http.MethodPost: + return h.handleCreateAgentVersion(ctx, c, agentID, body) + case http.MethodGet: + return h.handleListAgentVersions(ctx, c, agentID) + } + + return c.JSON(http.StatusMethodNotAllowed, errResp("MethodNotAllowedException", method)) + } + + parts := strings.SplitN(strings.TrimPrefix(rest, "/"), "/", splitTwo) + agentVersion := parts[0] + vSuffix := "" + + if len(parts) == splitTwo { + vSuffix = "/" + parts[1] + } + + return h.dispatchAgentVersionSuffix(ctx, c, agentID, agentVersion, vSuffix, method, body) +} + +func (h *Handler) dispatchAgentVersionSuffix( + ctx context.Context, c *echo.Context, agentID, agentVersion, vSuffix, method string, body []byte, +) error { + switch { + case vSuffix == "" && method == http.MethodGet: + return h.handleGetAgentVersion(ctx, c, agentID, agentVersion) + case vSuffix == "" && method == http.MethodDelete: + return h.handleDeleteAgentVersion(ctx, c, agentID, agentVersion) + case strings.HasPrefix(vSuffix, "/actiongroups"): + return h.dispatchActionGroups(ctx, c, agentID, agentVersion, vSuffix, method, body) + case strings.HasPrefix(vSuffix, "/agentcollaborators"): + return h.dispatchCollaborators(ctx, c, agentID, agentVersion, vSuffix, method, body) + case strings.HasPrefix(vSuffix, "/knowledgebases"): + return h.dispatchAgentKBs(ctx, c, agentID, agentVersion, vSuffix, method, body) + } + + return c.JSON(http.StatusNotFound, errResp("UnknownOperationException", "unknown version op")) +} + +func (h *Handler) dispatchActionGroups( + ctx context.Context, c *echo.Context, agentID, agentVersion, suffix, method string, body []byte, +) error { + rest, _ := strings.CutPrefix(suffix, "/actiongroups") + + if rest == "" { + switch method { + case http.MethodPut, http.MethodPost: + return h.handleCreateAgentActionGroup(ctx, c, agentID, body) + case http.MethodGet: + return h.handleListAgentActionGroups(ctx, c, agentID, agentVersion) + } + } + + agID := strings.TrimPrefix(rest, "/") + + switch method { + case http.MethodGet: + return h.handleGetAgentActionGroup(ctx, c, agentID, agentVersion, agID) + case http.MethodPut: + return h.handleUpdateAgentActionGroup(ctx, c, agentID, agentVersion, agID, body) + case http.MethodDelete: + return h.handleDeleteAgentActionGroup(ctx, c, agentID, agentVersion, agID) + } + + return c.JSON(http.StatusNotFound, errResp("UnknownOperationException", "unknown action group op")) +} + +func (h *Handler) dispatchCollaborators( + ctx context.Context, c *echo.Context, agentID, agentVersion, suffix, method string, body []byte, +) error { + rest, _ := strings.CutPrefix(suffix, "/agentcollaborators") + + if rest == "" { + switch method { + case http.MethodPut: + return h.handleAssociateCollaborator(ctx, c, agentID, agentVersion, body) + case http.MethodGet: + return h.handleListCollaborators(ctx, c, agentID, agentVersion) + } + } + + collaboratorID := strings.TrimPrefix(rest, "/") + + switch method { + case http.MethodGet: + return h.handleGetCollaborator(ctx, c, agentID, agentVersion, collaboratorID) + case http.MethodPut: + return h.handleUpdateCollaborator(ctx, c, agentID, agentVersion, collaboratorID, body) + case http.MethodDelete: + return h.handleDisassociateCollaborator(ctx, c, agentID, agentVersion, collaboratorID) + } + + return c.JSON(http.StatusNotFound, errResp("UnknownOperationException", "unknown collab op")) +} + +func (h *Handler) dispatchAgentKBs( + ctx context.Context, c *echo.Context, agentID, agentVersion, suffix, method string, body []byte, +) error { + rest, _ := strings.CutPrefix(suffix, "/knowledgebases") + + if rest == "" { + switch method { + case http.MethodPut: + return h.handleAssociateAgentKB(ctx, c, agentID, agentVersion, body) + case http.MethodGet: + return h.handleListAgentKBs(ctx, c, agentID, agentVersion) + } + } + + kbID := strings.TrimPrefix(rest, "/") + + switch method { + case http.MethodGet: + return h.handleGetAgentKB(ctx, c, agentID, agentVersion, kbID) + case http.MethodPut: + return h.handleUpdateAgentKB(ctx, c, agentID, agentVersion, kbID, body) + case http.MethodDelete: + return h.handleDisassociateAgentKB(ctx, c, agentID, agentVersion, kbID) + } + + return c.JSON(http.StatusNotFound, errResp("UnknownOperationException", "unknown agent-kb op")) +} + +func (h *Handler) dispatchAgentAliases( + ctx context.Context, c *echo.Context, agentID, suffix, method string, body []byte, +) error { + rest, _ := strings.CutPrefix(suffix, "/agentaliases") + + if rest == "" { + switch method { + case http.MethodPost, http.MethodPut: + return h.handleCreateAgentAlias(ctx, c, agentID, body) + case http.MethodGet: + return h.handleListAgentAliases(ctx, c, agentID) + } + } + + aliasID := strings.TrimPrefix(rest, "/") + + switch method { + case http.MethodGet: + return h.handleGetAgentAlias(ctx, c, agentID, aliasID) + case http.MethodPut: + return h.handleUpdateAgentAlias(ctx, c, agentID, aliasID, body) + case http.MethodDelete: + return h.handleDeleteAgentAlias(ctx, c, agentID, aliasID) + } + + return c.JSON(http.StatusNotFound, errResp("UnknownOperationException", "unknown alias op")) +} + +// --------------------------------------------------------------------------- +// Knowledge base dispatch +// --------------------------------------------------------------------------- + +func (h *Handler) dispatchKB( + ctx context.Context, c *echo.Context, path, method string, body []byte, +) error { + if path == kbBase { + switch method { + case http.MethodPut, http.MethodPost: + return h.handleCreateKB(ctx, c, body) + case http.MethodGet: + return h.handleListKBs(ctx, c) + } + } + + rest, _ := strings.CutPrefix(path, kbBase+"/") + parts := strings.SplitN(rest, "/", splitTwo) + kbID := parts[0] + suffix := "" + + if len(parts) == splitTwo { + suffix = "/" + parts[1] + } + + return h.dispatchKBID(ctx, c, kbID, suffix, method, body) +} + +func (h *Handler) dispatchKBID( + ctx context.Context, c *echo.Context, kbID, suffix, method string, body []byte, +) error { + switch { + case suffix == "" && method == http.MethodGet: + return h.handleGetKB(ctx, c, kbID) + case suffix == "" && method == http.MethodPut: + return h.handleUpdateKB(ctx, c, kbID, body) + case suffix == "" && method == http.MethodDelete: + return h.handleDeleteKB(ctx, c, kbID) + case strings.HasPrefix(suffix, "/datasources"): + return h.dispatchDataSources(ctx, c, kbID, suffix, method, body) + } + + return c.JSON(http.StatusNotFound, errResp("UnknownOperationException", "unknown kb op")) +} + +func (h *Handler) dispatchDataSources( + ctx context.Context, c *echo.Context, kbID, suffix, method string, body []byte, +) error { + rest, _ := strings.CutPrefix(suffix, "/datasources") + + if rest == "" { + switch method { + case http.MethodPut, http.MethodPost: + return h.handleCreateDS(ctx, c, kbID, body) + case http.MethodGet: + return h.handleListDS(ctx, c, kbID) + } + } + + parts := strings.SplitN(strings.TrimPrefix(rest, "/"), "/", splitTwo) + dsID := parts[0] + dsSuffix := "" + + if len(parts) == splitTwo { + dsSuffix = "/" + parts[1] + } + + return h.dispatchDSID(ctx, c, kbID, dsID, dsSuffix, method, body) +} + +func (h *Handler) dispatchDSID( + ctx context.Context, c *echo.Context, kbID, dsID, suffix, method string, body []byte, +) error { + switch { + case suffix == "" && method == http.MethodGet: + return h.handleGetDS(ctx, c, kbID, dsID) + case suffix == "" && method == http.MethodPut: + return h.handleUpdateDS(ctx, c, kbID, dsID, body) + case suffix == "" && method == http.MethodDelete: + return h.handleDeleteDS(ctx, c, kbID, dsID) + case strings.HasPrefix(suffix, "/ingestionjobs"): + return h.dispatchIngestionJobs(ctx, c, kbID, dsID, suffix, method, body) + case strings.HasPrefix(suffix, "/documents"): + return h.dispatchKBDocuments(ctx, c, kbID, dsID, suffix, method, body) + } + + return c.JSON(http.StatusNotFound, errResp("UnknownOperationException", "unknown ds op")) +} + +func (h *Handler) dispatchIngestionJobs( + ctx context.Context, c *echo.Context, kbID, dsID, suffix, method string, body []byte, +) error { + rest, _ := strings.CutPrefix(suffix, "/ingestionjobs") + + if rest == "" { + switch method { + case http.MethodPut, http.MethodPost: + return h.handleStartIngestionJob(ctx, c, kbID, dsID, body) + case http.MethodGet: + return h.handleListIngestionJobs(ctx, c, kbID, dsID) + } + } + + parts := strings.SplitN(strings.TrimPrefix(rest, "/"), "/", splitTwo) + jobID := parts[0] + + if len(parts) == splitTwo && parts[1] == "stop" { + return h.handleStopIngestionJob(ctx, c, kbID, dsID, jobID) + } + + if method == http.MethodGet { + return h.handleGetIngestionJob(ctx, c, kbID, dsID, jobID) + } + + return c.JSON(http.StatusNotFound, errResp("UnknownOperationException", "unknown ingestion op")) +} + +func (h *Handler) dispatchKBDocuments( + ctx context.Context, c *echo.Context, kbID, dsID, suffix, method string, body []byte, +) error { + rest, _ := strings.CutPrefix(suffix, "/documents") + + switch { + case rest == "" && method == http.MethodPost: + return h.handleIngestKBDocs(ctx, c, kbID, dsID, body) + case rest == "" && method == http.MethodGet: + return h.handleListKBDocs(ctx, c, kbID, dsID) + case rest == "/deleteDocuments": + return h.handleDeleteKBDocs(ctx, c, kbID, dsID, body) + case rest == "/getDocuments": + return h.handleGetKBDocs(ctx, c, kbID, dsID, body) + } + + return c.JSON(http.StatusNotFound, errResp("UnknownOperationException", "unknown kb docs op")) +} + +// --------------------------------------------------------------------------- +// Flow dispatch +// --------------------------------------------------------------------------- + +func (h *Handler) dispatchFlows( + ctx context.Context, c *echo.Context, path, method string, body []byte, +) error { + if path == flowsBase { + switch method { + case http.MethodPost: + return h.handleCreateFlow(ctx, c, body) + case http.MethodGet: + return h.handleListFlows(ctx, c) + } + } + + if path == flowsBase+"/validate-definition" { + return h.handleValidateFlowDef(ctx, c, body) + } + + rest, _ := strings.CutPrefix(path, flowsBase+"/") + parts := strings.SplitN(rest, "/", splitTwo) + flowID := parts[0] + suffix := "" + + if len(parts) == splitTwo { + suffix = "/" + parts[1] + } + + return h.dispatchFlowID(ctx, c, flowID, suffix, method, body) +} + +func (h *Handler) dispatchFlowID( + ctx context.Context, c *echo.Context, flowID, suffix, method string, body []byte, +) error { + if suffix == "" { + switch method { + case http.MethodGet: + return h.handleGetFlow(ctx, c, flowID) + case http.MethodPut: + return h.handleUpdateFlow(ctx, c, flowID, body) + case http.MethodDelete: + return h.handleDeleteFlow(ctx, c, flowID) + } + } + + if suffix == "/prepare" && method == http.MethodPost { + return h.handlePrepareFlow(ctx, c, flowID) + } + + if strings.HasPrefix(suffix, "/versions") { + return h.dispatchFlowVersions(ctx, c, flowID, suffix, method, body) + } + + if strings.HasPrefix(suffix, "/aliases") { + return h.dispatchFlowAliases(ctx, c, flowID, suffix, method, body) + } + + return c.JSON(http.StatusNotFound, errResp("UnknownOperationException", "unknown flow op")) +} + +func (h *Handler) dispatchFlowVersions( + ctx context.Context, c *echo.Context, flowID, suffix, method string, body []byte, +) error { + rest, _ := strings.CutPrefix(suffix, "/versions") + + if rest == "" { + switch method { + case http.MethodPost: + return h.handleCreateFlowVersion(ctx, c, flowID, body) + case http.MethodGet: + return h.handleListFlowVersions(ctx, c, flowID) + } + } + + flowVersion := strings.TrimPrefix(rest, "/") + + switch method { + case http.MethodGet: + return h.handleGetFlowVersion(ctx, c, flowID, flowVersion) + case http.MethodDelete: + return h.handleDeleteFlowVersion(ctx, c, flowID, flowVersion) + } + + return c.JSON(http.StatusNotFound, errResp("UnknownOperationException", "unknown flow version op")) +} + +func (h *Handler) dispatchFlowAliases( + ctx context.Context, c *echo.Context, flowID, suffix, method string, body []byte, +) error { + rest, _ := strings.CutPrefix(suffix, "/aliases") + + if rest == "" { + switch method { + case http.MethodPost: + return h.handleCreateFlowAlias(ctx, c, flowID, body) + case http.MethodGet: + return h.handleListFlowAliases(ctx, c, flowID) + } + } + + aliasID := strings.TrimPrefix(rest, "/") + + switch method { + case http.MethodGet: + return h.handleGetFlowAlias(ctx, c, flowID, aliasID) + case http.MethodPut: + return h.handleUpdateFlowAlias(ctx, c, flowID, aliasID, body) + case http.MethodDelete: + return h.handleDeleteFlowAlias(ctx, c, flowID, aliasID) + } + + return c.JSON(http.StatusNotFound, errResp("UnknownOperationException", "unknown flow alias op")) +} + +// --------------------------------------------------------------------------- +// Prompt dispatch +// --------------------------------------------------------------------------- + +func (h *Handler) dispatchPrompts( + ctx context.Context, c *echo.Context, path, method string, body []byte, +) error { + if path == promptsBase { + switch method { + case http.MethodPost: + return h.handleCreatePrompt(ctx, c, body) + case http.MethodGet: + return h.handleListPrompts(ctx, c) + } + } + + rest, _ := strings.CutPrefix(path, promptsBase+"/") + parts := strings.SplitN(rest, "/", splitTwo) + promptID := parts[0] + suffix := "" + + if len(parts) == splitTwo { + suffix = "/" + parts[1] + } + + return h.dispatchPromptID(ctx, c, promptID, suffix, method, body) +} + +func (h *Handler) dispatchPromptID( + ctx context.Context, c *echo.Context, promptID, suffix, method string, body []byte, +) error { + switch { + case suffix == "" && method == http.MethodGet: + return h.handleGetPrompt(ctx, c, promptID) + case suffix == "" && method == http.MethodPut: + return h.handleUpdatePrompt(ctx, c, promptID, body) + case suffix == "" && method == http.MethodDelete: + return h.handleDeletePrompt(ctx, c, promptID) + case strings.HasPrefix(suffix, "/versions"): + return h.dispatchPromptVersions(ctx, c, promptID, suffix, method, body) + } + + return c.JSON(http.StatusNotFound, errResp("UnknownOperationException", "unknown prompt op")) +} + +func (h *Handler) dispatchPromptVersions( + ctx context.Context, c *echo.Context, promptID, suffix, method string, body []byte, +) error { + rest, _ := strings.CutPrefix(suffix, "/versions") + + if rest == "" && method == http.MethodPost { + return h.handleCreatePromptVersion(ctx, c, promptID, body) + } + + versionID := strings.TrimPrefix(rest, "/") + + switch method { + case http.MethodGet: + return h.handleGetPromptVersion(ctx, c, promptID, versionID) + case http.MethodDelete: + return h.handleDeletePromptVersion(ctx, c, promptID, versionID) + } + + return c.JSON(http.StatusNotFound, errResp("UnknownOperationException", "unknown prompt version op")) +} + +// --------------------------------------------------------------------------- +// Tag dispatch +// --------------------------------------------------------------------------- + +func (h *Handler) dispatchTags( + ctx context.Context, c *echo.Context, path, method string, query url.Values, body []byte, +) error { + resourceARN, _ := strings.CutPrefix(path, tagsBase) + + switch method { + case http.MethodGet: + return h.handleListTags(ctx, c, resourceARN) + case http.MethodPost: + return h.handleTagResource(ctx, c, resourceARN, body) + case http.MethodDelete: + return h.handleUntagResource(ctx, c, resourceARN, query) + } + + return c.JSON(http.StatusMethodNotAllowed, errResp("MethodNotAllowedException", method)) +} + +// --------------------------------------------------------------------------- +// Agent handlers +// --------------------------------------------------------------------------- + +func (h *Handler) handleCreateAgent(ctx context.Context, c *echo.Context, body []byte) error { + var req struct { + Tags map[string]string `json:"tags"` + Guardrail map[string]any `json:"guardrailConfiguration"` + Memory map[string]any `json:"memoryConfiguration"` + AgentName string `json:"agentName"` + Collaboration string `json:"agentCollaboration"` + Description string `json:"description"` + FoundationModel string `json:"foundationModel"` + Instruction string `json:"instruction"` + RoleARN string `json:"agentResourceRoleArn"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + agent, err := h.Backend.CreateAgent(ctx, AgentConfig{ + AgentName: req.AgentName, + Collaboration: req.Collaboration, + Description: req.Description, + FoundationModel: req.FoundationModel, + Instruction: req.Instruction, + RoleARN: req.RoleARN, + Tags: req.Tags, + Guardrail: req.Guardrail, + Memory: req.Memory, + }) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyAgent: agent}) +} + +func (h *Handler) handleGetAgent(ctx context.Context, c *echo.Context, agentID string) error { + agent, err := h.Backend.GetAgent(ctx, agentID) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyAgent: agent}) +} + +func (h *Handler) handleUpdateAgent( + ctx context.Context, c *echo.Context, agentID string, body []byte, +) error { + var req struct { + Tags map[string]string `json:"tags"` + Guardrail map[string]any `json:"guardrailConfiguration"` + Memory map[string]any `json:"memoryConfiguration"` + AgentName string `json:"agentName"` + Collaboration string `json:"agentCollaboration"` + Description string `json:"description"` + FoundationModel string `json:"foundationModel"` + Instruction string `json:"instruction"` + RoleARN string `json:"agentResourceRoleArn"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + agent, err := h.Backend.UpdateAgent(ctx, agentID, AgentConfig{ + AgentName: req.AgentName, + Collaboration: req.Collaboration, + Description: req.Description, + FoundationModel: req.FoundationModel, + Instruction: req.Instruction, + RoleARN: req.RoleARN, + Tags: req.Tags, + Guardrail: req.Guardrail, + Memory: req.Memory, + }) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyAgent: agent}) +} + +func (h *Handler) handleDeleteAgent(ctx context.Context, c *echo.Context, agentID string) error { + if err := h.Backend.DeleteAgent(ctx, agentID); err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyAgentID: agentID, keyAgentStatus: statusDeleting}) +} + +func (h *Handler) handleListAgents(ctx context.Context, c *echo.Context) error { + maxResults, nextToken := pageParams(c.Request().URL.Query()) + + agents, outToken, err := h.Backend.ListAgents(ctx, maxResults, nextToken) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{"agentSummaries": agents, keyNextToken: outToken}) +} + +func (h *Handler) handlePrepareAgent(ctx context.Context, c *echo.Context, agentID string) error { + agent, err := h.Backend.PrepareAgent(ctx, agentID) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusAccepted, map[string]any{ + keyAgentID: agent.AgentID, + keyAgentStatus: agent.AgentStatus, + keyAgentVersion: agent.AgentVersion, + }) +} + +// --------------------------------------------------------------------------- +// Agent version handlers +// --------------------------------------------------------------------------- + +func (h *Handler) handleCreateAgentVersion( + ctx context.Context, c *echo.Context, agentID string, body []byte, +) error { + var req struct { + Description string `json:"description"` + } + + _ = json.Unmarshal(body, &req) + + av, err := h.Backend.CreateAgentVersion(ctx, agentID, req.Description) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyAgentVersion: av}) +} + +func (h *Handler) handleGetAgentVersion( + ctx context.Context, c *echo.Context, agentID, version string, +) error { + av, err := h.Backend.GetAgentVersion(ctx, agentID, version) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyAgentVersion: av}) +} + +func (h *Handler) handleDeleteAgentVersion( + ctx context.Context, c *echo.Context, agentID, version string, +) error { + if err := h.Backend.DeleteAgentVersion(ctx, agentID, version); err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{ + keyAgentID: agentID, + keyAgentVersion: version, + keyAgentStatus: statusDeleting, + }) +} + +func (h *Handler) handleListAgentVersions( + ctx context.Context, c *echo.Context, agentID string, +) error { + maxResults, nextToken := pageParams(c.Request().URL.Query()) + + summaries, outToken, err := h.Backend.ListAgentVersions(ctx, agentID, maxResults, nextToken) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{ + "agentVersionSummaries": summaries, + keyNextToken: outToken, + }) +} + +// --------------------------------------------------------------------------- +// Agent action group handlers +// --------------------------------------------------------------------------- + +func (h *Handler) handleCreateAgentActionGroup( + ctx context.Context, c *echo.Context, agentID string, body []byte, +) error { + var req struct { + ActionGroupExecutor map[string]any `json:"actionGroupExecutor"` + APISchema map[string]any `json:"apiSchema"` + FunctionSchema map[string]any `json:"functionSchema"` + ActionGroupName string `json:"actionGroupName"` + Description string `json:"description"` + ActionGroupState string `json:"actionGroupState"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + ag, err := h.Backend.CreateAgentActionGroup(ctx, agentID, ActionGroupConfig{ + ActionGroupName: req.ActionGroupName, + Description: req.Description, + ActionGroupState: req.ActionGroupState, + ActionGroupExecutor: req.ActionGroupExecutor, + APISchema: req.APISchema, + FunctionSchema: req.FunctionSchema, + }) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyAgentActionGroup: ag}) +} + +func (h *Handler) handleGetAgentActionGroup( + ctx context.Context, c *echo.Context, agentID, agentVersion, agID string, +) error { + ag, err := h.Backend.GetAgentActionGroup(ctx, agentID, agentVersion, agID) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyAgentActionGroup: ag}) +} + +func (h *Handler) handleUpdateAgentActionGroup( + ctx context.Context, c *echo.Context, agentID, agentVersion, agID string, body []byte, +) error { + var req struct { + ActionGroupExecutor map[string]any `json:"actionGroupExecutor"` + APISchema map[string]any `json:"apiSchema"` + FunctionSchema map[string]any `json:"functionSchema"` + ActionGroupName string `json:"actionGroupName"` + Description string `json:"description"` + ActionGroupState string `json:"actionGroupState"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + ag, err := h.Backend.UpdateAgentActionGroup(ctx, agentID, agentVersion, agID, ActionGroupConfig{ + ActionGroupName: req.ActionGroupName, + Description: req.Description, + ActionGroupState: req.ActionGroupState, + ActionGroupExecutor: req.ActionGroupExecutor, + APISchema: req.APISchema, + FunctionSchema: req.FunctionSchema, + }) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyAgentActionGroup: ag}) +} + +func (h *Handler) handleDeleteAgentActionGroup( + ctx context.Context, c *echo.Context, agentID, agentVersion, agID string, +) error { + if err := h.Backend.DeleteAgentActionGroup(ctx, agentID, agentVersion, agID); err != nil { + return handleErr(c, err) + } + + return c.NoContent(http.StatusNoContent) +} + +func (h *Handler) handleListAgentActionGroups( + ctx context.Context, c *echo.Context, agentID, agentVersion string, +) error { + maxResults, nextToken := pageParams(c.Request().URL.Query()) + + summaries, outToken, err := h.Backend.ListAgentActionGroups(ctx, agentID, agentVersion, maxResults, nextToken) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{ + "actionGroupSummaries": summaries, + keyNextToken: outToken, + }) +} + +// --------------------------------------------------------------------------- +// Agent alias handlers +// --------------------------------------------------------------------------- + +func (h *Handler) handleCreateAgentAlias( + ctx context.Context, c *echo.Context, agentID string, body []byte, +) error { + var req struct { + Tags map[string]string `json:"tags"` + AgentAliasName string `json:"agentAliasName"` + Description string `json:"description"` + RoutingConfiguration []AliasRouting `json:"routingConfiguration"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + al, err := h.Backend.CreateAgentAlias(ctx, agentID, AliasConfig{ + AliasName: req.AgentAliasName, + Description: req.Description, + RoutingConfiguration: req.RoutingConfiguration, + Tags: req.Tags, + }) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyAgentAlias: al}) +} + +func (h *Handler) handleGetAgentAlias( + ctx context.Context, c *echo.Context, agentID, aliasID string, +) error { + al, err := h.Backend.GetAgentAlias(ctx, agentID, aliasID) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyAgentAlias: al}) +} + +func (h *Handler) handleUpdateAgentAlias( + ctx context.Context, c *echo.Context, agentID, aliasID string, body []byte, +) error { + var req struct { + Tags map[string]string `json:"tags"` + AgentAliasName string `json:"agentAliasName"` + Description string `json:"description"` + RoutingConfiguration []AliasRouting `json:"routingConfiguration"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + al, err := h.Backend.UpdateAgentAlias(ctx, agentID, aliasID, AliasConfig{ + AliasName: req.AgentAliasName, + Description: req.Description, + RoutingConfiguration: req.RoutingConfiguration, + Tags: req.Tags, + }) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyAgentAlias: al}) +} + +func (h *Handler) handleDeleteAgentAlias( + ctx context.Context, c *echo.Context, agentID, aliasID string, +) error { + if err := h.Backend.DeleteAgentAlias(ctx, agentID, aliasID); err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{ + keyAgentID: agentID, + "agentAliasId": aliasID, + "agentAliasStatus": statusDeleting, + }) +} + +func (h *Handler) handleListAgentAliases( + ctx context.Context, c *echo.Context, agentID string, +) error { + maxResults, nextToken := pageParams(c.Request().URL.Query()) + + summaries, outToken, err := h.Backend.ListAgentAliases(ctx, agentID, maxResults, nextToken) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{ + "agentAliasSummaries": summaries, + keyNextToken: outToken, + }) +} + +// --------------------------------------------------------------------------- +// Collaborator handlers +// --------------------------------------------------------------------------- + +func (h *Handler) handleAssociateCollaborator( + ctx context.Context, c *echo.Context, agentID, agentVersion string, body []byte, +) error { + var req struct { + AgentDescriptor map[string]any `json:"agentDescriptor"` + CollaboratorName string `json:"collaboratorName"` + CollaborationInstruction string `json:"collaborationInstruction"` + RelayConversationHistory string `json:"relayConversationHistory"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + collab, err := h.Backend.AssociateAgentCollaborator(ctx, agentID, agentVersion, CollaboratorConfig{ + CollaboratorName: req.CollaboratorName, + CollaborationInstruction: req.CollaborationInstruction, + RelayConversationHistory: req.RelayConversationHistory, + AgentDescriptor: req.AgentDescriptor, + }) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyAgentCollaborator: collab}) +} + +func (h *Handler) handleGetCollaborator( + ctx context.Context, c *echo.Context, agentID, agentVersion, collaboratorID string, +) error { + collab, err := h.Backend.GetAgentCollaborator(ctx, agentID, agentVersion, collaboratorID) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyAgentCollaborator: collab}) +} + +func (h *Handler) handleUpdateCollaborator( + ctx context.Context, c *echo.Context, agentID, agentVersion, collaboratorID string, body []byte, +) error { + var req struct { + AgentDescriptor map[string]any `json:"agentDescriptor"` + CollaboratorName string `json:"collaboratorName"` + CollaborationInstruction string `json:"collaborationInstruction"` + RelayConversationHistory string `json:"relayConversationHistory"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + collab, err := h.Backend.UpdateAgentCollaborator(ctx, agentID, agentVersion, collaboratorID, CollaboratorConfig{ + CollaboratorName: req.CollaboratorName, + CollaborationInstruction: req.CollaborationInstruction, + RelayConversationHistory: req.RelayConversationHistory, + AgentDescriptor: req.AgentDescriptor, + }) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyAgentCollaborator: collab}) +} + +func (h *Handler) handleDisassociateCollaborator( + ctx context.Context, c *echo.Context, agentID, agentVersion, collaboratorID string, +) error { + if err := h.Backend.DisassociateAgentCollaborator(ctx, agentID, agentVersion, collaboratorID); err != nil { + return handleErr(c, err) + } + + return c.NoContent(http.StatusNoContent) +} + +func (h *Handler) handleListCollaborators( + ctx context.Context, c *echo.Context, agentID, agentVersion string, +) error { + maxResults, nextToken := pageParams(c.Request().URL.Query()) + + collabs, outToken, err := h.Backend.ListAgentCollaborators(ctx, agentID, agentVersion, maxResults, nextToken) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{ + "agentCollaboratorSummaries": collabs, + keyNextToken: outToken, + }) +} + +// --------------------------------------------------------------------------- +// Knowledge base handlers +// --------------------------------------------------------------------------- + +func (h *Handler) handleCreateKB(ctx context.Context, c *echo.Context, body []byte) error { + var req struct { + Tags map[string]string `json:"tags"` + KBConfiguration map[string]any `json:"knowledgeBaseConfiguration"` + StorageConfiguration map[string]any `json:"storageConfiguration"` + Name string `json:"name"` + Description string `json:"description"` + RoleARN string `json:"roleArn"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + kb, err := h.Backend.CreateKnowledgeBase(ctx, KnowledgeBaseConfig{ + Name: req.Name, + Description: req.Description, + RoleARN: req.RoleARN, + KBConfiguration: req.KBConfiguration, + StorageConfiguration: req.StorageConfiguration, + Tags: req.Tags, + }) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyKnowledgeBase: kb}) +} + +func (h *Handler) handleGetKB(ctx context.Context, c *echo.Context, kbID string) error { + kb, err := h.Backend.GetKnowledgeBase(ctx, kbID) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyKnowledgeBase: kb}) +} + +func (h *Handler) handleUpdateKB(ctx context.Context, c *echo.Context, kbID string, body []byte) error { + var req struct { + Tags map[string]string `json:"tags"` + KBConfiguration map[string]any `json:"knowledgeBaseConfiguration"` + StorageConfiguration map[string]any `json:"storageConfiguration"` + Name string `json:"name"` + Description string `json:"description"` + RoleARN string `json:"roleArn"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + kb, err := h.Backend.UpdateKnowledgeBase(ctx, kbID, KnowledgeBaseConfig{ + Name: req.Name, + Description: req.Description, + RoleARN: req.RoleARN, + KBConfiguration: req.KBConfiguration, + StorageConfiguration: req.StorageConfiguration, + Tags: req.Tags, + }) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyKnowledgeBase: kb}) +} + +func (h *Handler) handleDeleteKB(ctx context.Context, c *echo.Context, kbID string) error { + if err := h.Backend.DeleteKnowledgeBase(ctx, kbID); err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{"knowledgeBaseId": kbID, keyStatus: statusDeleting}) +} + +func (h *Handler) handleListKBs(ctx context.Context, c *echo.Context) error { + maxResults, nextToken := pageParams(c.Request().URL.Query()) + + summaries, outToken, err := h.Backend.ListKnowledgeBases(ctx, maxResults, nextToken) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{ + "knowledgeBaseSummaries": summaries, + keyNextToken: outToken, + }) +} + +// --------------------------------------------------------------------------- +// Agent–KB association handlers +// --------------------------------------------------------------------------- + +func (h *Handler) handleAssociateAgentKB( + ctx context.Context, c *echo.Context, agentID, agentVersion string, body []byte, +) error { + var req struct { + KnowledgeBaseID string `json:"knowledgeBaseId"` + Description string `json:"description"` + KnowledgeBaseState string `json:"knowledgeBaseState"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + assoc, err := h.Backend.AssociateAgentKnowledgeBase( + ctx, agentID, agentVersion, req.KnowledgeBaseID, req.Description, req.KnowledgeBaseState, + ) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyAgentKB: assoc}) +} + +func (h *Handler) handleGetAgentKB( + ctx context.Context, c *echo.Context, agentID, agentVersion, kbID string, +) error { + assoc, err := h.Backend.GetAgentKnowledgeBase(ctx, agentID, agentVersion, kbID) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyAgentKB: assoc}) +} + +func (h *Handler) handleUpdateAgentKB( + ctx context.Context, c *echo.Context, agentID, agentVersion, kbID string, body []byte, +) error { + var req struct { + Description string `json:"description"` + KnowledgeBaseState string `json:"knowledgeBaseState"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + assoc, err := h.Backend.UpdateAgentKnowledgeBase( + ctx, agentID, agentVersion, kbID, req.Description, req.KnowledgeBaseState, + ) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyAgentKB: assoc}) +} + +func (h *Handler) handleDisassociateAgentKB( + ctx context.Context, c *echo.Context, agentID, agentVersion, kbID string, +) error { + if err := h.Backend.DisassociateAgentKnowledgeBase(ctx, agentID, agentVersion, kbID); err != nil { + return handleErr(c, err) + } + + return c.NoContent(http.StatusNoContent) +} + +func (h *Handler) handleListAgentKBs( + ctx context.Context, c *echo.Context, agentID, agentVersion string, +) error { + maxResults, nextToken := pageParams(c.Request().URL.Query()) + + assocs, outToken, err := h.Backend.ListAgentKnowledgeBases(ctx, agentID, agentVersion, maxResults, nextToken) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{ + "agentKnowledgeBaseSummaries": assocs, + keyNextToken: outToken, + }) +} + +// --------------------------------------------------------------------------- +// Data source handlers +// --------------------------------------------------------------------------- + +func (h *Handler) handleCreateDS(ctx context.Context, c *echo.Context, kbID string, body []byte) error { + var req struct { + DataSourceConfiguration map[string]any `json:"dataSourceConfiguration"` + VectorIngestionConfig map[string]any `json:"vectorIngestionConfiguration"` + Name string `json:"name"` + Description string `json:"description"` + DataDeletionPolicy string `json:"dataDeletionPolicy"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + ds, err := h.Backend.CreateDataSource(ctx, kbID, DataSourceConfig{ + Name: req.Name, + Description: req.Description, + DataDeletionPolicy: req.DataDeletionPolicy, + DataSourceConfiguration: req.DataSourceConfiguration, + VectorIngestionConfig: req.VectorIngestionConfig, + }) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyDataSource: ds}) +} + +func (h *Handler) handleGetDS(ctx context.Context, c *echo.Context, kbID, dsID string) error { + ds, err := h.Backend.GetDataSource(ctx, kbID, dsID) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyDataSource: ds}) +} + +func (h *Handler) handleUpdateDS( + ctx context.Context, c *echo.Context, kbID, dsID string, body []byte, +) error { + var req struct { + DataSourceConfiguration map[string]any `json:"dataSourceConfiguration"` + VectorIngestionConfig map[string]any `json:"vectorIngestionConfiguration"` + Name string `json:"name"` + Description string `json:"description"` + DataDeletionPolicy string `json:"dataDeletionPolicy"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + ds, err := h.Backend.UpdateDataSource(ctx, kbID, dsID, DataSourceConfig{ + Name: req.Name, + Description: req.Description, + DataDeletionPolicy: req.DataDeletionPolicy, + DataSourceConfiguration: req.DataSourceConfiguration, + VectorIngestionConfig: req.VectorIngestionConfig, + }) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyDataSource: ds}) +} + +func (h *Handler) handleDeleteDS(ctx context.Context, c *echo.Context, kbID, dsID string) error { + if err := h.Backend.DeleteDataSource(ctx, kbID, dsID); err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{ + "dataSourceId": dsID, + "knowledgeBaseId": kbID, + keyStatus: statusDeleting, + }) +} + +func (h *Handler) handleListDS(ctx context.Context, c *echo.Context, kbID string) error { + maxResults, nextToken := pageParams(c.Request().URL.Query()) + + summaries, outToken, err := h.Backend.ListDataSources(ctx, kbID, maxResults, nextToken) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{ + "dataSourceSummaries": summaries, + keyNextToken: outToken, + }) +} + +// --------------------------------------------------------------------------- +// Ingestion job handlers +// --------------------------------------------------------------------------- + +func (h *Handler) handleStartIngestionJob( + ctx context.Context, c *echo.Context, kbID, dsID string, body []byte, +) error { + var req struct { + Description string `json:"description"` + } + + _ = json.Unmarshal(body, &req) + + job, err := h.Backend.StartIngestionJob(ctx, kbID, dsID, req.Description) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusAccepted, map[string]any{keyIngestionJob: job}) +} + +func (h *Handler) handleGetIngestionJob( + ctx context.Context, c *echo.Context, kbID, dsID, jobID string, +) error { + job, err := h.Backend.GetIngestionJob(ctx, kbID, dsID, jobID) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyIngestionJob: job}) +} + +func (h *Handler) handleStopIngestionJob( + ctx context.Context, c *echo.Context, kbID, dsID, jobID string, +) error { + job, err := h.Backend.StopIngestionJob(ctx, kbID, dsID, jobID) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyIngestionJob: job}) +} + +func (h *Handler) handleListIngestionJobs( + ctx context.Context, c *echo.Context, kbID, dsID string, +) error { + maxResults, nextToken := pageParams(c.Request().URL.Query()) + + jobs, outToken, err := h.Backend.ListIngestionJobs(ctx, kbID, dsID, maxResults, nextToken) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{ + "ingestionJobSummaries": jobs, + keyNextToken: outToken, + }) +} + +// --------------------------------------------------------------------------- +// Flow handlers +// --------------------------------------------------------------------------- + +func (h *Handler) handleCreateFlow(ctx context.Context, c *echo.Context, body []byte) error { + var req struct { + Tags map[string]string `json:"tags"` + Definition map[string]any `json:"definition"` + Name string `json:"name"` + Description string `json:"description"` + RoleARN string `json:"executionRoleArn"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + f, err := h.Backend.CreateFlow(ctx, FlowConfig{ + Name: req.Name, + Description: req.Description, + RoleARN: req.RoleARN, + Definition: req.Definition, + Tags: req.Tags, + }) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusCreated, f) +} + +func (h *Handler) handleGetFlow(ctx context.Context, c *echo.Context, flowID string) error { + f, err := h.Backend.GetFlow(ctx, flowID) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, f) +} + +func (h *Handler) handleUpdateFlow( + ctx context.Context, c *echo.Context, flowID string, body []byte, +) error { + var req struct { + Tags map[string]string `json:"tags"` + Definition map[string]any `json:"definition"` + Name string `json:"name"` + Description string `json:"description"` + RoleARN string `json:"executionRoleArn"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + f, err := h.Backend.UpdateFlow(ctx, flowID, FlowConfig{ + Name: req.Name, + Description: req.Description, + RoleARN: req.RoleARN, + Definition: req.Definition, + Tags: req.Tags, + }) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, f) +} + +func (h *Handler) handleDeleteFlow(ctx context.Context, c *echo.Context, flowID string) error { + if err := h.Backend.DeleteFlow(ctx, flowID); err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{"id": flowID, keyStatus: "Deleting"}) +} + +func (h *Handler) handleListFlows(ctx context.Context, c *echo.Context) error { + maxResults, nextToken := pageParams(c.Request().URL.Query()) + + summaries, outToken, err := h.Backend.ListFlows(ctx, maxResults, nextToken) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{"flowSummaries": summaries, keyNextToken: outToken}) +} + +func (h *Handler) handlePrepareFlow(ctx context.Context, c *echo.Context, flowID string) error { + f, err := h.Backend.PrepareFlow(ctx, flowID) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusAccepted, f) +} + +func (h *Handler) handleValidateFlowDef(ctx context.Context, c *echo.Context, body []byte) error { + var req struct { + Definition map[string]any `json:"definition"` + } + + _ = json.Unmarshal(body, &req) + + errs, err := h.Backend.ValidateFlowDefinition(ctx, req.Definition) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{"validations": errs}) +} + +// --------------------------------------------------------------------------- +// Flow version handlers +// --------------------------------------------------------------------------- + +func (h *Handler) handleCreateFlowVersion( + ctx context.Context, c *echo.Context, flowID string, body []byte, +) error { + var req struct { + Description string `json:"description"` + } + + _ = json.Unmarshal(body, &req) + + fv, err := h.Backend.CreateFlowVersion(ctx, flowID, req.Description) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusCreated, fv) +} + +func (h *Handler) handleGetFlowVersion( + ctx context.Context, c *echo.Context, flowID, flowVersion string, +) error { + fv, err := h.Backend.GetFlowVersion(ctx, flowID, flowVersion) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, fv) +} + +func (h *Handler) handleDeleteFlowVersion( + ctx context.Context, c *echo.Context, flowID, flowVersion string, +) error { + if err := h.Backend.DeleteFlowVersion(ctx, flowID, flowVersion); err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{"id": flowID, "version": flowVersion, keyStatus: "Deleting"}) +} + +func (h *Handler) handleListFlowVersions( + ctx context.Context, c *echo.Context, flowID string, +) error { + maxResults, nextToken := pageParams(c.Request().URL.Query()) + + summaries, outToken, err := h.Backend.ListFlowVersions(ctx, flowID, maxResults, nextToken) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{"flowVersionSummaries": summaries, keyNextToken: outToken}) +} + +// --------------------------------------------------------------------------- +// Flow alias handlers +// --------------------------------------------------------------------------- + +func (h *Handler) handleCreateFlowAlias( + ctx context.Context, c *echo.Context, flowID string, body []byte, +) error { + var req struct { + Tags map[string]string `json:"tags"` + Name string `json:"name"` + Description string `json:"description"` + RoutingConfiguration []FlowAliasRouting `json:"routingConfiguration"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + al, err := h.Backend.CreateFlowAlias(ctx, flowID, FlowAliasConfig{ + Name: req.Name, + Description: req.Description, + RoutingConfiguration: req.RoutingConfiguration, + Tags: req.Tags, + }) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusCreated, al) +} + +func (h *Handler) handleGetFlowAlias( + ctx context.Context, c *echo.Context, flowID, aliasID string, +) error { + al, err := h.Backend.GetFlowAlias(ctx, flowID, aliasID) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, al) +} + +func (h *Handler) handleUpdateFlowAlias( + ctx context.Context, c *echo.Context, flowID, aliasID string, body []byte, +) error { + var req struct { + Tags map[string]string `json:"tags"` + Name string `json:"name"` + Description string `json:"description"` + RoutingConfiguration []FlowAliasRouting `json:"routingConfiguration"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + al, err := h.Backend.UpdateFlowAlias(ctx, flowID, aliasID, FlowAliasConfig{ + Name: req.Name, + Description: req.Description, + RoutingConfiguration: req.RoutingConfiguration, + Tags: req.Tags, + }) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, al) +} + +func (h *Handler) handleDeleteFlowAlias( + ctx context.Context, c *echo.Context, flowID, aliasID string, +) error { + if err := h.Backend.DeleteFlowAlias(ctx, flowID, aliasID); err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{"id": aliasID, "flowId": flowID}) +} + +func (h *Handler) handleListFlowAliases( + ctx context.Context, c *echo.Context, flowID string, +) error { + maxResults, nextToken := pageParams(c.Request().URL.Query()) + + summaries, outToken, err := h.Backend.ListFlowAliases(ctx, flowID, maxResults, nextToken) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{"flowAliasSummaries": summaries, keyNextToken: outToken}) +} + +// --------------------------------------------------------------------------- +// Prompt handlers +// --------------------------------------------------------------------------- + +func (h *Handler) handleCreatePrompt(ctx context.Context, c *echo.Context, body []byte) error { + var req struct { + Tags map[string]string `json:"tags"` + Name string `json:"name"` + Description string `json:"description"` + DefaultVariant string `json:"defaultVariant"` + Variants []map[string]any `json:"variants"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + p, err := h.Backend.CreatePrompt(ctx, PromptConfig{ + Name: req.Name, + Description: req.Description, + DefaultVariant: req.DefaultVariant, + Variants: req.Variants, + Tags: req.Tags, + }) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusCreated, p) +} + +func (h *Handler) handleGetPrompt(ctx context.Context, c *echo.Context, promptID string) error { + p, err := h.Backend.GetPrompt(ctx, promptID) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, p) +} + +func (h *Handler) handleUpdatePrompt( + ctx context.Context, c *echo.Context, promptID string, body []byte, +) error { + var req struct { + Tags map[string]string `json:"tags"` + Name string `json:"name"` + Description string `json:"description"` + DefaultVariant string `json:"defaultVariant"` + Variants []map[string]any `json:"variants"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + p, err := h.Backend.UpdatePrompt(ctx, promptID, PromptConfig{ + Name: req.Name, + Description: req.Description, + DefaultVariant: req.DefaultVariant, + Variants: req.Variants, + Tags: req.Tags, + }) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, p) +} + +func (h *Handler) handleDeletePrompt(ctx context.Context, c *echo.Context, promptID string) error { + if err := h.Backend.DeletePrompt(ctx, promptID); err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{"id": promptID}) +} + +func (h *Handler) handleListPrompts(ctx context.Context, c *echo.Context) error { + maxResults, nextToken := pageParams(c.Request().URL.Query()) + + summaries, outToken, err := h.Backend.ListPrompts(ctx, maxResults, nextToken) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{"promptSummaries": summaries, keyNextToken: outToken}) +} + +// --------------------------------------------------------------------------- +// Prompt version handlers +// --------------------------------------------------------------------------- + +func (h *Handler) handleCreatePromptVersion( + ctx context.Context, c *echo.Context, promptID string, body []byte, +) error { + var req struct { + Description string `json:"description"` + } + + _ = json.Unmarshal(body, &req) + + pv, err := h.Backend.CreatePromptVersion(ctx, promptID, req.Description) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusCreated, pv) +} + +func (h *Handler) handleGetPromptVersion( + ctx context.Context, c *echo.Context, promptID, version string, +) error { + pv, err := h.Backend.GetPromptVersion(ctx, promptID, version) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, pv) +} + +func (h *Handler) handleDeletePromptVersion( + ctx context.Context, c *echo.Context, promptID, version string, +) error { + if err := h.Backend.DeletePromptVersion(ctx, promptID, version); err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{"id": promptID, "version": version}) +} + +// --------------------------------------------------------------------------- +// KB document handlers +// --------------------------------------------------------------------------- + +func (h *Handler) handleIngestKBDocs( + ctx context.Context, c *echo.Context, kbID, dsID string, body []byte, +) error { + var req struct { + Documents []struct { + Metadata map[string]any `json:"metadata"` + Content map[string]any `json:"content"` + DocID string `json:"documentId"` + } `json:"documents"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + docs := make([]KBDocument, 0, len(req.Documents)) + + for _, d := range req.Documents { + docs = append(docs, KBDocument{ + DocID: d.DocID, + Metadata: d.Metadata, + Content: d.Content, + }) + } + + details, err := h.Backend.IngestKnowledgeBaseDocuments(ctx, kbID, dsID, docs) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusAccepted, map[string]any{keyDocumentDetails: details}) +} + +func (h *Handler) handleGetKBDocs( + ctx context.Context, c *echo.Context, kbID, dsID string, body []byte, +) error { + var req struct { + DocumentIDs []string `json:"documentIds"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + details, err := h.Backend.GetKnowledgeBaseDocuments(ctx, kbID, dsID, req.DocumentIDs) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{keyDocumentDetails: details}) +} + +func (h *Handler) handleDeleteKBDocs( + ctx context.Context, c *echo.Context, kbID, dsID string, body []byte, +) error { + var req struct { + DocumentIDs []string `json:"documentIds"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + details, err := h.Backend.DeleteKnowledgeBaseDocuments(ctx, kbID, dsID, req.DocumentIDs) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusAccepted, map[string]any{keyDocumentDetails: details}) +} + +func (h *Handler) handleListKBDocs( + ctx context.Context, c *echo.Context, kbID, dsID string, +) error { + maxResults, nextToken := pageParams(c.Request().URL.Query()) + + details, outToken, err := h.Backend.ListKnowledgeBaseDocuments(ctx, kbID, dsID, maxResults, nextToken) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{ + keyDocumentDetails: details, + keyNextToken: outToken, + }) +} + +// --------------------------------------------------------------------------- +// Tag handlers +// --------------------------------------------------------------------------- + +func (h *Handler) handleListTags( + ctx context.Context, c *echo.Context, resourceARN string, +) error { + tags, err := h.Backend.ListTagsForResource(ctx, resourceARN) + if err != nil { + return handleErr(c, err) + } + + return c.JSON(http.StatusOK, map[string]any{"tags": tags}) +} + +func (h *Handler) handleTagResource( + ctx context.Context, c *echo.Context, resourceARN string, body []byte, +) error { + var req struct { + Tags map[string]string `json:"tags"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return handleErr(c, err) + } + + if err := h.Backend.TagResource(ctx, resourceARN, req.Tags); err != nil { + return handleErr(c, err) + } + + return c.NoContent(http.StatusNoContent) +} + +func (h *Handler) handleUntagResource( + ctx context.Context, c *echo.Context, resourceARN string, query url.Values, +) error { + tagKeys := query["tagKeys"] + + if err := h.Backend.UntagResource(ctx, resourceARN, tagKeys); err != nil { + return handleErr(c, err) + } + + return c.NoContent(http.StatusNoContent) +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +func handleErr(c *echo.Context, err error) error { + var syntaxErr *json.SyntaxError + + var code string + var status int + + switch { + case errors.Is(err, awserr.ErrNotFound): + status = http.StatusNotFound + code = "ResourceNotFoundException" + case errors.Is(err, awserr.ErrAlreadyExists): + status = http.StatusConflict + code = "ConflictException" + case errors.Is(err, awserr.ErrInvalidParameter): + status = http.StatusBadRequest + code = "ValidationException" + case errors.As(err, &syntaxErr): + status = http.StatusBadRequest + code = "ValidationException" + default: + status = http.StatusInternalServerError + code = "InternalServerException" + } + + c.Response().Header().Set("X-Amzn-Errortype", code) + + return c.JSON(status, map[string]any{"message": err.Error()}) +} + +func errResp(code, msg string) map[string]any { + return map[string]any{"__type": code, "message": msg} +} + +func pageParams(query url.Values) (int, string) { + maxResults := maxPageDefault + nextToken := query.Get(keyNextToken) + + if mr := query.Get("maxResults"); mr != "" { + _, _ = fmt.Sscanf(mr, "%d", &maxResults) + } + + return maxResults, nextToken +} + +// classifyPath returns the operation name from method+path (used by ExtractOperation). +func classifyPath(method, path string) string { + path = strings.TrimSuffix(path, "/") + + switch { + case path == agentsBase && isWrite(method): + return opCreateAgent + case path == agentsBase: + return opListAgents + case path == kbBase && isWrite(method): + return opCreateKnowledgeBase + case path == kbBase: + return opListKnowledgeBases + case path == flowsBase && method == http.MethodPost: + return opCreateFlow + case path == flowsBase: + return opListFlows + case path == promptsBase && method == http.MethodPost: + return opCreatePrompt + case path == promptsBase: + return opListPrompts + } + + return classifySubPath(method, path) +} + +func classifySubPath(method, path string) string { + switch { + case strings.HasPrefix(path, agentsBase+"/"): + return classifyAgentPath(method, path) + case strings.HasPrefix(path, kbBase+"/"): + return classifyKBPath(method, path) + case strings.HasPrefix(path, flowsBase+"/"): + return classifyFlowPath(method, path) + case strings.HasPrefix(path, promptsBase+"/"): + return classifyPromptPath(method, path) + case strings.HasPrefix(path, tagsBase): + return classifyTagPath(method) + } + + return opUnknown +} + +// classifyAgentVersionedSubPath handles the collaborator, agentKB, alias, and actiongroup cases. +func classifyAgentVersionedSubPath(method string, segs []string) string { + switch { + case containsSeg(segs, "actiongroups"): + return classifyActionGroupPath(method, segs) + case containsSeg(segs, "agentcollaborators"): + return classifyCollabPath(method, segs) + case containsSeg(segs, "knowledgebases"): + return classifyAgentKBPath(method, segs) + default: + return classifyAgentVersionPath(method, segs) + } +} + +func classifyAgentPath(method, path string) string { + rest, _ := strings.CutPrefix(path, agentsBase+"/") + segs := strings.Split(rest, "/") + + switch { + case len(segs) == 1 && method == http.MethodGet: + return opGetAgent + case len(segs) == 1 && method == http.MethodPut: + return opUpdateAgent + case len(segs) == 1 && method == http.MethodDelete: + return opDeleteAgent + case len(segs) == 2 && segs[1] == "prepare": + return opPrepareAgent + case containsSeg(segs, "agentversions"): + return classifyAgentVersionedSubPath(method, segs) + case containsSeg(segs, "agentaliases"): + return classifyAliasPath(method, segs) + } + + return opUnknown +} + +func classifyActionGroupPath(method string, segs []string) string { + idx := indexOf(segs, "actiongroups") + hasID := len(segs) > idx+1 && segs[idx+1] != "" + + if !hasID { + switch method { + case http.MethodPut, http.MethodPost: + return opCreateAgentActionGroup + case http.MethodGet: + return opListAgentActionGroups + } + } + + switch method { + case http.MethodGet: + return opGetAgentActionGroup + case http.MethodPut: + return opUpdateAgentActionGroup + case http.MethodDelete: + return opDeleteAgentActionGroup + } + + return opUnknown +} + +func classifyCollabPath(method string, segs []string) string { + idx := indexOf(segs, "agentcollaborators") + hasID := len(segs) > idx+1 && segs[idx+1] != "" + + if !hasID { + switch method { + case http.MethodPut: + return opAssociateAgentCollaborator + case http.MethodGet: + return opListAgentCollaborators + } + } + + switch method { + case http.MethodGet: + return opGetAgentCollaborator + case http.MethodPut: + return opUpdateAgentCollaborator + case http.MethodDelete: + return opDisassociateAgentCollaborator + } + + return opUnknown +} + +func classifyAgentKBPath(method string, segs []string) string { + idx := indexOf(segs, "knowledgebases") + hasID := len(segs) > idx+1 && segs[idx+1] != "" + + if !hasID { + switch method { + case http.MethodPut: + return opAssociateAgentKnowledgeBase + case http.MethodGet: + return opListAgentKnowledgeBases + } + } + + switch method { + case http.MethodGet: + return opGetAgentKnowledgeBase + case http.MethodPut: + return opUpdateAgentKnowledgeBase + case http.MethodDelete: + return opDisassociateAgentKnowledgeBase + } + + return opUnknown +} + +func classifyAgentVersionPath(method string, segs []string) string { + idx := indexOf(segs, "agentversions") + hasVersionID := len(segs) > idx+1 && segs[idx+1] != "" + + if !hasVersionID { + switch method { + case http.MethodPost: + return opCreateAgentVersion + case http.MethodGet: + return opListAgentVersions + } + } + + switch method { + case http.MethodGet: + return opGetAgentVersion + case http.MethodDelete: + return opDeleteAgentVersion + } + + return opUnknown +} + +func classifyAliasPath(method string, segs []string) string { + idx := indexOf(segs, "agentaliases") + hasID := len(segs) > idx+1 && segs[idx+1] != "" + + if !hasID { + switch method { + case http.MethodPost, http.MethodPut: + return opCreateAgentAlias + case http.MethodGet: + return opListAgentAliases + } + } + + switch method { + case http.MethodGet: + return opGetAgentAlias + case http.MethodPut: + return opUpdateAgentAlias + case http.MethodDelete: + return opDeleteAgentAlias + } + + return opUnknown +} + +func classifyKBPath(method, path string) string { + rest, _ := strings.CutPrefix(path, kbBase+"/") + segs := strings.Split(rest, "/") + + switch { + case len(segs) == 1 && method == http.MethodGet: + return opGetKnowledgeBase + case len(segs) == 1 && method == http.MethodPut: + return opUpdateKnowledgeBase + case len(segs) == 1 && method == http.MethodDelete: + return opDeleteKnowledgeBase + case containsSeg(segs, "datasources"): + return classifyDSPath(method, segs) + } + + return opUnknown +} + +func classifyDSPath(method string, segs []string) string { + idx := indexOf(segs, "datasources") + hasDSID := len(segs) > idx+1 && segs[idx+1] != "" + + if !hasDSID { + switch method { + case http.MethodPut, http.MethodPost: + return opCreateDataSource + case http.MethodGet: + return opListDataSources + } + } + + dsSuffix := "" + + if len(segs) > idx+splitTwo { + dsSuffix = segs[idx+splitTwo] + } + + return classifyDSSuffix(method, segs[idx+1], dsSuffix, segs) +} + +func classifyDSSuffix(method, _, suffix string, segs []string) string { + switch suffix { + case "ingestionjobs": + return classifyJobPath(method, segs) + case "documents": + return classifyDocPath(method, segs) + case "": + switch method { + case http.MethodGet: + return opGetDataSource + case http.MethodPut: + return opUpdateDataSource + case http.MethodDelete: + return opDeleteDataSource + } + } + + return opUnknown +} + +func classifyJobPath(method string, segs []string) string { + idx := indexOf(segs, "ingestionjobs") + hasJobID := len(segs) > idx+1 && segs[idx+1] != "" + + if !hasJobID { + switch method { + case http.MethodPut, http.MethodPost: + return opStartIngestionJob + case http.MethodGet: + return opListIngestionJobs + } + } + + if len(segs) > idx+splitTwo && segs[idx+splitTwo] == "stop" { + return opStopIngestionJob + } + + return opGetIngestionJob +} + +func classifyDocPath(method string, segs []string) string { + idx := indexOf(segs, "documents") + + if len(segs) > idx+1 { + switch segs[idx+1] { + case "deleteDocuments": + return opDeleteKnowledgeBaseDocuments + case "getDocuments": + return opGetKnowledgeBaseDocuments + } + } + + switch method { + case http.MethodPost: + return opIngestKnowledgeBaseDocuments + case http.MethodGet: + return opListKnowledgeBaseDocuments + } + + return opUnknown +} + +func classifyFlowPath(method, path string) string { + rest, _ := strings.CutPrefix(path, flowsBase+"/") + segs := strings.Split(rest, "/") + + switch { + case len(segs) == 1 && method == http.MethodGet: + return opGetFlow + case len(segs) == 1 && method == http.MethodPut: + return opUpdateFlow + case len(segs) == 1 && method == http.MethodDelete: + return opDeleteFlow + case len(segs) == 2 && segs[1] == "prepare": + return opPrepareFlow + case containsSeg(segs, "versions"): + return classifyFlowVersionPath(method, segs) + case containsSeg(segs, "aliases"): + return classifyFlowAliasPath(method, segs) + } + + return opUnknown +} + +func classifyFlowVersionPath(method string, segs []string) string { + idx := indexOf(segs, "versions") + hasID := len(segs) > idx+1 && segs[idx+1] != "" + + if !hasID { + switch method { + case http.MethodPost: + return opCreateFlowVersion + case http.MethodGet: + return opListFlowVersions + } + } + + switch method { + case http.MethodGet: + return opGetFlowVersion + case http.MethodDelete: + return opDeleteFlowVersion + } + + return opUnknown +} + +func classifyFlowAliasPath(method string, segs []string) string { + idx := indexOf(segs, "aliases") + hasID := len(segs) > idx+1 && segs[idx+1] != "" + + if !hasID { + switch method { + case http.MethodPost: + return opCreateFlowAlias + case http.MethodGet: + return opListFlowAliases + } + } + + switch method { + case http.MethodGet: + return opGetFlowAlias + case http.MethodPut: + return opUpdateFlowAlias + case http.MethodDelete: + return opDeleteFlowAlias + } + + return opUnknown +} + +func classifyPromptPath(method, path string) string { + rest, _ := strings.CutPrefix(path, promptsBase+"/") + segs := strings.Split(rest, "/") + + switch { + case len(segs) == 1 && method == http.MethodGet: + return opGetPrompt + case len(segs) == 1 && method == http.MethodPut: + return opUpdatePrompt + case len(segs) == 1 && method == http.MethodDelete: + return opDeletePrompt + case containsSeg(segs, "versions"): + return classifyPromptVersionPath(method, segs) + } + + return opUnknown +} + +func classifyPromptVersionPath(method string, segs []string) string { + idx := indexOf(segs, "versions") + hasID := len(segs) > idx+1 && segs[idx+1] != "" + + if !hasID && method == http.MethodPost { + return opCreatePromptVersion + } + + switch method { + case http.MethodGet: + return opGetPromptVersion + case http.MethodDelete: + return opDeletePromptVersion + } + + return opUnknown +} + +func classifyTagPath(method string) string { + switch method { + case http.MethodGet: + return opListTagsForResource + case http.MethodPost: + return opTagResource + case http.MethodDelete: + return opUntagResource + } + + return opUnknown +} + +func isWrite(method string) bool { + return method == http.MethodPost || method == http.MethodPut +} + +func containsSeg(segs []string, seg string) bool { + return slices.Contains(segs, seg) +} + +func indexOf(segs []string, seg string) int { + for i, s := range segs { + if s == seg { + return i + } + } + + return -1 +} diff --git a/services/bedrockagent/handler_test.go b/services/bedrockagent/handler_test.go new file mode 100644 index 000000000..2c5d8c704 --- /dev/null +++ b/services/bedrockagent/handler_test.go @@ -0,0 +1,596 @@ +package bedrockagent_test + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/labstack/echo/v5" + + "github.com/blackbirdworks/gopherstack/services/bedrockagent" +) + +func setupHandler(t *testing.T) (*bedrockagent.Handler, *echo.Echo) { + t.Helper() + + b := bedrockagent.NewTestBackend("us-east-1", "123456789012") + h := bedrockagent.NewTestHandler(b) + h.AccountID = "123456789012" + h.DefaultRegion = "us-east-1" + + e := echo.New() + + return h, e +} + +func doRequest( + t *testing.T, h *bedrockagent.Handler, e *echo.Echo, method, path string, body any, +) *httptest.ResponseRecorder { + t.Helper() + + var bodyBytes []byte + + if body != nil { + var err error + + bodyBytes, err = json.Marshal(body) + if err != nil { + t.Fatalf("marshal body: %v", err) + } + } + + req := httptest.NewRequest(method, path, bytes.NewReader(bodyBytes)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + if err := h.Handler()(c); err != nil { + t.Logf("handler returned error: %v", err) + } + + return rec +} + +func TestHandlerAgentCRUD(t *testing.T) { + t.Parallel() + + type tc struct { + body any + name string + method string + path string + expectedStatus int + } + + h, e := setupHandler(t) + + createBody := map[string]any{ + "agentName": "test-agent", + "foundationModel": "anthropic.claude-v2", + "agentResourceRoleArn": "arn:aws:iam::123456789012:role/AmazonBedrockRole", + } + + rec := doRequest(t, h, e, http.MethodPut, "/agents", createBody) + if rec.Code != http.StatusOK { + t.Fatalf("create agent got %d want 200: %s", rec.Code, rec.Body.String()) + } + + var createResp map[string]map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &createResp); err != nil { + t.Fatalf("unmarshal create response: %v", err) + } + + agentID, _ := createResp["agent"]["agentId"].(string) + if agentID == "" { + t.Fatal("no agentId in response") + } + + cases := []tc{ + {name: "list agents", method: http.MethodGet, path: "/agents", expectedStatus: http.StatusOK}, + {name: "get agent", method: http.MethodGet, path: "/agents/" + agentID, expectedStatus: http.StatusOK}, + { + name: "update agent", + method: http.MethodPut, + path: "/agents/" + agentID, + body: map[string]any{ + "agentName": "updated-agent", + "foundationModel": "anthropic.claude-v2", + "agentResourceRoleArn": "arn:aws:iam::123456789012:role/AmazonBedrockRole", + }, + expectedStatus: http.StatusOK, + }, + { + name: "prepare agent", + method: http.MethodPost, + path: "/agents/" + agentID + "/prepare", + expectedStatus: http.StatusAccepted, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + hLocal, eLocal := setupHandler(t) + // pre-create the agent for each sub-test + r := doRequest(t, hLocal, eLocal, http.MethodPut, "/agents", createBody) + if r.Code != http.StatusOK { + t.Fatalf("setup create: %d", r.Code) + } + + var sr map[string]map[string]any + _ = json.Unmarshal(r.Body.Bytes(), &sr) + aid := sr["agent"]["agentId"].(string) + + path := tc.path + if agentID != "" { + // substitute if path contains our original agentID + path = "/agents" + switch tc.name { + case "get agent", "update agent": + path = "/agents/" + aid + case "prepare agent": + path = "/agents/" + aid + "/prepare" + } + } + + result := doRequest(t, hLocal, eLocal, tc.method, path, tc.body) + if result.Code != tc.expectedStatus { + t.Errorf("got %d want %d: %s", result.Code, tc.expectedStatus, result.Body.String()) + } + }) + } +} + +func TestHandlerAgentNotFound(t *testing.T) { + t.Parallel() + + h, e := setupHandler(t) + + rec := doRequest(t, h, e, http.MethodGet, "/agents/nonexistent", nil) + if rec.Code != http.StatusNotFound { + t.Errorf("got %d want 404", rec.Code) + } +} + +func TestHandlerKnowledgeBaseCRUD(t *testing.T) { + t.Parallel() + + h, e := setupHandler(t) + + createBody := map[string]any{ + "name": "test-kb", + "roleArn": "arn:aws:iam::123456789012:role/KBRole", + "knowledgeBaseConfiguration": map[string]any{ + "type": "VECTOR", + }, + "storageConfiguration": map[string]any{ + "type": "OPENSEARCH_SERVERLESS", + }, + } + + rec := doRequest(t, h, e, http.MethodPut, "/knowledgebases", createBody) + if rec.Code != http.StatusOK { + t.Fatalf("create kb: %d %s", rec.Code, rec.Body.String()) + } + + var createResp map[string]map[string]any + _ = json.Unmarshal(rec.Body.Bytes(), &createResp) + kbID := createResp["knowledgeBase"]["knowledgeBaseId"].(string) + + t.Run("get kb", func(t *testing.T) { + t.Parallel() + + h2, e2 := setupHandler(t) + r := doRequest(t, h2, e2, http.MethodPut, "/knowledgebases", createBody) + + var resp map[string]map[string]any + _ = json.Unmarshal(r.Body.Bytes(), &resp) + id := resp["knowledgeBase"]["knowledgeBaseId"].(string) + + rec2 := doRequest(t, h2, e2, http.MethodGet, "/knowledgebases/"+id, nil) + if rec2.Code != http.StatusOK { + t.Errorf("got %d want 200", rec2.Code) + } + }) + + t.Run("list kbs", func(t *testing.T) { + t.Parallel() + + rec2 := doRequest(t, h, e, http.MethodGet, "/knowledgebases", nil) + if rec2.Code != http.StatusOK { + t.Errorf("got %d want 200", rec2.Code) + } + }) + + t.Run("delete kb", func(t *testing.T) { + t.Parallel() + + h2, e2 := setupHandler(t) + r := doRequest(t, h2, e2, http.MethodPut, "/knowledgebases", createBody) + + var resp map[string]map[string]any + _ = json.Unmarshal(r.Body.Bytes(), &resp) + id := resp["knowledgeBase"]["knowledgeBaseId"].(string) + + rec2 := doRequest(t, h2, e2, http.MethodDelete, "/knowledgebases/"+id, nil) + if rec2.Code != http.StatusOK { + t.Errorf("got %d want 200", rec2.Code) + } + }) + + _ = kbID +} + +func TestHandlerFlowCRUD(t *testing.T) { + t.Parallel() + + h, e := setupHandler(t) + + createBody := map[string]any{ + "name": "test-flow", + "executionRoleArn": "arn:aws:iam::123456789012:role/FlowRole", + "definition": map[string]any{ + "nodes": []any{}, + "connections": []any{}, + }, + } + + rec := doRequest(t, h, e, http.MethodPost, "/flows", createBody) + if rec.Code != http.StatusCreated { + t.Fatalf("create flow: %d %s", rec.Code, rec.Body.String()) + } + + var createResp map[string]any + _ = json.Unmarshal(rec.Body.Bytes(), &createResp) + flowID, _ := createResp["id"].(string) + + if flowID == "" { + t.Fatal("no id in flow response") + } + + t.Run("get flow", func(t *testing.T) { + t.Parallel() + + h2, e2 := setupHandler(t) + r := doRequest(t, h2, e2, http.MethodPost, "/flows", createBody) + + var resp map[string]any + _ = json.Unmarshal(r.Body.Bytes(), &resp) + id := resp["id"].(string) + + rec2 := doRequest(t, h2, e2, http.MethodGet, "/flows/"+id, nil) + if rec2.Code != http.StatusOK { + t.Errorf("got %d want 200", rec2.Code) + } + }) + + t.Run("list flows", func(t *testing.T) { + t.Parallel() + + rec2 := doRequest(t, h, e, http.MethodGet, "/flows", nil) + if rec2.Code != http.StatusOK { + t.Errorf("got %d want 200", rec2.Code) + } + }) + + t.Run("prepare flow", func(t *testing.T) { + t.Parallel() + + h2, e2 := setupHandler(t) + r := doRequest(t, h2, e2, http.MethodPost, "/flows", createBody) + + var resp map[string]any + _ = json.Unmarshal(r.Body.Bytes(), &resp) + id := resp["id"].(string) + + rec2 := doRequest(t, h2, e2, http.MethodPost, "/flows/"+id+"/prepare", nil) + if rec2.Code != http.StatusAccepted { + t.Errorf("got %d want 202", rec2.Code) + } + }) +} + +func TestHandlerPromptCRUD(t *testing.T) { + t.Parallel() + + h, e := setupHandler(t) + + createBody := map[string]any{ + "name": "test-prompt", + "defaultVariant": "v1", + "variants": []any{ + map[string]any{ + "name": "v1", + "templateType": "TEXT", + }, + }, + } + + rec := doRequest(t, h, e, http.MethodPost, "/prompts", createBody) + if rec.Code != http.StatusCreated { + t.Fatalf("create prompt: %d %s", rec.Code, rec.Body.String()) + } + + var createResp map[string]any + _ = json.Unmarshal(rec.Body.Bytes(), &createResp) + promptID, _ := createResp["id"].(string) + + if promptID == "" { + t.Fatal("no id in prompt response") + } + + t.Run("get prompt", func(t *testing.T) { + t.Parallel() + + h2, e2 := setupHandler(t) + r := doRequest(t, h2, e2, http.MethodPost, "/prompts", createBody) + + var resp map[string]any + _ = json.Unmarshal(r.Body.Bytes(), &resp) + id := resp["id"].(string) + + rec2 := doRequest(t, h2, e2, http.MethodGet, "/prompts/"+id, nil) + if rec2.Code != http.StatusOK { + t.Errorf("got %d want 200", rec2.Code) + } + }) + + t.Run("list prompts", func(t *testing.T) { + t.Parallel() + + rec2 := doRequest(t, h, e, http.MethodGet, "/prompts", nil) + if rec2.Code != http.StatusOK { + t.Errorf("got %d want 200", rec2.Code) + } + }) + + t.Run("create prompt version", func(t *testing.T) { + t.Parallel() + + h2, e2 := setupHandler(t) + r := doRequest(t, h2, e2, http.MethodPost, "/prompts", createBody) + + var resp map[string]any + _ = json.Unmarshal(r.Body.Bytes(), &resp) + id := resp["id"].(string) + + rec2 := doRequest(t, h2, e2, http.MethodPost, "/prompts/"+id+"/versions", map[string]any{ + "description": "v1", + }) + if rec2.Code != http.StatusCreated { + t.Errorf("got %d want 201: %s", rec2.Code, rec2.Body.String()) + } + }) +} + +func TestHandlerTagging(t *testing.T) { + t.Parallel() + + h, e := setupHandler(t) + + createBody := map[string]any{ + "agentName": "tagging-agent", + "foundationModel": "anthropic.claude-v2", + "agentResourceRoleArn": "arn:aws:iam::123456789012:role/AmazonBedrockRole", + } + + rec := doRequest(t, h, e, http.MethodPut, "/agents", createBody) + + var createResp map[string]map[string]any + _ = json.Unmarshal(rec.Body.Bytes(), &createResp) + arn := createResp["agent"]["agentArn"].(string) + + t.Run("tag resource", func(t *testing.T) { + t.Parallel() + + rec2 := doRequest(t, h, e, http.MethodPost, "/tags/"+arn, map[string]any{ + "tags": map[string]string{"env": "test"}, + }) + if rec2.Code != http.StatusNoContent { + t.Errorf("tag: got %d want 204", rec2.Code) + } + }) + + t.Run("list tags", func(t *testing.T) { + t.Parallel() + + rec2 := doRequest(t, h, e, http.MethodGet, "/tags/"+arn, nil) + if rec2.Code != http.StatusOK { + t.Errorf("list tags: got %d want 200", rec2.Code) + } + }) +} + +func TestHandlerDataSourceAndIngestion(t *testing.T) { + t.Parallel() + + h, e := setupHandler(t) + + kbBody := map[string]any{ + "name": "ingestion-kb", + "roleArn": "arn:aws:iam::123456789012:role/KBRole", + "knowledgeBaseConfiguration": map[string]any{"type": "VECTOR"}, + "storageConfiguration": map[string]any{"type": "OPENSEARCH_SERVERLESS"}, + } + + kbRec := doRequest(t, h, e, http.MethodPut, "/knowledgebases", kbBody) + if kbRec.Code != http.StatusOK { + t.Fatalf("create kb: %d", kbRec.Code) + } + + var kbResp map[string]map[string]any + _ = json.Unmarshal(kbRec.Body.Bytes(), &kbResp) + kbID := kbResp["knowledgeBase"]["knowledgeBaseId"].(string) + + dsBody := map[string]any{ + "name": "test-ds", + "dataSourceConfiguration": map[string]any{"type": "S3"}, + } + + dsRec := doRequest(t, h, e, http.MethodPut, "/knowledgebases/"+kbID+"/datasources", dsBody) + if dsRec.Code != http.StatusOK { + t.Fatalf("create ds: %d %s", dsRec.Code, dsRec.Body.String()) + } + + var dsResp map[string]map[string]any + _ = json.Unmarshal(dsRec.Body.Bytes(), &dsResp) + dsID := dsResp["dataSource"]["dataSourceId"].(string) + + t.Run("start ingestion job", func(t *testing.T) { + t.Parallel() + + rec := doRequest(t, h, e, http.MethodPut, + "/knowledgebases/"+kbID+"/datasources/"+dsID+"/ingestionjobs", nil) + if rec.Code != http.StatusAccepted { + t.Errorf("got %d want 202: %s", rec.Code, rec.Body.String()) + } + }) + + t.Run("list ingestion jobs", func(t *testing.T) { + t.Parallel() + + rec := doRequest(t, h, e, http.MethodGet, + "/knowledgebases/"+kbID+"/datasources/"+dsID+"/ingestionjobs", nil) + if rec.Code != http.StatusOK { + t.Errorf("got %d want 200", rec.Code) + } + }) +} + +func TestHandlerClassifyPath(t *testing.T) { + t.Parallel() + + b := bedrockagent.NewTestBackend("us-east-1", "123456789012") + h := bedrockagent.NewTestHandler(b) + h.AccountID = "123456789012" + h.DefaultRegion = "us-east-1" + e := echo.New() + + cases := []struct { + method string + path string + wantOp string + }{ + {http.MethodPut, "/agents", "CreateAgent"}, + {http.MethodGet, "/agents", "ListAgents"}, + {http.MethodGet, "/agents/abc123", "GetAgent"}, + {http.MethodDelete, "/agents/abc123", "DeleteAgent"}, + {http.MethodPut, "/knowledgebases", "CreateKnowledgeBase"}, + {http.MethodGet, "/knowledgebases", "ListKnowledgeBases"}, + {http.MethodPost, "/flows", "CreateFlow"}, + {http.MethodGet, "/flows", "ListFlows"}, + {http.MethodPost, "/prompts", "CreatePrompt"}, + {http.MethodGet, "/prompts", "ListPrompts"}, + {http.MethodGet, "/tags/arn:aws:bedrock:us-east-1::agent/abc", "ListTagsForResource"}, + {http.MethodPost, "/tags/arn:aws:bedrock:us-east-1::agent/abc", "TagResource"}, + {http.MethodDelete, "/tags/arn:aws:bedrock:us-east-1::agent/abc", "UntagResource"}, + } + + for _, tc := range cases { + t.Run(tc.method+":"+tc.path, func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(tc.method, tc.path, nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + got := h.ExtractOperation(c) + if got != tc.wantOp { + t.Errorf("got %q want %q", got, tc.wantOp) + } + }) + } +} + +func TestHandlerAgentVersions(t *testing.T) { + t.Parallel() + + h, e := setupHandler(t) + + createBody := map[string]any{ + "agentName": "version-agent", + "foundationModel": "anthropic.claude-v2", + "agentResourceRoleArn": "arn:aws:iam::123456789012:role/AmazonBedrockRole", + } + + rec := doRequest(t, h, e, http.MethodPut, "/agents", createBody) + + var createResp map[string]map[string]any + _ = json.Unmarshal(rec.Body.Bytes(), &createResp) + agentID := createResp["agent"]["agentId"].(string) + + // Prepare first so we can create a version + doRequest(t, h, e, http.MethodPost, "/agents/"+agentID+"/prepare", nil) + + t.Run("create version", func(t *testing.T) { + t.Parallel() + + h2, e2 := setupHandler(t) + r := doRequest(t, h2, e2, http.MethodPut, "/agents", createBody) + + var resp map[string]map[string]any + _ = json.Unmarshal(r.Body.Bytes(), &resp) + aid := resp["agent"]["agentId"].(string) + doRequest(t, h2, e2, http.MethodPost, "/agents/"+aid+"/prepare", nil) + + rec2 := doRequest(t, h2, e2, http.MethodPost, "/agents/"+aid+"/agentversions", map[string]any{ + "description": "initial version", + }) + if rec2.Code != http.StatusOK { + t.Errorf("got %d want 200: %s", rec2.Code, rec2.Body.String()) + } + }) + + t.Run("list versions", func(t *testing.T) { + t.Parallel() + + rec2 := doRequest(t, h, e, http.MethodGet, "/agents/"+agentID+"/agentversions", nil) + if rec2.Code != http.StatusOK { + t.Errorf("got %d want 200", rec2.Code) + } + }) +} + +func TestHandlerBackendReset(t *testing.T) { + t.Parallel() + + b := bedrockagent.NewTestBackend("us-east-1", "123456789012") + h := bedrockagent.NewTestHandler(b) + h.AccountID = "123456789012" + h.DefaultRegion = "us-east-1" + e := echo.New() + + createBody := map[string]any{ + "agentName": "reset-agent", + "foundationModel": "anthropic.claude-v2", + "agentResourceRoleArn": "arn:aws:iam::123456789012:role/AmazonBedrockRole", + } + + doRequest(t, h, e, http.MethodPut, "/agents", createBody) + + ctx := context.Background() + agents, _, err := b.ListAgents(ctx, 10, "") + if err != nil { + t.Fatal(err) + } + + if len(agents) == 0 { + t.Fatal("expected agent after create") + } + + h.Reset() + + agents, _, err = b.ListAgents(ctx, 10, "") + if err != nil { + t.Fatal(err) + } + + if len(agents) != 0 { + t.Fatalf("expected empty after reset, got %d", len(agents)) + } +} diff --git a/services/bedrockagent/interfaces.go b/services/bedrockagent/interfaces.go new file mode 100644 index 000000000..3ed2099a6 --- /dev/null +++ b/services/bedrockagent/interfaces.go @@ -0,0 +1,163 @@ +package bedrockagent + +import "context" + +// StorageBackend defines all persistence operations for the Bedrock Agent service. +type StorageBackend interface { + // Agent operations. + CreateAgent(ctx context.Context, cfg AgentConfig) (*Agent, error) + GetAgent(ctx context.Context, agentID string) (*Agent, error) + UpdateAgent(ctx context.Context, agentID string, cfg AgentConfig) (*Agent, error) + DeleteAgent(ctx context.Context, agentID string) error + ListAgents(ctx context.Context, maxResults int, nextToken string) ([]*AgentSummary, string, error) + PrepareAgent(ctx context.Context, agentID string) (*Agent, error) + + // Agent version operations. + CreateAgentVersion(ctx context.Context, agentID, description string) (*AgentVersion, error) + GetAgentVersion(ctx context.Context, agentID, agentVersion string) (*AgentVersion, error) + DeleteAgentVersion(ctx context.Context, agentID, agentVersion string) error + ListAgentVersions( + ctx context.Context, agentID string, maxResults int, nextToken string, + ) ([]*AgentVersionSummary, string, error) + + // Agent action group operations. + CreateAgentActionGroup( + ctx context.Context, agentID string, cfg ActionGroupConfig, + ) (*AgentActionGroup, error) + GetAgentActionGroup( + ctx context.Context, agentID, agentVersion, actionGroupID string, + ) (*AgentActionGroup, error) + UpdateAgentActionGroup( + ctx context.Context, agentID, agentVersion, actionGroupID string, cfg ActionGroupConfig, + ) (*AgentActionGroup, error) + DeleteAgentActionGroup( + ctx context.Context, agentID, agentVersion, actionGroupID string, + ) error + ListAgentActionGroups( + ctx context.Context, agentID, agentVersion string, maxResults int, nextToken string, + ) ([]*ActionGroupSummary, string, error) + + // Agent alias operations. + CreateAgentAlias(ctx context.Context, agentID string, cfg AliasConfig) (*AgentAlias, error) + GetAgentAlias(ctx context.Context, agentID, agentAliasID string) (*AgentAlias, error) + UpdateAgentAlias(ctx context.Context, agentID, agentAliasID string, cfg AliasConfig) (*AgentAlias, error) + DeleteAgentAlias(ctx context.Context, agentID, agentAliasID string) error + ListAgentAliases( + ctx context.Context, agentID string, maxResults int, nextToken string, + ) ([]*AgentAliasSummary, string, error) + + // Agent collaborator operations. + AssociateAgentCollaborator( + ctx context.Context, agentID, agentVersion string, cfg CollaboratorConfig, + ) (*AgentCollaborator, error) + GetAgentCollaborator( + ctx context.Context, agentID, agentVersion, collaboratorID string, + ) (*AgentCollaborator, error) + UpdateAgentCollaborator( + ctx context.Context, agentID, agentVersion, collaboratorID string, cfg CollaboratorConfig, + ) (*AgentCollaborator, error) + DisassociateAgentCollaborator( + ctx context.Context, agentID, agentVersion, collaboratorID string, + ) error + ListAgentCollaborators( + ctx context.Context, agentID, agentVersion string, maxResults int, nextToken string, + ) ([]*AgentCollaborator, string, error) + + // Knowledge base operations. + CreateKnowledgeBase(ctx context.Context, cfg KnowledgeBaseConfig) (*KnowledgeBase, error) + GetKnowledgeBase(ctx context.Context, kbID string) (*KnowledgeBase, error) + UpdateKnowledgeBase(ctx context.Context, kbID string, cfg KnowledgeBaseConfig) (*KnowledgeBase, error) + DeleteKnowledgeBase(ctx context.Context, kbID string) error + ListKnowledgeBases(ctx context.Context, maxResults int, nextToken string) ([]*KnowledgeBaseSummary, string, error) + + // Agent–knowledge base association operations. + AssociateAgentKnowledgeBase( + ctx context.Context, agentID, agentVersion, kbID, description, kbState string, + ) (*AgentKnowledgeBase, error) + GetAgentKnowledgeBase( + ctx context.Context, agentID, agentVersion, kbID string, + ) (*AgentKnowledgeBase, error) + UpdateAgentKnowledgeBase( + ctx context.Context, agentID, agentVersion, kbID, description, kbState string, + ) (*AgentKnowledgeBase, error) + DisassociateAgentKnowledgeBase( + ctx context.Context, agentID, agentVersion, kbID string, + ) error + ListAgentKnowledgeBases( + ctx context.Context, agentID, agentVersion string, maxResults int, nextToken string, + ) ([]*AgentKnowledgeBase, string, error) + + // Data source operations. + CreateDataSource(ctx context.Context, kbID string, cfg DataSourceConfig) (*DataSource, error) + GetDataSource(ctx context.Context, kbID, dataSourceID string) (*DataSource, error) + UpdateDataSource(ctx context.Context, kbID, dataSourceID string, cfg DataSourceConfig) (*DataSource, error) + DeleteDataSource(ctx context.Context, kbID, dataSourceID string) error + ListDataSources( + ctx context.Context, kbID string, maxResults int, nextToken string, + ) ([]*DataSourceSummary, string, error) + + // Ingestion job operations. + StartIngestionJob(ctx context.Context, kbID, dataSourceID, description string) (*IngestionJob, error) + GetIngestionJob(ctx context.Context, kbID, dataSourceID, ingestionJobID string) (*IngestionJob, error) + StopIngestionJob(ctx context.Context, kbID, dataSourceID, ingestionJobID string) (*IngestionJob, error) + ListIngestionJobs( + ctx context.Context, kbID, dataSourceID string, maxResults int, nextToken string, + ) ([]*IngestionJob, string, error) + + // Flow operations. + CreateFlow(ctx context.Context, cfg FlowConfig) (*Flow, error) + GetFlow(ctx context.Context, flowID string) (*Flow, error) + UpdateFlow(ctx context.Context, flowID string, cfg FlowConfig) (*Flow, error) + DeleteFlow(ctx context.Context, flowID string) error + ListFlows(ctx context.Context, maxResults int, nextToken string) ([]*FlowSummary, string, error) + PrepareFlow(ctx context.Context, flowID string) (*Flow, error) + ValidateFlowDefinition(ctx context.Context, definition map[string]any) ([]FlowValidationError, error) + + // Flow version operations. + CreateFlowVersion(ctx context.Context, flowID, description string) (*FlowVersion, error) + GetFlowVersion(ctx context.Context, flowID, flowVersion string) (*FlowVersion, error) + DeleteFlowVersion(ctx context.Context, flowID, flowVersion string) error + ListFlowVersions( + ctx context.Context, flowID string, maxResults int, nextToken string, + ) ([]*FlowVersionSummary, string, error) + + // Flow alias operations. + CreateFlowAlias(ctx context.Context, flowID string, cfg FlowAliasConfig) (*FlowAlias, error) + GetFlowAlias(ctx context.Context, flowID, aliasID string) (*FlowAlias, error) + UpdateFlowAlias(ctx context.Context, flowID, aliasID string, cfg FlowAliasConfig) (*FlowAlias, error) + DeleteFlowAlias(ctx context.Context, flowID, aliasID string) error + ListFlowAliases( + ctx context.Context, flowID string, maxResults int, nextToken string, + ) ([]*FlowAliasSummary, string, error) + + // Prompt operations. + CreatePrompt(ctx context.Context, cfg PromptConfig) (*Prompt, error) + GetPrompt(ctx context.Context, promptID string) (*Prompt, error) + UpdatePrompt(ctx context.Context, promptID string, cfg PromptConfig) (*Prompt, error) + DeletePrompt(ctx context.Context, promptID string) error + ListPrompts(ctx context.Context, maxResults int, nextToken string) ([]*PromptSummary, string, error) + + // Prompt version operations. + CreatePromptVersion(ctx context.Context, promptID, description string) (*PromptVersion, error) + GetPromptVersion(ctx context.Context, promptID, version string) (*PromptVersion, error) + DeletePromptVersion(ctx context.Context, promptID, version string) error + + // Knowledge base document operations. + IngestKnowledgeBaseDocuments( + ctx context.Context, kbID, dataSourceID string, docs []KBDocument, + ) ([]KBDocumentDetail, error) + GetKnowledgeBaseDocuments( + ctx context.Context, kbID, dataSourceID string, docIDs []string, + ) ([]KBDocumentDetail, error) + DeleteKnowledgeBaseDocuments( + ctx context.Context, kbID, dataSourceID string, docIDs []string, + ) ([]KBDocumentDetail, error) + ListKnowledgeBaseDocuments( + ctx context.Context, kbID, dataSourceID string, maxResults int, nextToken string, + ) ([]KBDocumentDetail, string, error) + + // Tagging operations. + ListTagsForResource(ctx context.Context, resourceARN string) (map[string]string, error) + TagResource(ctx context.Context, resourceARN string, tags map[string]string) error + UntagResource(ctx context.Context, resourceARN string, tagKeys []string) error +} diff --git a/services/bedrockagent/provider.go b/services/bedrockagent/provider.go new file mode 100644 index 000000000..a8c59365e --- /dev/null +++ b/services/bedrockagent/provider.go @@ -0,0 +1,43 @@ +// Package bedrockagent provides a local stub for the Amazon Bedrock Agent service. +package bedrockagent + +import ( + "errors" + + "github.com/blackbirdworks/gopherstack/pkgs/config" + "github.com/blackbirdworks/gopherstack/pkgs/service" +) + +// ErrNilAppContext is returned when a nil AppContext is passed to Provider.Init. +var ErrNilAppContext = errors.New("bedrockagent: AppContext must not be nil") + +// Provider implements service.Provider for the Bedrock Agent service. +type Provider struct{} + +// Name returns the provider name. +func (p *Provider) Name() string { return "BedrockAgent" } + +// Init initialises the Bedrock Agent backend and handler. +// +//nolint:ireturn,nolintlint // architecturally required to return interface +func (p *Provider) Init(ctx *service.AppContext) (service.Registerable, error) { + if ctx == nil { + return nil, ErrNilAppContext + } + + accountID := config.DefaultAccountID + region := config.DefaultRegion + + if cp, ok := ctx.Config.(config.Provider); ok { + cfg := cp.GetGlobalConfig() + accountID = cfg.GetAccountID() + region = cfg.GetRegion() + } + + backend := NewInMemoryBackend(region, accountID) + handler := NewHandler(backend) + handler.AccountID = accountID + handler.DefaultRegion = region + + return handler, nil +} diff --git a/services/bedrockagent/sdk_completeness_test.go b/services/bedrockagent/sdk_completeness_test.go new file mode 100644 index 000000000..c06fbc6cc --- /dev/null +++ b/services/bedrockagent/sdk_completeness_test.go @@ -0,0 +1,19 @@ +package bedrockagent_test + +import ( + "testing" + + bedrockagentsdk "github.com/aws/aws-sdk-go-v2/service/bedrockagent" + + "github.com/blackbirdworks/gopherstack/pkgs/sdkcheck" + "github.com/blackbirdworks/gopherstack/services/bedrockagent" +) + +func TestSDKCompleteness(t *testing.T) { + t.Parallel() + + b := bedrockagent.NewTestBackend("us-east-1", "123456789012") + h := bedrockagent.NewTestHandler(b) + + sdkcheck.CheckCompleteness(t, &bedrockagentsdk.Client{}, h.GetSupportedOperations(), []string{}) +} diff --git a/test/terraform/main_test.go b/test/terraform/main_test.go index 13c1afe7c..7c454aa62 100644 --- a/test/terraform/main_test.go +++ b/test/terraform/main_test.go @@ -32,6 +32,7 @@ import ( backupsvc "github.com/aws/aws-sdk-go-v2/service/backup" batchsvc "github.com/aws/aws-sdk-go-v2/service/batch" bedrocksvc "github.com/aws/aws-sdk-go-v2/service/bedrock" + bedrockagentsvc "github.com/aws/aws-sdk-go-v2/service/bedrockagent" bedrockruntimesvc "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" cloudcontrolsvc "github.com/aws/aws-sdk-go-v2/service/cloudcontrol" cfnsvc "github.com/aws/aws-sdk-go-v2/service/cloudformation" @@ -2368,3 +2369,20 @@ func createS3TablesClient(t *testing.T) *s3tablessvc.Client { o.BaseEndpoint = aws.String(endpoint) }) } + +func createBedrockAgentClient(t *testing.T) *bedrockagentsvc.Client { + t.Helper() + + cfg, err := config.LoadDefaultConfig( + t.Context(), + config.WithRegion("us-east-1"), + config.WithCredentialsProvider( + credentials.NewStaticCredentialsProvider("test", "test", ""), + ), + ) + require.NoError(t, err, "unable to load SDK config") + + return bedrockagentsvc.NewFromConfig(cfg, func(o *bedrockagentsvc.Options) { + o.BaseEndpoint = aws.String(endpoint) + }) +} diff --git a/test/terraform/terraform_test.go b/test/terraform/terraform_test.go index f8c87e597..b392a4193 100644 --- a/test/terraform/terraform_test.go +++ b/test/terraform/terraform_test.go @@ -39,6 +39,7 @@ import ( backupsvc "github.com/aws/aws-sdk-go-v2/service/backup" batchsvc "github.com/aws/aws-sdk-go-v2/service/batch" bedrocksvc "github.com/aws/aws-sdk-go-v2/service/bedrock" + bedrockagentsvc "github.com/aws/aws-sdk-go-v2/service/bedrockagent" bedrockruntimesvc "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" bedrockruntimetypes "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" cloudcontrolsvc "github.com/aws/aws-sdk-go-v2/service/cloudcontrol" @@ -6842,3 +6843,39 @@ func TestTerraform_CachingMessagingComprehensive(t *testing.T) { }) } } + +// TestTerraform_MegaBatch4 provisions Bedrock Agent resources and verifies they exist. +func TestTerraform_MegaBatch4(t *testing.T) { + t.Parallel() + + tests := []tfTestCase{ + { + name: "success", + fixture: "mega-batch-4", + setup: func(t *testing.T, _ string) map[string]any { + t.Helper() + + return map[string]any{} + }, + verify: func(t *testing.T, ctx context.Context, vars map[string]any) { + t.Helper() + client := createBedrockAgentClient(t) + + agentsOut, err := client.ListAgents(ctx, &bedrockagentsvc.ListAgentsInput{}) + require.NoError(t, err, "ListAgents should succeed") + require.NotEmpty(t, agentsOut.AgentSummaries, "at least one agent should exist") + + kbOut, err := client.ListKnowledgeBases(ctx, &bedrockagentsvc.ListKnowledgeBasesInput{}) + require.NoError(t, err, "ListKnowledgeBases should succeed") + require.NotEmpty(t, kbOut.KnowledgeBaseSummaries, "at least one knowledge base should exist") + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + runTFTest(t, tc) + }) + } +}