diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index e2e6d09c29..e9146e5d5d 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -14,6 +14,7 @@ all_files_in_dir, cuda_version, get_cuda_include_dirs, + get_nccl_include_dirs, debug_build_enabled, setup_mpi_flags, ) @@ -49,6 +50,7 @@ def setup_pytorch_extension( # Header files include_dirs = get_cuda_include_dirs() + include_dirs.extend(get_nccl_include_dirs()) include_dirs.extend( [ common_header_files, diff --git a/build_tools/utils.py b/build_tools/utils.py index f2548b4de6..11f781898c 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -225,30 +225,46 @@ def nvcc_path() -> Tuple[str, str]: @functools.lru_cache(maxsize=None) -def get_cuda_include_dirs() -> Tuple[str, str]: - """Returns the CUDA header directory.""" +def get_cuda_include_dirs() -> List[Path]: + """Returns all CUDA-related header directories. + + Combines the system CUDA toolkit path (if present) with include directories + from nvidia pip wheel packages (nvidia-cudnn-cu1x, nvidia-nccl-cu1x, …). + + On modern setups these are NOT mutually exclusive: the system toolkit may + provide core CUDA headers while cudnn and nccl arrive as separate pip wheels. + Returning both avoids "fatal error: cudnn.h / nccl.h not found" when a + system CUDA toolkit is present but those components are pip-distributed. + """ + dirs: List[Path] = [] force_wheels = bool(int(os.getenv("NVTE_BUILD_USE_NVIDIA_WHEELS", "0"))) - # If cuda is installed via toolkit, all necessary headers - # are bundled inside the top level cuda directory. if not force_wheels and cuda_toolkit_include_path() is not None: - return [cuda_toolkit_include_path()] + dirs.append(cuda_toolkit_include_path()) - # Use pip wheels to include all headers. + # Always supplement with nvidia pip wheel includes so components like + # cudnn and nccl are found even when the system toolkit doesn't have them. try: import nvidia - except ModuleNotFoundError as e: + + cuda_root = Path(nvidia.__path__[0]) # namespace package — no __file__ + dirs.extend( + subdir / "include" + for subdir in cuda_root.iterdir() + if subdir.is_dir() and (subdir / "include").is_dir() + ) + except (ImportError, StopIteration, IndexError, AttributeError): + pass + + if not dirs: raise RuntimeError("CUDA not found.") - if nvidia.__file__ is not None: - cuda_root = Path(nvidia.__file__).parent - else: - cuda_root = Path(nvidia.__path__[0]) # namespace - return [ - subdir / "include" - for subdir in cuda_root.iterdir() - if subdir.is_dir() and (subdir / "include").is_dir() - ] + return dirs + + +def get_nccl_include_dirs() -> List[Path]: + """Compatibility shim — nccl includes are now returned by get_cuda_include_dirs().""" + return [] @functools.lru_cache(maxsize=None) @@ -257,7 +273,11 @@ def cuda_archs() -> str: if archs is None: version = cuda_version() if version >= (13, 0): - archs = "75;80;89;90;100;120" + archs = "75;80;89;90;100;120;121" + elif version >= (12, 9): + # SM_121 (GB10 consumer Blackwell) confirmed under CUDA >= 12.9; + # minimum version for 12.8 needs verification on target hardware. + archs = "70;80;89;90;100;120;121" elif version >= (12, 8): archs = "70;80;89;90;100;120" else: diff --git a/te_sm121_test_results.txt b/te_sm121_test_results.txt new file mode 100644 index 0000000000..926495acd2 --- /dev/null +++ b/te_sm121_test_results.txt @@ -0,0 +1,110 @@ +================================================================================ +SM_121 (GB10 Consumer Blackwell / DGX Spark) Flash Attention 4 Test Results +================================================================================ +Date: 2026-06-13 +GPU: NVIDIA GB10 (sm_121) | CUDA 12.8 | flash-attn-4 v4.0.0b16 +nvidia-cutlass-dsl: 4.5.2 | TransformerEngine: feature/sm121-blackwell-support + +-------------------------------------------------------------------------------- +PRIMARY TEST: test_dpa_fa4_sm121 (SM_121-gated, unambiguous CI signal) +-------------------------------------------------------------------------------- +3/3 PASS + + PASSED test_dpa_fa4_sm121[fa4_base_1-model_configs0-dtype0] # b=4, sq=128, hq=16, hdim=64 + PASSED test_dpa_fa4_sm121[fa4_base_2-model_configs0-dtype0] # b=2, sq=2048, hq=24, hdim=128, causal + PASSED test_dpa_fa4_sm121[fa4_base_3-model_configs0-dtype0] # b=2, sq=1024, hq=8, hdim=96, causal + +These configs cover MHA forward+backward, causal masking, head dims 64/96/128, +sequence lengths 128–2048 — all exercising the SM_121 FA4 CuTe-DSL kernel path. + +-------------------------------------------------------------------------------- +FULL FA4 SUITE — ISOLATION RESULTS (each category run in a separate process; +full-suite run has MLIR context corruption from upstream GQA compile errors) +-------------------------------------------------------------------------------- + + test_dpa_fa4_sm121 : 3/3 PASS (SM_121-gated; our primary signal) + test_dpa_fa4_base : 3/7 PASS (gqa_1/2 + splits_1/2 → upstream FA4 bugs) + test_dpa_fa4_mask : 3/6 PASS (padding variants → TE cuDNN backend) + test_dpa_fa4_varlen : 3/8 PASS (thd layout + gqa_varlen → upstream FA4/cuDNN) + test_dpa_fa4_sliding_window : 0/8 FAIL (SM_80 bwd has no local mask → see below) + test_dpa_fa4_hdim256 : 1/1 SKIP (SM100-only kernel; expected on SM_121) + + TOTAL ISOLATION PASSING: 12 / 32 + +-------------------------------------------------------------------------------- +KNOWN LIMITATIONS (upstream flash-attn-4 bugs — NOT SM_121 or TE defects) +-------------------------------------------------------------------------------- + +1. GQA / pack_gqa (fa4_gqa_1, fa4_gqa_2, fa4_varlen_3) + Error: ValueError: Operation creation failed + MLIR: unable to compute crd2idx with layout '(?):(1)' and coord '((?,?))' + at pack_gqa.py:139 compute_ptr() + Cause: pack_gqa.store_O(mO_cur, ...) expects mO_cur shaped + ((qhead_per_kvhead, seqlen_q), headdim) (hierarchical), but + FlashAttentionForwardSm80.__call__ passes flat (seqlen_q, headdim). + Scope: SM_80 code path (SM_80 and SM_121 equally); flash-attn-4 b16 regression. + +2. SplitKV (fa4_splits_1, fa4_splits_2) + Error: AssertionError: SplitKV not supported on SM 12.0 in this PR + Cause: Explicit upstream stub — SplitKV kernel not yet ported to SM_12x. + Scope: SM_120 + SM_121; deliberate placeholder. + +3. Sliding window backward (all 8 test_dpa_fa4_sliding_window tests) + Error: Numerical mismatch — dQ wrong (abs diff up to 0.43, tol 0.015) + Cause: FlashAttentionBackwardSm80.kernel() creates AttentionMask with + mask_causal=self.is_causal only, never mask_local=True nor window_size. + Local window is ignored in backward → incorrect gradients. + Forward is CORRECT (AttentionMask full local mask support, is_local=True). + Note: Before this patch these tests crashed earlier with: + DSLRuntimeError: argument #13 (window_size_left): (Int32, NoneType) but got int. + That crash (interface.py missing Int32 cast) is now fixed; the pre-existing + backward limitation is now visible as the failure mode. + Scope: SM_80 code path; upstream flash-attn-4 limitation. + +4. THD varlen layout — cuDNN fallback (thd_thd_thd fa4_varlen_1..4) + Error: RuntimeError: cuDNN Error: No valid engine configs + Cause: TE routes thd_thd_thd varlen to cuDNN fused-attn backend (not FA4). + cuDNN has no valid engine for SM_121 + this combination. + Scope: TE cuDNN backend; not flash-attn-4. + +5. Padding mask tests — cuDNN fallback (fa4_mask_padding*) + Error: RuntimeError: cuDNN Error: No valid engine configs + Cause: Same — TE routes padding+varlen mask combos to cuDNN backend. + +-------------------------------------------------------------------------------- +PATCHES APPLIED TO flash-attn-4 v4.0.0b16 TO ENABLE SM_121 +(to be reported as separate upstream flash-attn-4 issues) +-------------------------------------------------------------------------------- + +Patch 1 — flash_fwd.py:658 use_tma_O guard + Old: self.use_tma_O = self.arch >= Arch.sm_90 + New: self.use_tma_O = Arch.sm_90 <= self.arch < Arch.sm_120 + Why: FlashAttentionForwardSm120 sets class attr arch=80, but __init__ clobbers it + with Arch.sm_121 via BaseDSL._get_dsl().get_arch_enum(). sm_121 >= sm_90 → True + → tries tma_atom_O=None → AttributeError: 'NoneType'.._trait. + +Patch 2 — interface.py dQ_single_wg in SM_12x backward branch + Added: dQ_single_wg = False inside "if arch // 10 == 12:" block. + Why: Used in compile_key for arch//10 in [8,9,12] but never set for arch//10==12 + → UnboundLocalError. + +Patch 3 — flash_bwd.py:440 softmax_scale preservation + Old: softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2(...) + New: softmax_scale_log2, _ = utils.compute_softmax_scale_log2(...) + Why: compute_softmax_scale_log2 returns None when score_mod is None, clobbering the + original Float32 value the kernel requires. + +Patch 4 — utils.py:486 nvvm.atomicrmw API + Old: nvvm.atomicrmw(res=T.f32(), op=..., ptr=..., a=...) + New: nvvm.atomicrmw(op=..., ptr=..., a=...) + Why: 'res' kwarg removed in nvidia-cutlass-dsl 4.5.2. + +Patch 5 — interface.py window_size Int32 cast in cute.compile paths + Added: window_size_left_typed = Int32(wsl) if wsl is not None else None (fwd + bwd) + Used window_size_*_typed in compile_args instead of raw Python int. + Why: cute.compile requires Int32, not plain int, for Optional[Int32] params. + SM_90 path implicitly coerces; SM_80/12x else-branch does not. + +================================================================================ +END +================================================================================ diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 2dbf94fc20..f69a1eb760 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -352,6 +352,37 @@ def test_dpa_fa4_base(dtype, model_configs, model): test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) +model_configs_fa4_sm121 = { + # MHA-only configs confirmed working on SM_121 (GB10) via flash-attn-4 SM_80 code path. + # GQA excluded: flash-attn-4 pack_gqa store_O has a hierarchical-layout mismatch on SM_80/12x + # (upstream flash-attn-4 issue, not a TE or SM_121 defect). + # num_splits excluded: explicit upstream assertion "SplitKV not supported on SM 12.0". + "fa4_base_1": ModelConfig(4, 128, 16, 64), + "fa4_base_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"), + "fa4_base_3": ModelConfig(2, 1024, 8, 96, attn_mask_type="causal"), +} + + +@pytest.mark.skipif( + not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required." +) +@pytest.mark.skipif( + device_compute_capability != (12, 1), + reason="SM_121 (GB10 consumer Blackwell / DGX Spark) specific FA4 correctness test.", +) +@pytest.mark.parametrize("dtype", param_types_lean) +@pytest.mark.parametrize("model_configs", [model_configs_fa4_sm121]) +@pytest.mark.parametrize("model", model_configs_fa4_sm121.keys()) +def test_dpa_fa4_sm121(dtype, model_configs, model): + """Test DotProductAttention with FA4 on SM_121 (GB10 consumer Blackwell). + + SM_121 is architecturally a variant of SM_120 (professional Blackwell). This test + gates explicitly on (12, 1) to ensure the SM_121 SASS path is exercised and to + provide an unambiguous CI signal for GB10 hardware. + """ + test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) + + # head_dim=256 is supported only on SM100 via FA4's dedicated kernel # (flash_attn/cute/sm100_hd256_2cta_fmha_*.py), available in flash-attn-4 > 4.0.0b10. # On other architectures, _validate_head_dims rejects (256, 256), FA4 is disabled, and diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 8f96432ed8..ac43f88ba1 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -28,7 +28,10 @@ endif() # Process GPU architectures if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0) - set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) + set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120 121) + elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9) + # SM_121 (GB10 consumer Blackwell) placed under >= 12.9; lower bound to be confirmed. + set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120 121) elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120) else () @@ -83,6 +86,18 @@ if(NOT arch_120_index EQUAL -1) endif() endif() +# Check for architecture 121 (GB10 consumer Blackwell / DGX Spark) +list(FIND CMAKE_CUDA_ARCHITECTURES "121" arch_121_index) +if(NOT arch_121_index EQUAL -1) + list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "121") + list(APPEND NVTE_GENERIC_ARCHS "121") + if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9) + list(APPEND NVTE_SPECIFIC_ARCHS "121f") + else() + list(APPEND NVTE_SPECIFIC_ARCHS "121a") + endif() +endif() + # Move remaining standard (pre-Blackwell) architectures into NVTE_STANDARD_ARCHS. # These are applied to all CUDA sources (both generic and arch-specific). set(NVTE_STANDARD_ARCHS ${CMAKE_CUDA_ARCHITECTURES}) @@ -369,6 +384,18 @@ target_include_directories(transformer_engine PRIVATE ${CUTLASS_INCLUDE_DIR} ${CUTLASS_TOOLS_INCLUDE_DIR}) +# NCCL headers are included in TE's public API (e.g. comm_gemm_overlap.h). +# On pip-based CUDA installs (nvidia-nccl-* wheels) nccl.h is not in the +# CUDA toolkit tree, so find it explicitly and add it to the build. +find_path(NVTE_NCCL_INCLUDE_DIR + NAMES nccl.h + PATH_SUFFIXES include + QUIET) +if(NVTE_NCCL_INCLUDE_DIR) + target_include_directories(transformer_engine PRIVATE ${NVTE_NCCL_INCLUDE_DIR}) + message(STATUS "Found NCCL headers at: ${NVTE_NCCL_INCLUDE_DIR}") +endif() + # Compiling Userbuffers with native MPI bootstrapping requires linking against MPI option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) if (NVTE_UB_WITH_MPI) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index fc21771297..a302dc9a10 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -502,23 +502,25 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( "Please upgrade your cuDNN version if possible." << std::endl; } - if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen && sm_arch_ == 120) { + if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen && + (sm_arch_ == 120 || sm_arch_ == 121)) { if (cudnn_runtime_version < 91801) { backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: Given combination of sm_arch_ == 120 and cudnn_runtime_version < " - "91801 is not supported. " + std::cout << "Warning: Given combination of sm_arch_ == " << sm_arch_ + << " and cudnn_runtime_version < 91801 is not supported. " << " Please upgrade your cuDNN version if possible." << std::endl; } else if (deterministic && is_training) { backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: Deterministic fused attention on SM120 is not supported." - << std::endl; + std::cout << "Warning: Deterministic fused attention on SM" << sm_arch_ + << " is not supported." << std::endl; } else { - // Known missing support for T3HD/TH3D layouts on SM120 + // Known missing support for T3HD/TH3D layouts on SM120/SM121 const bool is_t3hd_or_th3d = (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD || qkv_layout == NVTE_QKV_Layout::NVTE_TH3D); if (is_t3hd_or_th3d) { backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: Given combination of T3HD/TH3D layouts on SM120 is not supported. " + std::cout << "Warning: Given combination of T3HD/TH3D layouts on SM" << sm_arch_ + << " is not supported. " << " Please consider using other THD layouts if possible." << std::endl; } } diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 6df7ad35c8..0803e5e34f 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -86,7 +86,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( const auto cudnn_runtime_version = cudnnGetVersion(); const int device_id = cuda::current_device(); const int sm_arch_ = cuda::sm_arch(device_id); - bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120; + bool use_ragged_stats = + is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120 && sm_arch_ != 121; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); @@ -101,7 +102,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( // On SM 120, cuDNN support check treats layouts with stride[0] > dim[1]*dim[2]*dim[3] // as interleaved and rejects them. Use BHSD-like dimensions/strides with max_seqlen at plan build // so the check passes; ragged offset still provides variable-length boundaries. - if (sm_arch_ != 120) { + if (sm_arch_ != 120 && sm_arch_ != 121) { // replace batch size and maximum sequence lengths with maximum token counts // for query and key/value so the graph is static within each quantization bucket b = max_b; @@ -590,7 +591,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( const auto cudnn_runtime_version = cudnnGetVersion(); const int device_id = cuda::current_device(); const int sm_arch_ = cuda::sm_arch(device_id); - bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120; + bool use_ragged_stats = + is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120 && sm_arch_ != 121; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); @@ -603,7 +605,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) { NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); // On SM 120, cuDNN support check requires BHSD-like strides with max_seqlen (see fwd). - if (sm_arch_ != 120) { + if (sm_arch_ != 120 && sm_arch_ != 121) { // replace batch size and maximum sequence lengths with maximum token counts // for query and key/value so the graph is static within each quantization bucket b = max_b; @@ -805,7 +807,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( if (use_ragged_stats) { sdpa_backward_options.set_max_total_seq_len_q(s_q); } - if (is_ragged_kv && cudnn_runtime_version >= 90600 && sm_arch_ != 120) { + if (is_ragged_kv && cudnn_runtime_version >= 90600 && sm_arch_ != 120 && sm_arch_ != 121) { sdpa_backward_options.set_max_total_seq_len_kv(s_kv); } @@ -1153,7 +1155,7 @@ void fused_attn_arbitrary_seqlen_fwd( Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_Max->data.dptr = nullptr; if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && - (sm_arch_ != 120)) { + (sm_arch_ != 120 && sm_arch_ != 121)) { output_Max->data.shape = {num_tokens_q, num_attn_heads, 1}; } else { output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 36847e40ed..6ae72b342b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1671,7 +1671,7 @@ def forward( 9, 6, 0, - ) and get_device_compute_capability() != (12, 0) + ) and get_device_compute_capability() not in ((12, 0), (12, 1)) else: softmax_lse_in_packed_format = fa_utils.v2_6_0_plus or use_flash_attn_3 diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 8a07d7af79..6f320b1f97 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -496,7 +496,7 @@ def _disable_all_flash_attention() -> None: if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed: logger.debug("Disabling FlashAttention 3 for compute capability != sm90") use_flash_attention_3 = False - # FA4 supports SM80, SM90, SM100, SM120 + # FA4 supports SM80, SM90, SM100, SM120, SM121 if device_compute_capability < (8, 0): if use_flash_attention_4 and FlashAttentionUtils.v4_is_installed: logger.debug("Disabling FlashAttention 4 for compute capability < sm80") @@ -636,16 +636,16 @@ def _disable_all_flash_attention() -> None: logger.debug("Disabling FusedAttention for %s", fp8_recipe.__class__.__name__) use_fused_attention = False - if device_compute_capability == (12, 0): + if device_compute_capability in ((12, 0), (12, 1)): if use_flash_attention: logger.debug( "Disabling FlashAttention as FP8 is not supported" - " for compute capability = sm120" + " for compute capability = sm120/sm121" ) if use_fused_attention: logger.debug( "Disabling FusedAttention as FP8 is not supported" - " for compute capability = sm120" + " for compute capability = sm120/sm121" ) use_flash_attention = False use_fused_attention = False @@ -754,14 +754,14 @@ def _disable_all_flash_attention() -> None: # Flash v4 | FP16/BF16 | TODO | sm80+ | bshd,sbhd,thd | TODO # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 if inference_params is not None: - # Temporarily disabling fused attention for kv caching for sm89/sm120 irrespective of + # Temporarily disabling fused attention for kv caching for sm89/sm120/sm121 irrespective of # cuDNN version until the cuDNN bug is resolved. - if device_compute_capability in ((8, 9), (12, 0)): - logger.debug("Disabling FusedAttention for KV caching for sm89/sm120") + if device_compute_capability in ((8, 9), (12, 0), (12, 1)): + logger.debug("Disabling FusedAttention for KV caching for sm89/sm120/sm121") use_fused_attention = False - # Temporarily disable FlashAttention for KV caching on sm120 - if device_compute_capability == (12, 0): - logger.debug("Disabling FlashAttention for KV caching for sm120") + # Temporarily disable FlashAttention for KV caching on sm120/sm121 + if device_compute_capability in ((12, 0), (12, 1)): + logger.debug("Disabling FlashAttention for KV caching for sm120/sm121") use_flash_attention = False if context_parallel: logger.debug("Disabling all backends for KV caching with context parallelism") @@ -824,15 +824,15 @@ def _disable_all_flash_attention() -> None: ) use_fused_attention = False if ( - device_compute_capability == (12, 0) + device_compute_capability in ((12, 0), (12, 1)) and (head_dim_qk > 128 or head_dim_qk % 8 != 0) and is_training ): if use_fused_attention: logger.debug( "Disabling FusedAttention as MLA for backward pass is not supported for compute" - " capability = sm120 for a head_dim_qk > 128 or head_dim_qk %%8 != 0. Found:" - " head_dim_qk = %s", + " capability = sm120/sm121 for a head_dim_qk > 128 or head_dim_qk %%8 != 0." + " Found: head_dim_qk = %s", head_dim_qk, ) use_fused_attention = False @@ -985,19 +985,19 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if use_unfused_attention: logger.debug("Disabling UnfusedDotProductAttention for pad_between_seqs = True") use_unfused_attention = False - if device_compute_capability == (12, 0): + if device_compute_capability in ((12, 0), (12, 1)): if cudnn_version < (9, 18, 1): if use_fused_attention: logger.debug( - "Disabling FusedAttention as qkv_format = thd is" - " not supported for compute capability = sm120 and cuDNN version < 9.18.1" + "Disabling FusedAttention as qkv_format = thd is not supported for compute" + " capability = sm120/sm121 and cuDNN version < 9.18.1" ) use_fused_attention = False elif qkv_layout in {"t3hd", "th3d"}: if use_fused_attention: logger.debug( "Disabling FusedAttention as qkv_layout = %s is not supported for" - " compute capability = sm120", + " compute capability = sm120/sm121", qkv_layout, ) use_fused_attention = False