diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 52c7dc067c..3d2f99b51b 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -627,9 +627,8 @@ def run_dpa_with_cp( cu_seqlens_q = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True ) - num_pads_q = (cu_seqlens_q_padded - cu_seqlens_q)[1:] - ( - cu_seqlens_q_padded - cu_seqlens_q - )[:-1] + cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q + num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1] cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size cu_seqlens_kv = get_cu_seqlens_on_cp_rank( cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 59b0e0bdbf..681ee5e6e0 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -45,7 +45,7 @@ torch.manual_seed(seed) torch.cuda.manual_seed(seed) -test_essential = True +test_essential = bool(int(os.getenv("NVTE_TEST_ESSENTIAL", "1"))) model_configs_flash_attn = { # test: ModelConfig(b, sq, hq, dqk) @@ -319,6 +319,14 @@ def test_cp_with_flash_attention(cp_pool, dtype, model, qkv_format, cp_comm_type if cp_comm_type == "a2a+p2p": pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!") + if pad_between_seqs: + if qkv_format != "thd": + pytest.skip("pad_between_seqs only applies to THD format!") + if not FlashAttentionUtils.v3_is_installed: + pytest.skip("pad_between_seqs with CP requires Flash Attention v3!") + if cp_comm_type == "a2a+p2p": + pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!") + config = model_configs_flash_attn[model] config.context_parallel = True config.cp_comm_type = cp_comm_type @@ -328,8 +336,20 @@ def test_cp_with_flash_attention(cp_pool, dtype, model, qkv_format, cp_comm_type if config.attn_bias_type != "no_bias" and cp_comm_type in ["all_gather", "a2a", "a2a+p2p"]: pytest.skip("No support for bias with cp_comm_type={all_gather, a2a, a2a+p2p}!") - if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: - pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!") + if qkv_format == "thd" and cp_comm_type == "a2a+p2p": + pytest.skip( + "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" + " yet!" + ) + if ( + qkv_format == "thd" + and cp_comm_type == "all_gather" + and not FlashAttentionUtils.v3_is_installed + ): + pytest.skip( + "THD + all_gather requires FA3 (seqused_k) to separate tensor offsets from" + " visibility limits in the gathered KV buffer." + ) if ( config.window_size != (-1, 0) @@ -538,8 +558,11 @@ def test_cp_with_fused_attention( if config.attn_bias_type != "no_bias" and cp_comm_type in ["all_gather", "a2a", "a2a+p2p"]: pytest.skip("No support for bias with cp_comm_type={all_gather, a2a, a2a+p2p}!") - if qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: - pytest.skip("No support for THD format with cp_comm_type={all_gather, a2a+p2p}!") + if qkv_format == "thd" and cp_comm_type == "a2a+p2p": + pytest.skip( + "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" + " yet!" + ) if (config.window_size[0] != -1 or config.window_size[1] not in [-1, 0]) and cp_comm_type in [ "p2p", diff --git a/tests/pytorch/attention/test_cp_utils.py b/tests/pytorch/attention/test_cp_utils.py index e5051aab36..c3a423cef5 100644 --- a/tests/pytorch/attention/test_cp_utils.py +++ b/tests/pytorch/attention/test_cp_utils.py @@ -3,6 +3,8 @@ # See LICENSE for license information. """Unit tests for context parallel utils.""" + +import itertools import torch import unittest from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( @@ -11,9 +13,16 @@ generate_positional_ids_for_cp, ) +try: + import transformer_engine_torch as tex +except ImportError: + tex = None + class TestSequencePadding(unittest.TestCase): - def test_padding_with_custom_padding_values_sequences_shorter_than_divisibility_factor(self): + def test_padding_with_custom_padding_values_sequences_shorter_than_divisibility_factor( + self, + ): """Test with custom padding values for all tensors.""" # Setup @@ -467,7 +476,36 @@ def test_sequences_longer_than_divisibility_factor(self): ) expected_positional_ids = torch.tensor( - [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7] + [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + ] ) expected_cu_seqlens_padded = torch.tensor([0, 8, 20, 28]) @@ -710,5 +748,152 @@ def test_integration_with_padding_and_cp_slicing(self): self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) +def _legacy_reorder_thd_to_cp_rank_order(x, cu_seqlens, cp_size, seq_dim=0): + total_slices_of_any_sequence = 2 * cp_size + slice_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]) // total_slices_of_any_sequence + + indices = [ + ( + torch.arange( + seq_start + (cp_rank * slice_size), + seq_start + ((cp_rank + 1) * slice_size), + device=cu_seqlens.device, + ), + torch.arange( + seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size), + seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size), + device=cu_seqlens.device, + ), + ) + for cp_rank in range(cp_size) + for slice_size, seq_start in zip(slice_sizes, cu_seqlens[:-1]) + ] + + indices = list(itertools.chain(*indices)) + indices = torch.cat(indices) + return x.index_select(seq_dim, indices) + + +def _legacy_reorder_thd_to_sequence_order(x, cu_seqlens, seq_chunk_ids, cp_size, seq_dim=0): + max_cum_seqlen_per_cp_rank = cu_seqlens[-1] // cp_size + cu_seqlens_on_any_cp_rank = cu_seqlens // cp_size + + indices = [ + torch.arange( + ( + start + max_cum_seqlen_per_cp_rank * (chunk_id // 2) + if loc < cp_size + else (start + end) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2) + ), + ( + (start + end) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2) + if loc < cp_size + else end + max_cum_seqlen_per_cp_rank * (chunk_id // 2) + ), + device=cu_seqlens.device, + ) + for start, end in zip(cu_seqlens_on_any_cp_rank[:-1], cu_seqlens_on_any_cp_rank[1:]) + for loc, chunk_id in enumerate(seq_chunk_ids) + ] + + indices = torch.cat(indices) + return x.index_select(seq_dim, indices) + + +def _legacy_valid_copy(out, inp, cu_seqlens_padded, cu_seqlens): + batch_size = cu_seqlens.shape[0] - 1 + for b in range(batch_size): + s = cu_seqlens_padded[b].item() + sz = (cu_seqlens[b + 1] - cu_seqlens[b]).item() + if sz > 0: + out[s : s + sz].copy_(inp[s : s + sz]) + + +@unittest.skipIf( + not torch.cuda.is_available() or tex is None, + "THD kernel tests require CUDA and transformer_engine_torch", +) +class TestTHDKernels(unittest.TestCase): + def test_thd_sequence_cp_rank_order_roundtrip_matches_legacy_python_reorder(self): + cp_size = 4 + cu_seqlens = torch.tensor([0, 8, 24, 40], dtype=torch.int32, device="cuda") + x = torch.arange(40 * 2 * 4, dtype=torch.float16, device="cuda").view(40, 2, 4) + + cp_rank_order = tex.thd_sequence_order_to_cp_rank_order(x, cu_seqlens, cp_size, x.shape[0]) + ref_cp_rank_order = _legacy_reorder_thd_to_cp_rank_order(x, cu_seqlens, cp_size) + self.assertTrue(torch.equal(cp_rank_order, ref_cp_rank_order)) + + seq_chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device="cuda") + for rank in range(cp_size): + seq_chunk_ids[rank] = 2 * rank + seq_chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1 + sequence_order = tex.thd_cp_rank_order_to_sequence_order( + cp_rank_order, cu_seqlens, cp_size, cp_rank_order.shape[0] + ) + ref_sequence_order = _legacy_reorder_thd_to_sequence_order( + cp_rank_order, cu_seqlens, seq_chunk_ids, cp_size + ) + self.assertTrue(torch.equal(sequence_order, ref_sequence_order)) + self.assertTrue(torch.equal(sequence_order, x)) + + def test_thd_get_partitioned_indices_matches_dual_chunk_expected_indices(self): + cu_seqlens = torch.tensor([0, 8, 16], dtype=torch.int32, device="cuda") + + rank0 = tex.thd_get_partitioned_indices(cu_seqlens, 16, 2, 0) + rank1 = tex.thd_get_partitioned_indices(cu_seqlens, 16, 2, 1) + + expected_rank0 = torch.tensor([0, 1, 6, 7, 8, 9, 14, 15], dtype=torch.int32, device="cuda") + expected_rank1 = torch.tensor( + [2, 3, 4, 5, 10, 11, 12, 13], dtype=torch.int32, device="cuda" + ) + self.assertTrue(torch.equal(rank0, expected_rank0)) + self.assertTrue(torch.equal(rank1, expected_rank1)) + + def test_thd_copy_valid_tokens_from_per_split_matches_legacy_slice_copy_loop(self): + cu_seqlens_padded = torch.tensor([2, 6, 12], dtype=torch.int32, device="cuda") + cu_seqlens = torch.tensor([0, 3, 7], dtype=torch.int32, device="cuda") + inp = torch.arange(12 * 2 * 4, dtype=torch.float16, device="cuda").view(12, 2, 4) + out = torch.full_like(inp, -1) + expected = torch.full_like(inp, -1) + + _legacy_valid_copy(expected, inp, cu_seqlens_padded, cu_seqlens) + tex.thd_copy_valid_tokens_from_per_split_to_rank_local( + out, inp, cu_seqlens_padded, cu_seqlens + ) + self.assertTrue(torch.equal(out, expected)) + + def test_thd_read_half_tensor_reads_each_sequence_half(self): + cu_seqlens = torch.tensor([0, 8, 16], dtype=torch.int32, device="cuda") + q = torch.arange(16 * 2 * 4, dtype=torch.float16, device="cuda").view(16, 2, 4) + kv = torch.arange(2 * 16 * 2 * 4, dtype=torch.float16, device="cuda").view(2, 16, 2, 4) + + q_first = tex.thd_read_half_tensor(q, cu_seqlens, 0) + q_second = tex.thd_read_half_tensor(q, cu_seqlens, 1) + kv_first = tex.thd_read_half_tensor(kv, cu_seqlens, 0) + kv_second = tex.thd_read_half_tensor(kv, cu_seqlens, 1) + + expected_first = torch.cat([q[0:4], q[8:12]], dim=0) + expected_second = torch.cat([q[4:8], q[12:16]], dim=0) + self.assertTrue(torch.equal(q_first, expected_first)) + self.assertTrue(torch.equal(q_second, expected_second)) + self.assertTrue(torch.equal(kv_first, torch.stack([expected_first, expected_first + 128]))) + self.assertTrue( + torch.equal(kv_second, torch.stack([expected_second, expected_second + 128])) + ) + + def test_thd_read_second_half_lse_handles_packed_and_batch_major_lse(self): + cu_seqlens = torch.tensor([0, 8, 16], dtype=torch.int32, device="cuda") + lse = torch.arange(2 * 2 * 8, dtype=torch.float32, device="cuda").view(2, 2, 8) + packed_lse = torch.arange(2 * 16, dtype=torch.float32, device="cuda").view(2, 16) + + second_half_lse = tex.thd_read_second_half_lse(lse, cu_seqlens, False, 4) + packed_second_half_lse = tex.thd_read_second_half_lse(packed_lse, cu_seqlens, True, 8) + + expected = lse[:, :, 4:8] + expected_packed = torch.cat([packed_lse[:, 4:8], packed_lse[:, 12:16]], dim=1) + self.assertTrue(torch.equal(second_half_lse, expected)) + self.assertTrue(torch.equal(packed_second_half_lse, expected_packed)) + + if __name__ == "__main__": unittest.main() diff --git a/transformer_engine/common/fused_attn/context_parallel.cu b/transformer_engine/common/fused_attn/context_parallel.cu index cf1fffd94f..0f7b820bbd 100644 --- a/transformer_engine/common/fused_attn/context_parallel.cu +++ b/transformer_engine/common/fused_attn/context_parallel.cu @@ -77,6 +77,18 @@ __forceinline__ __device__ int binary_search(int target, int *array, int len) { return left - 1; } +// Dual-chunk source index for THD CP partitioning. cu_seqlens_s must already be divided by +// world_size. Single source of truth shared by thd_partition_indices_kernel and +// thd_reorder_between_sequence_and_cp_rank_order_kernel so the two never diverge. +__forceinline__ __device__ int thd_partition_src_index(int token_id, int *cu_seqlens_s, int batch, + int world_size, int rank) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + int index = token_id - cu_seqlens_s[seq_id]; + int offset = index < seq_len / 2 ? rank : (world_size - 1) * 2 - rank; + return index + cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset; +} + /*************************************************************************************************** * Support THD format for Context Parallel: Generate partitioned indices for input tokens **************************************************************************************************/ @@ -96,12 +108,78 @@ __global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int b int num_threads = blockDim.x * gridDim.x; for (int token_id = tid; token_id < total_tokens / world_size; token_id += num_threads) { - int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); - int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - int index = token_id - cu_seqlens_s[seq_id]; - int offset = index < seq_len / 2 ? rank : (world_size - 1) * 2 - rank; - index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset; - output[token_id] = index; + output[token_id] = thd_partition_src_index(token_id, cu_seqlens_s, batch, world_size, rank); + } +} + +/*************************************************************************************************** + * Fused dual-chunk THD reorder. Computes src inline and copies one leading-dimension + * token entry per warp. + * cp_rank_to_sequence_order=false: out[gi]=inp[src(gi)]. + * cp_rank_to_sequence_order=true: out[src(gi)]=inp[gi]. + **************************************************************************************************/ +__global__ void thd_reorder_between_sequence_and_cp_rank_order_kernel( + void *out, void *inp, int *cu_seqlens, int batch, int total_tokens, int world_size, + int hidden_size_in_bytes, bool cp_rank_to_sequence_order) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + cu_seqlens_s[i] = cu_seqlens[i] / world_size; + } + __syncthreads(); + + int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + int laneid = threadIdx.x % 32; + int num_warps = (blockDim.x * gridDim.x) / 32; + int tpr = total_tokens / world_size; + int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4); + + for (int gi = warpid; gi < total_tokens; gi += num_warps) { + int rank = gi / tpr; + int token_id = gi % tpr; + int src = thd_partition_src_index(token_id, cu_seqlens_s, batch, world_size, rank); + int rd = cp_rank_to_sequence_order ? gi : src; + int wr = cp_rank_to_sequence_order ? src : gi; + float4 *src_tok = reinterpret_cast(reinterpret_cast(inp) + + static_cast(rd) * hidden_size_in_bytes); + float4 *dst_tok = reinterpret_cast(reinterpret_cast(out) + + static_cast(wr) * hidden_size_in_bytes); + for (int idx = laneid; idx < num_float4s_per_token; idx += 32) dst_tok[idx] = src_tok[idx]; + } +} + +/*************************************************************************************************** + * Copy valid token entries from a per-split THD tensor into a rank-local accumulator. + * cu_seqlens_padded gives padded THD token offsets; cu_seqlens gives valid lengths. + **************************************************************************************************/ +__global__ void thd_copy_valid_tokens_from_per_split_to_rank_local_kernel( + void *out, void *inp, int *cu_seqlens_padded, int *cu_seqlens, int batch, int total_tokens, + int hidden_size_in_bytes) { + extern __shared__ int padded_s[]; // [0..batch] padded boundaries + int *valid_s = padded_s + (batch + 1); // [0..batch] valid boundaries + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + padded_s[i] = cu_seqlens_padded[i]; + valid_s[i] = cu_seqlens[i]; + } + __syncthreads(); + + int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + int laneid = threadIdx.x % 32; + int num_warps = (blockDim.x * gridDim.x) / 32; + int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4); + + for (int token_id = warpid; token_id < total_tokens; token_id += num_warps) { + int seq_id = binary_search(token_id, padded_s, batch + 1); + if (seq_id < 0 || seq_id >= batch) continue; + int local = token_id - padded_s[seq_id]; + int valid_len = valid_s[seq_id + 1] - valid_s[seq_id]; + // Later split offsets can shift a sequence start past earlier tokens; skip those entries. + if (local >= 0 && local < valid_len) { + float4 *src_tok = reinterpret_cast( + reinterpret_cast(inp) + static_cast(token_id) * hidden_size_in_bytes); + float4 *dst_tok = reinterpret_cast( + reinterpret_cast(out) + static_cast(token_id) * hidden_size_in_bytes); + for (int idx = laneid; idx < num_float4s_per_token; idx += 32) dst_tok[idx] = src_tok[idx]; + } } } @@ -678,6 +756,63 @@ void thd_get_partitioned_indices(const Tensor &cu_seqlens, Tensor output, int to NVTE_CHECK_CUDA(cudaGetLastError()); } +void thd_reorder_between_sequence_and_cp_rank_order(const Tensor &inp, const Tensor &cu_seqlens, + Tensor &out, int world_size, + bool cp_rank_to_sequence_order, + int total_tokens, cudaStream_t stream) { + using namespace transformer_engine; + NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32); + NVTE_CHECK(cu_seqlens.dim() == 1); + auto cu_seqlens_shape = cu_seqlens.shape(); + NVTE_CHECK(cu_seqlens_shape[0] >= 2); + NVTE_CHECK(world_size > 0); + NVTE_CHECK(total_tokens > 0 && total_tokens % (world_size * 2) == 0); + + auto inp_shape = inp.shape(); + size_t row_elems = 1; + for (int i = 1; i < inp.dim(); i++) row_elems *= inp_shape[i]; + int hidden_size_in_bytes = (row_elems * typeToNumBits(inp.dtype())) / 8; + NVTE_CHECK(hidden_size_in_bytes % 16 == 0); // 128-bit load/store + + int batch = cu_seqlens_shape[0] - 1; + constexpr unsigned int block = 256; + unsigned int grid = (static_cast(total_tokens) * 32 + block - 1) / block; + thd_reorder_between_sequence_and_cp_rank_order_kernel<<>>( + out.data.dptr, inp.data.dptr, reinterpret_cast(cu_seqlens.data.dptr), batch, + total_tokens, world_size, hidden_size_in_bytes, cp_rank_to_sequence_order); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void thd_copy_valid_tokens_from_per_split_to_rank_local(const Tensor &inp, + const Tensor &cu_seqlens_padded, + const Tensor &cu_seqlens, Tensor &out, + int total_tokens, cudaStream_t stream) { + using namespace transformer_engine; + NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32); + NVTE_CHECK(cu_seqlens_padded.dtype() == DType::kInt32); + NVTE_CHECK(cu_seqlens.dim() == 1 && cu_seqlens_padded.dim() == 1); + auto cu_seqlens_shape = cu_seqlens.shape(); + NVTE_CHECK(cu_seqlens_shape[0] >= 2); + NVTE_CHECK(cu_seqlens_padded.shape()[0] == cu_seqlens_shape[0]); + NVTE_CHECK(total_tokens > 0); + + auto inp_shape = inp.shape(); + size_t row_elems = 1; + for (int i = 1; i < inp.dim(); i++) row_elems *= inp_shape[i]; + int hidden_size_in_bytes = (row_elems * typeToNumBits(inp.dtype())) / 8; + NVTE_CHECK(hidden_size_in_bytes % 16 == 0); // 128-bit load/store + + int batch = cu_seqlens_shape[0] - 1; + constexpr unsigned int block = 256; + unsigned int grid = (static_cast(total_tokens) * 32 + block - 1) / block; + thd_copy_valid_tokens_from_per_split_to_rank_local_kernel<<< + grid, block, sizeof(int) * 2 * (batch + 1), stream>>>( + out.data.dptr, inp.data.dptr, reinterpret_cast(cu_seqlens_padded.data.dptr), + reinterpret_cast(cu_seqlens.data.dptr), batch, total_tokens, hidden_size_in_bytes); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + } // namespace context_parallel } // namespace transformer_engine @@ -750,3 +885,38 @@ void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETenso *convertNVTETensorCheck(output), total_tokens, world_size, rank, stream); } + +void nvte_thd_sequence_order_to_cp_rank_order(const NVTETensor &inp, const NVTETensor &cu_seqlens, + NVTETensor out, int world_size, int total_tokens, + cudaStream_t stream) { + NVTE_API_CALL(nvte_thd_sequence_order_to_cp_rank_order); + using namespace transformer_engine; + + context_parallel::thd_reorder_between_sequence_and_cp_rank_order( + *convertNVTETensorCheck(inp), *convertNVTETensorCheck(cu_seqlens), + *convertNVTETensorCheck(out), world_size, false, total_tokens, stream); +} + +void nvte_thd_cp_rank_order_to_sequence_order(const NVTETensor &inp, const NVTETensor &cu_seqlens, + NVTETensor out, int world_size, int total_tokens, + cudaStream_t stream) { + NVTE_API_CALL(nvte_thd_cp_rank_order_to_sequence_order); + using namespace transformer_engine; + + context_parallel::thd_reorder_between_sequence_and_cp_rank_order( + *convertNVTETensorCheck(inp), *convertNVTETensorCheck(cu_seqlens), + *convertNVTETensorCheck(out), world_size, true, total_tokens, stream); +} + +void nvte_thd_copy_valid_tokens_from_per_split_to_rank_local(const NVTETensor &inp, + const NVTETensor &cu_seqlens_padded, + const NVTETensor &cu_seqlens, + NVTETensor out, int total_tokens, + cudaStream_t stream) { + NVTE_API_CALL(nvte_thd_copy_valid_tokens_from_per_split_to_rank_local); + using namespace transformer_engine; + + context_parallel::thd_copy_valid_tokens_from_per_split_to_rank_local( + *convertNVTETensorCheck(inp), *convertNVTETensorCheck(cu_seqlens_padded), + *convertNVTETensorCheck(cu_seqlens), *convertNVTETensorCheck(out), total_tokens, stream); +} diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index d9d2786623..41e4b136bd 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -533,6 +533,56 @@ void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETenso int total_tokens, int world_size, int rank, cudaStream_t stream); +/*! \brief Reorder THD tensor from sequence order to dual-chunk CP rank order. + * + * Uses the padded THD sequence lengths to place each sequence's two CP chunks + * in the order consumed by each CP rank. + * + * \param[in] inp Input THD tensor [total_tokens, ...]. + * \param[in] cu_seqlens Padded cumulative sequence lengths, [batch_size + 1], int32. + * \param[out] out Output tensor, same shape/dtype as inp. + * \param[in] world_size Context-parallel size. + * \param[in] total_tokens Total padded tokens (= inp.shape[0]). + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_thd_sequence_order_to_cp_rank_order(const NVTETensor &inp, const NVTETensor &cu_seqlens, + NVTETensor out, int world_size, int total_tokens, + cudaStream_t stream); + +/*! \brief Reorder THD tensor from dual-chunk CP rank order to sequence order. + * + * Uses the padded THD sequence lengths to restore each sequence's dual-chunk + * CP entries to sequence order. + * + * \param[in] inp Input THD tensor [total_tokens, ...]. + * \param[in] cu_seqlens Padded cumulative sequence lengths, [batch_size + 1], int32. + * \param[out] out Output tensor, same shape/dtype as inp. + * \param[in] world_size Context-parallel size. + * \param[in] total_tokens Total padded tokens (= inp.shape[0]). + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_thd_cp_rank_order_to_sequence_order(const NVTETensor &inp, const NVTETensor &cu_seqlens, + NVTETensor out, int world_size, int total_tokens, + cudaStream_t stream); + +/*! \brief Copy valid token entries from a per-split THD tensor to a rank-local accumulator. + * + * For each dual-chunk CP step/split, copies each sequence's valid range at + * its padded THD token offsets and leaves padded entries untouched. + * + * \param[in] inp Per-split THD source tensor [total_tokens, ...]. + * \param[in] cu_seqlens_padded Padded cumulative sequence lengths, [batch_size + 1], int32. + * \param[in] cu_seqlens Valid cumulative sequence lengths, [batch_size + 1], int32. + * \param[in,out] out Rank-local accumulator, same shape/dtype as inp. + * \param[in] total_tokens Total padded tokens (= inp.shape[0]). + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_thd_copy_valid_tokens_from_per_split_to_rank_local(const NVTETensor &inp, + const NVTETensor &cu_seqlens_padded, + const NVTETensor &cu_seqlens, + NVTETensor out, int total_tokens, + cudaStream_t stream); + /*! \brief Convert tensor from THD to BSHD format. * * \warning This API is **experimental** and subject to change. diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 36847e40ed..61a46a8652 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4,7 +4,6 @@ """Context Parallelism.""" import os -import itertools from typing import List, Union, Tuple import torch import transformer_engine_torch as tex @@ -54,6 +53,7 @@ _seq_chunk_ids_cache_for_reordering_after_attn = {} _softmax_offset_chunk_ids_cache = {} + # Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16 _dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" @@ -295,14 +295,13 @@ def reorder_seq_chunks_for_a2a_after_attn(x, chunk_ids_for_a2a, seq_dim, cp_size return x -def reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens, cp_size, seq_dim=0): +def thd_sequence_order_to_cp_rank_order(x, cu_seqlens, cp_size, seq_dim=0): """ - Reorder sequence chunks for A2A communication that happens after attention - compute. + Reorder a THD tensor from sequence order to dual-chunk CP rank order. Args: x: The input tensor to be reordered. - cu_seqlens: The cumulative sequence lengths of the input tensor. + cu_seqlens: The padded cumulative sequence lengths of the input tensor. cp_size: The number of ranks participating in context parallelism. seq_dim: The dimension in which to reorder. @@ -321,10 +320,8 @@ def reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens, cp_size, seq_dim 10., 11., 3., 4., 3., 4., 3., 4., 6., 7., 8., 9.] - This logic is similar to how the DualChunking is done to split the sequence - for each rank. Here, the indices of sequence chunks for all those ranks - are concatenated together. So the returned tensor ends up looking like as if - the chunks from all the ranks are concatenated together. + This follows the same dual-chunk CP partitioning used to split each sequence + across ranks. The output concatenates each rank's two chunks in CP rank order. e.g. [ 0., 7., 0., 7., 0., 7., 0., 1., 14., 15., # chunk on rank 0 @@ -333,43 +330,19 @@ def reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens, cp_size, seq_dim 3., 4., 3., 4., 3., 4., 6., 7., 8., 9. # chunk on rank 3 ] """ - total_slices_of_any_sequence = 2 * cp_size - slice_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]) // total_slices_of_any_sequence - - indices = [ - ( - # 1st segment - torch.arange( - seq_start + (cp_rank * slice_size), - seq_start + ((cp_rank + 1) * slice_size), - device=cu_seqlens.device, - ), - # 2nd segment - torch.arange( - seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size), - seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size), - device=cu_seqlens.device, - ), - ) - for cp_rank in range(cp_size) - for slice_size, seq_start in zip(slice_sizes, cu_seqlens[:-1]) - ] - - # flatten the list of tuples to a list - indices = list(itertools.chain(*indices)) - indices = torch.cat(indices) - return x.index_select(seq_dim, indices) + assert ( + seq_dim == 0 + ), "tex.thd_sequence_order_to_cp_rank_order operates on the leading THD token dimension" + return tex.thd_sequence_order_to_cp_rank_order(x, cu_seqlens, cp_size, x.shape[seq_dim]) -def reorder_seq_chunks_after_a2a_before_attn_thd(x, cu_seqlens, seq_chunk_ids, cp_size, seq_dim=0): +def thd_cp_rank_order_to_sequence_order(x, cu_seqlens, cp_size, seq_dim=0): """ - Reorder sequence chunks for A2A communication that happens before attention - compute. + Reorder a THD tensor from dual-chunk CP rank order to sequence order. Args: x: The input tensor to be reordered. - cu_seqlens: The cumulative sequence lengths of the input tensor. - seq_chunk_ids: The sequence chunk ids of the input `x` which is to be reordered. + cu_seqlens: The padded cumulative sequence lengths of the input tensor. cp_size: The number of ranks participating in context parallelism. seq_dim: The dimension in which to reorder. @@ -381,15 +354,14 @@ def reorder_seq_chunks_after_a2a_before_attn_thd(x, cu_seqlens, seq_chunk_ids, c 1., 6., 2., 3., 12., 13., 2., 5., 2., 5., 2., 5., 4., 5., 10., 11., 3., 4., 3., 4., 3., 4., 6., 7., 8., 9.] cu_seqlens: [ 0, 8, 16, 24, 40] - seq_chunk_ids: [ 0, 2, 4, 6, 7, 5, 3, 1] cp_size: 4 Returns: [ 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.] - Note that the input sequences (x) are arranged after A2A communication as if DualChunked - chunks on all the ranks are concatenated together in the `seq_dim`. + Note that the input sequences (x) are arranged as if dual-chunk CP chunks on all + ranks are concatenated together in `seq_dim`. e.g. [ 0., 7., 0., 7., 0., 7., 0., 1., 14., 15., # chunk on rank 0 @@ -398,41 +370,18 @@ def reorder_seq_chunks_after_a2a_before_attn_thd(x, cu_seqlens, seq_chunk_ids, c 3., 4., 3., 4., 3., 4., 6., 7., 8., 9. # chunk on rank 3 ] - Then the logic to serialize the sequences is: - 1. For every sequence segment on any rank (denoted by `start` and `end`): + Then the logic to restore sequence order is: + 1. For every sequence range on any rank (denoted by `start` and `end`): 1a. For every chunk (in `chunk_id` and the total of those are twice as many as the number of CP ranks) : 1aa. The first `cp_size` number of chunks form the first half of the whole sequence. Get those indices. 1ab. The second `cp_size` number of chunks form the second half of the whole sequence. Get those indices. 1b. Concatenate the indices of the first half and the second half. 2. Reorder the entire input tensor by those indices. """ - - max_cum_seqlen_per_cp_rank = cu_seqlens[-1] // cp_size - cu_seqlens_on_any_cp_rank = cu_seqlens // cp_size - - # Go through all the sequence segments (the sizes should be the same from all the ranks) - indices = [ - torch.arange( - # Calculate 'left' boundary - ( - start + max_cum_seqlen_per_cp_rank * (chunk_id // 2) - if loc < cp_size - else (start + end) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2) - ), - # Calculate 'right' boundary - ( - (start + end) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2) - if loc < cp_size - else end + max_cum_seqlen_per_cp_rank * (chunk_id // 2) - ), - device=cu_seqlens.device, - ) - for start, end in zip(cu_seqlens_on_any_cp_rank[:-1], cu_seqlens_on_any_cp_rank[1:]) - for loc, chunk_id in enumerate(seq_chunk_ids) - ] - - indices = torch.cat(indices) - return x.index_select(seq_dim, indices) + assert ( + seq_dim == 0 + ), "tex.thd_cp_rank_order_to_sequence_order operates on the leading THD token dimension" + return tex.thd_cp_rank_order_to_sequence_order(x, cu_seqlens, cp_size, x.shape[seq_dim]) def flash_attn_a2a_communicate( @@ -501,8 +450,8 @@ def flash_attn_a2a_communicate( # [cp, t, h//cp, d] -> [cp*t, h//cp, d] x = x.view(-1, *x.shape[2:]) # reorder the sequence chunks - a2a_outputs[i - 2] = reorder_seq_chunks_after_a2a_before_attn_thd( - x, cu_seqlens_padded, chunk_ids_for_a2a, cp_size + a2a_outputs[i - 2] = thd_cp_rank_order_to_sequence_order( + x, cu_seqlens_padded, cp_size ) if i < len(a2a_inputs): @@ -547,7 +496,7 @@ def flash_attn_a2a_communicate( else cu_seqlens_kv_padded ) # reorder the sequence chunks - x = reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens_padded, cp_size) + x = thd_sequence_order_to_cp_rank_order(x, cu_seqlens_padded, cp_size) # [cp*t, h//cp, d] -> [cp, t, h//cp, d] a2a_inputs[i] = x.view(cp_size, -1, *x.shape[-2:]) if i > 1: @@ -2991,13 +2940,23 @@ def backward(ctx, dout, *_args): def get_kv_seq_info_after_all_gather( local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size, causal ): - """Compute KV sequence index range and update window size after all-gather.""" + """Return the visible KV range and adjusted window for one AG CP step. + + bshd/sbhd slices K/V with ``kv_range``. THD keeps full K/V and uses the end + bound as ``max_seqlen_kv`` because per-sequence visibility is passed through + ``cu_seqlens`` or ``seqused_k``. ``adjusted_window`` compensates for + bottom-right alignment after KV trimming. + + Returns: + Tuple of ``(kv_range, adjusted_window)``. + """ local_chunk_end_idx = (local_chunk_id + 1) * max_seqlen_kv full_seq_end_idx = max_seqlen_kv * cp_size * 2 if window_size is None: window_size = (-1, 0) if causal else (-1, -1) + # Right boundary: how far past chunk_end the kernel can see. if window_size[1] == -1: seq_end_idx = full_seq_end_idx window_size_right = -1 @@ -3005,6 +2964,7 @@ def get_kv_seq_info_after_all_gather( seq_end_idx = min(full_seq_end_idx, local_chunk_end_idx + window_size[1]) window_size_right = local_chunk_end_idx + window_size[1] - seq_end_idx + # Left boundary: how far before chunk_end - q_len the kernel can see. if window_size[0] == -1: seq_start_idx = 0 window_size_left = -1 @@ -3019,6 +2979,14 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): """ Attention implementation with context parallelism. KV all-gather between CP ranks is exposed. Refer section 3.3.2 of `The Llama 3 Herd of Models `_. + + THD all-gather needs separate tensor offsets and visible lengths. After + K/V all-gather, each sequence occupies its padded offset in the reordered + buffer, while causal/SWA steps may expose only a prefix of valid tokens. + FusedAttention carries this split with ``cu_seqlens`` plus + ``cu_seqlens_padded``; FlashAttention v3 uses layout ``cu_seqlens`` plus + ``seqused_k``. FlashAttention v2 cannot represent both values, so THD + all-gather is restricted to FusedAttention or FlashAttention v3. """ @staticmethod @@ -3029,9 +2997,11 @@ def forward( k, v, cu_seqlens_q, + cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, cu_seqlens_q_padded, + cu_seqlens_kv_padded, dropout_p, softmax_scale, qkv_format, @@ -3045,6 +3015,7 @@ def forward( cp_group, cp_stream, use_flash_attn_3, + pad_between_seqs, fp8, fp8_meta, quantizers, @@ -3061,9 +3032,17 @@ def forward( if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - assert qkv_format != "thd", f"No support for cp_comm_type='all_gather' and {qkv_format=}." + causal = "causal" in attn_mask_type + padding = "padding" in attn_mask_type + if qkv_format == "thd": + # THD always uses padding mask types; per-step masks set internally + assert padding, f"THD format requires padding mask type, got {attn_mask_type}!" + # AG CP uses shorter per-step Q against longer KV, so causal masks need + # bottom-right alignment for both sliced and THD paths. + if use_fused_attention and causal and "bottom_right" not in attn_mask_type: + attn_mask_type = attn_mask_type + "_bottom_right" assert ( - "padding" not in attn_mask_type + qkv_format == "thd" or "padding" not in attn_mask_type ), f"No support for cp_comm_type='all_gather' and {attn_mask_type=}." assert ( attn_bias_type == "no_bias" @@ -3113,11 +3092,24 @@ def forward( if fa_utils.v2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + + if qkv_format == "thd": + # Save original global cu_seqlens before division + cu_seqlens_q_original = cu_seqlens_q.clone() + cu_seqlens_kv_original = cu_seqlens_kv.clone() + else: + seq_dim = qkv_format.index("s") + assert ( + q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 + ), "Sequence length per GPU needs to be divisible by 2!" + + # Divide by 2*cp_size to get per-chunk values max_seqlen_q = max_seqlen_q // (2 * cp_size) max_seqlen_kv = max_seqlen_kv // (2 * cp_size) - if use_fused_attention or qkv_format == "thd": + if use_fused_attention and qkv_format != "thd": cu_seqlens_q = cu_seqlens_q // (2 * cp_size) - if cu_seqlens_q_padded is not None and qkv_format == "thd": + if qkv_format == "thd": cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size) else: cu_seqlens_q_padded = None @@ -3166,34 +3158,48 @@ def forward( orig_q_shape, _, orig_v_shape = q.shape, k.shape, v.shape orig_o_shape = orig_q_shape[:-1] + orig_v_shape[-1:] - # q, k, v: - # FP8DS/CS: torch.uint8 - # MXFP8/F16: torch.float16 or torch.bfloat16 - # reshape: split s - # [b, s, h, d] -> [b, 2, s//2, h, d] - # [s, b, h, d] -> [2, s//2, b, h, d] - q = q.view( - *q.shape[:seq_dim_qkv], 2, q.shape[seq_dim_qkv] // 2, *q.shape[(seq_dim_qkv + 1) :] - ) - # s dim first for all-gather - # [b, s, h, d]/[s, b, h, d] -> [s, b, h, d] - k, v = [x.movedim(seq_dim_qkv, 0).contiguous() for x in [k, v]] + if qkv_format != "thd": + # q, k, v: + # FP8DS/CS: torch.uint8 + # MXFP8/F16: torch.float16 or torch.bfloat16 + # reshape: split s + # [b, s, h, d] -> [b, 2, s//2, h, d] + # [s, b, h, d] -> [2, s//2, b, h, d] + q = q.view( + *q.shape[:seq_dim_qkv], 2, q.shape[seq_dim_qkv] // 2, *q.shape[(seq_dim_qkv + 1) :] + ) + # s dim first for all-gather + # [b, s, h, d]/[s, b, h, d] -> [s, b, h, d] + k, v = [x.movedim(seq_dim_qkv, 0).contiguous() for x in [k, v]] - # gather along s: [s, b, h, d] -> [cp, s, b, h, d] + # AllGather K/V across CP ranks + # gather along s or t: [s, b, h, d] -> [cp, s, b, h, d] or [t, h, d] -> [cp*t, h, d] k_ag, _ = gather_along_first_dim(k, cp_group) v_ag, _ = gather_along_first_dim(v, cp_group) - # split s:[cp, s, b, h, d] -> [cp*2, s//2, b, h, d] - k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) - v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) - # pick out specific chunks for each rank - chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) - k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) - v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # reshape/flatten: [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] - k_ag = k_ag.view(-1, *k.shape[1:]) - v_ag = v_ag.view(-1, *v.shape[1:]) + + if qkv_format == "thd": + # [cp*t, h, d] -> reorder to sequence order -> [t_full, h, d] + # The padded cu_seqlens are global sequence offsets. Reorder uses them to + # derive per-sequence chunk boundaries. + k_ag = thd_cp_rank_order_to_sequence_order(k_ag, cu_seqlens_kv_padded, cp_size) + v_ag = thd_cp_rank_order_to_sequence_order(v_ag, cu_seqlens_kv_padded, cp_size) + else: + # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] + k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) + v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) + k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) + v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) + # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + k_ag = k_ag.view(-1, *k.shape[1:]) + v_ag = v_ag.view(-1, *v.shape[1:]) + # cp_stream is used for step 1 of the per-step loop and must wait until + # k_ag/v_ag preparation finishes on the current stream — otherwise step 1 + # races against AG/reorder writes. Manifests at high cp_size where reorder + # is large enough to outlast cp_stream's launch (e.g. bucket128k @ cp=8). cp_stream.wait_stream(torch.cuda.current_stream()) + # THD all_gather only reaches this path for f16/bf16 attention today. # q: [b, 2, s//2, h, d] or [2, s//2, b, h, d] # k: [s, b, h, d] # v: [s, b, h, d] @@ -3214,17 +3220,144 @@ def forward( out_per_step = [None, None] softmax_lse_per_step = [None, None] rng_states = [None, None] + # THD per-split kernels may leave padded entries untouched; valid-copy only + # writes valid token entries, so keep the final accumulator zero-initialized. + out = torch.zeros(o_shape, dtype=fwd_nominal_dtype, device=q.device) max_logit_per_step = [None, None] max_logit = None + # Pre-compute THD-specific per-step cu_seqlens + if qkv_format == "thd": + # Rank-level padded offsets (2 chunks per sequence on this rank) + cu_seqlens_q_padded_rank = cu_seqlens_q_padded * 2 + + # Per-step Q cu_seqlens (non-padded): different per step since different + # chunks may have different valid token counts for non-divisible seqlens. + thd_cu_seqlens_q_per_step = [ + get_cu_seqlens_on_cp_rank( + cu_seqlens_q_original, + cu_seqlens_q_padded_rank, + cp_size, + rank, + True, + False, + ), + get_cu_seqlens_on_cp_rank( + cu_seqlens_q_original, + cu_seqlens_q_padded_rank, + cp_size, + rank, + False, + True, + ), + ] + + # Per-step Q cu_seqlens_padded: offset-based approach — pass full Q tensor + # and vary cu_seqlens_q_padded to point kernel at the correct chunk. + # cuDNN uses back-padding (valid tokens at beginning of padded allocation). + padded_chunk_sizes_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1] + + # Step 0: kernel reads from start of each seq's 2-chunk allocation (first chunk) + # Step 1: kernel reads from midpoint of each seq's allocation (second chunk) + cu_seqlens_q_padded_step_1 = cu_seqlens_q_padded_rank.clone() + cu_seqlens_q_padded_step_1[:-1] += padded_chunk_sizes_q + thd_cu_seqlens_q_padded_per_step = [ + cu_seqlens_q_padded_rank, + cu_seqlens_q_padded_step_1, + ] + + thd_cu_seqlens_kv_per_step = [ + cu_seqlens_kv_original.clone(), + cu_seqlens_kv_original.clone(), + ] + + sliding_window_attn = ( + window_size is not None and window_size != (-1, 0) and window_size != (-1, -1) + ) + if causal or sliding_window_attn: + actual_seqlens_kv = cu_seqlens_kv_original[1:] - cu_seqlens_kv_original[:-1] + padded_chunk_sizes_kv = (cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1]) // ( + 2 * cp_size + ) + # Visible KV covers chunks 0..chunk_id so bottom-right alignment + # places this Q chunk at the right offset. + visible_padded = [ + padded_chunk_sizes_kv * (chunk_id + 1) for chunk_id in local_seq_chunk_ids + ] + # Right-window SWA extends visibility past the chunk boundary. + if window_size is not None and window_size[1] > 0: + visible_padded = [vp + window_size[1] for vp in visible_padded] + visible_actual = [ + torch.minimum(actual_seqlens_kv, visible_padded_split) + for visible_padded_split in visible_padded + ] + thd_cu_seqlens_kv_per_step = [ + torch.zeros_like(cu_seqlens_kv_original) for _ in range(2) + ] + # Adjust chunks for each step + thd_cu_seqlens_kv_per_step[0][1:] = visible_actual[0].cumsum(0) + thd_cu_seqlens_kv_per_step[1][1:] = visible_actual[1].cumsum(0) + for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): + # FA3 uses internal per-call workspace. Consecutive AG per-step + # calls are serialized on GPU streams so that workspace lifetimes + # do not overlap. FusedAttention keeps the existing per-step overlap. + if i > 0 and use_flash_attn_3: + flash_attn_streams[i].wait_stream(flash_attn_streams[i - 1]) with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, s//2, h, d] -> [b, s//2, h, d] - # [2, s//2, b, h, d] -> [s//2, b, h, d] - q_part = q.select(seq_dim_qkv, i).contiguous() - kv_seq_range_per_step[i], window_size_per_step[i] = ( - get_kv_seq_info_after_all_gather( + new_qkv_layout = qkv_layout + qkv_scale_inv_format = None + if qkv_format in ["bshd", "sbhd"]: + # [b, 2, s//2, h, d] -> [b, s//2, h, d] + # [2, s//2, b, h, d] -> [s//2, b, h, d] + q_part = q.select(seq_dim_qkv, i).contiguous() + kv_seq_range_per_step[i], window_size_per_step[i] = ( + get_kv_seq_info_after_all_gather( + local_seq_chunk_ids[i], + cp_size, + max_seqlen_q, + max_seqlen_kv, + window_size, + causal, + ) + ) + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i][0], + kv_seq_range_per_step[i][1], + ) + max_seqlen_kv_ = seq_end_idx - seq_start_idx + + # select range: [s_range, b, h, d] + k_part, v_part = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # reshape to original format: [b, s_range, h, d] or [s_range, b, h, d] + k_part, v_part = [ + x.movedim(0, seq_dim_qkv).contiguous() for x in [k_part, v_part] + ] + if use_fused_attention: + cu_seqlens_kv_per_step[i] = dpa_utils.get_full_cu_seqlens( + cu_seqlens_q.shape[0] - 1, max_seqlen_kv_, q.device + ) + if fp8: + if not fp8_recipe.mxfp8(): + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip( + [q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part] + ) + ] + else: + q_part, k_part, v_part, new_qkv_layout, qkv_scale_inv_format = ( + combine_and_quantize( + qkv_layout, q_part, k_part, v_part, QKV_quantizer + ) + ) + elif qkv_format == "thd": + # THD passes full Q/KV; per-step cu_seqlens select chunks. + q_part = q + k_part = k_ag + v_part = v_ag + kv_range, window_size_per_step[i] = get_kv_seq_info_after_all_gather( local_seq_chunk_ids[i], cp_size, max_seqlen_q, @@ -3232,37 +3365,18 @@ def forward( window_size, causal, ) - ) - seq_start_idx, seq_end_idx = ( - kv_seq_range_per_step[i][0], - kv_seq_range_per_step[i][1], - ) - max_seqlen_kv_ = seq_end_idx - seq_start_idx - if use_fused_attention or qkv_format == "thd": - cu_seqlens_kv_per_step[i] = dpa_utils.get_full_cu_seqlens( - k.shape[1], max_seqlen_kv_, k.device - ) - # select range: [s_range, b, h, d] - k_part, v_part = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # reshape to original format: [b, s_range, h, d] or [s_range, b, h, d] - k_part, v_part = [ - x.movedim(0, seq_dim_qkv).contiguous() for x in [k_part, v_part] - ] + max_seqlen_kv_ = kv_range[1] + cu_seqlens_kv_per_step[i] = thd_cu_seqlens_kv_per_step[i] if use_fused_attention: - new_qkv_layout = qkv_layout - qkv_scale_inv_format = None - if fp8: - if not fp8_recipe.mxfp8(): - q_part, k_part, v_part = [ - Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) - for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) - ] - else: - q_part, k_part, v_part, new_qkv_layout, qkv_scale_inv_format = ( - combine_and_quantize( - qkv_layout, q_part, k_part, v_part, QKV_quantizer - ) - ) + # Set per-step parameters for THD vs bshd/sbhd + if qkv_format == "thd": + cu_seqlens_q_ = thd_cu_seqlens_q_per_step[i] + cu_seqlens_q_padded_ = thd_cu_seqlens_q_padded_per_step[i] + cu_seqlens_kv_padded_ = cu_seqlens_kv_padded + else: + cu_seqlens_q_ = cu_seqlens_q + cu_seqlens_q_padded_ = cu_seqlens_q_padded + cu_seqlens_kv_padded_ = cu_seqlens_kv_per_step[i] ( out_per_step[i], aux_ctx_tensors, @@ -3271,7 +3385,7 @@ def forward( is_training, max_seqlen_q, max_seqlen_kv_, - cu_seqlens_q, + cu_seqlens_q_, cu_seqlens_kv_per_step[i], q_part, k_part, @@ -3285,8 +3399,8 @@ def forward( attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, attn_bias=attn_bias, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], + cu_seqlens_q_padded=cu_seqlens_q_padded_, + cu_seqlens_kv_padded=cu_seqlens_kv_padded_, window_size=window_size_per_step[i], return_max_logit=return_max_logit, cuda_graph=is_graph_capturing(), @@ -3302,14 +3416,31 @@ def forward( if fp8 and isinstance(out_per_step[i], QuantizedTensorStorage): out_per_step[i] = out_per_step[i].dequantize(dtype=fwd_nominal_dtype) else: + seqused_q = None + seqused_k = None + fa_cu_seqlens_q = ( + thd_cu_seqlens_q_per_step[i] if qkv_format == "thd" else cu_seqlens_q + ) + fa_cu_seqlens_kv = cu_seqlens_kv_per_step[i] + if use_flash_attn_3 and qkv_format == "thd": + seqused_q = ( + thd_cu_seqlens_q_per_step[i][1:] - thd_cu_seqlens_q_per_step[i][:-1] + ) + seqused_k = ( + cu_seqlens_kv_per_step[i][1:] - cu_seqlens_kv_per_step[i][:-1] + ) + fa_cu_seqlens_q = thd_cu_seqlens_q_padded_per_step[i] + fa_cu_seqlens_kv = cu_seqlens_kv_padded fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, qkv_format, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv_per_step[i], + cu_seqlens_q=fa_cu_seqlens_q, + cu_seqlens_kv=fa_cu_seqlens_kv, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv_, + seqused_q=seqused_q, + seqused_k=seqused_k, ) if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size"] = window_size_per_step[i] @@ -3346,6 +3477,16 @@ def forward( out_f16[:, i - 1].copy_(out_per_step[i - 1]) elif o_format == "sbhd": out_f16[i - 1].copy_(out_per_step[i - 1]) + elif qkv_format == "thd": + # Copy every sequence's valid token range from this split's output. + # Each split writes to distinct positions. + tex.thd_copy_valid_tokens_from_per_split_to_rank_local( + out, + out_per_step[i - 1], + thd_cu_seqlens_q_padded_per_step[i - 1], + thd_cu_seqlens_q_per_step[i - 1], + ) + if return_max_logit: # max_logit_per_step[i-1] was written on flash_attn_streams[i-1] # (cp_stream for i-1=1). The torch.maximum below runs on the @@ -3363,10 +3504,13 @@ def forward( max_logit, op=torch.distributed.ReduceOp.MAX, group=cp_group ) - # out_f16: fwd_nominal_dtype - # [b, 2, s//2, h, d] -> [b, s, h, d] - # [2, s//2, b, h, d] -> [s, b, h, d] - out_f16 = out_f16.view(orig_o_shape) + if qkv_format == "thd": + out_f16 = out + else: + # out_f16: fwd_nominal_dtype + # [b, 2, s//2, h, d] -> [b, s, h, d] + # [2, s//2, b, h, d] -> [s, b, h, d] + out_f16 = out_f16.view(orig_o_shape) # prepare for forward output and backward saves of out out_fp8 = None @@ -3464,6 +3608,13 @@ def forward( ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention ctx.use_flash_attn_3 = use_flash_attn_3 + ctx.pad_between_seqs = pad_between_seqs + ctx.window_size = window_size + if qkv_format == "thd": + ctx.max_seqlen_kv = max_seqlen_kv + ctx.cu_seqlens_kv_padded = cu_seqlens_kv_padded + ctx.thd_cu_seqlens_q_per_step = thd_cu_seqlens_q_per_step + ctx.thd_cu_seqlens_q_padded_per_step = thd_cu_seqlens_q_padded_per_step ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 @@ -3571,7 +3722,11 @@ def backward(ctx, dout, *_args): # dq: fwd_nominal_dtype, [b, 2, s//2, h, d] or [2, s//2, b, h, d] # dk: fwd_nominal_dtype, [cp*s, b, h, d] # dv: fwd_nominal_dtype, [cp*s, b, h, d] - dq = torch.empty(ctx.q_shape, dtype=ctx.fwd_nominal_dtype, device=q.device) + dq = ( + torch.zeros(ctx.q_shape, dtype=ctx.fwd_nominal_dtype, device=q.device) + if ctx.qkv_format == "thd" + else torch.empty(ctx.q_shape, dtype=ctx.fwd_nominal_dtype, device=q.device) + ) dk = torch.zeros( (ctx.k_shape[0] * cp_size, *ctx.k_shape[1:]), dtype=ctx.fwd_nominal_dtype, @@ -3591,19 +3746,30 @@ def backward(ctx, dout, *_args): # synchronize dkv update across steps dkv_update_done = torch.cuda.Event() - # gather k and v along s: [s, b, h, d] -> [cp, s, b, h, d] + # gather k and v along s or t: [s, b, h, d] -> [cp, s, b, h, d] or [t, h, d] -> [cp*t, h, d] k_ag, _ = gather_along_first_dim(k, ctx.cp_group) v_ag, _ = gather_along_first_dim(v, ctx.cp_group) - # split s: [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] - k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) - v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) - # select appropriate chunks for each rank - chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) - k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) - v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # flatten: [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] - k_ag = k_ag.view(-1, *k.shape[1:]) - v_ag = v_ag.view(-1, *v.shape[1:]) + + if ctx.qkv_format == "thd": + cu_seqlens_kv_padded = ctx.cu_seqlens_kv_padded + thd_cu_seqlens_q_per_step = ctx.thd_cu_seqlens_q_per_step + # [cp*t, h, d] -> reorder to sequence order + # Use padded cu_seqlens (divisible by 2*cp_size) for correct reorder + k_ag = thd_cp_rank_order_to_sequence_order(k_ag, cu_seqlens_kv_padded, cp_size) + v_ag = thd_cp_rank_order_to_sequence_order(v_ag, cu_seqlens_kv_padded, cp_size) + + thd_cu_seqlens_q_padded_per_step = ctx.thd_cu_seqlens_q_padded_per_step + else: + # split s: [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] + k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) + v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + # select appropriate chunks for each rank + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) + k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) + v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) + # flatten: [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + k_ag = k_ag.view(-1, *k.shape[1:]) + v_ag = v_ag.view(-1, *v.shape[1:]) ctx.cp_stream.wait_stream(torch.cuda.current_stream()) # set up flash_attn_bwd @@ -3641,26 +3807,60 @@ def backward(ctx, dout, *_args): local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): + # FA3 uses internal per-call workspace. Consecutive AG per-step + # backward calls are serialized on GPU streams so that workspace + # lifetimes do not overlap. FusedAttention keeps the existing + # per-step overlap. + if i > 0 and ctx.use_flash_attn_3: + flash_attn_streams[i].wait_stream(flash_attn_streams[i - 1]) with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, s//2, h, d] -> [b, s//2, h, d] - # [2, s//2, b, h, d] -> [s//2, b, h, d] - q_part = q.select(seq_dim_qkv, i).contiguous() - seq_start_idx, seq_end_idx = ( - kv_seq_range_per_step[i][0], - kv_seq_range_per_step[i][1], - ) - max_seqlen_kv = seq_end_idx - seq_start_idx - # select range: [s_range, b, h, d] - k_part, v_part = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # reshape to original format: [b, s_range, h, d] or [s_range, b, h, d] - k_part, v_part = [ - x.movedim(0, seq_dim_qkv).contiguous() for x in [k_part, v_part] - ] - # [b, 2, s//2, h, d] -> [b, s//2, h, d] - # [2, s//2, b, h, d] -> [s//2, b, h, d] - out_part = out.select(seq_dim_o, i).contiguous() - dout_part = dout.select(seq_dim_o, i).contiguous() + if ctx.qkv_format == "thd": + # THD passes full Q/dout; per-step cu_seqlens select chunks. + q_part = q + k_part = k_ag + v_part = v_ag + kv_range, _ = get_kv_seq_info_after_all_gather( + local_seq_chunk_ids[i], + cp_size, + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + ctx.window_size, + "causal" in ctx.attn_mask_type, + ) + max_seqlen_kv = kv_range[1] + out_part = out + dout_part = dout + else: + # [b, 2, s//2, h, d] -> [b, s//2, h, d] + # [2, s//2, b, h, d] -> [s//2, b, h, d] + q_part = q.select(seq_dim_qkv, i).contiguous() + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i][0], + kv_seq_range_per_step[i][1], + ) + max_seqlen_kv = seq_end_idx - seq_start_idx + # select range: [s_range, b, h, d] + k_part, v_part = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # reshape to original format: [b, s_range, h, d] or [s_range, b, h, d] + k_part, v_part = [ + x.movedim(0, seq_dim_qkv).contiguous() for x in [k_part, v_part] + ] + # [b, 2, s//2, h, d] -> [b, s//2, h, d] + # [2, s//2, b, h, d] -> [s//2, b, h, d] + out_part = out.select(seq_dim_o, i).contiguous() + dout_part = dout.select(seq_dim_o, i).contiguous() + if ctx.use_fused_attention: + # Set per-step parameters for THD + if ctx.qkv_format == "thd": + cu_seqlens_q_ = thd_cu_seqlens_q_per_step[i] + cu_seqlens_q_padded_ = thd_cu_seqlens_q_padded_per_step[i] + cu_seqlens_kv_padded_ = cu_seqlens_kv_padded + else: + cu_seqlens_q_ = cu_seqlens_q + cu_seqlens_q_padded_ = cu_seqlens_q_padded + cu_seqlens_kv_padded_ = cu_seqlens_kv_per_step[i] + aux_ctx_tensors = [ softmax_lse_per_step[i], rng_states[i], @@ -3714,7 +3914,7 @@ def backward(ctx, dout, *_args): dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd( ctx.max_seqlen_q, max_seqlen_kv, - cu_seqlens_q, + cu_seqlens_q_, cu_seqlens_kv_per_step[i], q_part, k_part, @@ -3724,8 +3924,8 @@ def backward(ctx, dout, *_args): ctx.fwd_nominal_dtype, aux_ctx_tensors, fused_attn_backend, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], + cu_seqlens_q_padded=cu_seqlens_q_padded_, + cu_seqlens_kv_padded=cu_seqlens_kv_padded_, attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, qkv_layout=new_qkv_layout, @@ -3750,20 +3950,44 @@ def backward(ctx, dout, *_args): for x in [dq_per_step[i], dk_per_step[i], dv_per_step[i]] ] else: - dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ - torch.empty_like(x) for x in [q_part, k_part, v_part] - ] + if ctx.use_flash_attn_3 and ctx.qkv_format == "thd": + dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ + torch.zeros_like(x) for x in [q_part, k_part, v_part] + ] + else: + dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ + torch.empty_like(x) for x in [q_part, k_part, v_part] + ] + seqused_q = None + seqused_k = None + fa_cu_seqlens_q = ( + thd_cu_seqlens_q_per_step[i] + if ctx.qkv_format == "thd" + else cu_seqlens_q + ) + fa_cu_seqlens_kv = cu_seqlens_kv_per_step[i] + if ctx.use_flash_attn_3 and ctx.qkv_format == "thd": + seqused_q = ( + thd_cu_seqlens_q_per_step[i][1:] - thd_cu_seqlens_q_per_step[i][:-1] + ) + seqused_k = ( + cu_seqlens_kv_per_step[i][1:] - cu_seqlens_kv_per_step[i][:-1] + ) + fa_cu_seqlens_q = thd_cu_seqlens_q_padded_per_step[i] + fa_cu_seqlens_kv = cu_seqlens_kv_padded fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, - ctx.dqkv_format, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv_per_step[i], + ctx.qkv_format, + cu_seqlens_q=fa_cu_seqlens_q, + cu_seqlens_kv=fa_cu_seqlens_kv, max_seqlen_q=ctx.max_seqlen_q, max_seqlen_kv=max_seqlen_kv, dq=dq_per_step[i], dk=dk_per_step[i], dv=dv_per_step[i], + seqused_q=seqused_q, + seqused_k=seqused_k, ) if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[i] @@ -3790,55 +4014,83 @@ def backward(ctx, dout, *_args): if i > 0: # dq/dk/dv, dq_per_step/dk_per_step/dv_per_step: ctx.fwd_nominal_dtype with torch.cuda.stream(flash_attn_streams[i - 1]): - # dq: [b, 2, s//2, h, d] or [2, s//2, b, h, d] - # dq_per_step[i]: [b, s//2, h, d] or [s//2, b, h, d] - if ctx.dqkv_format == "bshd": - dq[:, i - 1].copy_(dq_per_step[i - 1]) - elif ctx.dqkv_format == "sbhd": - dq[i - 1].copy_(dq_per_step[i - 1]) - # dk/dv: [cp*s, b, h, d] - # dk_per_step[i - 1]/dv_per_step[i - 1]: [s_range, b, h, d] or [b, s_range, h, d] - # move s to first dim: [s_range, b, h, d] - dk_per_step[i - 1], dv_per_step[i - 1] = [ - x.movedim(seq_dim_dqkv, 0).contiguous() - for x in [dk_per_step[i - 1], dv_per_step[i - 1]] - ] - # wait until dkv update of last step is done - if i > 1: - flash_attn_streams[i - 1].wait_event(dkv_update_done) - seq_start_idx, seq_end_idx = ( - kv_seq_range_per_step[i - 1][0], - kv_seq_range_per_step[i - 1][1], - ) - # add to dk/dv: [cp*s, b, h, d] - dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1]) - dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1]) - if i < len(local_seq_chunk_ids): - flash_attn_streams[i - 1].record_event(dkv_update_done) + if ctx.qkv_format == "thd": + # dQ: copy every sequence's valid token range from this split's dQ. + tex.thd_copy_valid_tokens_from_per_split_to_rank_local( + dq, + dq_per_step[i - 1], + thd_cu_seqlens_q_padded_per_step[i - 1], + thd_cu_seqlens_q_per_step[i - 1], + ) + # dK/dV: accumulate full packed tensors. Padded entries may accumulate, + # but valid token entries are independent and only valid entries are + # reordered after the per-step reductions. + if i > 1: + flash_attn_streams[i - 1].wait_event(dkv_update_done) + dk.add_(dk_per_step[i - 1]) + dv.add_(dv_per_step[i - 1]) + if i < len(local_seq_chunk_ids): + flash_attn_streams[i - 1].record_event(dkv_update_done) + else: + # dq: [b, 2, s//2, h, d] or [2, s//2, b, h, d] + # dq_per_step[i]: [b, s//2, h, d] or [s//2, b, h, d] + if ctx.dqkv_format == "bshd": + dq[:, i - 1].copy_(dq_per_step[i - 1]) + elif ctx.dqkv_format == "sbhd": + dq[i - 1].copy_(dq_per_step[i - 1]) + # dk/dv: [cp*s, b, h, d] + # dk_per_step[i - 1]/dv_per_step[i - 1]: [s_range, b, h, d] or [b, s_range, h, d] + # move s to first dim: [s_range, b, h, d] + dk_per_step[i - 1], dv_per_step[i - 1] = [ + x.movedim(seq_dim_dqkv, 0).contiguous() + for x in [dk_per_step[i - 1], dv_per_step[i - 1]] + ] + # wait until dkv update of last step is done + if i > 1: + flash_attn_streams[i - 1].wait_event(dkv_update_done) + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i - 1][0], + kv_seq_range_per_step[i - 1][1], + ) + # add to dk/dv: [cp*s, b, h, d] + dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1]) + dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1]) + if i < len(local_seq_chunk_ids): + flash_attn_streams[i - 1].record_event(dkv_update_done) torch.cuda.current_stream().wait_stream(ctx.cp_stream) - # split s:[cp*s, b, h, d] -> [cp*2, s//2, b, h, d] - dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) - dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) - # put back together the right chunks for each rank - chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device) - dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) - dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) - # flatten: [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] - dk = dk.view(-1, *dk.shape[-3:]) - dv = dv.view(-1, *dv.shape[-3:]) - # reduce scatter: [cp*s, b, h, d] -> [s, b, h, d] - dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) - dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group) - - # reshape to original format: - # dq: [b, 2, s//2, h, d] or [2, s//2, b, h, d] -> [b, s, h, d] or [s, b, h, d] - # dk: [s, b, h, d] -> [b, s, h, d] or [s, b, h, d] - # dv: [s, b, h, d] -> [b, s, h, d] or [s, b, h, d] - dq = dq.view(*dq.shape[:seq_dim_dqkv], -1, *dq.shape[(seq_dim_dqkv + 2) :]) - dk = dk.movedim(0, seq_dim_dqkv).contiguous() - dv = dv.movedim(0, seq_dim_dqkv).contiguous() + if ctx.qkv_format == "thd": + # Reorder dK/dV from sequence order back to dual-chunk CP rank order, + # then reduce-scatter across CP ranks. + # Use padded cu_seqlens for correct slice boundaries. + dk = thd_sequence_order_to_cp_rank_order(dk, cu_seqlens_kv_padded, cp_size) + dv = thd_sequence_order_to_cp_rank_order(dv, cu_seqlens_kv_padded, cp_size) + dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) + dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group) + # dQ is already [t_rank, h, d], no reshape needed + else: + # split s:[cp*s, b, h, d] -> [cp*2, s//2, b, h, d] + dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) + dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) + # put back together the right chunks for each rank + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device) + dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) + dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) + # flatten: [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + dk = dk.view(-1, *dk.shape[-3:]) + dv = dv.view(-1, *dv.shape[-3:]) + # reduce scatter: [cp*s, b, h, d] -> [s, b, h, d] + dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) + dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group) + + # reshape to original format: + # dq: [b, 2, s//2, h, d] or [2, s//2, b, h, d] -> [b, s, h, d] or [s, b, h, d] + # dk: [s, b, h, d] -> [b, s, h, d] or [s, b, h, d] + # dv: [s, b, h, d] -> [b, s, h, d] or [s, b, h, d] + dq = dq.view(*dq.shape[:seq_dim_dqkv], -1, *dq.shape[(seq_dim_dqkv + 2) :]) + dk = dk.movedim(0, seq_dim_dqkv).contiguous() + dv = dv.movedim(0, seq_dim_dqkv).contiguous() # quantize if necessary if ctx.fp8 and ctx.is_input_fp8: @@ -3871,6 +4123,9 @@ def backward(ctx, dout, *_args): None, None, None, + None, + None, + None, ) @@ -4817,13 +5072,12 @@ def attn_forward_func_with_cp( ] out = AttnFuncWithCPAndKVP2P.apply(*args) elif cp_comm_type == "all_gather": - args.pop(5) - args.pop(8) args += [ window_size, cp_group, cp_stream, use_flash_attn_3, + pad_between_seqs, fp8, fp8_meta, quantizers, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 8a07d7af79..9913b78dfc 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1159,7 +1159,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt cp_comm_type, ) use_fused_attention = False - elif qkv_format == "thd" and cp_comm_type in ["all_gather", "a2a+p2p"]: + elif qkv_format == "thd" and cp_comm_type in ["a2a+p2p"]: logger.debug( "Disabling FusedAttention as it does not support context parallelism with THD" " format and cp_comm_type = %s", diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 62dcaadc96..c9f9c11d4e 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Python interface for fused attention extensions""" + import math from typing import Tuple, List, Union, Optional import torch @@ -18,7 +19,6 @@ from ..quantized_tensor import Quantizer from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx, DType - __all__ = [ "fused_attn_fwd", "fused_attn_bwd", @@ -304,7 +304,6 @@ def fused_attn_fwd( raise ValueError(f"Unsupported backend {fused_attention_backend}") # execute kernel - output_tensors = tex.fused_attn_fwd( max_seqlen_q, max_seqlen_kv, @@ -364,25 +363,18 @@ def fused_attn_fwd( max_tensor = max_tensor.masked_fill(~valid, float("-inf")) elif max_tensor.ndim == 3: if cu_seqlens_q_padded is not None: - # For THD + pad_between_seqs=True + non-sm120 + cuDNN>9.6, Max tensor is [tq, h, 1] - # and padding positions could be uninitialized. Exclude those padded positions when - # computing max_logit. + # Exclude padded THD rows; CP may pass nonzero padded offsets. actual_seqlens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).to( device=max_tensor.device ) - padded_seqlens = (cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1]).to( - device=max_tensor.device - ) - pad_lens = (padded_seqlens - actual_seqlens).to(device=max_tensor.device) - b = pad_lens.shape[0] - - # Stack [actual, pad] per batch into counts: e.g. [3,1, 3,1, 2,2, 7,1] - counts = torch.stack([actual_seqlens, pad_lens], dim=1).flatten() - # Tile [T, F] per sequence: [T,F, T,F, T,F, T,F] - values = torch.tensor([True, False], device=max_tensor.device).repeat(b) - # Expand: T×3, F×1, T×3, F×1, T×2, F×2, T×7, F×1 → TTTF|TTTF|TTFF|TTTTTTTF - valid = torch.repeat_interleave(values, counts) - # Finally, replace invalid (F) positions with -inf + tq = max_tensor.shape[0] + starts = cu_seqlens_q_padded[:-1].to(device=max_tensor.device) + ends = (starts + actual_seqlens).clamp(max=tq) + delta = torch.zeros(tq + 1, dtype=torch.int32, device=max_tensor.device) + updates = torch.ones_like(starts, dtype=torch.int32) + delta.scatter_add_(0, starts.clamp(max=tq), updates) + delta.scatter_add_(0, ends, -updates) + valid = delta[:-1].cumsum(0) > 0 max_tensor = max_tensor.masked_fill(~valid.view(-1, 1, 1), float("-inf")) # Max -> max_logit [h] diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 2b4f899e1d..c71abc5b0b 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -557,6 +557,16 @@ void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens, int world_size, int rank); +at::Tensor thd_sequence_order_to_cp_rank_order(const at::Tensor &inp, const at::Tensor &cu_seqlens, + int cp_size, int total_tokens); + +at::Tensor thd_cp_rank_order_to_sequence_order(const at::Tensor &inp, const at::Tensor &cu_seqlens, + int cp_size, int total_tokens); + +void thd_copy_valid_tokens_from_per_split_to_rank_local(at::Tensor out, const at::Tensor &inp, + const at::Tensor &cu_seqlens_padded, + const at::Tensor &cu_seqlens); + /*************************************************************************************************** * multi_tensor_* kernels **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 7e8018b3fd..eb8813d4a0 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -975,6 +975,76 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t return output; } +at::Tensor thd_reorder_between_sequence_and_cp_rank_order(const at::Tensor &inp, + const at::Tensor &cu_seqlens, int cp_size, + bool cp_rank_to_sequence_order, + int total_tokens) { + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens.dim() == 1); + NVTE_CHECK(cu_seqlens.size(0) >= 2); + NVTE_CHECK(cp_size > 0); + NVTE_CHECK(total_tokens > 0 && total_tokens % (cp_size * 2) == 0); + NVTE_CHECK(inp.dim() >= 1 && inp.size(0) == total_tokens); + + auto inp_c = inp.contiguous(); + at::Tensor out = at::empty_like(inp_c); + + auto te_inp = makeTransformerEngineTensor(inp_c); + auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens); + auto te_out = makeTransformerEngineTensor(out); + + if (cp_rank_to_sequence_order) { + nvte_thd_cp_rank_order_to_sequence_order(te_inp.data(), te_cu_seqlens.data(), te_out.data(), + cp_size, total_tokens, + at::cuda::getCurrentCUDAStream()); + } else { + nvte_thd_sequence_order_to_cp_rank_order(te_inp.data(), te_cu_seqlens.data(), te_out.data(), + cp_size, total_tokens, + at::cuda::getCurrentCUDAStream()); + } + + return out; +} + +at::Tensor thd_sequence_order_to_cp_rank_order(const at::Tensor &inp, const at::Tensor &cu_seqlens, + int cp_size, int total_tokens) { + return thd_reorder_between_sequence_and_cp_rank_order(inp, cu_seqlens, cp_size, false, + total_tokens); +} + +at::Tensor thd_cp_rank_order_to_sequence_order(const at::Tensor &inp, const at::Tensor &cu_seqlens, + int cp_size, int total_tokens) { + return thd_reorder_between_sequence_and_cp_rank_order(inp, cu_seqlens, cp_size, true, + total_tokens); +} + +void thd_copy_valid_tokens_from_per_split_to_rank_local(at::Tensor out, const at::Tensor &inp, + const at::Tensor &cu_seqlens_padded, + const at::Tensor &cu_seqlens) { + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens_padded.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens.dim() == 1 && cu_seqlens_padded.dim() == 1); + NVTE_CHECK(cu_seqlens.size(0) >= 2); + NVTE_CHECK(cu_seqlens_padded.size(0) == cu_seqlens.size(0)); + NVTE_CHECK(inp.dim() >= 1); + NVTE_CHECK(out.sizes() == inp.sizes() && out.scalar_type() == inp.scalar_type()); + NVTE_CHECK(out.is_contiguous(), + "thd_copy_valid_tokens_from_per_split_to_rank_local output must be contiguous."); + + auto inp_c = inp.contiguous(); + auto cu_seqlens_padded_c = cu_seqlens_padded.contiguous(); + auto cu_seqlens_c = cu_seqlens.contiguous(); + int total_tokens = inp_c.size(0); + auto te_inp = makeTransformerEngineTensor(inp_c); + auto te_cu_seqlens_padded = makeTransformerEngineTensor(cu_seqlens_padded_c); + auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens_c); + auto te_out = makeTransformerEngineTensor(out); + + nvte_thd_copy_valid_tokens_from_per_split_to_rank_local( + te_inp.data(), te_cu_seqlens_padded.data(), te_cu_seqlens.data(), te_out.data(), total_tokens, + at::cuda::getCurrentCUDAStream()); +} + /*************************************************************************************************** * KV Cache: Convert a tensor from qkv_format = thd to qkv_format = bshd **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d6089b1e01..756b34b102 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -560,6 +560,18 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("thd_get_partitioned_indices", &transformer_engine::pytorch::thd_get_partitioned_indices, "Generate partitioned indices for inputs in THD format", py::call_guard()); + m.def("thd_sequence_order_to_cp_rank_order", + &transformer_engine::pytorch::thd_sequence_order_to_cp_rank_order, + "Reorder a THD tensor from sequence order to dual-chunk CP rank order", + py::call_guard()); + m.def("thd_cp_rank_order_to_sequence_order", + &transformer_engine::pytorch::thd_cp_rank_order_to_sequence_order, + "Reorder a THD tensor from dual-chunk CP rank order to sequence order", + py::call_guard()); + m.def("thd_copy_valid_tokens_from_per_split_to_rank_local", + &transformer_engine::pytorch::thd_copy_valid_tokens_from_per_split_to_rank_local, + "Copy valid THD token entries from a per-split tensor into a rank-local accumulator", + py::call_guard()); // nvshmem functions m.def("init_nvshmem_backend", &transformer_engine::pytorch::init_nvshmem_backend,