Skip to content

Hopper (GH200 / sm_90) inference acceleration: opt-in, lossless by default#108

Open
hamuzhan wants to merge 8 commits into
OpenImagingLab:mainfrom
hamuzhan:hopper-acceleration
Open

Hopper (GH200 / sm_90) inference acceleration: opt-in, lossless by default#108
hamuzhan wants to merge 8 commits into
OpenImagingLab:mainfrom
hamuzhan:hopper-acceleration

Conversation

@hamuzhan

Copy link
Copy Markdown

Hopper (GH200 / sm_90) inference acceleration: opt-in, lossless by default

This PR adds a set of opt-in Hopper-specific fast paths for FlashVSR v1.1 Tiny inference. Every optimization is gated behind an environment knob and defaults to OFF, so the default output is bit-for-bit identical to current FlashVSR and Ampere / A100 behaviour is completely unchanged.

Motivation

On a Hopper GPU (NVIDIA GH200, sm_90), the stock inference path leaves most of the hardware idle: the LQ-projector CausalConv3d runs through a cuDNN path that doesn't use the tensor cores, and the block-sparse attention kernel emits Ampere-style HMMA instead of Hopper WGMMA. As a result the much stronger GH200 ran the same workload at roughly the same speed as the A100 baseline. This PR closes that gap.

Results (NVIDIA GH200, FlashVSR v1.1 Tiny, 768x1408, before -> after)

All numbers are isolated single-config runs (best-of-8, one config per process to avoid warmup/ordering bias) on the same GH200, same clip, denoise FPS. "Before" = stock path on GH200, "after" = all knobs enabled. The stack is cumulative.

Stage Optimization Denoise FPS Speedup vs before
before stock (cuDNN conv3d, HMMA attn) 17.2 1.00x
+conv3d im2col + WGMMA conv3d (LQ projector) 31.5 1.84x
+decoder TCDecoder channels_last (NHWC) 32.9 1.92x
+norm norm/elementwise fusion (torch.compile) 35.6 2.07x
+attn Triton WGMMA block-sparse attention 37.7 2.20x
+TMA TMA bulk loads for the attention kernel 38.8 ~2.26x

(For reference, the README quotes ~17 FPS at this resolution on an A100; the "before" GH200 number matches that, i.e. the stock path does not benefit from the newer hardware.)

What's added (all env knobs, default OFF)

Knob Values Default Effect
FLASHVSR_CONV3D_BACKEND auto, gemm auto gemm = im2col + WGMMA conv3d for the LQ projector
FLASHVSR_CONV3D_IM2COL_BUDGET_GB float 2.0 chunked im2col memory budget
FLASHVSR_TCDECODER_CHANNELS_LAST 0,1 0 NHWC TCDecoder (bit-identical)
FLASHVSR_FUSE_NORM 0,1 0 norm / modulate / gate torch.compile fusion
FLASHVSR_ATTN_BACKEND sparse,triton,auto,dense sparse triton = Hopper WGMMA block-sparse kernel
FLASHVSR_ATTN_TMA 0,1 1* TMA bulk load (only used by the triton backend)
FLASHVSR_CACHE_MOD 0,1 0 cache step-invariant (modulation + t_mod).chunk(...) (bit-identical)
FLASHVSR_CACHE_MASK_BIAS 0,1 0 cache the geometry-only attention bias (bit-identical)

* FLASHVSR_ATTN_TMA only matters when FLASHVSR_ATTN_BACKEND=triton is selected (it has no effect on the default sparse backend). Every knob is OFF by default, so with no environment variables set the output is bit-for-bit the original FlashVSR.

Suggested full-speed config (single line):

FLASHVSR_CONV3D_BACKEND=gemm FLASHVSR_TCDECODER_CHANNELS_LAST=1 \
FLASHVSR_FUSE_NORM=1 FLASHVSR_ATTN_BACKEND=triton \
python infer_flashvsr_v1.1_tiny.py

Design rationale: why each path was chosen

Each optimization started from a measured bottleneck (torch.profiler + Nsight Systems/Compute on GH200), and several alternatives were tried and rejected before landing the version here.

1. conv3d im2col + GEMM (FLASHVSR_CONV3D_BACKEND=gemm)
The LQ-projector CausalConv3d layers were ~50% of denoise time. cuDNN cannot pick a tensor-core implicit-GEMM for these shapes (kernel (4,3,3), stride (2,1,1)), leaving the tensor cores idle. Reformulating the conv as im2col + a bf16 GEMM (addmm) puts it on the tensor cores: on the largest conv (2048->3072) we measured cuDNN 298 ms -> GEMM 17.5 ms (~17x) on the kernel, parity cos 0.999996, causal replicate-pad + streaming cache preserved exactly. E2E this is the single biggest win (17.2 -> 31.5 FPS, see table).

  • Rejected: channels_last_3d on the conv3d: it made the conv slower, so it was dropped.
  • The naive im2col patch tensor gets large at high resolution, so we tile it along the output-H axis under a memory budget (FLASHVSR_CONV3D_IM2COL_BUDGET_GB). Measured on the largest conv (2048->3072) at a big latent: the transient peak drops 6.07 GB -> 3.70 GB (2 GB budget) -> 2.17 GB (0.5 GB budget), bit-identical and with no speed loss. Note this bounds the conv3d transient, not necessarily the end-to-end peak; at typical resolutions the overall peak is dominated by the TCDecoder, so chunking is mainly insurance against the conv3d patch blowing up at very high resolution. A Triton fused im2col was prototyped but turned out unnecessary, plain-PyTorch chunking was enough.

2. TCDecoder channels_last (FLASHVSR_TCDECODER_CHANNELS_LAST)
The TCDecoder is a pure nn.Conv2d graph running contiguous (NCHW). On bf16, cuDNN inserted an nchwToNhwc/nhwcToNchw layout conversion around every conv (~226 ms, ~9% of denoise GPU) and the convs themselves were ~1.5x slower. Running the decoder in NHWC removes the layout kernels entirely (226 ms -> 0.3 ms) and is bit-identical (max|diff| = 0).

3. norm/elementwise fusion (FLASHVSR_FUSE_NORM)
The DiT block (30 layers x many steps) is full of memory-bound elementwise kernels: RMSNorm (with an fp32 up/down cast), modulate (x*(1+scale)+shift), and the gate (x+gate*residual), together ~17% of denoise GPU. torch.compile fuses each chain into a single kernel (severalx faster in isolation, the exact ratio depending on tensor shape). It's default OFF because it has a compile warmup and a tiny bf16 reorder difference; measured 49.2 dB PSNR end-to-end vs the unfused path (test_fuse_norm.py). It touches only pure elementwise math, never the attention or streaming-cache path.

4. Triton WGMMA block-sparse attention (FLASHVSR_ATTN_BACKEND=triton + FLASHVSR_ATTN_TMA)
At the real shape (seq 25344, 12 heads, dim 128, block-mask density ~0.6) the bundled block_sparse_attn kernel compiles to Ampere HMMA even on sm_90 (0 WGMMA), reaching only ~33% of peak. We considered several routes:

  • cuDNN dense fused attention uses WGMMA (~62% peak) but cannot express the arbitrary 2D block mask, and routing to dense changes the output (it's full attention, not the trained sparse pattern), rejected for the default.
  • cuDNN block_mask is the exact API we'd want but is Blackwell-only in the cuDNN builds available here (confirmed ARCH_MISMATCH on Hopper), not usable.
  • cuDNN diagonal-band / sliding-window masks are 1-D; FlashVSR's mask is 2-D-spatial-local, so covering it with a band raises density too much to help, rejected.

The shipped kernel is a Triton block-sparse FlashAttention that honours the exact same 2D block mask and compiles to real Hopper WGMMA (verified in PTX/SASS), skipping the masked blocks (work scales with density). At the kernel level it matches block_sparse_attn very closely (cos ~ 0.99999); end-to-end, swapping it in measures ~49.7 dB PSNR vs the sparse backend (the small difference is the WGMMA vs HMMA accumulation order, not a change in the mask). Isolated single-config FPS: sparse 35.5 -> triton 37.7 -> triton+TMA 38.8 (so triton is ~+9% over the bundled kernel; TMA adds a further ~2% end-to-end (it's ~10% at the isolated kernel level but most of that is hidden behind other work in the full pipeline). It is sm_90-guarded and silently falls back to block_sparse_attn on Ampere or any error; the block_sparse_attn dependency is unchanged, the Triton kernel is just an alternate path.

5. Lossless step-invariant caches (FLASHVSR_CACHE_MOD, FLASHVSR_CACHE_MASK_BIAS)
Two per-block results are recomputed every denoise step but depend only on fixed inputs, not on x/q/k: the (modulation + t_mod).chunk(...) split (the timestep modulation is computed once in init_cross_kv and is constant), and the geometry-only 0/-inf additive attention bias. Caching them is bit-identical (max|diff| = 0) and removes ~2169 small kernel launches per denoise.

Safety

  • Every fast path is sm_90-guarded with a silent fallback to the original code (on Ampere or on any error it transparently reverts).
  • The block_sparse_attn dependency is unchanged; the Triton kernel is an alternate path, not a replacement.
  • Lossless paths (channels_last, CACHE_MOD, CACHE_MASK_BIAS) are verified max|diff| == 0 (bit-identical). The opt-in approximate paths are verified end-to-end: FUSE_NORM ~ 49.2 dB PSNR, Triton attention ~ 49.7 dB PSNR vs the original backends.

Docs

The README gets a new "Hopper Acceleration" section documenting all env knobs (the table, the recommended full-speed one-liner, and the parity/sm_90 notes), so the fast paths are discoverable without reading the code.

Tests

The PR includes isolated parity + benchmark scripts under examples/WanVSR/ (test_conv3d_gemm_parity.py, test_tcdecoder_channels_last.py, test_fuse_norm.py, test_attention_backend.py, test_cache_lossless.py, test_a100_ref_768x1408.py, plus profiling helpers). Each new fast path is checked for both parity and speed.

Notes

  • Tested on GH200 (sm_90). The kernels also run on H200 but the speedup is smaller due to scheduling differences. Other GPUs fall back to the original path.
  • A small, unrelated import-compat fix for stepvideo_text_encoder.py is included: newer transformers (verified on 5.5.0) no longer exposes PretrainedConfig from transformers.modeling_utils, which breaks import diffsynth. The fix falls back to importing it from transformers directly. Without it the package does not import on recent transformers.

hamuzhan added 8 commits June 17, 2026 11:15
The LQ-projector CausalConv3d kernels (4x3x3, stride (2,1,1), large channels)
run at ~27 TFLOP/s (~2.8% peak) on GH200 because cuDNN can't pick a tensor-core
implicit-GEMM for these shapes. They account for ~50% of denoise time.

Reformulate the core convolution as explicit im2col (unfold) + bf16 GEMM, which
saturates Hopper WGMMA tensor cores (~400+ TFLOP/s, ~15-17x faster) with
bit-identical math (parity cosine 0.999996).

- Phase 1 (D1): _conv3d_gemm + FLASHVSR_CONV3D_BACKEND={auto|gemm} knob +
  sm_90 guard + silent fallback. Causal replicate-pad + streaming cache
  preserved; only the padding-free core conv is rerouted. Single-point change
  benefits conv1/conv2 in both Buffer_/Causal_ projectors.
- Phase 2 (D1.5): chunked im2col over the output-H axis bounded by
  FLASHVSR_CONV3D_IM2COL_BUDGET_GB (default 2 GB) -> conv transient mem -51%
  at 1920x2560 with no speed/parity loss.

E2E (v1.1 Tiny): 1.92x @1536, 1.91x @2560x1920; norm-FPS ~17 -> ~33-34;
PSNR(auto,gemm) 47.6-49.5 dB. Default 'auto' = no regression.

Also fix transformers PretrainedConfig import (renamed -> PreTrainedConfig)
so the pipeline import chain works on current transformers.
Compares GH200 against the README's A100 reference (~17 FPS @ 768x1408) at the
exact same resolution, with both conv3d backends.

Results @ 768x1408:
- GH200 auto (cuDNN):       16.54 FPS  (0.97x A100 - parity, no Hopper gain)
- GH200 gemm (tensor-core): 31.56 FPS  (1.86x A100)
- gemm vs auto:             1.91x

Confirms the thesis: before the im2col+GEMM backend, GH200 was stuck at A100
parity despite 3x the hardware; the tensor-core conv path unlocks the real
Hopper advantage.
…eneck tooling

After the conv3d im2col+GEMM work (1.86x A100), profiled the full denoise
pipeline @768x1408 to find the next bottlenecks. GPU is ~91% busy (compute-
bound). Breakdown: attention ~27%, GEMM ~21%, TCDecoder conv2d ~19%,
norm/elementwise ~17%, copy/layout ~15%.

3-A) TCDecoder channels_last (NHWC):
  The TCDecoder is a pure Conv2d (TAEHV) graph running contiguous (NCHW), which
  made cuDNN insert nchwToNhwc/nhwcToNchw around every bf16 conv (~226ms /9%)
  and run the convs ~1.5x slower. Run the decoder in channels_last:
  - isolated decode: 231 -> 189 ms (1.22x), bit-identical (max|diff|=0)
  - E2E: 31.6 -> 32.9 FPS, 1.86 -> 1.93x A100, -0.5 GB peak
  Knob FLASHVSR_TCDECODER_CHANNELS_LAST (default on).

3-B) Adaptive attention backend (FLASHVSR_ATTN_BACKEND, default 'sparse'):
  Measured the real self-attn: seq=25344, 12 heads, dim=128, block-mask
  density ~0.606 (only 39% sparse). At that density cuDNN fused dense SDPA
  (6.5ms, 605 TFLOP/s) beats block_sparse (7.3ms) AND uses full context;
  FA2 dense is 10.6ms (cuDNN wins by 1.64x). Crossover is density ~0.5.
  Added density-adaptive routing (opt-in). E2E gain at default topk is
  negligible (~+0.5%) and dense changes the output (full vs trained sparse
  pattern, PSNR 41 dB), so default stays 'sparse' (no quality change).

Tooling added (isolated + E2E):
  profile_e2e_bottlenecks.py, probe_attention_shapes.py,
  test_tcdecoder_channels_last.py, test_attention_backend.py, test_topk_sweep.py

Note (not committed as default): topk/sparsity sweep shows sparse_ratio=1.5
gives 2.07x A100 @ PSNR 42.8 dB vs baseline (a documented 'faster' setting).
…100)

The DiT block spends ~17% of denoise GPU time in memory-bound elementwise
kernels: RMSNorm (q/k, with an fp32 up/down cast), modulate (x*(1+scale)+shift)
and the gate (x + gate*residual). Fuse each via torch.compile(dynamic=True).

Isolated (dim=1536, seq=25344):
  RMSNorm        0.556 -> 0.263 ms (2.12x, cos 1.000000)
  modulate+gate  0.438 -> 0.313 ms (1.40x, cos 0.999993)

E2E @768x1408 (conv3d=gemm, TCDecoder NHWC):
  32.9 -> 35.5 FPS, 1.94 -> 2.09x A100, PSNR(off,on) 49.2 dB (bf16-level).

These fused fns contain no attention / custom kernels, so they don't interact
with the block_sparse path or streaming cache. Knob FLASHVSR_FUSE_NORM
(default off, opt-in due to compile warmup + tiny bf16 reorder diff).

Adds test_fuse_norm.py (E2E parity + speed).
The bundled block_sparse_attn CUDA kernel is FlashAttention-2 style and emits
only Ampere HMMA tensor-core ops even when compiled for sm_90 (verified in SASS:
380928 HMMA, 0 WGMMA), reaching only ~33% of bf16 peak (~327 TFLOP/s) at the
real self-attn shape (seq=25344, 12 heads, dim=128, block-mask density ~0.606).
cuDNN dense reaches ~62% peak via WGMMA but can't express FlashVSR's 2D-spatial
block mask (cuDNN block_mask unsupported on sm_90; 1D band masks don't cover a
2D-local pattern efficiently).

Add a Triton block-sparse FlashAttention kernel that honors the exact
per-(q_block, kv_block) boolean mask and compiles to Hopper WGMMA (verified in
PTX/SASS). It uses a CSR-style per-q-block kept-kv-index list so masked tiles
are skipped entirely.

Isolated (vs the real block_sparse_attn kernel, same mask):
  cos 0.99999, max|diff| 0.0005; 7.36 -> 6.06 ms (1.21x), ~41% peak.

E2E @768x1408 (conv3d=gemm, TCDecoder NHWC, fuse_norm on):
  35.5 -> 38.0 FPS, 2.09 -> 2.23x A100; PSNR(sparse,triton) 49.97 dB (bf16-level).

Opt-in via FLASHVSR_ATTN_BACKEND=triton, guarded to sm_90, silent fallback to
the original block_sparse kernel on non-Hopper / triton-missing / any error.
Default remains 'sparse' (zero regression; Ampere keeps the original path).
Add a Tensor Memory Accelerator (TMA) fast path to the Triton block-sparse
attention kernel. Q/K/V tiles are loaded via device TMA descriptors
(triton TensorDescriptor), overlapping bulk global->shared loads with WGMMA.

Isolated (real shape 25344x12x128, density 0.606): 6.11 -> 5.6 ms, ~40% -> ~44%
of bf16 peak, parity cos 0.99999 (math unchanged; TMA only changes how memory
is fetched).

E2E @768x1408 (conv3d=gemm, TCDecoder NHWC, fuse_norm on):
  2.23 -> 2.29x A100; PSNR(sparse, triton+TMA) 49.97 dB.

TMA is the default when available (Triton TensorDescriptor present); guarded by
FLASHVSR_ATTN_TMA (set 0 to disable) and falls back silently to the non-TMA
kernel on older Triton / SMEM limits / any error. Still gated to sm_90 via the
attention backend; Ampere keeps the original block_sparse path.
Two opt-in, bit-identical caches for per-block elementwise results that are
recomputed every denoise step but depend only on fixed inputs (not q/k/x):

- FLASHVSR_CACHE_MOD (default 0): cache (modulation + t_mod).chunk(...) per
  DiTBlock and Head. t_mod is computed once in init_cross_kv and is constant.
- FLASHVSR_CACHE_MASK_BIAS (default 0): cache the geometry-only 0/-inf additive
  bias in generate_draft_block_mask (repeat + masked_fill), keyed on shape only.

Both default OFF, pure elementwise, silent fallback; verified max|diff|==0 on
all test clips (test_cache_lossless.py). These remove ~2169 small kernels per
denoise.

Also: profile_e2e_bottlenecks.py categorize() now classifies the Triton
_bsfa kernel as attention and xmma_fprop as TCDecoder conv.
Add a "Hopper Acceleration" section to the README so users can discover and
enable the opt-in fast paths: the env var table (all default OFF), the
recommended full-speed one-liner, and notes on sm_90 guarding and parity
(bit-identical vs ~49-50 dB PSNR paths).
@napervays

Copy link
Copy Markdown

Dude you’ve really done an amazing job

@FurkanGozukara

Copy link
Copy Markdown

amazing would any of these work on other gpus too? like rtx 3000 4000 series? 5000 series?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants