Add support of Grain input pipeline for DPO.#4009
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
d63a9cc to
3b27bbf
Compare
|
🤖 Hi @igorts-git, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This pull request adds Grain input pipeline support for DPO (Direct Preference Optimization). The implementation is consistent with the project's existing Grain-based SFT patterns and includes comprehensive unit tests for different DPO data formats.
🔍 General Feedback
- Good Coverage: The addition of
TestGrainDPOPipelineProcessingindpo_data_processing_test.pycovers key edge cases like 2-column (common prefix) and 3-column datasets. - Defensive Design: The validation check in
src/maxtext/configs/types.pyprevents unsupported configurations early. - Defaults: The updated defaults in
dpo.ymlprovide a smoother "out-of-the-box" experience for Tunix-based DPO. - Batching Consistency: It is recommended to use the
get_local_batch_sizeutility to ensure all global configuration flags (like real data expansion) are respected during batching.
3b27bbf to
86f715f
Compare
86f715f to
89d9e3c
Compare
| max_prompt_length: int | None = None | ||
|
|
||
| def __post_init__(self): | ||
| if self.max_prompt_length is None: |
There was a problem hiding this comment.
The new logic that I added in types.py guarantees that this value is not None. However, for cases like unit tests it is still useful to have this default computed. Let' me know if you want it removed.
Description
Add support of Grain input reading when using the new Tunix-based DPO/ORPO.
Modify
dpo.ymlto have some reasonable defaults that allow invoking DPO without extra config parameters.Correct the case where
config.dpo.max_prompt_lengthis not set. The default value ismax_target_length // 2. It should be the same value that is passed to both the input pipeline and to the Tunix DPOTrainer class. Thus, I moved its computation totypes.py.In a follow up PR I will add a detailed logits comparison test for DPO.
BUGS: b/485626968
Tests
CI tests.
Ran DPO/ORPO while reading using Grain from parquet files.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.