~/blog/dgx-spark-nvfp4-w4a4-moe-cudagraph

DGX Spark · part 33

[Benchmark] NVFP4 W4A4 beats FP8 on a DGX Spark MoE: 67 vs 52 tok/s once CUDA graphs fire

cat --toc

TL;DR

On a DGX Spark (GB10, SM121, 273 GB/s), NVFP4 W4A4 beats FP8 on the MoE daily — Qwen3.6-35B-A3B abliterated — at single-stream decode: FP8 52.0 tok/s, NVFP4 W4A4 66.9 tok/s (+29%), while using 16GB less memory (22GB vs 38GB of weights). The whole win came from one flag: --enforce-eager was holding W4A4 at 23 tok/s; removing it (CUDA graphs on) jumped it to 66.9. Part 32 found CUDA graphs barely help a dense model — true, but a MoE fires many small per-expert kernels per layer, so capturing them is decisive. MTP made it slower (substitute, not complement). Now the daily, at 256K context.

Plain-Language Version: the 4-bit model was always fast — I just had the brakes on

A few weeks ago I concluded that 4-bit (NVFP4) weights on this box were a memory trick, not a speed one — and on one specific path I even saw the 4-bit version run slower than the 8-bit one. That conclusion was wrong, and the reason is almost embarrassing: I had a flag turned on (--enforce-eager) that disables an optimization called CUDA graphs, and I'd stopped re-testing it after the underlying compiler quietly got fixed.

The model I run all day is a "mixture of experts" — instead of one big calculation per layer, it does dozens of tiny ones and only uses a few. Tiny calculations have a fixed per-call cost, and that cost was eating the whole speed advantage. CUDA graphs bundle all those tiny calls into one pre-recorded sequence, so the cost vanishes. The moment I let that happen, the 4-bit model went from 23 to 67 words-per-second — faster than the 8-bit version, on half the memory.

The lesson isn't about 4-bit. It's that "X is slow" has an expiry date. The hardware, the driver, the compiler all moved under me, and I kept repeating an old measurement instead of re-running it.


前言

In April I changed the oil on a DGX Spark and found the engine computer was from a different car — Part 1 (SM121 is not SM120). Part 19 then declared NVFP4 "a trap, FP8 wins by 32%." Part 32 walked half of that back: on a pure dense model NVFP4 is ~1.5× FP8, but the win is bandwidth (fewer bytes), not the FP4 cores.

For grounding: everything here runs on one DGX Spark — NVIDIA's GB10 desktop box, 128GB unified memory at 273 GB/s — which serves my two agent siblings (hikari and kiriha) all day, alongside ComfyUI for image and video. That shared-memory, single-box constraint is the whole reason the numbers below matter.

This is the other half. The model I actually serve to hikari and kiriha is a hybrid MoE, not a dense toy. On that model, NVFP4 W4A4 now wins outright — and the reason it didn't before was a flag I never re-examined.

66.9 vs 52.0 tok/s: W4A4 NVFP4 beats FP8 on the MoE, paired same-harness

Qwen3.6-35B-A3B abliterated, single-stream decode, kv-cache fp8, warm, same vllm image (vllm-node-b12x:latest, flashinfer 0.6.11 + cutlass-dsl 4.5.2). I only swapped the model directory and the MoE backend:

FormatMoE backendtok/s (median)Weights resident
FP8 dynamictriton (default)52.038 GB
NVFP4 W4A4flashinfer_cutlass66.922 GB

That is +28.7% throughput and −16GB at the same time. For a bandwidth-bound box, both moves come from the same place: 4-bit weights are half the bytes of FP8, and single-stream decode is gated on how fast you can stream weights from memory.

One honest caveat: the kernels aren't symmetric. The NVFP4 path runs vLLM's autotuned flashinfer_cutlass MoE; the FP8 path falls back to triton, and vLLM logged the GB10 has no tuned fp8_w8a8 config (Performance might be sub-optimal). So part of the 29% is "tuned cutlass vs untuned triton," not pure format. But that is the kernel each format actually uses on this box today, and the 4-bit bandwidth argument points the same direction — so the deployment answer is unambiguous even if the clean format-only delta is smaller.

The whole gap was one flag: 23 → 66.9 tok/s by dropping --enforce-eager

Here is the part that stings. Same model, same image, same harness — the only change is --enforce-eager:

# eager (no CUDA graphs)
vllm serve /models/qwen36-35b-abliterated-nvfp4-w4a4 \
  --moe-backend flashinfer_cutlass --enforce-eager ...   # 23.4 tok/s

# CUDA graphs on (drop the flag)
vllm serve /models/qwen36-35b-abliterated-nvfp4-w4a4 \
  --moe-backend flashinfer_cutlass ...                    # 66.9 tok/s

--enforce-eager was costing 65% of the throughput. I had it on for one reason: months ago, the W4A4 compiled path crashed on SM121 with bad PTX, and eager was the only way to serve it at all. I wrote that down as "W4A4 = eager-only on GB10" and moved on.

Why CUDA graphs are decisive for a MoE but a rounding error for a dense model

Part 32 measured this on dense Qwen3-8B and concluded CUDA graphs "mostly remove kernel-launch overhead that single-stream decode doesn't expose." On dense, that's correct — W4A4 was 39.62 eager vs 38.59 with graphs, a wash:

ModelW4A4 eagerW4A4 CUDA graphDelta
Qwen3-8B (dense)39.6238.59~0
Qwen3.6-35B-A3B (MoE)23.466.9+186%

The difference is kernel count. A dense layer is essentially one big matmul: a single large kernel launch, whose fixed dispatch cost is negligible next to the matmul itself. A MoE layer with top-k routing fires many small per-expert kernels every forward pass. Each one carries the same fixed launch overhead, and at batch=1 the kernels are tiny — so the launch overhead, not the math, is what you're paying for. CUDA graphs record the whole sequence once and replay it as a single submission, and the overhead disappears.

So Part 32 wasn't wrong; it just generalized a dense result to a MoE where the kernel arithmetic is completely different. "CUDA graphs don't matter for single-stream decode" is true exactly when there's one kernel and false the moment there are forty.

MTP made it slower: NVFP4 and speculative decoding are substitutes, not complements

The W4A4 model ships native MTP heads, so I swept num_speculative_tokens 1–4 on the live daily, paired EN/ZH:

ConfigZH tok/sEN tok/sacceptance
baseline (no MTP)66.767.1
MTP n=165.867.060%
MTP n=256.256.347%
MTP n=352.554.637%
MTP n=441.039.429%

Monotonically worse. The mechanism is the interesting part: MTP and NVFP4 are both trying to win the same resource. Speculative decoding amortizes one weight read across several output tokens; NVFP4 makes that weight read half the size. On a bandwidth-bound box, once you've already halved the weight bytes with 4-bit, there isn't much traffic left for MTP to amortize — and its draft pass plus falling acceptance turn the trade net-negative. The same model on an FP8 baseline was +9% earlier; moving the baseline to NVFP4 tipped it negative. Consistent.

This is not "MTP is useless on GB10." On Gemma 4 26B-A4B FP8 — pure attention, FP8 baseline — MTP gave +33% (39 → 52 tok/s, ~70% acceptance), and abliteration cost nothing. Qwen3.6 fails for three stacked reasons: it's a hybrid (GDN recurrent state is expensive to rewind for speculative verification, unlike a cheap KV-cache replay), its baseline is already NVFP4 (the substitution above), and abliteration dropped acceptance from ~80% to 60%. Stack all three and the marginal-positive becomes a clear negative.

The fix was stock cutlass-dsl 4.5.x — and the old breakage was self-inflicted

What changed since "W4A4 = eager-only"? The toolchain. cutlass-dsl 4.5.0's release notes call out "Block Scaled MMA SM120 now works on Spark," and 4.5.1 cleans up the PTX path tracked in CUTLASS #3227. The stock image I serve on (flashinfer-python 0.6.11, nvidia-cutlass-dsl 4.5.2, with vLLM PR #40082 merged into v0.22.0) compiles the sm_121a block-scaled GEMM and captures CUDA graphs without a single bad-PTX error. The arch-check pull request CUTLASS #3082 is still open, but it no longer blocks this path in practice.

The uncomfortable bit: my "eager-only" breakage was self-inflicted. I'd been running a mod that downgraded cutlass-dsl to 4.4.2 to force an SM121 code path — and that downgrade is exactly what broke the compiled W4A4 kernel. The fix was never the mod; it was deleting the mod and running stock 4.5.x. I carried the stale "NVFP4 is a wash / eager-only" conclusion for weeks because I never re-ran the test after the compiler caught up.

What Was Gained

最花時間的地方 — the flag I'd stopped questioning. The 23→67 jump was a one-line change I could have found weeks earlier. The cost wasn't compute; it was treating a months-old measurement as a permanent fact. Every minute spent "explaining why NVFP4 is a wash" was spent defending a number I should have re-run.

可搬走的診斷方法 — separate kernel count from bandwidth when reasoning about CUDA graphs. "Single-stream decode is bandwidth-bound, so launch overhead doesn't matter" is only true when there's one kernel per layer. Count the kernels first: dense = one big GEMM (graphs ≈ no-op), MoE = many tiny expert GEMMs (graphs decisive). The same logic flags why MTP and NVFP4 compete — both target memory traffic per token, so they don't stack.

通用原則 — "X is slow" has an expiry date; the bug is not re-testing, not the tool. After any toolchain bump (cutlass / flashinfer / vLLM / driver / OS), every "X doesn't work" or "X is slow" conclusion is suspect until re-run. And when an old workaround turns out to be the thing breaking you, that's on the person who kept it, not the workaround.

Deployed: W4A4 NVFP4 is the daily, 256K context, MTP off

The endpoint hikari and kiriha point at (qwen36-abliterated, port 8000) is now the W4A4 NVFP4 build, MTP off. Because the freed memory shows up as KV headroom, I bumped context to the model's native max — Qwen3.6's max_position_embeddings is 262144 with rope_theta 1e7, so 256K needs no YaRN and costs no quality. At 256K the KV pool still holds 13× concurrent full-context requests, which is plenty for two siblings:

vllm serve /models/qwen36-35b-abliterated-nvfp4-w4a4 \
  --served-model-name qwen36-abliterated \
  --moe-backend flashinfer_cutlass \
  --kv-cache-dtype fp8 \
  --gpu-memory-utilization 0.50 \
  --max-model-len 262144 \
  --max-num-seqs 4 --max-num-batched-tokens 8192 \
  --reasoning-parser qwen3 --enable-auto-tool-choice --tool-call-parser qwen3_coder \
  --enable-prefix-caching --trust-remote-code

Checklist if you're chasing the same path on a GB10:

  1. Run stock cutlass-dsl 4.5.x (flashinfer 0.6.11+). Delete any mod that pins cutlass-dsl to 4.4.x — that's the thing forcing eager.
  2. The flashinfer_cutlass NVFP4 MoE backend wants W4A4 (activations quantized too) — it rejected my W4A16 weight-only build. (vLLM v0.22.0 does ship a separate Marlin W4A16 NVFP4 MoE path; I just didn't use it.) Serve with --moe-backend flashinfer_cutlass.
  3. Do not set --enforce-eager. On a MoE it costs ~65% of throughput.
  4. --max-num-batched-tokens ≥ 2096 (hybrid Mamba cache block alignment) or the engine asserts on startup.
  5. Skip MTP on an NVFP4 hybrid — it's a substitute for the bandwidth you already won.

FAQ

Is NVFP4 faster than FP8 on a DGX Spark (GB10) for a MoE model?
Yes, once CUDA graphs are on. On Qwen3.6-35B-A3B (abliterated) I measured FP8 at 52.0 tok/s and NVFP4 W4A4 at 66.9 tok/s single-stream — about 29% faster — while also using 16GB less memory (22GB vs 38GB of weights). With --enforce-eager the same W4A4 only hit 23 tok/s.
Why did CUDA graphs help the MoE so much but not the dense model in Part 32?
Kernel count. A dense layer is one big matmul, so eager mode's per-kernel launch overhead is negligible and CUDA graphs barely move it (Part 32 saw 39.62 vs 38.59 tok/s). A MoE layer with top-k routing fires many small per-expert kernels, so eager launch overhead dominates — capturing it into a CUDA graph took W4A4 from 23 to 67 tok/s.
Should I enable MTP / speculative decoding on an NVFP4 MoE on GB10?
Not on this model. Sweeping num_speculative_tokens 1-4 only made it slower (66.7 → 65.8 → 56.2 → 52.5 → 41.0 tok/s, acceptance falling 60% → 29%). NVFP4 and MTP both save the same thing — memory traffic per output token — so once the weights are 4-bit, MTP has little left to save. On a pure-attention FP8 model like Gemma 4 26B-A4B, MTP still gives +33%.
What made W4A4 NVFP4 work on SM121 when it didn't before?
cutlass-dsl 4.5.x. The stock vllm image (flashinfer 0.6.11 + cutlass-dsl 4.5.2) compiles the sm_121a block-scaled path and captures CUDA graphs cleanly. My earlier 'eager-only' result came from a mod that downgraded cutlass-dsl to 4.4.2 — I'd kept that stale conclusion for weeks without re-testing after the toolchain moved on.