Hopper (GH200 / sm_90) inference acceleration: opt-in, lossless by default#108
Open
hamuzhan wants to merge 8 commits into
Open
Hopper (GH200 / sm_90) inference acceleration: opt-in, lossless by default#108hamuzhan wants to merge 8 commits into
hamuzhan wants to merge 8 commits into
Conversation
The LQ-projector CausalConv3d kernels (4x3x3, stride (2,1,1), large channels)
run at ~27 TFLOP/s (~2.8% peak) on GH200 because cuDNN can't pick a tensor-core
implicit-GEMM for these shapes. They account for ~50% of denoise time.
Reformulate the core convolution as explicit im2col (unfold) + bf16 GEMM, which
saturates Hopper WGMMA tensor cores (~400+ TFLOP/s, ~15-17x faster) with
bit-identical math (parity cosine 0.999996).
- Phase 1 (D1): _conv3d_gemm + FLASHVSR_CONV3D_BACKEND={auto|gemm} knob +
sm_90 guard + silent fallback. Causal replicate-pad + streaming cache
preserved; only the padding-free core conv is rerouted. Single-point change
benefits conv1/conv2 in both Buffer_/Causal_ projectors.
- Phase 2 (D1.5): chunked im2col over the output-H axis bounded by
FLASHVSR_CONV3D_IM2COL_BUDGET_GB (default 2 GB) -> conv transient mem -51%
at 1920x2560 with no speed/parity loss.
E2E (v1.1 Tiny): 1.92x @1536, 1.91x @2560x1920; norm-FPS ~17 -> ~33-34;
PSNR(auto,gemm) 47.6-49.5 dB. Default 'auto' = no regression.
Also fix transformers PretrainedConfig import (renamed -> PreTrainedConfig)
so the pipeline import chain works on current transformers.
Compares GH200 against the README's A100 reference (~17 FPS @ 768x1408) at the exact same resolution, with both conv3d backends. Results @ 768x1408: - GH200 auto (cuDNN): 16.54 FPS (0.97x A100 - parity, no Hopper gain) - GH200 gemm (tensor-core): 31.56 FPS (1.86x A100) - gemm vs auto: 1.91x Confirms the thesis: before the im2col+GEMM backend, GH200 was stuck at A100 parity despite 3x the hardware; the tensor-core conv path unlocks the real Hopper advantage.
…eneck tooling After the conv3d im2col+GEMM work (1.86x A100), profiled the full denoise pipeline @768x1408 to find the next bottlenecks. GPU is ~91% busy (compute- bound). Breakdown: attention ~27%, GEMM ~21%, TCDecoder conv2d ~19%, norm/elementwise ~17%, copy/layout ~15%. 3-A) TCDecoder channels_last (NHWC): The TCDecoder is a pure Conv2d (TAEHV) graph running contiguous (NCHW), which made cuDNN insert nchwToNhwc/nhwcToNchw around every bf16 conv (~226ms /9%) and run the convs ~1.5x slower. Run the decoder in channels_last: - isolated decode: 231 -> 189 ms (1.22x), bit-identical (max|diff|=0) - E2E: 31.6 -> 32.9 FPS, 1.86 -> 1.93x A100, -0.5 GB peak Knob FLASHVSR_TCDECODER_CHANNELS_LAST (default on). 3-B) Adaptive attention backend (FLASHVSR_ATTN_BACKEND, default 'sparse'): Measured the real self-attn: seq=25344, 12 heads, dim=128, block-mask density ~0.606 (only 39% sparse). At that density cuDNN fused dense SDPA (6.5ms, 605 TFLOP/s) beats block_sparse (7.3ms) AND uses full context; FA2 dense is 10.6ms (cuDNN wins by 1.64x). Crossover is density ~0.5. Added density-adaptive routing (opt-in). E2E gain at default topk is negligible (~+0.5%) and dense changes the output (full vs trained sparse pattern, PSNR 41 dB), so default stays 'sparse' (no quality change). Tooling added (isolated + E2E): profile_e2e_bottlenecks.py, probe_attention_shapes.py, test_tcdecoder_channels_last.py, test_attention_backend.py, test_topk_sweep.py Note (not committed as default): topk/sparsity sweep shows sparse_ratio=1.5 gives 2.07x A100 @ PSNR 42.8 dB vs baseline (a documented 'faster' setting).
…100) The DiT block spends ~17% of denoise GPU time in memory-bound elementwise kernels: RMSNorm (q/k, with an fp32 up/down cast), modulate (x*(1+scale)+shift) and the gate (x + gate*residual). Fuse each via torch.compile(dynamic=True). Isolated (dim=1536, seq=25344): RMSNorm 0.556 -> 0.263 ms (2.12x, cos 1.000000) modulate+gate 0.438 -> 0.313 ms (1.40x, cos 0.999993) E2E @768x1408 (conv3d=gemm, TCDecoder NHWC): 32.9 -> 35.5 FPS, 1.94 -> 2.09x A100, PSNR(off,on) 49.2 dB (bf16-level). These fused fns contain no attention / custom kernels, so they don't interact with the block_sparse path or streaming cache. Knob FLASHVSR_FUSE_NORM (default off, opt-in due to compile warmup + tiny bf16 reorder diff). Adds test_fuse_norm.py (E2E parity + speed).
The bundled block_sparse_attn CUDA kernel is FlashAttention-2 style and emits only Ampere HMMA tensor-core ops even when compiled for sm_90 (verified in SASS: 380928 HMMA, 0 WGMMA), reaching only ~33% of bf16 peak (~327 TFLOP/s) at the real self-attn shape (seq=25344, 12 heads, dim=128, block-mask density ~0.606). cuDNN dense reaches ~62% peak via WGMMA but can't express FlashVSR's 2D-spatial block mask (cuDNN block_mask unsupported on sm_90; 1D band masks don't cover a 2D-local pattern efficiently). Add a Triton block-sparse FlashAttention kernel that honors the exact per-(q_block, kv_block) boolean mask and compiles to Hopper WGMMA (verified in PTX/SASS). It uses a CSR-style per-q-block kept-kv-index list so masked tiles are skipped entirely. Isolated (vs the real block_sparse_attn kernel, same mask): cos 0.99999, max|diff| 0.0005; 7.36 -> 6.06 ms (1.21x), ~41% peak. E2E @768x1408 (conv3d=gemm, TCDecoder NHWC, fuse_norm on): 35.5 -> 38.0 FPS, 2.09 -> 2.23x A100; PSNR(sparse,triton) 49.97 dB (bf16-level). Opt-in via FLASHVSR_ATTN_BACKEND=triton, guarded to sm_90, silent fallback to the original block_sparse kernel on non-Hopper / triton-missing / any error. Default remains 'sparse' (zero regression; Ampere keeps the original path).
Add a Tensor Memory Accelerator (TMA) fast path to the Triton block-sparse attention kernel. Q/K/V tiles are loaded via device TMA descriptors (triton TensorDescriptor), overlapping bulk global->shared loads with WGMMA. Isolated (real shape 25344x12x128, density 0.606): 6.11 -> 5.6 ms, ~40% -> ~44% of bf16 peak, parity cos 0.99999 (math unchanged; TMA only changes how memory is fetched). E2E @768x1408 (conv3d=gemm, TCDecoder NHWC, fuse_norm on): 2.23 -> 2.29x A100; PSNR(sparse, triton+TMA) 49.97 dB. TMA is the default when available (Triton TensorDescriptor present); guarded by FLASHVSR_ATTN_TMA (set 0 to disable) and falls back silently to the non-TMA kernel on older Triton / SMEM limits / any error. Still gated to sm_90 via the attention backend; Ampere keeps the original block_sparse path.
Two opt-in, bit-identical caches for per-block elementwise results that are recomputed every denoise step but depend only on fixed inputs (not q/k/x): - FLASHVSR_CACHE_MOD (default 0): cache (modulation + t_mod).chunk(...) per DiTBlock and Head. t_mod is computed once in init_cross_kv and is constant. - FLASHVSR_CACHE_MASK_BIAS (default 0): cache the geometry-only 0/-inf additive bias in generate_draft_block_mask (repeat + masked_fill), keyed on shape only. Both default OFF, pure elementwise, silent fallback; verified max|diff|==0 on all test clips (test_cache_lossless.py). These remove ~2169 small kernels per denoise. Also: profile_e2e_bottlenecks.py categorize() now classifies the Triton _bsfa kernel as attention and xmma_fprop as TCDecoder conv.
Add a "Hopper Acceleration" section to the README so users can discover and enable the opt-in fast paths: the env var table (all default OFF), the recommended full-speed one-liner, and notes on sm_90 guarding and parity (bit-identical vs ~49-50 dB PSNR paths).
|
Dude you’ve really done an amazing job |
|
amazing would any of these work on other gpus too? like rtx 3000 4000 series? 5000 series? |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Hopper (GH200 / sm_90) inference acceleration: opt-in, lossless by default
This PR adds a set of opt-in Hopper-specific fast paths for FlashVSR v1.1 Tiny inference. Every optimization is gated behind an environment knob and defaults to OFF, so the default output is bit-for-bit identical to current FlashVSR and Ampere / A100 behaviour is completely unchanged.
Motivation
On a Hopper GPU (NVIDIA GH200, sm_90), the stock inference path leaves most of the hardware idle: the LQ-projector
CausalConv3druns through a cuDNN path that doesn't use the tensor cores, and the block-sparse attention kernel emits Ampere-styleHMMAinstead of HopperWGMMA. As a result the much stronger GH200 ran the same workload at roughly the same speed as the A100 baseline. This PR closes that gap.Results (NVIDIA GH200, FlashVSR v1.1 Tiny, 768x1408, before -> after)
All numbers are isolated single-config runs (best-of-8, one config per process to avoid warmup/ordering bias) on the same GH200, same clip, denoise FPS. "Before" = stock path on GH200, "after" = all knobs enabled. The stack is cumulative.
channels_last(NHWC)torch.compile)(For reference, the README quotes ~17 FPS at this resolution on an A100; the "before" GH200 number matches that, i.e. the stock path does not benefit from the newer hardware.)
What's added (all env knobs, default OFF)
FLASHVSR_CONV3D_BACKENDauto,gemmautogemm= im2col + WGMMA conv3d for the LQ projectorFLASHVSR_CONV3D_IM2COL_BUDGET_GB2.0FLASHVSR_TCDECODER_CHANNELS_LAST0,10FLASHVSR_FUSE_NORM0,10torch.compilefusionFLASHVSR_ATTN_BACKENDsparse,triton,auto,densesparsetriton= Hopper WGMMA block-sparse kernelFLASHVSR_ATTN_TMA0,11*tritonbackend)FLASHVSR_CACHE_MOD0,10(modulation + t_mod).chunk(...)(bit-identical)FLASHVSR_CACHE_MASK_BIAS0,10*
FLASHVSR_ATTN_TMAonly matters whenFLASHVSR_ATTN_BACKEND=tritonis selected (it has no effect on the defaultsparsebackend). Every knob is OFF by default, so with no environment variables set the output is bit-for-bit the original FlashVSR.Suggested full-speed config (single line):
Design rationale: why each path was chosen
Each optimization started from a measured bottleneck (
torch.profiler+ Nsight Systems/Compute on GH200), and several alternatives were tried and rejected before landing the version here.1. conv3d im2col + GEMM (
FLASHVSR_CONV3D_BACKEND=gemm)The LQ-projector
CausalConv3dlayers were ~50% of denoise time. cuDNN cannot pick a tensor-core implicit-GEMM for these shapes (kernel(4,3,3), stride(2,1,1)), leaving the tensor cores idle. Reformulating the conv as im2col + a bf16 GEMM (addmm) puts it on the tensor cores: on the largest conv (2048->3072) we measured cuDNN 298 ms -> GEMM 17.5 ms (~17x) on the kernel, paritycos 0.999996, causal replicate-pad + streaming cache preserved exactly. E2E this is the single biggest win (17.2 -> 31.5 FPS, see table).channels_last_3don the conv3d: it made the conv slower, so it was dropped.FLASHVSR_CONV3D_IM2COL_BUDGET_GB). Measured on the largest conv (2048->3072) at a big latent: the transient peak drops 6.07 GB -> 3.70 GB (2 GB budget) -> 2.17 GB (0.5 GB budget), bit-identical and with no speed loss. Note this bounds the conv3d transient, not necessarily the end-to-end peak; at typical resolutions the overall peak is dominated by the TCDecoder, so chunking is mainly insurance against the conv3d patch blowing up at very high resolution. A Triton fused im2col was prototyped but turned out unnecessary, plain-PyTorch chunking was enough.2. TCDecoder
channels_last(FLASHVSR_TCDECODER_CHANNELS_LAST)The TCDecoder is a pure
nn.Conv2dgraph running contiguous (NCHW). On bf16, cuDNN inserted annchwToNhwc/nhwcToNchwlayout conversion around every conv (~226 ms, ~9% of denoise GPU) and the convs themselves were ~1.5x slower. Running the decoder in NHWC removes the layout kernels entirely (226 ms -> 0.3 ms) and is bit-identical (max|diff| = 0).3. norm/elementwise fusion (
FLASHVSR_FUSE_NORM)The DiT block (30 layers x many steps) is full of memory-bound elementwise kernels: RMSNorm (with an fp32 up/down cast),
modulate(x*(1+scale)+shift), and the gate (x+gate*residual), together ~17% of denoise GPU.torch.compilefuses each chain into a single kernel (severalx faster in isolation, the exact ratio depending on tensor shape). It's default OFF because it has a compile warmup and a tiny bf16 reorder difference; measured 49.2 dB PSNR end-to-end vs the unfused path (test_fuse_norm.py). It touches only pure elementwise math, never the attention or streaming-cache path.4. Triton WGMMA block-sparse attention (
FLASHVSR_ATTN_BACKEND=triton+FLASHVSR_ATTN_TMA)At the real shape (seq 25344, 12 heads, dim 128, block-mask density ~0.6) the bundled
block_sparse_attnkernel compiles to AmpereHMMAeven on sm_90 (0WGMMA), reaching only ~33% of peak. We considered several routes:WGMMA(~62% peak) but cannot express the arbitrary 2D block mask, and routing to dense changes the output (it's full attention, not the trained sparse pattern), rejected for the default.block_maskis the exact API we'd want but is Blackwell-only in the cuDNN builds available here (confirmedARCH_MISMATCHon Hopper), not usable.The shipped kernel is a Triton block-sparse FlashAttention that honours the exact same 2D block mask and compiles to real Hopper
WGMMA(verified in PTX/SASS), skipping the masked blocks (work scales with density). At the kernel level it matchesblock_sparse_attnvery closely (cos ~ 0.99999); end-to-end, swapping it in measures ~49.7 dB PSNR vs thesparsebackend (the small difference is the WGMMA vs HMMA accumulation order, not a change in the mask). Isolated single-config FPS: sparse 35.5 -> triton 37.7 -> triton+TMA 38.8 (so triton is ~+9% over the bundled kernel; TMA adds a further ~2% end-to-end (it's ~10% at the isolated kernel level but most of that is hidden behind other work in the full pipeline). It is sm_90-guarded and silently falls back toblock_sparse_attnon Ampere or any error; theblock_sparse_attndependency is unchanged, the Triton kernel is just an alternate path.5. Lossless step-invariant caches (
FLASHVSR_CACHE_MOD,FLASHVSR_CACHE_MASK_BIAS)Two per-block results are recomputed every denoise step but depend only on fixed inputs, not on
x/q/k: the(modulation + t_mod).chunk(...)split (the timestep modulation is computed once ininit_cross_kvand is constant), and the geometry-only0/-infadditive attention bias. Caching them is bit-identical (max|diff| = 0) and removes ~2169 small kernel launches per denoise.Safety
block_sparse_attndependency is unchanged; the Triton kernel is an alternate path, not a replacement.channels_last,CACHE_MOD,CACHE_MASK_BIAS) are verifiedmax|diff| == 0(bit-identical). The opt-in approximate paths are verified end-to-end:FUSE_NORM~ 49.2 dB PSNR, Triton attention ~ 49.7 dB PSNR vs the original backends.Docs
The README gets a new "Hopper Acceleration" section documenting all env knobs (the table, the recommended full-speed one-liner, and the parity/
sm_90notes), so the fast paths are discoverable without reading the code.Tests
The PR includes isolated parity + benchmark scripts under
examples/WanVSR/(test_conv3d_gemm_parity.py,test_tcdecoder_channels_last.py,test_fuse_norm.py,test_attention_backend.py,test_cache_lossless.py,test_a100_ref_768x1408.py, plus profiling helpers). Each new fast path is checked for both parity and speed.Notes
stepvideo_text_encoder.pyis included: newertransformers(verified on 5.5.0) no longer exposesPretrainedConfigfromtransformers.modeling_utils, which breaksimport diffsynth. The fix falls back to importing it fromtransformersdirectly. Without it the package does not import on recenttransformers.