-
Notifications
You must be signed in to change notification settings - Fork 749
feat: add SM_121 (GB10 consumer Blackwell) support for FA4 #3125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8a86608
9a8817d
45f2d76
43e5c9b
268f5ed
db898eb
ceaad74
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -225,30 +225,46 @@ def nvcc_path() -> Tuple[str, str]: | |
|
|
||
|
|
||
| @functools.lru_cache(maxsize=None) | ||
| def get_cuda_include_dirs() -> Tuple[str, str]: | ||
| """Returns the CUDA header directory.""" | ||
| def get_cuda_include_dirs() -> List[Path]: | ||
| """Returns all CUDA-related header directories. | ||
|
|
||
| Combines the system CUDA toolkit path (if present) with include directories | ||
| from nvidia pip wheel packages (nvidia-cudnn-cu1x, nvidia-nccl-cu1x, …). | ||
|
|
||
| On modern setups these are NOT mutually exclusive: the system toolkit may | ||
| provide core CUDA headers while cudnn and nccl arrive as separate pip wheels. | ||
| Returning both avoids "fatal error: cudnn.h / nccl.h not found" when a | ||
| system CUDA toolkit is present but those components are pip-distributed. | ||
| """ | ||
| dirs: List[Path] = [] | ||
|
|
||
| force_wheels = bool(int(os.getenv("NVTE_BUILD_USE_NVIDIA_WHEELS", "0"))) | ||
| # If cuda is installed via toolkit, all necessary headers | ||
| # are bundled inside the top level cuda directory. | ||
| if not force_wheels and cuda_toolkit_include_path() is not None: | ||
| return [cuda_toolkit_include_path()] | ||
| dirs.append(cuda_toolkit_include_path()) | ||
|
|
||
| # Use pip wheels to include all headers. | ||
| # Always supplement with nvidia pip wheel includes so components like | ||
| # cudnn and nccl are found even when the system toolkit doesn't have them. | ||
| try: | ||
| import nvidia | ||
| except ModuleNotFoundError as e: | ||
|
|
||
| cuda_root = Path(nvidia.__path__[0]) # namespace package — no __file__ | ||
| dirs.extend( | ||
| subdir / "include" | ||
| for subdir in cuda_root.iterdir() | ||
| if subdir.is_dir() and (subdir / "include").is_dir() | ||
| ) | ||
| except (ImportError, StopIteration, IndexError, AttributeError): | ||
| pass | ||
|
|
||
| if not dirs: | ||
| raise RuntimeError("CUDA not found.") | ||
|
|
||
| if nvidia.__file__ is not None: | ||
| cuda_root = Path(nvidia.__file__).parent | ||
| else: | ||
| cuda_root = Path(nvidia.__path__[0]) # namespace | ||
| return [ | ||
| subdir / "include" | ||
| for subdir in cuda_root.iterdir() | ||
| if subdir.is_dir() and (subdir / "include").is_dir() | ||
| ] | ||
| return dirs | ||
|
|
||
|
|
||
| def get_nccl_include_dirs() -> List[Path]: | ||
| """Compatibility shim — nccl includes are now returned by get_cuda_include_dirs().""" | ||
| return [] | ||
|
|
||
|
|
||
| @functools.lru_cache(maxsize=None) | ||
|
|
@@ -257,7 +273,11 @@ def cuda_archs() -> str: | |
| if archs is None: | ||
| version = cuda_version() | ||
| if version >= (13, 0): | ||
| archs = "75;80;89;90;100;120" | ||
| archs = "75;80;89;90;100;120;121" | ||
| elif version >= (12, 9): | ||
| # SM_121 (GB10 consumer Blackwell) confirmed under CUDA >= 12.9; | ||
| # minimum version for 12.8 needs verification on target hardware. | ||
| archs = "70;80;89;90;100;120;121" | ||
|
Comment on lines
+277
to
+280
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The inline comment says "confirmed under CUDA >= 12.9; minimum version for 12.8 needs verification", yet |
||
| elif version >= (12, 8): | ||
| archs = "70;80;89;90;100;120" | ||
| else: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| ================================================================================ | ||
| SM_121 (GB10 Consumer Blackwell / DGX Spark) Flash Attention 4 Test Results | ||
| ================================================================================ | ||
| Date: 2026-06-13 | ||
| GPU: NVIDIA GB10 (sm_121) | CUDA 12.8 | flash-attn-4 v4.0.0b16 | ||
|
Comment on lines
+1
to
+5
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Committing generated test output files to the repository is generally discouraged — they become stale as the code evolves, are not run as part of CI, and add noise to Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||
| nvidia-cutlass-dsl: 4.5.2 | TransformerEngine: feature/sm121-blackwell-support | ||
|
|
||
| -------------------------------------------------------------------------------- | ||
| PRIMARY TEST: test_dpa_fa4_sm121 (SM_121-gated, unambiguous CI signal) | ||
| -------------------------------------------------------------------------------- | ||
| 3/3 PASS | ||
|
|
||
| PASSED test_dpa_fa4_sm121[fa4_base_1-model_configs0-dtype0] # b=4, sq=128, hq=16, hdim=64 | ||
| PASSED test_dpa_fa4_sm121[fa4_base_2-model_configs0-dtype0] # b=2, sq=2048, hq=24, hdim=128, causal | ||
| PASSED test_dpa_fa4_sm121[fa4_base_3-model_configs0-dtype0] # b=2, sq=1024, hq=8, hdim=96, causal | ||
|
|
||
| These configs cover MHA forward+backward, causal masking, head dims 64/96/128, | ||
| sequence lengths 128–2048 — all exercising the SM_121 FA4 CuTe-DSL kernel path. | ||
|
|
||
| -------------------------------------------------------------------------------- | ||
| FULL FA4 SUITE — ISOLATION RESULTS (each category run in a separate process; | ||
| full-suite run has MLIR context corruption from upstream GQA compile errors) | ||
| -------------------------------------------------------------------------------- | ||
|
|
||
| test_dpa_fa4_sm121 : 3/3 PASS (SM_121-gated; our primary signal) | ||
| test_dpa_fa4_base : 3/7 PASS (gqa_1/2 + splits_1/2 → upstream FA4 bugs) | ||
| test_dpa_fa4_mask : 3/6 PASS (padding variants → TE cuDNN backend) | ||
| test_dpa_fa4_varlen : 3/8 PASS (thd layout + gqa_varlen → upstream FA4/cuDNN) | ||
| test_dpa_fa4_sliding_window : 0/8 FAIL (SM_80 bwd has no local mask → see below) | ||
| test_dpa_fa4_hdim256 : 1/1 SKIP (SM100-only kernel; expected on SM_121) | ||
|
|
||
| TOTAL ISOLATION PASSING: 12 / 32 | ||
|
|
||
| -------------------------------------------------------------------------------- | ||
| KNOWN LIMITATIONS (upstream flash-attn-4 bugs — NOT SM_121 or TE defects) | ||
| -------------------------------------------------------------------------------- | ||
|
|
||
| 1. GQA / pack_gqa (fa4_gqa_1, fa4_gqa_2, fa4_varlen_3) | ||
| Error: ValueError: Operation creation failed | ||
| MLIR: unable to compute crd2idx with layout '(?):(1)' and coord '((?,?))' | ||
| at pack_gqa.py:139 compute_ptr() | ||
| Cause: pack_gqa.store_O(mO_cur, ...) expects mO_cur shaped | ||
| ((qhead_per_kvhead, seqlen_q), headdim) (hierarchical), but | ||
| FlashAttentionForwardSm80.__call__ passes flat (seqlen_q, headdim). | ||
| Scope: SM_80 code path (SM_80 and SM_121 equally); flash-attn-4 b16 regression. | ||
|
|
||
| 2. SplitKV (fa4_splits_1, fa4_splits_2) | ||
| Error: AssertionError: SplitKV not supported on SM 12.0 in this PR | ||
| Cause: Explicit upstream stub — SplitKV kernel not yet ported to SM_12x. | ||
| Scope: SM_120 + SM_121; deliberate placeholder. | ||
|
|
||
| 3. Sliding window backward (all 8 test_dpa_fa4_sliding_window tests) | ||
| Error: Numerical mismatch — dQ wrong (abs diff up to 0.43, tol 0.015) | ||
| Cause: FlashAttentionBackwardSm80.kernel() creates AttentionMask with | ||
| mask_causal=self.is_causal only, never mask_local=True nor window_size. | ||
| Local window is ignored in backward → incorrect gradients. | ||
| Forward is CORRECT (AttentionMask full local mask support, is_local=True). | ||
| Note: Before this patch these tests crashed earlier with: | ||
| DSLRuntimeError: argument #13 (window_size_left): (Int32, NoneType) but got int. | ||
| That crash (interface.py missing Int32 cast) is now fixed; the pre-existing | ||
| backward limitation is now visible as the failure mode. | ||
| Scope: SM_80 code path; upstream flash-attn-4 limitation. | ||
|
|
||
| 4. THD varlen layout — cuDNN fallback (thd_thd_thd fa4_varlen_1..4) | ||
| Error: RuntimeError: cuDNN Error: No valid engine configs | ||
| Cause: TE routes thd_thd_thd varlen to cuDNN fused-attn backend (not FA4). | ||
| cuDNN has no valid engine for SM_121 + this combination. | ||
| Scope: TE cuDNN backend; not flash-attn-4. | ||
|
|
||
| 5. Padding mask tests — cuDNN fallback (fa4_mask_padding*) | ||
| Error: RuntimeError: cuDNN Error: No valid engine configs | ||
| Cause: Same — TE routes padding+varlen mask combos to cuDNN backend. | ||
|
|
||
| -------------------------------------------------------------------------------- | ||
| PATCHES APPLIED TO flash-attn-4 v4.0.0b16 TO ENABLE SM_121 | ||
| (to be reported as separate upstream flash-attn-4 issues) | ||
| -------------------------------------------------------------------------------- | ||
|
|
||
| Patch 1 — flash_fwd.py:658 use_tma_O guard | ||
| Old: self.use_tma_O = self.arch >= Arch.sm_90 | ||
| New: self.use_tma_O = Arch.sm_90 <= self.arch < Arch.sm_120 | ||
| Why: FlashAttentionForwardSm120 sets class attr arch=80, but __init__ clobbers it | ||
| with Arch.sm_121 via BaseDSL._get_dsl().get_arch_enum(). sm_121 >= sm_90 → True | ||
| → tries tma_atom_O=None → AttributeError: 'NoneType'.._trait. | ||
|
|
||
| Patch 2 — interface.py dQ_single_wg in SM_12x backward branch | ||
| Added: dQ_single_wg = False inside "if arch // 10 == 12:" block. | ||
| Why: Used in compile_key for arch//10 in [8,9,12] but never set for arch//10==12 | ||
| → UnboundLocalError. | ||
|
|
||
| Patch 3 — flash_bwd.py:440 softmax_scale preservation | ||
| Old: softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2(...) | ||
| New: softmax_scale_log2, _ = utils.compute_softmax_scale_log2(...) | ||
| Why: compute_softmax_scale_log2 returns None when score_mod is None, clobbering the | ||
| original Float32 value the kernel requires. | ||
|
|
||
| Patch 4 — utils.py:486 nvvm.atomicrmw API | ||
| Old: nvvm.atomicrmw(res=T.f32(), op=..., ptr=..., a=...) | ||
| New: nvvm.atomicrmw(op=..., ptr=..., a=...) | ||
| Why: 'res' kwarg removed in nvidia-cutlass-dsl 4.5.2. | ||
|
|
||
| Patch 5 — interface.py window_size Int32 cast in cute.compile paths | ||
| Added: window_size_left_typed = Int32(wsl) if wsl is not None else None (fwd + bwd) | ||
| Used window_size_*_typed in compile_args instead of raw Python int. | ||
| Why: cute.compile requires Int32, not plain int, for Optional[Int32] params. | ||
| SM_90 path implicitly coerces; SM_80/12x else-branch does not. | ||
|
|
||
| ================================================================================ | ||
| END | ||
| ================================================================================ | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,7 +28,10 @@ endif() | |
| # Process GPU architectures | ||
| if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) | ||
| if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0) | ||
| set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) | ||
| set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120 121) | ||
| elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9) | ||
| # SM_121 (GB10 consumer Blackwell) placed under >= 12.9; lower bound to be confirmed. | ||
| set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120 121) | ||
| elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) | ||
| set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120) | ||
| else () | ||
|
|
@@ -83,6 +86,18 @@ if(NOT arch_120_index EQUAL -1) | |
| endif() | ||
| endif() | ||
|
|
||
| # Check for architecture 121 (GB10 consumer Blackwell / DGX Spark) | ||
| list(FIND CMAKE_CUDA_ARCHITECTURES "121" arch_121_index) | ||
| if(NOT arch_121_index EQUAL -1) | ||
| list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "121") | ||
| list(APPEND NVTE_GENERIC_ARCHS "121") | ||
| if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9) | ||
| list(APPEND NVTE_SPECIFIC_ARCHS "121f") | ||
| else() | ||
| list(APPEND NVTE_SPECIFIC_ARCHS "121a") | ||
| endif() | ||
| endif() | ||
|
|
||
| # Move remaining standard (pre-Blackwell) architectures into NVTE_STANDARD_ARCHS. | ||
| # These are applied to all CUDA sources (both generic and arch-specific). | ||
| set(NVTE_STANDARD_ARCHS ${CMAKE_CUDA_ARCHITECTURES}) | ||
|
|
@@ -369,6 +384,18 @@ target_include_directories(transformer_engine PRIVATE | |
| ${CUTLASS_INCLUDE_DIR} | ||
| ${CUTLASS_TOOLS_INCLUDE_DIR}) | ||
|
|
||
| # NCCL headers are included in TE's public API (e.g. comm_gemm_overlap.h). | ||
| # On pip-based CUDA installs (nvidia-nccl-* wheels) nccl.h is not in the | ||
| # CUDA toolkit tree, so find it explicitly and add it to the build. | ||
| find_path(NVTE_NCCL_INCLUDE_DIR | ||
| NAMES nccl.h | ||
| PATH_SUFFIXES include | ||
| QUIET) | ||
| if(NVTE_NCCL_INCLUDE_DIR) | ||
| target_include_directories(transformer_engine PRIVATE ${NVTE_NCCL_INCLUDE_DIR}) | ||
| message(STATUS "Found NCCL headers at: ${NVTE_NCCL_INCLUDE_DIR}") | ||
| endif() | ||
|
Comment on lines
+390
to
+397
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| # Compiling Userbuffers with native MPI bootstrapping requires linking against MPI | ||
| option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) | ||
| if (NVTE_UB_WITH_MPI) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_nccl_include_dirs()is a brand-new function (introduced in this PR) that immediately returns[]. The docstring calls it a "compatibility shim", but there are no pre-existing callers to stay compatible with —pytorch.pyis also calling it for the first time in this same PR. The only concrete effect of theinclude_dirs.extend(get_nccl_include_dirs())call inpytorch.pyisextend([]), which is a no-op. The function and its import can be removed without any behavioral change; NCCL headers are now properly included viaget_cuda_include_dirs().