diff --git a/src/maxtext/inference/vllm_decode.py b/src/maxtext/inference/vllm_decode.py index c2b1e5e5d2..67ebc88116 100644 --- a/src/maxtext/inference/vllm_decode.py +++ b/src/maxtext/inference/vllm_decode.py @@ -124,7 +124,7 @@ def decode_with_vllm(config: Config) -> None: token=config.hf_access_token, ) - prompts = [config.prompt] + prompts = [config.prompt] * int(config.per_device_batch_size) if config.use_chat_template: # Format the prompt using chat template if specified messages = [