diff --git a/env/SE3Transformer/se3_transformer/runtime/inference.py b/env/SE3Transformer/se3_transformer/runtime/inference.py index 21e9125b..4ecc47fb 100644 --- a/env/SE3Transformer/se3_transformer/runtime/inference.py +++ b/env/SE3Transformer/se3_transformer/runtime/inference.py @@ -49,7 +49,7 @@ def evaluate(model: nn.Module, for callback in callbacks: callback.on_batch_start() - with torch.cuda.amp.autocast(enabled=args.amp): + with torch.amp.autocast('cuda', enabled=args.amp): pred = model(*input) for callback in callbacks: diff --git a/env/SE3Transformer/se3_transformer/runtime/training.py b/env/SE3Transformer/se3_transformer/runtime/training.py index 53122779..5931be16 100644 --- a/env/SE3Transformer/se3_transformer/runtime/training.py +++ b/env/SE3Transformer/se3_transformer/runtime/training.py @@ -90,7 +90,7 @@ def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimi for callback in callbacks: callback.on_batch_start() - with torch.cuda.amp.autocast(enabled=args.amp): + with torch.amp.autocast('cuda', enabled=args.amp): pred = model(*inputs) loss = loss_fn(pred, target) / args.accumulate_grad_batches @@ -127,7 +127,7 @@ def train(model: nn.Module, model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) model.train() - grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp) + grad_scaler = torch.amp.GradScaler('cuda', enabled=args.amp) if args.optimizer == 'adam': optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999), weight_decay=args.weight_decay) diff --git a/rfdiffusion/Track_module.py b/rfdiffusion/Track_module.py index 27511e5d..727c0a8a 100644 --- a/rfdiffusion/Track_module.py +++ b/rfdiffusion/Track_module.py @@ -233,7 +233,7 @@ def reset_parameter(self): nn.init.zeros_(self.embed_e1.bias) nn.init.zeros_(self.embed_e2.bias) - @torch.cuda.amp.autocast(enabled=False) + @torch.amp.autocast('cuda', enabled=False) def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, cyclic_reses=None, top_k=64, eps=1e-5): B, N, L = msa.shape[:3]