From bfa5ad2736ef4667721c8f9e774efaf410f2b70e Mon Sep 17 00:00:00 2001 From: Ahmed Khaled Date: Fri, 10 Apr 2026 20:24:52 -0700 Subject: [PATCH] internal PiperOrigin-RevId: 898012702 --- init2winit/hyperparameters.py | 47 ++++++++++++++++--- init2winit/model_lib/adabelief_densenet.py | 19 +------- init2winit/model_lib/adabelief_resnet.py | 17 ------- init2winit/model_lib/adabelief_vgg.py | 19 +------- init2winit/model_lib/autoencoder.py | 21 --------- init2winit/model_lib/conformer.py | 36 -------------- .../model_lib/convolutional_autoencoder.py | 17 ------- init2winit/model_lib/deepspeech.py | 28 ----------- init2winit/model_lib/dlrm.py | 16 ------- init2winit/model_lib/fully_connected.py | 20 +------- init2winit/model_lib/gnn.py | 19 -------- .../model_lib/local_attention_transformer.py | 27 +---------- init2winit/model_lib/lstm_lm.py | 17 ------- init2winit/model_lib/max_pooling_cnn.py | 17 ------- init2winit/model_lib/mdlm_rope_nanodo.py | 15 ------ init2winit/model_lib/mlperf_resnet.py | 40 ---------------- init2winit/model_lib/nanodo.py | 15 ------ init2winit/model_lib/nqm.py | 11 ----- init2winit/model_lib/resnet.py | 17 ------- init2winit/model_lib/rope_nanodo.py | 15 ------ init2winit/model_lib/simple_cnn.py | 15 ------ init2winit/model_lib/test_models.py | 7 ++- init2winit/model_lib/transformer_lm.py | 21 --------- init2winit/model_lib/transformer_stu_lm.py | 21 --------- .../model_lib/transformer_stu_tensordot_lm.py | 21 --------- init2winit/model_lib/unet.py | 45 ------------------ init2winit/model_lib/vit.py | 19 -------- init2winit/model_lib/wide_resnet.py | 17 ------- init2winit/model_lib/xformer_translate.py | 44 ----------------- .../model_lib/xformer_translate_binary.py | 22 --------- .../xformer_translate_mlc_variant.py | 22 --------- init2winit/test_hyperparameters.py | 10 ++-- init2winit/trainer_lib/test_trainer.py | 13 +++-- 33 files changed, 65 insertions(+), 645 deletions(-) diff --git a/init2winit/hyperparameters.py b/init2winit/hyperparameters.py index b145599d..50231444 100644 --- a/init2winit/hyperparameters.py +++ b/init2winit/hyperparameters.py @@ -24,6 +24,32 @@ from ml_collections.config_dict import config_dict +# Default hyperparameters for training and optimization. +DEFAULT_TRAINING_HPARAMS = config_dict.ConfigDict( + dict( + optimizer='adam', + opt_hparams={ + 'beta1': 0.9, + 'beta2': 0.999, + 'epsilon': 1e-8, + 'grad_clip': None, + }, + lr_hparams={ + 'base_lr': 0.01, + 'schedule': 'constant', + }, + batch_size=128, + total_accumulated_batch_size=None, + l2_decay_factor=None, + l2_decay_rank_threshold=2, + label_smoothing=None, + rng_seed=-1, + use_shallue_label_smoothing=False, + layer_rescale_factors={}, + ) +) + + def expand_key(hparams, key_pieces, index, value): """Util to safely expand dotted keys in a dictionary. @@ -125,22 +151,31 @@ def build_hparams(model_name, initializer_hps = initializers.get_initializer_hparams(initializer_name) dataset_hps = datasets.get_dataset_hparams(dataset_name) input_pipeline_hps = input_pipeline_hps or config_dict.ConfigDict() + training_hps = DEFAULT_TRAINING_HPARAMS merged_dict = {} hps_dicts = [ hps.to_dict() - for hps in [model_hps, initializer_hps, dataset_hps, input_pipeline_hps] + for hps in [ + training_hps, + model_hps, + initializer_hps, + dataset_hps, + input_pipeline_hps, + ] ] - total_hps = 0 for hps_dict in hps_dicts: merged_dict.update(hps_dict) - total_hps += len(hps_dict.keys()) - # Check that all provided have no overlap. - if total_hps != len(merged_dict.keys()): - raise ValueError('There is overlap in the provided hparams.') + # Check that all provided hps have no overlap. + seen_keys = set() + for hps_dict in hps_dicts: + overlap = seen_keys.intersection(hps_dict.keys()) + if overlap: + raise ValueError(f'There is overlap in the provided hparams: {overlap}') + seen_keys.update(hps_dict.keys()) # Convert to the Shallue and Lee label smoothing style. if merged_dict.get('use_shallue_label_smoothing', False): diff --git a/init2winit/model_lib/adabelief_densenet.py b/init2winit/model_lib/adabelief_densenet.py index af64b028..9b06ac5e 100644 --- a/init2winit/model_lib/adabelief_densenet.py +++ b/init2winit/model_lib/adabelief_densenet.py @@ -48,27 +48,12 @@ # results in a large Dense matrix in the readout layer and unstable # training. use_kernel_size_as_stride_in_pooling=True, - layer_rescale_factors={}, - lr_hparams={ - 'schedule': 'constant', - 'base_lr': 0.2, - }, normalizer='batch_norm', - optimizer='momentum', - opt_hparams={ - 'momentum': 0.9, - }, - batch_size=128, - l2_decay_factor=0.0001, - l2_decay_rank_threshold=2, - label_smoothing=None, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, normalize_classifier_input='none', classification_scale_factor=1.0, - )) + ) +) class BottleneckBlock(nn.Module): diff --git a/init2winit/model_lib/adabelief_resnet.py b/init2winit/model_lib/adabelief_resnet.py index 72d6282b..6954e771 100644 --- a/init2winit/model_lib/adabelief_resnet.py +++ b/init2winit/model_lib/adabelief_resnet.py @@ -49,29 +49,12 @@ dict( num_filters=16, num_layers=18, # Must be one of [18, 34, 50, 101, 152, 200] - layer_rescale_factors={}, - lr_hparams={ - 'schedule': 'constant', - 'base_lr': 0.2, - }, - optimizer='momentum', - opt_hparams={ - 'momentum': 0.9, - }, - batch_size=128, - l2_decay_factor=0.0001, - l2_decay_rank_threshold=2, - label_smoothing=None, - rng_seed=-1, - use_shallue_label_smoothing=False, batch_norm_momentum=0.9, batch_norm_epsilon=1e-5, # Make this a string to avoid having to import jnp into the configs. model_dtype='float32', virtual_batch_size=None, - total_accumulated_batch_size=None, data_format='NHWC', - grad_clip=None, )) diff --git a/init2winit/model_lib/adabelief_vgg.py b/init2winit/model_lib/adabelief_vgg.py index d7e1b143..95422ea1 100644 --- a/init2winit/model_lib/adabelief_vgg.py +++ b/init2winit/model_lib/adabelief_vgg.py @@ -39,25 +39,10 @@ DEFAULT_HPARAMS = config_dict.ConfigDict( dict( num_layers=11, # Must be one of [11, 13, 16, 19] - layer_rescale_factors={}, - lr_hparams={ - 'schedule': 'constant', - 'base_lr': 0.2, - }, normalizer='none', - optimizer='momentum', - opt_hparams={ - 'momentum': 0.9, - }, - batch_size=128, - l2_decay_factor=0.0001, - l2_decay_rank_threshold=2, - label_smoothing=None, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, - )) + ) +) def classifier(x, num_outputs, dropout_rate, deterministic): diff --git a/init2winit/model_lib/autoencoder.py b/init2winit/model_lib/autoencoder.py index 36c803d7..9e6c7bb0 100644 --- a/init2winit/model_lib/autoencoder.py +++ b/init2winit/model_lib/autoencoder.py @@ -37,27 +37,7 @@ hid_sizes=[128, 64, 32, 64, 128], activation_function=['relu', 'relu', 'relu', 'relu', 'relu'], kernel_scales=[1.0] * 6, - lr_hparams={ - 'base_lr': 0.1, - 'schedule': 'constant' - }, - layer_rescale_factors={}, - optimizer='hessian_free', - opt_hparams={ - 'cg_max_iter': 250, - 'cg_iter_tracking_method': 'back_tracking', - 'use_line_search': True, - 'init_damping': 50.0, - 'damping_ub': 10 ** 2, - 'damping_lb': 10 ** -6, - }, - batch_size=128, - label_smoothing=None, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - l2_decay_factor=2e-5, - l2_decay_rank_threshold=1, )) @@ -82,4 +62,3 @@ def get_fake_inputs(self, hps): jnp.zeros((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype) ] return dummy_inputs - \ No newline at end of file diff --git a/init2winit/model_lib/conformer.py b/init2winit/model_lib/conformer.py index c732d372..146ef857 100644 --- a/init2winit/model_lib/conformer.py +++ b/init2winit/model_lib/conformer.py @@ -46,25 +46,8 @@ MLCOMMONS_DEFAULT_HPARAMS = config_dict.ConfigDict( dict( activation_function='swish', - optimizer='adam', - opt_hparams={ - 'beta1': .9, - 'beta2': .98, - 'epsilon': 1e-9, - 'weight_decay': 0.0 - }, - lr_hparams={ - 'base_lr': 0.1, - 'schedule': 'constant' - }, - batch_size=256, eval_batch_size=128, - l2_decay_factor=1e-6, - l2_decay_rank_threshold=0, - use_shallue_label_smoothing=False, - rng_seed=-1, model_dtype='float32', - grad_clip=5.0, encoder_dim=512, num_attention_heads=8, num_encoder_layers=16, @@ -84,7 +67,6 @@ enable_decoder_pre_layer_norm=True, enable_conformer_post_layer_norm=True, use_lingvo_attention=False, - total_accumulated_batch_size=None, attn_temperature=1.0, )) @@ -92,25 +74,8 @@ DEFAULT_HPARAMS = config_dict.ConfigDict( dict( activation_function='swish', - optimizer='adam', - opt_hparams={ - 'beta1': .9, - 'beta2': .98, - 'epsilon': 1e-9, - 'weight_decay': 0.0 - }, - lr_hparams={ - 'base_lr': 0.1, - 'schedule': 'constant' - }, - batch_size=256, eval_batch_size=128, - l2_decay_factor=1e-6, - l2_decay_rank_threshold=0, - use_shallue_label_smoothing=False, - rng_seed=-1, model_dtype='float32', - grad_clip=5.0, encoder_dim=512, num_attention_heads=8, num_encoder_layers=16, @@ -128,7 +93,6 @@ enable_decoder_pre_layer_norm=True, enable_conformer_post_layer_norm=True, use_lingvo_attention=False, - total_accumulated_batch_size=None, attn_temperature=1.0)) diff --git a/init2winit/model_lib/convolutional_autoencoder.py b/init2winit/model_lib/convolutional_autoencoder.py index f04fc459..607b9bf5 100644 --- a/init2winit/model_lib/convolutional_autoencoder.py +++ b/init2winit/model_lib/convolutional_autoencoder.py @@ -49,25 +49,8 @@ 'paddings': ['SAME', ((1, 0), (1, 0)), 'SAME', 'SAME'], 'activations': ['relu', 'relu', 'relu', 'id'], }, - activation_function='relu', - lr_hparams={ - 'base_lr': 0.02, - 'schedule': 'constant' - }, - layer_rescale_factors={}, - optimizer='momentum', - opt_hparams={ - 'momentum': 0, - }, - batch_size=128, - l2_decay_factor=None, - l2_decay_rank_threshold=0, - label_smoothing=None, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, )) diff --git a/init2winit/model_lib/deepspeech.py b/init2winit/model_lib/deepspeech.py index 3c4fba49..89b33b8b 100644 --- a/init2winit/model_lib/deepspeech.py +++ b/init2winit/model_lib/deepspeech.py @@ -47,21 +47,8 @@ MLCOMMONS_DEFAULT_HPARAMS = config_dict.ConfigDict( dict( activation='relu', - optimizer='adam', - opt_hparams={ - 'beta1': .9, - 'beta2': .98, - 'epsilon': 1e-9, - 'weight_decay': 0.0 - }, - batch_size=256, eval_batch_size=128, - l2_decay_factor=1e-6, - l2_decay_rank_threshold=0, - use_shallue_label_smoothing=False, - rng_seed=-1, model_dtype='float32', - grad_clip=10.0, num_lstm_layers=4, num_ffn_layers=3, encoder_dim=512, @@ -79,7 +66,6 @@ enable_residual_connections=False, enable_decoder_layer_norm=False, bidirectional=True, - total_accumulated_batch_size=None, enable_subsampling_batchnorm=False, enable_synced_batchnorm=False, layernorm_everywhere=False)) @@ -88,21 +74,8 @@ DEFAULT_HPARAMS = config_dict.ConfigDict( dict( activation='relu', - optimizer='adam', - opt_hparams={ - 'beta1': .9, - 'beta2': .98, - 'epsilon': 1e-9, - 'weight_decay': 0.0 - }, - batch_size=256, eval_batch_size=128, - l2_decay_factor=1e-6, - l2_decay_rank_threshold=0, - use_shallue_label_smoothing=False, - rng_seed=-1, model_dtype='float32', - grad_clip=10.0, num_lstm_layers=4, num_ffn_layers=3, encoder_dim=512, @@ -119,7 +92,6 @@ enable_residual_connections=False, enable_decoder_layer_norm=False, bidirectional=True, - total_accumulated_batch_size=None, enable_subsampling_batchnorm=False, enable_synced_batchnorm=False, layernorm_everywhere=False)) diff --git a/init2winit/model_lib/dlrm.py b/init2winit/model_lib/dlrm.py index b340ab56..786b37f5 100644 --- a/init2winit/model_lib/dlrm.py +++ b/init2winit/model_lib/dlrm.py @@ -33,7 +33,6 @@ dict( activation_function='relu', embedding_init_multiplier=None, - rng_seed=-1, model_dtype='float32', vocab_size=32 * 128 * 1024, mlp_bottom_dims=[128, 128], @@ -41,22 +40,7 @@ output_shape=(1,), embed_dim=64, keep_diags=True, - optimizer='adam', - batch_size=128, num_dense_features=13, - lr_hparams={ - 'base_lr': 0.01, - 'schedule': 'constant' - }, - opt_hparams={ - 'beta1': 0.9, - 'beta2': 0.999, - 'epsilon': 1e-8, - }, - l2_decay_factor=1e-5, - l2_decay_rank_threshold=2, - total_accumulated_batch_size=None, - grad_clip=None, dropout_rate=0.0, normalizer='none', # dropout will exist only if there are at least two top mlp layers diff --git a/init2winit/model_lib/fully_connected.py b/init2winit/model_lib/fully_connected.py index 641e057e..4bd60ef9 100644 --- a/init2winit/model_lib/fully_connected.py +++ b/init2winit/model_lib/fully_connected.py @@ -30,26 +30,10 @@ dict( hid_sizes=[20, 10], kernel_scales=[1.0, 1.0, 1.0], - lr_hparams={ - 'base_lr': 0.1, - 'schedule': 'constant' - }, - layer_rescale_factors={}, - optimizer='momentum', - opt_hparams={ - 'momentum': 0.9, - }, - batch_size=128, - total_accumulated_batch_size=None, activation_function='relu', - l2_decay_factor=.0005, - l2_decay_rank_threshold=2, - label_smoothing=None, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, - )) + ) +) class FullyConnected(nn.Module): diff --git a/init2winit/model_lib/gnn.py b/init2winit/model_lib/gnn.py index 33ff062e..9f7da0b7 100644 --- a/init2winit/model_lib/gnn.py +++ b/init2winit/model_lib/gnn.py @@ -37,32 +37,13 @@ # small hparams used for unit tests DEFAULT_HPARAMS = config_dict.ConfigDict( dict( - rng_seed=-1, model_dtype='float32', latent_dim=256, - optimizer='adam', hidden_dims=(256,), - batch_size=256, - lr_hparams={ - 'base_lr': 0.01, - 'schedule': 'constant' - }, - opt_hparams={ - 'beta1': 0.9, - 'beta2': 0.999, - 'epsilon': 1e-8, - 'weight_decay': 0.0, - }, activation_function='relu', - l2_decay_factor=.0005, - l2_decay_rank_threshold=2, num_message_passing_steps=5, normalizer='layer_norm', dropout_rate=0.1, - total_accumulated_batch_size=None, - grad_clip=None, - label_smoothing=0.0, - use_shallue_label_smoothing=False, )) diff --git a/init2winit/model_lib/local_attention_transformer.py b/init2winit/model_lib/local_attention_transformer.py index 9709da74..b489cd01 100644 --- a/init2winit/model_lib/local_attention_transformer.py +++ b/init2winit/model_lib/local_attention_transformer.py @@ -79,31 +79,8 @@ feedforward_dropout=0.0, feedforward_depths=[4096, 1032], model_dtype='float32', - batch_size=8, - grad_clip=None, - lr_hparams={ - 'base_lr': 0.01, - 'defer_steps': 10000, - 'schedule': 't2t_rsqrt_normalized_decay', - }, - optimizer='adafactor', - opt_hparams={ - 'adafactor_decay_rate': 0.8, - 'clipping_threshold': 1.0, - 'factored': True, - 'min_dim_size_to_factor': 128, - # The 2 hyperparameters cause errors with optax.inject_hyperparams - # In this case it is not relevant since the default - # adafactors values are needed - # 'adafactor_momentum': 0.0, - # 'multiply_by_parameter_scale': True, - }, - # Below hyperparameters needed only to make the model - # compatible with init2winit library - rng_seed=-1, - label_smoothing=None, - weight_decay=None, - l2_decay_factor=None,)) + ) +) Tensor = Union[np.array, jnp.ndarray] diff --git a/init2winit/model_lib/lstm_lm.py b/init2winit/model_lib/lstm_lm.py index f4276131..5bc2b554 100644 --- a/init2winit/model_lib/lstm_lm.py +++ b/init2winit/model_lib/lstm_lm.py @@ -35,9 +35,6 @@ DEFAULT_HPARAMS = config_dict.ConfigDict( dict( - # training params - batch_size=256, - rng_seed=-1, # model architecture params model_dtype='float32', bidirectional=False, @@ -49,20 +46,6 @@ recurrent_dropout_rate=0.1, tie_embeddings=False, projection_layer=False, - # optimizer params - lr_hparams={ - 'schedule': 'constant', - 'base_lr': 1e-3, - }, - l2_decay_factor=None, - grad_clip=None, - optimizer='adam', - opt_hparams={ - 'beta1': 0.9, - 'beta2': 0.999, - 'epsilon': 1e-8, - 'weight_decay': 0, - }, ) ) diff --git a/init2winit/model_lib/max_pooling_cnn.py b/init2winit/model_lib/max_pooling_cnn.py index 1520c7a9..d73d0a64 100644 --- a/init2winit/model_lib/max_pooling_cnn.py +++ b/init2winit/model_lib/max_pooling_cnn.py @@ -38,26 +38,9 @@ window_paddings=['SAME', 'SAME', 'SAME'], strides=[2, 2, 2], num_dense_units=[512, 256], - lr_hparams={ - 'base_lr': 0.001, - 'schedule': 'constant' - }, - layer_rescale_factors={}, - optimizer='momentum', - opt_hparams={ - 'momentum': 0.9, - }, - batch_size=128, activation_fn='relu', normalizer='none', - l2_decay_factor=.0005, - l2_decay_rank_threshold=2, - label_smoothing=None, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, - total_accumulated_batch_size=None, )) diff --git a/init2winit/model_lib/mdlm_rope_nanodo.py b/init2winit/model_lib/mdlm_rope_nanodo.py index 97815764..74c3d80b 100644 --- a/init2winit/model_lib/mdlm_rope_nanodo.py +++ b/init2winit/model_lib/mdlm_rope_nanodo.py @@ -37,23 +37,8 @@ num_heads=8, num_layers=12, mlp_dim=2048, - rng_seed=-1, computation_dtype='bfloat16', model_dtype='float32', - optimizer='adam', - batch_size=256, - lr_hparams={'base_lr': 0.01, 'schedule': 'constant'}, - opt_hparams={ - 'beta1': 0.9, - 'beta2': 0.999, - 'epsilon': 1e-8, - 'weight_decay': 0.0, - }, - l2_decay_factor=0.0005, - l2_decay_rank_threshold=2, - grad_clip=None, - label_smoothing=0.0, - use_shallue_label_smoothing=False, normalization='rmsnorm', mlp_activation='glu', qk_norm=True, diff --git a/init2winit/model_lib/mlperf_resnet.py b/init2winit/model_lib/mlperf_resnet.py index 59c8d562..5bbf85f3 100644 --- a/init2winit/model_lib/mlperf_resnet.py +++ b/init2winit/model_lib/mlperf_resnet.py @@ -30,33 +30,10 @@ FAKE_MODEL_DEFAULT_HPARAMS = config_dict.ConfigDict(dict( num_filters=16, num_layers=18, # Must be one of [18, 34, 50, 101, 152, 200] - layer_rescale_factors={}, - lr_hparams={ - 'batch_size': 128, - 'base_lr': 10.0, - 'decay_end': -1, - 'end_lr': 1e-4, - 'power': 2.0, - 'schedule': 'mlperf_polynomial', - 'start_lr': 0.0, - 'steps_per_epoch': 10009.250000000002, - 'warmup_steps': 18, - }, - optimizer='mlperf_lars_resnet', - opt_hparams={ - 'weight_decay': 2e-4, - 'beta': 0.9 - }, - batch_size=128, - l2_decay_factor=None, - l2_decay_rank_threshold=2, - label_smoothing=.1, - use_shallue_label_smoothing=False, model_dtype='float32', virtual_batch_size=64, data_format='NHWC', activation_function='relu', - grad_clip=None, dropout_rate=0.0, )) @@ -66,30 +43,13 @@ num_filters=16, # We set default to 18 for faster unit tests. num_layers=18, # Must be one of [18, 34, 50, 101, 152, 200] - layer_rescale_factors={}, - lr_hparams={ - 'schedule': 'constant', - 'base_lr': 0.2, - }, - optimizer='momentum', - opt_hparams={ - 'momentum': 0.9, - }, - batch_size=128, bn_output_scale=0.0, - l2_decay_factor=None, - l2_decay_rank_threshold=2, - label_smoothing=None, - rng_seed=-1, - use_shallue_label_smoothing=False, batch_norm_momentum=0.9, batch_norm_epsilon=1e-5, model_dtype='float32', virtual_batch_size=64, - total_accumulated_batch_size=None, data_format='NHWC', activation_function='relu', - grad_clip=None, dropout_rate=0.0, )) diff --git a/init2winit/model_lib/nanodo.py b/init2winit/model_lib/nanodo.py index 99fb0a0e..dac53a55 100644 --- a/init2winit/model_lib/nanodo.py +++ b/init2winit/model_lib/nanodo.py @@ -41,23 +41,8 @@ num_heads=8, # num attention heads num_layers=6, # number of transformer block layers mlp_dim=2048, # FF inner dimension - rng_seed=-1, computation_dtype='bfloat16', model_dtype='float32', - optimizer='adam', - batch_size=256, - lr_hparams={'base_lr': 0.01, 'schedule': 'constant'}, - opt_hparams={ - 'beta1': 0.9, - 'beta2': 0.999, - 'epsilon': 1e-8, - 'weight_decay': 0.0, - }, - l2_decay_factor=0.0005, - l2_decay_rank_threshold=2, - grad_clip=None, - label_smoothing=0.0, - use_shallue_label_smoothing=False, ) ) diff --git a/init2winit/model_lib/nqm.py b/init2winit/model_lib/nqm.py index 818bddf7..baa8a350 100644 --- a/init2winit/model_lib/nqm.py +++ b/init2winit/model_lib/nqm.py @@ -27,22 +27,11 @@ # small hparams used for unit tests DEFAULT_HPARAMS = config_dict.ConfigDict(dict( - optimizer='momentum', - opt_hparams={ - 'momentum': 0.0, - }, - lr_hparams={ - 'base_lr': 0.1, - 'schedule': 'constant' - }, - batch_size=128, - rng_seed=-1, # Note the dimension is set by input_shape. hessian_decay_power=1, noise_decay_power=1, nqm_mode='diagH_diagC', model_dtype='float32', - l2_decay_factor=None, )) diff --git a/init2winit/model_lib/resnet.py b/init2winit/model_lib/resnet.py index ae312232..4198259f 100644 --- a/init2winit/model_lib/resnet.py +++ b/init2winit/model_lib/resnet.py @@ -29,33 +29,16 @@ DEFAULT_HPARAMS = config_dict.ConfigDict(dict( num_filters=16, num_layers=18, # Must be one of [18, 34, 50, 101, 152, 200] - layer_rescale_factors={}, - lr_hparams={ - 'schedule': 'constant', - 'base_lr': 0.2, - }, - optimizer='momentum', - opt_hparams={ - 'momentum': 0.9, - }, - batch_size=128, - l2_decay_factor=0.0001, - l2_decay_rank_threshold=2, - label_smoothing=None, - rng_seed=-1, - use_shallue_label_smoothing=False, batch_norm_momentum=0.9, batch_norm_epsilon=1e-5, # Make this a string to avoid having to import jnp into the configs. model_dtype='float32', virtual_batch_size=64, - total_accumulated_batch_size=None, data_format='NHWC', block_type='post_activation', # either pre_activation or post_activation bn_relu_conv=True, # only used for block_type='pre_activation' use_bn=True, dropout_rate=0.0, - grad_clip=None, activation_function='relu', extra_norm_on_residual=False, )) diff --git a/init2winit/model_lib/rope_nanodo.py b/init2winit/model_lib/rope_nanodo.py index 0b43ac8d..803cfff9 100644 --- a/init2winit/model_lib/rope_nanodo.py +++ b/init2winit/model_lib/rope_nanodo.py @@ -42,23 +42,8 @@ num_heads=8, # num attention heads num_layers=12, # number of transformer block layers mlp_dim=2048, # FF inner dimension - rng_seed=-1, computation_dtype='bfloat16', model_dtype='float32', - optimizer='adam', - batch_size=256, - lr_hparams={'base_lr': 0.01, 'schedule': 'constant'}, - opt_hparams={ - 'beta1': 0.9, - 'beta2': 0.999, - 'epsilon': 1e-8, - 'weight_decay': 0.0, - }, - l2_decay_factor=0.0005, - l2_decay_rank_threshold=2, - grad_clip=None, - label_smoothing=0.0, - use_shallue_label_smoothing=False, normalization='rmsnorm', mlp_activation='glu', qk_norm=True, diff --git a/init2winit/model_lib/simple_cnn.py b/init2winit/model_lib/simple_cnn.py index 9cbdb5e6..e1d35b53 100644 --- a/init2winit/model_lib/simple_cnn.py +++ b/init2winit/model_lib/simple_cnn.py @@ -29,22 +29,7 @@ DEFAULT_HPARAMS = config_dict.ConfigDict(dict( num_filters=[20, 10], kernel_sizes=[3, 3], - lr_hparams={ - 'base_lr': 0.001, - 'schedule': 'constant' - }, - layer_rescale_factors={}, - optimizer='momentum', - opt_hparams={ - 'momentum': 0.9, - }, - batch_size=128, activation_function='relu', - l2_decay_factor=.0005, - l2_decay_rank_threshold=2, - label_smoothing=None, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', )) diff --git a/init2winit/model_lib/test_models.py b/init2winit/model_lib/test_models.py index 5f7155c0..c3decf5d 100644 --- a/init2winit/model_lib/test_models.py +++ b/init2winit/model_lib/test_models.py @@ -25,6 +25,7 @@ from absl.testing import absltest from absl.testing import parameterized import flax.linen as nn +from init2winit import hyperparameters from init2winit.init_lib import initializers from init2winit.model_lib import model_utils from init2winit.model_lib import models @@ -425,7 +426,8 @@ def _get_fake_inputs_for_initialization(model, hps): def _initialize_model(model_str, model_dtype): """Initialize a model given a registry name and dtype.""" model_cls = models.get_model(model_str) - hps = models.get_model_hparams(model_str) + hps = copy.deepcopy(hyperparameters.DEFAULT_TRAINING_HPARAMS) + hps.update(models.get_model_hparams(model_str)) hps.update(DATA_HPS[model_str]) if 'input_edge_shape' in hps and 'input_node_shape' in hps: hps.input_shape = (hps.input_node_shape, hps.input_edge_shape) @@ -462,7 +464,8 @@ def test_classification_models(self, model_str): model_hps = models.get_model_hparams(model_str) loss = 'cross_entropy' metrics = 'classification_metrics' - hps = copy.copy(model_hps) + hps = copy.deepcopy(hyperparameters.DEFAULT_TRAINING_HPARAMS) + hps.update(model_hps) hps.update({'output_shape': OUTPUT_SHAPE['classification']}) rng = jax.random.PRNGKey(0) dropout_rng, params_rng = jax.random.split(rng) diff --git a/init2winit/model_lib/transformer_lm.py b/init2winit/model_lib/transformer_lm.py index 0e310ff8..f2fb4b3d 100644 --- a/init2winit/model_lib/transformer_lm.py +++ b/init2winit/model_lib/transformer_lm.py @@ -37,7 +37,6 @@ # These reproduce the flax example. DEFAULT_HPARAMS = config_dict.ConfigDict( dict( - batch_size=512, emb_dim=128, num_heads=8, num_layers=6, @@ -45,28 +44,8 @@ mlp_dim=512, dropout_rate=0.1, attention_dropout_rate=0.1, - optimizer='adam', - opt_hparams={ - 'beta1': .9, - 'beta2': .98, - 'epsilon': 1e-9, - 'weight_decay': 0.1 - }, - layer_rescale_factors={}, normalizer='layer_norm', - lr_hparams={ - 'base_lr': 0.0016, - 'warmup_steps': 1000, - 'squash_steps': 1000, - 'schedule': 'rsqrt_normalized_decay_warmup' - }, - label_smoothing=None, - l2_decay_factor=None, - l2_decay_rank_threshold=0, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, decode=False, normalize_attention=False, )) diff --git a/init2winit/model_lib/transformer_stu_lm.py b/init2winit/model_lib/transformer_stu_lm.py index 07e41acf..2a5d2fcd 100644 --- a/init2winit/model_lib/transformer_stu_lm.py +++ b/init2winit/model_lib/transformer_stu_lm.py @@ -40,7 +40,6 @@ # These reproduce the flax example. DEFAULT_HPARAMS = config_dict.ConfigDict( dict( - batch_size=512, emb_dim=128, num_heads=8, num_layers=6, @@ -48,28 +47,8 @@ mlp_dim=512, dropout_rate=0.1, attention_dropout_rate=0.1, - optimizer='adam', - opt_hparams={ - 'beta1': .9, - 'beta2': .98, - 'epsilon': 1e-9, - 'weight_decay': 0.1 - }, - layer_rescale_factors={}, normalizer='layer_norm', - lr_hparams={ - 'base_lr': 0.0016, - 'warmup_steps': 1000, - 'squash_steps': 1000, - 'schedule': 'rsqrt_normalized_decay_warmup' - }, - label_smoothing=None, - l2_decay_factor=None, - l2_decay_rank_threshold=0, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, decode=False, normalize_attention=False, input_len=128, diff --git a/init2winit/model_lib/transformer_stu_tensordot_lm.py b/init2winit/model_lib/transformer_stu_tensordot_lm.py index 42a1ae52..351cde4d 100644 --- a/init2winit/model_lib/transformer_stu_tensordot_lm.py +++ b/init2winit/model_lib/transformer_stu_tensordot_lm.py @@ -40,7 +40,6 @@ # These reproduce the flax example. DEFAULT_HPARAMS = config_dict.ConfigDict( dict( - batch_size=512, emb_dim=128, num_heads=8, num_layers=6, @@ -48,28 +47,8 @@ mlp_dim=512, dropout_rate=0.1, attention_dropout_rate=0.1, - optimizer='adam', - opt_hparams={ - 'beta1': .9, - 'beta2': .98, - 'epsilon': 1e-9, - 'weight_decay': 0.1 - }, - layer_rescale_factors={}, normalizer='layer_norm', - lr_hparams={ - 'base_lr': 0.0016, - 'warmup_steps': 1000, - 'squash_steps': 1000, - 'schedule': 'rsqrt_normalized_decay_warmup' - }, - label_smoothing=None, - l2_decay_factor=None, - l2_decay_rank_threshold=0, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, decode=False, normalize_attention=False, input_len=128, diff --git a/init2winit/model_lib/unet.py b/init2winit/model_lib/unet.py index 3a4e3741..b12b4982 100644 --- a/init2winit/model_lib/unet.py +++ b/init2winit/model_lib/unet.py @@ -33,43 +33,6 @@ from ml_collections import config_dict -# NOTE(dsuo): We use the Kitchen Sink optimizer to match the RMSProp -# implementation found in the reference FastMRI U-Net code. Specifically, -# epsilon in optax's scale_by_rms places its epsilon inside the square root, -# whereas the reference code epsilon outside. -opt_hparams = { - 'weight_decay': 0.0, - 'beta1': 0.9, - 'beta2': 0.999, - 'epsilon': 1e-8, -} - -# NOTE(dsuo): This lives here because decay_events / decay_factors is too large -# to pass via the config file. -_FASTMRI_TRAIN_SIZE = 34742 -_FASTMRI_VALID_SIZE = 7135 - -batch_size = 8 -num_epochs = 50 -steps_per_epoch = int(_FASTMRI_TRAIN_SIZE / batch_size) -num_train_steps = num_epochs * steps_per_epoch -lr_gamma = 0.1 -lr_step_size = 40 * steps_per_epoch -decay_events = list(range(lr_step_size, num_train_steps, lr_step_size)) -decay_factors = [lr_gamma] * len(decay_events) -decay_factors = [ - decay_factor**i - for decay_factor, i in zip(decay_factors, range(1, - len(decay_events) + 1)) -] - -lr_hparams = { - 'schedule': 'piecewise_constant', - 'base_lr': 1e-3, - 'decay_events': decay_events, - 'decay_factors': decay_factors -} - DEFAULT_HPARAMS = config_dict.ConfigDict( dict( out_chans=1, @@ -77,15 +40,7 @@ num_pool_layers=4, dropout_rate=0.0, activation='leaky_relu', - optimizer='adam', - opt_hparams=opt_hparams, - lr_hparams=lr_hparams, - l2_decay_factor=None, - batch_size=batch_size, - rng_seed=-1, model_dtype='float32', - grad_clip=None, - total_accumulated_batch_size=None, normalizer='unet_instance_norm', )) diff --git a/init2winit/model_lib/vit.py b/init2winit/model_lib/vit.py index 6bcad8f9..275f9dcc 100644 --- a/init2winit/model_lib/vit.py +++ b/init2winit/model_lib/vit.py @@ -43,27 +43,8 @@ pool_type='gap', posemb='sincos2d', head_zeroinit=True, - lr_hparams={ - 'base_lr': 1e-3, - 'schedule': 'cosine_warmup', - }, - optimizer='adam', - opt_hparams={ - 'beta1': 0.9, - 'beta2': 0.999, - 'epsilon': 1e-8, - 'weight_decay': 1e-1, - }, - l2_decay_factor=None, - l2_decay_rank_threshold=2, - batch_size=1024, - rng_seed=-1, model_dtype='float32', - grad_clip=None, - total_accumulated_batch_size=None, dropout_rate=0.0, - label_smoothing=0.0, - use_shallue_label_smoothing=False, normalizer='pre_layer_norm', activation='gelu', resnet_style_residual=False, diff --git a/init2winit/model_lib/wide_resnet.py b/init2winit/model_lib/wide_resnet.py index 8df134a9..f9379b1f 100644 --- a/init2winit/model_lib/wide_resnet.py +++ b/init2winit/model_lib/wide_resnet.py @@ -30,31 +30,14 @@ dict( blocks_per_group=3, channel_multiplier=2, - lr_hparams={ - 'base_lr': 0.001, - 'schedule': 'cosine' - }, normalizer='batch_norm', - layer_rescale_factors={}, conv_kernel_scale=1.0, dense_kernel_scale=1.0, dropout_rate=0.0, conv_kernel_init='lecun_normal', dense_kernel_init='lecun_normal', - optimizer='momentum', - opt_hparams={ - 'momentum': 0.9, - }, - batch_size=128, virtual_batch_size=None, - total_accumulated_batch_size=None, - l2_decay_factor=0.0001, - l2_decay_rank_threshold=2, - label_smoothing=None, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, activation_function='relu', group_strides=[(1, 1), (2, 2), (2, 2)]) ) diff --git a/init2winit/model_lib/xformer_translate.py b/init2winit/model_lib/xformer_translate.py index 348117b8..ff7e5984 100644 --- a/init2winit/model_lib/xformer_translate.py +++ b/init2winit/model_lib/xformer_translate.py @@ -42,7 +42,6 @@ MLCOMMONS_DEFAULT_HPARAMS = config_dict.ConfigDict( dict( - batch_size=64, share_embeddings=False, logits_via_embedding=False, emb_dim=512, @@ -56,40 +55,18 @@ dropout_rate=0.1, aux_dropout_rate=0.1, tie_dropouts=False, - optimizer='adam', - opt_hparams={ - 'beta1': .9, - 'beta2': .98, - 'epsilon': 1e-9, - 'weight_decay': 0.0 - }, - layer_rescale_factors={}, normalizer='layer_norm', - lr_hparams={ - 'base_lr': 0.05, - 'warmup_steps': 8000, - 'factors': 'constant * linear_warmup * rsqrt_decay', - 'schedule': 'compound' - }, - label_smoothing=0.1, - l2_decay_factor=None, - l2_decay_rank_threshold=0, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, enc_self_attn_kernel_init='xavier_uniform', dec_self_attn_kernel_init='xavier_uniform', dec_cross_attn_kernel_init='xavier_uniform', decode=False, - total_accumulated_batch_size=None, normalize_attention=False, )) DEFAULT_HPARAMS = config_dict.ConfigDict( dict( - batch_size=64, share_embeddings=False, logits_via_embedding=False, emb_dim=512, @@ -102,33 +79,12 @@ mlp_dim=512, dropout_rate=0.1, attention_dropout_rate=0.1, - optimizer='adam', - opt_hparams={ - 'beta1': .9, - 'beta2': .98, - 'epsilon': 1e-9, - 'weight_decay': 0.0 - }, - layer_rescale_factors={}, normalizer='layer_norm', - lr_hparams={ - 'base_lr': 0.05, - 'warmup_steps': 8000, - 'factors': 'constant * linear_warmup * rsqrt_decay', - 'schedule': 'compound' - }, - label_smoothing=0.1, - l2_decay_factor=None, - l2_decay_rank_threshold=0, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, enc_self_attn_kernel_init='xavier_uniform', dec_self_attn_kernel_init='xavier_uniform', dec_cross_attn_kernel_init='xavier_uniform', decode=False, - total_accumulated_batch_size=None, normalize_attention=False, )) diff --git a/init2winit/model_lib/xformer_translate_binary.py b/init2winit/model_lib/xformer_translate_binary.py index bd5b4fbd..ace8d5e4 100644 --- a/init2winit/model_lib/xformer_translate_binary.py +++ b/init2winit/model_lib/xformer_translate_binary.py @@ -43,7 +43,6 @@ def _default_binarize_hparams(): DEFAULT_HPARAMS = config_dict.ConfigDict( dict( - batch_size=64, share_embeddings=False, logits_via_embedding=False, emb_dim=512, @@ -54,33 +53,12 @@ def _default_binarize_hparams(): mlp_dim=512, dropout_rate=0.1, attention_dropout_rate=0.1, - optimizer='adam', - opt_hparams={ - 'beta1': 0.9, - 'beta2': 0.98, - 'epsilon': 1e-9, - 'weight_decay': 0.0, - }, - layer_rescale_factors={}, normalizer='layer_norm', - lr_hparams={ - 'base_lr': 0.05, - 'warmup_steps': 8000, - 'factors': 'constant * linear_warmup * rsqrt_decay', - 'schedule': 'compound', - }, - label_smoothing=0.1, - l2_decay_factor=None, - l2_decay_rank_threshold=0, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, enc_self_attn_kernel_init='xavier_uniform', dec_self_attn_kernel_init='xavier_uniform', dec_cross_attn_kernel_init='xavier_uniform', decode=False, - total_accumulated_batch_size=None, binarize_hparams=_default_binarize_hparams(), quant_steps={ # training step at which model is partially binarized 'ff_weights': 90e3, diff --git a/init2winit/model_lib/xformer_translate_mlc_variant.py b/init2winit/model_lib/xformer_translate_mlc_variant.py index 9e825328..6cea66c1 100644 --- a/init2winit/model_lib/xformer_translate_mlc_variant.py +++ b/init2winit/model_lib/xformer_translate_mlc_variant.py @@ -42,7 +42,6 @@ DEFAULT_HPARAMS = config_dict.ConfigDict( dict( - batch_size=64, share_embeddings=False, logits_via_embedding=False, emb_dim=512, @@ -56,34 +55,13 @@ dropout_rate=0.1, aux_dropout_rate=0.1, tie_dropouts=False, - optimizer='adam', - opt_hparams={ - 'beta1': .9, - 'beta2': .98, - 'epsilon': 1e-9, - 'weight_decay': 0.0 - }, - layer_rescale_factors={}, normalizer='layer_norm', - lr_hparams={ - 'base_lr': 0.05, - 'warmup_steps': 8000, - 'factors': 'constant * linear_warmup * rsqrt_decay', - 'schedule': 'compound' - }, - label_smoothing=0.1, - l2_decay_factor=None, - l2_decay_rank_threshold=0, - rng_seed=-1, - use_shallue_label_smoothing=False, model_dtype='float32', - grad_clip=None, enc_self_attn_kernel_init='xavier_uniform', dec_self_attn_kernel_init='xavier_uniform', dec_cross_attn_kernel_init='xavier_uniform', attn_kernel_scale=1.0, decode=False, - total_accumulated_batch_size=None, normalize_attention=False, glu=False, ffn_activation='relu', diff --git a/init2winit/test_hyperparameters.py b/init2winit/test_hyperparameters.py index 0b85286e..43419351 100644 --- a/init2winit/test_hyperparameters.py +++ b/init2winit/test_hyperparameters.py @@ -84,18 +84,14 @@ def test_dot_override(self): hparam_overrides=hps_overrides, ) - self.assertEqual( - merged_hps.lr_hparams['schedule'], 'rsqrt_normalized_decay_warmup' - ) + self.assertEqual(merged_hps.lr_hparams['schedule'], 'constant') expected_lr_hparams = { 'base_lr': 77.0, - 'warmup_steps': 1000, - 'squash_steps': 1000, - 'schedule': 'rsqrt_normalized_decay_warmup', + 'schedule': 'constant', } self.assertEqual( set(merged_hps.lr_hparams.keys()), - set(['schedule', 'warmup_steps', 'base_lr', 'squash_steps']), + set(['schedule', 'base_lr']), ) self.assertEqual(merged_hps.lr_hparams.to_dict(), expected_lr_hparams) diff --git a/init2winit/trainer_lib/test_trainer.py b/init2winit/trainer_lib/test_trainer.py index ed126a15..2d02fa47 100644 --- a/init2winit/trainer_lib/test_trainer.py +++ b/init2winit/trainer_lib/test_trainer.py @@ -389,7 +389,8 @@ def test_graph_model_trainer(self): rng = jax.random.PRNGKey(1337) model_str = 'gnn' model_cls = models.get_model(model_str) - hps = models.get_model_hparams(model_str) + hps = copy.deepcopy(hyperparameters.DEFAULT_TRAINING_HPARAMS) + hps.update(models.get_model_hparams(model_str)) hps.update({ 'batch_size': 2, 'input_edge_shape': (7,), @@ -455,7 +456,8 @@ def test_dlrm_model_trainer(self): model_str = 'dlrm' dataset_str = 'criteo1tb' model_cls = models.get_model(model_str) - model_hps = models.get_model_hparams(model_str) + model_hps = copy.deepcopy(hyperparameters.DEFAULT_TRAINING_HPARAMS) + model_hps.update(models.get_model_hparams(model_str)) model_hps.vocab_size = 1024 dataset_hps = datasets.get_dataset_hparams(dataset_str) dataset_hps.update({ @@ -961,9 +963,10 @@ def test_early_stopping(self, min_steps): initializer = initializers.get_initializer(initializer_name) dataset_builder = datasets.get_dataset(dataset_name) hparam_overrides = { - 'lr_hparams': { - 'base_lr': 0.1, - 'schedule': 'cosine' + 'lr_hparams': {'base_lr': 0.1, 'schedule': 'cosine'}, + 'optimizer': 'momentum', + 'opt_hparams': { + 'momentum': 0.9, }, 'batch_size': 8, 'train_size': 160,