Skip to content

Add eval_batch_size knob for faster post-train RL evaluation#4030

Open
py4 wants to merge 1 commit into
mainfrom
pr/eval-batch-size
Open

Add eval_batch_size knob for faster post-train RL evaluation#4030
py4 wants to merge 1 commit into
mainfrom
pr/eval-batch-size

Conversation

@py4
Copy link
Copy Markdown
Collaborator

@py4 py4 commented Jun 1, 2026

Post-train RL evaluation is currently batched at trainer_config.batch_size, which for GRPO is intentionally small (e.g. 4 prompts per step × 8 generations = 32 trajectories — anything larger blows out trainer-side KV cache). At eval time this is wasteful: vLLM rollout has many DP replicas sitting idle when only 4 prompts are dispatched per call.

This PR adds an rl.eval_batch_size knob (default -1 = use batch_size, preserving old behavior) that overrides the batch dimension during dataset preparation for the test split. Setting it to e.g. 128 on a sampler with 8 DP replicas gives a ~32× eval throughput improvement on TPU without affecting training behavior.

Changes (3 files, +32/-2 lines):

  • src/maxtext/configs/types.py: new Pydantic field RLDataset.eval_batch_size: int = -1
  • src/maxtext/configs/post_train/rl.yml: default eval_batch_size: -1 + comment
  • src/maxtext/trainers/post_train/rl/train_rl.py:prepare_datasets: when set and positive, use eval_batch_size for the test split's slice + .batch(...) call

NOTE: total eval examples = num_test_batches * eval_batch_size, so users adjusting eval_batch_size should adjust num_test_batches to keep total eval set size constant.

Backward compatible: default -1 falls back to batch_size (identical to old behavior). No effect on training path.

Checklist

  • Tested locally on TPU v6e 8×8: with eval_batch_size=128 (8 DP replicas), eval over 1408 examples completes in ~3 min vs ~30+ min at batch_size=4
  • Backward compatible: default -1 preserves existing behavior bit-for-bit
  • No effect on training path (only eval-side dataset preparation touched)
  • No effect on non-RL paths (only RL Pydantic config + RL trainer touched)

Post-train RL evaluation batched at trainer_config.batch_size, which
for GRPO is intentionally small (e.g. 4 prompts per training step ×
8 generations = 32 trajectories — anything larger blows out KV cache
for the trainer). At eval time this is wasteful: vLLM rollout has many
DP replicas sitting idle when only 4 prompts are dispatched per batch.

Add an `eval_batch_size` knob (default -1 = use batch_size, preserving
old behavior) that overrides the batch dimension during dataset
preparation for the test split. Setting it to e.g. 128 on a sampler
with 8 DP replicas gives a ~32x eval throughput improvement on TPU
without affecting training behavior.

Total eval examples = num_test_batches * eval_batch_size, so users
should adjust num_test_batches when increasing eval_batch_size to keep
total eval set size constant.
@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 1, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant