From 7bf11b5e5796f184cb36c5f52143f62ab0426813 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 6 Mar 2026 16:26:59 -0800 Subject: [PATCH] fix bf16 type conversion Signed-off-by: Phuong Nguyen --- src/maxtext/checkpoint_conversion/to_maxtext.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/maxtext/checkpoint_conversion/to_maxtext.py b/src/maxtext/checkpoint_conversion/to_maxtext.py index 3292b33565..c48ece773a 100644 --- a/src/maxtext/checkpoint_conversion/to_maxtext.py +++ b/src/maxtext/checkpoint_conversion/to_maxtext.py @@ -595,9 +595,14 @@ def main( hf_config_obj = AutoConfig.from_pretrained(model_id, token=hf_token, revision=revision) hf_model = get_hf_model(model_id, token=hf_token, revision=revision) hf_state_dict_numpy = hf_model.state_dict() - # Convert all to numpy immediately in eager mode + # Convert all to numpy immediately in eager mode. + # torch.Tensor.numpy() does not support bfloat16, so cast to float32 first. + import torch # pylint: disable=g-import-not-at-top for k, v in hf_state_dict_numpy.items(): - hf_state_dict_numpy[k] = v.numpy() + if v.dtype == torch.bfloat16: + hf_state_dict_numpy[k] = v.float().numpy() + else: + hf_state_dict_numpy[k] = v.numpy() del hf_model max_logging.log("HuggingFace model loaded and converted to NumPy.") print_ram_usage("After full HF model load")