Add eval_batch_size knob for faster post-train RL evaluation#4030
Open
py4 wants to merge 1 commit into
Open
Conversation
Post-train RL evaluation batched at trainer_config.batch_size, which for GRPO is intentionally small (e.g. 4 prompts per training step × 8 generations = 32 trajectories — anything larger blows out KV cache for the trainer). At eval time this is wasteful: vLLM rollout has many DP replicas sitting idle when only 4 prompts are dispatched per batch. Add an `eval_batch_size` knob (default -1 = use batch_size, preserving old behavior) that overrides the batch dimension during dataset preparation for the test split. Setting it to e.g. 128 on a sampler with 8 DP replicas gives a ~32x eval throughput improvement on TPU without affecting training behavior. Total eval examples = num_test_batches * eval_batch_size, so users should adjust num_test_batches when increasing eval_batch_size to keep total eval set size constant.
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Post-train RL evaluation is currently batched at
trainer_config.batch_size, which for GRPO is intentionally small (e.g. 4 prompts per step × 8 generations = 32 trajectories — anything larger blows out trainer-side KV cache). At eval time this is wasteful: vLLM rollout has many DP replicas sitting idle when only 4 prompts are dispatched per call.This PR adds an
rl.eval_batch_sizeknob (default-1= usebatch_size, preserving old behavior) that overrides the batch dimension during dataset preparation for the test split. Setting it to e.g. 128 on a sampler with 8 DP replicas gives a ~32× eval throughput improvement on TPU without affecting training behavior.Changes (3 files, +32/-2 lines):
src/maxtext/configs/types.py: new Pydantic fieldRLDataset.eval_batch_size: int = -1src/maxtext/configs/post_train/rl.yml: defaulteval_batch_size: -1+ commentsrc/maxtext/trainers/post_train/rl/train_rl.py:prepare_datasets: when set and positive, useeval_batch_sizefor the test split's slice +.batch(...)callNOTE: total eval examples =
num_test_batches * eval_batch_size, so users adjustingeval_batch_sizeshould adjustnum_test_batchesto keep total eval set size constant.Backward compatible: default
-1falls back tobatch_size(identical to old behavior). No effect on training path.Checklist
eval_batch_size=128(8 DP replicas), eval over 1408 examples completes in ~3 min vs ~30+ min atbatch_size=4-1preserves existing behavior bit-for-bit