Skip to content

Commit 5dfdcdd

Browse files
committed
chased down nans in training. stable again
1 parent 13866bc commit 5dfdcdd

9 files changed

Lines changed: 839 additions & 432 deletions

File tree

foldtree2/learn_lightning.py

Lines changed: 133 additions & 153 deletions
Large diffs are not rendered by default.

foldtree2/learn_monodecoder.py

Lines changed: 128 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,96 @@
3232
from torch.cuda.amp import autocast, GradScaler
3333
warnings.filterwarnings("ignore", category=UserWarning)
3434

35+
36+
def build_notebook_mono_configs(args, converter, hidden_size, ndim_godnode, ndim_fft2r, ndim_fft2i):
37+
"""Build MultiMonoDecoder configs matching the notebook defaults as closely as possible."""
38+
geometry_output_rt = args.output_rt if args.geometry_output_rt is None else args.geometry_output_rt
39+
geometry_cnn_output_fft = args.output_fft if args.geometry_cnn_output_fft is None else args.geometry_cnn_output_fft
40+
41+
mono_configs = {
42+
'sequence_transformer': {
43+
'in_channels': {'res': args.embedding_dim},
44+
'xdim': 20,
45+
'concat_positions': False,
46+
'hidden_channels': {('res', 'backbone', 'res'): [hidden_size], ('res', 'backbonerev', 'res'): [hidden_size]},
47+
'layers': 2,
48+
'AAdecoder_hidden': [hidden_size, hidden_size, hidden_size],
49+
'amino_mapper': converter.aaindex,
50+
'nheads': 5,
51+
'dropout': 0.001,
52+
'normalize': False,
53+
'residual': False,
54+
'use_cnn_decoder': args.sequence_use_cnn_decoder,
55+
'output_ss': args.sequence_output_ss,
56+
'learn_positions': True,
57+
},
58+
}
59+
60+
if args.use_geometry_transformer:
61+
mono_configs['geometry_transformer'] = {
62+
'in_channels': {'res': args.embedding_dim},
63+
'concat_positions': False,
64+
'hidden_channels': {('res', 'backbone', 'res'): [hidden_size], ('res', 'backbonerev', 'res'): [hidden_size]},
65+
'layers': 2,
66+
'nheads': 5,
67+
'RTdecoder_hidden': [hidden_size, hidden_size, hidden_size],
68+
'ssdecoder_hidden': [hidden_size, hidden_size, hidden_size],
69+
'anglesdecoder_hidden': [hidden_size, hidden_size, hidden_size],
70+
'dropout': 0.001,
71+
'normalize': False,
72+
'residual': False,
73+
'learn_positions': True,
74+
'output_rt': geometry_output_rt,
75+
'output_ss': args.geometry_output_ss,
76+
'output_angles': args.geometry_output_angles,
77+
}
78+
79+
if args.use_geometry_cnn:
80+
mono_configs['geometry_cnn'] = {
81+
'in_channels': {'res': args.embedding_dim, 'godnode4decoder': ndim_godnode, 'foldx': 23, 'fft2r': ndim_fft2r, 'fft2i': ndim_fft2i},
82+
'concat_positions': False,
83+
'conv_channels': [2 * hidden_size, hidden_size, hidden_size],
84+
'kernel_sizes': [3] * args.nconv_layers,
85+
'FFT2decoder_hidden': [hidden_size // 2, hidden_size // 2],
86+
'contactdecoder_hidden': [hidden_size // 2, hidden_size // 4],
87+
'ssdecoder_hidden': [hidden_size // 2, hidden_size // 2],
88+
'Xdecoder_hidden': [hidden_size, hidden_size],
89+
'anglesdecoder_hidden': [hidden_size, hidden_size, hidden_size // 2],
90+
'RTdecoder_hidden': [hidden_size // 2, hidden_size // 4],
91+
'metadata': converter.metadata,
92+
'dropout': 0.001,
93+
'output_fft': geometry_cnn_output_fft,
94+
'output_rt': args.geometry_cnn_output_rt,
95+
'output_angles': args.geometry_cnn_output_angles,
96+
'output_ss': args.geometry_cnn_output_ss,
97+
'normalize': True,
98+
'residual': False,
99+
'output_edge_logits': args.geometry_cnn_output_edge_logits,
100+
'ncat': 8,
101+
'contact_mlp': False,
102+
'pool_type': 'global_mean',
103+
'learn_positions': True,
104+
}
105+
106+
if args.output_foldx:
107+
mono_configs['foldx'] = {
108+
'in_channels': {'res': args.embedding_dim, 'godnode4decoder': ndim_godnode, 'foldx': 23},
109+
'concat_positions': False,
110+
'hidden_channels': {('res', 'backbone', 'res'): [hidden_size] * 3, ('res', 'backbonerev', 'res'): [hidden_size] * 3,
111+
('res', 'informs', 'godnode4decoder'): [hidden_size] * 3,
112+
('godnode4decoder', 'informs', 'res'): [hidden_size] * 3},
113+
'layers': 3,
114+
'foldx_hidden': [hidden_size, hidden_size // 2],
115+
'nheads': 2,
116+
'metadata': converter.metadata,
117+
'flavor': 'sage',
118+
'dropout': 0.005,
119+
'normalize': True,
120+
'residual': False,
121+
}
122+
123+
return mono_configs
124+
35125
# Try to import Muon optimizer
36126
try:
37127
from muon import MuonWithAuxAdam
@@ -125,6 +215,30 @@ def print_about():
125215
help='Train the model with rotation and translation output')
126216
parser.add_argument('--output-foldx' , action='store_true',
127217
help='Train the model with Foldx energy prediction output')
218+
parser.add_argument('--use-geometry-transformer', action=argparse.BooleanOptionalAction, default=True,
219+
help='Enable the notebook-style geometry_transformer decoder (default: True)')
220+
parser.add_argument('--use-geometry-cnn', action=argparse.BooleanOptionalAction, default=True,
221+
help='Enable the notebook-style geometry_cnn decoder (default: True)')
222+
parser.add_argument('--sequence-use-cnn-decoder', action=argparse.BooleanOptionalAction, default=True,
223+
help='Use the CNN AA head inside sequence_transformer, matching the notebook (default: True)')
224+
parser.add_argument('--sequence-output-ss', action=argparse.BooleanOptionalAction, default=False,
225+
help='Allow sequence_transformer to emit secondary structure predictions (default: False)')
226+
parser.add_argument('--geometry-output-rt', action=argparse.BooleanOptionalAction, default=None,
227+
help='Allow geometry_transformer to emit rotation/translation outputs (default: follows --output-rt)')
228+
parser.add_argument('--geometry-output-ss', action=argparse.BooleanOptionalAction, default=True,
229+
help='Allow geometry_transformer to emit secondary structure predictions (default: True)')
230+
parser.add_argument('--geometry-output-angles', action=argparse.BooleanOptionalAction, default=True,
231+
help='Allow geometry_transformer to emit bond angle predictions (default: True)')
232+
parser.add_argument('--geometry-cnn-output-fft', action=argparse.BooleanOptionalAction, default=None,
233+
help='Allow geometry_cnn to emit FFT predictions (default: follows --output-fft)')
234+
parser.add_argument('--geometry-cnn-output-rt', action=argparse.BooleanOptionalAction, default=False,
235+
help='Allow geometry_cnn to emit rotation/translation outputs (default: False)')
236+
parser.add_argument('--geometry-cnn-output-angles', action=argparse.BooleanOptionalAction, default=False,
237+
help='Allow geometry_cnn to emit bond angle predictions (default: False)')
238+
parser.add_argument('--geometry-cnn-output-ss', action=argparse.BooleanOptionalAction, default=False,
239+
help='Allow geometry_cnn to emit secondary structure predictions (default: False)')
240+
parser.add_argument('--geometry-cnn-output-edge-logits', action=argparse.BooleanOptionalAction, default=True,
241+
help='Allow geometry_cnn to emit contact edge logits/probabilities (default: True)')
128242
parser.add_argument('--seed', type=int, default=0,
129243
help='Random seed for reproducibility')
130244
parser.add_argument('--hetero-gae', action='store_true',
@@ -432,6 +546,18 @@ def decode_batch_reconstruction(encoder, decoder, z_batch, device, converter, ve
432546
print(f" Output FFT: {'Enabled' if args.output_fft else 'Disabled'}")
433547
print(f" Output RT: {'Enabled' if args.output_rt else 'Disabled'}")
434548
print(f" Output Foldx: {'Enabled' if args.output_foldx else 'Disabled'}")
549+
print(f" Sequence CNN Head: {'Enabled' if args.sequence_use_cnn_decoder else 'Disabled'}")
550+
print(f" Geometry Transformer: {'Enabled' if args.use_geometry_transformer else 'Disabled'}")
551+
print(f" Geometry CNN: {'Enabled' if args.use_geometry_cnn else 'Disabled'}")
552+
print(f" Sequence Output SS: {'Enabled' if args.sequence_output_ss else 'Disabled'}")
553+
print(f" Geometry Output RT: {'Enabled' if (args.output_rt if args.geometry_output_rt is None else args.geometry_output_rt) else 'Disabled'}")
554+
print(f" Geometry Output SS: {'Enabled' if args.geometry_output_ss else 'Disabled'}")
555+
print(f" Geometry Output Angles: {'Enabled' if args.geometry_output_angles else 'Disabled'}")
556+
print(f" Geometry CNN Output FFT: {'Enabled' if (args.output_fft if args.geometry_cnn_output_fft is None else args.geometry_cnn_output_fft) else 'Disabled'}")
557+
print(f" Geometry CNN Output RT: {'Enabled' if args.geometry_cnn_output_rt else 'Disabled'}")
558+
print(f" Geometry CNN Output Angles: {'Enabled' if args.geometry_cnn_output_angles else 'Disabled'}")
559+
print(f" Geometry CNN Output SS: {'Enabled' if args.geometry_cnn_output_ss else 'Disabled'}")
560+
print(f" Geometry CNN Edge Logits: {'Enabled' if args.geometry_cnn_output_edge_logits else 'Disabled'}")
435561
print(f" Hetero GAE Decoder: {'Enabled' if args.hetero_gae else 'Disabled'}")
436562
print(f" Gradient Clipping: {'Enabled' if args.clip_grad else 'Disabled'}")
437563
print(f" Burn-in Period: {args.burn_in} epochs")
@@ -644,73 +770,8 @@ def decode_batch_reconstruction(encoder, decoder, z_batch, device, converter, ve
644770
)
645771
else:
646772
# MultiMonoDecoder for sequence and geometry
647-
print("Using standard decoders")
648-
mono_configs = {
649-
'sequence_transformer': {
650-
'in_channels': {'res': args.embedding_dim},
651-
'xdim': 20,
652-
'concat_positions': False,
653-
'hidden_channels': {('res','backbone','res'): [hidden_size], ('res','backbonerev','res'): [hidden_size]},
654-
'layers': 2,
655-
'AAdecoder_hidden': [hidden_size, hidden_size, hidden_size//2],
656-
'amino_mapper': converter.aaindex,
657-
'nheads': 5,
658-
'dropout': 0.001,
659-
'normalize': False,
660-
'residual': False,
661-
'use_cnn_decoder': True,
662-
'output_ss': False, # Don't output SS from sequence decoder
663-
'learn_positions': True,
664-
'concat_positions': False
665-
},
666-
667-
'geometry_transformer': {
668-
'in_channels': {'res': args.embedding_dim},
669-
'concat_positions': False,
670-
'hidden_channels': {('res','backbone','res'): [hidden_size], ('res','backbonerev','res'): [hidden_size]},
671-
'layers': 2,
672-
'nheads': 5,
673-
'RTdecoder_hidden': [hidden_size, hidden_size, hidden_size//2],
674-
'ssdecoder_hidden': [hidden_size,hidden_size, hidden_size//2],
675-
'anglesdecoder_hidden': [hidden_size, hidden_size,hidden_size//2],
676-
'dropout': 0.001,
677-
'normalize': False,
678-
'residual': False,
679-
'learn_positions': True,
680-
'use_cnn_decoder':True,
681-
'concat_positions': False,
682-
'output_rt': False, # Enable if you want rotation-translation
683-
'output_ss': True, # Secondary structure prediction
684-
'output_angles': True # Bond angles prediction
685-
},
686-
687-
'geometry_cnn': {
688-
'in_channels': {'res': args.embedding_dim, 'godnode4decoder': ndim_godnode, 'foldx': 23, 'fft2r': ndim_fft2r, 'fft2i': ndim_fft2i},
689-
'concat_positions': False,
690-
'conv_channels': [hidden_size, hidden_size, hidden_size],
691-
'kernel_sizes': [3]*args.nconv_layers,
692-
'FFT2decoder_hidden': [hidden_size//2, hidden_size//2],
693-
'contactdecoder_hidden': [hidden_size//2, hidden_size//4],
694-
'ssdecoder_hidden': [hidden_size//2, hidden_size//2],
695-
'Xdecoder_hidden': [hidden_size, hidden_size],
696-
'anglesdecoder_hidden': [hidden_size, hidden_size, hidden_size//2],
697-
'RTdecoder_hidden': [hidden_size//2, hidden_size//4],
698-
'metadata': converter.metadata,
699-
'dropout': 0.001,
700-
'output_fft': False,
701-
'output_rt': False,
702-
'output_angles': False, # Don't duplicate angles from geometry_transformer
703-
'output_ss': False, # Don't duplicate SS from geometry_transformer
704-
'normalize': True,
705-
'residual': False,
706-
'output_edge_logits': True,
707-
'ncat': 8,
708-
'contact_mlp': False,
709-
'pool_type': 'global_mean',
710-
'learn_positions': True,
711-
'concat_positions': False
712-
},
713-
}
773+
print("Using notebook-style decoder configuration")
774+
mono_configs = build_notebook_mono_configs(args, converter, hidden_size, ndim_godnode, ndim_fft2r, ndim_fft2i)
714775
# Initialize decoder
715776
decoder = MultiMonoDecoder( configs=mono_configs)
716777

0 commit comments

Comments
 (0)