[PyTorch][CP] Add THD format support for AllGather-based Context Parallelism#2829
[PyTorch][CP] Add THD format support for AllGather-based Context Parallelism#2829sudhakarsingh27 wants to merge 76 commits into
Conversation
… cu_seqlens - Use per-step cu_seqlens_q_padded to select Q chunks instead of tensor slicing - Use padded cu_seqlens_kv for K/V reordering (ensures divisibility) - Add cu_seqlens_kv and cu_seqlens_kv_padded to AllGather function signature - Compute per-step Q and KV cu_seqlens correctly from actual seqlens - Support non-causal attention (all KV visible) - Zero-initialize out/dq for THD to avoid garbage in padding regions - Save per-step cu_seqlens in ctx for backward (avoid recomputation) Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Remove skip gates that blocked THD format with all_gather CP comm type. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…seqlens_q_padded The interleaved valid mask computation assumed cu_seqlens_q_padded starts at 0. With the CP offset-based approach, cu_seqlens_q_padded can start at a non-zero offset, causing a size mismatch. Use absolute positions from cu_seqlens_q_padded to build the valid mask instead. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
1164a15 to
b4db9eb
Compare
for more information, see https://pre-commit.ci
| if qkv_format == "thd": | ||
| # [cp*t, h, d] -> reorder to contiguous per-sequence order -> [t_full, h, d] | ||
| chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) | ||
| k_ag = reorder_seq_chunks_after_a2a_before_attn_thd( |
There was a problem hiding this comment.
This reorder_seq_chunks_after_a2a_before_attn_thd and the other releated method are not "a2a" specific now, rename them to something like dualchunk_to_contiguous_order_thd and the other one contiguous_to_dualchunk_order_thd
There was a problem hiding this comment.
Resolved on the current branch, with final cleanup in 0e926c42. The THD reorder entry points are now reorder_thd_sequences_to_rank_sharded and reorder_thd_sequences_to_contiguous, and the stale Python permutation helpers were removed. Both wrappers call the fused tex.thd_reorder path, so this logic is no longer A2A-named or A2A-specific.
There was a problem hiding this comment.
Still resolved at PR head af2bd1c3. The Python wrappers are not A2A-specific (reorder_thd_sequences_to_contiguous / reorder_thd_sequences_to_rank_sharded) and now call the renamed fused binding tex.thd_cp_reorder_sequences.
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…formerEngine into cp_thd_swa_with_ag
for more information, see https://pre-commit.ci
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
…formerEngine into cp_thd_swa_with_ag
Greptile SummaryThis PR adds THD (variable-length sequence) format support to
Confidence Score: 4/5Safe to merge for FusedAttention and FA3 paths; FA2 fallback for THD+AllGather silently produces wrong output when cuDNN and FA3 are both unavailable. The FusedAttention and FA3 forward/backward paths have been thoroughly validated. CUDA kernels are correct, stream synchronizations are in place, and per-step cu_seqlens accounting is sound. The one outstanding gap is that a user whose environment falls back to FA2 will silently compute wrong attention outputs for THD+AllGather — the test suite skips this combination but there is no runtime assertion blocking it. transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py — specifically the non-fused, non-FA3 branch of the AllGather forward/backward for THD format. Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant Forward as AttnFuncWithCPAndKVAllGather.forward
participant AG as gather_along_first_dim
participant Reorder as thd_cp_rank_order_to_sequence_order
participant Attn0 as Attn Step 0 (current_stream)
participant Attn1 as Attn Step 1 (cp_stream)
participant Copy as thd_copy_valid_tokens
Caller->>Forward: "q[t_q,h,d], k[t_k,h,d], v[t_k,h,d], cu_seqlens_*, cu_seqlens_*_padded"
Forward->>Forward: Clone cu_seqlens_q_original, cu_seqlens_kv_original
Forward->>Forward: "Divide cu_seqlens_q_padded, max_seqlen by 2*cp_size"
Forward->>AG: "k, v → k_ag[cp*t_k,h,d], v_ag[cp*t_k,h,d]"
AG-->>Forward: k_ag, v_ag (gathered)
Forward->>Reorder: k_ag, v_ag + cu_seqlens_kv_padded (global)
Reorder-->>Forward: k_ag, v_ag in sequence order
Forward->>Forward: cp_stream.wait_stream(current_stream)
Forward->>Forward: Pre-compute per-step cu_seqlens_q/kv, cu_seqlens_q_padded
Forward->>Attn0: "q_part=q, k_ag, v_ag, thd_cu_seqlens_q_padded_per_step[0], seqused_q/k"
Forward->>Attn1: "q_part=q, k_ag, v_ag, thd_cu_seqlens_q_padded_per_step[1], seqused_q/k"
Attn0-->>Forward: out_per_step[0], softmax_lse[0]
Attn1-->>Forward: out_per_step[1], softmax_lse[1]
Forward->>Copy: thd_copy_valid_tokens(out, out_per_step[0], padded_step0, cu_seqlens_q_step0)
Forward->>Copy: thd_copy_valid_tokens(out, out_per_step[1], padded_step1, cu_seqlens_q_step1)
Copy-->>Forward: out[t_q,h,d] filled at valid positions
Forward-->>Caller: out (combined attention output)
Reviews (19): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
…ific helpers The AllGather THD path was not extending KV visibility beyond the causal boundary when window_size had a right component > 0, meaning tokens right of the diagonal were invisible to the kernel. Fix by adding window_size[1] to visible_padded (clamped at actual seqlen) and max_seqlen_kv_. Also rename reorder helpers to backend-neutral names since AllGather now uses them too, and add a clarifying comment for non-causal KV cu_seqlens. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
…formerEngine into cp_thd_swa_with_ag
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
| actual_seqlens_kv = cu_seqlens_kv_original[1:] - cu_seqlens_kv_original[:-1] | ||
| padded_chunk_sizes_kv = (cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1]) // ( | ||
| 2 * cp_size | ||
| ) |
There was a problem hiding this comment.
Addressed in the pending local patch. visible_padded and visible_actual are now computed inside the if causal or sliding_window_attn: block where they are used.
…leanup # Conflicts: # transformer_engine/pytorch/csrc/extensions/pybind.cpp
Remove redundant runtime checks and stale local-experiment wording from the THD AllGather path because the PR review requested keeping this path focused on existing support gates and code behavior. Keep the non-essential CP matrix opt-in via NVTE_TEST_ESSENTIAL so offline validation does not require source edits. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
| flash_attn_streams[i - 1].record_event(dkv_update_done) | ||
| if ctx.qkv_format == "thd": | ||
| # dQ: copy every segment's valid token range from this step's dQ. | ||
| tex.thd_valid_copy( |
There was a problem hiding this comment.
Nit: would thd_copy_valid_tokens be a more descriptive name?
There was a problem hiding this comment.
Pushed in 27d2b84a, with pre-commit follow-up af2bd1c3: renamed the Python-visible/C++ wrapper to thd_cp_copy_valid_tokens.
| return out; | ||
| } | ||
|
|
||
| void thd_valid_copy(at::Tensor out, const at::Tensor &inp, const at::Tensor &cu_seqlens_padded, |
There was a problem hiding this comment.
Should cp_ be in the names as well? i.e. thd_cp_reorder_sequences/cp_thd_reorder_sequences or thd_cp_copy_valid_tokens/cp_thd_copy_valid_tokens? (Just a suggestion)
There was a problem hiding this comment.
Pushed in 27d2b84a, with pre-commit follow-up af2bd1c3: the Python-visible/C++ wrapper names are now thd_cp_reorder_sequences and thd_cp_copy_valid_tokens. The underlying C API names remain nvte_cp_thd_*.
| "Fused dual-chunk THD reorder for context parallel (gather/scatter), inline index", | ||
| py::call_guard<py::gil_scoped_release>()); | ||
| m.def("thd_valid_copy", &transformer_engine::pytorch::thd_valid_copy, | ||
| "Sync-free copy of valid THD token rows into an accumulator (CP AllGather fwd/bwd)", |
There was a problem hiding this comment.
What's "token rows" here? Sequences? Ranks?
Also, what's "inline index" in thd_reorder? :)
There was a problem hiding this comment.
Pushed in 27d2b84a, with pre-commit follow-up af2bd1c3: cleaned up the pybind docstrings. They no longer use “token rows” or “inline index”; they now describe reordering between CP rank-sharded dual-chunk order and contiguous per-sequence order, and copying valid THD sequence entries from a padded tensor.
|
I feel the code changes are complex enough that I have to rely on the tests. Please make sure the tests cover both Fused and Flash v3 backends, SWA/full attention/causal mask, pad_between_seqs=T/F combinations, and post the performance numbers if there's any (for the new reorder/valid_copy kernels vs the old PyTorch-based functions). Thanks! |
|
Also, please fix the DCO, and make the comments a bit more succinct if you can (some comments are quite thorough but also very long :) ). Thanks! |
Resolve review comments for PR 2829 by tightening the THD all_gather output shape, renaming the new THD CP helper bindings, removing the unrelated pybind helper extraction, and aligning the FP8 t3hd aux handling with the post-FP8DS code path. Also clean up the THD CP test skip logic and remove an unnecessary dtype conversion from the THD max-logit mask construction. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
The AllGather CP path already has a support assert that rejects padding masks for non-THD inputs. Keep the earlier THD-specific padding requirement, and rely on the later AllGather support assert for the non-THD padding case so the check is not duplicated. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
af2bd1c to
46b764e
Compare
| ] | ||
| # Adjust chunks for each step | ||
| thd_cu_seqlens_kv_per_step[0][1:] = visible_actual[0].cumsum(0) | ||
| thd_cu_seqlens_kv_per_step[1][1:] = visible_actual[1].cumsum(0) |
There was a problem hiding this comment.
L3286-L3370 could be cudafied if this becomes a performance bottleneck.
|
/te-ci pytorch L3 |
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…leanup Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…codex/pr2829-review-cleanup Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
Rename the new THD CP reorder helpers so the source and destination layouts are encoded in the API names instead of a direction boolean. Also rename the valid-token copy helper to describe its per-split to rank-local accumulator role. Guard copy-valid tokens that precede the first padded THD offset before indexing shared cu_seqlens arrays; later split offsets can legitimately leave those token positions outside any valid sequence range. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> (cherry picked from commit 55f9f18747510d92791ef2650c65d93d9d90c27c)
…view-cleanup Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch L3 |
Description
Add THD (variable-length sequence) format support to
AttnFuncWithCPAndKVAllGather. Previously, AllGather-based CP only supported fixed-length formats (bshd/sbhd). THD format packs variable-length sequences into a single[t, h, d]tensor tracked bycu_seqlens, which is needed for workloads with heterogeneous sequence lengths.The key challenge is that AllGather CP splits Q across 2 steps (one per local chunk), but THD tensors cannot be naively sliced like fixed-length formats. This PR uses an offset-based approach: the full Q tensor is passed to the attention kernel each step, with per-step
cu_seqlens_q_paddedvalues directing the kernel to read the correct chunk. This avoids tensor slicing and follows the padded THD convention used by the backends.Type of change
Changes
cu_seqlens_q_paddedselects which chunk the kernel reads from the full Q tensor, instead of slicing Q per step.cu_seqlens_kv, and window ranges for causal, full/no-window, and SWA cases.pad_between_seqsandseqused_khandling.cu_seqlens_q_paddedin the valid-token mask without.item()D2H synchronizations.test_cp_utils.py.Checklist:
Latest validation update (2026-06-11)
CP THD AllGather coverage was validated across FusedAttention and FlashAttention v3.
False/TrueFalse/True(128,0)False/True(128,0)(512,512)test_cp_utils.pyalso passed 14/14 on both H100 and B200 L1 in CI pipeline 54405200. It covers the THD helper kernels, including reorder/copy-valid tests against the legacy Python reference paths.FusedAttention does not expose the FlashAttention-specific
pad_between_seqsaxis; Fused THD padding semantics are covered through the THD/padding mask path.Helper-kernel microbenchmarks, single H100, bf16,
batch=16,seqlen=4096,heads=16,dim=64:thd_cp_reorder_sequencescontiguous->rankthd_cp_reorder_sequencesrank->contiguousthd_cp_copy_valid_tokensNote: the helper-kernel numbers are fixed-shape microbenchmarks, not a sweep over sequence length or number of sequences.