Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
all_files_in_dir,
cuda_version,
get_cuda_include_dirs,
get_nccl_include_dirs,
debug_build_enabled,
setup_mpi_flags,
)
Expand Down Expand Up @@ -49,6 +50,7 @@ def setup_pytorch_extension(

# Header files
include_dirs = get_cuda_include_dirs()
include_dirs.extend(get_nccl_include_dirs())
include_dirs.extend(
[
common_header_files,
Expand Down
54 changes: 37 additions & 17 deletions build_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,30 +225,46 @@ def nvcc_path() -> Tuple[str, str]:


@functools.lru_cache(maxsize=None)
def get_cuda_include_dirs() -> Tuple[str, str]:
"""Returns the CUDA header directory."""
def get_cuda_include_dirs() -> List[Path]:
"""Returns all CUDA-related header directories.

Combines the system CUDA toolkit path (if present) with include directories
from nvidia pip wheel packages (nvidia-cudnn-cu1x, nvidia-nccl-cu1x, …).

On modern setups these are NOT mutually exclusive: the system toolkit may
provide core CUDA headers while cudnn and nccl arrive as separate pip wheels.
Returning both avoids "fatal error: cudnn.h / nccl.h not found" when a
system CUDA toolkit is present but those components are pip-distributed.
"""
dirs: List[Path] = []

force_wheels = bool(int(os.getenv("NVTE_BUILD_USE_NVIDIA_WHEELS", "0")))
# If cuda is installed via toolkit, all necessary headers
# are bundled inside the top level cuda directory.
if not force_wheels and cuda_toolkit_include_path() is not None:
return [cuda_toolkit_include_path()]
dirs.append(cuda_toolkit_include_path())

# Use pip wheels to include all headers.
# Always supplement with nvidia pip wheel includes so components like
# cudnn and nccl are found even when the system toolkit doesn't have them.
try:
import nvidia
except ModuleNotFoundError as e:

cuda_root = Path(nvidia.__path__[0]) # namespace package — no __file__
dirs.extend(
subdir / "include"
for subdir in cuda_root.iterdir()
if subdir.is_dir() and (subdir / "include").is_dir()
)
except (ImportError, StopIteration, IndexError, AttributeError):
pass

if not dirs:
raise RuntimeError("CUDA not found.")

if nvidia.__file__ is not None:
cuda_root = Path(nvidia.__file__).parent
else:
cuda_root = Path(nvidia.__path__[0]) # namespace
return [
subdir / "include"
for subdir in cuda_root.iterdir()
if subdir.is_dir() and (subdir / "include").is_dir()
]
return dirs


def get_nccl_include_dirs() -> List[Path]:
"""Compatibility shim — nccl includes are now returned by get_cuda_include_dirs()."""
return []
Comment on lines +265 to +267

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Dead shim with a misleading "compatibility" label

get_nccl_include_dirs() is a brand-new function (introduced in this PR) that immediately returns []. The docstring calls it a "compatibility shim", but there are no pre-existing callers to stay compatible with — pytorch.py is also calling it for the first time in this same PR. The only concrete effect of the include_dirs.extend(get_nccl_include_dirs()) call in pytorch.py is extend([]), which is a no-op. The function and its import can be removed without any behavioral change; NCCL headers are now properly included via get_cuda_include_dirs().



@functools.lru_cache(maxsize=None)
Expand All @@ -257,7 +273,11 @@ def cuda_archs() -> str:
if archs is None:
version = cuda_version()
if version >= (13, 0):
archs = "75;80;89;90;100;120"
archs = "75;80;89;90;100;120;121"
elif version >= (12, 9):
# SM_121 (GB10 consumer Blackwell) confirmed under CUDA >= 12.9;
# minimum version for 12.8 needs verification on target hardware.
archs = "70;80;89;90;100;120;121"
Comment on lines +277 to +280

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Comment contradicts the test evidence

The inline comment says "confirmed under CUDA >= 12.9; minimum version for 12.8 needs verification", yet te_sm121_test_results.txt (committed in this same PR) explicitly shows CUDA 12.8 in its header line and reports 3/3 passing. Either SM_121 should be added to the >= (12, 8) arch list (if NVCC 12.8 accepts sm_121), or the test results file should clarify that the FA4 kernel path uses JIT compilation (bypassing NVCC's arch list) so the CUDA version constraint only affects native TE CUDA kernels.

elif version >= (12, 8):
archs = "70;80;89;90;100;120"
else:
Expand Down
110 changes: 110 additions & 0 deletions te_sm121_test_results.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
================================================================================
SM_121 (GB10 Consumer Blackwell / DGX Spark) Flash Attention 4 Test Results
================================================================================
Date: 2026-06-13
GPU: NVIDIA GB10 (sm_121) | CUDA 12.8 | flash-attn-4 v4.0.0b16
Comment on lines +1 to +5

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Test-result artifact committed to source tree

Committing generated test output files to the repository is generally discouraged — they become stale as the code evolves, are not run as part of CI, and add noise to git log. The conventional approach is to attach them to the PR description (as a GitHub comment or linked gist) rather than merging them into the source tree. If the intent is permanent documentation, consider a location under docs/ with a clear label rather than the repo root.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

nvidia-cutlass-dsl: 4.5.2 | TransformerEngine: feature/sm121-blackwell-support

--------------------------------------------------------------------------------
PRIMARY TEST: test_dpa_fa4_sm121 (SM_121-gated, unambiguous CI signal)
--------------------------------------------------------------------------------
3/3 PASS

PASSED test_dpa_fa4_sm121[fa4_base_1-model_configs0-dtype0] # b=4, sq=128, hq=16, hdim=64
PASSED test_dpa_fa4_sm121[fa4_base_2-model_configs0-dtype0] # b=2, sq=2048, hq=24, hdim=128, causal
PASSED test_dpa_fa4_sm121[fa4_base_3-model_configs0-dtype0] # b=2, sq=1024, hq=8, hdim=96, causal

These configs cover MHA forward+backward, causal masking, head dims 64/96/128,
sequence lengths 128–2048 — all exercising the SM_121 FA4 CuTe-DSL kernel path.

--------------------------------------------------------------------------------
FULL FA4 SUITE — ISOLATION RESULTS (each category run in a separate process;
full-suite run has MLIR context corruption from upstream GQA compile errors)
--------------------------------------------------------------------------------

test_dpa_fa4_sm121 : 3/3 PASS (SM_121-gated; our primary signal)
test_dpa_fa4_base : 3/7 PASS (gqa_1/2 + splits_1/2 → upstream FA4 bugs)
test_dpa_fa4_mask : 3/6 PASS (padding variants → TE cuDNN backend)
test_dpa_fa4_varlen : 3/8 PASS (thd layout + gqa_varlen → upstream FA4/cuDNN)
test_dpa_fa4_sliding_window : 0/8 FAIL (SM_80 bwd has no local mask → see below)
test_dpa_fa4_hdim256 : 1/1 SKIP (SM100-only kernel; expected on SM_121)

TOTAL ISOLATION PASSING: 12 / 32

--------------------------------------------------------------------------------
KNOWN LIMITATIONS (upstream flash-attn-4 bugs — NOT SM_121 or TE defects)
--------------------------------------------------------------------------------

1. GQA / pack_gqa (fa4_gqa_1, fa4_gqa_2, fa4_varlen_3)
Error: ValueError: Operation creation failed
MLIR: unable to compute crd2idx with layout '(?):(1)' and coord '((?,?))'
at pack_gqa.py:139 compute_ptr()
Cause: pack_gqa.store_O(mO_cur, ...) expects mO_cur shaped
((qhead_per_kvhead, seqlen_q), headdim) (hierarchical), but
FlashAttentionForwardSm80.__call__ passes flat (seqlen_q, headdim).
Scope: SM_80 code path (SM_80 and SM_121 equally); flash-attn-4 b16 regression.

2. SplitKV (fa4_splits_1, fa4_splits_2)
Error: AssertionError: SplitKV not supported on SM 12.0 in this PR
Cause: Explicit upstream stub — SplitKV kernel not yet ported to SM_12x.
Scope: SM_120 + SM_121; deliberate placeholder.

3. Sliding window backward (all 8 test_dpa_fa4_sliding_window tests)
Error: Numerical mismatch — dQ wrong (abs diff up to 0.43, tol 0.015)
Cause: FlashAttentionBackwardSm80.kernel() creates AttentionMask with
mask_causal=self.is_causal only, never mask_local=True nor window_size.
Local window is ignored in backward → incorrect gradients.
Forward is CORRECT (AttentionMask full local mask support, is_local=True).
Note: Before this patch these tests crashed earlier with:
DSLRuntimeError: argument #13 (window_size_left): (Int32, NoneType) but got int.
That crash (interface.py missing Int32 cast) is now fixed; the pre-existing
backward limitation is now visible as the failure mode.
Scope: SM_80 code path; upstream flash-attn-4 limitation.

4. THD varlen layout — cuDNN fallback (thd_thd_thd fa4_varlen_1..4)
Error: RuntimeError: cuDNN Error: No valid engine configs
Cause: TE routes thd_thd_thd varlen to cuDNN fused-attn backend (not FA4).
cuDNN has no valid engine for SM_121 + this combination.
Scope: TE cuDNN backend; not flash-attn-4.

5. Padding mask tests — cuDNN fallback (fa4_mask_padding*)
Error: RuntimeError: cuDNN Error: No valid engine configs
Cause: Same — TE routes padding+varlen mask combos to cuDNN backend.

--------------------------------------------------------------------------------
PATCHES APPLIED TO flash-attn-4 v4.0.0b16 TO ENABLE SM_121
(to be reported as separate upstream flash-attn-4 issues)
--------------------------------------------------------------------------------

Patch 1 — flash_fwd.py:658 use_tma_O guard
Old: self.use_tma_O = self.arch >= Arch.sm_90
New: self.use_tma_O = Arch.sm_90 <= self.arch < Arch.sm_120
Why: FlashAttentionForwardSm120 sets class attr arch=80, but __init__ clobbers it
with Arch.sm_121 via BaseDSL._get_dsl().get_arch_enum(). sm_121 >= sm_90 → True
→ tries tma_atom_O=None → AttributeError: 'NoneType'.._trait.

Patch 2 — interface.py dQ_single_wg in SM_12x backward branch
Added: dQ_single_wg = False inside "if arch // 10 == 12:" block.
Why: Used in compile_key for arch//10 in [8,9,12] but never set for arch//10==12
→ UnboundLocalError.

Patch 3 — flash_bwd.py:440 softmax_scale preservation
Old: softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2(...)
New: softmax_scale_log2, _ = utils.compute_softmax_scale_log2(...)
Why: compute_softmax_scale_log2 returns None when score_mod is None, clobbering the
original Float32 value the kernel requires.

Patch 4 — utils.py:486 nvvm.atomicrmw API
Old: nvvm.atomicrmw(res=T.f32(), op=..., ptr=..., a=...)
New: nvvm.atomicrmw(op=..., ptr=..., a=...)
Why: 'res' kwarg removed in nvidia-cutlass-dsl 4.5.2.

Patch 5 — interface.py window_size Int32 cast in cute.compile paths
Added: window_size_left_typed = Int32(wsl) if wsl is not None else None (fwd + bwd)
Used window_size_*_typed in compile_args instead of raw Python int.
Why: cute.compile requires Int32, not plain int, for Optional[Int32] params.
SM_90 path implicitly coerces; SM_80/12x else-branch does not.

================================================================================
END
================================================================================
31 changes: 31 additions & 0 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,37 @@ def test_dpa_fa4_base(dtype, model_configs, model):
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)


model_configs_fa4_sm121 = {
# MHA-only configs confirmed working on SM_121 (GB10) via flash-attn-4 SM_80 code path.
# GQA excluded: flash-attn-4 pack_gqa store_O has a hierarchical-layout mismatch on SM_80/12x
# (upstream flash-attn-4 issue, not a TE or SM_121 defect).
# num_splits excluded: explicit upstream assertion "SplitKV not supported on SM 12.0".
"fa4_base_1": ModelConfig(4, 128, 16, 64),
"fa4_base_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
"fa4_base_3": ModelConfig(2, 1024, 8, 96, attn_mask_type="causal"),
}


@pytest.mark.skipif(
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
)
@pytest.mark.skipif(
device_compute_capability != (12, 1),
reason="SM_121 (GB10 consumer Blackwell / DGX Spark) specific FA4 correctness test.",
)
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_fa4_sm121])
@pytest.mark.parametrize("model", model_configs_fa4_sm121.keys())
def test_dpa_fa4_sm121(dtype, model_configs, model):
"""Test DotProductAttention with FA4 on SM_121 (GB10 consumer Blackwell).

SM_121 is architecturally a variant of SM_120 (professional Blackwell). This test
gates explicitly on (12, 1) to ensure the SM_121 SASS path is exercised and to
provide an unambiguous CI signal for GB10 hardware.
"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)


# head_dim=256 is supported only on SM100 via FA4's dedicated kernel
# (flash_attn/cute/sm100_hd256_2cta_fmha_*.py), available in flash-attn-4 > 4.0.0b10.
# On other architectures, _validate_head_dims rejects (256, 256), FA4 is disabled, and
Expand Down
29 changes: 28 additions & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ endif()
# Process GPU architectures
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0)
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120)
set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120 121)
elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9)
# SM_121 (GB10 consumer Blackwell) placed under >= 12.9; lower bound to be confirmed.
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120 121)
elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120)
else ()
Expand Down Expand Up @@ -83,6 +86,18 @@ if(NOT arch_120_index EQUAL -1)
endif()
endif()

# Check for architecture 121 (GB10 consumer Blackwell / DGX Spark)
list(FIND CMAKE_CUDA_ARCHITECTURES "121" arch_121_index)
if(NOT arch_121_index EQUAL -1)
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "121")
list(APPEND NVTE_GENERIC_ARCHS "121")
if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9)
list(APPEND NVTE_SPECIFIC_ARCHS "121f")
else()
list(APPEND NVTE_SPECIFIC_ARCHS "121a")
endif()
endif()

# Move remaining standard (pre-Blackwell) architectures into NVTE_STANDARD_ARCHS.
# These are applied to all CUDA sources (both generic and arch-specific).
set(NVTE_STANDARD_ARCHS ${CMAKE_CUDA_ARCHITECTURES})
Expand Down Expand Up @@ -369,6 +384,18 @@ target_include_directories(transformer_engine PRIVATE
${CUTLASS_INCLUDE_DIR}
${CUTLASS_TOOLS_INCLUDE_DIR})

# NCCL headers are included in TE's public API (e.g. comm_gemm_overlap.h).
# On pip-based CUDA installs (nvidia-nccl-* wheels) nccl.h is not in the
# CUDA toolkit tree, so find it explicitly and add it to the build.
find_path(NVTE_NCCL_INCLUDE_DIR
NAMES nccl.h
PATH_SUFFIXES include
QUIET)
if(NVTE_NCCL_INCLUDE_DIR)
target_include_directories(transformer_engine PRIVATE ${NVTE_NCCL_INCLUDE_DIR})
message(STATUS "Found NCCL headers at: ${NVTE_NCCL_INCLUDE_DIR}")
endif()
Comment on lines +390 to +397

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 find_path won't locate pip-installed NCCL headers

find_path without PATHS or HINTS searches only standard system include paths (e.g., /usr/include, /usr/local/include) and any CMAKE_PREFIX_PATH entries. On DGX Spark — the primary target of this PR — NCCL headers land in a site-packages tree like ~/.local/lib/python3.x/site-packages/nvidia/nccl/include/, which is never in CMake's default search roots. The QUIET flag silences the failure, NVTE_NCCL_INCLUDE_DIR is left empty, and target_include_directories is never called. Any translation unit that includes nccl.h (e.g., comm_gemm_overlap.h) will then fail to compile in a pip-only CUDA environment — the exact scenario this PR is supposed to fix. A HINTS pointing to the Python-resolved path (passed in via -DNVTE_NCCL_INCLUDE_HINT from build_ext.py) or an execute_process call to query Python would close the gap.


# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI
option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF)
if (NVTE_UB_WITH_MPI)
Expand Down
16 changes: 9 additions & 7 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,23 +502,25 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
"Please upgrade your cuDNN version if possible."
<< std::endl;
}
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen && sm_arch_ == 120) {
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen &&
(sm_arch_ == 120 || sm_arch_ == 121)) {
if (cudnn_runtime_version < 91801) {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: Given combination of sm_arch_ == 120 and cudnn_runtime_version < "
"91801 is not supported. "
std::cout << "Warning: Given combination of sm_arch_ == " << sm_arch_
<< " and cudnn_runtime_version < 91801 is not supported. "
<< " Please upgrade your cuDNN version if possible." << std::endl;
} else if (deterministic && is_training) {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: Deterministic fused attention on SM120 is not supported."
<< std::endl;
std::cout << "Warning: Deterministic fused attention on SM" << sm_arch_
<< " is not supported." << std::endl;
} else {
// Known missing support for T3HD/TH3D layouts on SM120
// Known missing support for T3HD/TH3D layouts on SM120/SM121
const bool is_t3hd_or_th3d =
(qkv_layout == NVTE_QKV_Layout::NVTE_T3HD || qkv_layout == NVTE_QKV_Layout::NVTE_TH3D);
if (is_t3hd_or_th3d) {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: Given combination of T3HD/TH3D layouts on SM120 is not supported. "
std::cout << "Warning: Given combination of T3HD/TH3D layouts on SM" << sm_arch_
<< " is not supported. "
<< " Please consider using other THD layouts if possible." << std::endl;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
const auto cudnn_runtime_version = cudnnGetVersion();
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120;
bool use_ragged_stats =
is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120 && sm_arch_ != 121;

NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD);
Expand All @@ -101,7 +102,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
// On SM 120, cuDNN support check treats layouts with stride[0] > dim[1]*dim[2]*dim[3]
// as interleaved and rejects them. Use BHSD-like dimensions/strides with max_seqlen at plan build
// so the check passes; ragged offset still provides variable-length boundaries.
if (sm_arch_ != 120) {
if (sm_arch_ != 120 && sm_arch_ != 121) {
// replace batch size and maximum sequence lengths with maximum token counts
// for query and key/value so the graph is static within each quantization bucket
b = max_b;
Expand Down Expand Up @@ -590,7 +591,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
const auto cudnn_runtime_version = cudnnGetVersion();
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120;
bool use_ragged_stats =
is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120 && sm_arch_ != 121;

NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD);
Expand All @@ -603,7 +605,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
// On SM 120, cuDNN support check requires BHSD-like strides with max_seqlen (see fwd).
if (sm_arch_ != 120) {
if (sm_arch_ != 120 && sm_arch_ != 121) {
// replace batch size and maximum sequence lengths with maximum token counts
// for query and key/value so the graph is static within each quantization bucket
b = max_b;
Expand Down Expand Up @@ -805,7 +807,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
if (use_ragged_stats) {
sdpa_backward_options.set_max_total_seq_len_q(s_q);
}
if (is_ragged_kv && cudnn_runtime_version >= 90600 && sm_arch_ != 120) {
if (is_ragged_kv && cudnn_runtime_version >= 90600 && sm_arch_ != 120 && sm_arch_ != 121) {
sdpa_backward_options.set_max_total_seq_len_kv(s_kv);
}

Expand Down Expand Up @@ -1153,7 +1155,7 @@ void fused_attn_arbitrary_seqlen_fwd(
Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_Max->data.dptr = nullptr;
if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) &&
(sm_arch_ != 120)) {
(sm_arch_ != 120 && sm_arch_ != 121)) {
output_Max->data.shape = {num_tokens_q, num_attn_heads, 1};
} else {
output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1671,7 +1671,7 @@ def forward(
9,
6,
0,
) and get_device_compute_capability() != (12, 0)
) and get_device_compute_capability() not in ((12, 0), (12, 1))
else:
softmax_lse_in_packed_format = fa_utils.v2_6_0_plus or use_flash_attn_3

Expand Down
Loading