Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions src/maxtext/configs/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import omegaconf

from maxtext.configs import pyconfig_deprecated
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR, MAXTEXT_PKG_DIR
from maxtext.common.common_types import DecoderBlockType, ShardMode
from maxtext.configs import types
from maxtext.configs.types import MaxTextConfig
Expand All @@ -46,6 +46,49 @@
# Don't log the following keys.
KEYS_NO_LOGGING = ("hf_access_token",)

# Module paths to their default config file (relative to MAXTEXT_CONFIGS_DIR).
_CONFIG_FILE_MAPPING: dict[str, str] = {
"maxtext.trainers.pre_train.train": "base.yml",
"maxtext.trainers.pre_train.train_compile": "base.yml",
"maxtext.trainers.post_train.distillation.train_distill": "post_train/distillation.yml",
"maxtext.trainers.post_train.rl.train_rl": "post_train/rl.yml",
"maxtext.trainers.post_train.sft.train_sft": "post_train/sft.yml",
"maxtext.trainers.post_train.sft.train_sft_deprecated": "post_train/sft.yml",
"maxtext.inference.decode": "base.yml",
"maxtext.inference.decode_multi": "base.yml",
"maxtext.inference.inference_microbenchmark": "base.yml",
"maxtext.inference.inference_microbenchmark_sweep": "base.yml",
"maxtext.inference.maxengine.maxengine_server": "base.yml",
"maxtext.inference.mlperf.microbenchmarks.benchmark_chunked_prefill": "base.yml",
"maxtext.inference.vllm_decode": "base.yml",
"maxtext.checkpoint_conversion.to_maxtext": "base.yml",
"maxtext.checkpoint_conversion.to_huggingface": "base.yml",
}


def _module_from_path(path: str) -> str | None:
"""Convert a file path to module path for config inference."""
real_path = os.path.realpath(path)
pkg_parent = os.path.realpath(os.path.dirname(MAXTEXT_PKG_DIR))
if real_path.startswith(pkg_parent + os.sep):
relative = os.path.relpath(real_path, pkg_parent)
return relative.replace(os.sep, ".").removesuffix(".py")
return None


def _resolve_or_infer_config(argv: list[str]) -> tuple[str, list[str]]:
"""Resolves or infers config file path from module."""
if len(argv) >= 2 and argv[1].endswith(".yml"):
return resolve_config_path(argv[1]), argv[2:]
module = _module_from_path(argv[0])
if module not in _CONFIG_FILE_MAPPING:
raise ValueError(
f"No config file provided and no default config found for module '{module}'"
)
config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module])
logger.warning("No config file provided, using default config mapping: %s", config_path)
return config_path, argv[1:]


def yaml_key_to_env_key(s: str) -> str:
return _MAX_PREFIX + s.upper()
Expand Down Expand Up @@ -227,11 +270,11 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig:
Returns pydantic MaxTextConfig class whereas `initialize` returns the og `HyperParameters`
"""
# 1. Load base and inherited configs from file(s)
config_path = resolve_config_path(argv[1])
config_path, cli_args = _resolve_or_infer_config(argv)
base_yml_config = _load_config(config_path)

# 2. Get overrides from CLI and kwargs
cli_cfg = omegaconf.OmegaConf.from_cli(argv[2:])
cli_cfg = omegaconf.OmegaConf.from_cli(cli_args)
kwargs_cfg = omegaconf.OmegaConf.create(kwargs)
overrides_cfg = omegaconf.OmegaConf.merge(cli_cfg, kwargs_cfg)

Expand Down
17 changes: 16 additions & 1 deletion tests/unit/pyconfig_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import unittest

from maxtext.configs import pyconfig
from maxtext.configs.pyconfig import resolve_config_path
from maxtext.configs.pyconfig import resolve_config_path, _CONFIG_FILE_MAPPING, _module_from_path
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR, MAXTEXT_PKG_DIR
from tests.utils.test_helpers import get_test_config_path, get_post_train_test_config_path

Expand Down Expand Up @@ -115,6 +115,21 @@ def test_resolve_config_path_pip_install(self):
finally:
os.chdir(orig)

def test_config_file_mapping(self):
for module, relative_path in _CONFIG_FILE_MAPPING.items():
full_path = os.path.join(MAXTEXT_CONFIGS_DIR, relative_path)
self.assertTrue(os.path.isfile(full_path), f"Default config for '{module}' not found at {full_path}")

def test_module_from_path(self):
import maxtext.trainers.pre_train.train as train_module
module_file = train_module.__file__
result = _module_from_path(module_file)
self.assertEqual(result, "maxtext.trainers.pre_train.train")

def test_unknown_module_raises(self):
with self.assertRaises(ValueError):
pyconfig.initialize_pydantic(["/custom_rl/module.py", "run_name=test"])


if __name__ == "__main__":
unittest.main()
Loading