From 54708cc851c3e07a926898ffb596b8a0c8e67de5 Mon Sep 17 00:00:00 2001 From: learning-to-play <66660475+learning-to-play@users.noreply.github.com> Date: Sat, 16 May 2026 18:09:26 -0700 Subject: [PATCH] Updates 'full' remat policy to use `nothing_saveable` Maps 'full' remat policy to `jax.checkpoint_policies.nothing_saveable` to fix previous mapping to None, which JAX interpreted as "save all activations." --- src/maxtext/layers/decoders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index a2d52dd033..f95fb69a35 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -436,7 +436,7 @@ def get_remat_policy(self): ) else: assert cfg.remat_policy == "full", "Remat policy needs to be on list of remat policies" - policy = None + policy = jax.checkpoint_policies.nothing_saveable return policy def get_decoder_layers(self):