Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
db44fc2
[PyTorch][CP] Fix THD AllGather CP: offset-based approach with proper…
sudhakarsingh27 Apr 7, 2026
1a5ca4c
[PyTorch][CP] Enable THD+all_gather tests in test_attention_with_cp
sudhakarsingh27 Apr 7, 2026
b4db9eb
[PyTorch][Fused Attn] Fix max_logit masking for non-zero-starting cu_…
sudhakarsingh27 Apr 7, 2026
7491ab6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2026
b957725
some cleanup of ag+thd impl and gate e e te test for flash+ag+thd
sudhakarsingh27 Apr 10, 2026
c89173c
Merge branch 'cp_thd_swa_with_ag' of github.com:sudhakarsingh27/Trans…
sudhakarsingh27 Apr 10, 2026
18e41bd
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into cp_th…
sudhakarsingh27 Apr 10, 2026
0b48746
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2026
608106d
improve the logic and remvoe for loop from the code
sudhakarsingh27 Apr 13, 2026
4b95130
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2026
15af3af
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into cp_th…
sudhakarsingh27 Apr 13, 2026
5bec5b3
Merge branch 'cp_thd_swa_with_ag' of github.com:sudhakarsingh27/Trans…
sudhakarsingh27 Apr 13, 2026
89b1066
AG+THD SWA: extend KV visibility for right window and rename a2a-spec…
sudhakarsingh27 Apr 16, 2026
55fc2cd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 16, 2026
f499f59
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into cp_th…
sudhakarsingh27 Apr 20, 2026
2569a65
Merge branch 'cp_thd_swa_with_ag' of github.com:sudhakarsingh27/Trans…
sudhakarsingh27 Apr 20, 2026
4e4212f
resolved merge conflicts with main
sudhakarsingh27 Apr 23, 2026
10e4cfc
[PyTorch] Add pad_between_seqs support for FlashAttention 3 with CP
sudhakarsingh27 Apr 24, 2026
2a49dee
[PyTorch] Add pad_between_seqs tests for CP and non-CP FlashAttention
sudhakarsingh27 Apr 24, 2026
34e3d62
[QA] Add CP deterministic tests to L3 and support TE_PATH in FA test
sudhakarsingh27 Apr 24, 2026
4745f98
[PyTorch] Fix FA3 deterministic gate to match upstream backward const…
sudhakarsingh27 Apr 24, 2026
4be004f
[PyTorch] Disable FlashAttention 4 for pad_between_seqs with THD
sudhakarsingh27 Apr 24, 2026
c476f15
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Apr 24, 2026
a2b0f1b
[QA] Fix cutlass-dsl utils shadow in FA versions test
sudhakarsingh27 Apr 25, 2026
0ee22c7
merge conflicts with main
sudhakarsingh27 Apr 26, 2026
dfc1472
Merge branch 'main' of https://github.com/NVIDIA/TransformerEngine in…
sudhakarsingh27 Apr 26, 2026
ac38d4f
merge flash attn pad bw seqs
sudhakarsingh27 Apr 26, 2026
b94e175
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Apr 28, 2026
7ebe3d9
fixes after merging with flash_attn_pad_bw_seqs branchj
sudhakarsingh27 Apr 28, 2026
ddaa196
Merge branch 'main' of https://github.com/NVIDIA/TransformerEngine in…
sudhakarsingh27 Apr 28, 2026
fc9182f
skip tests which OOM in deterministic+backward+hopper+large_configs a…
sudhakarsingh27 Apr 29, 2026
636666f
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 Apr 29, 2026
7928bc9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2026
1585ebb
Merge branch 'flash_attn_pad_bw_seqs' of github.com:sudhakarsingh27/T…
sudhakarsingh27 Apr 29, 2026
7ecad01
[PyTorch][CP] Replace Python-loop THD reorder with kernel-backed perm…
sudhakarsingh27 Apr 29, 2026
d8bf5c5
Merge remote-tracking branch 'sudhakar_repo/flash_attn_pad_bw_seqs' i…
sudhakarsingh27 Apr 29, 2026
cc104d3
[PyTorch][CP] Fix AllGather SBHD forward: set cu_seqlens_kv_per_step
sudhakarsingh27 Apr 29, 2026
2464f43
make cp det and nondet tests run in parallel whenever possible
sudhakarsingh27 Apr 30, 2026
26e9f6f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 30, 2026
611d876
[PyTorch][CP] Fix THD AllGather forward stream race on k_ag/v_ag
sudhakarsingh27 Apr 30, 2026
0aae820
Merge branch 'cp_thd_swa_with_ag' of github.com:sudhakarsingh27/Trans…
sudhakarsingh27 Apr 30, 2026
08a5239
Merge branch 'main' into flash_attn_pad_bw_seqs
sudhakarsingh27 May 1, 2026
0a32185
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into flash…
sudhakarsingh27 May 4, 2026
e173807
Merge branch 'flash_attn_pad_bw_seqs' of github.com:sudhakarsingh27/T…
sudhakarsingh27 May 4, 2026
88951c1
Merge remote-tracking branch 'sudhakar_repo/flash_attn_pad_bw_seqs' i…
sudhakarsingh27 May 4, 2026
6e3d1bd
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into cp_th…
sudhakarsingh27 May 5, 2026
f934111
Add THD + FlashAttention v3 support to AllGather CP backend
sudhakarsingh27 May 5, 2026
d3a9903
Refactor AG THD window logic into shared get_kv_seq_info_after_all_ga…
sudhakarsingh27 May 6, 2026
f334657
[PyTorch][CP] Address PR 2829 self-review: clarify THD mask/cu_seqlens
sudhakarsingh27 May 22, 2026
683e277
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into cp_th…
sudhakarsingh27 May 28, 2026
bdcdf65
[PyTorch] Fused thd_reorder kernel + sync-free CP THD reorder
sudhakarsingh27 May 30, 2026
5c1ad10
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2026
e6e8354
[PyTorch] Sync-free thd_valid_copy kernel for AllGather CP THD fwd/bwd
sudhakarsingh27 May 30, 2026
cfc754d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2026
c101665
[PyTorch] Fix FA3 all_gather THD allocator-reuse race in fused reorder
sudhakarsingh27 Jun 1, 2026
ce48413
[PyTorch] Serialize FA3 AG calls on GPU
sudhakarsingh27 Jun 2, 2026
f5ccaa0
[PyTorch] Avoid D2H sync in THD max-logit mask
sudhakarsingh27 Jun 2, 2026
f6e8487
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 2, 2026
17641d3
[PyTorch] Address THD AG review and lint issues
sudhakarsingh27 Jun 3, 2026
eb7aaf3
Merge NVIDIA main into CP THD SWA branch
sudhakarsingh27 Jun 3, 2026
d6ef2de
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2026
e834206
[PyTorch] Add THD helper kernel tests
sudhakarsingh27 Jun 3, 2026
8ed69d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2026
0f75ad4
Merge remote-tracking branch 'origin/main' into codex/pr2829-review-c…
sudhakarsingh27 Jun 5, 2026
b71d7f8
[PyTorch] Clean up THD AG review comments
sudhakarsingh27 Jun 5, 2026
1561a1c
[PyTorch] Address THD AG review follow-ups
sudhakarsingh27 Jun 10, 2026
6e8b668
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2026
46b764e
[PyTorch] Remove duplicate AllGather padding assert
sudhakarsingh27 Jun 10, 2026
695a01f
Merge branch 'main' into cp_thd_swa_with_ag
sudhakarsingh27 Jun 11, 2026
732801f
Address THD CP review cleanup
sudhakarsingh27 Jun 11, 2026
be665fe
Merge remote-tracking branch 'origin/main' into codex/pr2829-review-c…
sudhakarsingh27 Jun 11, 2026
85064f4
Merge remote-tracking branch 'sudhakar_repo/cp_thd_swa_with_ag' into …
sudhakarsingh27 Jun 11, 2026
ed3cc5c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 11, 2026
e4d1ecc
Clarify THD CP helper API names
sudhakarsingh27 Jun 12, 2026
4277f31
Merge remote-tracking branch 'https-origin/main' into codex/pr2829-re…
sudhakarsingh27 Jun 12, 2026
10eb1aa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions tests/pytorch/attention/run_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,9 +627,8 @@ def run_dpa_with_cp(
cu_seqlens_q = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True
)
num_pads_q = (cu_seqlens_q_padded - cu_seqlens_q)[1:] - (
cu_seqlens_q_padded - cu_seqlens_q
)[:-1]
cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q
num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1]
cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size
cu_seqlens_kv = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True
Expand Down
33 changes: 28 additions & 5 deletions tests/pytorch/attention/test_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

test_essential = True
test_essential = bool(int(os.getenv("NVTE_TEST_ESSENTIAL", "1")))

model_configs_flash_attn = {
# test: ModelConfig(b, sq, hq, dqk)
Expand Down Expand Up @@ -319,6 +319,14 @@ def test_cp_with_flash_attention(cp_pool, dtype, model, qkv_format, cp_comm_type
if cp_comm_type == "a2a+p2p":
pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!")

if pad_between_seqs:
if qkv_format != "thd":
pytest.skip("pad_between_seqs only applies to THD format!")
if not FlashAttentionUtils.v3_is_installed:
pytest.skip("pad_between_seqs with CP requires Flash Attention v3!")
if cp_comm_type == "a2a+p2p":
pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!")

config = model_configs_flash_attn[model]
config.context_parallel = True
config.cp_comm_type = cp_comm_type
Expand All @@ -328,8 +336,20 @@ def test_cp_with_flash_attention(cp_pool, dtype, model, qkv_format, cp_comm_type
if config.attn_bias_type != "no_bias" and cp_comm_type in ["all_gather", "a2a", "a2a+p2p"]:
pytest.skip("No support for bias with cp_comm_type={all_gather, a2a, a2a+p2p}!")

if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]:
pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!")
if qkv_format == "thd" and cp_comm_type == "a2a+p2p":
pytest.skip(
"CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format"
" yet!"
)
if (
qkv_format == "thd"
and cp_comm_type == "all_gather"
and not FlashAttentionUtils.v3_is_installed
):
pytest.skip(
"THD + all_gather requires FA3 (seqused_k) to separate tensor offsets from"
" visibility limits in the gathered KV buffer."
)

if (
config.window_size != (-1, 0)
Expand Down Expand Up @@ -538,8 +558,11 @@ def test_cp_with_fused_attention(
if config.attn_bias_type != "no_bias" and cp_comm_type in ["all_gather", "a2a", "a2a+p2p"]:
pytest.skip("No support for bias with cp_comm_type={all_gather, a2a, a2a+p2p}!")

if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]:
pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!")
if qkv_format == "thd" and cp_comm_type == "a2a+p2p":
pytest.skip(
"CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format"
" yet!"
)

if (config.window_size[0] != -1 or config.window_size[1] not in [-1, 0]) and cp_comm_type in [
"p2p",
Expand Down
189 changes: 187 additions & 2 deletions tests/pytorch/attention/test_cp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# See LICENSE for license information.

"""Unit tests for context parallel utils."""

import itertools
import torch
import unittest
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
Expand All @@ -11,9 +13,16 @@
generate_positional_ids_for_cp,
)

try:
import transformer_engine_torch as tex
except ImportError:
tex = None


class TestSequencePadding(unittest.TestCase):
def test_padding_with_custom_padding_values_sequences_shorter_than_divisibility_factor(self):
def test_padding_with_custom_padding_values_sequences_shorter_than_divisibility_factor(
self,
):
"""Test with custom padding values for all tensors."""
# Setup

Expand Down Expand Up @@ -467,7 +476,36 @@ def test_sequences_longer_than_divisibility_factor(self):
)

expected_positional_ids = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7]
[
0,
1,
2,
3,
4,
5,
6,
7,
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
0,
1,
2,
3,
4,
5,
6,
7,
]
)

expected_cu_seqlens_padded = torch.tensor([0, 8, 20, 28])
Expand Down Expand Up @@ -710,5 +748,152 @@ def test_integration_with_padding_and_cp_slicing(self):
self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0))


def _legacy_reorder_thd_to_cp_rank_order(x, cu_seqlens, cp_size, seq_dim=0):
total_slices_of_any_sequence = 2 * cp_size
slice_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]) // total_slices_of_any_sequence

indices = [
(
torch.arange(
seq_start + (cp_rank * slice_size),
seq_start + ((cp_rank + 1) * slice_size),
device=cu_seqlens.device,
),
torch.arange(
seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size),
seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size),
device=cu_seqlens.device,
),
)
for cp_rank in range(cp_size)
for slice_size, seq_start in zip(slice_sizes, cu_seqlens[:-1])
]

indices = list(itertools.chain(*indices))
indices = torch.cat(indices)
return x.index_select(seq_dim, indices)


def _legacy_reorder_thd_to_sequence_order(x, cu_seqlens, seq_chunk_ids, cp_size, seq_dim=0):
max_cum_seqlen_per_cp_rank = cu_seqlens[-1] // cp_size
cu_seqlens_on_any_cp_rank = cu_seqlens // cp_size

indices = [
torch.arange(
(
start + max_cum_seqlen_per_cp_rank * (chunk_id // 2)
if loc < cp_size
else (start + end) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2)
),
(
(start + end) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2)
if loc < cp_size
else end + max_cum_seqlen_per_cp_rank * (chunk_id // 2)
),
device=cu_seqlens.device,
)
for start, end in zip(cu_seqlens_on_any_cp_rank[:-1], cu_seqlens_on_any_cp_rank[1:])
for loc, chunk_id in enumerate(seq_chunk_ids)
]

indices = torch.cat(indices)
return x.index_select(seq_dim, indices)


def _legacy_valid_copy(out, inp, cu_seqlens_padded, cu_seqlens):
batch_size = cu_seqlens.shape[0] - 1
for b in range(batch_size):
s = cu_seqlens_padded[b].item()
sz = (cu_seqlens[b + 1] - cu_seqlens[b]).item()
if sz > 0:
out[s : s + sz].copy_(inp[s : s + sz])


@unittest.skipIf(
not torch.cuda.is_available() or tex is None,
"THD kernel tests require CUDA and transformer_engine_torch",
)
class TestTHDKernels(unittest.TestCase):
def test_thd_sequence_cp_rank_order_roundtrip_matches_legacy_python_reorder(self):
cp_size = 4
cu_seqlens = torch.tensor([0, 8, 24, 40], dtype=torch.int32, device="cuda")
x = torch.arange(40 * 2 * 4, dtype=torch.float16, device="cuda").view(40, 2, 4)

cp_rank_order = tex.thd_sequence_order_to_cp_rank_order(x, cu_seqlens, cp_size, x.shape[0])
ref_cp_rank_order = _legacy_reorder_thd_to_cp_rank_order(x, cu_seqlens, cp_size)
self.assertTrue(torch.equal(cp_rank_order, ref_cp_rank_order))

seq_chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device="cuda")
for rank in range(cp_size):
seq_chunk_ids[rank] = 2 * rank
seq_chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1
sequence_order = tex.thd_cp_rank_order_to_sequence_order(
cp_rank_order, cu_seqlens, cp_size, cp_rank_order.shape[0]
)
ref_sequence_order = _legacy_reorder_thd_to_sequence_order(
cp_rank_order, cu_seqlens, seq_chunk_ids, cp_size
)
self.assertTrue(torch.equal(sequence_order, ref_sequence_order))
self.assertTrue(torch.equal(sequence_order, x))

def test_thd_get_partitioned_indices_matches_dual_chunk_expected_indices(self):
cu_seqlens = torch.tensor([0, 8, 16], dtype=torch.int32, device="cuda")

rank0 = tex.thd_get_partitioned_indices(cu_seqlens, 16, 2, 0)
rank1 = tex.thd_get_partitioned_indices(cu_seqlens, 16, 2, 1)

expected_rank0 = torch.tensor([0, 1, 6, 7, 8, 9, 14, 15], dtype=torch.int32, device="cuda")
expected_rank1 = torch.tensor(
[2, 3, 4, 5, 10, 11, 12, 13], dtype=torch.int32, device="cuda"
)
self.assertTrue(torch.equal(rank0, expected_rank0))
self.assertTrue(torch.equal(rank1, expected_rank1))

def test_thd_copy_valid_tokens_from_per_split_matches_legacy_slice_copy_loop(self):
cu_seqlens_padded = torch.tensor([2, 6, 12], dtype=torch.int32, device="cuda")
cu_seqlens = torch.tensor([0, 3, 7], dtype=torch.int32, device="cuda")
inp = torch.arange(12 * 2 * 4, dtype=torch.float16, device="cuda").view(12, 2, 4)
out = torch.full_like(inp, -1)
expected = torch.full_like(inp, -1)

_legacy_valid_copy(expected, inp, cu_seqlens_padded, cu_seqlens)
tex.thd_copy_valid_tokens_from_per_split_to_rank_local(
out, inp, cu_seqlens_padded, cu_seqlens
)
self.assertTrue(torch.equal(out, expected))

def test_thd_read_half_tensor_reads_each_sequence_half(self):
cu_seqlens = torch.tensor([0, 8, 16], dtype=torch.int32, device="cuda")
q = torch.arange(16 * 2 * 4, dtype=torch.float16, device="cuda").view(16, 2, 4)
kv = torch.arange(2 * 16 * 2 * 4, dtype=torch.float16, device="cuda").view(2, 16, 2, 4)

q_first = tex.thd_read_half_tensor(q, cu_seqlens, 0)
q_second = tex.thd_read_half_tensor(q, cu_seqlens, 1)
kv_first = tex.thd_read_half_tensor(kv, cu_seqlens, 0)
kv_second = tex.thd_read_half_tensor(kv, cu_seqlens, 1)

expected_first = torch.cat([q[0:4], q[8:12]], dim=0)
expected_second = torch.cat([q[4:8], q[12:16]], dim=0)
self.assertTrue(torch.equal(q_first, expected_first))
self.assertTrue(torch.equal(q_second, expected_second))
self.assertTrue(torch.equal(kv_first, torch.stack([expected_first, expected_first + 128])))
self.assertTrue(
torch.equal(kv_second, torch.stack([expected_second, expected_second + 128]))
)

def test_thd_read_second_half_lse_handles_packed_and_batch_major_lse(self):
cu_seqlens = torch.tensor([0, 8, 16], dtype=torch.int32, device="cuda")
lse = torch.arange(2 * 2 * 8, dtype=torch.float32, device="cuda").view(2, 2, 8)
packed_lse = torch.arange(2 * 16, dtype=torch.float32, device="cuda").view(2, 16)

second_half_lse = tex.thd_read_second_half_lse(lse, cu_seqlens, False, 4)
packed_second_half_lse = tex.thd_read_second_half_lse(packed_lse, cu_seqlens, True, 8)

expected = lse[:, :, 4:8]
expected_packed = torch.cat([packed_lse[:, 4:8], packed_lse[:, 12:16]], dim=1)
self.assertTrue(torch.equal(second_half_lse, expected))
self.assertTrue(torch.equal(packed_second_half_lse, expected_packed))


if __name__ == "__main__":
unittest.main()
Loading
Loading