Skip to content

perf: reduce GPU memory usage during inference#447

Open
haoyu-haoyu wants to merge 1 commit intoRosettaCommons:mainfrom
haoyu-haoyu:perf/reduce-gpu-memory-usage
Open

perf: reduce GPU memory usage during inference#447
haoyu-haoyu wants to merge 1 commit intoRosettaCommons:mainfrom
haoyu-haoyu:perf/reduce-gpu-memory-usage

Conversation

@haoyu-haoyu
Copy link

Summary

  • Offload trajectory to CPU per step: Move px0, x_t, seq, plddt to CPU immediately after each denoising step instead of accumulating all steps on GPU. Peak GPU memory is reduced from O(T × L × 14 × 3) to O(L × 14 × 3) for trajectory storage.

  • Remove unused ComputeAllAtomCoords() instantiations: get_next_ca() and get_next_pose() in inference/utils.py created a ComputeAllAtomCoords() module every timestep but never used it. Removed both instantiations and the now-unused import.

Why this is safe

x_t and seq_t are produced fresh by sampler.sample_step() each iteration — the active computation tensors remain on GPU. Only the previous step's saved copies (for trajectory visualization) are moved to CPU. The final torch.stack() and torch.flip() operate on CPU tensors, which is fine since they're only used for PDB writing.

Impact

For a 200-residue, 50-step design, this reduces GPU memory for trajectory storage from ~50× to 1× the single-step footprint. This is especially important for longer proteins (>300 residues) that can OOM on 16GB GPUs.

Test plan

  • Compare output PDB coordinates with and without this change (should be identical)
  • Profile GPU memory with torch.cuda.max_memory_allocated()
  • Run tests/test_diffusion.py

- Move trajectory data (px0, x_t, seq, plddt) to CPU immediately after
  each denoising step instead of accumulating on GPU. This frees GPU
  memory for the next forward pass, reducing peak memory usage
  proportional to the number of diffusion steps.

- Remove two unused ComputeAllAtomCoords() instantiations in
  get_next_ca() and get_next_pose() that were created every timestep
  but never referenced, wasting memory and compute.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant