Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions rfdiffusion/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from rfdiffusion.diffusion import get_beta_schedule
from scipy.spatial.transform import Rotation as scipy_R
from rfdiffusion.util import rigid_from_3_points
from rfdiffusion.util_module import ComputeAllAtomCoords

from rfdiffusion import util
import random
import logging
Expand Down Expand Up @@ -152,7 +152,6 @@ def get_next_ca(
noise_scale: scale factor for the noise being added

"""
get_allatom = ComputeAllAtomCoords().to(device=xt.device)
L = len(xt)

# bring to origin after global alignment (when don't have a motif) or replace input motif and bring to origin, and then scale
Expand Down Expand Up @@ -435,7 +434,6 @@ def get_next_pose(
include_motif_sidechains (bool): Provide sidechains of the fixed motif to the model
"""

get_allatom = ComputeAllAtomCoords().to(device=xt.device)
L, n_atom = xt.shape[:2]
assert (xt.shape[1] == 14) or (xt.shape[1] == 27)
assert (px0.shape[1] == 14) or (px0.shape[1] == 27)
Expand Down
11 changes: 7 additions & 4 deletions scripts/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,13 @@ def main(conf: HydraConfig) -> None:
px0, x_t, seq_t, plddt = sampler.sample_step(
t=t, x_t=x_t, seq_init=seq_t, final_step=sampler.inf_conf.final_step
)
px0_xyz_stack.append(px0)
denoised_xyz_stack.append(x_t)
seq_stack.append(seq_t)
plddt_stack.append(plddt[0]) # remove singleton leading dimension
# Move trajectory data to CPU immediately to free GPU memory
# for the next denoising step. x_t/seq_t are cloned inside
# sample_step, so moving previous outputs to CPU is safe.
px0_xyz_stack.append(px0.cpu())
denoised_xyz_stack.append(x_t.cpu())
seq_stack.append(seq_t.cpu())
plddt_stack.append(plddt[0].cpu()) # remove singleton leading dimension

# Flip order for better visualization in pymol
denoised_xyz_stack = torch.stack(denoised_xyz_stack)
Expand Down