From 5735aff0e6d512ffa26a4d787613820c2c9bf16f Mon Sep 17 00:00:00 2001 From: SexyERIC0723 Date: Fri, 20 Mar 2026 17:35:19 +0000 Subject: [PATCH] perf: reduce GPU memory usage during inference - Move trajectory data (px0, x_t, seq, plddt) to CPU immediately after each denoising step instead of accumulating on GPU. This frees GPU memory for the next forward pass, reducing peak memory usage proportional to the number of diffusion steps. - Remove two unused ComputeAllAtomCoords() instantiations in get_next_ca() and get_next_pose() that were created every timestep but never referenced, wasting memory and compute. --- rfdiffusion/inference/utils.py | 4 +--- scripts/run_inference.py | 11 +++++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/rfdiffusion/inference/utils.py b/rfdiffusion/inference/utils.py index 3fb14112..4e318118 100644 --- a/rfdiffusion/inference/utils.py +++ b/rfdiffusion/inference/utils.py @@ -6,7 +6,7 @@ from rfdiffusion.diffusion import get_beta_schedule from scipy.spatial.transform import Rotation as scipy_R from rfdiffusion.util import rigid_from_3_points -from rfdiffusion.util_module import ComputeAllAtomCoords + from rfdiffusion import util import random import logging @@ -152,7 +152,6 @@ def get_next_ca( noise_scale: scale factor for the noise being added """ - get_allatom = ComputeAllAtomCoords().to(device=xt.device) L = len(xt) # bring to origin after global alignment (when don't have a motif) or replace input motif and bring to origin, and then scale @@ -435,7 +434,6 @@ def get_next_pose( include_motif_sidechains (bool): Provide sidechains of the fixed motif to the model """ - get_allatom = ComputeAllAtomCoords().to(device=xt.device) L, n_atom = xt.shape[:2] assert (xt.shape[1] == 14) or (xt.shape[1] == 27) assert (px0.shape[1] == 14) or (px0.shape[1] == 27) diff --git a/scripts/run_inference.py b/scripts/run_inference.py index 3fb6466e..623516e1 100755 --- a/scripts/run_inference.py +++ b/scripts/run_inference.py @@ -94,10 +94,13 @@ def main(conf: HydraConfig) -> None: px0, x_t, seq_t, plddt = sampler.sample_step( t=t, x_t=x_t, seq_init=seq_t, final_step=sampler.inf_conf.final_step ) - px0_xyz_stack.append(px0) - denoised_xyz_stack.append(x_t) - seq_stack.append(seq_t) - plddt_stack.append(plddt[0]) # remove singleton leading dimension + # Move trajectory data to CPU immediately to free GPU memory + # for the next denoising step. x_t/seq_t are cloned inside + # sample_step, so moving previous outputs to CPU is safe. + px0_xyz_stack.append(px0.cpu()) + denoised_xyz_stack.append(x_t.cpu()) + seq_stack.append(seq_t.cpu()) + plddt_stack.append(plddt[0].cpu()) # remove singleton leading dimension # Flip order for better visualization in pymol denoised_xyz_stack = torch.stack(denoised_xyz_stack)