Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/maxtext/checkpoint_conversion/to_maxtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,9 +595,14 @@ def main(
hf_config_obj = AutoConfig.from_pretrained(model_id, token=hf_token, revision=revision)
hf_model = get_hf_model(model_id, token=hf_token, revision=revision)
hf_state_dict_numpy = hf_model.state_dict()
# Convert all to numpy immediately in eager mode
# Convert all to numpy immediately in eager mode.
# torch.Tensor.numpy() does not support bfloat16, so cast to float32 first.
import torch # pylint: disable=g-import-not-at-top
Copy link
Collaborator

@shuningjin shuningjin Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this import to the top?

Copy link
Contributor Author

@phu0ngng phu0ngng Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can, but then it will require torch in the lazy loading path as well, even though torch is not needed there. I think we should go with the current implementation to avoid torch requirements for lazy load.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shuningjin what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shuningjin, friendly reminder.

for k, v in hf_state_dict_numpy.items():
hf_state_dict_numpy[k] = v.numpy()
if v.dtype == torch.bfloat16:
hf_state_dict_numpy[k] = v.float().numpy()
else:
hf_state_dict_numpy[k] = v.numpy()
del hf_model
max_logging.log("HuggingFace model loaded and converted to NumPy.")
print_ram_usage("After full HF model load")
Expand Down
Loading