diff --git a/rfdiffusion/inference/utils.py b/rfdiffusion/inference/utils.py index 3fb14112..4e318118 100644 --- a/rfdiffusion/inference/utils.py +++ b/rfdiffusion/inference/utils.py @@ -6,7 +6,7 @@ from rfdiffusion.diffusion import get_beta_schedule from scipy.spatial.transform import Rotation as scipy_R from rfdiffusion.util import rigid_from_3_points -from rfdiffusion.util_module import ComputeAllAtomCoords + from rfdiffusion import util import random import logging @@ -152,7 +152,6 @@ def get_next_ca( noise_scale: scale factor for the noise being added """ - get_allatom = ComputeAllAtomCoords().to(device=xt.device) L = len(xt) # bring to origin after global alignment (when don't have a motif) or replace input motif and bring to origin, and then scale @@ -435,7 +434,6 @@ def get_next_pose( include_motif_sidechains (bool): Provide sidechains of the fixed motif to the model """ - get_allatom = ComputeAllAtomCoords().to(device=xt.device) L, n_atom = xt.shape[:2] assert (xt.shape[1] == 14) or (xt.shape[1] == 27) assert (px0.shape[1] == 14) or (px0.shape[1] == 27) diff --git a/scripts/run_inference.py b/scripts/run_inference.py index 3fb6466e..623516e1 100755 --- a/scripts/run_inference.py +++ b/scripts/run_inference.py @@ -94,10 +94,13 @@ def main(conf: HydraConfig) -> None: px0, x_t, seq_t, plddt = sampler.sample_step( t=t, x_t=x_t, seq_init=seq_t, final_step=sampler.inf_conf.final_step ) - px0_xyz_stack.append(px0) - denoised_xyz_stack.append(x_t) - seq_stack.append(seq_t) - plddt_stack.append(plddt[0]) # remove singleton leading dimension + # Move trajectory data to CPU immediately to free GPU memory + # for the next denoising step. x_t/seq_t are cloned inside + # sample_step, so moving previous outputs to CPU is safe. + px0_xyz_stack.append(px0.cpu()) + denoised_xyz_stack.append(x_t.cpu()) + seq_stack.append(seq_t.cpu()) + plddt_stack.append(plddt[0].cpu()) # remove singleton leading dimension # Flip order for better visualization in pymol denoised_xyz_stack = torch.stack(denoised_xyz_stack)