-
Notifications
You must be signed in to change notification settings - Fork 491
Description
Bug report
When running a full end-to-end GRPO training pipeline using VllmRollout on a TPU v5e-8 slice, the rollout engine crashes during the first micro-batch step due to two consecutive API mismatches between tunix/rl/rl_cluster.py and tunix/rl/rollout/vllm_rollout.py.
Issue 1: Unexpected Keyword Argument completion_mask
RLCluster.get_old_per_token_logps passes a completion_mask argument down to the rollout engine, but the VllmRollout.get_per_token_logps method signature does not accept it.
Traceback:
File "/usr/local/lib/python3.12/site-packages/tunix/rl/grpo/grpo_learner.py", line 267, in _generate_and_compute_advantage
old_per_token_logps = self.rl_cluster.get_old_per_token_logps(...)
TypeError: VllmRollout.get_per_token_logps() got an unexpected keyword argument 'completion_mask'
Issue 2: JAX concatenate Strict Type Error
If the completion_mask issue is bypassed, the pipeline immediately crashes on the output. VllmRollout returns native Python lists, but rl_cluster.py passes them directly to jnp.concatenate(). Under strict JAX versions (like 0.4.25), this throws a TypeError because it expects array-like objects, not lists.
Traceback:
Python
File "/usr/local/lib/python3.12/site-packages/tunix/rl/rl_cluster.py", line 993, in get_old_per_token_logps
per_token_logps = jnp.concatenate(outs, axis=0)
TypeError: concatenate requires ndarray or scalar arguments, got <class 'list'> at position 0.
Logs/Output
No response
Environment Information
Environment:
Hardware: GKE TPU v5e-8 slice
JAX Version: 0.4.25 (Pinned to maintain mesh stability with with_sharding_constraint)
vLLM Version: 0.17.0rc1
Algorithm: GRPO
Additional Context
No response