Skip to content

[PyTorch][CP] Add THD format support for AllGather-based Context Parallelism#2829

Open
sudhakarsingh27 wants to merge 76 commits into
NVIDIA:mainfrom
sudhakarsingh27:cp_thd_swa_with_ag
Open

[PyTorch][CP] Add THD format support for AllGather-based Context Parallelism#2829
sudhakarsingh27 wants to merge 76 commits into
NVIDIA:mainfrom
sudhakarsingh27:cp_thd_swa_with_ag

Conversation

@sudhakarsingh27

@sudhakarsingh27 sudhakarsingh27 commented Apr 3, 2026

Copy link
Copy Markdown
Member

Description

Add THD (variable-length sequence) format support to AttnFuncWithCPAndKVAllGather. Previously, AllGather-based CP only supported fixed-length formats (bshd/sbhd). THD format packs variable-length sequences into a single [t, h, d] tensor tracked by cu_seqlens, which is needed for workloads with heterogeneous sequence lengths.

The key challenge is that AllGather CP splits Q across 2 steps (one per local chunk), but THD tensors cannot be naively sliced like fixed-length formats. This PR uses an offset-based approach: the full Q tensor is passed to the attention kernel each step, with per-step cu_seqlens_q_padded values directing the kernel to read the correct chunk. This avoids tensor slicing and follows the padded THD convention used by the backends.

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

  • Offset-based Q chunking: Per-step cu_seqlens_q_padded selects which chunk the kernel reads from the full Q tensor, instead of slicing Q per step.
  • Per-step KV sequence metadata: Computes visible KV token counts, per-step cu_seqlens_kv, and window ranges for causal, full/no-window, and SWA cases.
  • THD AllGather helper kernels: Adds sync-free THD reorder / valid-copy helper kernels for the AllGather CP THD path.
  • FlashAttention v3 AllGather THD support: Enables FlashAttention v3 AG+THD coverage, including padded THD cases through pad_between_seqs and seqused_k handling.
  • max_logit masking fix: Handles non-zero-starting cu_seqlens_q_padded in the valid-token mask without .item() D2H synchronizations.
  • Tests: Adds/extends CP THD coverage for FusedAttention, FlashAttention v3, AllGather, SWA/full/causal masks, padded Flash THD, and helper-kernel unit coverage in test_cp_utils.py.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Latest validation update (2026-06-11)

CP THD AllGather coverage was validated across FusedAttention and FlashAttention v3.

Backend Coverage Pad Result CI/local test
Flash v3 AG+THD causal False/True PASS CI 54405200 H100 L3 FA3 CP; local focused H100 run
Flash v3 AG+THD full/no-window False/True PASS Local focused H100 run
Flash v3 AG+THD causal SWA (128,0) False/True PASS CI 54405200 H100 L3 FA3 CP; local focused H100 run
Fused AG+THD causal n/a PASS CI 54405200 H100 L1 CP; local focused H100 run
Fused AG+THD full/no-window n/a PASS Local focused H100 run
Fused AG+THD causal SWA (128,0) n/a PASS CI 54405200 H100 L1 CP; local focused H100 run
Fused AG+THD causal SWA (512,512) n/a PASS Local focused H100 run

test_cp_utils.py also passed 14/14 on both H100 and B200 L1 in CI pipeline 54405200. It covers the THD helper kernels, including reorder/copy-valid tests against the legacy Python reference paths.

FusedAttention does not expose the FlashAttention-specific pad_between_seqs axis; Fused THD padding semantics are covered through the THD/padding mask path.

Helper-kernel microbenchmarks, single H100, bf16, batch=16, seqlen=4096, heads=16, dim=64:

Kernel path cp sizes Legacy wall ms New kernel wall ms Speedup
thd_cp_reorder_sequences contiguous->rank 2, 4, 8 11.6829-37.4912 0.0954-0.1090 122.42x-346.08x
thd_cp_reorder_sequences rank->contiguous 2, 4, 8 11.0553-19.2267 0.0966-0.1094 114.43x-177.87x
thd_cp_copy_valid_tokens 2, 4, 8 0.7070-9.4218 0.0911-0.1052 7.76x-89.59x

Note: the helper-kernel numbers are fixed-shape microbenchmarks, not a sweep over sequence length or number of sequences.

… cu_seqlens

- Use per-step cu_seqlens_q_padded to select Q chunks instead of tensor slicing
- Use padded cu_seqlens_kv for K/V reordering (ensures divisibility)
- Add cu_seqlens_kv and cu_seqlens_kv_padded to AllGather function signature
- Compute per-step Q and KV cu_seqlens correctly from actual seqlens
- Support non-causal attention (all KV visible)
- Zero-initialize out/dq for THD to avoid garbage in padding regions
- Save per-step cu_seqlens in ctx for backward (avoid recomputation)

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Remove skip gates that blocked THD format with all_gather CP comm type.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…seqlens_q_padded

The interleaved valid mask computation assumed cu_seqlens_q_padded starts
at 0. With the CP offset-based approach, cu_seqlens_q_padded can start at
a non-zero offset, causing a size mismatch. Use absolute positions from
cu_seqlens_q_padded to build the valid mask instead.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
if qkv_format == "thd":
# [cp*t, h, d] -> reorder to contiguous per-sequence order -> [t_full, h, d]
chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device)
k_ag = reorder_seq_chunks_after_a2a_before_attn_thd(

@sudhakarsingh27 sudhakarsingh27 Apr 3, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reorder_seq_chunks_after_a2a_before_attn_thd and the other releated method are not "a2a" specific now, rename them to something like dualchunk_to_contiguous_order_thd and the other one contiguous_to_dualchunk_order_thd

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved on the current branch, with final cleanup in 0e926c42. The THD reorder entry points are now reorder_thd_sequences_to_rank_sharded and reorder_thd_sequences_to_contiguous, and the stale Python permutation helpers were removed. Both wrappers call the fused tex.thd_reorder path, so this logic is no longer A2A-named or A2A-specific.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still resolved at PR head af2bd1c3. The Python wrappers are not A2A-specific (reorder_thd_sequences_to_contiguous / reorder_thd_sequences_to_rank_sharded) and now call the renamed fused binding tex.thd_cp_reorder_sequences.

Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated
@sudhakarsingh27 sudhakarsingh27 changed the title Cp thd swa with ag [PyTorch][CP] Add THD format support for AllGather-based Context Parallelism Apr 13, 2026
@sudhakarsingh27 sudhakarsingh27 marked this pull request as ready for review April 13, 2026 21:53
Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated
@greptile-apps

greptile-apps Bot commented Apr 13, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds THD (variable-length sequence) format support to AttnFuncWithCPAndKVAllGather. It introduces an offset-based approach — passing the full Q tensor each step with per-step cu_seqlens_q_padded offsets directing the kernel to the correct chunk — along with new CUDA helper kernels for THD reordering and valid-token copying.

  • New CUDA kernels: thd_reorder_between_sequence_and_cp_rank_order_kernel (fused sequence↔CP-rank reorder) and thd_copy_valid_tokens_from_per_split_to_rank_local_kernel (valid-token scatter into accumulator), exposed through pybind as tex.thd_sequence_order_to_cp_rank_order, tex.thd_cp_rank_order_to_sequence_order, and tex.thd_copy_valid_tokens_from_per_split_to_rank_local.
  • Per-step THD metadata: Pre-computes thd_cu_seqlens_q_per_step, thd_cu_seqlens_q_padded_per_step, and thd_cu_seqlens_kv_per_step for both forward and backward, covering causal, full-window, and SWA cases.
  • FA3 + THD AllGather path: Uses seqused_q/seqused_k alongside padded cu_seqlens to separate tensor offsets from visibility limits; FA2 varlen cannot represent this split so it is test-skipped when FA3 is absent.

Confidence Score: 4/5

Safe to merge for FusedAttention and FA3 paths; FA2 fallback for THD+AllGather silently produces wrong output when cuDNN and FA3 are both unavailable.

The FusedAttention and FA3 forward/backward paths have been thoroughly validated. CUDA kernels are correct, stream synchronizations are in place, and per-step cu_seqlens accounting is sound. The one outstanding gap is that a user whose environment falls back to FA2 will silently compute wrong attention outputs for THD+AllGather — the test suite skips this combination but there is no runtime assertion blocking it.

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py — specifically the non-fused, non-FA3 branch of the AllGather forward/backward for THD format.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Core logic change: removes THD assert for AllGather, adds offset-based per-step Q chunking, new forward/backward THD metadata pre-computation, and delegates reorder/copy to new CUDA kernels. FA2 fallback path for THD is silently broken.
transformer_engine/common/fused_attn/context_parallel.cu Adds thd_reorder and thd_copy_valid_tokens CUDA kernels with warp-level parallelism, shared memory for cu_seqlens, float4 loads, and binary_search for sequence identification. NVTE_CHECK enforces 128-bit alignment.
transformer_engine/pytorch/cpp_extensions/fused_attn.py Replaces D2H .item() loop with GPU-only scatter_add_ delta approach for valid-token mask construction, eliminating device synchronization.
transformer_engine/pytorch/csrc/extensions/attention.cpp Adds C++ pybind wrappers for the three new CUDA kernel entry points with proper NVTE_CHECK input validation.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Removes 'all_gather' from the THD FusedAttention disablement list, enabling the new code path.
tests/pytorch/attention/test_cp_utils.py New unit tests for the CUDA reorder and copy kernels covering sequence-order↔CP-rank-order round-trips and valid-token copy correctness.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant Forward as AttnFuncWithCPAndKVAllGather.forward
    participant AG as gather_along_first_dim
    participant Reorder as thd_cp_rank_order_to_sequence_order
    participant Attn0 as Attn Step 0 (current_stream)
    participant Attn1 as Attn Step 1 (cp_stream)
    participant Copy as thd_copy_valid_tokens

    Caller->>Forward: "q[t_q,h,d], k[t_k,h,d], v[t_k,h,d], cu_seqlens_*, cu_seqlens_*_padded"
    Forward->>Forward: Clone cu_seqlens_q_original, cu_seqlens_kv_original
    Forward->>Forward: "Divide cu_seqlens_q_padded, max_seqlen by 2*cp_size"
    Forward->>AG: "k, v → k_ag[cp*t_k,h,d], v_ag[cp*t_k,h,d]"
    AG-->>Forward: k_ag, v_ag (gathered)
    Forward->>Reorder: k_ag, v_ag + cu_seqlens_kv_padded (global)
    Reorder-->>Forward: k_ag, v_ag in sequence order
    Forward->>Forward: cp_stream.wait_stream(current_stream)
    Forward->>Forward: Pre-compute per-step cu_seqlens_q/kv, cu_seqlens_q_padded
    Forward->>Attn0: "q_part=q, k_ag, v_ag, thd_cu_seqlens_q_padded_per_step[0], seqused_q/k"
    Forward->>Attn1: "q_part=q, k_ag, v_ag, thd_cu_seqlens_q_padded_per_step[1], seqused_q/k"
    Attn0-->>Forward: out_per_step[0], softmax_lse[0]
    Attn1-->>Forward: out_per_step[1], softmax_lse[1]
    Forward->>Copy: thd_copy_valid_tokens(out, out_per_step[0], padded_step0, cu_seqlens_q_step0)
    Forward->>Copy: thd_copy_valid_tokens(out, out_per_step[1], padded_step1, cu_seqlens_q_step1)
    Copy-->>Forward: out[t_q,h,d] filled at valid positions
    Forward-->>Caller: out (combined attention output)
Loading

Reviews (19): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated
sudhakarsingh27 and others added 2 commits April 16, 2026 11:28
…ific helpers

The AllGather THD path was not extending KV visibility beyond the causal
boundary when window_size had a right component > 0, meaning tokens right
of the diagonal were invisible to the kernel. Fix by adding window_size[1]
to visible_padded (clamped at actual seqlen) and max_seqlen_kv_.

Also rename reorder helpers to backend-neutral names since AllGather now
uses them too, and add a clarifying comment for non-causal KV cu_seqlens.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Comment thread tests/pytorch/attention/test_attention_with_cp.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated
Comment thread tests/pytorch/attention/test_attention_with_cp.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@ptrendx ptrendx added this to the 2.15 milestone Apr 23, 2026
Comment on lines +3347 to +3350
actual_seqlens_kv = cu_seqlens_kv_original[1:] - cu_seqlens_kv_original[:-1]
padded_chunk_sizes_kv = (cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1]) // (
2 * cp_size
)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in the pending local patch. visible_padded and visible_actual are now computed inside the if causal or sliding_window_attn: block where they are used.

…leanup

# Conflicts:
#	transformer_engine/pytorch/csrc/extensions/pybind.cpp
Remove redundant runtime checks and stale local-experiment wording from the THD AllGather path because the PR review requested keeping this path focused on existing support gates and code behavior. Keep the non-essential CP matrix opt-in via NVTE_TEST_ESSENTIAL so offline validation does not require source edits.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Comment thread tests/pytorch/attention/test_attention_with_cp.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated
flash_attn_streams[i - 1].record_event(dkv_update_done)
if ctx.qkv_format == "thd":
# dQ: copy every segment's valid token range from this step's dQ.
tex.thd_valid_copy(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: would thd_copy_valid_tokens be a more descriptive name?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushed in 27d2b84a, with pre-commit follow-up af2bd1c3: renamed the Python-visible/C++ wrapper to thd_cp_copy_valid_tokens.

Comment thread transformer_engine/pytorch/cpp_extensions/fused_attn.py
return out;
}

void thd_valid_copy(at::Tensor out, const at::Tensor &inp, const at::Tensor &cu_seqlens_padded,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should cp_ be in the names as well? i.e. thd_cp_reorder_sequences/cp_thd_reorder_sequences or thd_cp_copy_valid_tokens/cp_thd_copy_valid_tokens? (Just a suggestion)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushed in 27d2b84a, with pre-commit follow-up af2bd1c3: the Python-visible/C++ wrapper names are now thd_cp_reorder_sequences and thd_cp_copy_valid_tokens. The underlying C API names remain nvte_cp_thd_*.

"Fused dual-chunk THD reorder for context parallel (gather/scatter), inline index",
py::call_guard<py::gil_scoped_release>());
m.def("thd_valid_copy", &transformer_engine::pytorch::thd_valid_copy,
"Sync-free copy of valid THD token rows into an accumulator (CP AllGather fwd/bwd)",

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's "token rows" here? Sequences? Ranks?

Also, what's "inline index" in thd_reorder? :)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushed in 27d2b84a, with pre-commit follow-up af2bd1c3: cleaned up the pybind docstrings. They no longer use “token rows” or “inline index”; they now describe reordering between CP rank-sharded dual-chunk order and contiguous per-sequence order, and copying valid THD sequence entries from a padded tensor.

Comment thread transformer_engine/pytorch/csrc/extensions/pybind.cpp Outdated
@cyanguwa

cyanguwa commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator

I feel the code changes are complex enough that I have to rely on the tests. Please make sure the tests cover both Fused and Flash v3 backends, SWA/full attention/causal mask, pad_between_seqs=T/F combinations, and post the performance numbers if there's any (for the new reorder/valid_copy kernels vs the old PyTorch-based functions). Thanks!

@cyanguwa

cyanguwa commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator

Also, please fix the DCO, and make the comments a bit more succinct if you can (some comments are quite thorough but also very long :) ). Thanks!

sudhakarsingh27 and others added 3 commits June 10, 2026 13:36
Resolve review comments for PR 2829 by tightening the THD all_gather output shape, renaming the new THD CP helper bindings, removing the unrelated pybind helper extraction, and aligning the FP8 t3hd aux handling with the post-FP8DS code path.

Also clean up the THD CP test skip logic and remove an unnecessary dtype conversion from the THD max-logit mask construction.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
The AllGather CP path already has a support assert that rejects padding masks for non-THD inputs. Keep the earlier THD-specific padding requirement, and rely on the later AllGather support assert for the non-THD padding case so the check is not duplicated.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
]
# Adjust chunks for each step
thd_cu_seqlens_kv_per_step[0][1:] = visible_actual[0].cumsum(0)
thd_cu_seqlens_kv_per_step[1][1:] = visible_actual[1].cumsum(0)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L3286-L3370 could be cudafied if this becomes a performance bottleneck.

@sudhakarsingh27

Copy link
Copy Markdown
Member Author

/te-ci pytorch L3

sudhakarsingh27 and others added 7 commits June 11, 2026 12:42
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…leanup

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…codex/pr2829-review-cleanup

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Rename the new THD CP reorder helpers so the source and destination layouts are encoded in the API names instead of a direction boolean. Also rename the valid-token copy helper to describe its per-split to rank-local accumulator role.

Guard copy-valid tokens that precede the first padded THD offset before indexing shared cu_seqlens arrays; later split offsets can legitimately leave those token positions outside any valid sequence range.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
(cherry picked from commit 55f9f18747510d92791ef2650c65d93d9d90c27c)
…view-cleanup

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27

Copy link
Copy Markdown
Member Author

/te-ci pytorch L3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants