~/blog/dgx-spark-nemotron-3-nano-w4a16-74-toks

DGX Spark · part 25

[vLLM] DGX Spark 跑 Nemotron 3 Nano NVFP4:74.75 tok/s,比公開值快 11.5%

cat --toc

TL;DR

Nemotron 3 Nano 30B-A3B 用 W4A16 NVFP4 跑 single-stream 74.75 tok/s,比 NVIDIA 論壇公開的 67 tok/s(eugr 量的)快 11.5%,比我自己 Part 20 的 47.6 tok/s FP8 hack 快 57%。能再推上去靠兩件事:4 層軟體 patch 把 b12x 在 SM121 跑通,再來是挑對 model — Nemotron 3 雖然也是 hybrid,但不像 Qwen 3.6 在 GDN 偷藏一堆 BF16 把單流壓在 44 tok/s。現在卡的不是 kernel,是 273 GB/s LPDDR5X 配 3.5B active 的頻寬天花板。

白話版:4-bit 數字怎麼能跑得比 8-bit 快

DGX Spark 是 NVIDIA 賣的小型 AI 桌機,台幣 9 萬出頭。它用一顆叫 GB10 的晶片,最大瓶頸是記憶體頻寬只有 273 GB/s — 大概是 H100 的 1/12。本機跑 chat model 時,頻寬就是一切:weight 多快搬出記憶體,每秒就能吐多少 token。

NVFP4 把每個 weight 壓到 4 bits、剛好是 FP8 的一半。理論上速度該翻倍。但 DGX Spark 是 SM121 這顆特殊晶片,早期 vLLM 的 FP4 路徑根本沒做好 — 我十天前 Part 19 才寫過:NVFP4 比 FP8 慢 32%。

今天同一台機器,30B 參數的 model 用 NVFP4 拉到 74.75 tok/s。比我自己之前 FP8 hack 的紀錄快、比網路上看得到的最佳值快、離 273 GB/s 頻寬搬 4-bit weight 的物理上限只差一點。這篇就是收據:補了什麼 patch、挑了哪顆 model、跑了什麼 bench script、撞了哪些牆。


前言

之前那個坑是真的,但不是死路 — 是我該停止怪硬體、回去讀 vLLM 原始碼。

Part 19 我宣告 NVFP4 是個坑。Part 20 拿 47.6 tok/s 的 Triton FP8 dequant hack 當 workaround。這篇 Part 25 是大翻盤。問題的本體不是 NVFP4,是三件事疊起來:(a) hybrid SSM/Mamba 架構的 model 在 decoder stack 偷塞一堆 BF16 layer;(b) vLLM 跑起來預期 cutlass-dsl 已經 patch 過 sm_121a,但根本沒;(c) 單流場景挑錯 quant variant。

整篇就在講怎麼從 44 tok/s 一路推到 74.75。Cherry-pick vLLM PR #40082 是第一步,改 cutlass-dsl 跟 FlashInfer 是第二步,挑 Nemotron 3 Nano + W4A16 是第三步。Bench script 跟踩雷時間軸放後面。


公開最佳值是 67 tok/s — 同硬體不同 image,我們拉到 74.75

來源ImageModelSingle-stream tok/s
eugr (論壇,Dec 2025)avarok/vllm-dgx-spark:v11cybermotaz/nemotron3-nano-nvfp4-w4a1667
Spark Arena 排行榜(各家)Nemotron-3-Nano-30B-A3B56.11
這次(May 2026)vllm-node-tf6-b12x(自建)跟 eugr 同一顆74.75

11.5% 是個值得拆開來講的差距。avarok image 是 Jan 2026 上架,當時 vLLM 還在 0.14 — b12x backend 那條 PR #40082 到 2026-05-01 都還掛在 open。我自己 build 的 image 走 vLLM 0.20.1,PR 直接 cherry-pick 進 local fork,再補 4 層 SM121 patch(upstream 還沒收,要自己 patch)。


4 層 patch 把 b12x 在 SM121 跑起來

CUTLASS PR #3082 到 2026 年 5 月還掛在 open。CUTLASS DSL 的 warp/mma.py 寫死只認 sm_120a,但 FP4 MMA 要的硬體 feature sm_121a 也有。沒 patch 的話,b12x kernel 不是拒收 dispatch、就是吐出壞掉的 PTX。

第 1 層 — 裝 CUDA 13 的 runtime lib(pip 預設只給 libs-base,CUDA 13 那包要另外裝):

pip install --no-deps nvidia-cutlass-dsl-libs-cu13==4.4.2

第 2 層 — 改 cutlass-dsl 的 warp/mma.py,admissible list 跟 base equality check 兩邊都要改:

# /usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cute/nvgpu/warp/mma.py

# 原本:
admissible_archs = ["sm_120a"]
if not arch == Arch.sm_120a:
    raise OpError(...)

# 改成:
admissible_archs = ["sm_120a", "sm_121a"]
if arch not in (Arch.sm_120a, Arch.sm_121a):
    raise OpError(...)

第 3 層 — 改 FlashInfer 的 dense_blockscaled_gemm_sm120.py。kernel 裡兩個地方寫死 sm_version="sm_120",b12x dispatcher 一看到 SM121 就拒收:

# /usr/local/lib/python3.12/dist-packages/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm120.py

# runtime check 放寬到 sm_121 系列
if sm_version not in ("sm_120", "sm_120a", "sm_121", "sm_121a"):
    raise ValueError(...)

# 不寫死 sm_120,改讀 env
sm_version=__import__("os").environ.get("CUTE_DSL_ARCH", "sm_120"),

第 4 層 — recipe 補 env var:

env:
  CUTE_DSL_ARCH: sm_121a
  VLLM_NVFP4_GEMM_BACKEND: flashinfer-b12x
  VLLM_FLASHINFER_MOE_BACKEND: latency

四層都到位,vLLM log 才會吐這行:

Using FlashInferB12xNvFp4LinearKernel for NVFP4 GEMM

完整 mod 我包成 mods/sm121-b12x-full-enable/run.sh,每次起 container 跑一次就好。


Qwen 3.6 卡在 44 tok/s — 不是 kernel 慢,是模型本身

4 層 patch 都打完,b12x 跑得乾乾淨淨。但 Qwen 3.6-35B-A3B-NVFP4 single-stream 還是只到 44 tok/s。我寫了個 static audit — model load 完之後走一次 model.named_modules(),每層印出 type(quant_method).__name__,答案一秒就出來。

Layer 類型數量Quant method總 weight
linear_attn.in_proj_qkvz30UnquantizedLinearMethod1.44 GB BF16
linear_attn.out_proj30UnquantizedLinearMethod480 MB BF16
mlp.gate(router)40UnquantizedLinearMethod40 MB BF16
lm_head1UnquantizedEmbeddingMethod970 MB BF16

換句話說,每跑一個 single-stream token,光 BF16 weight 就要從記憶體讀 2.9 GB,再加 active 3B MoE expert 的 1.7 GB NVFP4。273 GB/s 算下來,這個 hybrid Qwen 3.5-MoE 架構的理論天花板大概 60 tok/s — 實測 44 tok/s 落在 cuBLAS GEMV batch=1 慢的合理範圍。

linear_attnQwen3-Next 架構裡的 GDN/SSM mixer。每層 in_proj_qkvz 是個 [12288, 2048] 投影,LLM Compressor 的 quant recipe 為了 SSM 精度故意 skip 不量化。在頻寬綁住的機器上,這個 skip 直接決定 single-stream throughput。b12x kernel 從頭到尾沒做錯。


換 Nemotron 3 Nano:同 A3B 大小、沒 GDN tax

Nemotron 3 Nano-30B-A3B 也是 hybrid SSM/Mamba(Nemotron-H base),但 quant 排除清單短很多。cybermotaz/nemotron3-nano-nvfp4-w4a16 這個社群 quant 用 W4A16 — weight FP4、activation 留 BF16。少了每 step 的 activation FP4 量化開銷,single-stream 跑起來就快。

Bench 結果:

ModelVariantSingle-streamc=16 aggregate
Qwen 3.6-35B-A3B-NVFP4hybrid44.32363
Gemma 4-31B-IT-NVFP4dense6.85105
Nemotron 3 NanoW4A1674.75400

(多模態的 Nemotron 3 Nano Omni 沒放進來比 — b12x patched stack 餵給它 modelopt_mixed dispatch 會吐 NaN logits,要換 image 才能跑。多模態那條線在 Part 26 另開講。)

兩個場景挑不一樣的 quant:

  • 單流 chat 選 W4A16 — 比 W4A4 快 28%。沒有 activation quant 開銷、每 token 算的東西少。
  • 多人 serving 選 W4A4 — c=16 aggregate 比 W4A16 快 96%。activation 也壓掉、cache 用得兇,多並發吃到的 batch 大。

兩個目標不一樣。

Bench script:

# 200 token deterministic decode、threading 控並發
import json, time, urllib.request, threading

def req(out):
    data = json.dumps({
        "model": "nemotron-w4a16",
        "messages": [{"role": "user", "content": PROMPT}],
        "max_tokens": 200, "temperature": 0.0,
    }).encode()
    t0 = time.perf_counter()
    r = urllib.request.urlopen(urllib.request.Request(
        URL, data=data, headers={"Content-Type": "application/json"}
    ), timeout=600).read()
    out.append((time.perf_counter() - t0,
                json.loads(r)["usage"]["completion_tokens"]))

# warmup 一次再 sweep concurrency 1-16

輸出品質我隨便丟了三題:「台灣首都」、「英翻中」、「Python Fibonacci memoization」。沒 garbage、token 沒掉。<think> reasoning block 直接寫在 content 裡(我沒裝專屬 reasoning parser),實用上 OK。


MTP 走的死路:5 次失敗、0 次成功

Multi-Token Prediction(MTP)是突破頻寬天花板的標準招 — 一次 forward 預測 2-3 個 token,對的接、錯的重算。NVIDIA 自己的 Spark Deployment Guide 就明寫 Nemotron 3 該用 --speculative_config '{"method":"mtp","num_speculative_tokens":3}'

試了 5 次,全掛:

  1. Qwen 3.6 + b12x MoE backendb12x 不接 unquantized MoE,draft model 預設 BF16,dispatcher 一看就拒。啟動就死。
  2. Qwen 3.6 + flashinfer_cutlass + gpu_memory_utilization=0.85。vLLM 的 CUDA-graph mem profiler 回傳 -35.63 GiB(負值 bug,沒 clamp),KV cache budget 被算成 112 GB,128 GB 統一記憶體直接爆,cudagraph capture 中 SIGKILL。
  3. 同上、util 降到 0.70。一樣 OOM、一樣 SIGKILL。
  4. 同上、util 0.65 + max_num_seqs 16。還是死。
  5. Nemotron 3 Nano W4A16 + mtp method。vLLM 0.20 要 model config 有 num_nextn_predict_layers > 0 才把 nemotron_h 升 mtp。cybermotaz 那個社群 quant 把欄位砍了,config 階段就吐 NotImplementedError: Unsupported speculative method: 'mtp'

退一步用 ngram speculative,跑 essay 內容反而慢 25%(74.75 → 55.70)。文章沒重複 pattern → prompt-lookup 幾乎不命中 → 算了 draft token 全被 reject、白花時間。

我請 Codex 做了一輪深度 review,結論是要再快只剩兩條路:(a) 再 cherry-pick vLLM PR #35947(E2M1 軟體 fallback,throughput 提升不確定),或 (b) 找 Nemotron-H 專屬的 EAGLE3 head(公開資源目前沒有)。兩條都不是 2 小時搞得定的。

「MTP 是下一個 speedup」變成「MTP 是下一個 6 個月研究案」。


收穫

最花時間的地方

MTP 的死路。5 種失敗長得幾乎一樣,看起來都像上一次的鬼。OOM 把系統整個凍住那次最痛 — 我得自己去按重開機按鈕才救得回來。vLLM 的 CUDA-graph mem profiler 沒 clamp 負值,回傳 -35.63 GiB 配高 gpu_memory_utilization 會默默分配超過統一記憶體上限,連 kernel 一起拖死。

真正推進 throughput 的是一連串 static introspection:dump model.named_modules() 看哪些投影層是 UnquantizedLinearMethod、讀 vllm/config/speculative.py 才知道 mtp 卡在 model config 某個欄位(社群 quant 常砍掉)、grep FlashInfer 原始碼看到 b12x kernel 把 sm_120 寫死三處。

可搬走的診斷方法

1. Static module audit 比 forward hook 好用。 model 跑得慢、不知道哪層拖累,load_model() 之後走 model.named_modules(),按 type(m.quant_method).__name__ 分類,BF16 fallback 的層一秒就看出來。不用 NVTX、不用 profiler trace、不用改框架。

2. Dispatch log 寫的是 intent,不是 behavior。 Using FlashInferB12xNvFp4LinearKernel for NVFP4 GEMM 這行從第一天就有,但 kernel 裡 sm_version="sm_120" 寫死三處,要一個一個 monkey-patch。Log 確認意圖、原始碼才確認行為。

3. CUDA-graph 記憶體 bug 會偽裝成 OOM。 vLLM gpu_worker.py:448 如果 log 跑出 Estimated CUDA graph memory: -X GiB,profiler 已經壞了、KV-cache budget 也算錯。直接 pin --kv-cache-memory-bytes,別靠 --gpu-memory-utilization 自動算。

通用原則

記憶體頻寬是物理事實。273 GB/s LPDDR5X 是每個 token 要讀的所有東西的總預算 — 包括量化後的 weight、加 lm_head、加 checkpoint 沒量化掉的 BF16 殘留層。確切上限要看那些非量化部分多大(Nemotron 3 W4A16 我筆算下來大概落在 80 出頭 tok/s),74.75 已經爬到差不多盡頭了,再往上必須改 bytes-per-token 那條曲線 — speculative decoding、active params 更少、或者更快的記憶體。Kernel 怎麼調都已經沒用。


收尾

想在 DGX Spark 上把 NVFP4 single-stream 拉到上限,今天的配方:

  1. Cherry-pick vLLM PR #40082 進 vLLM main。
  2. 補 4 層 SM121 patch(libs-cu13cutlass-dsl mma.py、FlashInfer dense kernel、env vars)。
  3. 挑非 hybrid model。cybermotaz/nemotron3-nano-nvfp4-w4a16 是我目前測到 single-stream 最快的。
  4. Chat 用 W4A16、serving 用 W4A4。兩個目標不一樣。
  5. MTP 先不要碰。vLLM 跟 checkpoint 的相容矩陣還沒打通。

完整 mod 跟 recipe 在 spark-vllm-docker fork 的 mods/sm121-b12x-full-enable/run.shrecipes/nemotron-w4a16.yaml。從乾淨 image 重 build,30 分鐘可以複現 74.75。

系列其他文章:

常見問題

Nemotron 3 Nano NVFP4 在 DGX Spark 多快?
single-stream 74.75 tok/s。model 是 cybermotaz/nemotron3-nano-nvfp4-w4a16(W4A16),跑在我自己 build 的 vLLM image 上。比 NVIDIA 論壇公開的 67 tok/s 快 11.5%。這個數字已經卡在 GB10 的頻寬天花板區段了(3.5B active 配 273 GB/s LPDDR5X),再去調 kernel 也擠不太出來。
DGX Spark 上 W4A16 跟 W4A4 該選哪個?
看你要單流還是多並發。單人 chat → W4A16(74.75 tok/s vs W4A4 的 58.27)。Server 多人 → W4A4(c=16 aggregate 786 vs W4A16 的 400 tok/s)。W4A4 把 activation 也壓 FP4,多並發的 cache 利用率高;但每 step 多一個 activation quant overhead,單流被拖。
NVFP4 終於在 GB10 贏 FP8 了嗎?
贏了。之前 part 20 寫的 47.6 tok/s Triton FP8 hack 被打贏 57%。當初 part 19 說「NVFP4 是個坑」對 Qwen 3.6 hybrid 是對的 — 真正的瓶頸是 BF16 的 GDN linear_attn 投影層。換掉 hybrid model(Nemotron 3 Nano)就解了。
4 層 patch 是哪 4 層?
1) pip 裝 nvidia-cutlass-dsl-libs-cu13==4.4.2 補 CUDA 13 的 runtime lib。2) 改 cutlass-dsl warp/mma.py,把 sm_121a 加進 admissible_archs 跟 base equality check。3) 改 FlashInfer dense_blockscaled_gemm_sm120.py 的 sm_version 檢查。4) Recipe 加 CUTE_DSL_ARCH=sm_121a + VLLM_NVFP4_GEMM_BACKEND=flashinfer-b12x。最後 cherry-pick vLLM PR #40082 才有 b12x dispatcher 可用。