~/blog/dgx-spark-nvfp4-fp8-triton-patch

DGX Spark · part 20

[Hands-On] Making NVFP4 17% Faster on GB10 with a Triton FP8 Bypass

2026-04-22updated 2026-05-0612 min read#nvfp4#fp8#triton#dgx-spark中文版
cat --toc

TL;DR

Part 19 showed NVFP4 hits a Marlin BF16 fallback on GB10, capping at 40.8 tok/s. This post fights back: a Triton kernel dequants NVFP4 → FP8, then runs on FP8 tensor cores. Result: 47.6 tok/s (+17%). Only 12% behind native FP8's 53.8. Monkey-patches vLLM with no source changes. Full code included.

The Setup: Hardware Won't Help, So Software Must

Last time we discovered that GB10 (SM121) lacks the FP4 hardware instruction cvt.rn.satfinite.e2m1x2.f32. NVFP4 falls back to Marlin's BF16 dequant path — 40.8 tok/s, 32% slower than FP8.

The conclusion was "just use FP8." But that assumes an FP8 checkpoint exists.

Here's the key insight: GB10 has no FP4 tensor cores, but it does have FP8 tensor cores. Can we dequant NVFP4 weights to FP8 and use the FP8 hardware instead?

Note (2026-04-30): The "GB10 has no FP4 tensor cores" premise turns out to be wrong — NVIDIA later clarified on the developer forum that GB10 does have hardware NVFP4 support, it just requires the sm_121a target instead of sm_121. See the Part 19 update. The Triton patch below is still useful in practice though — vLLM 0.19's default Marlin BF16 fallback (40.8 tok/s) was the real-world ceiling at the time, and routing through FP8 tensor cores still beats it by 17%. Going through native NVFP4 on sm_121a requires vLLM PR #40082 + a 4-layer software patch stack — see Part 25.

Yes.


What Marlin Wastes

vLLM's NVFP4 path on SM121:

NVFP4 weights → Marlin dequant → BF16 → BF16 GEMM

Marlin targets SM75+ compatibility, so it dequants to BF16. But on GB10, FP8 tensor cores are 1.6–4x faster than BF16 for the same matrix sizes. Marlin leaves that performance on the table.

Our path:

NVFP4 weights → Triton dequant → FP8 → FP8 tensor core GEMM

The NVFP4 Dequant Formula

NVFP4 (compressed-tensors format) stores three tensors per layer:

TensorTypeShape
weight_packeduint8(N, K/2) — two FP4 values per byte
weight_scalefloat8_e4m3fn(N, K/16) — per-group block scale
weight_global_scalefloat32(1,) — per-tensor global scale

FP4 E2M1 has exactly 16 possible values:

0, ±0.5, ±1.0, ±1.5, ±2.0, ±3.0, ±4.0, ±6.0

The dequant formula:

value = fp4_lut[nibble] × block_scale ÷ global_scale

Gotcha: global_scale stores the quantization scale (not the dequant scale), so you divide by it. I got this wrong on the first attempt — dequanted values hit ±100 million, got clamped to FP8's ±448 range, and produced garbage.

The Triton Kernel

FP4 E2M1 → FP8 E4M3 is a lookup table + scale multiply — ideal for GPU parallelism:

@triton.jit
def _nvfp4_dequant_to_fp8_kernel(
    packed_ptr, scale_ptr, global_scale_ptr, out_ptr,
    K, K_packed, n_scale_cols, group_size, BLOCK_K: tl.constexpr,
):
    row = tl.program_id(0)
    col_offsets = tl.program_id(1) * BLOCK_K + tl.arange(0, BLOCK_K)

    global_scale_inv = 1.0 / tl.load(global_scale_ptr)
    packed = tl.load(packed_ptr + row * K_packed + col_offsets, ...)

    # Unpack two FP4 values per byte
    low_nibble = packed & 0x0F
    high_nibble = (packed >> 4) & 0x0F

    # LUT decode + sign handling
    low_val = fp4_lut_decode(low_nibble)
    high_val = fp4_lut_decode(high_nibble)

    # Apply scales
    low_val *= block_scale * global_scale_inv
    high_val *= block_scale * global_scale_inv

    # Store as FP8 E4M3
    tl.store(out_ptr + ..., low_val.to(tl.float8e4nv), ...)

The full kernel is ~80 lines of Triton. No CUDA, no CUTLASS templates.

Micro-Benchmarks

Tested with real Qwen 3.6-35B-A3B NVFP4 weights on GB10.

Dequant Speed

MethodTimeSpeedup
Python FP4→BF16 (simulating Marlin)0.226 msbaseline
Triton FP4→FP80.010 ms23x

Single-Layer End-to-End (dequant + GEMM, batch=32)

PathTimeSpeedup
A: FP4→BF16 + BF16 GEMM0.229 msbaseline
B: Triton FP4→FP8 + FP8 GEMM0.060 ms3.8x
C: Pure FP8 GEMM (ceiling)0.029 ms7.9x

The Triton dequant itself is 0.01ms — essentially free. The gap to ceiling is activation FP8 quantization overhead.

Integrating with vLLM

No source modifications. A monkey-patch replaces two methods on MarlinNvFp4LinearKernel:

process_weights_after_loading: Skip Marlin's repack. Instead, run the Triton kernel once to convert all FP4 weights to FP8. Store the result as layer._fp8_weight. Free the original FP4 tensors to reclaim VRAM.

apply_weights: Quantize activations to FP8, then call torch._scaled_mm for the FP8 GEMM. No Marlin involved.

Launch: Set VLLM_NVFP4_GEMM_BACKEND=marlin to force vLLM to select the Marlin kernel (which our patch intercepts). Use a wrapper script to apply the patch before model loading.

Two pitfalls we hit:

  • _scaled_mm matrix layout: The B matrix must be column-major. Store (N, K) contiguous, .t() at runtime gives the col-major (K, N) view.
  • torch.compile compatibility: Use pure torch ops for activation quantization (not vLLM custom ops) so the inductor can trace and fuse the graph.

Result: 47.6 tok/s

Qwen 3.6-35B-A3B NVFP4, DGX Spark, vLLM 0.19.1, driver 580.142.

Note (2026-05-06): All tok/s numbers are measured on Qwen 3.6-35B-A3B. A later profile showed 60% of single-stream time is spent in non-quantized BF16 GDN linear projections, so the +17% speedup (40.8 → 47.6) is reproducible on Qwen 3.6 but the dominant cause may be launch-overhead reduction rather than the "FP8 tensor core beats BF16 GEMM" mechanism reasoned about in this article. Validating the patch's real effect on pure-transformer models requires a retest with non-hybrid architectures.

Approachtok/svs Marlin
Marlin BF16 fallback (stock)40.8baseline
FlashInfer CUTLASS (vLLM 0.19 default)42.5+4.2%
Triton FP8 patch47.6+16.7%
Native FP8 (FP8 checkpoint)53.8+31.9%

NVFP4 goes from 40.8 → 47.6 tok/s, a 17% speedup. The gap to native FP8 shrinks from 32% to 12%.

Where the Last 12% Lives

Three sources:

  1. Dequant precision loss: FP4 → FP8 conversion isn't as precise as direct FP8 quantization from BF16/FP32. The original FP8 model was carefully quantized; our path adds an extra conversion step.

  2. Activation quantization overhead: Every forward pass quantizes activations from BF16 to FP8. torch.compile fuses this, but it's not zero-cost.

  3. No fused kernel: A CUTLASS-level fused dequant+GEMM would dequant in shared memory while computing — eliminating the intermediate FP8 tensor entirely. That's a 1–2 week engineering effort requiring CUTLASS 3.x template metaprogramming.

When to Use This

Use it: Model only ships NVFP4, no FP8 checkpoint exists, you're on DGX Spark and want more than 40.8 tok/s.

Skip it: FP8 checkpoint available. Native FP8 = 53.8 tok/s, always faster than the bypass.

Conclusion

Part 19 said "NVFP4 is a trap on GB10, use FP8." Part 20 amends that: NVFP4 is recoverable on GB10, with a ceiling.

A 30-line Triton kernel plus a monkey-patch, built in an afternoon, delivers 17% more throughput. No vLLM source changes, no CUDA, no waiting for upstream fixes.

If you're running NVFP4 models on DGX Spark, the patch works today. If you want to contribute to vLLM, here's a clear direction: replace the BF16 Marlin fallback with FP8 tensor cores on SM121.

Full Code

Two files. nvfp4_fp8_patch.py is the core patch, serve_nvfp4_fp8.py is the launch wrapper.

nvfp4_fp8_patch.py

"""
Monkey-patch vLLM's NVFP4 Marlin fallback to use Triton FP4→FP8 dequant + FP8 GEMM.
For NVIDIA GB10 (SM121) which has FP8 tensor cores but no FP4 tensor cores.

Intercepts two methods:
  - process_weights_after_loading: skip Marlin repack, convert to FP8 via Triton
  - apply_weights: FP8 GEMM via torch._scaled_mm

Prerequisites:
  - vLLM 0.19+
  - VLLM_NVFP4_GEMM_BACKEND=marlin (force Marlin so our patch intercepts it)
  - TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas (Triton compilation on SM121)
"""

import torch
import triton
import triton.language as tl
import logging
import sys

logger = logging.getLogger("nvfp4_fp8_patch")


# ============================================================
# Triton kernel: NVFP4 packed uint8 → FP8 E4M3
#
# Each byte holds 2 FP4 E2M1 values (low 4 bits + high 4 bits).
# FP4 E2M1 has only 16 possible values: 0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6
# We decode via LUT, multiply by block scale, divide by global scale.
#
# Dequant formula: value = fp4_lut[nibble] × block_scale ÷ global_scale
# Note: global_scale stores the quantization scale — divide, don't multiply!
# ============================================================

@triton.jit
def _nvfp4_dequant_to_fp8_kernel(
    packed_ptr,        # input: uint8 packed weights (N, K/2)
    scale_ptr,         # input: float8_e4m3fn block scales (N, K/group_size)
    global_scale_ptr,  # input: float32 global scale (1,)
    out_ptr,           # output: float8_e4m3fn (N, K)
    K: tl.constexpr,
    K_packed: tl.constexpr,     # K // 2
    n_scale_cols: tl.constexpr, # K // group_size
    group_size: tl.constexpr,   # typically 16
    BLOCK_K: tl.constexpr,      # tile size
):
    row = tl.program_id(0)
    col_block = tl.program_id(1)
    col_offsets = col_block * BLOCK_K + tl.arange(0, BLOCK_K)
    mask = col_offsets < K_packed

    # global_scale stores quantize scale; invert for dequant
    global_scale_inv = 1.0 / tl.load(global_scale_ptr)

    # Load packed bytes
    packed = tl.load(packed_ptr + row * K_packed + col_offsets, mask=mask, other=0).to(tl.uint8)

    # Split low and high nibbles
    low_nibble = packed & 0x0F
    high_nibble = (packed >> 4) & 0x0F

    # Separate sign bit (bit 3) and magnitude bits (bits 0-2)
    low_sign = ((low_nibble >> 3) & 1).to(tl.float32)
    low_mag = (low_nibble & 0x07).to(tl.uint8)
    high_sign = ((high_nibble >> 3) & 1).to(tl.float32)
    high_mag = (high_nibble & 0x07).to(tl.uint8)

    # FP4 E2M1 LUT: magnitude index → float value
    # 0→0, 1→0.5, 2→1.0, 3→1.5, 4→2.0, 5→3.0, 6→4.0, 7→6.0
    low_val = tl.where(low_mag == 0, 0.0,
              tl.where(low_mag == 1, 0.5,
              tl.where(low_mag == 2, 1.0,
              tl.where(low_mag == 3, 1.5,
              tl.where(low_mag == 4, 2.0,
              tl.where(low_mag == 5, 3.0,
              tl.where(low_mag == 6, 4.0, 6.0)))))))
    low_val = tl.where(low_sign > 0.5, -low_val, low_val)

    high_val = tl.where(high_mag == 0, 0.0,
               tl.where(high_mag == 1, 0.5,
               tl.where(high_mag == 2, 1.0,
               tl.where(high_mag == 3, 1.5,
               tl.where(high_mag == 4, 2.0,
               tl.where(high_mag == 5, 3.0,
               tl.where(high_mag == 6, 4.0, 6.0)))))))
    high_val = tl.where(high_sign > 0.5, -high_val, high_val)

    # Compute block scale indices
    actual_col_low = col_offsets * 2
    actual_col_high = col_offsets * 2 + 1

    block_scale_low = tl.load(
        scale_ptr + row * n_scale_cols + actual_col_low // group_size,
        mask=actual_col_low < K, other=0.0).to(tl.float32)
    block_scale_high = tl.load(
        scale_ptr + row * n_scale_cols + actual_col_high // group_size,
        mask=actual_col_high < K, other=0.0).to(tl.float32)

    # Dequant + clamp to FP8 E4M3 range (±448)
    low_val = tl.minimum(tl.maximum(low_val * block_scale_low * global_scale_inv, -448.0), 448.0)
    high_val = tl.minimum(tl.maximum(high_val * block_scale_high * global_scale_inv, -448.0), 448.0)

    # Write FP8 output
    tl.store(out_ptr + row * K + actual_col_low, low_val.to(tl.float8e4nv), mask=actual_col_low < K)
    tl.store(out_ptr + row * K + actual_col_high, high_val.to(tl.float8e4nv), mask=actual_col_high < K)


def nvfp4_to_fp8(weight_packed, weight_scale, weight_global_scale):
    """Convert NVFP4 packed weights to FP8 E4M3 via Triton kernel."""
    N, K_packed = weight_packed.shape
    K = K_packed * 2
    n_scale_cols = weight_scale.shape[1]
    group_size = K // n_scale_cols

    out = torch.empty(N, K, device=weight_packed.device, dtype=torch.float8_e4m3fn)
    BLOCK_K = 256
    grid = (N, triton.cdiv(K_packed, BLOCK_K))

    _nvfp4_dequant_to_fp8_kernel[grid](
        weight_packed, weight_scale, weight_global_scale, out,
        K=K, K_packed=K_packed,
        n_scale_cols=n_scale_cols, group_size=group_size,
        BLOCK_K=BLOCK_K,
    )
    return out


# ============================================================
# Replacement methods for vLLM's MarlinNvFp4LinearKernel
# ============================================================

def _patched_process_weights(self, layer: torch.nn.Module) -> None:
    """Skip Marlin repack. Convert NVFP4 → FP8 once at load time.

    Frees original FP4 weights after conversion to save VRAM.
    Stores FP8 weights in layer._fp8_weight for runtime use.
    """
    logger.warning_once(
        "[nvfp4_fp8_patch] SM121 detected. Skipping Marlin repack, "
        "will use Triton FP4→FP8 dequant + FP8 tensor core GEMM."
    )

    weight_packed = layer.weight.data           # uint8 (N, K/2)
    weight_scale = layer.weight_scale.data      # float8_e4m3fn (N, K/group)
    weight_global_scale = layer.weight_global_scale.data  # float32

    # One-time dequant: NVFP4 → FP8
    w_fp8 = nvfp4_to_fp8(weight_packed, weight_scale, weight_global_scale)

    # Store as (N, K) contiguous. At runtime, .t() produces a col-major
    # (K, N) view — required by torch._scaled_mm for the B matrix.
    layer._fp8_weight = torch.nn.Parameter(w_fp8.contiguous(), requires_grad=False)
    layer._fp8_w_scale = torch.nn.Parameter(
        torch.ones(1, 1, device=w_fp8.device, dtype=torch.float32),
        requires_grad=False,
    )

    # Free original FP4 weights to reclaim VRAM
    layer.weight = torch.nn.Parameter(
        torch.empty(0, dtype=torch.uint8, device=weight_packed.device),
        requires_grad=False,
    )
    layer.weight_scale = torch.nn.Parameter(
        torch.empty(0, dtype=torch.float8_e4m3fn, device=weight_scale.device),
        requires_grad=False,
    )

    logger.info(
        f"[nvfp4_fp8_patch] Pre-converted layer to FP8: "
        f"{weight_packed.shape} → {w_fp8.shape}, "
        f"weight range: [{w_fp8.float().min():.4f}, {w_fp8.float().max():.4f}]"
    )


def _patched_apply_weights(self, layer, x, bias=None):
    """FP8 tensor core GEMM using pre-converted weights.

    Activation quantization uses pure torch ops so torch.compile/inductor
    can trace and fuse the graph (vLLM custom ops break inductor here).
    """
    orig_shape = x.shape
    x_2d = x.reshape(-1, orig_shape[-1])

    # Activation → FP8 (per-tensor dynamic quantization)
    amax = x_2d.float().abs().amax()
    a_scale = (amax / 448.0).clamp(min=1e-12)
    x_fp8 = (x_2d.float() / a_scale).clamp(-448, 448).to(torch.float8_e4m3fn)
    a_scale = a_scale.reshape(1, 1)

    # FP8 GEMM: (M, K) @ (K, N) → (M, N)
    out = torch._scaled_mm(
        x_fp8,
        layer._fp8_weight.t(),  # (N,K).t() = col-major (K,N)
        scale_a=a_scale,
        scale_b=layer._fp8_w_scale,
        out_dtype=x.dtype,
    )

    if bias is not None:
        out = out + bias

    return out.reshape(*orig_shape[:-1], out.shape[-1])


# ============================================================
# Patch installation: detect SM121, replace Marlin methods
# ============================================================

_patched = False

def patch():
    """Install the monkey-patch. Only activates on SM121 (GB10)."""
    global _patched
    if _patched:
        return True

    if torch.cuda.is_available():
        cap = torch.cuda.get_device_capability()
        if cap != (12, 1):
            logger.info(f"SM{cap[0]}{cap[1]} detected, FP8 patch not needed")
            return False

    try:
        from vllm.model_executor.kernels.linear.nvfp4.marlin import (
            MarlinNvFp4LinearKernel,
        )
        MarlinNvFp4LinearKernel.process_weights_after_loading = _patched_process_weights
        MarlinNvFp4LinearKernel.apply_weights = _patched_apply_weights

        _patched = True
        logger.warning(
            "[nvfp4_fp8_patch] Installed! MarlinNvFp4LinearKernel now uses "
            "pre-converted FP8 weights + FP8 tensor core GEMM."
        )
        return True
    except ImportError as e:
        logger.error(f"Failed to patch: {e}")
        return False

serve_nvfp4_fp8.py

"""
Launch wrapper: install FP8 patch, then start vllm serve.

The patch runs at import time (outside __main__ guard) so spawned workers
also get patched. Only the parent process calls vllm_main().

Environment variables:
  MODEL_PATH     Model path (default: /model)
  PORT           API port (default: 8000)
  MAX_MODEL_LEN  Max context length (default: 4096)
  GPU_MEM_UTIL   GPU memory utilization (default: 0.90)

Docker usage:
  docker run --gpus all \
    -e VLLM_NVFP4_GEMM_BACKEND=marlin \
    -e TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas \
    -e MODEL_PATH=/model \
    -v /path/to/nvfp4/model:/model \
    -v /path/to/scripts:/scripts \
    --entrypoint python3 <vllm-image> /scripts/serve_nvfp4_fp8.py
"""
import sys
import os
import logging

# Patch at import time so spawned workers inherit it
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import nvfp4_fp8_patch
nvfp4_fp8_patch.patch()

if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO,
                        format="%(asctime)s %(name)s %(levelname)s %(message)s")

    model_path = os.environ.get("MODEL_PATH", "/model")
    port = os.environ.get("PORT", "8000")
    max_model_len = os.environ.get("MAX_MODEL_LEN", "4096")
    gpu_mem = os.environ.get("GPU_MEM_UTIL", "0.90")

    sys.argv = [
        "vllm", "serve", model_path,
        "--port", port,
        "--max-model-len", max_model_len,
        "--gpu-memory-utilization", gpu_mem,
        "--trust-remote-code",
    ]

    from vllm.entrypoints.cli.main import main as vllm_main
    vllm_main()

Series: Part 19 — NVFP4 Is a Trap · Part 14 — Gemma 4 Complete Guide · Part 8 — vLLM vs Ollama

FAQ

Is NVFP4 truly hopeless on DGX Spark?
No native hardware support, but you can work around it. A Triton kernel dequants NVFP4 weights to FP8, then runs on FP8 tensor cores — 40.8 → 47.6 tok/s (+17%). Only 12% behind native FP8.
Can I use this patch with vLLM today?
Yes. It monkey-patches vLLM's MarlinNvFp4LinearKernel, converts weights to FP8 at load time, and runs torch._scaled_mm at inference. Set VLLM_NVFP4_GEMM_BACKEND=marlin. Full code in the article.
Why not just run the FP8 model directly?
If an FP8 checkpoint exists, use it — 53.8 tok/s beats everything. This patch matters when a model ships only as NVFP4 and no FP8 version exists.