Skip to content

feat: add pretrained orbax cache for fast WAN inference loads#406

Open
csgoogle wants to merge 1 commit into
mainfrom
wan-orbax-checkpoint-cache
Open

feat: add pretrained orbax cache for fast WAN inference loads#406
csgoogle wants to merge 1 commit into
mainfrom
wan-orbax-checkpoint-cache

Conversation

@csgoogle
Copy link
Copy Markdown
Collaborator

@csgoogle csgoogle commented May 15, 2026

Pretrained Orbax cache for fast WAN inference loads

Caches WAN weights as an Orbax checkpoint so inference can skip the slow diffusers load on repeat runs (~10× faster). The first run loads from diffusers and writes the cache; subsequent runs restore directly from it. Works with both local paths and GCS buckets.

Changes

  • Added a new pretrained_orbax_dir config option to all 6 WAN configs (empty by default, so the feature is off unless set). Accepts a local path or a gs:// bucket path.
  • Load priority: training checkpoint → pretrained cache → diffusers (the cache is populated automatically on a miss).
  • generate_wan and generate_wan_animate now load with use_pretrained_cache=True.
  • Refactored the shared Orbax logic into the WanCheckpointer base class; subclasses now only declare model_name, checkpoint_state_item_names, and pretrained_state_sources.
  • Added an animate checkpointer along with tests covering the cache hit / miss / save paths.
image

@github-actions
Copy link
Copy Markdown

Copy link
Copy Markdown
Collaborator

@Perseus14 Perseus14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add this for the other WAN models?

Comment thread src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py Outdated
Comment thread src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py Outdated
@csgoogle csgoogle force-pushed the wan-orbax-checkpoint-cache branch from 1e3be2d to 4e2c7f1 Compare May 25, 2026 13:13
@csgoogle csgoogle marked this pull request as ready for review May 25, 2026 13:19
@csgoogle csgoogle requested a review from entrpn as a code owner May 25, 2026 13:19
@csgoogle
Copy link
Copy Markdown
Collaborator Author

Can you also add this for the other WAN models?

done

@github-actions
Copy link
Copy Markdown

🤖 Hi @csgoogle, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request introduces a pretrained Orbax cache for WAN models, providing a significant performance boost (~10×) for subsequent inference loads by skipping the slow diffusers load. The implementation is well-structured, with a clean refactoring of the shared Orbax logic into a base WanCheckpointer class, and includes comprehensive tests for the new caching strategy.

🔍 General Feedback

  • Architecture: The refactoring of WanCheckpointer into an abstract base class is a great improvement, making it easy to add support for new WAN model variants with minimal boilerplate.
  • Performance: The automatic population of the cache on the first run is a user-friendly feature that drastically improves the startup time for repeated tasks.
  • Correctness: The restructuring of multi-transformer checkpoints (e.g., WAN 2.2) is handled correctly to maintain compatibility with existing pipeline loading logic.
  • Robustness: I've suggested adding exception handling to the cache-saving logic to ensure that intermittent write failures (e.g., GCS permissions) do not disrupt the main execution flow.

Comment thread src/maxdiffusion/checkpointing/wan_checkpointer.py Outdated
Comment thread src/maxdiffusion/checkpointing/wan_checkpointer.py Outdated
@Perseus14
Copy link
Copy Markdown
Collaborator

Would this change have an impact on any training runs?

@csgoogle
Copy link
Copy Markdown
Collaborator Author

Would this change have an impact on any training runs?

It won't there are unit tests covering the training load, also the flag for the change only get's enabled if it's true

@csgoogle csgoogle force-pushed the wan-orbax-checkpoint-cache branch 5 times, most recently from dd62933 to c9fa891 Compare May 26, 2026 19:44
@csgoogle
Copy link
Copy Markdown
Collaborator Author

Would this change have an impact on any training runs?

Made it simple, now we are not touching training code.

@csgoogle csgoogle force-pushed the wan-orbax-checkpoint-cache branch from c9fa891 to 874f2c1 Compare May 26, 2026 19:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants