diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index a2d52dd033..f95fb69a35 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -436,7 +436,7 @@ def get_remat_policy(self): ) else: assert cfg.remat_policy == "full", "Remat policy needs to be on list of remat policies" - policy = None + policy = jax.checkpoint_policies.nothing_saveable return policy def get_decoder_layers(self):