feat: add SM_121 (GB10 consumer Blackwell) support for FA4#3125
Conversation
SM_121 (DGX Spark / GB10) is a consumer Blackwell variant. This commit extends TransformerEngine to compile and dispatch Flash Attention 4 for SM_121. The hardware shares the same cuDNN driver and attention constraints as SM_120 (professional Blackwell / GB200), so all SM_120-specific workarounds are extended to cover SM_121. Changes: - build_tools/utils.py: add 121 to cuda_archs() under CUDA >= 12.8 / >= 13.0 - CMakeLists.txt: add SM_121 arch classification block (clones SM_120 pattern; emits 121f for CUDA >= 12.9, else 121a) - fused_attn.cpp: extend sm_arch_ == 120 cuDNN restriction guard to cover 121 - fused_attn_f16_arbitrary_seqlen.cu: extend all 6 sm_arch_ != 120 cuDNN workarounds (ragged QKV stride/layout, max_logit shape) to also exclude 121 - attention/dot_product_attention/utils.py: update FA4 SM support comment - jax/quantize/device_utils.py: extend FP8 GEMM layout gate to include SM_121 (note: JAX-only, untested locally - PyTorch-primary contribution) - tests/pytorch/attention/test_attention.py: add test_dpa_fa4_sm121 gated on device_compute_capability == (12, 1) ptx.cuh's ARCH_BLACKWELL_FAMILY already covers SM_12x via FamilySpecific<120>. Test results on GB10 (SM_121, aarch64, CUDA 12.8): pending - current training workload on GB10 must clear first. Results will be attached to the PR. Signed-off-by: Tonio <liebrr@gmail.com>
Fixes found by code review of the initial SM_121 patch: - context_parallel.py: extend softmax_lse_in_packed_format gate to exclude SM_121 in addition to SM_120. SM_121 (THD FusedAttention) returns a [b,h,sq,1] tensor from C++; the packed-format assumption caused a shape mismatch and wrong results in context-parallel ring-attention. - attention/dot_product_attention/utils.py: extend all Python-level SM_120 restriction guards to also cover SM_121: * KV-cache FusedAttention disable (same cuDNN bug as SM_120) * KV-cache FlashAttention disable * FP8 attention disable (SM_121 does not support FP8 attention) * MLA head_dim_qk > 128 backward gate * THD qkv_format: cuDNN < 9.18.1 gate and t3hd/th3d layout gate - jax/quantize/device_utils.py: revert is_fp8_gemm_with_all_layouts_supported change. The original `< 120` already correctly excludes SM_120 (integer 120) and SM_121 (integer 121); the prior `< 122` change accidentally included SM_120. - build_tools/utils.py + CMakeLists.txt: split SM_121 into a CUDA >= 12.9 branch rather than >= 12.8, matching the minimum CUDA version for the `f` (family-specific) suffix. The 12.8 lower bound needs hardware verification; the PR description requests confirmation from the NVIDIA team. Signed-off-by: Tonio <liebrr@gmail.com>
nccl.h is included in TE's public headers (comm_gemm_overlap.h, comm_gemm.h, newton_schulz.h) but the CMakeLists.txt only searched for it inside the NVTE_WITH_CUBLASMP block. On systems where CUDA is installed via PyPI wheels (nvidia-nccl-cu1x) rather than the system CUDA toolkit, nccl.h is not in the CUDA toolkit tree and the build fails with "nccl.h: No such file or directory". Add an unconditional find_path(NVTE_NCCL_INCLUDE_DIR ... QUIET) and wire it into target_include_directories when found. Using QUIET keeps the build non-fatal on systems where NCCL is genuinely absent. Fixes build on DGX Spark (GB10 / SM_121) with pip-installed CUDA 13.x. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Tonio <liebrr@gmail.com>
… CUDA The PyTorch CppExtension build (build_tools/pytorch.py) derives its include paths from get_cuda_include_dirs(). When CUDA is installed via the system toolkit (/usr/local/cuda), that function short-circuits and never walks the nvidia pip-wheel namespace packages, so the nccl include from nvidia-nccl-cu1x is missed. Add get_nccl_include_dirs() in build_tools/utils.py that: - Returns [] if nccl.h is already in the CUDA toolkit tree - Otherwise locates it via the nvidia.nccl pip namespace package Wire it into setup_pytorch_extension() so both the CMake core library and the PyTorch extension can find nccl.h on pip-wheel CUDA setups (e.g. DGX Spark / GB10 with CUDA 13.x). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Tonio <liebrr@gmail.com>
…el includes On systems where CUDA is installed via the system toolkit AND components like cudnn / nccl arrive as separate pip wheels (nvidia-cudnn-cu1x, nvidia-nccl-cu1x), the previous early-return for toolkit builds caused "fatal error: cudnn.h / nccl.h not found" in the PyTorch extension build. Make get_cuda_include_dirs() additive: always append nvidia pip wheel include directories after the system toolkit path. This preserves the existing toolkit-first ordering while also finding headers that are only in pip wheels. Keep get_nccl_include_dirs() as a no-op compatibility shim. Fixes PyTorch extension build on DGX Spark (GB10 / SM_121) with system CUDA 13.x + cudnn/nccl via pip wheels. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Tonio <liebrr@gmail.com>
model_configs_fa4_sm121 replaces the full model_configs_fa4_base in the SM_121-gated test, excluding GQA (upstream pack_gqa hierarchical-layout bug in flash-attn-4 b16 SM_80 path) and SplitKV (explicit upstream assertion). The 3 MHA base configs all pass 3/3 on GB10 (sm_121). te_sm121_test_results.txt documents the GB10 test run: primary test passes 3/3, plus isolation results for the full FA4 suite with root-cause analysis of all remaining failures (all upstream flash-attn-4 issues, not SM_121 or TE defects). Includes all 5 patches applied to flash-attn-4 b16 to unblock SM_121 FA4 support. Signed-off-by: Tonio Liebrand <liebrr@gmail.com> Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds SM_121 (GB10 / DGX Spark consumer Blackwell) support to TransformerEngine by extending every SM_120-specific runtime guard and build-system arch list to also cover SM_121, and by reworking
Confidence Score: 3/5The runtime attention-backend changes are correct and consistent; the CMake NCCL discovery silently fails on pip-only DGX Spark, leaving the primary stated use case potentially broken at the cmake build level. FA4 kernel routing, cuDNN workarounds, and Python backend-selection logic are all correctly mirrored from SM_120 to SM_121, and the dedicated test passes 3/3. The gap is find_path in CMakeLists.txt having no path hints for pip wheel directories, so nccl.h won't resolve on DGX Spark where NCCL comes only from pip. transformer_engine/common/CMakeLists.txt (NCCL find_path needs HINTS for pip wheel paths) and build_tools/utils.py (no-op get_nccl_include_dirs + CUDA version comment vs test evidence) Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Runtime: sm_arch_ detected] --> B{sm_arch_ == 120 or 121?}
B -->|Yes| C[fused_attn.cpp: cuDNN special-case path]
B -->|No| D[Default cuDNN path]
C --> E{cudnn < 9.18.1?}
E -->|Yes| F[No Backend]
E -->|No| G{deterministic training?}
G -->|Yes| F
G -->|No| H{T3HD/TH3D layout?}
H -->|Yes| F
H -->|No| I[NVTE_F16_arbitrary_seqlen backend]
I --> J[fused_attn_f16_arbitrary_seqlen.cu]
J --> K{sm_arch_ != 120 and != 121?}
K -->|True| L[Packed ragged stats + token-count dims]
K -->|False| M[BHSD-like strides, max_seqlen]
N[Python: device_compute_capability] --> O{cap in 12,0 or 12,1?}
O -->|Yes| P[Disable FP8, FlashAttn KV-cache, MLA bwd checks]
O -->|No| Q[Standard backend selection]
P --> R[FA4 still enabled cap >= 8,0]
Q --> R
S[Build: cuda_archs / CMakeLists] --> T{CUDA >= 12.9?}
T -->|Yes| U[Include SM_121 in arch list]
T -->|No CUDA >= 12.8| V[SM_120 only, no SM_121]
Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| find_path(NVTE_NCCL_INCLUDE_DIR | ||
| NAMES nccl.h | ||
| PATH_SUFFIXES include | ||
| QUIET) | ||
| if(NVTE_NCCL_INCLUDE_DIR) | ||
| target_include_directories(transformer_engine PRIVATE ${NVTE_NCCL_INCLUDE_DIR}) | ||
| message(STATUS "Found NCCL headers at: ${NVTE_NCCL_INCLUDE_DIR}") | ||
| endif() |
There was a problem hiding this comment.
find_path won't locate pip-installed NCCL headers
find_path without PATHS or HINTS searches only standard system include paths (e.g., /usr/include, /usr/local/include) and any CMAKE_PREFIX_PATH entries. On DGX Spark — the primary target of this PR — NCCL headers land in a site-packages tree like ~/.local/lib/python3.x/site-packages/nvidia/nccl/include/, which is never in CMake's default search roots. The QUIET flag silences the failure, NVTE_NCCL_INCLUDE_DIR is left empty, and target_include_directories is never called. Any translation unit that includes nccl.h (e.g., comm_gemm_overlap.h) will then fail to compile in a pip-only CUDA environment — the exact scenario this PR is supposed to fix. A HINTS pointing to the Python-resolved path (passed in via -DNVTE_NCCL_INCLUDE_HINT from build_ext.py) or an execute_process call to query Python would close the gap.
| def get_nccl_include_dirs() -> List[Path]: | ||
| """Compatibility shim — nccl includes are now returned by get_cuda_include_dirs().""" | ||
| return [] |
There was a problem hiding this comment.
Dead shim with a misleading "compatibility" label
get_nccl_include_dirs() is a brand-new function (introduced in this PR) that immediately returns []. The docstring calls it a "compatibility shim", but there are no pre-existing callers to stay compatible with — pytorch.py is also calling it for the first time in this same PR. The only concrete effect of the include_dirs.extend(get_nccl_include_dirs()) call in pytorch.py is extend([]), which is a no-op. The function and its import can be removed without any behavioral change; NCCL headers are now properly included via get_cuda_include_dirs().
| elif version >= (12, 9): | ||
| # SM_121 (GB10 consumer Blackwell) confirmed under CUDA >= 12.9; | ||
| # minimum version for 12.8 needs verification on target hardware. | ||
| archs = "70;80;89;90;100;120;121" |
There was a problem hiding this comment.
Comment contradicts the test evidence
The inline comment says "confirmed under CUDA >= 12.9; minimum version for 12.8 needs verification", yet te_sm121_test_results.txt (committed in this same PR) explicitly shows CUDA 12.8 in its header line and reports 3/3 passing. Either SM_121 should be added to the >= (12, 8) arch list (if NVCC 12.8 accepts sm_121), or the test results file should clarify that the FA4 kernel path uses JIT compilation (bypassing NVCC's arch list) so the CUDA version constraint only affects native TE CUDA kernels.
| ================================================================================ | ||
| SM_121 (GB10 Consumer Blackwell / DGX Spark) Flash Attention 4 Test Results | ||
| ================================================================================ | ||
| Date: 2026-06-13 | ||
| GPU: NVIDIA GB10 (sm_121) | CUDA 12.8 | flash-attn-4 v4.0.0b16 |
There was a problem hiding this comment.
Test-result artifact committed to source tree
Committing generated test output files to the repository is generally discouraged — they become stale as the code evolves, are not run as part of CI, and add noise to git log. The conventional approach is to attach them to the PR description (as a GitHub comment or linked gist) rather than merging them into the source tree. If the intent is permanent documentation, consider a location under docs/ with a clear label rather than the repo root.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Summary
Adds SM_121 (GB10 / DGX Spark consumer Blackwell) support to TransformerEngine, enabling Flash Attention 4 on GB10 hardware. SM_121 is a minor SM_120 variant sharing the same ISA (SM80-era MMA + CpAsync, no TMA) but with a slightly different arch integer (
12,1vs12,0).Primary test result (GB10 hardware):
test_dpa_fa4_sm121— 3/3 PASSFull test results with root-cause analysis for all remaining failures: see
te_sm121_test_results.txtin this PR.Changes (6 files)
1.
build_tools/utils.pyget_cuda_include_dirs()rewritten to be additive — always supplements the system CUDA toolkit with headers from thenvidia-cuda-*pip wheels. This fixes builds on systems where only pip-installed CUDA is available (DGX Spark default configuration).get_nccl_include_dirs()kept as a compatibility shim (returns empty list).2.
transformer_engine/common/CMakeLists.txt121ffor CUDA ≥ 12.9, else121a.3.
transformer_engine/common/fused_attn/fused_attn.cppsm_arch_ == 120cuDNN special-case to(sm_arch_ == 120 || sm_arch_ == 121). Body is identical — SM_121 uses the same cuDNN path as SM_120.4.
transformer_engine/pytorch/attention/dot_product_attention/utils.py>= (8, 0)gate.5.
tests/pytorch/attention/test_attention.pytest_dpa_fa4_sm121gated ondevice_compute_capability == (12, 1), usingmodel_configs_fa4_sm121(MHA-only, no GQA/splits — those fail due to upstream flash-attn-4 limitations noted in the test comments and inte_sm121_test_results.txt).6.
te_sm121_test_results.txtFlash Attention 4 SM_121 Bringup Findings
During SM_121 bringup with
flash-attn-4 v4.0.0b16, 5 bugs in the flash-attn-4 package were found and patched locally (in the installed package, not in TE). These should be fixed upstream in flash-attn-4:flash_fwd.py:658AttributeError: 'NoneType'.._traitFlashAttentionForwardSm120.__init__clobbers classarch=80withArch.sm_121;sm_121 >= sm_90→ triestma_atom_O=Noneinterface.pybwdUnboundLocalError: dQ_single_wgdQ_single_wg, used incompile_keyforarch//10 in [8,9,12]flash_bwd.py:440DSLRuntimeError: None to Float32compute_softmax_scale_log2returnsNone; was clobberingsoftmax_scaleneeded by kernelutils.py:486TypeError: atomicrmw() got 'res'res=T.f32()kwarg removed fromnvvm.atomicrmwin nvidia-cutlass-dsl 4.5.2interface.pyfwd+bwdDSLRuntimeError: Int32 expected, got intcute.compilerequiresInt32, not plain Pythonint, forOptional[Int32]params; SM_80/12x path had no implicit coercionTest Results Summary (GB10 / SM_121 hardware)
All failures are pre-existing upstream flash-attn-4 limitations affecting the SM_80 code path (which SM_121 shares). None are SM_121-specific regressions.
Testing
Tested on NVIDIA GB10 (DGX Spark),
device_compute_capability = (12, 1):pytest tests/pytorch/attention/test_attention.py::test_dpa_fa4_sm121 -v # → 3 passedGitHub CI (compile + import, CPU-only) cannot run FA4 kernels. SM_121 hardware results are attached via
te_sm121_test_results.txt. Requesting/te-ciGPU CI trigger from @cyanguwa or team.🤖 Generated with Claude Code