diff --git a/src/maxtext/configs/pyconfig.py b/src/maxtext/configs/pyconfig.py index 663b76633a..a4edd5320b 100644 --- a/src/maxtext/configs/pyconfig.py +++ b/src/maxtext/configs/pyconfig.py @@ -29,7 +29,7 @@ import omegaconf from maxtext.configs import pyconfig_deprecated -from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR +from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR, MAXTEXT_PKG_DIR from maxtext.common.common_types import DecoderBlockType, ShardMode from maxtext.configs import types from maxtext.configs.types import MaxTextConfig @@ -46,6 +46,49 @@ # Don't log the following keys. KEYS_NO_LOGGING = ("hf_access_token",) +# Module paths to their default config file (relative to MAXTEXT_CONFIGS_DIR). +_CONFIG_FILE_MAPPING: dict[str, str] = { + "maxtext.trainers.pre_train.train": "base.yml", + "maxtext.trainers.pre_train.train_compile": "base.yml", + "maxtext.trainers.post_train.distillation.train_distill": "post_train/distillation.yml", + "maxtext.trainers.post_train.rl.train_rl": "post_train/rl.yml", + "maxtext.trainers.post_train.sft.train_sft": "post_train/sft.yml", + "maxtext.trainers.post_train.sft.train_sft_deprecated": "post_train/sft.yml", + "maxtext.inference.decode": "base.yml", + "maxtext.inference.decode_multi": "base.yml", + "maxtext.inference.inference_microbenchmark": "base.yml", + "maxtext.inference.inference_microbenchmark_sweep": "base.yml", + "maxtext.inference.maxengine.maxengine_server": "base.yml", + "maxtext.inference.mlperf.microbenchmarks.benchmark_chunked_prefill": "base.yml", + "maxtext.inference.vllm_decode": "base.yml", + "maxtext.checkpoint_conversion.to_maxtext": "base.yml", + "maxtext.checkpoint_conversion.to_huggingface": "base.yml", +} + + +def _module_from_path(path: str) -> str | None: + """Convert a file path to module path for config inference.""" + real_path = os.path.realpath(path) + pkg_parent = os.path.realpath(os.path.dirname(MAXTEXT_PKG_DIR)) + if real_path.startswith(pkg_parent + os.sep): + relative = os.path.relpath(real_path, pkg_parent) + return relative.replace(os.sep, ".").removesuffix(".py") + return None + + +def _resolve_or_infer_config(argv: list[str]) -> tuple[str, list[str]]: + """Resolves or infers config file path from module.""" + if len(argv) >= 2 and argv[1].endswith(".yml"): + return resolve_config_path(argv[1]), argv[2:] + module = _module_from_path(argv[0]) + if module not in _CONFIG_FILE_MAPPING: + raise ValueError( + f"No config file provided and no default config found for module '{module}'" + ) + config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module]) + logger.warning("No config file provided, using default config mapping: %s", config_path) + return config_path, argv[1:] + def yaml_key_to_env_key(s: str) -> str: return _MAX_PREFIX + s.upper() @@ -227,11 +270,11 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig: Returns pydantic MaxTextConfig class whereas `initialize` returns the og `HyperParameters` """ # 1. Load base and inherited configs from file(s) - config_path = resolve_config_path(argv[1]) + config_path, cli_args = _resolve_or_infer_config(argv) base_yml_config = _load_config(config_path) # 2. Get overrides from CLI and kwargs - cli_cfg = omegaconf.OmegaConf.from_cli(argv[2:]) + cli_cfg = omegaconf.OmegaConf.from_cli(cli_args) kwargs_cfg = omegaconf.OmegaConf.create(kwargs) overrides_cfg = omegaconf.OmegaConf.merge(cli_cfg, kwargs_cfg) diff --git a/tests/unit/pyconfig_test.py b/tests/unit/pyconfig_test.py index 3e7f7be975..9aed5cc195 100644 --- a/tests/unit/pyconfig_test.py +++ b/tests/unit/pyconfig_test.py @@ -19,7 +19,7 @@ import unittest from maxtext.configs import pyconfig -from maxtext.configs.pyconfig import resolve_config_path +from maxtext.configs.pyconfig import resolve_config_path, _CONFIG_FILE_MAPPING, _module_from_path from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR, MAXTEXT_PKG_DIR from tests.utils.test_helpers import get_test_config_path, get_post_train_test_config_path @@ -115,6 +115,21 @@ def test_resolve_config_path_pip_install(self): finally: os.chdir(orig) + def test_config_file_mapping(self): + for module, relative_path in _CONFIG_FILE_MAPPING.items(): + full_path = os.path.join(MAXTEXT_CONFIGS_DIR, relative_path) + self.assertTrue(os.path.isfile(full_path), f"Default config for '{module}' not found at {full_path}") + + def test_module_from_path(self): + import maxtext.trainers.pre_train.train as train_module + module_file = train_module.__file__ + result = _module_from_path(module_file) + self.assertEqual(result, "maxtext.trainers.pre_train.train") + + def test_unknown_module_raises(self): + with self.assertRaises(ValueError): + pyconfig.initialize_pydantic(["/custom_rl/module.py", "run_name=test"]) + if __name__ == "__main__": unittest.main()