Skip to content

Ring + Ulysses 2D context parallelism#404

Open
csgoogle wants to merge 1 commit into
mainfrom
wan-ulysses-bshd-attention
Open

Ring + Ulysses 2D context parallelism#404
csgoogle wants to merge 1 commit into
mainfrom
wan-ulysses-bshd-attention

Conversation

@csgoogle
Copy link
Copy Markdown
Collaborator

@csgoogle csgoogle commented May 12, 2026

Description
This PR adds support for a new ulysses_ring attention mode for WAN models. The implementation keeps the public sequence sharding on the existing context mesh axis, then internally reshapes that axis into private ring and ulysses axes so the attention path can combine Ulysses all-to-all head sharding with ring-based KV rotation.

Changes

  • Adds ulysses_ring attention kernel registration and routing.
  • Introduces internal mesh reshaping from context into hidden ring and ulysses axes.
  • Adds ulysses_shards config plumbing through WAN pipeline, WAN transformer blocks, and attention ops.
  • Adds logical axis rules for Ulysses ring attention.
  • Updates WAN config files to document ulysses_ring support and add ulysses_shards.
  • Adds tests covering layout round-trip behavior and validation errors for invalid Ulysses shard settings.

Testing

  • Added unit tests in src/maxdiffusion/tests/attention_test.py for ulysses_ring behavior and validation.

Results
image

@google-cla
Copy link
Copy Markdown

google-cla Bot commented May 12, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@csgoogle csgoogle force-pushed the wan-ulysses-bshd-attention branch from d6638b6 to 2582cf1 Compare May 15, 2026 11:36
@csgoogle csgoogle changed the title Wan ulysses bshd attention Ring + Ulysses 2D context parallelism May 15, 2026
@csgoogle csgoogle force-pushed the wan-ulysses-bshd-attention branch 3 times, most recently from 48a900e to d10bf07 Compare May 15, 2026 19:27
@github-actions
Copy link
Copy Markdown

🤖 Hi @csgoogle, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@csgoogle csgoogle force-pushed the wan-ulysses-bshd-attention branch from d10bf07 to 6521956 Compare May 15, 2026 19:29
Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request introduces a hybrid Ulysses + Ring 2D context parallelism for attention, which is a significant addition for handling long sequences in large-scale diffusion models. The implementation correctly sets up the internal mesh and handles the necessary all_to_all communications. However, there are some concerns regarding the masking logic and the handling of the attention_mask parameter in the new kernel that should be addressed to ensure correctness across all use cases.

🔍 General Feedback

  • Positive Highlights: The refactoring of Attention classes to use a more flexible attention_config dictionary is a great improvement for maintainability and reduces the complexity of method signatures.
  • Robustness: The validation checks for ulysses_shards and mesh axes are well-implemented and will help users avoid misconfiguration.
  • Testing: New tests for the Ulysses-Ring layout are included, although more comprehensive tests covering the actual attention computation with padding would further improve confidence.

Comment thread src/maxdiffusion/configs/base_wan_27b.yml Outdated
Comment thread src/maxdiffusion/models/attention_flax.py

block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "tokamax_ring")

q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 The use of np.tile(kv_shard_valid, num_ring_shards) assumes that every shard in the ring has an identical valid/padding pattern. In sequence parallelism, padding is typically only at the end of the global sequence (affecting only the last shard). Tiling a local mask will incorrectly mask valid tokens in earlier shards or fail to mask padding tokens in the last shard correctly.

Consider calculating the mask based on global sequence positions:

    # Example logic for global masking
    ring_idx = jax.lax.axis_index(ring_axis)
    q_global_start = ring_idx * query_seq_len
    # ... logic to create a mask that correctly accounts for orig_q_seq_len across the ring

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, fixed it.

@csgoogle csgoogle force-pushed the wan-ulysses-bshd-attention branch from 6521956 to 5c4d053 Compare May 15, 2026 19:39
@csgoogle csgoogle marked this pull request as ready for review May 15, 2026 19:57
@csgoogle csgoogle requested a review from entrpn as a code owner May 15, 2026 19:57
@github-actions
Copy link
Copy Markdown

🤖 Hi @csgoogle, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR successfully implements the ulysses_ring attention mode, a sophisticated 2D parallelism strategy that combines Ulysses all-to-all head sharding with Ring KV rotation. The implementation is clean, well-integrated into the existing WAN model infrastructure, and includes a comprehensive suite of unit tests verifying both functional correctness and validation logic.

🔍 General Feedback

  • Elegant Mesh Reshaping: The internal reshaping of the context mesh axis into private ring and ulysses axes within shard_map is a great architectural choice, keeping the public configuration simple.
  • Robust Masking: The global sequence masking logic in the hybrid kernel correctly handles padding, which is critical for accuracy in sequence-parallel settings.
  • Thorough Testing: The addition of detailed round-trip and validation tests in attention_test.py is excellent and ensures the new feature is robust against edge cases.

Comment thread src/maxdiffusion/common_types.py Outdated
Comment thread src/maxdiffusion/models/attention_flax.py
Comment thread src/maxdiffusion/models/attention_flax.py
Comment thread src/maxdiffusion/models/attention_flax.py
Comment thread src/maxdiffusion/tests/attention_test.py
@csgoogle csgoogle force-pushed the wan-ulysses-bshd-attention branch 2 times, most recently from 14b997a to a219866 Compare May 19, 2026 10:35
@github-actions
Copy link
Copy Markdown

🤖 Hi @csgoogle, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR introduces a robust implementation of the ulysses_ring attention mode, effectively combining Ulysses all-to-all head sharding with Ring-based KV rotation for WAN models. The implementation is well-architected, leveraging internal mesh reshaping and shard_map to provide a clean public interface while handling complex 2D parallelism internally.

🔍 General Feedback

  • High Quality Implementation: The use of internal axes (ring, ulysses) within shard_map is a clever and effective way to manage 2D parallelism.
  • Comprehensive Testing: The added unit tests are thorough and cover important edge cases, including global sequence padding and validation logic.
  • Plumbing Consistency: The configuration plumbing is consistently applied across all WAN model variants and pipelines.
  • Safety: The fallback mechanism for cross-attention and the rigorous validation of shard divisibility ensure the stability of the new feature.

Comment thread src/maxdiffusion/pyconfig.py
Comment thread src/maxdiffusion/models/attention_flax.py
Comment thread src/maxdiffusion/models/attention_flax.py Outdated
@csgoogle csgoogle force-pushed the wan-ulysses-bshd-attention branch 2 times, most recently from 4541243 to e88ad77 Compare May 25, 2026 11:43
@github-actions
Copy link
Copy Markdown

🤖 Hi @csgoogle, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR introduces a significant new feature: hybrid Ulysses + Ring attention for WAN models, enabling better scalability for long-sequence video generation. The implementation of the attention logic and configuration plumbing is well-structured. However, there are a few regressions in default configuration behavior and profiling logic that should be addressed before merging.

🔍 General Feedback

  • Attention Implementation: The new ulysses_ring kernel logic looks solid and follows the established patterns in the codebase for Pallas-based attention.
  • Config Regressions: The change to the default vae_spatial value from -1 to 1 is a high-priority concern as it might cause OOMs for users relying on automatic spatial parallelism.
  • Profiling Accuracy: The removal of block_until_ready() in the pipelines will likely break the accuracy of the profiling traces.
  • Test Restructuring: The relocation of WAN tests to the top-level tests/ directory is a large change; ensure that all CI paths are updated to reflect this new structure.
  • Dependency Cleanup: The removal of g++ and av from dependencies should be double-checked to ensure no hidden dependencies rely on them.

Comment thread src/maxdiffusion/tests/generate_wan_smoke_test.py
Comment thread src/maxdiffusion/common_types.py Outdated
Comment thread src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py
Comment thread maxdiffusion_dependencies.Dockerfile Outdated
Comment thread src/maxdiffusion/pyconfig.py Outdated
@csgoogle csgoogle force-pushed the wan-ulysses-bshd-attention branch from e88ad77 to 1bd27d2 Compare May 25, 2026 11:49
@csgoogle csgoogle force-pushed the wan-ulysses-bshd-attention branch 2 times, most recently from fce03c6 to 74058f5 Compare May 25, 2026 11:58
@csgoogle csgoogle requested a review from mbohlool May 25, 2026 12:01
@github-actions
Copy link
Copy Markdown

🤖 Hi @csgoogle, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request introduces a hybrid ulysses_ring attention mode for WAN models, allowing for 2D context parallelism by combining Ulysses all-to-all head sharding with ring-based KV rotation. The implementation is technically sound, particularly in how it handles internal mesh reshaping and kernel routing.

🔍 General Feedback

  • Consistency: The new ulysses_ring attention mode is correctly registered and integrated into the AttentionBlock routing logic.
  • VACE Support: The VACE components in transformer_wan_vace.py are currently missing the plumbing for ulysses_shards, which will prevent them from using the new mode if requested.
  • Axis Rules: There is a minor inconsistency in the CROSS_ATTN_KV_LENGTH axis rules for the new mode compared to existing sequence-parallel modes.
  • Padding & Masking: There are some subtle potential issues with attention_mask slicing and segment_ids rotation inside shard_map that could affect correctness in specific edge cases or with custom attention masks.

Comment thread src/maxdiffusion/common_types.py Outdated
Comment thread src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py
Comment thread src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py
Comment thread src/maxdiffusion/models/attention_flax.py
Comment thread src/maxdiffusion/models/attention_flax.py
@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @csgoogle, but I was unable to process your request. Please see the logs for more details.

@csgoogle csgoogle force-pushed the wan-ulysses-bshd-attention branch from 74058f5 to 163e43c Compare May 26, 2026 18:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant