Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion rfdiffusion/Embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn.functional as F
from opt_einsum import contract as einsum
import torch.utils.checkpoint as checkpoint
from rfdiffusion.constants import CBETA_A, CBETA_B, CBETA_C
from rfdiffusion.util import get_tips
from rfdiffusion.util_module import Dropout, create_custom_forward, rbf, init_lecun_normal, find_breaks
from rfdiffusion.Attention_module import Attention, FeedForwardLayer, AttentionWithBias
Expand Down Expand Up @@ -317,7 +318,7 @@ def forward(self, seq, msa, pair, xyz, state):
b = Ca - N
c = C - Ca
a = torch.cross(b, c, dim=-1)
Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
Cb = CBETA_A*a + CBETA_B*b + CBETA_C*c + Ca

dist = rbf(torch.cdist(Cb, Cb))
dist = torch.cat((dist, left, right), dim=-1)
Expand Down
68 changes: 68 additions & 0 deletions rfdiffusion/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Named constants for RFdiffusion.

Centralizes magic numbers used across the codebase. Each constant
documents its meaning and where it originates.
"""

# ===== Amino Acid Encoding =====
NUM_AA_CLASSES = 22 # 20 standard amino acids + UNK + MASK
AA_MASK_TOKEN = 21 # Index for masked/unknown residue
AA_GLYCINE = 7 # Index for glycine in the alphabet

# ===== Atom Counts =====
N_BACKBONE_ATOMS = 4 # N, CA, C, O
N_HEAVY = 14 # Heavy atoms per residue (backbone + sidechain)
N_ALLATOM = 27 # All atoms per residue including hydrogens

# ===== Virtual Cbeta Reconstruction =====
# Coefficients for computing virtual Cbeta from backbone N, CA, C atoms.
# Derived from the cross product (CA-N) x (C-CA) basis vectors.
# Used in generate_Cbeta() across util.py, Embeddings.py, coords6d.py.
CBETA_A = -0.58273431 # coefficient for cross product vector
CBETA_B = 0.56802827 # coefficient for (CA - N) vector
CBETA_C = -0.54067466 # coefficient for (C - CA) vector

# ===== Distance Sentinel =====
# Large distance value used to indicate "no contact" or to mask
# self-interactions in distance matrices and top-k graphs.
NO_CONTACT_DIST = 999.9

# ===== Chain/Contig =====
# Index jump inserted between chains in the residue index array.
# Used by ContigMap and positional encodings to detect chain boundaries.
CHAIN_BREAK_INDEX_JUMP = 200

# Chain break detection threshold: gaps in residue index larger than
# this value are treated as chain breaks.
CHAIN_BREAK_DETECTION_THRESH = 35

# Maximum number of attempts when randomly sampling a valid contig
# length from a specified range.
CONTIG_MAX_SAMPLE_ATTEMPTS = 100_000

# ===== SE(3) Prediction Scaling =====
# Divisors applied to raw SE(3) transformer outputs to bring
# translations and rotations into physical scale.
SE3_TRANSLATION_SCALE = 10.0
SE3_ROTATION_SCALE = 100.0

# ===== Diffusion Schedule =====
# Reference number of timesteps for beta schedule scaling.
# When T != 200, betas are rescaled: beta *= 200/T.
BETA_SCHEDULE_REF_T = 200

# Minimum number of diffusion steps required for the schedule
# approximation to remain valid.
MIN_DIFFUSION_STEPS = 15

# ===== IGSO3 (Rotation Diffusion) =====
# Number of discrete sigma values for the IGSO3 distribution.
IGSO3_NUM_SIGMA = 500

# Truncation level L for the power series expansion of the
# IGSO3 probability density.
IGSO3_TRUNCATION_LEVEL = 2000

# ===== Peptide Geometry =====
# Ideal C-N peptide bond length in Angstroms.
PEPTIDE_BOND_LENGTH = 1.33
3 changes: 2 additions & 1 deletion rfdiffusion/coords6d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import scipy
import scipy.spatial
from rfdiffusion.kinematics import get_dih
from rfdiffusion.constants import CBETA_A, CBETA_B, CBETA_C

# calculate planar angles defined by 3 sets of points
def get_angles(a, b, c):
Expand Down Expand Up @@ -31,7 +32,7 @@ def get_coords6d(xyz, dmax):
b = Ca - N
c = C - Ca
a = np.cross(b, c)
Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
Cb = CBETA_A*a + CBETA_B*b + CBETA_C*c + Ca

# fast neighbors search to collect all
# Cb-Cb pairs within dmax
Expand Down
5 changes: 3 additions & 2 deletions rfdiffusion/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import scipy.sparse
from rfdiffusion.chemical import *
from rfdiffusion.scoring import *
from rfdiffusion.constants import CBETA_A, CBETA_B, CBETA_C


def generate_Cbeta(N, Ca, C):
Expand All @@ -9,7 +10,7 @@ def generate_Cbeta(N, Ca, C):
c = C - Ca
a = torch.cross(b, c, dim=-1)
# These are the values used during training
Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
Cb = CBETA_A*a + CBETA_B*b + CBETA_C*c + Ca
# fd: below matches sidechain generator (=Rosetta params)
# Cb = -0.57910144 * a + 0.5689693 * b - 0.5441217 * c + Ca

Expand Down Expand Up @@ -239,7 +240,7 @@ def get_tips(xyz, seq):
b = Ca - N
c = C - Ca
a = torch.cross(b, c, dim=-1)
Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + Ca
Cb = CBETA_A * a + CBETA_B * b + CBETA_C * c + Ca

xyz_tips = torch.where(torch.isnan(xyz_tips), Cb, xyz_tips)
return xyz_tips, mask
Expand Down
7 changes: 4 additions & 3 deletions scripts/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import hydra
import logging
from rfdiffusion.util import writepdb_multi, writepdb
from rfdiffusion.constants import AA_MASK_TOKEN, AA_GLYCINE
from rfdiffusion.inference import utils as iu
from hydra.core.hydra_config import HydraConfig
import numpy as np
Expand Down Expand Up @@ -124,12 +125,12 @@ def main(conf: HydraConfig) -> None:

# Output glycines, except for motif region
final_seq = torch.where(
torch.argmax(seq_init, dim=-1) == 21, 7, torch.argmax(seq_init, dim=-1)
) # 7 is glycine
torch.argmax(seq_init, dim=-1) == AA_MASK_TOKEN, AA_GLYCINE, torch.argmax(seq_init, dim=-1)
)

bfacts = torch.ones_like(final_seq.squeeze())
# make bfact=0 for diffused coordinates
bfacts[torch.where(torch.argmax(seq_init, dim=-1) == 21, True, False)] = 0
bfacts[torch.where(torch.argmax(seq_init, dim=-1) == AA_MASK_TOKEN, True, False)] = 0
# pX0 last step
out = f"{out_prefix}.pdb"

Expand Down