Skip to content

get_optimizer: respect learning_rate_schedule_steps config knob#4029

Open
py4 wants to merge 1 commit into
mainfrom
pr/lr-schedule-steps
Open

get_optimizer: respect learning_rate_schedule_steps config knob#4029
py4 wants to merge 1 commit into
mainfrom
pr/lr-schedule-steps

Conversation

@py4
Copy link
Copy Markdown
Collaborator

@py4 py4 commented Jun 1, 2026

base.yml documents learning_rate_schedule_steps as the LR schedule shape control ("By default the length of the schedule is set to the number of steps", but configurable to a longer/different value). The post-train RL get_optimizer ignored this knob and always used max_train_steps directly, silently dropping any non-default value.

This matters for recipe parity when reproducing a recipe from one stack on another with a different num_batches: you typically want to keep the LR schedule SHAPE the same (e.g., warmup=50, decay=500) regardless of how many steps you actually run. Without this fix, integrated LR scales linearly with num_batches, breaking trajectory matching.

Changes (single file, +18/-5 lines):

  • src/maxtext/trainers/post_train/rl/utils_rl.py:get_optimizer: read tmvp_config.learning_rate_schedule_steps via getattr (with default -1). When set and positive, use it for both warmup_steps (= warmup_steps_fraction × schedule_steps) and decay_steps. When -1/unset/None, fall back to max_train_steps — identical to the old behavior.
  • Updated docstring + inline comments to describe the new behavior.

Backward compatible: default learning_rate_schedule_steps=-1 (already the documented default) preserves the existing behavior bit-for-bit. Only callers that explicitly set a positive value see a behavior change, which is exactly what the config knob's documentation promises.

Aligns the RL get_optimizer with the rest of the codebase: maxtext.utils.maxtext_utils.create_learning_rate_schedule (used by pre-train, SFT, DPO, distillation) already respects learning_rate_schedule_steps. Only the RL-specific copy was missing the plumbing.

Checklist

  • Backward compatible: default -1 preserves old behavior (no test regressions expected)
  • Tested locally on TPU with learning_rate_schedule_steps=500, num_batches=50, warmup_steps_fraction=0.1 → produced expected 50-step warmup over 500-step schedule (vs 5-step warmup over 50-step schedule without the fix)
  • No effect on non-RL paths (only utils_rl.get_optimizer touched; SFT/distillation/pre-train use a different get_optimizer)

base.yml documents learning_rate_schedule_steps as the LR schedule shape
control ("By default the length of the schedule is set to the number of
steps", but configurable to a longer/different value). The post_train RL
get_optimizer ignored this knob and always used max_train_steps directly,
silently dropping any non-default value.

This matters for GPU<->TPU recipe parity: when reproducing a GPU recipe
with NUM_BATCHES different from the GPU's, you need to keep the LR
schedule SHAPE the same (e.g., warmup=50, decay=500 like NeMo-RL's
lr_warmup_iters/lr_decay_iters) regardless of how many TPU steps you
run. Without this fix, integrated LR scales linearly with NUM_BATCHES.

Backward-compatible: default learning_rate_schedule_steps=-1 (or unset)
falls back to max_train_steps, identical to old behavior.
@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 1, 2026

Codecov Report

❌ Patch coverage is 66.66667% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/trainers/post_train/rl/utils_rl.py 66.66% 0 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants