Skip to content

Enable fused SDPA vector kernel for asymmetric Q/V head dims (192, 128)#3637

Open
yohann-bearzi wants to merge 2 commits into
ml-explore:mainfrom
yohann-bearzi:sdpa-asym-headdim
Open

Enable fused SDPA vector kernel for asymmetric Q/V head dims (192, 128)#3637
yohann-bearzi wants to merge 2 commits into
ml-explore:mainfrom
yohann-bearzi:sdpa-asym-headdim

Conversation

@yohann-bearzi

Copy link
Copy Markdown

The sdpa_vector and sdpa_vector_2pass_1 kernels are already templated on a separate value head dim (template <typename T, int D, int V = D>), but no (D, V) pairs with D != V were instantiated, and use_fallback required query_head_dim == value_head_dim. Models with asymmetric head dims therefore fall back to a compiled-graph attention decomposition (multiple GatherAxis dispatches per layer) instead of the fused kernel.

This adds instantiate_sdpa_vector(type, 192, 128) and relaxes use_fallback to allow this specific asymmetric case. All other head dims are unchanged.

Motivation: MiMo-V2.5 uses head_dim=192 for Q/K and v_head_dim=128 for V. On current main it hits the fallback path.

Testing: Verified on MiMo-V2.5 (Metal, M3 Ultra): decode throughput 28.7 → 29.8 tok/s (+4%), generated output (top-1) unchanged. The kernel templates already supported V != D; this only instantiates and enables the existing code path. Incremental build on current main is clean.

Scope: Minimal — 2 files, +5/-3. No new kernel code, just one instantiation + the fallback guard.

The sdpa_vector kernel template already supported a separate value
head_dim (template <typename T, int D, int V = D>), but no (D, V) pairs
with D != V were instantiated, and use_fallback required
query_head_dim == value_head_dim.

MiMo-V2.5 uses head_dim=192 for Q/K and v_head_dim=128 for V, falling
through to a compiled-graph decomposition (multiple GatherAxis dispatches
per attention layer) instead of the fused kernel.

Adds instantiate_sdpa_vector(type, 192, 128) and relaxes use_fallback to
allow this specific asymmetric case. Other head dims remain unchanged.

Verified on MiMo-V2.5: decode 28.7 -> 29.8 tok/s (+4%). Top-1 stable.
The remaining decode bottleneck is MoE routing, not attention.
@yohann-bearzi

Copy link
Copy Markdown
Author

For models using block_fp8 quantization (e.g. MiMo-V2.5, which motivated this asymmetric-head-dim fix), the block_fp8 matmul and MoE kernels are available as a standalone MLX extension: https://github.com/yohann-bearzi/mlx-block-fp8 — it builds against stock upstream MLX (kernels vendored, no fork) and pairs with this SDPA change for full MiMo-V2.5 decode throughput.

The MiMo-V2.5 MLX weights (block_fp8) are on Hugging Face: https://huggingface.co/bearzi/MiMo-V2.5-MLX

@yohann-bearzi

Copy link
Copy Markdown
Author

@zcbenz would you please be able to review?

@zcbenz zcbenz left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you run some benchmarks? The benchmarks/python/sdpa_bench.py can be simply modified. The fused kernel is not guaranteed to be faster than unused one.

@yohann-bearzi

Copy link
Copy Markdown
Author

Benchmarked with a MiMo-shaped variant of benchmarks/python/sdpa_bench.py (asymmetric Dqk=192, Dv=128), added as benchmarks/python/sdpa_bench_mimo.py. It compares the fused vector kernel against the stock fallback by running the same mx.fast.scaled_dot_product_attention call on this branch vs current main — on main the (192, 128) case takes the fallback decomposition. Apple M3 Ultra, bf16, decode (query len 1), MiMo's head config (64 Q heads; 4 KV full-attention / 8 KV SWA), causal mask, 50 iters:

Full-attention layers (64 Q / 4 KV):

KV len fallback (us) fused (us) speedup
256 306 218 1.41x
1024 281 217 1.29x
4096 370 255 1.45x
8192 497 305 1.63x
16384 817 382 2.14x
32768 1392 599 2.32x

SWA layers (64 Q / 8 KV):

KV len fallback (us) fused (us) speedup
256 257 207 1.24x
1024 280 222 1.27x
4096 373 288 1.30x
8192 506 358 1.41x
16384 841 444 1.89x
32768 1455 728 2.00x

The fused kernel is faster across the full sweep — roughly 1.2-1.4x at short context, growing to ~2x past 16K KV (the fallback's separate softmax + gather dispatches scale worse with sequence length). Faster at every shape tested, no regressions, and the result is stable across repeated runs. Output matches the reference decomposition (max abs diff ~1e-4 in bf16; top-1 generation unchanged).

The end-to-end MiMo decode gain is smaller (~3-4%), since attention is only a fraction of decode time — MoE routing dominates, as noted in the commit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants