From 21dfe5d22cb272f7d600b6e9578337d858f5e2b7 Mon Sep 17 00:00:00 2001 From: learning-to-play <66660475+learning-to-play@users.noreply.github.com> Date: Sat, 16 May 2026 18:21:48 -0700 Subject: [PATCH] Gemma and Gemma 2 Native Serving Registration Maps GemmaForCausalLM and Gemma2ForCausalLM to the high-performance JAX-native adapter, preventing vLLM from falling back to PyTorch and crashing or running inefficiently. --- src/maxtext/integration/vllm/maxtext_vllm_adapter/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/maxtext/integration/vllm/maxtext_vllm_adapter/__init__.py b/src/maxtext/integration/vllm/maxtext_vllm_adapter/__init__.py index 216737a7fc..5f9474ea0b 100644 --- a/src/maxtext/integration/vllm/maxtext_vllm_adapter/__init__.py +++ b/src/maxtext/integration/vllm/maxtext_vllm_adapter/__init__.py @@ -30,4 +30,6 @@ def register(): """ logger.info("Registering MaxTextForCausalLM model with tpu_inference and vllm.") register_model("MaxTextForCausalLM", MaxTextForCausalLM) - logger.info("Successfully registered MaxTextForCausalLM model.") + register_model("GemmaForCausalLM", MaxTextForCausalLM) + register_model("Gemma2ForCausalLM", MaxTextForCausalLM) + logger.info("Successfully registered MaxTextForCausalLM, Gemma2ForCausalLM, and GemmaForCausalLM model.")