diff --git a/pkg/deviceclaimingserver/gateways/gateways.go b/pkg/deviceclaimingserver/gateways/gateways.go index 919995f679..9e31ac45ef 100644 --- a/pkg/deviceclaimingserver/gateways/gateways.go +++ b/pkg/deviceclaimingserver/gateways/gateways.go @@ -92,10 +92,10 @@ func ParseGatewayEUIRanges(conf map[string][]string) (map[string][]dcstypes.EUI6 type Claimer interface { // Claim claims a gateway. Claim( - ctx context.Context, eui types.EUI64, ownerToken string, clusterAddress string, + ctx context.Context, ids *ttnpb.GatewayIdentifiers, ownerToken string, clusterAddress string, ) (*dcstypes.GatewayMetadata, error) // Unclaim unclaims a gateway. - Unclaim(ctx context.Context, eui types.EUI64) error + Unclaim(ctx context.Context, ids *ttnpb.GatewayIdentifiers) error // IsManagedGateway returns true if the gateway is a managed gateway. IsManagedGateway(ctx context.Context, eui types.EUI64) (bool, error) } diff --git a/pkg/deviceclaimingserver/gateways/ttgc/lbscups.go b/pkg/deviceclaimingserver/gateways/ttgc/lbscups.go index 4d79613f80..5902b2f561 100644 --- a/pkg/deviceclaimingserver/gateways/ttgc/lbscups.go +++ b/pkg/deviceclaimingserver/gateways/ttgc/lbscups.go @@ -38,13 +38,10 @@ var ( ) func (u *Upstream) claimLBSCUPSGateway( - ctx context.Context, eui types.EUI64, ownerToken, clusterAddress string, + ctx context.Context, ids *ttnpb.GatewayIdentifiers, ownerToken, clusterAddress string, ) (*dcstypes.GatewayMetadata, error) { logger := log.FromContext(ctx) - - ids := &ttnpb.GatewayIdentifiers{ - Eui: eui.Bytes(), - } + eui := types.MustEUI64(ids.Eui).OrZero() // Create CUPS and LNS API keys for the gateway. The CUPS key will be used as gateway token when claiming on TTGC and // the LNS key will be returned in the metadata. The caller is responsible for updating the LNS key in the gateway. @@ -165,7 +162,7 @@ func (u *Upstream) createAPIKeys( cupsKey, err = gatewayAccess.CreateAPIKey(ctx, &ttnpb.CreateGatewayAPIKeyRequest{ GatewayIds: ids, - Name: fmt.Sprintf("LBS CUPS Key (TTGC claim), generated %s", time.Now().UTC().Format(time.RFC3339)), + Name: fmt.Sprintf("LBS CUPS Key (TTGC), %s", time.Now().UTC().Format(time.RFC3339)), Rights: []ttnpb.Right{ ttnpb.Right_RIGHT_GATEWAY_INFO, ttnpb.Right_RIGHT_GATEWAY_SETTINGS_BASIC, @@ -179,7 +176,7 @@ func (u *Upstream) createAPIKeys( lnsKey, err = gatewayAccess.CreateAPIKey(ctx, &ttnpb.CreateGatewayAPIKeyRequest{ GatewayIds: ids, - Name: fmt.Sprintf("LBS LNS Key (TTGC claim), generated %s", time.Now().UTC().Format(time.RFC3339)), + Name: fmt.Sprintf("LBS LNS Key (TTGC), %s", time.Now().UTC().Format(time.RFC3339)), Rights: []ttnpb.Right{ ttnpb.Right_RIGHT_GATEWAY_LINK, }, diff --git a/pkg/deviceclaimingserver/gateways/ttgc/ttgc.go b/pkg/deviceclaimingserver/gateways/ttgc/ttgc.go index db42b96bfe..99a631c606 100644 --- a/pkg/deviceclaimingserver/gateways/ttgc/ttgc.go +++ b/pkg/deviceclaimingserver/gateways/ttgc/ttgc.go @@ -70,13 +70,15 @@ func New(ctx context.Context, c component, config ttgc.Config) (*Upstream, error type claimOption struct { protocol northboundv1.GatewayProtocolIdentifier authMethod northboundv1.AuthenticationMethod - handler func(context.Context, types.EUI64, string, string) (*dcstypes.GatewayMetadata, error) + handler func(context.Context, *ttnpb.GatewayIdentifiers, string, string) (*dcstypes.GatewayMetadata, error) } // Claim implements gateways.GatewayClaimer. func (u *Upstream) Claim( - ctx context.Context, eui types.EUI64, ownerToken, clusterAddress string, + ctx context.Context, ids *ttnpb.GatewayIdentifiers, ownerToken, clusterAddress string, ) (*dcstypes.GatewayMetadata, error) { + eui := types.MustEUI64(ids.Eui).OrZero() + // Get the gateway description to verify what protocol it supports. gtwClient := northboundv1.NewGatewayServiceClient(u.client) desc, err := gtwClient.Describe(ctx, &northboundv1.GatewayServiceDescribeRequest{ @@ -103,7 +105,7 @@ func (u *Upstream) Claim( // Select the first supported claiming option and use its handler. for _, option := range claimPreferences { if u.supportsOption(desc, option) { - return option.handler(ctx, eui, ownerToken, clusterAddress) + return option.handler(ctx, ids, ownerToken, clusterAddress) } } @@ -127,9 +129,11 @@ func (*Upstream) supportsOption( } // Unclaim implements gateways.GatewayClaimer. -func (u *Upstream) Unclaim(ctx context.Context, eui types.EUI64) error { +func (u *Upstream) Unclaim(ctx context.Context, ids *ttnpb.GatewayIdentifiers) error { + eui := types.MustEUI64(ids.Eui).OrZero() + // Delete the CUPS and LNS API keys for the gateway. - if err := u.deleteAPIKeys(ctx, &ttnpb.GatewayIdentifiers{Eui: eui.Bytes()}); err != nil { + if err := u.deleteAPIKeys(ctx, ids); err != nil { // Don't fail unclaiming if deleting the API keys fails. log.FromContext(ctx).WithError(err).Warn("Failed to delete API keys for gateway") } diff --git a/pkg/deviceclaimingserver/gateways/ttgc/ttiv1.go b/pkg/deviceclaimingserver/gateways/ttgc/ttiv1.go index 12d32b965c..b4c4cb147d 100644 --- a/pkg/deviceclaimingserver/gateways/ttgc/ttiv1.go +++ b/pkg/deviceclaimingserver/gateways/ttgc/ttiv1.go @@ -36,9 +36,10 @@ import ( // 3. Upsert a Geolocation profile // 4. Update the gateway with the profiles func (u *Upstream) claimTTIV1Gateway( - ctx context.Context, eui types.EUI64, ownerToken, clusterAddress string, + ctx context.Context, ids *ttnpb.GatewayIdentifiers, ownerToken, clusterAddress string, ) (*dcstypes.GatewayMetadata, error) { logger := log.FromContext(ctx) + eui := types.MustEUI64(ids.Eui).OrZero() // Claim the gateway. gtwClient := northboundv1.NewGatewayServiceClient(u.client) diff --git a/pkg/deviceclaimingserver/grpc_gateways.go b/pkg/deviceclaimingserver/grpc_gateways.go index 35ad56fb29..389551afad 100644 --- a/pkg/deviceclaimingserver/grpc_gateways.go +++ b/pkg/deviceclaimingserver/grpc_gateways.go @@ -17,6 +17,7 @@ package deviceclaimingserver import ( "context" "fmt" + "strings" "go.thethings.network/lorawan-stack/v3/pkg/deviceclaimingserver/gateways" "go.thethings.network/lorawan-stack/v3/pkg/deviceclaimingserver/observability" @@ -95,9 +96,13 @@ func (gcls *gatewayClaimingServer) Claim( logger = logger.WithFields(log.Fields( "gateway_eui", gatewayEUI, )) + gatewayID := req.TargetGatewayId + if gatewayID == "" { + gatewayID = strings.ToLower(gatewayEUI.String()) + } ids = &ttnpb.GatewayIdentifiers{ Eui: gatewayEUI.Bytes(), - GatewayId: req.TargetGatewayId, + GatewayId: gatewayID, } // Check if the gateway already exists. @@ -113,21 +118,24 @@ func (gcls *gatewayClaimingServer) Claim( Ids: ids, } - _, err = gcls.registry.Create(ctx, &ttnpb.CreateGatewayRequest{ + created, err := gcls.registry.Create(ctx, &ttnpb.CreateGatewayRequest{ Gateway: gateway, Collaborator: req.GetCollaborator(), }) if err != nil { return nil, errCreateGateway.WithCause(err) } - defer func() { + if createdIDs := created.GetIds(); createdIDs != nil { + ids = createdIDs + } + defer func(ids *ttnpb.GatewayIdentifiers) { if retErr != nil { logger.Warn("Failed to claim gateway, deleting created gateway") if _, delErr := gcls.registry.Delete(ctx, ids); delErr != nil { logger.WithError(delErr).Warn("Failed to delete created gateway after failed claim") } } - }() + }(ids) // Support clients that only set a single frequency plan. if len(req.TargetFrequencyPlanIds) == 0 && req.TargetFrequencyPlanId != "" { // nolint:staticcheck @@ -141,7 +149,7 @@ func (gcls *gatewayClaimingServer) Claim( } // Claim the gateway on the upstream. - res, err := claimer.Claim(ctx, gatewayEUI, string(authCode), req.TargetGatewayServerAddress) + res, err := claimer.Claim(ctx, ids, string(authCode), req.TargetGatewayServerAddress) if err != nil { observability.RegisterFailClaim(ctx, ids.GetEntityIdentifiers(), err) return nil, errClaim.WithCause(err) @@ -151,7 +159,7 @@ func (gcls *gatewayClaimingServer) Claim( defer func(ids *ttnpb.GatewayIdentifiers) { if retErr != nil { observability.RegisterAbortClaim(ctx, ids.GetEntityIdentifiers(), retErr) - if err := claimer.Unclaim(ctx, gatewayEUI); err != nil { + if err := claimer.Unclaim(ctx, ids); err != nil { logger.WithError(err).Warn("Failed to unclaim gateway") } return @@ -256,7 +264,7 @@ func (gcls gatewayClaimingServer) Unclaim(ctx context.Context, req *ttnpb.Gatewa return nil, errGatewayClaimingNotSupported.WithAttributes("eui", gatewayEUI) } - if err := claimer.Unclaim(ctx, gatewayEUI); err != nil { + if err := claimer.Unclaim(ctx, gtw.Ids); err != nil { observability.RegisterFailUnclaim(ctx, gtw.GetEntityIdentifiers(), err) return nil, err } diff --git a/pkg/deviceclaimingserver/grpc_gateways_test.go b/pkg/deviceclaimingserver/grpc_gateways_test.go index 0bc10e5362..780c68d4b9 100644 --- a/pkg/deviceclaimingserver/grpc_gateways_test.go +++ b/pkg/deviceclaimingserver/grpc_gateways_test.go @@ -181,10 +181,10 @@ func TestGatewayClaimingServer(t *testing.T) { //nolint:paralleltest Name string Req *ttnpb.ClaimGatewayRequest CallOpt grpc.CallOption - ClaimFunc func(context.Context, types.EUI64, string, string) (*dcstypes.GatewayMetadata, error) + ClaimFunc func(context.Context, *ttnpb.GatewayIdentifiers, string, string) (*dcstypes.GatewayMetadata, error) CreateFunc func(context.Context, *ttnpb.CreateGatewayRequest) (*ttnpb.Gateway, error) UpdateFunc func(context.Context, *ttnpb.UpdateGatewayRequest) (*ttnpb.Gateway, error) - UnclaimFunc func(context.Context, types.EUI64) error + UnclaimFunc func(context.Context, *ttnpb.GatewayIdentifiers) error DeleteFunc func(context.Context, *ttnpb.GatewayIdentifiers) (*emptypb.Empty, error) ErrorAssertion func(error) bool }{ @@ -302,7 +302,7 @@ func TestGatewayClaimingServer(t *testing.T) { //nolint:paralleltest CreateFunc: func(_ context.Context, in *ttnpb.CreateGatewayRequest) (*ttnpb.Gateway, error) { return in.Gateway, nil }, - ClaimFunc: func(_ context.Context, _ types.EUI64, _, _ string) (*dcstypes.GatewayMetadata, error) { + ClaimFunc: func(_ context.Context, _ *ttnpb.GatewayIdentifiers, _, _ string) (*dcstypes.GatewayMetadata, error) { return nil, errClaim.New() }, UpdateFunc: func(_ context.Context, in *ttnpb.UpdateGatewayRequest) (*ttnpb.Gateway, error) { @@ -327,7 +327,7 @@ func TestGatewayClaimingServer(t *testing.T) { //nolint:paralleltest TargetGatewayServerAddress: "things.example.com", }, CallOpt: authorizedCallOpt, - ClaimFunc: func(context.Context, types.EUI64, string, string) (*dcstypes.GatewayMetadata, error) { + ClaimFunc: func(context.Context, *ttnpb.GatewayIdentifiers, string, string) (*dcstypes.GatewayMetadata, error) { return &dcstypes.GatewayMetadata{}, nil }, CreateFunc: func(context.Context, *ttnpb.CreateGatewayRequest) (*ttnpb.Gateway, error) { @@ -339,8 +339,8 @@ func TestGatewayClaimingServer(t *testing.T) { //nolint:paralleltest DeleteFunc: func(_ context.Context, _ *ttnpb.GatewayIdentifiers) (*emptypb.Empty, error) { return &emptypb.Empty{}, nil }, - UnclaimFunc: func(_ context.Context, eui types.EUI64) error { - if eui.Equal(supportedEUI) { + UnclaimFunc: func(_ context.Context, ids *ttnpb.GatewayIdentifiers) error { + if types.MustEUI64(ids.Eui).OrZero().Equal(supportedEUI) { return nil } return errUnclaim.New() @@ -361,7 +361,7 @@ func TestGatewayClaimingServer(t *testing.T) { //nolint:paralleltest TargetGatewayServerAddress: "things.example.com", }, CallOpt: authorizedCallOpt, - ClaimFunc: func(context.Context, types.EUI64, string, string) (*dcstypes.GatewayMetadata, error) { + ClaimFunc: func(context.Context, *ttnpb.GatewayIdentifiers, string, string) (*dcstypes.GatewayMetadata, error) { return &dcstypes.GatewayMetadata{}, nil }, CreateFunc: func(context.Context, *ttnpb.CreateGatewayRequest) (*ttnpb.Gateway, error) { @@ -373,7 +373,7 @@ func TestGatewayClaimingServer(t *testing.T) { //nolint:paralleltest DeleteFunc: func(_ context.Context, _ *ttnpb.GatewayIdentifiers) (*emptypb.Empty, error) { return &emptypb.Empty{}, nil }, - UnclaimFunc: func(context.Context, types.EUI64) error { + UnclaimFunc: func(context.Context, *ttnpb.GatewayIdentifiers) error { return errUnclaim.New() }, ErrorAssertion: errors.IsAborted, @@ -391,7 +391,66 @@ func TestGatewayClaimingServer(t *testing.T) { //nolint:paralleltest TargetGatewayId: "test-gateway", TargetGatewayServerAddress: "things.example.com", }, - ClaimFunc: func(context.Context, types.EUI64, string, string) (*dcstypes.GatewayMetadata, error) { + ClaimFunc: func(context.Context, *ttnpb.GatewayIdentifiers, string, string) (*dcstypes.GatewayMetadata, error) { + return &dcstypes.GatewayMetadata{}, nil + }, + CreateFunc: func(_ context.Context, in *ttnpb.CreateGatewayRequest) (*ttnpb.Gateway, error) { + return in.Gateway, nil + }, + UpdateFunc: func(_ context.Context, in *ttnpb.UpdateGatewayRequest) (*ttnpb.Gateway, error) { + return in.Gateway, nil + }, + CallOpt: authorizedCallOpt, + }, + { + Name: "Claim/EmptyTargetGatewayIDDefaultsToEUIAndDeletesOnFailedClaim", + Req: &ttnpb.ClaimGatewayRequest{ + Collaborator: userID.GetOrganizationOrUserIdentifiers(), + SourceGateway: &ttnpb.ClaimGatewayRequest_AuthenticatedIdentifiers_{ + AuthenticatedIdentifiers: &ttnpb.ClaimGatewayRequest_AuthenticatedIdentifiers{ + GatewayEui: supportedEUI.Bytes(), + AuthenticationCode: claimAuthCode, + }, + }, + TargetGatewayServerAddress: "things.example.com", + }, + CallOpt: authorizedCallOpt, + ClaimFunc: func( + _ context.Context, ids *ttnpb.GatewayIdentifiers, _, _ string, + ) (*dcstypes.GatewayMetadata, error) { + a.So(ids.GatewayId, should.Equal, "58a0cbfffe800001") + a.So(ids.Eui, should.Resemble, supportedEUI.Bytes()) + return nil, errClaim.New() + }, + CreateFunc: func(_ context.Context, in *ttnpb.CreateGatewayRequest) (*ttnpb.Gateway, error) { + a.So(in.Gateway.GetIds().GetGatewayId(), should.Equal, "58a0cbfffe800001") + return in.Gateway, nil + }, + DeleteFunc: func(_ context.Context, ids *ttnpb.GatewayIdentifiers) (*emptypb.Empty, error) { + a.So(ids.GatewayId, should.Equal, "58a0cbfffe800001") + a.So(ids.Eui, should.Resemble, supportedEUI.Bytes()) + return &emptypb.Empty{}, nil + }, + ErrorAssertion: errors.IsAborted, + }, + { + Name: "Claim/ForwardsGatewayIdentifiers", + Req: &ttnpb.ClaimGatewayRequest{ + Collaborator: userID.GetOrganizationOrUserIdentifiers(), + SourceGateway: &ttnpb.ClaimGatewayRequest_AuthenticatedIdentifiers_{ + AuthenticatedIdentifiers: &ttnpb.ClaimGatewayRequest_AuthenticatedIdentifiers{ + GatewayEui: supportedEUI.Bytes(), + AuthenticationCode: claimAuthCode, + }, + }, + TargetGatewayId: "forwarded-gateway", + TargetGatewayServerAddress: "things.example.com", + }, + ClaimFunc: func( + _ context.Context, ids *ttnpb.GatewayIdentifiers, _, _ string, + ) (*dcstypes.GatewayMetadata, error) { + a.So(ids.GatewayId, should.Equal, "forwarded-gateway") + a.So(ids.Eui, should.Resemble, supportedEUI.Bytes()) return &dcstypes.GatewayMetadata{}, nil }, CreateFunc: func(_ context.Context, in *ttnpb.CreateGatewayRequest) (*ttnpb.Gateway, error) { @@ -437,7 +496,7 @@ func TestGatewayClaimingServer(t *testing.T) { //nolint:paralleltest Req *ttnpb.GatewayIdentifiers CallOpt grpc.CallOption GetFunc func(context.Context, *ttnpb.GetGatewayRequest) (*ttnpb.Gateway, error) - UnclaimFunc func(context.Context, types.EUI64) error + UnclaimFunc func(context.Context, *ttnpb.GatewayIdentifiers) error ErrorAssertion func(error) bool }{ { @@ -508,7 +567,7 @@ func TestGatewayClaimingServer(t *testing.T) { //nolint:paralleltest GatewayServerAddress: "test.example.com", }, nil }, - UnclaimFunc: func(context.Context, types.EUI64) error { + UnclaimFunc: func(context.Context, *ttnpb.GatewayIdentifiers) error { return errUnclaim.New() }, CallOpt: authorizedCallOpt, @@ -528,7 +587,28 @@ func TestGatewayClaimingServer(t *testing.T) { //nolint:paralleltest GatewayServerAddress: "test.example.com", }, nil }, - UnclaimFunc: func(context.Context, types.EUI64) error { + UnclaimFunc: func(context.Context, *ttnpb.GatewayIdentifiers) error { + return nil + }, + CallOpt: authorizedCallOpt, + }, + { + Name: "Unclaim/ForwardsGatewayIdentifiers", + Req: &ttnpb.GatewayIdentifiers{ + GatewayId: "forwarded-gateway", + }, + GetFunc: func(context.Context, *ttnpb.GetGatewayRequest) (*ttnpb.Gateway, error) { + return &ttnpb.Gateway{ + Ids: &ttnpb.GatewayIdentifiers{ + GatewayId: "forwarded-gateway", + Eui: supportedEUI.Bytes(), + }, + GatewayServerAddress: "test.example.com", + }, nil + }, + UnclaimFunc: func(_ context.Context, ids *ttnpb.GatewayIdentifiers) error { + a.So(ids.GatewayId, should.Equal, "forwarded-gateway") + a.So(ids.Eui, should.Resemble, supportedEUI.Bytes()) return nil }, CallOpt: authorizedCallOpt, diff --git a/pkg/deviceclaimingserver/util_test.go b/pkg/deviceclaimingserver/util_test.go index f5cda85220..5f0d119a30 100644 --- a/pkg/deviceclaimingserver/util_test.go +++ b/pkg/deviceclaimingserver/util_test.go @@ -25,7 +25,7 @@ import ( "google.golang.org/protobuf/types/known/emptypb" ) -// MockEndDeviceClaimer is a mock Claimer. +// MockEndDeviceClaimer is a mock end device Claimer. type MockEndDeviceClaimer struct { JoinEUIs []types.EUI64 @@ -81,24 +81,26 @@ func (m MockEndDeviceClaimer) BatchUnclaim( type MockGatewayClaimer struct { EUIs []types.EUI64 - ClaimFunc func(context.Context, types.EUI64, string, string) (*dcstypes.GatewayMetadata, error) - UnclaimFunc func(context.Context, types.EUI64) error + ClaimFunc func( + context.Context, *ttnpb.GatewayIdentifiers, string, string, + ) (*dcstypes.GatewayMetadata, error) + UnclaimFunc func(context.Context, *ttnpb.GatewayIdentifiers) error IsManagedGatewayFunc func(context.Context, types.EUI64) (bool, error) } // Claim implements gateways.Claimer. func (claimer MockGatewayClaimer) Claim( ctx context.Context, - eui types.EUI64, + ids *ttnpb.GatewayIdentifiers, ownerToken string, clusterAddress string, ) (*dcstypes.GatewayMetadata, error) { - return claimer.ClaimFunc(ctx, eui, ownerToken, clusterAddress) + return claimer.ClaimFunc(ctx, ids, ownerToken, clusterAddress) } // Unclaim implements gateways.Claimer. -func (claimer MockGatewayClaimer) Unclaim(ctx context.Context, eui types.EUI64) error { - return claimer.UnclaimFunc(ctx, eui) +func (claimer MockGatewayClaimer) Unclaim(ctx context.Context, ids *ttnpb.GatewayIdentifiers) error { + return claimer.UnclaimFunc(ctx, ids) } // IsManagedGateway implements gateways.Claimer.