Add all-gather + matmul ring tutorial#124
Open
yongweiy wants to merge 1 commit into
Open
Conversation
e4e1b8c to
cac3576
Compare
cac3576 to
2e40c58
Compare
Demonstrates a fused all-gather + matmul along a ring of ranks, using
nki.collectives.collective_permute_implicit (CPI) to overlap
communication with compute. Each step, a rank computes a local matmul
on the LHS fragment in its ring buffer, then passes the fragment to the
next rank while receiving the previous rank's fragment. The scheduler
places the matmul of step (i) and the CPI of step (i) on disjoint
engines so they run concurrently.
Sized for world_size=8, LNC=1 on a single trn2 device (8 physical cores
forming the ring). The kernel has no LNC=2-specific code path, so LNC=1
is the canonical config; M_LOCAL=128 matches the Tensor Engine
partition-dim limit of a single core, so each rank fully utilizes its
core. Validated end-to-end: all 8 ranks PASS at rel_err < 0.01
(observed 0.0006-0.0025).
Run command:
NEURON_CC_FLAGS="--lnc=1" \
NEURON_LOGICAL_NC_CONFIG=1 \
NEURONCORE_NUM_DEVICES=8 \
python allgather_matmul_ring_torch.py
2e40c58 to
a89c931
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a new tutorial under
src/nki_samples/tutorials/allgather_matmul_ring/demonstrating a fused all-gather + matmul along a ring of ranks, usingnki.collectives.collective_permute_implicit(CPI) to overlap communication with compute.Each ring step, a rank computes a local matmul against the LHS fragment currently in its ring buffer, then passes the fragment on to the next rank and receives the previous rank's fragment — the scheduler places the matmul of step i and the CPI of step i on disjoint engines so they run concurrently. The matmul is row-parallel (LHS row-sharded, RHS column-sharded). After
RANK_Nring steps every rank has computed one(M_LOCAL, N_LOCAL)slot of the fully-gathered output for every source rank.Follows the existing tutorial layout:
allgather_matmul_ring_nki_kernels.py— NKI kernel (guarded withNKI_EXAMPLE_AGMM_RING_*markers). Module-level constants (RANK_N,M_LOCAL,N_LOCAL,K,K_TILE,N_TILE) keep the kernel signature small:allgather_matmul_ring(lhs_shard, rhs_shard, replica_group).allgather_matmul_ring_torch.py— PyTorch/XLA runner:xmp.spawnacross ranks, builds deterministic LHS/RHS shards, validates each rank's output against a reference matmul computed on the host.Sized for
world_size=8,LNC=1on a single trn2 device — the kernel has no LNC=2-specific code path, so LNC=1 is the canonical config (M_LOCAL=128 matches the Tensor Engine partition-dim limit of a single physical core, so each rank fully utilizes its core).Test plan
rel_errin0.0006-0.0025(threshold< 0.01).Other
world_size × LNCcombinations may fail if the replica group does not map to a valid CPI ring topology on the hardware; documented in the tutorial's module docstring.