|
32 | 32 | from torch.cuda.amp import autocast, GradScaler |
33 | 33 | warnings.filterwarnings("ignore", category=UserWarning) |
34 | 34 |
|
| 35 | + |
| 36 | +def build_notebook_mono_configs(args, converter, hidden_size, ndim_godnode, ndim_fft2r, ndim_fft2i): |
| 37 | + """Build MultiMonoDecoder configs matching the notebook defaults as closely as possible.""" |
| 38 | + geometry_output_rt = args.output_rt if args.geometry_output_rt is None else args.geometry_output_rt |
| 39 | + geometry_cnn_output_fft = args.output_fft if args.geometry_cnn_output_fft is None else args.geometry_cnn_output_fft |
| 40 | + |
| 41 | + mono_configs = { |
| 42 | + 'sequence_transformer': { |
| 43 | + 'in_channels': {'res': args.embedding_dim}, |
| 44 | + 'xdim': 20, |
| 45 | + 'concat_positions': False, |
| 46 | + 'hidden_channels': {('res', 'backbone', 'res'): [hidden_size], ('res', 'backbonerev', 'res'): [hidden_size]}, |
| 47 | + 'layers': 2, |
| 48 | + 'AAdecoder_hidden': [hidden_size, hidden_size, hidden_size], |
| 49 | + 'amino_mapper': converter.aaindex, |
| 50 | + 'nheads': 5, |
| 51 | + 'dropout': 0.001, |
| 52 | + 'normalize': False, |
| 53 | + 'residual': False, |
| 54 | + 'use_cnn_decoder': args.sequence_use_cnn_decoder, |
| 55 | + 'output_ss': args.sequence_output_ss, |
| 56 | + 'learn_positions': True, |
| 57 | + }, |
| 58 | + } |
| 59 | + |
| 60 | + if args.use_geometry_transformer: |
| 61 | + mono_configs['geometry_transformer'] = { |
| 62 | + 'in_channels': {'res': args.embedding_dim}, |
| 63 | + 'concat_positions': False, |
| 64 | + 'hidden_channels': {('res', 'backbone', 'res'): [hidden_size], ('res', 'backbonerev', 'res'): [hidden_size]}, |
| 65 | + 'layers': 2, |
| 66 | + 'nheads': 5, |
| 67 | + 'RTdecoder_hidden': [hidden_size, hidden_size, hidden_size], |
| 68 | + 'ssdecoder_hidden': [hidden_size, hidden_size, hidden_size], |
| 69 | + 'anglesdecoder_hidden': [hidden_size, hidden_size, hidden_size], |
| 70 | + 'dropout': 0.001, |
| 71 | + 'normalize': False, |
| 72 | + 'residual': False, |
| 73 | + 'learn_positions': True, |
| 74 | + 'output_rt': geometry_output_rt, |
| 75 | + 'output_ss': args.geometry_output_ss, |
| 76 | + 'output_angles': args.geometry_output_angles, |
| 77 | + } |
| 78 | + |
| 79 | + if args.use_geometry_cnn: |
| 80 | + mono_configs['geometry_cnn'] = { |
| 81 | + 'in_channels': {'res': args.embedding_dim, 'godnode4decoder': ndim_godnode, 'foldx': 23, 'fft2r': ndim_fft2r, 'fft2i': ndim_fft2i}, |
| 82 | + 'concat_positions': False, |
| 83 | + 'conv_channels': [2 * hidden_size, hidden_size, hidden_size], |
| 84 | + 'kernel_sizes': [3] * args.nconv_layers, |
| 85 | + 'FFT2decoder_hidden': [hidden_size // 2, hidden_size // 2], |
| 86 | + 'contactdecoder_hidden': [hidden_size // 2, hidden_size // 4], |
| 87 | + 'ssdecoder_hidden': [hidden_size // 2, hidden_size // 2], |
| 88 | + 'Xdecoder_hidden': [hidden_size, hidden_size], |
| 89 | + 'anglesdecoder_hidden': [hidden_size, hidden_size, hidden_size // 2], |
| 90 | + 'RTdecoder_hidden': [hidden_size // 2, hidden_size // 4], |
| 91 | + 'metadata': converter.metadata, |
| 92 | + 'dropout': 0.001, |
| 93 | + 'output_fft': geometry_cnn_output_fft, |
| 94 | + 'output_rt': args.geometry_cnn_output_rt, |
| 95 | + 'output_angles': args.geometry_cnn_output_angles, |
| 96 | + 'output_ss': args.geometry_cnn_output_ss, |
| 97 | + 'normalize': True, |
| 98 | + 'residual': False, |
| 99 | + 'output_edge_logits': args.geometry_cnn_output_edge_logits, |
| 100 | + 'ncat': 8, |
| 101 | + 'contact_mlp': False, |
| 102 | + 'pool_type': 'global_mean', |
| 103 | + 'learn_positions': True, |
| 104 | + } |
| 105 | + |
| 106 | + if args.output_foldx: |
| 107 | + mono_configs['foldx'] = { |
| 108 | + 'in_channels': {'res': args.embedding_dim, 'godnode4decoder': ndim_godnode, 'foldx': 23}, |
| 109 | + 'concat_positions': False, |
| 110 | + 'hidden_channels': {('res', 'backbone', 'res'): [hidden_size] * 3, ('res', 'backbonerev', 'res'): [hidden_size] * 3, |
| 111 | + ('res', 'informs', 'godnode4decoder'): [hidden_size] * 3, |
| 112 | + ('godnode4decoder', 'informs', 'res'): [hidden_size] * 3}, |
| 113 | + 'layers': 3, |
| 114 | + 'foldx_hidden': [hidden_size, hidden_size // 2], |
| 115 | + 'nheads': 2, |
| 116 | + 'metadata': converter.metadata, |
| 117 | + 'flavor': 'sage', |
| 118 | + 'dropout': 0.005, |
| 119 | + 'normalize': True, |
| 120 | + 'residual': False, |
| 121 | + } |
| 122 | + |
| 123 | + return mono_configs |
| 124 | + |
35 | 125 | # Try to import Muon optimizer |
36 | 126 | try: |
37 | 127 | from muon import MuonWithAuxAdam |
@@ -125,6 +215,30 @@ def print_about(): |
125 | 215 | help='Train the model with rotation and translation output') |
126 | 216 | parser.add_argument('--output-foldx' , action='store_true', |
127 | 217 | help='Train the model with Foldx energy prediction output') |
| 218 | +parser.add_argument('--use-geometry-transformer', action=argparse.BooleanOptionalAction, default=True, |
| 219 | + help='Enable the notebook-style geometry_transformer decoder (default: True)') |
| 220 | +parser.add_argument('--use-geometry-cnn', action=argparse.BooleanOptionalAction, default=True, |
| 221 | + help='Enable the notebook-style geometry_cnn decoder (default: True)') |
| 222 | +parser.add_argument('--sequence-use-cnn-decoder', action=argparse.BooleanOptionalAction, default=True, |
| 223 | + help='Use the CNN AA head inside sequence_transformer, matching the notebook (default: True)') |
| 224 | +parser.add_argument('--sequence-output-ss', action=argparse.BooleanOptionalAction, default=False, |
| 225 | + help='Allow sequence_transformer to emit secondary structure predictions (default: False)') |
| 226 | +parser.add_argument('--geometry-output-rt', action=argparse.BooleanOptionalAction, default=None, |
| 227 | + help='Allow geometry_transformer to emit rotation/translation outputs (default: follows --output-rt)') |
| 228 | +parser.add_argument('--geometry-output-ss', action=argparse.BooleanOptionalAction, default=True, |
| 229 | + help='Allow geometry_transformer to emit secondary structure predictions (default: True)') |
| 230 | +parser.add_argument('--geometry-output-angles', action=argparse.BooleanOptionalAction, default=True, |
| 231 | + help='Allow geometry_transformer to emit bond angle predictions (default: True)') |
| 232 | +parser.add_argument('--geometry-cnn-output-fft', action=argparse.BooleanOptionalAction, default=None, |
| 233 | + help='Allow geometry_cnn to emit FFT predictions (default: follows --output-fft)') |
| 234 | +parser.add_argument('--geometry-cnn-output-rt', action=argparse.BooleanOptionalAction, default=False, |
| 235 | + help='Allow geometry_cnn to emit rotation/translation outputs (default: False)') |
| 236 | +parser.add_argument('--geometry-cnn-output-angles', action=argparse.BooleanOptionalAction, default=False, |
| 237 | + help='Allow geometry_cnn to emit bond angle predictions (default: False)') |
| 238 | +parser.add_argument('--geometry-cnn-output-ss', action=argparse.BooleanOptionalAction, default=False, |
| 239 | + help='Allow geometry_cnn to emit secondary structure predictions (default: False)') |
| 240 | +parser.add_argument('--geometry-cnn-output-edge-logits', action=argparse.BooleanOptionalAction, default=True, |
| 241 | + help='Allow geometry_cnn to emit contact edge logits/probabilities (default: True)') |
128 | 242 | parser.add_argument('--seed', type=int, default=0, |
129 | 243 | help='Random seed for reproducibility') |
130 | 244 | parser.add_argument('--hetero-gae', action='store_true', |
@@ -432,6 +546,18 @@ def decode_batch_reconstruction(encoder, decoder, z_batch, device, converter, ve |
432 | 546 | print(f" Output FFT: {'Enabled' if args.output_fft else 'Disabled'}") |
433 | 547 | print(f" Output RT: {'Enabled' if args.output_rt else 'Disabled'}") |
434 | 548 | print(f" Output Foldx: {'Enabled' if args.output_foldx else 'Disabled'}") |
| 549 | +print(f" Sequence CNN Head: {'Enabled' if args.sequence_use_cnn_decoder else 'Disabled'}") |
| 550 | +print(f" Geometry Transformer: {'Enabled' if args.use_geometry_transformer else 'Disabled'}") |
| 551 | +print(f" Geometry CNN: {'Enabled' if args.use_geometry_cnn else 'Disabled'}") |
| 552 | +print(f" Sequence Output SS: {'Enabled' if args.sequence_output_ss else 'Disabled'}") |
| 553 | +print(f" Geometry Output RT: {'Enabled' if (args.output_rt if args.geometry_output_rt is None else args.geometry_output_rt) else 'Disabled'}") |
| 554 | +print(f" Geometry Output SS: {'Enabled' if args.geometry_output_ss else 'Disabled'}") |
| 555 | +print(f" Geometry Output Angles: {'Enabled' if args.geometry_output_angles else 'Disabled'}") |
| 556 | +print(f" Geometry CNN Output FFT: {'Enabled' if (args.output_fft if args.geometry_cnn_output_fft is None else args.geometry_cnn_output_fft) else 'Disabled'}") |
| 557 | +print(f" Geometry CNN Output RT: {'Enabled' if args.geometry_cnn_output_rt else 'Disabled'}") |
| 558 | +print(f" Geometry CNN Output Angles: {'Enabled' if args.geometry_cnn_output_angles else 'Disabled'}") |
| 559 | +print(f" Geometry CNN Output SS: {'Enabled' if args.geometry_cnn_output_ss else 'Disabled'}") |
| 560 | +print(f" Geometry CNN Edge Logits: {'Enabled' if args.geometry_cnn_output_edge_logits else 'Disabled'}") |
435 | 561 | print(f" Hetero GAE Decoder: {'Enabled' if args.hetero_gae else 'Disabled'}") |
436 | 562 | print(f" Gradient Clipping: {'Enabled' if args.clip_grad else 'Disabled'}") |
437 | 563 | print(f" Burn-in Period: {args.burn_in} epochs") |
@@ -644,73 +770,8 @@ def decode_batch_reconstruction(encoder, decoder, z_batch, device, converter, ve |
644 | 770 | ) |
645 | 771 | else: |
646 | 772 | # MultiMonoDecoder for sequence and geometry |
647 | | - print("Using standard decoders") |
648 | | - mono_configs = { |
649 | | - 'sequence_transformer': { |
650 | | - 'in_channels': {'res': args.embedding_dim}, |
651 | | - 'xdim': 20, |
652 | | - 'concat_positions': False, |
653 | | - 'hidden_channels': {('res','backbone','res'): [hidden_size], ('res','backbonerev','res'): [hidden_size]}, |
654 | | - 'layers': 2, |
655 | | - 'AAdecoder_hidden': [hidden_size, hidden_size, hidden_size//2], |
656 | | - 'amino_mapper': converter.aaindex, |
657 | | - 'nheads': 5, |
658 | | - 'dropout': 0.001, |
659 | | - 'normalize': False, |
660 | | - 'residual': False, |
661 | | - 'use_cnn_decoder': True, |
662 | | - 'output_ss': False, # Don't output SS from sequence decoder |
663 | | - 'learn_positions': True, |
664 | | - 'concat_positions': False |
665 | | - }, |
666 | | - |
667 | | - 'geometry_transformer': { |
668 | | - 'in_channels': {'res': args.embedding_dim}, |
669 | | - 'concat_positions': False, |
670 | | - 'hidden_channels': {('res','backbone','res'): [hidden_size], ('res','backbonerev','res'): [hidden_size]}, |
671 | | - 'layers': 2, |
672 | | - 'nheads': 5, |
673 | | - 'RTdecoder_hidden': [hidden_size, hidden_size, hidden_size//2], |
674 | | - 'ssdecoder_hidden': [hidden_size,hidden_size, hidden_size//2], |
675 | | - 'anglesdecoder_hidden': [hidden_size, hidden_size,hidden_size//2], |
676 | | - 'dropout': 0.001, |
677 | | - 'normalize': False, |
678 | | - 'residual': False, |
679 | | - 'learn_positions': True, |
680 | | - 'use_cnn_decoder':True, |
681 | | - 'concat_positions': False, |
682 | | - 'output_rt': False, # Enable if you want rotation-translation |
683 | | - 'output_ss': True, # Secondary structure prediction |
684 | | - 'output_angles': True # Bond angles prediction |
685 | | - }, |
686 | | - |
687 | | - 'geometry_cnn': { |
688 | | - 'in_channels': {'res': args.embedding_dim, 'godnode4decoder': ndim_godnode, 'foldx': 23, 'fft2r': ndim_fft2r, 'fft2i': ndim_fft2i}, |
689 | | - 'concat_positions': False, |
690 | | - 'conv_channels': [hidden_size, hidden_size, hidden_size], |
691 | | - 'kernel_sizes': [3]*args.nconv_layers, |
692 | | - 'FFT2decoder_hidden': [hidden_size//2, hidden_size//2], |
693 | | - 'contactdecoder_hidden': [hidden_size//2, hidden_size//4], |
694 | | - 'ssdecoder_hidden': [hidden_size//2, hidden_size//2], |
695 | | - 'Xdecoder_hidden': [hidden_size, hidden_size], |
696 | | - 'anglesdecoder_hidden': [hidden_size, hidden_size, hidden_size//2], |
697 | | - 'RTdecoder_hidden': [hidden_size//2, hidden_size//4], |
698 | | - 'metadata': converter.metadata, |
699 | | - 'dropout': 0.001, |
700 | | - 'output_fft': False, |
701 | | - 'output_rt': False, |
702 | | - 'output_angles': False, # Don't duplicate angles from geometry_transformer |
703 | | - 'output_ss': False, # Don't duplicate SS from geometry_transformer |
704 | | - 'normalize': True, |
705 | | - 'residual': False, |
706 | | - 'output_edge_logits': True, |
707 | | - 'ncat': 8, |
708 | | - 'contact_mlp': False, |
709 | | - 'pool_type': 'global_mean', |
710 | | - 'learn_positions': True, |
711 | | - 'concat_positions': False |
712 | | - }, |
713 | | - } |
| 773 | + print("Using notebook-style decoder configuration") |
| 774 | + mono_configs = build_notebook_mono_configs(args, converter, hidden_size, ndim_godnode, ndim_fft2r, ndim_fft2i) |
714 | 775 | # Initialize decoder |
715 | 776 | decoder = MultiMonoDecoder( configs=mono_configs) |
716 | 777 |
|
|
0 commit comments