~/blog/dgx-spark-nvfp4-fp8-triton-patch

DGX Spark · part 20

[Hands-On] Making NVFP4 17% Faster on GB10 with a Triton FP8 Bypass

cat --toc

TL;DR

Part 19 showed NVFP4 hits a Marlin BF16 fallback on GB10, capping at 40.8 tok/s. This post fights back: a Triton kernel dequants NVFP4 → FP8, then runs on FP8 tensor cores. Result: 47.6 tok/s (+17%). Only 12% behind native FP8's 53.8. Monkey-patches vLLM with no source changes. Full code included.

The Setup: Hardware Won't Help, So Software Must

Last time we discovered that GB10 (SM121) lacks the FP4 hardware instruction cvt.rn.satfinite.e2m1x2.f32. NVFP4 falls back to Marlin's BF16 dequant path — 40.8 tok/s, 32% slower than FP8.

The conclusion was "just use FP8." But that assumes an FP8 checkpoint exists.

Here's the key insight: GB10 has no FP4 tensor cores, but it does have FP8 tensor cores. Can we dequant NVFP4 weights to FP8 and use the FP8 hardware instead?

Yes.


What Marlin Wastes

vLLM's NVFP4 path on SM121:

NVFP4 weights → Marlin dequant → BF16 → BF16 GEMM

Marlin targets SM75+ compatibility, so it dequants to BF16. But on GB10, FP8 tensor cores are 1.6–4x faster than BF16 for the same matrix sizes. Marlin leaves that performance on the table.

Our path:

NVFP4 weights → Triton dequant → FP8 → FP8 tensor core GEMM

The NVFP4 Dequant Formula

NVFP4 (compressed-tensors format) stores three tensors per layer:

TensorTypeShape
weight_packeduint8(N, K/2) — two FP4 values per byte
weight_scalefloat8_e4m3fn(N, K/16) — per-group block scale
weight_global_scalefloat32(1,) — per-tensor global scale

FP4 E2M1 has exactly 16 possible values:

0, ±0.5, ±1.0, ±1.5, ±2.0, ±3.0, ±4.0, ±6.0

The dequant formula:

value = fp4_lut[nibble] × block_scale ÷ global_scale

Gotcha: global_scale stores the quantization scale (not the dequant scale), so you divide by it. I got this wrong on the first attempt — dequanted values hit ±100 million, got clamped to FP8's ±448 range, and produced garbage.

The Triton Kernel

FP4 E2M1 → FP8 E4M3 is a lookup table + scale multiply — ideal for GPU parallelism:

@triton.jit
def _nvfp4_dequant_to_fp8_kernel(
    packed_ptr, scale_ptr, global_scale_ptr, out_ptr,
    K, K_packed, n_scale_cols, group_size, BLOCK_K: tl.constexpr,
):
    row = tl.program_id(0)
    col_offsets = tl.program_id(1) * BLOCK_K + tl.arange(0, BLOCK_K)

    global_scale_inv = 1.0 / tl.load(global_scale_ptr)
    packed = tl.load(packed_ptr + row * K_packed + col_offsets, ...)

    # Unpack two FP4 values per byte
    low_nibble = packed & 0x0F
    high_nibble = (packed >> 4) & 0x0F

    # LUT decode + sign handling
    low_val = fp4_lut_decode(low_nibble)
    high_val = fp4_lut_decode(high_nibble)

    # Apply scales
    low_val *= block_scale * global_scale_inv
    high_val *= block_scale * global_scale_inv

    # Store as FP8 E4M3
    tl.store(out_ptr + ..., low_val.to(tl.float8e4nv), ...)

The full kernel is ~80 lines of Triton. No CUDA, no CUTLASS templates.

Micro-Benchmarks

Tested with real Qwen 3.6-35B-A3B NVFP4 weights on GB10.

Dequant Speed

MethodTimeSpeedup
Python FP4→BF16 (simulating Marlin)0.226 msbaseline
Triton FP4→FP80.010 ms23x

Single-Layer End-to-End (dequant + GEMM, batch=32)

PathTimeSpeedup
A: FP4→BF16 + BF16 GEMM0.229 msbaseline
B: Triton FP4→FP8 + FP8 GEMM0.060 ms3.8x
C: Pure FP8 GEMM (ceiling)0.029 ms7.9x

The Triton dequant itself is 0.01ms — essentially free. The gap to ceiling is activation FP8 quantization overhead.

Integrating with vLLM

No source modifications. A monkey-patch replaces two methods on MarlinNvFp4LinearKernel:

process_weights_after_loading: Skip Marlin's repack. Instead, run the Triton kernel once to convert all FP4 weights to FP8. Store the result as layer._fp8_weight. Free the original FP4 tensors to reclaim VRAM.

apply_weights: Quantize activations to FP8, then call torch._scaled_mm for the FP8 GEMM. No Marlin involved.

Launch: Set VLLM_NVFP4_GEMM_BACKEND=marlin to force vLLM to select the Marlin kernel (which our patch intercepts). Use a wrapper script to apply the patch before model loading.

Two pitfalls we hit:

  • _scaled_mm matrix layout: The B matrix must be column-major. Store (N, K) contiguous, .t() at runtime gives the col-major (K, N) view.
  • torch.compile compatibility: Use pure torch ops for activation quantization (not vLLM custom ops) so the inductor can trace and fuse the graph.

Result: 47.6 tok/s

Qwen 3.6-35B-A3B NVFP4, DGX Spark, vLLM 0.19.1, driver 580.142.

Approachtok/svs Marlin
Marlin BF16 fallback (stock)40.8baseline
FlashInfer CUTLASS (vLLM 0.19 default)42.5+4.2%
Triton FP8 patch47.6+16.7%
Native FP8 (FP8 checkpoint)53.8+31.9%

NVFP4 goes from 40.8 → 47.6 tok/s, a 17% speedup. The gap to native FP8 shrinks from 32% to 12%.

Where the Last 12% Lives

Three sources:

  1. Dequant precision loss: FP4 → FP8 conversion isn't as precise as direct FP8 quantization from BF16/FP32. The original FP8 model was carefully quantized; our path adds an extra conversion step.

  2. Activation quantization overhead: Every forward pass quantizes activations from BF16 to FP8. torch.compile fuses this, but it's not zero-cost.

  3. No fused kernel: A CUTLASS-level fused dequant+GEMM would dequant in shared memory while computing — eliminating the intermediate FP8 tensor entirely. That's a 1–2 week engineering effort requiring CUTLASS 3.x template metaprogramming.

When to Use This

Use it: Model only ships NVFP4, no FP8 checkpoint exists, you're on DGX Spark and want more than 40.8 tok/s.

Skip it: FP8 checkpoint available. Native FP8 = 53.8 tok/s, always faster than the bypass.

Conclusion

Part 19 said "NVFP4 is a trap on GB10, use FP8." Part 20 amends that: NVFP4 is recoverable on GB10, with a ceiling.

A 30-line Triton kernel plus a monkey-patch, built in an afternoon, delivers 17% more throughput. No vLLM source changes, no CUDA, no waiting for upstream fixes.

If you're running NVFP4 models on DGX Spark, the patch works today. If you want to contribute to vLLM, here's a clear direction: replace the BF16 Marlin fallback with FP8 tensor cores on SM121.


Series: Part 19 — NVFP4 Is a Trap · Part 14 — Gemma 4 Complete Guide · Part 8 — vLLM vs Ollama

FAQ

Is NVFP4 truly hopeless on DGX Spark?
No native hardware support, but you can work around it. A Triton kernel dequants NVFP4 weights to FP8, then runs on FP8 tensor cores — 40.8 → 47.6 tok/s (+17%). Only 12% behind native FP8.
Can I use this patch with vLLM today?
Yes. It monkey-patches vLLM's MarlinNvFp4LinearKernel, converts weights to FP8 at load time, and runs torch._scaled_mm at inference. Set VLLM_NVFP4_GEMM_BACKEND=marlin. Full code in the article.
Why not just run the FP8 model directly?
If an FP8 checkpoint exists, use it — 53.8 tok/s beats everything. This patch matters when a model ships only as NVFP4 and no FP8 version exists.