Skip to content

Add all-gather + matmul ring tutorial#124

Open
yongweiy wants to merge 1 commit into
aws-neuron:mainfrom
yongweiy:allgather-matmul-ring
Open

Add all-gather + matmul ring tutorial#124
yongweiy wants to merge 1 commit into
aws-neuron:mainfrom
yongweiy:allgather-matmul-ring

Conversation

@yongweiy

@yongweiy yongweiy commented May 8, 2026

Copy link
Copy Markdown

Summary

Adds a new tutorial under src/nki_samples/tutorials/allgather_matmul_ring/ demonstrating a fused all-gather + matmul along a ring of ranks, using nki.collectives.collective_permute_implicit (CPI) to overlap communication with compute.

Each ring step, a rank computes a local matmul against the LHS fragment currently in its ring buffer, then passes the fragment on to the next rank and receives the previous rank's fragment — the scheduler places the matmul of step i and the CPI of step i on disjoint engines so they run concurrently. The matmul is row-parallel (LHS row-sharded, RHS column-sharded). After RANK_N ring steps every rank has computed one (M_LOCAL, N_LOCAL) slot of the fully-gathered output for every source rank.

Follows the existing tutorial layout:

  • allgather_matmul_ring_nki_kernels.py — NKI kernel (guarded with NKI_EXAMPLE_AGMM_RING_* markers). Module-level constants (RANK_N, M_LOCAL, N_LOCAL, K, K_TILE, N_TILE) keep the kernel signature small: allgather_matmul_ring(lhs_shard, rhs_shard, replica_group).
  • allgather_matmul_ring_torch.py — PyTorch/XLA runner: xmp.spawn across ranks, builds deterministic LHS/RHS shards, validates each rank's output against a reference matmul computed on the host.

Sized for world_size=8, LNC=1 on a single trn2 device — the kernel has no LNC=2-specific code path, so LNC=1 is the canonical config (M_LOCAL=128 matches the Tensor Engine partition-dim limit of a single physical core, so each rank fully utilizes its core).

Test plan

  • trn2, world_size=8, LNC=1 (single device, 8 cores):
    NEURON_CC_FLAGS="--lnc=1" \
    NEURON_LOGICAL_NC_CONFIG=1 \
    NEURONCORE_NUM_DEVICES=8 \
      python allgather_matmul_ring_torch.py
    
    All 8 ranks PASS with rel_err in 0.0006-0.0025 (threshold < 0.01).
  • Profile inspection on NeuronExplorer confirms Tensor-engine matmul and CPI overlap on disjoint engines per ring step.

Other world_size × LNC combinations may fail if the replica group does not map to a valid CPI ring topology on the hardware; documented in the tutorial's module docstring.

@yongweiy yongweiy force-pushed the allgather-matmul-ring branch from e4e1b8c to cac3576 Compare June 18, 2026 17:03
@yongweiy yongweiy force-pushed the allgather-matmul-ring branch from cac3576 to 2e40c58 Compare June 18, 2026 17:11
@yongweiy yongweiy requested a review from liralon June 18, 2026 17:11
Demonstrates a fused all-gather + matmul along a ring of ranks, using
nki.collectives.collective_permute_implicit (CPI) to overlap
communication with compute. Each step, a rank computes a local matmul
on the LHS fragment in its ring buffer, then passes the fragment to the
next rank while receiving the previous rank's fragment. The scheduler
places the matmul of step (i) and the CPI of step (i) on disjoint
engines so they run concurrently.

Sized for world_size=8, LNC=1 on a single trn2 device (8 physical cores
forming the ring). The kernel has no LNC=2-specific code path, so LNC=1
is the canonical config; M_LOCAL=128 matches the Tensor Engine
partition-dim limit of a single core, so each rank fully utilizes its
core. Validated end-to-end: all 8 ranks PASS at rel_err < 0.01
(observed 0.0006-0.0025).

Run command:

  NEURON_CC_FLAGS="--lnc=1" \
  NEURON_LOGICAL_NC_CONFIG=1 \
  NEURONCORE_NUM_DEVICES=8 \
    python allgather_matmul_ring_torch.py
@yongweiy yongweiy force-pushed the allgather-matmul-ring branch from 2e40c58 to a89c931 Compare June 18, 2026 17:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant