feat: add pretrained orbax cache for fast WAN inference loads#406
feat: add pretrained orbax cache for fast WAN inference loads#406csgoogle wants to merge 1 commit into
Conversation
Perseus14
left a comment
There was a problem hiding this comment.
Can you also add this for the other WAN models?
1e3be2d to
4e2c7f1
Compare
done |
|
🤖 Hi @csgoogle, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request introduces a pretrained Orbax cache for WAN models, providing a significant performance boost (~10×) for subsequent inference loads by skipping the slow diffusers load. The implementation is well-structured, with a clean refactoring of the shared Orbax logic into a base WanCheckpointer class, and includes comprehensive tests for the new caching strategy.
🔍 General Feedback
- Architecture: The refactoring of
WanCheckpointerinto an abstract base class is a great improvement, making it easy to add support for new WAN model variants with minimal boilerplate. - Performance: The automatic population of the cache on the first run is a user-friendly feature that drastically improves the startup time for repeated tasks.
- Correctness: The restructuring of multi-transformer checkpoints (e.g., WAN 2.2) is handled correctly to maintain compatibility with existing pipeline loading logic.
- Robustness: I've suggested adding exception handling to the cache-saving logic to ensure that intermittent write failures (e.g., GCS permissions) do not disrupt the main execution flow.
|
Would this change have an impact on any training runs? |
It won't there are unit tests covering the training load, also the flag for the change only get's enabled if it's |
dd62933 to
c9fa891
Compare
Made it simple, now we are not touching training code. |
c9fa891 to
874f2c1
Compare
Pretrained Orbax cache for fast WAN inference loads
Caches WAN weights as an Orbax checkpoint so inference can skip the slow diffusers load on repeat runs (~10× faster). The first run loads from diffusers and writes the cache; subsequent runs restore directly from it. Works with both local paths and GCS buckets.
Changes
pretrained_orbax_dirconfig option to all 6 WAN configs (empty by default, so the feature is off unless set). Accepts a local path or ags://bucket path.generate_wanandgenerate_wan_animatenow load withuse_pretrained_cache=True.WanCheckpointerbase class; subclasses now only declaremodel_name,checkpoint_state_item_names, andpretrained_state_sources.