Skip to content

fix: replace deprecated torch.cuda.amp with torch.amp#445

Open
haoyu-haoyu wants to merge 1 commit intoRosettaCommons:mainfrom
haoyu-haoyu:fix/replace-deprecated-torch-cuda-amp
Open

fix: replace deprecated torch.cuda.amp with torch.amp#445
haoyu-haoyu wants to merge 1 commit intoRosettaCommons:mainfrom
haoyu-haoyu:fix/replace-deprecated-torch-cuda-amp

Conversation

@haoyu-haoyu
Copy link

Summary

  • Replace all torch.cuda.amp.autocasttorch.amp.autocast('cuda', ...) (3 instances)
  • Replace torch.cuda.amp.GradScalertorch.amp.GradScaler('cuda', ...) (1 instance)

The torch.cuda.amp.* APIs were deprecated in PyTorch 1.13 (migration guide) and emit FutureWarning in PyTorch 2.4+. The new torch.amp.* equivalents require an explicit device_type argument.

Files changed

File Change
rfdiffusion/Track_module.py:236 @torch.cuda.amp.autocast(enabled=False)@torch.amp.autocast('cuda', enabled=False)
env/SE3Transformer/.../inference.py:52 context manager
env/SE3Transformer/.../training.py:93 context manager
env/SE3Transformer/.../training.py:130 GradScaler constructor

Test plan

  • Verify no FutureWarning with python -W error::FutureWarning
  • Run existing test suite (tests/test_diffusion.py)
  • Confirm inference output is unchanged

`torch.cuda.amp.autocast`, `torch.cuda.amp.GradScaler` were deprecated
in PyTorch 1.13 and will be removed in a future release. Replace with
the device-explicit `torch.amp.autocast('cuda', ...)` and
`torch.amp.GradScaler('cuda', ...)` equivalents.

Files changed:
- rfdiffusion/Track_module.py (decorator on Str2Str.forward)
- env/SE3Transformer/se3_transformer/runtime/inference.py
- env/SE3Transformer/se3_transformer/runtime/training.py (2 instances)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant