get_optimizer: respect learning_rate_schedule_steps config knob#4029
Open
py4 wants to merge 1 commit into
Open
get_optimizer: respect learning_rate_schedule_steps config knob#4029py4 wants to merge 1 commit into
py4 wants to merge 1 commit into
Conversation
base.yml documents learning_rate_schedule_steps as the LR schedule shape
control ("By default the length of the schedule is set to the number of
steps", but configurable to a longer/different value). The post_train RL
get_optimizer ignored this knob and always used max_train_steps directly,
silently dropping any non-default value.
This matters for GPU<->TPU recipe parity: when reproducing a GPU recipe
with NUM_BATCHES different from the GPU's, you need to keep the LR
schedule SHAPE the same (e.g., warmup=50, decay=500 like NeMo-RL's
lr_warmup_iters/lr_decay_iters) regardless of how many TPU steps you
run. Without this fix, integrated LR scales linearly with NUM_BATCHES.
Backward-compatible: default learning_rate_schedule_steps=-1 (or unset)
falls back to max_train_steps, identical to old behavior.
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
A9isha
approved these changes
Jun 2, 2026
SurbhiJainUSC
approved these changes
Jun 2, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
base.ymldocumentslearning_rate_schedule_stepsas the LR schedule shape control ("By default the length of the schedule is set to the number of steps", but configurable to a longer/different value). The post-train RLget_optimizerignored this knob and always usedmax_train_stepsdirectly, silently dropping any non-default value.This matters for recipe parity when reproducing a recipe from one stack on another with a different
num_batches: you typically want to keep the LR schedule SHAPE the same (e.g., warmup=50, decay=500) regardless of how many steps you actually run. Without this fix, integrated LR scales linearly withnum_batches, breaking trajectory matching.Changes (single file, +18/-5 lines):
src/maxtext/trainers/post_train/rl/utils_rl.py:get_optimizer: readtmvp_config.learning_rate_schedule_stepsviagetattr(with default-1). When set and positive, use it for bothwarmup_steps(=warmup_steps_fraction × schedule_steps) anddecay_steps. When-1/unset/None, fall back tomax_train_steps— identical to the old behavior.Backward compatible: default
learning_rate_schedule_steps=-1(already the documented default) preserves the existing behavior bit-for-bit. Only callers that explicitly set a positive value see a behavior change, which is exactly what the config knob's documentation promises.Aligns the RL
get_optimizerwith the rest of the codebase:maxtext.utils.maxtext_utils.create_learning_rate_schedule(used by pre-train, SFT, DPO, distillation) already respectslearning_rate_schedule_steps. Only the RL-specific copy was missing the plumbing.Checklist
-1preserves old behavior (no test regressions expected)learning_rate_schedule_steps=500,num_batches=50,warmup_steps_fraction=0.1→ produced expected 50-step warmup over 500-step schedule (vs 5-step warmup over 50-step schedule without the fix)utils_rl.get_optimizertouched; SFT/distillation/pre-train use a differentget_optimizer)