Enable fused SDPA vector kernel for asymmetric Q/V head dims (192, 128)#3637
Enable fused SDPA vector kernel for asymmetric Q/V head dims (192, 128)#3637yohann-bearzi wants to merge 2 commits into
Conversation
The sdpa_vector kernel template already supported a separate value head_dim (template <typename T, int D, int V = D>), but no (D, V) pairs with D != V were instantiated, and use_fallback required query_head_dim == value_head_dim. MiMo-V2.5 uses head_dim=192 for Q/K and v_head_dim=128 for V, falling through to a compiled-graph decomposition (multiple GatherAxis dispatches per attention layer) instead of the fused kernel. Adds instantiate_sdpa_vector(type, 192, 128) and relaxes use_fallback to allow this specific asymmetric case. Other head dims remain unchanged. Verified on MiMo-V2.5: decode 28.7 -> 29.8 tok/s (+4%). Top-1 stable. The remaining decode bottleneck is MoE routing, not attention.
|
For models using block_fp8 quantization (e.g. MiMo-V2.5, which motivated this asymmetric-head-dim fix), the block_fp8 matmul and MoE kernels are available as a standalone MLX extension: https://github.com/yohann-bearzi/mlx-block-fp8 — it builds against stock upstream MLX (kernels vendored, no fork) and pairs with this SDPA change for full MiMo-V2.5 decode throughput. The MiMo-V2.5 MLX weights (block_fp8) are on Hugging Face: https://huggingface.co/bearzi/MiMo-V2.5-MLX |
|
@zcbenz would you please be able to review? |
zcbenz
left a comment
There was a problem hiding this comment.
Can you run some benchmarks? The benchmarks/python/sdpa_bench.py can be simply modified. The fused kernel is not guaranteed to be faster than unused one.
|
Benchmarked with a MiMo-shaped variant of Full-attention layers (64 Q / 4 KV):
SWA layers (64 Q / 8 KV):
The fused kernel is faster across the full sweep — roughly 1.2-1.4x at short context, growing to ~2x past 16K KV (the fallback's separate softmax + gather dispatches scale worse with sequence length). Faster at every shape tested, no regressions, and the result is stable across repeated runs. Output matches the reference decomposition (max abs diff ~1e-4 in bf16; top-1 generation unchanged). The end-to-end MiMo decode gain is smaller (~3-4%), since attention is only a fraction of decode time — MoE routing dominates, as noted in the commit. |
The
sdpa_vectorandsdpa_vector_2pass_1kernels are already templated on a separate value head dim (template <typename T, int D, int V = D>), but no(D, V)pairs withD != Vwere instantiated, anduse_fallbackrequiredquery_head_dim == value_head_dim. Models with asymmetric head dims therefore fall back to a compiled-graph attention decomposition (multipleGatherAxisdispatches per layer) instead of the fused kernel.This adds
instantiate_sdpa_vector(type, 192, 128)and relaxesuse_fallbackto allow this specific asymmetric case. All other head dims are unchanged.Motivation: MiMo-V2.5 uses
head_dim=192for Q/K andv_head_dim=128for V. On currentmainit hits the fallback path.Testing: Verified on MiMo-V2.5 (Metal, M3 Ultra): decode throughput 28.7 → 29.8 tok/s (+4%), generated output (top-1) unchanged. The kernel templates already supported
V != D; this only instantiates and enables the existing code path. Incremental build on currentmainis clean.Scope: Minimal — 2 files, +5/-3. No new kernel code, just one instantiation + the fallback guard.