From 2a7aa2d3a49c4fd3b52bd4cb95e6cc61de90d0d1 Mon Sep 17 00:00:00 2001 From: SexyERIC0723 Date: Fri, 20 Mar 2026 17:37:34 +0000 Subject: [PATCH] refactor: extract magic numbers into named constants module MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add rfdiffusion/constants.py centralizing magic numbers used across the codebase, with documentation of each constant's meaning and provenance. Replace inline values in 4 files: - Cbeta reconstruction coefficients (-0.58273431, 0.56802827, -0.54067466) used in util.py (2x), Embeddings.py, coords6d.py → CBETA_A/B/C - Amino acid token indices (21=mask, 7=glycine) in run_inference.py → AA_MASK_TOKEN, AA_GLYCINE Also documents additional constants (NO_CONTACT_DIST, CHAIN_BREAK_*, SE3_*_SCALE, diffusion schedule params) for future refactoring. --- rfdiffusion/Embeddings.py | 3 +- rfdiffusion/constants.py | 68 +++++++++++++++++++++++++++++++++++++++ rfdiffusion/coords6d.py | 3 +- rfdiffusion/util.py | 5 +-- scripts/run_inference.py | 7 ++-- 5 files changed, 79 insertions(+), 7 deletions(-) create mode 100644 rfdiffusion/constants.py diff --git a/rfdiffusion/Embeddings.py b/rfdiffusion/Embeddings.py index be359480..ab25b0dc 100644 --- a/rfdiffusion/Embeddings.py +++ b/rfdiffusion/Embeddings.py @@ -3,6 +3,7 @@ import torch.nn.functional as F from opt_einsum import contract as einsum import torch.utils.checkpoint as checkpoint +from rfdiffusion.constants import CBETA_A, CBETA_B, CBETA_C from rfdiffusion.util import get_tips from rfdiffusion.util_module import Dropout, create_custom_forward, rbf, init_lecun_normal, find_breaks from rfdiffusion.Attention_module import Attention, FeedForwardLayer, AttentionWithBias @@ -317,7 +318,7 @@ def forward(self, seq, msa, pair, xyz, state): b = Ca - N c = C - Ca a = torch.cross(b, c, dim=-1) - Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca + Cb = CBETA_A*a + CBETA_B*b + CBETA_C*c + Ca dist = rbf(torch.cdist(Cb, Cb)) dist = torch.cat((dist, left, right), dim=-1) diff --git a/rfdiffusion/constants.py b/rfdiffusion/constants.py new file mode 100644 index 00000000..3818caa9 --- /dev/null +++ b/rfdiffusion/constants.py @@ -0,0 +1,68 @@ +"""Named constants for RFdiffusion. + +Centralizes magic numbers used across the codebase. Each constant +documents its meaning and where it originates. +""" + +# ===== Amino Acid Encoding ===== +NUM_AA_CLASSES = 22 # 20 standard amino acids + UNK + MASK +AA_MASK_TOKEN = 21 # Index for masked/unknown residue +AA_GLYCINE = 7 # Index for glycine in the alphabet + +# ===== Atom Counts ===== +N_BACKBONE_ATOMS = 4 # N, CA, C, O +N_HEAVY = 14 # Heavy atoms per residue (backbone + sidechain) +N_ALLATOM = 27 # All atoms per residue including hydrogens + +# ===== Virtual Cbeta Reconstruction ===== +# Coefficients for computing virtual Cbeta from backbone N, CA, C atoms. +# Derived from the cross product (CA-N) x (C-CA) basis vectors. +# Used in generate_Cbeta() across util.py, Embeddings.py, coords6d.py. +CBETA_A = -0.58273431 # coefficient for cross product vector +CBETA_B = 0.56802827 # coefficient for (CA - N) vector +CBETA_C = -0.54067466 # coefficient for (C - CA) vector + +# ===== Distance Sentinel ===== +# Large distance value used to indicate "no contact" or to mask +# self-interactions in distance matrices and top-k graphs. +NO_CONTACT_DIST = 999.9 + +# ===== Chain/Contig ===== +# Index jump inserted between chains in the residue index array. +# Used by ContigMap and positional encodings to detect chain boundaries. +CHAIN_BREAK_INDEX_JUMP = 200 + +# Chain break detection threshold: gaps in residue index larger than +# this value are treated as chain breaks. +CHAIN_BREAK_DETECTION_THRESH = 35 + +# Maximum number of attempts when randomly sampling a valid contig +# length from a specified range. +CONTIG_MAX_SAMPLE_ATTEMPTS = 100_000 + +# ===== SE(3) Prediction Scaling ===== +# Divisors applied to raw SE(3) transformer outputs to bring +# translations and rotations into physical scale. +SE3_TRANSLATION_SCALE = 10.0 +SE3_ROTATION_SCALE = 100.0 + +# ===== Diffusion Schedule ===== +# Reference number of timesteps for beta schedule scaling. +# When T != 200, betas are rescaled: beta *= 200/T. +BETA_SCHEDULE_REF_T = 200 + +# Minimum number of diffusion steps required for the schedule +# approximation to remain valid. +MIN_DIFFUSION_STEPS = 15 + +# ===== IGSO3 (Rotation Diffusion) ===== +# Number of discrete sigma values for the IGSO3 distribution. +IGSO3_NUM_SIGMA = 500 + +# Truncation level L for the power series expansion of the +# IGSO3 probability density. +IGSO3_TRUNCATION_LEVEL = 2000 + +# ===== Peptide Geometry ===== +# Ideal C-N peptide bond length in Angstroms. +PEPTIDE_BOND_LENGTH = 1.33 diff --git a/rfdiffusion/coords6d.py b/rfdiffusion/coords6d.py index d3224543..3b080579 100644 --- a/rfdiffusion/coords6d.py +++ b/rfdiffusion/coords6d.py @@ -2,6 +2,7 @@ import scipy import scipy.spatial from rfdiffusion.kinematics import get_dih +from rfdiffusion.constants import CBETA_A, CBETA_B, CBETA_C # calculate planar angles defined by 3 sets of points def get_angles(a, b, c): @@ -31,7 +32,7 @@ def get_coords6d(xyz, dmax): b = Ca - N c = C - Ca a = np.cross(b, c) - Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca + Cb = CBETA_A*a + CBETA_B*b + CBETA_C*c + Ca # fast neighbors search to collect all # Cb-Cb pairs within dmax diff --git a/rfdiffusion/util.py b/rfdiffusion/util.py index 19c30f5f..199a9a17 100644 --- a/rfdiffusion/util.py +++ b/rfdiffusion/util.py @@ -1,6 +1,7 @@ import scipy.sparse from rfdiffusion.chemical import * from rfdiffusion.scoring import * +from rfdiffusion.constants import CBETA_A, CBETA_B, CBETA_C def generate_Cbeta(N, Ca, C): @@ -9,7 +10,7 @@ def generate_Cbeta(N, Ca, C): c = C - Ca a = torch.cross(b, c, dim=-1) # These are the values used during training - Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca + Cb = CBETA_A*a + CBETA_B*b + CBETA_C*c + Ca # fd: below matches sidechain generator (=Rosetta params) # Cb = -0.57910144 * a + 0.5689693 * b - 0.5441217 * c + Ca @@ -239,7 +240,7 @@ def get_tips(xyz, seq): b = Ca - N c = C - Ca a = torch.cross(b, c, dim=-1) - Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + Ca + Cb = CBETA_A * a + CBETA_B * b + CBETA_C * c + Ca xyz_tips = torch.where(torch.isnan(xyz_tips), Cb, xyz_tips) return xyz_tips, mask diff --git a/scripts/run_inference.py b/scripts/run_inference.py index 3fb6466e..d9ce0a8e 100755 --- a/scripts/run_inference.py +++ b/scripts/run_inference.py @@ -22,6 +22,7 @@ import hydra import logging from rfdiffusion.util import writepdb_multi, writepdb +from rfdiffusion.constants import AA_MASK_TOKEN, AA_GLYCINE from rfdiffusion.inference import utils as iu from hydra.core.hydra_config import HydraConfig import numpy as np @@ -124,12 +125,12 @@ def main(conf: HydraConfig) -> None: # Output glycines, except for motif region final_seq = torch.where( - torch.argmax(seq_init, dim=-1) == 21, 7, torch.argmax(seq_init, dim=-1) - ) # 7 is glycine + torch.argmax(seq_init, dim=-1) == AA_MASK_TOKEN, AA_GLYCINE, torch.argmax(seq_init, dim=-1) + ) bfacts = torch.ones_like(final_seq.squeeze()) # make bfact=0 for diffused coordinates - bfacts[torch.where(torch.argmax(seq_init, dim=-1) == 21, True, False)] = 0 + bfacts[torch.where(torch.argmax(seq_init, dim=-1) == AA_MASK_TOKEN, True, False)] = 0 # pX0 last step out = f"{out_prefix}.pdb"