diff --git a/src/maxtext/integration/vllm/maxtext_vllm_adapter/__init__.py b/src/maxtext/integration/vllm/maxtext_vllm_adapter/__init__.py index 216737a7fc..5f9474ea0b 100644 --- a/src/maxtext/integration/vllm/maxtext_vllm_adapter/__init__.py +++ b/src/maxtext/integration/vllm/maxtext_vllm_adapter/__init__.py @@ -30,4 +30,6 @@ def register(): """ logger.info("Registering MaxTextForCausalLM model with tpu_inference and vllm.") register_model("MaxTextForCausalLM", MaxTextForCausalLM) - logger.info("Successfully registered MaxTextForCausalLM model.") + register_model("GemmaForCausalLM", MaxTextForCausalLM) + register_model("Gemma2ForCausalLM", MaxTextForCausalLM) + logger.info("Successfully registered MaxTextForCausalLM, Gemma2ForCausalLM, and GemmaForCausalLM model.")