diff --git a/cli.go b/cli.go index 01c31e8f5..23e2260d7 100644 --- a/cli.go +++ b/cli.go @@ -146,6 +146,7 @@ import ( mqbackend "github.com/blackbirdworks/gopherstack/services/mq" mwaabackend "github.com/blackbirdworks/gopherstack/services/mwaa" neptunebackend "github.com/blackbirdworks/gopherstack/services/neptune" + networkmonitorbackend "github.com/blackbirdworks/gopherstack/services/networkmonitor" omicsbackend "github.com/blackbirdworks/gopherstack/services/omics" opensearchbackend "github.com/blackbirdworks/gopherstack/services/opensearch" organizationsbackend "github.com/blackbirdworks/gopherstack/services/organizations" @@ -271,6 +272,7 @@ type CLI struct { resourcegroupstaggingHandler service.Registerable swfHandler service.Registerable firehoseHandler service.Registerable + networkmonitorHandler service.Registerable schedulerHandler service.Registerable servicediscoveryHandler service.Registerable transcribeHandler service.Registerable @@ -1145,6 +1147,11 @@ func (c *CLI) GetSWFHandler() service.Registerable { return c.swfHandler } //nolint:ireturn // architecturally required to return interface func (c *CLI) GetFirehoseHandler() service.Registerable { return c.firehoseHandler } +// GetNetworkMonitorHandler returns the NetworkMonitor handler. +// +//nolint:ireturn // architecturally required to return interface +func (c *CLI) GetNetworkMonitorHandler() service.Registerable { return c.networkmonitorHandler } + // GetSchedulerHandler returns the Scheduler handler (dashboard.AWSSDKProvider). // //nolint:ireturn // architecturally required to return interface @@ -2258,6 +2265,7 @@ func storeCLIHandlers(cli *CLI, services []service.Registerable) { cli.resourcegroupstaggingHandler = byName["ResourceGroupsTaggingAPI"] cli.swfHandler = byName["SWF"] cli.firehoseHandler = byName["Firehose"] + cli.networkmonitorHandler = byName["NetworkMonitor"] cli.schedulerHandler = byName["Scheduler"] cli.route53resolverHandler = byName["Route53Resolver"] cli.rdsHandler = byName["RDS"] @@ -2634,6 +2642,7 @@ func getRemainingServiceProviders() []service.Provider { &resourcegroupstaggingapibackend.Provider{}, &swfbackend.Provider{}, &firehosebackend.Provider{}, + &networkmonitorbackend.Provider{}, &schedulerbackend.Provider{}, &route53resolverbackend.Provider{}, &rdsbackend.Provider{}, @@ -2699,7 +2708,6 @@ func getRemainingServiceProviders() []service.Provider { &mediaconvertbackend.Provider{}, &mqbackend.Provider{}, &mediastorebackend.Provider{}, - &mediastoredatabackend.Provider{}, }, getLatestServiceProviders()...) } @@ -2707,6 +2715,7 @@ func getRemainingServiceProviders() []service.Provider { // Extracted from getServiceProviders to satisfy the funlen limit. func getLatestServiceProviders() []service.Provider { return append([]service.Provider{ + &mediastoredatabackend.Provider{}, &memorydbbackend.Provider{}, }, getNewestServiceProviders()...) } diff --git a/go.mod b/go.mod index 6076f704a..e46977fd7 100644 --- a/go.mod +++ b/go.mod @@ -202,7 +202,9 @@ require ( github.com/aws/aws-sdk-go-v2/service/appstream v1.60.3 ) -require github.com/aws/aws-sdk-go-v2/service/omics v1.45.0 // indirect +require github.com/aws/aws-sdk-go-v2/service/networkmonitor v1.14.6 + +require github.com/aws/aws-sdk-go-v2/service/omics v1.45.0 require ( github.com/antlr/antlr4 v0.0.0-20181218183524-be58ebffde8e // indirect diff --git a/go.sum b/go.sum index bea54cbe6..9df608d85 100644 --- a/go.sum +++ b/go.sum @@ -30,8 +30,6 @@ github.com/aws/aws-dax-go v1.2.15 h1:30rH3+QgjpjemrVg0NGIG5FnB1izJZ7jUZuBb1Fy8ak github.com/aws/aws-dax-go v1.2.15/go.mod h1:4f/qGLBQlPYd+fmAfG4n4oSvN19JdKNYYmsr90/MPso= github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU= github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= -github.com/aws/aws-sdk-go-v2 v1.41.11 h1:9PRf7jyTMEUM6fuNRAJa2mO/skJfrF50rENJwf2LXqw= -github.com/aws/aws-sdk-go-v2 v1.41.11/go.mod h1:iiUX27gOXRuYaoeUVXhUpPwjJHzISfPAjjcuhUbLSVs= github.com/aws/aws-sdk-go-v2 v1.42.0 h1:XvXMJTkFQtpBKIWZnmr9ZEOc2InWM2yldjXEJ/bymhA= github.com/aws/aws-sdk-go-v2 v1.42.0/go.mod h1:27+ACypSLljLAEKsCYOmrjKh83vuTRkuAe9Uv/3A4bg= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10 h1:gx1AwW1Iyk9Z9dD9F4akX5gnN3QZwUB20GGKH/I+Rho= @@ -42,12 +40,8 @@ github.com/aws/aws-sdk-go-v2/credentials v1.19.17 h1:gP2nkGsS+KMvF/jfFz2Vv2qiiOq github.com/aws/aws-sdk-go-v2/credentials v1.19.17/go.mod h1:Bsew3S/moG5iT77giPj1q8wb/s0RE5/QfH+ASjYtuQc= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 h1:UuSfcORqNSz/ey3VPRS8TcVH2Ikf0/sC+Hdj400QI6U= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23/go.mod h1:+G/OSGiOFnSOkYloKj/9M35s74LgVAdJBSD5lsFfqKg= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.27 h1:8sPbKi1/KRHwl5oR3qN9mUXestCeHuaRutxylnr/eVY= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.27/go.mod h1:QV9IVIopJ1dpQUno0f9VYDUwOEjj8u0iEJ4JiZVre3Y= github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.29 h1:f3vKqSo13fhTYb+JEcXwXefZQE26I1FB5eTSniU67ko= github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.29/go.mod h1:MzoLFUArKGpGD+ukmPiTPG1X5x4o6M2kq4v2dr1FiEc= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.27 h1:9d8AoASQY9UwrOSmiJ7uSM0MGUPFhnenwSvpaFfat2c= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.27/go.mod h1:x0rldpsnUQaQIs4Rh+Vwm9Z/0vI6BxadGtsgJfZFb8s= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.29 h1:RdwIf/CuUsvJX3RgJagbOyotl/cxoLY4xviKuE7p2GY= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.29/go.mod h1:71wt8W2EgswdZy9Mf9KNnzxZ3TiZlv4caKghPktDOkA= github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 h1:OQqn11BtaYv1WLUowvcA30MpzIu8Ti4pcLPIIyoKZrA= @@ -252,6 +246,8 @@ github.com/aws/aws-sdk-go-v2/service/mwaa v1.40.1 h1:fVJ6uZsjHWiMbT4tPsNPmA5VEln github.com/aws/aws-sdk-go-v2/service/mwaa v1.40.1/go.mod h1:O2kdYrYDkdHdUpvYcrWQoYaJgPJCGHdaLF06ep3HzKI= github.com/aws/aws-sdk-go-v2/service/neptune v1.44.1 h1:Akp3PA62O+s8Tze5/4f3YthrAT6OHhay9QhUPv7zw6Y= github.com/aws/aws-sdk-go-v2/service/neptune v1.44.1/go.mod h1:YNCq3tFuuD8sxXKR9NC7n8gdafguPzxp+WFKuivplTw= +github.com/aws/aws-sdk-go-v2/service/networkmonitor v1.14.6 h1:SE8N02CtLWVY/3UM+mqAA5tcnmC6fKhUg3R9gVyk120= +github.com/aws/aws-sdk-go-v2/service/networkmonitor v1.14.6/go.mod h1:0OibDBxn2Uj3XhWpXkqAQkom2usZf200CpfV2K1PirI= github.com/aws/aws-sdk-go-v2/service/omics v1.45.0 h1:eYMKiWBNj6Q61yQCN25B0SLfb1VFLV0axyrrB6rXW9w= github.com/aws/aws-sdk-go-v2/service/omics v1.45.0/go.mod h1:czP0k4tesyOIDgyiUY0jP8AxVtxCy9keGlLm1IheCmQ= github.com/aws/aws-sdk-go-v2/service/opensearch v1.59.0 h1:rhlLa2SwSeKva0DEUrbRc5DN2bAsIPvcmW7N/c2MUgI= @@ -362,8 +358,6 @@ github.com/aws/aws-sdk-go-v2/service/workspaces v1.68.3 h1:VdduyWoOF4l/GUaNfSIFE github.com/aws/aws-sdk-go-v2/service/workspaces v1.68.3/go.mod h1:CuyzqbKdY8lN//0RPBb7OkQ9YRFYBFpK5SQjlANpWJI= github.com/aws/aws-sdk-go-v2/service/xray v1.36.20 h1:5V3CHiHP3OHaeB6e1tOC2hw5FrHkxepAho+4MEJG4QM= github.com/aws/aws-sdk-go-v2/service/xray v1.36.20/go.mod h1:sgjg2v2UIv+sDFiig3tbkJ4sGSQrXQ2f+YgWg8TLOu4= -github.com/aws/smithy-go v1.27.0 h1:ZoFioDKJxkSIW2otF9T0aPtNlUwhdVCcuZh/rzH9Hus= -github.com/aws/smithy-go v1.27.0/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/aws/smithy-go v1.27.1 h1:4T340VFndXtADGF52gYa1POyL7s9E4Z1OeZ1hCscIw8= github.com/aws/smithy-go v1.27.1/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= diff --git a/internal/teststack/teststack.go b/internal/teststack/teststack.go index 229e508b8..179f1bf01 100644 --- a/internal/teststack/teststack.go +++ b/internal/teststack/teststack.go @@ -94,6 +94,7 @@ import ( mqbackend "github.com/blackbirdworks/gopherstack/services/mq" mwaabackend "github.com/blackbirdworks/gopherstack/services/mwaa" neptunebackend "github.com/blackbirdworks/gopherstack/services/neptune" + networkmonitorbackend "github.com/blackbirdworks/gopherstack/services/networkmonitor" opensearchbackend "github.com/blackbirdworks/gopherstack/services/opensearch" organizationsbackend "github.com/blackbirdworks/gopherstack/services/organizations" pinpointbackend "github.com/blackbirdworks/gopherstack/services/pinpoint" @@ -189,6 +190,7 @@ type Stack struct { ResourceGroupsTaggingHandler *rgtabackend.Handler SWFHandler *swfbackend.Handler FirehoseHandler *firehosebackend.Handler + NetworkMonitorHandler *networkmonitorbackend.Handler SchedulerHandler *schedulerbackend.Handler Route53ResolverHandler *route53resolverbackend.Handler TranscribeHandler *transcribebackend.Handler @@ -562,6 +564,7 @@ func registerLatestServices(registry *service.Registry, h handlers) { _ = registry.Register(h.wafv2) _ = registry.Register(h.xray) _ = registry.Register(h.s3tables) + _ = registry.Register(h.networkmonitor) } // handlers bundles all service handlers created for a test stack. @@ -606,6 +609,7 @@ type handlers struct { rgtagging *rgtabackend.Handler swf *swfbackend.Handler firehose *firehosebackend.Handler + networkmonitor *networkmonitorbackend.Handler scheduler *schedulerbackend.Handler route53resolver *route53resolverbackend.Handler transcribe *transcribebackend.Handler @@ -788,6 +792,11 @@ func populateExtendedHandlers(h *handlers) { h.firehose = firehosebackend.NewHandler( firehosebackend.NewInMemoryBackend(config.DefaultAccountID, config.DefaultRegion), ) + h.networkmonitor = networkmonitorbackend.NewHandler( + networkmonitorbackend.NewInMemoryBackend(config.DefaultRegion, config.DefaultAccountID), + ) + h.networkmonitor.AccountID = config.DefaultAccountID + h.networkmonitor.DefaultRegion = config.DefaultRegion h.scheduler = schedulerbackend.NewHandler( schedulerbackend.NewInMemoryBackend(config.DefaultAccountID, config.DefaultRegion), ) @@ -1231,6 +1240,7 @@ func buildStack( ResourceGroupsTaggingHandler: h.rgtagging, SWFHandler: h.swf, FirehoseHandler: h.firehose, + NetworkMonitorHandler: h.networkmonitor, SchedulerHandler: h.scheduler, Route53ResolverHandler: h.route53resolver, TranscribeHandler: h.transcribe, diff --git a/services/networkmonitor/backend.go b/services/networkmonitor/backend.go new file mode 100644 index 000000000..ede8c6906 --- /dev/null +++ b/services/networkmonitor/backend.go @@ -0,0 +1,808 @@ +package networkmonitor + +import ( + "context" + "fmt" + "maps" + "regexp" + "sort" + "strings" + "sync" + "time" + + "github.com/blackbirdworks/gopherstack/pkgs/arn" + "github.com/blackbirdworks/gopherstack/pkgs/awserr" +) + +// regionContextKey is the context key under which the per-request AWS region is stored. +type regionContextKey struct{} + +// getRegion extracts the region from ctx, falling back to defaultRegion when unset. +func getRegion(ctx context.Context, defaultRegion string) string { + if r, ok := ctx.Value(regionContextKey{}).(string); ok && r != "" { + return r + } + + return defaultRegion +} + +var ( + // ErrNotFound is returned when a monitor or probe does not exist. + ErrNotFound = awserr.New("ResourceNotFoundException", awserr.ErrNotFound) + // ErrAlreadyExists is returned when a monitor already exists. + ErrAlreadyExists = awserr.New("ConflictException", awserr.ErrAlreadyExists) + // ErrValidation is returned for invalid input parameters. + ErrValidation = awserr.New("ValidationException", awserr.ErrInvalidParameter) +) + +const ( + monitorStateActive = "ACTIVE" + monitorStatePending = "PENDING" + monitorStateDeleted = "DELETED" + + probeStateActive = "ACTIVE" + probeStatePending = "PENDING" + + defaultAggregationPeriod = int64(60) + minAggregationPeriod = int64(30) + + networkmonitorService = "networkmonitor" + + arnColonParts = 6 + probePathParts = 2 +) + +var monitorNameRE = regexp.MustCompile(`^[a-zA-Z0-9_-]{1,200}$`) + +// StorageBackend is the interface for the Network Monitor in-memory backend. +type StorageBackend interface { + CreateMonitor( + ctx context.Context, + name string, + aggregationPeriod *int64, + probes []createMonitorProbeInput, + tags map[string]string, + ) (*Monitor, error) + DeleteMonitor(ctx context.Context, name string) error + GetMonitor(ctx context.Context, name string) (*Monitor, error) + UpdateMonitor(ctx context.Context, name string, aggregationPeriod int64) (*Monitor, error) + ListMonitors( + ctx context.Context, + state, nextToken string, + maxResults int, + ) ([]monitorSummary, string, error) + CreateProbe( + ctx context.Context, + monitorName string, + probe *probeInput, + tags map[string]string, + ) (*Probe, error) + DeleteProbe(ctx context.Context, monitorName, probeID string) error + GetProbe(ctx context.Context, monitorName, probeID string) (*Probe, error) + UpdateProbe( + ctx context.Context, + monitorName, probeID string, + req *updateProbeRequest, + ) (*Probe, error) + ListTagsForResource(ctx context.Context, resourceARN string) (map[string]string, error) + TagResource(ctx context.Context, resourceARN string, tags map[string]string) error + UntagResource(ctx context.Context, resourceARN string, tagKeys []string) error +} + +// InMemoryBackend is the in-memory implementation of StorageBackend. +// Resources are nested by region for isolation. +type InMemoryBackend struct { + monitors map[string]map[string]*Monitor + arnIndex map[string]map[string]string + accountID string + defaultRegion string + nextProbeSeq int64 + mu sync.RWMutex +} + +var _ StorageBackend = (*InMemoryBackend)(nil) + +// NewInMemoryBackend creates a new InMemoryBackend. +func NewInMemoryBackend(region, accountID string) *InMemoryBackend { + return &InMemoryBackend{ + monitors: make(map[string]map[string]*Monitor), + arnIndex: make(map[string]map[string]string), + accountID: accountID, + defaultRegion: region, + } +} + +// Reset clears all backend state. +func (b *InMemoryBackend) Reset() { + b.mu.Lock() + defer b.mu.Unlock() + + b.monitors = make(map[string]map[string]*Monitor) + b.arnIndex = make(map[string]map[string]string) + b.nextProbeSeq = 0 +} + +func (b *InMemoryBackend) regionMonitors(region string) map[string]*Monitor { + if _, ok := b.monitors[region]; !ok { + b.monitors[region] = make(map[string]*Monitor) + } + + return b.monitors[region] +} + +func (b *InMemoryBackend) regionARNIndex(region string) map[string]string { + if _, ok := b.arnIndex[region]; !ok { + b.arnIndex[region] = make(map[string]string) + } + + return b.arnIndex[region] +} + +func (b *InMemoryBackend) buildMonitorARN(region, monitorName string) string { + return arn.Build(networkmonitorService, region, b.accountID, "monitor/"+monitorName) +} + +func (b *InMemoryBackend) buildProbeARN(region, monitorName, probeID string) string { + return arn.Build( + networkmonitorService, + region, + b.accountID, + fmt.Sprintf("probe/%s/%s", monitorName, probeID), + ) +} + +func (b *InMemoryBackend) nextProbeID() string { + b.nextProbeSeq++ + + return fmt.Sprintf("probe-%08d", b.nextProbeSeq) +} + +// CreateMonitor creates a new monitor. +func (b *InMemoryBackend) CreateMonitor( + ctx context.Context, + name string, + aggregationPeriod *int64, + probeInputs []createMonitorProbeInput, + tags map[string]string, +) (*Monitor, error) { + if err := validateMonitorName(name); err != nil { + return nil, err + } + + region := getRegion(ctx, b.defaultRegion) + period := defaultAggregationPeriod + + if aggregationPeriod != nil { + if *aggregationPeriod != 30 && *aggregationPeriod != 60 { + return nil, fmt.Errorf("%w: aggregationPeriod must be 30 or 60", ErrValidation) + } + + period = *aggregationPeriod + } + + b.mu.Lock() + defer b.mu.Unlock() + + rm := b.regionMonitors(region) + + if _, exists := rm[name]; exists { + return nil, fmt.Errorf("%w: monitor %q already exists", ErrAlreadyExists, name) + } + + now := time.Now().UTC() + monARN := b.buildMonitorARN(region, name) + + var probes []*Probe + + for _, pi := range probeInputs { + if err := validateProbeInput(pi.Destination, pi.Protocol, pi.SourceArn, pi.DestinationPort); err != nil { + return nil, err + } + + probeID := b.nextProbeID() + probeARN := b.buildProbeARN(region, name, probeID) + af := detectAddressFamily(pi.Destination) + + probeTags := make(map[string]string, len(pi.Tags)) + maps.Copy(probeTags, pi.Tags) + + probes = append(probes, &Probe{ + ProbeID: probeID, + ProbeArn: probeARN, + SourceArn: pi.SourceArn, + Destination: pi.Destination, + Protocol: pi.Protocol, + DestinationPort: pi.DestinationPort, + PacketSize: pi.PacketSize, + State: probeStateActive, + AddressFamily: af, + CreatedAt: &now, + ModifiedAt: &now, + Tags: probeTags, + }) + } + + tagsCopy := make(map[string]string, len(tags)) + maps.Copy(tagsCopy, tags) + + m := &Monitor{ + MonitorArn: monARN, + MonitorName: name, + State: monitorStateActive, + AggregationPeriod: period, + Probes: probes, + Tags: tagsCopy, + CreatedAt: &now, + ModifiedAt: &now, + } + + rm[name] = m + b.regionARNIndex(region)[monARN] = name + + return monitorCopy(m), nil +} + +// DeleteMonitor deletes a monitor and its probes. +func (b *InMemoryBackend) DeleteMonitor(ctx context.Context, name string) error { + region := getRegion(ctx, b.defaultRegion) + + b.mu.Lock() + defer b.mu.Unlock() + + rm := b.regionMonitors(region) + + m, exists := rm[name] + if !exists { + return fmt.Errorf("%w: monitor %q not found", ErrNotFound, name) + } + + delete(b.regionARNIndex(region), m.MonitorArn) + delete(rm, name) + + return nil +} + +// GetMonitor returns a monitor by name. +func (b *InMemoryBackend) GetMonitor(ctx context.Context, name string) (*Monitor, error) { + region := getRegion(ctx, b.defaultRegion) + + b.mu.RLock() + defer b.mu.RUnlock() + + rm := b.monitors[region] + + m, exists := rm[name] + if !exists { + return nil, fmt.Errorf("%w: monitor %q not found", ErrNotFound, name) + } + + return monitorCopy(m), nil +} + +// UpdateMonitor updates a monitor's aggregation period. +func (b *InMemoryBackend) UpdateMonitor( + ctx context.Context, + name string, + aggregationPeriod int64, +) (*Monitor, error) { + if aggregationPeriod != 30 && aggregationPeriod != 60 { + return nil, fmt.Errorf("%w: aggregationPeriod must be 30 or 60", ErrValidation) + } + + region := getRegion(ctx, b.defaultRegion) + + b.mu.Lock() + defer b.mu.Unlock() + + rm := b.monitors[region] + + m, exists := rm[name] + if !exists { + return nil, fmt.Errorf("%w: monitor %q not found", ErrNotFound, name) + } + + now := time.Now().UTC() + m.AggregationPeriod = aggregationPeriod + m.ModifiedAt = &now + + return monitorCopy(m), nil +} + +// ListMonitors returns a filtered, paginated list of monitor summaries. +func (b *InMemoryBackend) ListMonitors( + ctx context.Context, + state, nextToken string, + maxResults int, +) ([]monitorSummary, string, error) { + region := getRegion(ctx, b.defaultRegion) + + b.mu.RLock() + defer b.mu.RUnlock() + + rm := b.monitors[region] + + names := make([]string, 0, len(rm)) + + for n := range rm { + names = append(names, n) + } + + sort.Strings(names) + + startIdx := 0 + + if nextToken != "" { + for i, n := range names { + if n > nextToken { + startIdx = i + + break + } + } + } + + if maxResults <= 0 || maxResults > 100 { + maxResults = 100 + } + + var summaries []monitorSummary + + for i := startIdx; i < len(names) && len(summaries) < maxResults; i++ { + m := rm[names[i]] + if state != "" && !strings.EqualFold(m.State, state) { + continue + } + + period := m.AggregationPeriod + s := monitorSummary{ + MonitorArn: m.MonitorArn, + MonitorName: m.MonitorName, + State: m.State, + AggregationPeriod: &period, + Tags: maps.Clone(m.Tags), + } + + summaries = append(summaries, s) + } + + var outToken string + + if len(summaries) == maxResults && startIdx+maxResults < len(names) { + outToken = summaries[len(summaries)-1].MonitorName + } + + if summaries == nil { + summaries = []monitorSummary{} + } + + return summaries, outToken, nil +} + +// CreateProbe adds a probe to an existing monitor. +func (b *InMemoryBackend) CreateProbe( + ctx context.Context, + monitorName string, + pi *probeInput, + tags map[string]string, +) (*Probe, error) { + if pi == nil { + return nil, fmt.Errorf("%w: probe is required", ErrValidation) + } + + if err := validateProbeInput(pi.Destination, pi.Protocol, pi.SourceArn, pi.DestinationPort); err != nil { + return nil, err + } + + region := getRegion(ctx, b.defaultRegion) + + b.mu.Lock() + defer b.mu.Unlock() + + rm := b.monitors[region] + + m, exists := rm[monitorName] + if !exists { + return nil, fmt.Errorf("%w: monitor %q not found", ErrNotFound, monitorName) + } + + now := time.Now().UTC() + probeID := b.nextProbeID() + probeARN := b.buildProbeARN(region, monitorName, probeID) + af := detectAddressFamily(pi.Destination) + + tagsCopy := make(map[string]string, len(pi.Tags)+len(tags)) + maps.Copy(tagsCopy, pi.Tags) + maps.Copy(tagsCopy, tags) + + probe := &Probe{ + ProbeID: probeID, + ProbeArn: probeARN, + SourceArn: pi.SourceArn, + Destination: pi.Destination, + Protocol: pi.Protocol, + DestinationPort: pi.DestinationPort, + PacketSize: pi.PacketSize, + State: probeStateActive, + AddressFamily: af, + CreatedAt: &now, + ModifiedAt: &now, + Tags: tagsCopy, + } + + m.Probes = append(m.Probes, probe) + m.ModifiedAt = &now + + return probeCopy(probe), nil +} + +// DeleteProbe removes a probe from a monitor. +func (b *InMemoryBackend) DeleteProbe(ctx context.Context, monitorName, probeID string) error { + region := getRegion(ctx, b.defaultRegion) + + b.mu.Lock() + defer b.mu.Unlock() + + rm := b.monitors[region] + + m, exists := rm[monitorName] + if !exists { + return fmt.Errorf("%w: monitor %q not found", ErrNotFound, monitorName) + } + + idx := findProbeIndex(m.Probes, probeID) + if idx < 0 { + return fmt.Errorf("%w: probe %q not found in monitor %q", ErrNotFound, probeID, monitorName) + } + + now := time.Now().UTC() + m.Probes = append(m.Probes[:idx], m.Probes[idx+1:]...) + m.ModifiedAt = &now + + return nil +} + +// GetProbe returns a probe from a monitor. +func (b *InMemoryBackend) GetProbe( + ctx context.Context, + monitorName, probeID string, +) (*Probe, error) { + region := getRegion(ctx, b.defaultRegion) + + b.mu.RLock() + defer b.mu.RUnlock() + + rm := b.monitors[region] + + m, exists := rm[monitorName] + if !exists { + return nil, fmt.Errorf("%w: monitor %q not found", ErrNotFound, monitorName) + } + + idx := findProbeIndex(m.Probes, probeID) + if idx < 0 { + return nil, fmt.Errorf( + "%w: probe %q not found in monitor %q", + ErrNotFound, + probeID, + monitorName, + ) + } + + return probeCopy(m.Probes[idx]), nil +} + +// UpdateProbe updates a probe's attributes. +func (b *InMemoryBackend) UpdateProbe( + ctx context.Context, + monitorName, probeID string, + req *updateProbeRequest, +) (*Probe, error) { + if req == nil { + return nil, fmt.Errorf("%w: update request is required", ErrValidation) + } + + region := getRegion(ctx, b.defaultRegion) + + b.mu.Lock() + defer b.mu.Unlock() + + rm := b.monitors[region] + + m, exists := rm[monitorName] + if !exists { + return nil, fmt.Errorf("%w: monitor %q not found", ErrNotFound, monitorName) + } + + idx := findProbeIndex(m.Probes, probeID) + if idx < 0 { + return nil, fmt.Errorf( + "%w: probe %q not found in monitor %q", + ErrNotFound, + probeID, + monitorName, + ) + } + + probe := m.Probes[idx] + now := time.Now().UTC() + + if req.Destination != "" { + probe.Destination = req.Destination + probe.AddressFamily = detectAddressFamily(req.Destination) + } + + if req.Protocol != "" { + probe.Protocol = req.Protocol + } + + if req.DestinationPort != nil { + probe.DestinationPort = req.DestinationPort + } + + if req.PacketSize != nil { + probe.PacketSize = req.PacketSize + } + + if req.State != "" { + probe.State = req.State + } + + if req.Tags != nil { + maps.Copy(probe.Tags, req.Tags) + } + + probe.ModifiedAt = &now + m.ModifiedAt = &now + + return probeCopy(probe), nil +} + +// ListTagsForResource returns tags for a monitor or probe by ARN. +func (b *InMemoryBackend) ListTagsForResource( + ctx context.Context, + resourceARN string, +) (map[string]string, error) { + region := getRegion(ctx, b.defaultRegion) + + b.mu.RLock() + defer b.mu.RUnlock() + + return b.lookupTagsByARN(region, resourceARN) +} + +// TagResource adds or updates tags on a monitor or probe. +func (b *InMemoryBackend) TagResource( + ctx context.Context, + resourceARN string, + tags map[string]string, +) error { + region := getRegion(ctx, b.defaultRegion) + + b.mu.Lock() + defer b.mu.Unlock() + + m, probe, err := b.findResourceByARN(region, resourceARN) + if err != nil { + return err + } + + if probe != nil { + if probe.Tags == nil { + probe.Tags = make(map[string]string) + } + + maps.Copy(probe.Tags, tags) + } else if m != nil { + if m.Tags == nil { + m.Tags = make(map[string]string) + } + + maps.Copy(m.Tags, tags) + } + + return nil +} + +// UntagResource removes tags from a monitor or probe. +func (b *InMemoryBackend) UntagResource( + ctx context.Context, + resourceARN string, + tagKeys []string, +) error { + region := getRegion(ctx, b.defaultRegion) + + b.mu.Lock() + defer b.mu.Unlock() + + m, probe, err := b.findResourceByARN(region, resourceARN) + if err != nil { + return err + } + + if probe != nil { + for _, k := range tagKeys { + delete(probe.Tags, k) + } + } else if m != nil { + for _, k := range tagKeys { + delete(m.Tags, k) + } + } + + return nil +} + +func (b *InMemoryBackend) lookupTagsByARN(region, resourceARN string) (map[string]string, error) { + m, probe, err := b.findResourceByARN(region, resourceARN) + if err != nil { + return nil, err + } + + if probe != nil { + return maps.Clone(probe.Tags), nil + } + + return maps.Clone(m.Tags), nil +} + +// findResourceByARN resolves an ARN to either a monitor or a probe (not both). +// Must be called with b.mu held (read or write). +func (b *InMemoryBackend) findResourceByARN(region, resourceARN string) (*Monitor, *Probe, error) { + // ARN formats: + // monitor: arn:aws:networkmonitor:{region}:{acct}:monitor/{name} + // probe: arn:aws:networkmonitor:{region}:{acct}:probe/{monitorName}/{probeId} + parts := strings.SplitN(resourceARN, ":", arnColonParts) + if len(parts) < arnColonParts { + return nil, nil, fmt.Errorf("%w: invalid resource ARN", ErrNotFound) + } + + resource := parts[arnColonParts-1] + rm := b.monitors[region] + + if monitorName, ok := strings.CutPrefix(resource, "monitor/"); ok { + m, exists := rm[monitorName] + if !exists { + return nil, nil, fmt.Errorf("%w: resource %q not found", ErrNotFound, resourceARN) + } + + return m, nil, nil + } + + if rest, ok := strings.CutPrefix(resource, "probe/"); ok { + segments := strings.SplitN(rest, "/", probePathParts) + if len(segments) != probePathParts { + return nil, nil, fmt.Errorf("%w: invalid probe ARN", ErrNotFound) + } + + monitorName, probeID := segments[0], segments[1] + + m, exists := rm[monitorName] + if !exists { + return nil, nil, fmt.Errorf("%w: resource %q not found", ErrNotFound, resourceARN) + } + + idx := findProbeIndex(m.Probes, probeID) + if idx < 0 { + return nil, nil, fmt.Errorf("%w: resource %q not found", ErrNotFound, resourceARN) + } + + return nil, m.Probes[idx], nil + } + + return nil, nil, fmt.Errorf("%w: resource %q not found", ErrNotFound, resourceARN) +} + +func findProbeIndex(probes []*Probe, probeID string) int { + for i, p := range probes { + if p.ProbeID == probeID { + return i + } + } + + return -1 +} + +func validateMonitorName(name string) error { + if !monitorNameRE.MatchString(name) { + return fmt.Errorf("%w: monitorName must match [a-zA-Z0-9_-]{1,200}", ErrValidation) + } + + return nil +} + +func validateProbeInput(destination, protocol, sourceARN string, destPort *int32) error { + if destination == "" { + return fmt.Errorf("%w: probe destination is required", ErrValidation) + } + + if protocol == "" { + return fmt.Errorf("%w: probe protocol is required", ErrValidation) + } + + proto := strings.ToUpper(protocol) + if proto != "TCP" && proto != "ICMP" { + return fmt.Errorf("%w: probe protocol must be TCP or ICMP", ErrValidation) + } + + if sourceARN == "" { + return fmt.Errorf("%w: probe sourceArn is required", ErrValidation) + } + + if proto == "TCP" && destPort == nil { + return fmt.Errorf("%w: destinationPort is required when protocol is TCP", ErrValidation) + } + + if destPort != nil && (*destPort < 1 || *destPort > 65535) { + return fmt.Errorf("%w: destinationPort must be between 1 and 65535", ErrValidation) + } + + return nil +} + +// detectAddressFamily returns "IPV4" or "IPV6" based on destination format. +func detectAddressFamily(destination string) string { + if strings.Contains(destination, ":") { + return "IPV6" + } + + return "IPV4" +} + +func monitorCopy(m *Monitor) *Monitor { + if m == nil { + return nil + } + + cp := *m + cp.Tags = maps.Clone(m.Tags) + + if m.Probes != nil { + cp.Probes = make([]*Probe, len(m.Probes)) + for i, p := range m.Probes { + cp.Probes[i] = probeCopy(p) + } + } + + if m.CreatedAt != nil { + t := *m.CreatedAt + cp.CreatedAt = &t + } + + if m.ModifiedAt != nil { + t := *m.ModifiedAt + cp.ModifiedAt = &t + } + + return &cp +} + +func probeCopy(p *Probe) *Probe { + if p == nil { + return nil + } + + cp := *p + cp.Tags = maps.Clone(p.Tags) + + if p.CreatedAt != nil { + t := *p.CreatedAt + cp.CreatedAt = &t + } + + if p.ModifiedAt != nil { + t := *p.ModifiedAt + cp.ModifiedAt = &t + } + + if p.DestinationPort != nil { + v := *p.DestinationPort + cp.DestinationPort = &v + } + + if p.PacketSize != nil { + v := *p.PacketSize + cp.PacketSize = &v + } + + return &cp +} diff --git a/services/networkmonitor/backend_test.go b/services/networkmonitor/backend_test.go new file mode 100644 index 000000000..a98992577 --- /dev/null +++ b/services/networkmonitor/backend_test.go @@ -0,0 +1,495 @@ +package networkmonitor_test + +import ( + "context" + "testing" + + "github.com/blackbirdworks/gopherstack/services/networkmonitor" +) + +func newTestBackend(t *testing.T) *networkmonitor.InMemoryBackend { + t.Helper() + + return networkmonitor.NewInMemoryBackend("us-east-1", "000000000000") +} + +func ptr[T any](v T) *T { + p := new(T) + *p = v + + return p +} + +func TestCreateMonitor(t *testing.T) { + t.Parallel() + + tests := []struct { + aggregationPeriod *int64 + name string + monitorName string + wantState string + wantPeriod int64 + wantErr bool + }{ + { + name: "valid monitor no period", + monitorName: "test-monitor", + wantState: "ACTIVE", + wantPeriod: 60, + }, + { + name: "valid monitor period 30", + monitorName: "monitor-30", + aggregationPeriod: ptr(int64(30)), + wantState: "ACTIVE", + wantPeriod: 30, + }, + { + name: "valid monitor period 60", + monitorName: "monitor-60", + aggregationPeriod: ptr(int64(60)), + wantState: "ACTIVE", + wantPeriod: 60, + }, + { + name: "invalid period", + monitorName: "bad-period", + aggregationPeriod: ptr(int64(45)), + wantErr: true, + }, + { + name: "invalid monitor name", + monitorName: "bad name!", + wantErr: true, + }, + { + name: "empty name", + monitorName: "", + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + b := newTestBackend(t) + m, err := b.CreateMonitor( + context.Background(), + tc.monitorName, + tc.aggregationPeriod, + nil, + nil, + ) + + if tc.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if m.MonitorName != tc.monitorName { + t.Errorf("name: got %q, want %q", m.MonitorName, tc.monitorName) + } + + if m.State != tc.wantState { + t.Errorf("state: got %q, want %q", m.State, tc.wantState) + } + + if m.AggregationPeriod != tc.wantPeriod { + t.Errorf("period: got %d, want %d", m.AggregationPeriod, tc.wantPeriod) + } + + if m.MonitorArn == "" { + t.Error("expected non-empty MonitorArn") + } + }) + } +} + +func TestCreateMonitorDuplicate(t *testing.T) { + t.Parallel() + + b := newTestBackend(t) + ctx := context.Background() + + if _, err := b.CreateMonitor(ctx, "dup-monitor", nil, nil, nil); err != nil { + t.Fatalf("first create: %v", err) + } + + if _, err := b.CreateMonitor(ctx, "dup-monitor", nil, nil, nil); err == nil { + t.Fatal("expected conflict error, got nil") + } +} + +func TestGetMonitor(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + monitorName string + create bool + wantErr bool + }{ + { + name: "existing monitor", + create: true, + monitorName: "exists", + }, + { + name: "missing monitor", + create: false, + monitorName: "missing", + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + b := newTestBackend(t) + ctx := context.Background() + + if tc.create { + if _, err := b.CreateMonitor(ctx, tc.monitorName, nil, nil, nil); err != nil { + t.Fatalf("create: %v", err) + } + } + + m, err := b.GetMonitor(ctx, tc.monitorName) + + if tc.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if m.MonitorName != tc.monitorName { + t.Errorf("name: got %q, want %q", m.MonitorName, tc.monitorName) + } + }) + } +} + +func TestDeleteMonitor(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + monitorName string + create bool + wantErr bool + }{ + { + name: "delete existing", + create: true, + monitorName: "to-delete", + }, + { + name: "delete missing", + create: false, + monitorName: "ghost", + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + b := newTestBackend(t) + ctx := context.Background() + + if tc.create { + if _, err := b.CreateMonitor(ctx, tc.monitorName, nil, nil, nil); err != nil { + t.Fatalf("create: %v", err) + } + } + + err := b.DeleteMonitor(ctx, tc.monitorName) + + if tc.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if _, getErr := b.GetMonitor(ctx, tc.monitorName); getErr == nil { + t.Fatal("expected not-found after delete, got nil") + } + }) + } +} + +func TestUpdateMonitor(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + monitorName string + aggregationPeriod int64 + wantErr bool + wantPeriod int64 + }{ + { + name: "update to 30", + monitorName: "mon", + aggregationPeriod: 30, + wantPeriod: 30, + }, + { + name: "update to 60", + monitorName: "mon", + aggregationPeriod: 60, + wantPeriod: 60, + }, + { + name: "invalid period", + monitorName: "mon", + aggregationPeriod: 45, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + b := newTestBackend(t) + ctx := context.Background() + + if _, err := b.CreateMonitor(ctx, tc.monitorName, nil, nil, nil); err != nil { + t.Fatalf("create: %v", err) + } + + m, err := b.UpdateMonitor(ctx, tc.monitorName, tc.aggregationPeriod) + + if tc.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if m.AggregationPeriod != tc.wantPeriod { + t.Errorf("period: got %d, want %d", m.AggregationPeriod, tc.wantPeriod) + } + }) + } +} + +func TestListMonitors(t *testing.T) { + t.Parallel() + + b := newTestBackend(t) + ctx := context.Background() + + for _, name := range []string{"alpha", "beta", "gamma"} { + if _, err := b.CreateMonitor(ctx, name, nil, nil, nil); err != nil { + t.Fatalf("create %s: %v", name, err) + } + } + + summaries, token, err := b.ListMonitors(ctx, "", "", 0) + if err != nil { + t.Fatalf("list: %v", err) + } + + if len(summaries) != 3 { + t.Errorf("count: got %d, want 3", len(summaries)) + } + + if token != "" { + t.Errorf("unexpected next token: %q", token) + } +} + +func TestListMonitorsPagination(t *testing.T) { + t.Parallel() + + b := newTestBackend(t) + ctx := context.Background() + + for _, name := range []string{"a-monitor", "b-monitor", "c-monitor", "d-monitor", "e-monitor"} { + if _, err := b.CreateMonitor(ctx, name, nil, nil, nil); err != nil { + t.Fatalf("create %s: %v", name, err) + } + } + + page1, token, err := b.ListMonitors(ctx, "", "", 2) + if err != nil { + t.Fatalf("page1: %v", err) + } + + if len(page1) != 2 { + t.Errorf("page1 count: got %d, want 2", len(page1)) + } + + if token == "" { + t.Fatal("expected next token for page2") + } + + page2, _, err := b.ListMonitors(ctx, "", token, 2) + if err != nil { + t.Fatalf("page2: %v", err) + } + + if len(page2) != 2 { + t.Errorf("page2 count: got %d, want 2", len(page2)) + } +} + +func TestProbeLifecycle(t *testing.T) { + t.Parallel() + + b := newTestBackend(t) + ctx := context.Background() + + if _, err := b.CreateMonitor(ctx, "probe-mon", nil, nil, nil); err != nil { + t.Fatalf("create monitor: %v", err) + } + + probe, err := b.CreateProbe(ctx, "probe-mon", &networkmonitor.ProbeInputForTest{ + Destination: "10.0.0.1", + Protocol: "ICMP", + SourceArn: "arn:aws:ec2:us-east-1:000000000000:subnet/subnet-abc", + }, nil) + + if err != nil { + t.Fatalf("create probe: %v", err) + } + + if probe.ProbeID == "" { + t.Error("expected non-empty ProbeID") + } + + got, err := b.GetProbe(ctx, "probe-mon", probe.ProbeID) + if err != nil { + t.Fatalf("get probe: %v", err) + } + + if got.Destination != "10.0.0.1" { + t.Errorf("destination: got %q, want 10.0.0.1", got.Destination) + } + + if delErr := b.DeleteProbe(ctx, "probe-mon", probe.ProbeID); delErr != nil { + t.Fatalf("delete probe: %v", delErr) + } + + if _, getErr := b.GetProbe(ctx, "probe-mon", probe.ProbeID); getErr == nil { + t.Fatal("expected not-found after delete") + } +} + +func TestCreateProbeTCPRequiresPort(t *testing.T) { + t.Parallel() + + b := newTestBackend(t) + ctx := context.Background() + + if _, err := b.CreateMonitor(ctx, "tcp-mon", nil, nil, nil); err != nil { + t.Fatalf("create monitor: %v", err) + } + + _, err := b.CreateProbe(ctx, "tcp-mon", &networkmonitor.ProbeInputForTest{ + Destination: "10.0.0.1", + Protocol: "TCP", + SourceArn: "arn:aws:ec2:us-east-1:000000000000:subnet/subnet-abc", + }, nil) + + if err == nil { + t.Fatal("expected validation error: TCP requires port") + } +} + +func TestTagging(t *testing.T) { + t.Parallel() + + b := newTestBackend(t) + ctx := context.Background() + + m, err := b.CreateMonitor(ctx, "tagged-mon", nil, nil, map[string]string{"env": "test"}) + if err != nil { + t.Fatalf("create: %v", err) + } + + tags, err := b.ListTagsForResource(ctx, m.MonitorArn) + if err != nil { + t.Fatalf("list tags: %v", err) + } + + if tags["env"] != "test" { + t.Errorf("tag env: got %q, want test", tags["env"]) + } + + if tagErr := b.TagResource(ctx, m.MonitorArn, map[string]string{"team": "sre"}); tagErr != nil { + t.Fatalf("tag resource: %v", tagErr) + } + + tags, err = b.ListTagsForResource(ctx, m.MonitorArn) + if err != nil { + t.Fatalf("list tags after add: %v", err) + } + + if tags["team"] != "sre" { + t.Errorf("tag team: got %q, want sre", tags["team"]) + } + + if untagErr := b.UntagResource(ctx, m.MonitorArn, []string{"env"}); untagErr != nil { + t.Fatalf("untag: %v", untagErr) + } + + tags, err = b.ListTagsForResource(ctx, m.MonitorArn) + if err != nil { + t.Fatalf("list tags after remove: %v", err) + } + + if _, ok := tags["env"]; ok { + t.Error("expected env tag removed") + } +} + +func TestRegionIsolation(t *testing.T) { + t.Parallel() + + b := networkmonitor.NewInMemoryBackend("us-east-1", "000000000000") + + ctxEast := networkmonitor.WithRegion("us-east-1") + ctxWest := networkmonitor.WithRegion("us-west-2") + + if _, err := b.CreateMonitor(ctxEast, "regional-mon", nil, nil, nil); err != nil { + t.Fatalf("create in us-east-1: %v", err) + } + + if _, err := b.GetMonitor(ctxWest, "regional-mon"); err == nil { + t.Fatal("expected not-found in us-west-2") + } + + if _, err := b.GetMonitor(ctxEast, "regional-mon"); err != nil { + t.Fatalf("expected found in us-east-1: %v", err) + } +} diff --git a/services/networkmonitor/export_test.go b/services/networkmonitor/export_test.go new file mode 100644 index 000000000..23bd68d5f --- /dev/null +++ b/services/networkmonitor/export_test.go @@ -0,0 +1,30 @@ +package networkmonitor + +import "context" + +// ProbeInputForTest is a test alias for probeInput so external test packages can create probes. +type ProbeInputForTest = probeInput + +// MonitorCount returns the number of monitors stored across all regions. +func MonitorCount(b *InMemoryBackend) int { + b.mu.RLock() + defer b.mu.RUnlock() + + total := 0 + + for _, regionMons := range b.monitors { + total += len(regionMons) + } + + return total +} + +// HandlerOpsLen returns the number of operations the handler supports. +func HandlerOpsLen(h *Handler) int { + return len(h.GetSupportedOperations()) +} + +// WithRegion returns a context with the given region set. +func WithRegion(region string) context.Context { + return context.WithValue(context.Background(), regionContextKey{}, region) +} diff --git a/services/networkmonitor/exports.go b/services/networkmonitor/exports.go new file mode 100644 index 000000000..e4b48de1d --- /dev/null +++ b/services/networkmonitor/exports.go @@ -0,0 +1,4 @@ +package networkmonitor + +// ExportedMonitor is a compatibility alias used by the dashboard package. +type ExportedMonitor = Monitor diff --git a/services/networkmonitor/handler.go b/services/networkmonitor/handler.go new file mode 100644 index 000000000..773fc8f69 --- /dev/null +++ b/services/networkmonitor/handler.go @@ -0,0 +1,621 @@ +package networkmonitor + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/labstack/echo/v5" + + "github.com/blackbirdworks/gopherstack/pkgs/awserr" + "github.com/blackbirdworks/gopherstack/pkgs/httputils" + "github.com/blackbirdworks/gopherstack/pkgs/logger" + "github.com/blackbirdworks/gopherstack/pkgs/service" +) + +const ( + opCreateMonitor = "CreateMonitor" + opDeleteMonitor = "DeleteMonitor" + opGetMonitor = "GetMonitor" + opUpdateMonitor = "UpdateMonitor" + opListMonitors = "ListMonitors" + opCreateProbe = "CreateProbe" + opDeleteProbe = "DeleteProbe" + opGetProbe = "GetProbe" + opUpdateProbe = "UpdateProbe" + opListTagsForResource = "ListTagsForResource" + opTagResource = "TagResource" + opUntagResource = "UntagResource" +) + +const ( + nmService = "networkmonitor" + nmMatchPriority = 88 + nmPathMonitors = "/monitors" + nmPathTags = "/tags/" + opUnknown = "Unknown" + splitTwo = 2 + splitThree = 3 + splitFour = 4 +) + +var errUnknownAction = errors.New("unknown action") + +// Handler is the HTTP handler for the CloudWatch Network Monitor REST API. +type Handler struct { + Backend StorageBackend + AccountID string + DefaultRegion string +} + +// NewHandler creates a new Network Monitor handler. +func NewHandler(backend StorageBackend) *Handler { + return &Handler{Backend: backend} +} + +// Reset clears handler state (delegates to backend if supported). +func (h *Handler) Reset() { + if r, ok := h.Backend.(interface{ Reset() }); ok { + r.Reset() + } +} + +// Name returns the service name. +func (h *Handler) Name() string { return "NetworkMonitor" } + +// GetSupportedOperations returns the list of supported Network Monitor operations. +func (h *Handler) GetSupportedOperations() []string { + return []string{ + opCreateMonitor, + opDeleteMonitor, + opGetMonitor, + opUpdateMonitor, + opListMonitors, + opCreateProbe, + opDeleteProbe, + opGetProbe, + opUpdateProbe, + opListTagsForResource, + opTagResource, + opUntagResource, + } +} + +// ChaosServiceName returns the lowercase AWS service name for fault rule matching. +func (h *Handler) ChaosServiceName() string { return nmService } + +// ChaosOperations returns all operations that can be fault-injected. +func (h *Handler) ChaosOperations() []string { return h.GetSupportedOperations() } + +// ChaosRegions returns all regions this handler handles. +func (h *Handler) ChaosRegions() []string { return []string{h.DefaultRegion} } + +// RouteMatcher returns a function that matches Network Monitor requests by service + path. +func (h *Handler) RouteMatcher() service.Matcher { + return func(c *echo.Context) bool { + if httputils.ExtractServiceFromRequest(c.Request()) != nmService { + return false + } + + path := c.Request().URL.Path + + return strings.HasPrefix(path, nmPathMonitors) || strings.HasPrefix(path, nmPathTags) + } +} + +// MatchPriority returns the routing priority. +func (h *Handler) MatchPriority() int { return nmMatchPriority } + +// ExtractOperation determines the operation name from the HTTP request. +func (h *Handler) ExtractOperation(c *echo.Context) string { + method := c.Request().Method + path := c.Request().URL.Path + + if strings.HasPrefix(path, nmPathTags) { + return extractTagOp(method) + } + + return extractMonitorOp(method, path) +} + +func extractTagOp(method string) string { + switch method { + case http.MethodGet: + return opListTagsForResource + case http.MethodPost: + return opTagResource + case http.MethodDelete: + return opUntagResource + } + + return opUnknown +} + +func extractMonitorOp(method, path string) string { + trimmed := strings.TrimSuffix(strings.TrimPrefix(path, "/monitors"), "/") + segments := strings.SplitN(strings.TrimPrefix(trimmed, "/"), "/", splitFour) + + if trimmed == "" || (len(segments) == 1 && segments[0] == "") { + return extractMonitorListOp(method) + } + + if len(segments) == 1 { + return extractMonitorCRUDOp(method) + } + + if len(segments) >= 2 && segments[1] == "probes" { + return extractProbeOp(method, segments) + } + + return opUnknown +} + +func extractMonitorListOp(method string) string { + switch method { + case http.MethodPost: + return opCreateMonitor + case http.MethodGet: + return opListMonitors + } + + return opUnknown +} + +func extractMonitorCRUDOp(method string) string { + switch method { + case http.MethodGet: + return opGetMonitor + case http.MethodDelete: + return opDeleteMonitor + case http.MethodPatch: + return opUpdateMonitor + } + + return opUnknown +} + +func extractProbeOp(method string, segments []string) string { + if len(segments) == 2 || segments[2] == "" { + if method == http.MethodPost { + return opCreateProbe + } + + return opUnknown + } + + switch method { + case http.MethodGet: + return opGetProbe + case http.MethodDelete: + return opDeleteProbe + case http.MethodPatch: + return opUpdateProbe + } + + return opUnknown +} + +// ExtractResource extracts the monitor name from the request path. +func (h *Handler) ExtractResource(c *echo.Context) string { + path := c.Request().URL.Path + trimmed := strings.TrimPrefix(path, "/monitors/") + + if trimmed == path { + return "" + } + + parts := strings.SplitN(trimmed, "/", splitTwo) + + return parts[0] +} + +// Handler returns the Echo handler function. +func (h *Handler) Handler() echo.HandlerFunc { + return func(c *echo.Context) error { + region := httputils.ExtractRegionFromRequest(c.Request(), h.DefaultRegion) + ctx := context.WithValue(c.Request().Context(), regionContextKey{}, region) + log := logger.Load(ctx) + + path := c.Request().URL.Path + query := c.Request().URL.Query() + + body, err := httputils.ReadBody(c.Request()) + if err != nil { + log.ErrorContext(ctx, "networkmonitor: failed to read request body", "error", err) + + return c.String(http.StatusInternalServerError, "internal server error") + } + + op := h.ExtractOperation(c) + + result, dispErr := h.dispatch(ctx, op, path, query, body) + if dispErr != nil { + return h.handleError(c, dispErr) + } + + if result == nil { + return c.NoContent(http.StatusOK) + } + + return c.JSONBlob(http.StatusOK, result) + } +} + +func (h *Handler) dispatch( + ctx context.Context, + op, path string, + query url.Values, + body []byte, +) ([]byte, error) { + switch op { + case opCreateMonitor: + return h.handleCreateMonitor(ctx, body) + case opDeleteMonitor: + return h.handleDeleteMonitor(ctx, path) + case opGetMonitor: + return h.handleGetMonitor(ctx, path) + case opUpdateMonitor: + return h.handleUpdateMonitor(ctx, path, body) + case opListMonitors: + return h.handleListMonitors(ctx, query) + case opCreateProbe: + return h.handleCreateProbe(ctx, path, body) + case opDeleteProbe: + return h.handleDeleteProbe(ctx, path) + case opGetProbe: + return h.handleGetProbe(ctx, path) + case opUpdateProbe: + return h.handleUpdateProbe(ctx, path, body) + case opListTagsForResource: + return h.handleListTagsForResource(ctx, path) + case opTagResource: + return h.handleTagResource(ctx, path, body) + case opUntagResource: + return h.handleUntagResource(ctx, path, query) + default: + return nil, fmt.Errorf("%w: %s", errUnknownAction, op) + } +} + +func (h *Handler) handleError(c *echo.Context, err error) error { + var syntaxErr *json.SyntaxError + var typeErr *json.UnmarshalTypeError + + var code string + var status int + + switch { + case errors.Is(err, awserr.ErrNotFound): + status = http.StatusNotFound + code = "ResourceNotFoundException" + case errors.Is(err, awserr.ErrAlreadyExists): + status = http.StatusConflict + code = "ConflictException" + case errors.Is(err, awserr.ErrInvalidParameter): + status = http.StatusBadRequest + code = "ValidationException" + case errors.Is(err, errUnknownAction), + errors.As(err, &syntaxErr), + errors.As(err, &typeErr): + status = http.StatusBadRequest + code = "ValidationException" + default: + status = http.StatusInternalServerError + code = "InternalServerException" + } + + c.Response().Header().Set("X-Amzn-Errortype", code) + + return c.JSON(status, errorResponse{Message: err.Error()}) +} + +// extractMonitorName extracts the monitor name from /monitors/{name}[/...]. +func extractMonitorName(path string) string { + trimmed := strings.TrimPrefix(path, "/monitors/") + if trimmed == path { + return "" + } + + parts := strings.SplitN(trimmed, "/", splitTwo) + + return parts[0] +} + +// extractProbeID extracts the probe ID from /monitors/{name}/probes/{probeId}. +func extractProbeID(path string) string { + // path format: /monitors/{name}/probes/{probeId} + trimmed := strings.TrimPrefix(path, "/monitors/") + if trimmed == path { + return "" + } + + parts := strings.SplitN(trimmed, "/", splitThree) + if len(parts) < 3 || parts[1] != "probes" { + return "" + } + + return parts[2] +} + +// extractTagResourceARN extracts the resource ARN from /tags/{resourceArn}. +func extractTagResourceARN(path string) string { + trimmed := strings.TrimPrefix(path, nmPathTags) + if trimmed == path { + return "" + } + + return trimmed +} + +func (h *Handler) handleCreateMonitor(ctx context.Context, body []byte) ([]byte, error) { + var req createMonitorRequest + if err := json.Unmarshal(body, &req); err != nil { + return nil, err + } + + if req.MonitorName == "" { + return nil, fmt.Errorf("%w: monitorName is required", ErrValidation) + } + + m, err := h.Backend.CreateMonitor( + ctx, + req.MonitorName, + req.AggregationPeriod, + req.Probes, + req.Tags, + ) + if err != nil { + return nil, err + } + + period := m.AggregationPeriod + resp := createMonitorResponse{ + MonitorArn: m.MonitorArn, + MonitorName: m.MonitorName, + State: m.State, + AggregationPeriod: &period, + Tags: m.Tags, + } + + return json.Marshal(resp) +} + +func (h *Handler) handleDeleteMonitor(ctx context.Context, path string) ([]byte, error) { + name := extractMonitorName(path) + if name == "" { + return nil, fmt.Errorf("%w: monitorName is required", ErrValidation) + } + + if err := h.Backend.DeleteMonitor(ctx, name); err != nil { + return nil, err + } + + return nil, nil +} + +func (h *Handler) handleGetMonitor(ctx context.Context, path string) ([]byte, error) { + name := extractMonitorName(path) + if name == "" { + return nil, fmt.Errorf("%w: monitorName is required", ErrValidation) + } + + m, err := h.Backend.GetMonitor(ctx, name) + if err != nil { + return nil, err + } + + resp := getMonitorResponse{ + MonitorArn: m.MonitorArn, + MonitorName: m.MonitorName, + State: m.State, + AggregationPeriod: m.AggregationPeriod, + Probes: m.Probes, + Tags: m.Tags, + CreatedAt: m.CreatedAt, + ModifiedAt: m.ModifiedAt, + } + + return json.Marshal(resp) +} + +func (h *Handler) handleUpdateMonitor( + ctx context.Context, + path string, + body []byte, +) ([]byte, error) { + name := extractMonitorName(path) + if name == "" { + return nil, fmt.Errorf("%w: monitorName is required", ErrValidation) + } + + var req updateMonitorRequest + if err := json.Unmarshal(body, &req); err != nil { + return nil, err + } + + m, err := h.Backend.UpdateMonitor(ctx, name, req.AggregationPeriod) + if err != nil { + return nil, err + } + + period := m.AggregationPeriod + resp := updateMonitorResponse{ + MonitorArn: m.MonitorArn, + MonitorName: m.MonitorName, + State: m.State, + AggregationPeriod: &period, + Tags: m.Tags, + } + + return json.Marshal(resp) +} + +func (h *Handler) handleListMonitors(ctx context.Context, query url.Values) ([]byte, error) { + state := query.Get("state") + nextToken := query.Get("nextToken") + maxResults := 0 + + if mr := query.Get("maxResults"); mr != "" { + if _, err := fmt.Sscanf(mr, "%d", &maxResults); err != nil { + maxResults = 0 + } + } + + summaries, outToken, err := h.Backend.ListMonitors(ctx, state, nextToken, maxResults) + if err != nil { + return nil, err + } + + resp := listMonitorsResponse{ + Monitors: summaries, + NextToken: outToken, + } + + return json.Marshal(resp) +} + +func (h *Handler) handleCreateProbe(ctx context.Context, path string, body []byte) ([]byte, error) { + monitorName := extractMonitorName(path) + if monitorName == "" { + return nil, fmt.Errorf("%w: monitorName is required", ErrValidation) + } + + var req createProbeRequest + if err := json.Unmarshal(body, &req); err != nil { + return nil, err + } + + if req.Probe == nil { + return nil, fmt.Errorf("%w: probe is required", ErrValidation) + } + + probe, err := h.Backend.CreateProbe(ctx, monitorName, req.Probe, req.Tags) + if err != nil { + return nil, err + } + + return json.Marshal(probe) +} + +func (h *Handler) handleDeleteProbe(ctx context.Context, path string) ([]byte, error) { + monitorName := extractMonitorName(path) + probeID := extractProbeID(path) + + if monitorName == "" { + return nil, fmt.Errorf("%w: monitorName is required", ErrValidation) + } + + if probeID == "" { + return nil, fmt.Errorf("%w: probeId is required", ErrValidation) + } + + if err := h.Backend.DeleteProbe(ctx, monitorName, probeID); err != nil { + return nil, err + } + + return nil, nil +} + +func (h *Handler) handleGetProbe(ctx context.Context, path string) ([]byte, error) { + monitorName := extractMonitorName(path) + probeID := extractProbeID(path) + + if monitorName == "" { + return nil, fmt.Errorf("%w: monitorName is required", ErrValidation) + } + + if probeID == "" { + return nil, fmt.Errorf("%w: probeId is required", ErrValidation) + } + + probe, err := h.Backend.GetProbe(ctx, monitorName, probeID) + if err != nil { + return nil, err + } + + return json.Marshal(probe) +} + +func (h *Handler) handleUpdateProbe(ctx context.Context, path string, body []byte) ([]byte, error) { + monitorName := extractMonitorName(path) + probeID := extractProbeID(path) + + if monitorName == "" { + return nil, fmt.Errorf("%w: monitorName is required", ErrValidation) + } + + if probeID == "" { + return nil, fmt.Errorf("%w: probeId is required", ErrValidation) + } + + var req updateProbeRequest + if err := json.Unmarshal(body, &req); err != nil { + return nil, err + } + + probe, err := h.Backend.UpdateProbe(ctx, monitorName, probeID, &req) + if err != nil { + return nil, err + } + + return json.Marshal(probe) +} + +func (h *Handler) handleListTagsForResource(ctx context.Context, path string) ([]byte, error) { + resourceARN := extractTagResourceARN(path) + if resourceARN == "" { + return nil, fmt.Errorf("%w: resourceArn is required", ErrValidation) + } + + tags, err := h.Backend.ListTagsForResource(ctx, resourceARN) + if err != nil { + return nil, err + } + + if tags == nil { + tags = map[string]string{} + } + + return json.Marshal(listTagsForResourceResponse{Tags: tags}) +} + +func (h *Handler) handleTagResource(ctx context.Context, path string, body []byte) ([]byte, error) { + resourceARN := extractTagResourceARN(path) + if resourceARN == "" { + return nil, fmt.Errorf("%w: resourceArn is required", ErrValidation) + } + + var req tagResourceRequest + if err := json.Unmarshal(body, &req); err != nil { + return nil, err + } + + if err := h.Backend.TagResource(ctx, resourceARN, req.Tags); err != nil { + return nil, err + } + + return nil, nil +} + +func (h *Handler) handleUntagResource( + ctx context.Context, + path string, + query url.Values, +) ([]byte, error) { + resourceARN := extractTagResourceARN(path) + if resourceARN == "" { + return nil, fmt.Errorf("%w: resourceArn is required", ErrValidation) + } + + tagKeys := query["tagKeys"] + + if err := h.Backend.UntagResource(ctx, resourceARN, tagKeys); err != nil { + return nil, err + } + + return nil, nil +} diff --git a/services/networkmonitor/handler_test.go b/services/networkmonitor/handler_test.go new file mode 100644 index 000000000..8e14d704f --- /dev/null +++ b/services/networkmonitor/handler_test.go @@ -0,0 +1,435 @@ +package networkmonitor_test + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/labstack/echo/v5" + + "github.com/blackbirdworks/gopherstack/services/networkmonitor" +) + +func newTestHandler(t *testing.T) *networkmonitor.Handler { + t.Helper() + + b := networkmonitor.NewInMemoryBackend("us-east-1", "000000000000") + h := networkmonitor.NewHandler(b) + h.AccountID = "000000000000" + h.DefaultRegion = "us-east-1" + + return h +} + +func doNMRequest( + t *testing.T, + h *networkmonitor.Handler, + method, path string, + body any, +) *httptest.ResponseRecorder { + t.Helper() + + var bodyBytes []byte + + if body != nil { + var err error + bodyBytes, err = json.Marshal(body) + if err != nil { + t.Fatalf("marshal body: %v", err) + } + } + + e := echo.New() + req := httptest.NewRequest(method, path, bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set( + "Authorization", + "AWS4-HMAC-SHA256 Credential=AKID/20240101/us-east-1/networkmonitor/aws4_request", + ) + + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + c.SetRequest(req) + + if err := h.Handler()(c); err != nil { + t.Logf("handler returned error: %v", err) + } + + return rec +} + +func TestHandlerCreateMonitor(t *testing.T) { + t.Parallel() + + tests := []struct { + body map[string]any + name string + wantStatus int + }{ + { + name: "valid create", + body: map[string]any{"monitorName": "test-mon"}, + wantStatus: http.StatusOK, + }, + { + name: "missing name", + body: map[string]any{}, + wantStatus: http.StatusBadRequest, + }, + { + name: "invalid period", + body: map[string]any{"monitorName": "x", "aggregationPeriod": 45}, + wantStatus: http.StatusBadRequest, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + h := newTestHandler(t) + rr := doNMRequest(t, h, http.MethodPost, "/monitors", tc.body) + + if rr.Code != tc.wantStatus { + t.Errorf( + "status: got %d, want %d — body: %s", + rr.Code, + tc.wantStatus, + rr.Body.String(), + ) + } + }) + } +} + +func TestHandlerGetMonitor(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + monName string + wantStatus int + create bool + }{ + { + name: "existing monitor", + create: true, + monName: "my-mon", + wantStatus: http.StatusOK, + }, + { + name: "missing monitor", + create: false, + monName: "ghost", + wantStatus: http.StatusNotFound, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + h := newTestHandler(t) + + if tc.create { + rr := doNMRequest( + t, + h, + http.MethodPost, + "/monitors", + map[string]any{"monitorName": tc.monName}, + ) + if rr.Code != http.StatusOK { + t.Fatalf("create: status %d", rr.Code) + } + } + + rr := doNMRequest(t, h, http.MethodGet, "/monitors/"+tc.monName, nil) + + if rr.Code != tc.wantStatus { + t.Errorf( + "status: got %d, want %d — body: %s", + rr.Code, + tc.wantStatus, + rr.Body.String(), + ) + } + }) + } +} + +func TestHandlerDeleteMonitor(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + monName string + wantStatus int + create bool + }{ + { + name: "delete existing", + create: true, + monName: "del-mon", + wantStatus: http.StatusOK, + }, + { + name: "delete missing", + create: false, + monName: "ghost", + wantStatus: http.StatusNotFound, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + h := newTestHandler(t) + + if tc.create { + rr := doNMRequest( + t, + h, + http.MethodPost, + "/monitors", + map[string]any{"monitorName": tc.monName}, + ) + if rr.Code != http.StatusOK { + t.Fatalf("create: status %d", rr.Code) + } + } + + rr := doNMRequest(t, h, http.MethodDelete, "/monitors/"+tc.monName, nil) + + if rr.Code != tc.wantStatus { + t.Errorf( + "status: got %d, want %d — body: %s", + rr.Code, + tc.wantStatus, + rr.Body.String(), + ) + } + }) + } +} + +func TestHandlerListMonitors(t *testing.T) { + t.Parallel() + + h := newTestHandler(t) + + for _, name := range []string{"first-mon", "second-mon"} { + rr := doNMRequest(t, h, http.MethodPost, "/monitors", map[string]any{"monitorName": name}) + if rr.Code != http.StatusOK { + t.Fatalf("create %s: status %d", name, rr.Code) + } + } + + rr := doNMRequest(t, h, http.MethodGet, "/monitors", nil) + + if rr.Code != http.StatusOK { + t.Fatalf("list: status %d — body: %s", rr.Code, rr.Body.String()) + } + + var resp map[string]any + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + monitors, ok := resp["monitors"].([]any) + if !ok { + t.Fatalf("monitors field missing or wrong type in: %s", rr.Body.String()) + } + + if len(monitors) != 2 { + t.Errorf("count: got %d, want 2", len(monitors)) + } +} + +func TestHandlerUpdateMonitor(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + period int64 + wantStatus int + }{ + { + name: "update to 30", + period: 30, + wantStatus: http.StatusOK, + }, + { + name: "invalid period", + period: 45, + wantStatus: http.StatusBadRequest, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + h := newTestHandler(t) + rr := doNMRequest( + t, + h, + http.MethodPost, + "/monitors", + map[string]any{"monitorName": "upd-mon"}, + ) + + if rr.Code != http.StatusOK { + t.Fatalf("create: status %d", rr.Code) + } + + rr = doNMRequest( + t, + h, + http.MethodPatch, + "/monitors/upd-mon", + map[string]any{"aggregationPeriod": tc.period}, + ) + + if rr.Code != tc.wantStatus { + t.Errorf( + "status: got %d, want %d — body: %s", + rr.Code, + tc.wantStatus, + rr.Body.String(), + ) + } + }) + } +} + +func TestHandlerProbeLifecycle(t *testing.T) { + t.Parallel() + + h := newTestHandler(t) + + rr := doNMRequest( + t, + h, + http.MethodPost, + "/monitors", + map[string]any{"monitorName": "probe-mon"}, + ) + if rr.Code != http.StatusOK { + t.Fatalf("create monitor: status %d", rr.Code) + } + + rr = doNMRequest(t, h, http.MethodPost, "/monitors/probe-mon/probes", map[string]any{ + "probe": map[string]any{ + "destination": "10.0.0.2", + "protocol": "ICMP", + "sourceArn": "arn:aws:ec2:us-east-1:000000000000:subnet/subnet-abc", + }, + }) + + if rr.Code != http.StatusOK { + t.Fatalf("create probe: status %d — body: %s", rr.Code, rr.Body.String()) + } + + var probeResp map[string]any + if err := json.Unmarshal(rr.Body.Bytes(), &probeResp); err != nil { + t.Fatalf("unmarshal probe: %v", err) + } + + probeID, _ := probeResp["probeId"].(string) + if probeID == "" { + t.Fatal("expected non-empty probeId") + } + + rr = doNMRequest(t, h, http.MethodGet, "/monitors/probe-mon/probes/"+probeID, nil) + if rr.Code != http.StatusOK { + t.Fatalf("get probe: status %d", rr.Code) + } + + rr = doNMRequest(t, h, http.MethodPatch, "/monitors/probe-mon/probes/"+probeID, map[string]any{ + "destination": "10.0.0.3", + }) + if rr.Code != http.StatusOK { + t.Fatalf("update probe: status %d — %s", rr.Code, rr.Body.String()) + } + + rr = doNMRequest(t, h, http.MethodDelete, "/monitors/probe-mon/probes/"+probeID, nil) + if rr.Code != http.StatusOK { + t.Fatalf("delete probe: status %d", rr.Code) + } + + rr = doNMRequest(t, h, http.MethodGet, "/monitors/probe-mon/probes/"+probeID, nil) + if rr.Code != http.StatusNotFound { + t.Errorf("expected 404 after delete, got %d", rr.Code) + } +} + +func TestHandlerTags(t *testing.T) { + t.Parallel() + + h := newTestHandler(t) + + rr := doNMRequest(t, h, http.MethodPost, "/monitors", map[string]any{ + "monitorName": "tagged", + "tags": map[string]any{"env": "prod"}, + }) + + if rr.Code != http.StatusOK { + t.Fatalf("create: status %d", rr.Code) + } + + var mon map[string]any + if err := json.Unmarshal(rr.Body.Bytes(), &mon); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + monARN, _ := mon["monitorArn"].(string) + if monARN == "" { + t.Fatal("expected monitorArn in response") + } + + rr = doNMRequest(t, h, http.MethodGet, "/tags/"+monARN, nil) + if rr.Code != http.StatusOK { + t.Fatalf("list tags: status %d", rr.Code) + } + + var tagResp map[string]any + if err := json.Unmarshal(rr.Body.Bytes(), &tagResp); err != nil { + t.Fatalf("unmarshal tags: %v", err) + } + + tags, _ := tagResp["tags"].(map[string]any) + if tags["env"] != "prod" { + t.Errorf("tag env: got %v, want prod", tags["env"]) + } + + rr = doNMRequest(t, h, http.MethodPost, "/tags/"+monARN, map[string]any{ + "tags": map[string]any{"team": "sre"}, + }) + if rr.Code != http.StatusOK { + t.Fatalf("tag: status %d", rr.Code) + } + + rr = doNMRequest(t, h, http.MethodDelete, "/tags/"+monARN+"?tagKeys=env", nil) + if rr.Code != http.StatusOK { + t.Fatalf("untag: status %d", rr.Code) + } + + rr = doNMRequest(t, h, http.MethodGet, "/tags/"+monARN, nil) + if err := json.Unmarshal(rr.Body.Bytes(), &tagResp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + tags, _ = tagResp["tags"].(map[string]any) + + if _, ok := tags["env"]; ok { + t.Error("env tag should be removed") + } + + if tags["team"] != "sre" { + t.Errorf("team tag: got %v, want sre", tags["team"]) + } +} diff --git a/services/networkmonitor/models.go b/services/networkmonitor/models.go new file mode 100644 index 000000000..d5d764054 --- /dev/null +++ b/services/networkmonitor/models.go @@ -0,0 +1,143 @@ +package networkmonitor + +import "time" + +// Monitor represents an AWS CloudWatch Network Monitor monitor. +type Monitor struct { + CreatedAt *time.Time `json:"createdAt,omitempty"` + ModifiedAt *time.Time `json:"modifiedAt,omitempty"` + Tags map[string]string `json:"tags,omitempty"` + MonitorArn string `json:"monitorArn"` + MonitorName string `json:"monitorName"` + State string `json:"state"` + Probes []*Probe `json:"probes,omitempty"` + AggregationPeriod int64 `json:"aggregationPeriod"` +} + +// Probe represents a network monitor probe. +type Probe struct { + CreatedAt *time.Time `json:"createdAt,omitempty"` + ModifiedAt *time.Time `json:"modifiedAt,omitempty"` + Tags map[string]string `json:"tags,omitempty"` + PacketSize *int32 `json:"packetSize,omitempty"` + DestinationPort *int32 `json:"destinationPort,omitempty"` + Destination string `json:"destination"` + SourceArn string `json:"sourceArn"` + Protocol string `json:"protocol"` + State string `json:"state"` + AddressFamily string `json:"addressFamily,omitempty"` + VpcID string `json:"vpcId,omitempty"` + ProbeID string `json:"probeId,omitempty"` + ProbeArn string `json:"probeArn,omitempty"` +} + +// monitorSummary is the short form returned by ListMonitors. +type monitorSummary struct { + Tags map[string]string `json:"tags,omitempty"` + AggregationPeriod *int64 `json:"aggregationPeriod,omitempty"` + MonitorArn string `json:"monitorArn"` + MonitorName string `json:"monitorName"` + State string `json:"state"` +} + +// createMonitorProbeInput is the probe input nested in CreateMonitor. +type createMonitorProbeInput struct { + Tags map[string]string `json:"probeTags,omitempty"` + DestinationPort *int32 `json:"destinationPort,omitempty"` + PacketSize *int32 `json:"packetSize,omitempty"` + Destination string `json:"destination"` + Protocol string `json:"protocol"` + SourceArn string `json:"sourceArn"` +} + +// createMonitorRequest is the request body for POST /monitors. +type createMonitorRequest struct { + Tags map[string]string `json:"tags,omitempty"` + AggregationPeriod *int64 `json:"aggregationPeriod,omitempty"` + MonitorName string `json:"monitorName"` + ClientToken string `json:"clientToken,omitempty"` + Probes []createMonitorProbeInput `json:"probes,omitempty"` +} + +// createMonitorResponse is the response body for POST /monitors. +type createMonitorResponse struct { + Tags map[string]string `json:"tags,omitempty"` + AggregationPeriod *int64 `json:"aggregationPeriod,omitempty"` + MonitorArn string `json:"monitorArn"` + MonitorName string `json:"monitorName"` + State string `json:"state"` +} + +// updateMonitorRequest is the request body for PATCH /monitors/{monitorName}. +type updateMonitorRequest struct { + AggregationPeriod int64 `json:"aggregationPeriod"` +} + +// updateMonitorResponse is the response body for PATCH /monitors/{monitorName}. +type updateMonitorResponse struct { + Tags map[string]string `json:"tags,omitempty"` + AggregationPeriod *int64 `json:"aggregationPeriod,omitempty"` + MonitorArn string `json:"monitorArn"` + MonitorName string `json:"monitorName"` + State string `json:"state"` +} + +// listMonitorsResponse is the response body for GET /monitors. +type listMonitorsResponse struct { + NextToken string `json:"nextToken,omitempty"` + Monitors []monitorSummary `json:"monitors"` +} + +// probeInput is the probe input for CreateProbe. +type probeInput struct { + Tags map[string]string `json:"tags,omitempty"` + DestinationPort *int32 `json:"destinationPort,omitempty"` + PacketSize *int32 `json:"packetSize,omitempty"` + Destination string `json:"destination"` + Protocol string `json:"protocol"` + SourceArn string `json:"sourceArn"` +} + +// createProbeRequest is the request body for POST /monitors/{monitorName}/probes. +type createProbeRequest struct { + Tags map[string]string `json:"tags,omitempty"` + Probe *probeInput `json:"probe"` + ClientToken string `json:"clientToken,omitempty"` +} + +// updateProbeRequest is the request body for PATCH /monitors/{monitorName}/probes/{probeId}. +type updateProbeRequest struct { + Tags map[string]string `json:"tags,omitempty"` + DestinationPort *int32 `json:"destinationPort,omitempty"` + PacketSize *int32 `json:"packetSize,omitempty"` + Destination string `json:"destination,omitempty"` + Protocol string `json:"protocol,omitempty"` + State string `json:"state,omitempty"` +} + +// getMonitorResponse is the response body for GET /monitors/{monitorName}. +type getMonitorResponse struct { + CreatedAt *time.Time `json:"createdAt"` + ModifiedAt *time.Time `json:"modifiedAt"` + Tags map[string]string `json:"tags,omitempty"` + MonitorArn string `json:"monitorArn"` + MonitorName string `json:"monitorName"` + State string `json:"state"` + Probes []*Probe `json:"probes,omitempty"` + AggregationPeriod int64 `json:"aggregationPeriod"` +} + +// listTagsForResourceResponse is the response body for GET /tags/{resourceArn}. +type listTagsForResourceResponse struct { + Tags map[string]string `json:"tags"` +} + +// tagResourceRequest is the request body for POST /tags/{resourceArn}. +type tagResourceRequest struct { + Tags map[string]string `json:"tags"` +} + +// errorResponse is the standard error response body. +type errorResponse struct { + Message string `json:"message"` +} diff --git a/services/networkmonitor/persistence.go b/services/networkmonitor/persistence.go new file mode 100644 index 000000000..fc387d89a --- /dev/null +++ b/services/networkmonitor/persistence.go @@ -0,0 +1,75 @@ +package networkmonitor + +import ( + "encoding/json" + "log/slog" +) + +type backendSnapshot struct { + Monitors map[string]map[string]*Monitor `json:"monitors"` + NextProbeSeq int64 `json:"next_probe_seq"` +} + +// Snapshot serialises the backend state to JSON. +// It implements persistence.Persistable. +func (b *InMemoryBackend) Snapshot() []byte { + b.mu.RLock() + defer b.mu.RUnlock() + + monsCopy := make(map[string]map[string]*Monitor, len(b.monitors)) + + for region, regionMons := range b.monitors { + monsCopy[region] = make(map[string]*Monitor, len(regionMons)) + + for k, v := range regionMons { + monsCopy[region][k] = monitorCopy(v) + } + } + + snap := backendSnapshot{ + Monitors: monsCopy, + NextProbeSeq: b.nextProbeSeq, + } + + data, err := json.Marshal(snap) + if err != nil { + slog.Default().Warn("networkmonitor: failed to marshal snapshot", "error", err) + + return nil + } + + return data +} + +// Restore loads backend state from a JSON snapshot. +// It implements persistence.Persistable. +func (b *InMemoryBackend) Restore(data []byte) error { + var snap backendSnapshot + + if err := json.Unmarshal(data, &snap); err != nil { + return err + } + + b.mu.Lock() + defer b.mu.Unlock() + + if snap.Monitors != nil { + b.monitors = snap.Monitors + } else { + b.monitors = make(map[string]map[string]*Monitor) + } + + b.arnIndex = make(map[string]map[string]string) + + for region, regionMons := range b.monitors { + b.arnIndex[region] = make(map[string]string, len(regionMons)) + + for name, m := range regionMons { + b.arnIndex[region][m.MonitorArn] = name + } + } + + b.nextProbeSeq = snap.NextProbeSeq + + return nil +} diff --git a/services/networkmonitor/provider.go b/services/networkmonitor/provider.go new file mode 100644 index 000000000..e157664c5 --- /dev/null +++ b/services/networkmonitor/provider.go @@ -0,0 +1,42 @@ +package networkmonitor + +import ( + "errors" + + "github.com/blackbirdworks/gopherstack/pkgs/config" + "github.com/blackbirdworks/gopherstack/pkgs/service" +) + +// ErrNilAppContext is returned by Provider.Init when a nil AppContext is supplied. +var ErrNilAppContext = errors.New("networkmonitor: AppContext must not be nil") + +// Provider implements service.Provider for the Network Monitor service. +type Provider struct{} + +// Name returns the provider name. +func (p *Provider) Name() string { return "NetworkMonitor" } + +// Init initializes the Network Monitor service backend and handler. +// +//nolint:ireturn,nolintlint // architecturally required to return interface +func (p *Provider) Init(ctx *service.AppContext) (service.Registerable, error) { + if ctx == nil { + return nil, ErrNilAppContext + } + + accountID := config.DefaultAccountID + region := config.DefaultRegion + + if cp, ok := ctx.Config.(config.Provider); ok { + cfg := cp.GetGlobalConfig() + accountID = cfg.GetAccountID() + region = cfg.GetRegion() + } + + backend := NewInMemoryBackend(region, accountID) + handler := NewHandler(backend) + handler.AccountID = accountID + handler.DefaultRegion = region + + return handler, nil +} diff --git a/services/networkmonitor/sdk_completeness_test.go b/services/networkmonitor/sdk_completeness_test.go new file mode 100644 index 000000000..5db7201b2 --- /dev/null +++ b/services/networkmonitor/sdk_completeness_test.go @@ -0,0 +1,27 @@ +package networkmonitor_test + +import ( + "testing" + + networkmonitorsdk "github.com/aws/aws-sdk-go-v2/service/networkmonitor" + + "github.com/blackbirdworks/gopherstack/pkgs/sdkcheck" + "github.com/blackbirdworks/gopherstack/services/networkmonitor" +) + +// TestSDKCompleteness verifies that every operation exposed by the AWS SDK v2 +// networkmonitor client is either listed in GetSupportedOperations() or explicitly +// acknowledged in the notImplemented slice. The test fails when the upstream +// SDK adds a new operation that gopherstack has not yet handled. +func TestSDKCompleteness(t *testing.T) { + t.Parallel() + + backend := networkmonitor.NewInMemoryBackend("us-east-1", "000000000000") + h := networkmonitor.NewHandler(backend) + sdkcheck.CheckCompleteness( + t, + &networkmonitorsdk.Client{}, + h.GetSupportedOperations(), + []string{}, + ) +}