Skip to content

perf(moe): fused decode-MoE kernel foundation (#268)#274

Merged
inureyes merged 1 commit into
mainfrom
perf/issue-268-fused-decode-moe
Jun 14, 2026
Merged

perf(moe): fused decode-MoE kernel foundation (#268)#274
inureyes merged 1 commit into
mainfrom
perf/issue-268-fused-decode-moe

Conversation

@inureyes

Copy link
Copy Markdown
Member

Summary

Starts the fused decode-MoE kernel effort (option A from the #268 investigation), as a new PR off latest main. Foundation only: design + trace harness, no model behavior change.

Why not the cheap fusions

The MoE decode gap is GPU-bound idle between small kernels on the expert path (Step 0/1 of the investigation: graph build 0.8 ms/tok vs 20.3 ms GPU, ~16-20% bandwidth). Fusing the combine (moe_weighted_sum) or the router post-processing saves ~1 dispatch/layer ≈ 48/token ≈ 38 µs against a 21 ms token (~0.18%). Negligible. The win has to come from a single Metal kernel over the expert path (gather + 4/6-bit dequant + gate/up/down + swiglu + weighted-sum) so the GPU stops idling between the gather_qmm calls.

What's here

  • docs/benchmark_results/fused-moe-decode-kernel-design.md — the kernel design built on the fast::metal_kernel JIT path (same as ssm_update_kernel): inputs/outputs/template args, the seq_len == 1 dispatch guard with SwitchGLU fallback, the in-kernel affine dequant-GEMV as the hard part, risks (4/6-bit + mixed bits), the validation gate (RMS < 5e-3, greedy parity, decode bench, trace check), and a one-PR-per-step roadmap.
  • scripts/capture_moe_decode_trace.sh — capture one warm MoE decode token as a Metal trace (gputrace or xctrace) to localize the expert-path GPU idle before/after.
  • Linked both from docs/benchmarks.md.

Scope

Documentation + tooling. Issue #268 stays open; the fused kernel lands next, trace-directed per the roadmap. Builds on the merged investigation report.

Starts the fused decode-MoE kernel work (option A from the #268 investigation). The MoE decode gap is GPU-bound idle between small kernels on the expert path, so the real win needs a single Metal kernel for the single-token expert computation (gather + 4/6-bit dequant + gate/up/down + swiglu + weighted-sum). The cheap small fusions do not help: fusing the combine or router post-processing saves ~1 dispatch/layer (~0.18% by the dispatch-count math), and per-dispatch overhead is not the bottleneck (graph build is 0.8 ms/tok vs 20.3 ms GPU).

This PR lays the groundwork the kernel needs:
- docs/benchmark_results/fused-moe-decode-kernel-design.md: the kernel design (fast::metal_kernel based, same JIT path as ssm_update_kernel), the seq_len==1 dispatch guard, the per-step validation gate (RMS < 5e-3, greedy parity, decode bench, trace check), and a one-PR-per-step roadmap.
- scripts/capture_moe_decode_trace.sh: capture one warm MoE decode token as a Metal trace (gputrace or xctrace) to localize the inter-kernel GPU idle before and after the kernel lands.

No model behavior change. The fused kernel itself lands in the next PR, trace-directed.
@inureyes inureyes added type:performance Performance improvements area:core mlxcel-core: MLX FFI, primitives, KV cache, layers area:models Model architectures, weights, loading, metadata status:done Completed labels Jun 13, 2026
@inureyes inureyes merged commit ca3b00e into main Jun 14, 2026
5 checks passed
@inureyes inureyes deleted the perf/issue-268-fused-decode-moe branch June 14, 2026 01:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:core mlxcel-core: MLX FFI, primitives, KV cache, layers area:models Model architectures, weights, loading, metadata status:done Completed type:performance Performance improvements

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant