Add distillation launchers for qwen3-30b-a3b-base and gpt-oss-20b#4028
Add distillation launchers for qwen3-30b-a3b-base and gpt-oss-20b#4028gagika wants to merge 1 commit into
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
94bcf37 to
e202120
Compare
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request introduces distillation launchers and configurations for qwen3-30b-a3b-base and gpt-oss-20b models on TPU v7x. The additions are useful for standardizing distillation runs, but there are a few issues regarding redundancy and hardcoded personal paths.
🔍 General Feedback
- Redundant Patch File: The file
distillation-wrappers.patchappears to be a redundant diff of the entire PR and should be removed. - Hardcoded Defaults: Several scripts and configuration files contain default GCS paths and images pointing to personal buckets (
agagik-us,yujiedeng-maxtext-dev). These should ideally be replaced with generic placeholders or public resources to improve maintainability and portability for other users. - Environment Management: The use of
/dev/shmforTMPDIRand Hugging Face caches is a good performance optimization to avoid ephemeral storage limits, but setting it globally asTMPDIRshould be done with caution.
e202120 to
1a00405
Compare
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR introduces comprehensive one-command distillation launchers for qwen3-30b-a3b-base and gpt-oss-20b on TPU v7x. The additions include performance-tuned XLA flags, optimized YAML configurations (including activation offload for higher batch sizes), and enhancements to the shared XPK submission script to handle tokenizer staging and HF caching efficiently.
🔍 General Feedback
- Robustness: The shared
run_distill_xpk.shwas improved to handle HF caching in/dev/shm, which is a great optimization for TPU workloads. I've suggested some minor quoting fixes to ensure these scripts handle paths with spaces or special characters reliably. - Documentation: The scripts and YAML files include helpful comments explaining specific model quirks (e.g., the
distill_beta=0requirement forgpt-oss). - Defaults: While demo defaults are provided, I recommended using more generic placeholders for buckets and images to prevent accidental use of dev resources by other users.
| export XPK_ZONE="${XPK_ZONE:-us-central1}" | ||
| export XPK_DEVICE_TYPE="${XPK_DEVICE_TYPE:-tpu7x-4x4x4}" | ||
| export XPK_BASE_OUTPUT_DIR="${XPK_BASE_OUTPUT_DIR:-gs://agagik-us/distillation}" | ||
| export XPK_BASE_IMAGE="${XPK_BASE_IMAGE:-gcr.io/cloud-tpu-multipod-dev/maxtext_base_image:agagik-distill}" |
There was a problem hiding this comment.
| export XPK_BASE_IMAGE="${XPK_BASE_IMAGE:-gcr.io/cloud-tpu-multipod-dev/maxtext_base_image:agagik-distill}" | |
| export XPK_BASE_IMAGE="${XPK_BASE_IMAGE:-gcr.io/cloud-tpu-multipod-dev/maxtext_base_image:agagik-distill}" |
| "$image_flag=$XPK_BASE_IMAGE" \ | ||
| --command "export PYTHONPATH=/deps/src:/app/src; \ | ||
| export BASE_OUTPUT_DIRECTORY=${OUTPUT_DIR}; \ | ||
| export LIBTPU_INIT_ARGS='${libtpu_init_args}'; \ |
There was a problem hiding this comment.
| export LIBTPU_INIT_ARGS='${libtpu_init_args}'; \ | |
| export HF_HOME=\"${XPK_HF_CACHE_DIR}\"; export HF_DATASETS_CACHE=\"${XPK_HF_CACHE_DIR}/datasets\"; mkdir -p \"${XPK_HF_CACHE_DIR}/datasets\"; \ |
| --xla_tpu_aggressive_opt_barrier_removal=ENABLED \ | ||
| --xla_lhs_prioritize_async_depth_over_stall=ENABLED \ | ||
| --xla_tpu_enable_ag_backward_pipelining=true \ | ||
| --xla_should_allow_loop_variant_parameter_in_chain=ENABLED \ |
There was a problem hiding this comment.
| --xla_should_allow_loop_variant_parameter_in_chain=ENABLED \ | |
| libtpu_init_args=$(printf -- '%s' "${XPK_LIBTPU_INIT_ARGS:-$default_libtpu_args}" | tr -s '[:space:]' ' ') |
|
|
||
| # Optional: stage HF tokenizer files from GCS for models whose tokenizer isn't | ||
| # baked into the image (e.g. gpt-oss). | ||
| tokenizer_prelude="" |
There was a problem hiding this comment.
| tokenizer_prelude="" | |
| tokenizer_prelude="mkdir -p \"${XPK_TOKENIZER_LOCAL}\" && gcloud storage rsync \"${XPK_TOKENIZER_GCS}\" \"${XPK_TOKENIZER_LOCAL}\";" |
| export XPK_PROJECT="${XPK_PROJECT:-cloud-tpu-multipod-dev}" | ||
| export XPK_ZONE="${XPK_ZONE:-us-central1}" | ||
| export XPK_DEVICE_TYPE="${XPK_DEVICE_TYPE:-tpu7x-4x4x4}" | ||
| export XPK_BASE_OUTPUT_DIR="${XPK_BASE_OUTPUT_DIR:-gs://agagik-us/distillation}" |
There was a problem hiding this comment.
| export XPK_BASE_OUTPUT_DIR="${XPK_BASE_OUTPUT_DIR:-gs://agagik-us/distillation}" | |
| export XPK_BASE_OUTPUT_DIR="${XPK_BASE_OUTPUT_DIR:-gs://YOUR-BUCKET/distillation}" |
| export XPK_ZONE="${XPK_ZONE:-us-central1}" | ||
| export XPK_DEVICE_TYPE="${XPK_DEVICE_TYPE:-tpu7x-4x4x4}" | ||
| export XPK_BASE_OUTPUT_DIR="${XPK_BASE_OUTPUT_DIR:-gs://agagik-us/distillation}" | ||
| export XPK_BASE_IMAGE="${XPK_BASE_IMAGE:-gcr.io/cloud-tpu-multipod-dev/maxtext_base_image:agagik-distill}" |
There was a problem hiding this comment.
| export XPK_BASE_IMAGE="${XPK_BASE_IMAGE:-gcr.io/cloud-tpu-multipod-dev/maxtext_base_image:agagik-distill}" | |
| export XPK_BASE_IMAGE="${XPK_BASE_IMAGE:-gcr.io/cloud-tpu-multipod-dev/maxtext_base_image:agagik-distill}" |
| export XPK_PROJECT="${XPK_PROJECT:-cloud-tpu-multipod-dev}" | ||
| export XPK_ZONE="${XPK_ZONE:-us-central1}" | ||
| export XPK_DEVICE_TYPE="${XPK_DEVICE_TYPE:-tpu7x-4x4x4}" | ||
| export XPK_BASE_OUTPUT_DIR="${XPK_BASE_OUTPUT_DIR:-gs://agagik-us/distillation}" |
There was a problem hiding this comment.
| export XPK_BASE_OUTPUT_DIR="${XPK_BASE_OUTPUT_DIR:-gs://agagik-us/distillation}" | |
| export XPK_BASE_OUTPUT_DIR="${XPK_BASE_OUTPUT_DIR:-gs://YOUR-BUCKET/distillation}" |
1a00405 to
bc5cc4a
Compare
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
bc5cc4a to
fb8d0fb
Compare
There was a problem hiding this comment.
This PR introduces comprehensive distillation launchers and configurations for qwen3-30b-a3b-base and gpt-oss-20b models on TPU v7x. The updates to the underlying run_distill_xpk.sh script, including HF cache management and GCS-based asset staging, significantly improve the robustness and ease of use for distillation workloads.
🔍 General Feedback
- Out-of-the-box Usability: While the scripts are well-structured, several defaults point to specific user buckets (
gs://agagik-us/). Replacing these with generic placeholders or documenting them as mandatory overrides would improve the experience for the broader team. - Parallelism Consistency: The
gpt-oss-20bconfiguration has a hardcodedici_fsdp_parallelismthat conflicts with the default cluster size in its launcher script. Usingauto(-1) is preferred for better scalability. - XLA Tuning: The inclusion of tuned XLA flags for both models is a great addition, providing clear performance targets (~17-20% MFU).
| @@ -290,6 +330,11 @@ submit_workload() { | |||
| "$image_flag=$XPK_BASE_IMAGE" \ | |||
| --command "export PYTHONPATH=/deps/src:/app/src; \ | |||
There was a problem hiding this comment.
| --command "export PYTHONPATH=/deps/src:/app/src; \ | |
| export LIBTPU_INIT_ARGS=\"${libtpu_init_args}\"; \ |
| export XPK_DATASET_BUCKET="${XPK_DATASET_BUCKET:-maxtext-dataset}" | ||
| export XPK_DATASET_SUBPATH="${XPK_DATASET_SUBPATH:-array-record/climbmix/*.arrayrecord}" | ||
|
|
||
| # Stage HF tokenizer files (not in the image for gpt-oss). |
There was a problem hiding this comment.
| # Stage HF tokenizer files (not in the image for gpt-oss). | |
| export XPK_YAML_GCS="${XPK_YAML_GCS:-gs://YOUR-BUCKET/distill-configs/distillation_gpt_oss_20b.yml}" |
| export XPK_BASE_IMAGE="${XPK_BASE_IMAGE:-gcr.io/cloud-tpu-multipod-dev/maxtext_base_image:agagik-distill}" | ||
| export XPK_PRIORITY="${XPK_PRIORITY:-high}" | ||
|
|
||
| export XPK_USE_GCSFUSE=1 |
There was a problem hiding this comment.
| export XPK_USE_GCSFUSE=1 | |
| export XPK_YAML_GCS="${XPK_YAML_GCS:-gs://YOUR-BUCKET/distill-configs/distillation_qwen3_30b_base.yml}" |
| # distill_beta=0: decoder feature loss is broken on gpt-oss. | ||
| export DISTILL_ALPHA="${DISTILL_ALPHA:-0.5}" | ||
| export DISTILL_TEMPERATURE="${DISTILL_TEMPERATURE:-1.0}" | ||
| export DISTILL_BETA="${DISTILL_BETA:-0}" |
There was a problem hiding this comment.
| export DISTILL_BETA="${DISTILL_BETA:-0}" | |
| export XPK_LIBTPU_INIT_ARGS="${XPK_LIBTPU_INIT_ARGS:---xla_tpu_scoped_vmem_limit_kib=61440 \ |
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR introduces well-structured distillation launchers for qwen3-30b-a3b-base and gpt-oss-20b models, significantly simplifying the setup for these workloads on TPU v7x. The inclusion of tuned XLA flags and optimized configuration files demonstrates a strong focus on performance (MFU).
🔍 General Feedback
- Robustness: The shell scripts could be made more robust by consistently quoting paths and environment variables to handle potential special characters or spaces.
- Consistency: A few XLA flags use
trueinstead of the more standardENABLEDvalue found elsewhere in the repository; aligning these improves maintainability. - Explicit Overrides: Explicitly passing the staged tokenizer path to the training script ensures that the workload uses the intended assets regardless of the pod's working directory.
- Documentation: The scripts include helpful comments and usage examples, which is great for usability.
| # by latency_hiding_layer_scheduler. | ||
| export XPK_LIBTPU_INIT_ARGS="${XPK_LIBTPU_INIT_ARGS:---xla_tpu_scoped_vmem_limit_kib=65536 \ | ||
| --xla_tpu_impure_enable_packed_bf16_math_ops=true \ | ||
| --xla_tpu_aggressive_opt_barrier_removal=true \ |
There was a problem hiding this comment.
| --xla_tpu_aggressive_opt_barrier_removal=true \ | |
| --xla_tpu_aggressive_opt_barrier_removal=ENABLED \ |
| export BASE_OUTPUT_DIRECTORY=${OUTPUT_DIR}; \ | ||
| export LIBTPU_INIT_ARGS='${libtpu_init_args}'; \ | ||
| export TMPDIR=/dev/shm; export JAX_COMPILATION_CACHE_DIR=/dev/shm/jax_cache; \ | ||
| export HF_HOME=${XPK_HF_CACHE_DIR}; export HF_DATASETS_CACHE=${XPK_HF_CACHE_DIR}/datasets; mkdir -p ${XPK_HF_CACHE_DIR}/datasets; \ |
There was a problem hiding this comment.
| export HF_HOME=${XPK_HF_CACHE_DIR}; export HF_DATASETS_CACHE=${XPK_HF_CACHE_DIR}/datasets; mkdir -p ${XPK_HF_CACHE_DIR}/datasets; \ | |
| export HF_HOME='${XPK_HF_CACHE_DIR}'; export HF_DATASETS_CACHE='${XPK_HF_CACHE_DIR}/datasets'; mkdir -p '${XPK_HF_CACHE_DIR}/datasets'; \ |
| export HF_HOME=${XPK_HF_CACHE_DIR}; export HF_DATASETS_CACHE=${XPK_HF_CACHE_DIR}/datasets; mkdir -p ${XPK_HF_CACHE_DIR}/datasets; \ | ||
| ${yaml_prelude} \ | ||
| ${tokenizer_prelude} \ | ||
| ${gcsfuse_prelude} \ |
There was a problem hiding this comment.
| ${gcsfuse_prelude} \ | |
| python3 -m maxtext.trainers.post_train.distillation.train_distill ${XPK_DISTILL_CONFIG} \ | |
| run_name=${XPK_RUN_NAME} \ | |
| ${grain_files_override} \ | |
| ${steps_override} \ | |
| ${checkpoint_period_override} \ | |
| tokenizer_path=${XPK_TOKENIZER_LOCAL:-} \ | |
| distill_alpha=${DISTILL_ALPHA} \ | |
| distill_temperature=${DISTILL_TEMPERATURE} \ | |
| distill_beta=${DISTILL_BETA} \ | |
| distill_layer_indices="${DISTILL_LAYER_INDICES}" |
| # Optional: stage HF tokenizer files from GCS for models whose tokenizer isn't | ||
| # baked into the image (e.g. gpt-oss). | ||
| tokenizer_prelude="" | ||
| if [ -n "${XPK_TOKENIZER_GCS:-}" ] && [ -n "${XPK_TOKENIZER_LOCAL:-}" ]; then |
There was a problem hiding this comment.
| if [ -n "${XPK_TOKENIZER_GCS:-}" ] && [ -n "${XPK_TOKENIZER_LOCAL:-}" ]; then | |
| tokenizer_prelude="mkdir -p '${XPK_TOKENIZER_LOCAL}' && gcloud storage rsync '${XPK_TOKENIZER_GCS}' '${XPK_TOKENIZER_LOCAL}';" |
| grain_files_override="grain_train_files=gs://${XPK_DATASET_BUCKET}/${XPK_DATASET_SUBPATH}" | ||
| fi | ||
|
|
||
| # Optional: stage the YAML from GCS instead of baking via upload_runner. |
There was a problem hiding this comment.
| # Optional: stage the YAML from GCS instead of baking via upload_runner. | |
| yaml_prelude="gcloud storage cp '${XPK_YAML_GCS}' '${XPK_DISTILL_CONFIG}';" |
c6e4983 to
eba25b0
Compare
eba25b0 to
03dafe7
Compare
Description
One-command launchers for running distillation on TPU v7x. Each script sets the
right XLA flags, mounts a grain arrayrecord dataset via gcsfuse (ClimbMix by
default; configurable via
XPK_DATASET_BUCKET/XPK_DATASET_SUBPATH),configures distillation knobs, stages the HF tokenizer when needed, and submits
a workload via XPK.
Usage
Each launcher takes a mode argument (default
submit):submit— stage the YAML to GCS and create the xpk workloadmonitor— stream logs for the last submitted workloadresume_until_done— auto-resubmit on failure until the run completesTests
End to end test for both gpt-oss and qwen3-30b models.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.