~/blog/dgx-spark-nvfp4-fp8-triton-patch

DGX Spark · part 20

[實作] 用 Triton 讓 NVFP4 在 GB10 上快 17%:FP8 Tensor Core 繞路攻略

2026-04-22更新於 2026-05-0610 分鐘閱讀#nvfp4#fp8#triton#dgx-sparkEnglish
cat --toc

TL;DR

Part 19 確認 NVFP4 在 GB10 上走 Marlin BF16 fallback,只有 40.8 tok/s。這篇直接寫 Triton kernel 繞路:NVFP4 → FP8 dequant → FP8 tensor core GEMM。結果:47.6 tok/s(+17%)。離 native FP8 的 53.8 只剩 12% 差距。Monkey-patch vLLM 即可使用,附完整程式碼。

白話版:既然硬體不給力,軟體來補

上一篇我們發現 GB10 (SM121) 缺少 FP4 硬體指令,NVFP4 只能走慢速的 BF16 fallback。結論是「用 FP8 就對了」。

但這個結論有個前提:你得有 FP8 版的模型

如果某天有個模型只 release NVFP4 checkpoint,你就只能吃 Marlin BF16 fallback 的 40.8 tok/s。

所以問題變成:GB10 沒有 FP4 tensor core,但有 FP8 tensor core——能不能把 NVFP4 權重轉成 FP8,用 FP8 的硬體加速來跑?

註(2026-04-30):寫這篇時假設「GB10 沒 FP4 tensor core」, 後來 NVIDIA forum 澄清 GB10 其實有(target 要 sm_121a 而非 sm_121)— 詳見 Part 19 修正。但這篇實作還是有用:vLLM 0.19 預設 Marlin BF16 fallback(40.8 tok/s)在那時是默認路徑,Triton 繞 FP8 tensor core 仍快 17%。要走真正的 sm_121a NVFP4 native 需要 vLLM PR #40082 + 4 層 patch,見 Part 25。

答案是可以。


思路:Marlin 的浪費

vLLM 在 SM121 上的 NVFP4 路徑是這樣的:

NVFP4 weights → Marlin dequant → BF16 → BF16 GEMM

Marlin 把 FP4 解壓成 BF16,然後跑 BF16 矩陣乘法。但 GB10 的 FP8 tensor core 比 BF16 快得多——我們的 micro-benchmark 測到 1.6x 到 4x 的 GEMM 加速

Marlin 選 BF16 是因為它要相容所有 GPU(SM75+),而 FP8 需要 SM89+。但 GB10 是 SM121——它 FP8,只是沒有 FP4。

我們的路徑:

NVFP4 weights → Triton dequant → FP8 → FP8 tensor core GEMM

第一步:搞清楚 NVFP4 的 dequant 公式

NVFP4(compressed-tensors 格式)的權重由三個 tensor 組成:

Tensor型別說明
weight_packeduint8 (N, K/2)兩個 FP4 值 pack 在一個 byte
weight_scalefloat8_e4m3fn (N, K/16)每 16 個元素一個 block scale
weight_global_scalefloat32 (1,)整個 tensor 的 global scale

FP4 E2M1 只有 16 個可能的值(含正負):

0, ±0.5, ±1.0, ±1.5, ±2.0, ±3.0, ±4.0, ±6.0

Dequant 公式:

value = fp4_lut[nibble] × block_scale ÷ global_scale

注意 global_scale 存的是 quantize scale(不是 dequant scale),所以要不是乘。我第一次搞錯這個,dequant 出來的值全部爆到 ±1 億,被 FP8 的 ±448 範圍 clamp 成垃圾。

第二步:Triton kernel

FP4 E2M1 → FP8 E4M3 的轉換本質上就是查表 + 乘 scale,非常適合 GPU 平行:

@triton.jit
def _nvfp4_dequant_to_fp8_kernel(
    packed_ptr, scale_ptr, global_scale_ptr, out_ptr,
    K, K_packed, n_scale_cols, group_size, BLOCK_K: tl.constexpr,
):
    row = tl.program_id(0)
    col_offsets = tl.program_id(1) * BLOCK_K + tl.arange(0, BLOCK_K)

    global_scale_inv = 1.0 / tl.load(global_scale_ptr)
    packed = tl.load(packed_ptr + row * K_packed + col_offsets, ...)

    # Unpack two FP4 values per byte
    low_nibble = packed & 0x0F
    high_nibble = (packed >> 4) & 0x0F

    # LUT decode + sign handling
    low_val = fp4_lut_decode(low_nibble)
    high_val = fp4_lut_decode(high_nibble)

    # Apply block scale and global scale
    low_val *= block_scale_low * global_scale_inv
    high_val *= block_scale_high * global_scale_inv

    # Store as FP8 E4M3
    tl.store(out_ptr + ..., low_val.to(tl.float8e4nv), ...)
    tl.store(out_ptr + ..., high_val.to(tl.float8e4nv), ...)

第三步:Micro-benchmark

在 GB10 上用真實的 Qwen 3.6-35B-A3B NVFP4 權重測試:

Dequant 速度

方法耗時倍數
Python FP4→BF16(模擬 Marlin)0.226 msbaseline
Triton FP4→FP80.010 ms23x

單層 End-to-End(dequant + GEMM,batch=32)

路徑耗時倍數
Path A:FP4→BF16 + BF16 GEMM0.229 msbaseline
Path B:Triton FP4→FP8 + FP8 GEMM0.060 ms3.8x
Path C:純 FP8 GEMM(天花板)0.029 ms7.9x

Triton dequant 只要 0.01ms,幾乎免費。瓶頸已經移到 activation 的 FP8 量化。

第四步:整合進 vLLM

不改 vLLM source code,用 monkey-patch:

  1. 攔截 process_weights_after_loading:跳過 Marlin repack,改用 Triton 把 FP4 權重一次轉成 FP8,存在 layer._fp8_weight。轉完後釋放原始 FP4 權重省 VRAM。

  2. 攔截 apply_weights:用 torch._scaled_mm 跑 FP8 GEMM,不走 Marlin。

  3. 啟動方式:設定 VLLM_NVFP4_GEMM_BACKEND=marlin 強制 vLLM 選 Marlin kernel(這樣我們的 patch 才會被觸發),然後用 wrapper script import patch。

兩個踩過的坑:

  • _scaled_mm 的 matrix layout:B 必須是 column-major。存 (N, K) contiguous,runtime .t() 得到 col-major (K, N) view。
  • torch.compile 相容性:activation quant 用純 torch ops(不用 vllm custom op),讓 inductor 能 trace 和 fuse。

結果:47.6 tok/s

Qwen 3.6-35B-A3B NVFP4,DGX Spark,vLLM 0.19.1,driver 580.142。

註(2026-05-06):所有 tok/s 都在 Qwen 3.6-35B-A3B 上測。後續 profile 顯示單流 60% 時間在沒被量化的 BF16 GDN linear projections,本文 +17% speedup(40.8 → 47.6)在 Qwen 3.6 上可重現,但主因可能是 launch overhead reduction 而非「FP8 tensor core 比 BF16 GEMM 快」這條 mechanism 推論。要驗 patch 對純 transformer 的真實效應需用非 hybrid model 重測。

方案tok/svs Marlin
Marlin BF16 fallback(原始)40.8baseline
FlashInfer CUTLASS(vLLM 0.19 預設)42.5+4.2%
Triton FP8 patch47.6+16.7%
Native FP8(直接跑 FP8 模型)53.8+31.9%

NVFP4 模型從 40.8 → 47.6 tok/s,提升 17%。 離 native FP8 只剩 12% 差距。

剩下的 12% 在哪

三個來源:

  1. Dequant 精度損失:FP4 → FP8 的 dequant 沒有直接 FP8 quantize 好。原始的 FP8 模型是從 BF16/FP32 精心量化過的,我們的 FP4 → FP8 多了一層轉換。

  2. Activation quantize overhead:每次 forward pass 都要把 activation 從 BF16 量化成 FP8。torch.compile 已經幫忙 fuse 了,但還是有開銷。

  3. 沒有真正的 fused kernel:理想情況是 CUTLASS level 的 fused dequant+GEMM(在 shared memory 裡邊解壓邊算),我們是分兩步。

要吃掉這 12%,得改 CUTLASS C++ template——工作量大概一到兩週,而且需要 CUTLASS 3.x template metaprogramming 經驗。更實際的做法是等 vLLM 或 NVIDIA 官方出 SM121 FP8 fallback。

什麼時候該用這個 patch

:模型只有 NVFP4 checkpoint,沒有 FP8 版,你在 DGX Spark 上想要比 40.8 tok/s 更快。

不用:模型有 FP8 版。直接跑 FP8 = 53.8 tok/s,比繞路快。

結論

Part 19 的結論是「NVFP4 在 GB10 上沒救,用 FP8」。Part 20 修正為:NVFP4 在 GB10 上可以救,但有上限

Triton 30 行 kernel + monkey-patch,一個下午搞定,提速 17%。不需要改 vLLM source、不需要寫 CUDA、不需要等官方更新。

如果你也在 DGX Spark 上跑 NVFP4 模型,這個 patch 直接可用。如果你想貢獻 vLLM,這裡有一個明確的方向:在 SM121 上用 FP8 tensor core 取代 BF16 Marlin fallback。

完整程式碼

兩個檔案。nvfp4_fp8_patch.py 是核心 patch,serve_nvfp4_fp8.py 是啟動 wrapper。

nvfp4_fp8_patch.py

"""
Monkey-patch vLLM 的 NVFP4 Marlin fallback,改用 Triton FP4→FP8 dequant + FP8 GEMM。
適用於 NVIDIA GB10 (SM121):有 FP8 tensor core,沒有 FP4 tensor core。

攔截兩個方法:
  - process_weights_after_loading: 跳過 Marlin repack,用 Triton 一次轉成 FP8
  - apply_weights: 用 torch._scaled_mm 跑 FP8 GEMM

前置條件:
  - vLLM 0.19+
  - 環境變數 VLLM_NVFP4_GEMM_BACKEND=marlin(強制走 Marlin,才能被 patch 攔截)
  - 環境變數 TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas(SM121 Triton 編譯需要)
"""

import torch
import triton
import triton.language as tl
import logging
import sys

logger = logging.getLogger("nvfp4_fp8_patch")


# ============================================================
# Triton kernel: NVFP4 packed uint8 → FP8 E4M3
#
# 每個 byte 包含 2 個 FP4 E2M1 值(低 4 bit + 高 4 bit)。
# FP4 E2M1 只有 16 個可能的值:0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6
# 用 LUT 解碼後,乘上 block scale 再除以 global scale。
#
# Dequant 公式:value = fp4_lut[nibble] × block_scale ÷ global_scale
# 注意:global_scale 存的是 quantize scale,要除不是乘!
# ============================================================

@triton.jit
def _nvfp4_dequant_to_fp8_kernel(
    packed_ptr,        # 輸入:uint8 packed weights (N, K/2)
    scale_ptr,         # 輸入:float8_e4m3fn block scales (N, K/group_size)
    global_scale_ptr,  # 輸入:float32 global scale (1,)
    out_ptr,           # 輸出:float8_e4m3fn (N, K)
    K: tl.constexpr,
    K_packed: tl.constexpr,     # K // 2
    n_scale_cols: tl.constexpr, # K // group_size
    group_size: tl.constexpr,   # 通常是 16
    BLOCK_K: tl.constexpr,      # tile 大小
):
    row = tl.program_id(0)
    col_block = tl.program_id(1)
    col_offsets = col_block * BLOCK_K + tl.arange(0, BLOCK_K)
    mask = col_offsets < K_packed

    # global_scale 存的是 quantize scale,dequant 時要取倒數
    global_scale_inv = 1.0 / tl.load(global_scale_ptr)

    # 載入 packed bytes
    packed = tl.load(packed_ptr + row * K_packed + col_offsets, mask=mask, other=0).to(tl.uint8)

    # 拆分低 4 bit 和高 4 bit
    low_nibble = packed & 0x0F
    high_nibble = (packed >> 4) & 0x0F

    # 分離符號位 (bit 3) 和數值位 (bits 0-2)
    low_sign = ((low_nibble >> 3) & 1).to(tl.float32)
    low_mag = (low_nibble & 0x07).to(tl.uint8)
    high_sign = ((high_nibble >> 3) & 1).to(tl.float32)
    high_mag = (high_nibble & 0x07).to(tl.uint8)

    # FP4 E2M1 LUT:magnitude index → float value
    # 0→0, 1→0.5, 2→1.0, 3→1.5, 4→2.0, 5→3.0, 6→4.0, 7→6.0
    low_val = tl.where(low_mag == 0, 0.0,
              tl.where(low_mag == 1, 0.5,
              tl.where(low_mag == 2, 1.0,
              tl.where(low_mag == 3, 1.5,
              tl.where(low_mag == 4, 2.0,
              tl.where(low_mag == 5, 3.0,
              tl.where(low_mag == 6, 4.0, 6.0)))))))
    low_val = tl.where(low_sign > 0.5, -low_val, low_val)

    high_val = tl.where(high_mag == 0, 0.0,
               tl.where(high_mag == 1, 0.5,
               tl.where(high_mag == 2, 1.0,
               tl.where(high_mag == 3, 1.5,
               tl.where(high_mag == 4, 2.0,
               tl.where(high_mag == 5, 3.0,
               tl.where(high_mag == 6, 4.0, 6.0)))))))
    high_val = tl.where(high_sign > 0.5, -high_val, high_val)

    # 計算 block scale 的索引
    actual_col_low = col_offsets * 2
    actual_col_high = col_offsets * 2 + 1

    block_scale_low = tl.load(
        scale_ptr + row * n_scale_cols + actual_col_low // group_size,
        mask=actual_col_low < K, other=0.0).to(tl.float32)
    block_scale_high = tl.load(
        scale_ptr + row * n_scale_cols + actual_col_high // group_size,
        mask=actual_col_high < K, other=0.0).to(tl.float32)

    # dequant + clamp 到 FP8 E4M3 範圍 (±448)
    low_val = tl.minimum(tl.maximum(low_val * block_scale_low * global_scale_inv, -448.0), 448.0)
    high_val = tl.minimum(tl.maximum(high_val * block_scale_high * global_scale_inv, -448.0), 448.0)

    # 寫出 FP8
    tl.store(out_ptr + row * K + actual_col_low, low_val.to(tl.float8e4nv), mask=actual_col_low < K)
    tl.store(out_ptr + row * K + actual_col_high, high_val.to(tl.float8e4nv), mask=actual_col_high < K)


def nvfp4_to_fp8(weight_packed, weight_scale, weight_global_scale):
    """呼叫 Triton kernel 把 NVFP4 packed weights 轉成 FP8 E4M3。"""
    N, K_packed = weight_packed.shape
    K = K_packed * 2
    n_scale_cols = weight_scale.shape[1]
    group_size = K // n_scale_cols

    out = torch.empty(N, K, device=weight_packed.device, dtype=torch.float8_e4m3fn)
    BLOCK_K = 256
    grid = (N, triton.cdiv(K_packed, BLOCK_K))

    _nvfp4_dequant_to_fp8_kernel[grid](
        weight_packed, weight_scale, weight_global_scale, out,
        K=K, K_packed=K_packed,
        n_scale_cols=n_scale_cols, group_size=group_size,
        BLOCK_K=BLOCK_K,
    )
    return out


# ============================================================
# 替換 vLLM 的 Marlin 方法
# ============================================================

def _patched_process_weights(self, layer: torch.nn.Module) -> None:
    """跳過 Marlin repack,改用 Triton 把 NVFP4 一次轉成 FP8。

    轉完後釋放原始 FP4 權重以節省 VRAM。
    FP8 權重存在 layer._fp8_weight,runtime 直接用。
    """
    logger.warning_once(
        "[nvfp4_fp8_patch] SM121 detected. Skipping Marlin repack, "
        "will use Triton FP4→FP8 dequant + FP8 tensor core GEMM."
    )

    weight_packed = layer.weight.data           # uint8 (N, K/2)
    weight_scale = layer.weight_scale.data      # float8_e4m3fn (N, K/group)
    weight_global_scale = layer.weight_global_scale.data  # float32

    # 一次性 dequant:NVFP4 → FP8
    w_fp8 = nvfp4_to_fp8(weight_packed, weight_scale, weight_global_scale)

    # 存成 (N, K) contiguous。runtime 用 .t() 得到 col-major (K, N) view,
    # 這是 torch._scaled_mm 對 B matrix 的要求。
    layer._fp8_weight = torch.nn.Parameter(w_fp8.contiguous(), requires_grad=False)
    layer._fp8_w_scale = torch.nn.Parameter(
        torch.ones(1, 1, device=w_fp8.device, dtype=torch.float32),
        requires_grad=False,
    )

    # 釋放原始 FP4 權重,騰出 VRAM
    layer.weight = torch.nn.Parameter(
        torch.empty(0, dtype=torch.uint8, device=weight_packed.device),
        requires_grad=False,
    )
    layer.weight_scale = torch.nn.Parameter(
        torch.empty(0, dtype=torch.float8_e4m3fn, device=weight_scale.device),
        requires_grad=False,
    )

    logger.info(
        f"[nvfp4_fp8_patch] Pre-converted layer to FP8: "
        f"{weight_packed.shape} → {w_fp8.shape}, "
        f"weight range: [{w_fp8.float().min():.4f}, {w_fp8.float().max():.4f}]"
    )


def _patched_apply_weights(self, layer, x, bias=None):
    """用預轉好的 FP8 權重跑 FP8 tensor core GEMM。

    Activation 用純 torch ops 量化成 FP8,
    讓 torch.compile/inductor 能 trace 和 fuse。
    """
    orig_shape = x.shape
    x_2d = x.reshape(-1, orig_shape[-1])

    # Activation → FP8(per-tensor dynamic quantization)
    amax = x_2d.float().abs().amax()
    a_scale = (amax / 448.0).clamp(min=1e-12)
    x_fp8 = (x_2d.float() / a_scale).clamp(-448, 448).to(torch.float8_e4m3fn)
    a_scale = a_scale.reshape(1, 1)

    # FP8 GEMM:(M, K) @ (K, N) → (M, N)
    out = torch._scaled_mm(
        x_fp8,
        layer._fp8_weight.t(),  # (N,K).t() = col-major (K,N)
        scale_a=a_scale,
        scale_b=layer._fp8_w_scale,
        out_dtype=x.dtype,
    )

    if bias is not None:
        out = out + bias

    return out.reshape(*orig_shape[:-1], out.shape[-1])


# ============================================================
# Patch 安裝:偵測 SM121 後替換 Marlin 方法
# ============================================================

_patched = False

def patch():
    """安裝 monkey-patch。只在 SM121 (GB10) 上生效。"""
    global _patched
    if _patched:
        return True

    if torch.cuda.is_available():
        cap = torch.cuda.get_device_capability()
        if cap != (12, 1):
            logger.info(f"SM{cap[0]}{cap[1]} detected, FP8 patch not needed")
            return False

    try:
        from vllm.model_executor.kernels.linear.nvfp4.marlin import (
            MarlinNvFp4LinearKernel,
        )
        MarlinNvFp4LinearKernel.process_weights_after_loading = _patched_process_weights
        MarlinNvFp4LinearKernel.apply_weights = _patched_apply_weights

        _patched = True
        logger.warning(
            "[nvfp4_fp8_patch] Installed! MarlinNvFp4LinearKernel now uses "
            "pre-converted FP8 weights + FP8 tensor core GEMM."
        )
        return True
    except ImportError as e:
        logger.error(f"Failed to patch: {e}")
        return False

serve_nvfp4_fp8.py

"""
啟動 wrapper:先安裝 FP8 patch,再啟動 vllm serve。

patch 必須在 parent 和 spawned worker 都執行(放在 __main__ guard 外面),
但只有 parent 才呼叫 vllm_main()。

環境變數:
  MODEL_PATH   模型路徑(預設 /model)
  PORT         API port(預設 8000)
  MAX_MODEL_LEN  最大 context 長度(預設 4096)
  GPU_MEM_UTIL   GPU 記憶體使用率(預設 0.90)

Docker 用法:
  docker run --gpus all \
    -e VLLM_NVFP4_GEMM_BACKEND=marlin \
    -e TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas \
    -e MODEL_PATH=/model \
    -v /path/to/nvfp4/model:/model \
    -v /path/to/scripts:/scripts \
    --entrypoint python3 <vllm-image> /scripts/serve_nvfp4_fp8.py
"""
import sys
import os
import logging

# patch 在 import 時執行,確保 spawned workers 也會套用
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import nvfp4_fp8_patch
nvfp4_fp8_patch.patch()

if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO,
                        format="%(asctime)s %(name)s %(levelname)s %(message)s")

    model_path = os.environ.get("MODEL_PATH", "/model")
    port = os.environ.get("PORT", "8000")
    max_model_len = os.environ.get("MAX_MODEL_LEN", "4096")
    gpu_mem = os.environ.get("GPU_MEM_UTIL", "0.90")

    sys.argv = [
        "vllm", "serve", model_path,
        "--port", port,
        "--max-model-len", max_model_len,
        "--gpu-memory-utilization", gpu_mem,
        "--trust-remote-code",
    ]

    from vllm.entrypoints.cli.main import main as vllm_main
    vllm_main()

同系列文章:Part 19 — NVFP4 是陷阱 · Part 14 — Gemma 4 全家桶 · Part 8 — vLLM vs Ollama

常見問題

NVFP4 在 DGX Spark 上真的沒救了嗎?
沒有原生硬體支援,但可以繞路。用 Triton 把 NVFP4 權重 dequant 成 FP8,再走 FP8 tensor core,從 40.8 提升到 47.6 tok/s(+17%)。離 native FP8 的 53.8 只差 12%。
這個 patch 能直接用在 vLLM 上嗎?
可以。monkey-patch vLLM 的 MarlinNvFp4LinearKernel,load time 一次轉好 FP8,runtime 走 torch._scaled_mm。需要設定 VLLM_NVFP4_GEMM_BACKEND=marlin。完整程式碼在文中。
為什麼不直接跑 FP8 模型就好?
如果有 FP8 版,直接跑 FP8 確實更快(53.8 tok/s)。這個 patch 的價值是:當模型只有 NVFP4 release、沒有 FP8 版本時,你不用再忍受 Marlin BF16 fallback 的 40.8 tok/s。