DGX Spark · part 20
[實作] 用 Triton 讓 NVFP4 在 GB10 上快 17%:FP8 Tensor Core 繞路攻略
❯ cat --toc
TL;DR
Part 19 確認 NVFP4 在 GB10 上走 Marlin BF16 fallback,只有 40.8 tok/s。這篇直接寫 Triton kernel 繞路:NVFP4 → FP8 dequant → FP8 tensor core GEMM。結果:47.6 tok/s(+17%)。離 native FP8 的 53.8 只剩 12% 差距。Monkey-patch vLLM 即可使用,附完整程式碼。
白話版:既然硬體不給力,軟體來補
上一篇我們發現 GB10 (SM121) 缺少 FP4 硬體指令,NVFP4 只能走慢速的 BF16 fallback。結論是「用 FP8 就對了」。
但這個結論有個前提:你得有 FP8 版的模型。
如果某天有個模型只 release NVFP4 checkpoint,你就只能吃 Marlin BF16 fallback 的 40.8 tok/s。
所以問題變成:GB10 沒有 FP4 tensor core,但有 FP8 tensor core——能不能把 NVFP4 權重轉成 FP8,用 FP8 的硬體加速來跑?
註(2026-04-30):寫這篇時假設「GB10 沒 FP4 tensor core」, 後來 NVIDIA forum 澄清 GB10 其實有(target 要
sm_121a而非sm_121)— 詳見 Part 19 修正。但這篇實作還是有用:vLLM 0.19 預設 Marlin BF16 fallback(40.8 tok/s)在那時是默認路徑,Triton 繞 FP8 tensor core 仍快 17%。要走真正的 sm_121a NVFP4 native 需要 vLLM PR #40082 + 4 層 patch,見 Part 25。
答案是可以。
思路:Marlin 的浪費
vLLM 在 SM121 上的 NVFP4 路徑是這樣的:
NVFP4 weights → Marlin dequant → BF16 → BF16 GEMM
Marlin 把 FP4 解壓成 BF16,然後跑 BF16 矩陣乘法。但 GB10 的 FP8 tensor core 比 BF16 快得多——我們的 micro-benchmark 測到 1.6x 到 4x 的 GEMM 加速。
Marlin 選 BF16 是因為它要相容所有 GPU(SM75+),而 FP8 需要 SM89+。但 GB10 是 SM121——它有 FP8,只是沒有 FP4。
我們的路徑:
NVFP4 weights → Triton dequant → FP8 → FP8 tensor core GEMM
第一步:搞清楚 NVFP4 的 dequant 公式
NVFP4(compressed-tensors 格式)的權重由三個 tensor 組成:
| Tensor | 型別 | 說明 |
|---|---|---|
weight_packed | uint8 (N, K/2) | 兩個 FP4 值 pack 在一個 byte |
weight_scale | float8_e4m3fn (N, K/16) | 每 16 個元素一個 block scale |
weight_global_scale | float32 (1,) | 整個 tensor 的 global scale |
FP4 E2M1 只有 16 個可能的值(含正負):
0, ±0.5, ±1.0, ±1.5, ±2.0, ±3.0, ±4.0, ±6.0
Dequant 公式:
value = fp4_lut[nibble] × block_scale ÷ global_scale
注意 global_scale 存的是 quantize scale(不是 dequant scale),所以要除不是乘。我第一次搞錯這個,dequant 出來的值全部爆到 ±1 億,被 FP8 的 ±448 範圍 clamp 成垃圾。
第二步:Triton kernel
FP4 E2M1 → FP8 E4M3 的轉換本質上就是查表 + 乘 scale,非常適合 GPU 平行:
@triton.jit
def _nvfp4_dequant_to_fp8_kernel(
packed_ptr, scale_ptr, global_scale_ptr, out_ptr,
K, K_packed, n_scale_cols, group_size, BLOCK_K: tl.constexpr,
):
row = tl.program_id(0)
col_offsets = tl.program_id(1) * BLOCK_K + tl.arange(0, BLOCK_K)
global_scale_inv = 1.0 / tl.load(global_scale_ptr)
packed = tl.load(packed_ptr + row * K_packed + col_offsets, ...)
# Unpack two FP4 values per byte
low_nibble = packed & 0x0F
high_nibble = (packed >> 4) & 0x0F
# LUT decode + sign handling
low_val = fp4_lut_decode(low_nibble)
high_val = fp4_lut_decode(high_nibble)
# Apply block scale and global scale
low_val *= block_scale_low * global_scale_inv
high_val *= block_scale_high * global_scale_inv
# Store as FP8 E4M3
tl.store(out_ptr + ..., low_val.to(tl.float8e4nv), ...)
tl.store(out_ptr + ..., high_val.to(tl.float8e4nv), ...)
第三步:Micro-benchmark
在 GB10 上用真實的 Qwen 3.6-35B-A3B NVFP4 權重測試:
Dequant 速度
| 方法 | 耗時 | 倍數 |
|---|---|---|
| Python FP4→BF16(模擬 Marlin) | 0.226 ms | baseline |
| Triton FP4→FP8 | 0.010 ms | 23x |
單層 End-to-End(dequant + GEMM,batch=32)
| 路徑 | 耗時 | 倍數 |
|---|---|---|
| Path A:FP4→BF16 + BF16 GEMM | 0.229 ms | baseline |
| Path B:Triton FP4→FP8 + FP8 GEMM | 0.060 ms | 3.8x |
| Path C:純 FP8 GEMM(天花板) | 0.029 ms | 7.9x |
Triton dequant 只要 0.01ms,幾乎免費。瓶頸已經移到 activation 的 FP8 量化。
第四步:整合進 vLLM
不改 vLLM source code,用 monkey-patch:
-
攔截
process_weights_after_loading:跳過 Marlin repack,改用 Triton 把 FP4 權重一次轉成 FP8,存在layer._fp8_weight。轉完後釋放原始 FP4 權重省 VRAM。 -
攔截
apply_weights:用torch._scaled_mm跑 FP8 GEMM,不走 Marlin。 -
啟動方式:設定
VLLM_NVFP4_GEMM_BACKEND=marlin強制 vLLM 選 Marlin kernel(這樣我們的 patch 才會被觸發),然後用 wrapper script import patch。
兩個踩過的坑:
_scaled_mm的 matrix layout:B 必須是 column-major。存(N, K)contiguous,runtime.t()得到 col-major(K, N)view。torch.compile相容性:activation quant 用純 torch ops(不用 vllm custom op),讓 inductor 能 trace 和 fuse。
結果:47.6 tok/s
Qwen 3.6-35B-A3B NVFP4,DGX Spark,vLLM 0.19.1,driver 580.142。
註(2026-05-06):所有 tok/s 都在 Qwen 3.6-35B-A3B 上測。後續 profile 顯示單流 60% 時間在沒被量化的 BF16 GDN linear projections,本文 +17% speedup(40.8 → 47.6)在 Qwen 3.6 上可重現,但主因可能是 launch overhead reduction 而非「FP8 tensor core 比 BF16 GEMM 快」這條 mechanism 推論。要驗 patch 對純 transformer 的真實效應需用非 hybrid model 重測。
| 方案 | tok/s | vs Marlin |
|---|---|---|
| Marlin BF16 fallback(原始) | 40.8 | baseline |
| FlashInfer CUTLASS(vLLM 0.19 預設) | 42.5 | +4.2% |
| Triton FP8 patch | 47.6 | +16.7% |
| Native FP8(直接跑 FP8 模型) | 53.8 | +31.9% |
NVFP4 模型從 40.8 → 47.6 tok/s,提升 17%。 離 native FP8 只剩 12% 差距。
剩下的 12% 在哪
三個來源:
-
Dequant 精度損失:FP4 → FP8 的 dequant 沒有直接 FP8 quantize 好。原始的 FP8 模型是從 BF16/FP32 精心量化過的,我們的 FP4 → FP8 多了一層轉換。
-
Activation quantize overhead:每次 forward pass 都要把 activation 從 BF16 量化成 FP8。
torch.compile已經幫忙 fuse 了,但還是有開銷。 -
沒有真正的 fused kernel:理想情況是 CUTLASS level 的 fused dequant+GEMM(在 shared memory 裡邊解壓邊算),我們是分兩步。
要吃掉這 12%,得改 CUTLASS C++ template——工作量大概一到兩週,而且需要 CUTLASS 3.x template metaprogramming 經驗。更實際的做法是等 vLLM 或 NVIDIA 官方出 SM121 FP8 fallback。
什麼時候該用這個 patch
用:模型只有 NVFP4 checkpoint,沒有 FP8 版,你在 DGX Spark 上想要比 40.8 tok/s 更快。
不用:模型有 FP8 版。直接跑 FP8 = 53.8 tok/s,比繞路快。
結論
Part 19 的結論是「NVFP4 在 GB10 上沒救,用 FP8」。Part 20 修正為:NVFP4 在 GB10 上可以救,但有上限。
Triton 30 行 kernel + monkey-patch,一個下午搞定,提速 17%。不需要改 vLLM source、不需要寫 CUDA、不需要等官方更新。
如果你也在 DGX Spark 上跑 NVFP4 模型,這個 patch 直接可用。如果你想貢獻 vLLM,這裡有一個明確的方向:在 SM121 上用 FP8 tensor core 取代 BF16 Marlin fallback。
完整程式碼
兩個檔案。nvfp4_fp8_patch.py 是核心 patch,serve_nvfp4_fp8.py 是啟動 wrapper。
nvfp4_fp8_patch.py
"""
Monkey-patch vLLM 的 NVFP4 Marlin fallback,改用 Triton FP4→FP8 dequant + FP8 GEMM。
適用於 NVIDIA GB10 (SM121):有 FP8 tensor core,沒有 FP4 tensor core。
攔截兩個方法:
- process_weights_after_loading: 跳過 Marlin repack,用 Triton 一次轉成 FP8
- apply_weights: 用 torch._scaled_mm 跑 FP8 GEMM
前置條件:
- vLLM 0.19+
- 環境變數 VLLM_NVFP4_GEMM_BACKEND=marlin(強制走 Marlin,才能被 patch 攔截)
- 環境變數 TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas(SM121 Triton 編譯需要)
"""
import torch
import triton
import triton.language as tl
import logging
import sys
logger = logging.getLogger("nvfp4_fp8_patch")
# ============================================================
# Triton kernel: NVFP4 packed uint8 → FP8 E4M3
#
# 每個 byte 包含 2 個 FP4 E2M1 值(低 4 bit + 高 4 bit)。
# FP4 E2M1 只有 16 個可能的值:0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6
# 用 LUT 解碼後,乘上 block scale 再除以 global scale。
#
# Dequant 公式:value = fp4_lut[nibble] × block_scale ÷ global_scale
# 注意:global_scale 存的是 quantize scale,要除不是乘!
# ============================================================
@triton.jit
def _nvfp4_dequant_to_fp8_kernel(
packed_ptr, # 輸入:uint8 packed weights (N, K/2)
scale_ptr, # 輸入:float8_e4m3fn block scales (N, K/group_size)
global_scale_ptr, # 輸入:float32 global scale (1,)
out_ptr, # 輸出:float8_e4m3fn (N, K)
K: tl.constexpr,
K_packed: tl.constexpr, # K // 2
n_scale_cols: tl.constexpr, # K // group_size
group_size: tl.constexpr, # 通常是 16
BLOCK_K: tl.constexpr, # tile 大小
):
row = tl.program_id(0)
col_block = tl.program_id(1)
col_offsets = col_block * BLOCK_K + tl.arange(0, BLOCK_K)
mask = col_offsets < K_packed
# global_scale 存的是 quantize scale,dequant 時要取倒數
global_scale_inv = 1.0 / tl.load(global_scale_ptr)
# 載入 packed bytes
packed = tl.load(packed_ptr + row * K_packed + col_offsets, mask=mask, other=0).to(tl.uint8)
# 拆分低 4 bit 和高 4 bit
low_nibble = packed & 0x0F
high_nibble = (packed >> 4) & 0x0F
# 分離符號位 (bit 3) 和數值位 (bits 0-2)
low_sign = ((low_nibble >> 3) & 1).to(tl.float32)
low_mag = (low_nibble & 0x07).to(tl.uint8)
high_sign = ((high_nibble >> 3) & 1).to(tl.float32)
high_mag = (high_nibble & 0x07).to(tl.uint8)
# FP4 E2M1 LUT:magnitude index → float value
# 0→0, 1→0.5, 2→1.0, 3→1.5, 4→2.0, 5→3.0, 6→4.0, 7→6.0
low_val = tl.where(low_mag == 0, 0.0,
tl.where(low_mag == 1, 0.5,
tl.where(low_mag == 2, 1.0,
tl.where(low_mag == 3, 1.5,
tl.where(low_mag == 4, 2.0,
tl.where(low_mag == 5, 3.0,
tl.where(low_mag == 6, 4.0, 6.0)))))))
low_val = tl.where(low_sign > 0.5, -low_val, low_val)
high_val = tl.where(high_mag == 0, 0.0,
tl.where(high_mag == 1, 0.5,
tl.where(high_mag == 2, 1.0,
tl.where(high_mag == 3, 1.5,
tl.where(high_mag == 4, 2.0,
tl.where(high_mag == 5, 3.0,
tl.where(high_mag == 6, 4.0, 6.0)))))))
high_val = tl.where(high_sign > 0.5, -high_val, high_val)
# 計算 block scale 的索引
actual_col_low = col_offsets * 2
actual_col_high = col_offsets * 2 + 1
block_scale_low = tl.load(
scale_ptr + row * n_scale_cols + actual_col_low // group_size,
mask=actual_col_low < K, other=0.0).to(tl.float32)
block_scale_high = tl.load(
scale_ptr + row * n_scale_cols + actual_col_high // group_size,
mask=actual_col_high < K, other=0.0).to(tl.float32)
# dequant + clamp 到 FP8 E4M3 範圍 (±448)
low_val = tl.minimum(tl.maximum(low_val * block_scale_low * global_scale_inv, -448.0), 448.0)
high_val = tl.minimum(tl.maximum(high_val * block_scale_high * global_scale_inv, -448.0), 448.0)
# 寫出 FP8
tl.store(out_ptr + row * K + actual_col_low, low_val.to(tl.float8e4nv), mask=actual_col_low < K)
tl.store(out_ptr + row * K + actual_col_high, high_val.to(tl.float8e4nv), mask=actual_col_high < K)
def nvfp4_to_fp8(weight_packed, weight_scale, weight_global_scale):
"""呼叫 Triton kernel 把 NVFP4 packed weights 轉成 FP8 E4M3。"""
N, K_packed = weight_packed.shape
K = K_packed * 2
n_scale_cols = weight_scale.shape[1]
group_size = K // n_scale_cols
out = torch.empty(N, K, device=weight_packed.device, dtype=torch.float8_e4m3fn)
BLOCK_K = 256
grid = (N, triton.cdiv(K_packed, BLOCK_K))
_nvfp4_dequant_to_fp8_kernel[grid](
weight_packed, weight_scale, weight_global_scale, out,
K=K, K_packed=K_packed,
n_scale_cols=n_scale_cols, group_size=group_size,
BLOCK_K=BLOCK_K,
)
return out
# ============================================================
# 替換 vLLM 的 Marlin 方法
# ============================================================
def _patched_process_weights(self, layer: torch.nn.Module) -> None:
"""跳過 Marlin repack,改用 Triton 把 NVFP4 一次轉成 FP8。
轉完後釋放原始 FP4 權重以節省 VRAM。
FP8 權重存在 layer._fp8_weight,runtime 直接用。
"""
logger.warning_once(
"[nvfp4_fp8_patch] SM121 detected. Skipping Marlin repack, "
"will use Triton FP4→FP8 dequant + FP8 tensor core GEMM."
)
weight_packed = layer.weight.data # uint8 (N, K/2)
weight_scale = layer.weight_scale.data # float8_e4m3fn (N, K/group)
weight_global_scale = layer.weight_global_scale.data # float32
# 一次性 dequant:NVFP4 → FP8
w_fp8 = nvfp4_to_fp8(weight_packed, weight_scale, weight_global_scale)
# 存成 (N, K) contiguous。runtime 用 .t() 得到 col-major (K, N) view,
# 這是 torch._scaled_mm 對 B matrix 的要求。
layer._fp8_weight = torch.nn.Parameter(w_fp8.contiguous(), requires_grad=False)
layer._fp8_w_scale = torch.nn.Parameter(
torch.ones(1, 1, device=w_fp8.device, dtype=torch.float32),
requires_grad=False,
)
# 釋放原始 FP4 權重,騰出 VRAM
layer.weight = torch.nn.Parameter(
torch.empty(0, dtype=torch.uint8, device=weight_packed.device),
requires_grad=False,
)
layer.weight_scale = torch.nn.Parameter(
torch.empty(0, dtype=torch.float8_e4m3fn, device=weight_scale.device),
requires_grad=False,
)
logger.info(
f"[nvfp4_fp8_patch] Pre-converted layer to FP8: "
f"{weight_packed.shape} → {w_fp8.shape}, "
f"weight range: [{w_fp8.float().min():.4f}, {w_fp8.float().max():.4f}]"
)
def _patched_apply_weights(self, layer, x, bias=None):
"""用預轉好的 FP8 權重跑 FP8 tensor core GEMM。
Activation 用純 torch ops 量化成 FP8,
讓 torch.compile/inductor 能 trace 和 fuse。
"""
orig_shape = x.shape
x_2d = x.reshape(-1, orig_shape[-1])
# Activation → FP8(per-tensor dynamic quantization)
amax = x_2d.float().abs().amax()
a_scale = (amax / 448.0).clamp(min=1e-12)
x_fp8 = (x_2d.float() / a_scale).clamp(-448, 448).to(torch.float8_e4m3fn)
a_scale = a_scale.reshape(1, 1)
# FP8 GEMM:(M, K) @ (K, N) → (M, N)
out = torch._scaled_mm(
x_fp8,
layer._fp8_weight.t(), # (N,K).t() = col-major (K,N)
scale_a=a_scale,
scale_b=layer._fp8_w_scale,
out_dtype=x.dtype,
)
if bias is not None:
out = out + bias
return out.reshape(*orig_shape[:-1], out.shape[-1])
# ============================================================
# Patch 安裝:偵測 SM121 後替換 Marlin 方法
# ============================================================
_patched = False
def patch():
"""安裝 monkey-patch。只在 SM121 (GB10) 上生效。"""
global _patched
if _patched:
return True
if torch.cuda.is_available():
cap = torch.cuda.get_device_capability()
if cap != (12, 1):
logger.info(f"SM{cap[0]}{cap[1]} detected, FP8 patch not needed")
return False
try:
from vllm.model_executor.kernels.linear.nvfp4.marlin import (
MarlinNvFp4LinearKernel,
)
MarlinNvFp4LinearKernel.process_weights_after_loading = _patched_process_weights
MarlinNvFp4LinearKernel.apply_weights = _patched_apply_weights
_patched = True
logger.warning(
"[nvfp4_fp8_patch] Installed! MarlinNvFp4LinearKernel now uses "
"pre-converted FP8 weights + FP8 tensor core GEMM."
)
return True
except ImportError as e:
logger.error(f"Failed to patch: {e}")
return False
serve_nvfp4_fp8.py
"""
啟動 wrapper:先安裝 FP8 patch,再啟動 vllm serve。
patch 必須在 parent 和 spawned worker 都執行(放在 __main__ guard 外面),
但只有 parent 才呼叫 vllm_main()。
環境變數:
MODEL_PATH 模型路徑(預設 /model)
PORT API port(預設 8000)
MAX_MODEL_LEN 最大 context 長度(預設 4096)
GPU_MEM_UTIL GPU 記憶體使用率(預設 0.90)
Docker 用法:
docker run --gpus all \
-e VLLM_NVFP4_GEMM_BACKEND=marlin \
-e TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas \
-e MODEL_PATH=/model \
-v /path/to/nvfp4/model:/model \
-v /path/to/scripts:/scripts \
--entrypoint python3 <vllm-image> /scripts/serve_nvfp4_fp8.py
"""
import sys
import os
import logging
# patch 在 import 時執行,確保 spawned workers 也會套用
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import nvfp4_fp8_patch
nvfp4_fp8_patch.patch()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO,
format="%(asctime)s %(name)s %(levelname)s %(message)s")
model_path = os.environ.get("MODEL_PATH", "/model")
port = os.environ.get("PORT", "8000")
max_model_len = os.environ.get("MAX_MODEL_LEN", "4096")
gpu_mem = os.environ.get("GPU_MEM_UTIL", "0.90")
sys.argv = [
"vllm", "serve", model_path,
"--port", port,
"--max-model-len", max_model_len,
"--gpu-memory-utilization", gpu_mem,
"--trust-remote-code",
]
from vllm.entrypoints.cli.main import main as vllm_main
vllm_main()
同系列文章:Part 19 — NVFP4 是陷阱 · Part 14 — Gemma 4 全家桶 · Part 8 — vLLM vs Ollama
常見問題
- NVFP4 在 DGX Spark 上真的沒救了嗎?
- 沒有原生硬體支援,但可以繞路。用 Triton 把 NVFP4 權重 dequant 成 FP8,再走 FP8 tensor core,從 40.8 提升到 47.6 tok/s(+17%)。離 native FP8 的 53.8 只差 12%。
- 這個 patch 能直接用在 vLLM 上嗎?
- 可以。monkey-patch vLLM 的 MarlinNvFp4LinearKernel,load time 一次轉好 FP8,runtime 走 torch._scaled_mm。需要設定 VLLM_NVFP4_GEMM_BACKEND=marlin。完整程式碼在文中。
- 為什麼不直接跑 FP8 模型就好?
- 如果有 FP8 版,直接跑 FP8 確實更快(53.8 tok/s)。這個 patch 的價值是:當模型只有 NVFP4 release、沒有 FP8 版本時,你不用再忍受 Marlin BF16 fallback 的 40.8 tok/s。