Log RL metrics per environment#3446
Conversation
Phlip79
left a comment
There was a problem hiding this comment.
Thank you for adding great tests!
|
/ok to test 51e0005 |
| env_idx = [i for i, eidx in enumerate(group_stats.env_ids) if eidx == env_id] | ||
|
|
||
| # Advantages are flattened, we need to be more careful with those. | ||
| group_turn_counts = [sum(nt) for nt in num_turns] |
There was a problem hiding this comment.
This doesn't need to be inside the for loop
| rewards: list[float] | ||
| rewards: list[list[float]] # inner list is for a group | ||
| env_ids: list[int] # same length as len(rewards) | ||
| turn_lens: list[list[int]] # |
| for g in rollouts: | ||
| if g[0].env_id not in example_groups: | ||
| example_groups[g[0].env_id] = g |
There was a problem hiding this comment.
RewardOnlyAgent says env_id: str | None = None. But in practice, all of our environments do set an env_id, so in practice it's not a problem.
Still though, we should either None-check here, or we should enforce env_id as required in RewardOnlyAgent.
There was a problem hiding this comment.
I do not know why we had None there. I added env_id to Countdown which was the only None-id env probably because we forget. Maybe @ArEsKay3 remembers why we need None there.
| return_log_probs = bool(req.get("logprobs", False)) | ||
| top_n_logprobs = int(req.get("top_logprobs", 0)) if return_log_probs else 0 | ||
| skip_prompt_log_probs = bool(req.get("skip_prompt_log_probs", False)) | ||
| skip_prompt_log_probs = bool(req.get("skip_prompt_log_probs", True)) |
There was a problem hiding this comment.
Please remove this diff; it's already been fixed in main, and having this diff adds extra reviewer burden on other teams.
|
/ok to test a8abff4 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22144219623 |
This PR makes 2 main things: