~/blog/dgx-spark-nvfp4-fp8-triton-patch

DGX Spark · part 20

[實作] 用 Triton 讓 NVFP4 在 GB10 上快 17%:FP8 Tensor Core 繞路攻略

2026-04-224 分鐘閱讀#nvfp4#fp8#triton#dgx-sparkEnglish
cat --toc

TL;DR

Part 19 確認 NVFP4 在 GB10 上走 Marlin BF16 fallback,只有 40.8 tok/s。這篇直接寫 Triton kernel 繞路:NVFP4 → FP8 dequant → FP8 tensor core GEMM。結果:47.6 tok/s(+17%)。離 native FP8 的 53.8 只剩 12% 差距。Monkey-patch vLLM 即可使用,附完整程式碼。

白話版:既然硬體不給力,軟體來補

上一篇我們發現 GB10 (SM121) 缺少 FP4 硬體指令,NVFP4 只能走慢速的 BF16 fallback。結論是「用 FP8 就對了」。

但這個結論有個前提:你得有 FP8 版的模型

如果某天有個模型只 release NVFP4 checkpoint,你就只能吃 Marlin BF16 fallback 的 40.8 tok/s。

所以問題變成:GB10 沒有 FP4 tensor core,但有 FP8 tensor core——能不能把 NVFP4 權重轉成 FP8,用 FP8 的硬體加速來跑?

答案是可以。


思路:Marlin 的浪費

vLLM 在 SM121 上的 NVFP4 路徑是這樣的:

NVFP4 weights → Marlin dequant → BF16 → BF16 GEMM

Marlin 把 FP4 解壓成 BF16,然後跑 BF16 矩陣乘法。但 GB10 的 FP8 tensor core 比 BF16 快得多——我們的 micro-benchmark 測到 1.6x 到 4x 的 GEMM 加速

Marlin 選 BF16 是因為它要相容所有 GPU(SM75+),而 FP8 需要 SM89+。但 GB10 是 SM121——它 FP8,只是沒有 FP4。

我們的路徑:

NVFP4 weights → Triton dequant → FP8 → FP8 tensor core GEMM

第一步:搞清楚 NVFP4 的 dequant 公式

NVFP4(compressed-tensors 格式)的權重由三個 tensor 組成:

Tensor型別說明
weight_packeduint8 (N, K/2)兩個 FP4 值 pack 在一個 byte
weight_scalefloat8_e4m3fn (N, K/16)每 16 個元素一個 block scale
weight_global_scalefloat32 (1,)整個 tensor 的 global scale

FP4 E2M1 只有 16 個可能的值(含正負):

0, ±0.5, ±1.0, ±1.5, ±2.0, ±3.0, ±4.0, ±6.0

Dequant 公式:

value = fp4_lut[nibble] × block_scale ÷ global_scale

注意 global_scale 存的是 quantize scale(不是 dequant scale),所以要不是乘。我第一次搞錯這個,dequant 出來的值全部爆到 ±1 億,被 FP8 的 ±448 範圍 clamp 成垃圾。

第二步:Triton kernel

FP4 E2M1 → FP8 E4M3 的轉換本質上就是查表 + 乘 scale,非常適合 GPU 平行:

@triton.jit
def _nvfp4_dequant_to_fp8_kernel(
    packed_ptr, scale_ptr, global_scale_ptr, out_ptr,
    K, K_packed, n_scale_cols, group_size, BLOCK_K: tl.constexpr,
):
    row = tl.program_id(0)
    col_offsets = tl.program_id(1) * BLOCK_K + tl.arange(0, BLOCK_K)

    global_scale_inv = 1.0 / tl.load(global_scale_ptr)
    packed = tl.load(packed_ptr + row * K_packed + col_offsets, ...)

    # Unpack two FP4 values per byte
    low_nibble = packed & 0x0F
    high_nibble = (packed >> 4) & 0x0F

    # LUT decode + sign handling
    low_val = fp4_lut_decode(low_nibble)
    high_val = fp4_lut_decode(high_nibble)

    # Apply block scale and global scale
    low_val *= block_scale_low * global_scale_inv
    high_val *= block_scale_high * global_scale_inv

    # Store as FP8 E4M3
    tl.store(out_ptr + ..., low_val.to(tl.float8e4nv), ...)
    tl.store(out_ptr + ..., high_val.to(tl.float8e4nv), ...)

第三步:Micro-benchmark

在 GB10 上用真實的 Qwen 3.6-35B-A3B NVFP4 權重測試:

Dequant 速度

方法耗時倍數
Python FP4→BF16(模擬 Marlin)0.226 msbaseline
Triton FP4→FP80.010 ms23x

單層 End-to-End(dequant + GEMM,batch=32)

路徑耗時倍數
Path A:FP4→BF16 + BF16 GEMM0.229 msbaseline
Path B:Triton FP4→FP8 + FP8 GEMM0.060 ms3.8x
Path C:純 FP8 GEMM(天花板)0.029 ms7.9x

Triton dequant 只要 0.01ms,幾乎免費。瓶頸已經移到 activation 的 FP8 量化。

第四步:整合進 vLLM

不改 vLLM source code,用 monkey-patch:

  1. 攔截 process_weights_after_loading:跳過 Marlin repack,改用 Triton 把 FP4 權重一次轉成 FP8,存在 layer._fp8_weight。轉完後釋放原始 FP4 權重省 VRAM。

  2. 攔截 apply_weights:用 torch._scaled_mm 跑 FP8 GEMM,不走 Marlin。

  3. 啟動方式:設定 VLLM_NVFP4_GEMM_BACKEND=marlin 強制 vLLM 選 Marlin kernel(這樣我們的 patch 才會被觸發),然後用 wrapper script import patch。

兩個踩過的坑:

  • _scaled_mm 的 matrix layout:B 必須是 column-major。存 (N, K) contiguous,runtime .t() 得到 col-major (K, N) view。
  • torch.compile 相容性:activation quant 用純 torch ops(不用 vllm custom op),讓 inductor 能 trace 和 fuse。

結果:47.6 tok/s

Qwen 3.6-35B-A3B NVFP4,DGX Spark,vLLM 0.19.1,driver 580.142。

方案tok/svs Marlin
Marlin BF16 fallback(原始)40.8baseline
FlashInfer CUTLASS(vLLM 0.19 預設)42.5+4.2%
Triton FP8 patch47.6+16.7%
Native FP8(直接跑 FP8 模型)53.8+31.9%

NVFP4 模型從 40.8 → 47.6 tok/s,提升 17%。 離 native FP8 只剩 12% 差距。

剩下的 12% 在哪

三個來源:

  1. Dequant 精度損失:FP4 → FP8 的 dequant 沒有直接 FP8 quantize 好。原始的 FP8 模型是從 BF16/FP32 精心量化過的,我們的 FP4 → FP8 多了一層轉換。

  2. Activation quantize overhead:每次 forward pass 都要把 activation 從 BF16 量化成 FP8。torch.compile 已經幫忙 fuse 了,但還是有開銷。

  3. 沒有真正的 fused kernel:理想情況是 CUTLASS level 的 fused dequant+GEMM(在 shared memory 裡邊解壓邊算),我們是分兩步。

要吃掉這 12%,得改 CUTLASS C++ template——工作量大概一到兩週,而且需要 CUTLASS 3.x template metaprogramming 經驗。更實際的做法是等 vLLM 或 NVIDIA 官方出 SM121 FP8 fallback。

什麼時候該用這個 patch

:模型只有 NVFP4 checkpoint,沒有 FP8 版,你在 DGX Spark 上想要比 40.8 tok/s 更快。

不用:模型有 FP8 版。直接跑 FP8 = 53.8 tok/s,比繞路快。

結論

Part 19 的結論是「NVFP4 在 GB10 上沒救,用 FP8」。Part 20 修正為:NVFP4 在 GB10 上可以救,但有上限

Triton 30 行 kernel + monkey-patch,一個下午搞定,提速 17%。不需要改 vLLM source、不需要寫 CUDA、不需要等官方更新。

如果你也在 DGX Spark 上跑 NVFP4 模型,這個 patch 直接可用。如果你想貢獻 vLLM,這裡有一個明確的方向:在 SM121 上用 FP8 tensor core 取代 BF16 Marlin fallback。


同系列文章:Part 19 — NVFP4 是陷阱 · Part 14 — Gemma 4 全家桶 · Part 8 — vLLM vs Ollama

常見問題

NVFP4 在 DGX Spark 上真的沒救了嗎?
沒有原生硬體支援,但可以繞路。用 Triton 把 NVFP4 權重 dequant 成 FP8,再走 FP8 tensor core,從 40.8 提升到 47.6 tok/s(+17%)。離 native FP8 的 53.8 只差 12%。
這個 patch 能直接用在 vLLM 上嗎?
可以。monkey-patch vLLM 的 MarlinNvFp4LinearKernel,load time 一次轉好 FP8,runtime 走 torch._scaled_mm。需要設定 VLLM_NVFP4_GEMM_BACKEND=marlin。完整程式碼在文中。
為什麼不直接跑 FP8 模型就好?
如果有 FP8 版,直接跑 FP8 確實更快(53.8 tok/s)。這個 patch 的價值是:當模型只有 NVFP4 release、沒有 FP8 版本時,你不用再忍受 Marlin BF16 fallback 的 40.8 tok/s。