DGX Spark · part 20
[Hands-On] Making NVFP4 17% Faster on GB10 with a Triton FP8 Bypass
❯ cat --toc
- The Setup: Hardware Won't Help, So Software Must
- What Marlin Wastes
- The NVFP4 Dequant Formula
- The Triton Kernel
- Micro-Benchmarks
- Dequant Speed
- Single-Layer End-to-End (dequant + GEMM, batch=32)
- Integrating with vLLM
- Result: 47.6 tok/s
- Where the Last 12% Lives
- When to Use This
- Conclusion
- Full Code
- nvfp4_fp8_patch.py
- serve_nvfp4_fp8.py
TL;DR
Part 19 showed NVFP4 hits a Marlin BF16 fallback on GB10, capping at 40.8 tok/s. This post fights back: a Triton kernel dequants NVFP4 → FP8, then runs on FP8 tensor cores. Result: 47.6 tok/s (+17%). Only 12% behind native FP8's 53.8. Monkey-patches vLLM with no source changes. Full code included.
The Setup: Hardware Won't Help, So Software Must
Last time we discovered that GB10 (SM121) lacks the FP4 hardware instruction cvt.rn.satfinite.e2m1x2.f32. NVFP4 falls back to Marlin's BF16 dequant path — 40.8 tok/s, 32% slower than FP8.
The conclusion was "just use FP8." But that assumes an FP8 checkpoint exists.
Here's the key insight: GB10 has no FP4 tensor cores, but it does have FP8 tensor cores. Can we dequant NVFP4 weights to FP8 and use the FP8 hardware instead?
Note (2026-04-30): The "GB10 has no FP4 tensor cores" premise turns out to be wrong — NVIDIA later clarified on the developer forum that GB10 does have hardware NVFP4 support, it just requires the
sm_121atarget instead ofsm_121. See the Part 19 update. The Triton patch below is still useful in practice though — vLLM 0.19's default Marlin BF16 fallback (40.8 tok/s) was the real-world ceiling at the time, and routing through FP8 tensor cores still beats it by 17%. Going through native NVFP4 onsm_121arequires vLLM PR #40082 + a 4-layer software patch stack — see Part 25.
Yes.
What Marlin Wastes
vLLM's NVFP4 path on SM121:
NVFP4 weights → Marlin dequant → BF16 → BF16 GEMM
Marlin targets SM75+ compatibility, so it dequants to BF16. But on GB10, FP8 tensor cores are 1.6–4x faster than BF16 for the same matrix sizes. Marlin leaves that performance on the table.
Our path:
NVFP4 weights → Triton dequant → FP8 → FP8 tensor core GEMM
The NVFP4 Dequant Formula
NVFP4 (compressed-tensors format) stores three tensors per layer:
| Tensor | Type | Shape |
|---|---|---|
weight_packed | uint8 | (N, K/2) — two FP4 values per byte |
weight_scale | float8_e4m3fn | (N, K/16) — per-group block scale |
weight_global_scale | float32 | (1,) — per-tensor global scale |
FP4 E2M1 has exactly 16 possible values:
0, ±0.5, ±1.0, ±1.5, ±2.0, ±3.0, ±4.0, ±6.0
The dequant formula:
value = fp4_lut[nibble] × block_scale ÷ global_scale
Gotcha: global_scale stores the quantization scale (not the dequant scale), so you divide by it. I got this wrong on the first attempt — dequanted values hit ±100 million, got clamped to FP8's ±448 range, and produced garbage.
The Triton Kernel
FP4 E2M1 → FP8 E4M3 is a lookup table + scale multiply — ideal for GPU parallelism:
@triton.jit
def _nvfp4_dequant_to_fp8_kernel(
packed_ptr, scale_ptr, global_scale_ptr, out_ptr,
K, K_packed, n_scale_cols, group_size, BLOCK_K: tl.constexpr,
):
row = tl.program_id(0)
col_offsets = tl.program_id(1) * BLOCK_K + tl.arange(0, BLOCK_K)
global_scale_inv = 1.0 / tl.load(global_scale_ptr)
packed = tl.load(packed_ptr + row * K_packed + col_offsets, ...)
# Unpack two FP4 values per byte
low_nibble = packed & 0x0F
high_nibble = (packed >> 4) & 0x0F
# LUT decode + sign handling
low_val = fp4_lut_decode(low_nibble)
high_val = fp4_lut_decode(high_nibble)
# Apply scales
low_val *= block_scale * global_scale_inv
high_val *= block_scale * global_scale_inv
# Store as FP8 E4M3
tl.store(out_ptr + ..., low_val.to(tl.float8e4nv), ...)
The full kernel is ~80 lines of Triton. No CUDA, no CUTLASS templates.
Micro-Benchmarks
Tested with real Qwen 3.6-35B-A3B NVFP4 weights on GB10.
Dequant Speed
| Method | Time | Speedup |
|---|---|---|
| Python FP4→BF16 (simulating Marlin) | 0.226 ms | baseline |
| Triton FP4→FP8 | 0.010 ms | 23x |
Single-Layer End-to-End (dequant + GEMM, batch=32)
| Path | Time | Speedup |
|---|---|---|
| A: FP4→BF16 + BF16 GEMM | 0.229 ms | baseline |
| B: Triton FP4→FP8 + FP8 GEMM | 0.060 ms | 3.8x |
| C: Pure FP8 GEMM (ceiling) | 0.029 ms | 7.9x |
The Triton dequant itself is 0.01ms — essentially free. The gap to ceiling is activation FP8 quantization overhead.
Integrating with vLLM
No source modifications. A monkey-patch replaces two methods on MarlinNvFp4LinearKernel:
process_weights_after_loading: Skip Marlin's repack. Instead, run the Triton kernel once to convert all FP4 weights to FP8. Store the result as layer._fp8_weight. Free the original FP4 tensors to reclaim VRAM.
apply_weights: Quantize activations to FP8, then call torch._scaled_mm for the FP8 GEMM. No Marlin involved.
Launch: Set VLLM_NVFP4_GEMM_BACKEND=marlin to force vLLM to select the Marlin kernel (which our patch intercepts). Use a wrapper script to apply the patch before model loading.
Two pitfalls we hit:
_scaled_mmmatrix layout: The B matrix must be column-major. Store(N, K)contiguous,.t()at runtime gives the col-major(K, N)view.torch.compilecompatibility: Use pure torch ops for activation quantization (not vLLM custom ops) so the inductor can trace and fuse the graph.
Result: 47.6 tok/s
Qwen 3.6-35B-A3B NVFP4, DGX Spark, vLLM 0.19.1, driver 580.142.
Note (2026-05-06): All tok/s numbers are measured on Qwen 3.6-35B-A3B. A later profile showed 60% of single-stream time is spent in non-quantized BF16 GDN linear projections, so the +17% speedup (40.8 → 47.6) is reproducible on Qwen 3.6 but the dominant cause may be launch-overhead reduction rather than the "FP8 tensor core beats BF16 GEMM" mechanism reasoned about in this article. Validating the patch's real effect on pure-transformer models requires a retest with non-hybrid architectures.
| Approach | tok/s | vs Marlin |
|---|---|---|
| Marlin BF16 fallback (stock) | 40.8 | baseline |
| FlashInfer CUTLASS (vLLM 0.19 default) | 42.5 | +4.2% |
| Triton FP8 patch | 47.6 | +16.7% |
| Native FP8 (FP8 checkpoint) | 53.8 | +31.9% |
NVFP4 goes from 40.8 → 47.6 tok/s, a 17% speedup. The gap to native FP8 shrinks from 32% to 12%.
Where the Last 12% Lives
Three sources:
-
Dequant precision loss: FP4 → FP8 conversion isn't as precise as direct FP8 quantization from BF16/FP32. The original FP8 model was carefully quantized; our path adds an extra conversion step.
-
Activation quantization overhead: Every forward pass quantizes activations from BF16 to FP8.
torch.compilefuses this, but it's not zero-cost. -
No fused kernel: A CUTLASS-level fused dequant+GEMM would dequant in shared memory while computing — eliminating the intermediate FP8 tensor entirely. That's a 1–2 week engineering effort requiring CUTLASS 3.x template metaprogramming.
When to Use This
Use it: Model only ships NVFP4, no FP8 checkpoint exists, you're on DGX Spark and want more than 40.8 tok/s.
Skip it: FP8 checkpoint available. Native FP8 = 53.8 tok/s, always faster than the bypass.
Conclusion
Part 19 said "NVFP4 is a trap on GB10, use FP8." Part 20 amends that: NVFP4 is recoverable on GB10, with a ceiling.
A 30-line Triton kernel plus a monkey-patch, built in an afternoon, delivers 17% more throughput. No vLLM source changes, no CUDA, no waiting for upstream fixes.
If you're running NVFP4 models on DGX Spark, the patch works today. If you want to contribute to vLLM, here's a clear direction: replace the BF16 Marlin fallback with FP8 tensor cores on SM121.
Full Code
Two files. nvfp4_fp8_patch.py is the core patch, serve_nvfp4_fp8.py is the launch wrapper.
nvfp4_fp8_patch.py
"""
Monkey-patch vLLM's NVFP4 Marlin fallback to use Triton FP4→FP8 dequant + FP8 GEMM.
For NVIDIA GB10 (SM121) which has FP8 tensor cores but no FP4 tensor cores.
Intercepts two methods:
- process_weights_after_loading: skip Marlin repack, convert to FP8 via Triton
- apply_weights: FP8 GEMM via torch._scaled_mm
Prerequisites:
- vLLM 0.19+
- VLLM_NVFP4_GEMM_BACKEND=marlin (force Marlin so our patch intercepts it)
- TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas (Triton compilation on SM121)
"""
import torch
import triton
import triton.language as tl
import logging
import sys
logger = logging.getLogger("nvfp4_fp8_patch")
# ============================================================
# Triton kernel: NVFP4 packed uint8 → FP8 E4M3
#
# Each byte holds 2 FP4 E2M1 values (low 4 bits + high 4 bits).
# FP4 E2M1 has only 16 possible values: 0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6
# We decode via LUT, multiply by block scale, divide by global scale.
#
# Dequant formula: value = fp4_lut[nibble] × block_scale ÷ global_scale
# Note: global_scale stores the quantization scale — divide, don't multiply!
# ============================================================
@triton.jit
def _nvfp4_dequant_to_fp8_kernel(
packed_ptr, # input: uint8 packed weights (N, K/2)
scale_ptr, # input: float8_e4m3fn block scales (N, K/group_size)
global_scale_ptr, # input: float32 global scale (1,)
out_ptr, # output: float8_e4m3fn (N, K)
K: tl.constexpr,
K_packed: tl.constexpr, # K // 2
n_scale_cols: tl.constexpr, # K // group_size
group_size: tl.constexpr, # typically 16
BLOCK_K: tl.constexpr, # tile size
):
row = tl.program_id(0)
col_block = tl.program_id(1)
col_offsets = col_block * BLOCK_K + tl.arange(0, BLOCK_K)
mask = col_offsets < K_packed
# global_scale stores quantize scale; invert for dequant
global_scale_inv = 1.0 / tl.load(global_scale_ptr)
# Load packed bytes
packed = tl.load(packed_ptr + row * K_packed + col_offsets, mask=mask, other=0).to(tl.uint8)
# Split low and high nibbles
low_nibble = packed & 0x0F
high_nibble = (packed >> 4) & 0x0F
# Separate sign bit (bit 3) and magnitude bits (bits 0-2)
low_sign = ((low_nibble >> 3) & 1).to(tl.float32)
low_mag = (low_nibble & 0x07).to(tl.uint8)
high_sign = ((high_nibble >> 3) & 1).to(tl.float32)
high_mag = (high_nibble & 0x07).to(tl.uint8)
# FP4 E2M1 LUT: magnitude index → float value
# 0→0, 1→0.5, 2→1.0, 3→1.5, 4→2.0, 5→3.0, 6→4.0, 7→6.0
low_val = tl.where(low_mag == 0, 0.0,
tl.where(low_mag == 1, 0.5,
tl.where(low_mag == 2, 1.0,
tl.where(low_mag == 3, 1.5,
tl.where(low_mag == 4, 2.0,
tl.where(low_mag == 5, 3.0,
tl.where(low_mag == 6, 4.0, 6.0)))))))
low_val = tl.where(low_sign > 0.5, -low_val, low_val)
high_val = tl.where(high_mag == 0, 0.0,
tl.where(high_mag == 1, 0.5,
tl.where(high_mag == 2, 1.0,
tl.where(high_mag == 3, 1.5,
tl.where(high_mag == 4, 2.0,
tl.where(high_mag == 5, 3.0,
tl.where(high_mag == 6, 4.0, 6.0)))))))
high_val = tl.where(high_sign > 0.5, -high_val, high_val)
# Compute block scale indices
actual_col_low = col_offsets * 2
actual_col_high = col_offsets * 2 + 1
block_scale_low = tl.load(
scale_ptr + row * n_scale_cols + actual_col_low // group_size,
mask=actual_col_low < K, other=0.0).to(tl.float32)
block_scale_high = tl.load(
scale_ptr + row * n_scale_cols + actual_col_high // group_size,
mask=actual_col_high < K, other=0.0).to(tl.float32)
# Dequant + clamp to FP8 E4M3 range (±448)
low_val = tl.minimum(tl.maximum(low_val * block_scale_low * global_scale_inv, -448.0), 448.0)
high_val = tl.minimum(tl.maximum(high_val * block_scale_high * global_scale_inv, -448.0), 448.0)
# Write FP8 output
tl.store(out_ptr + row * K + actual_col_low, low_val.to(tl.float8e4nv), mask=actual_col_low < K)
tl.store(out_ptr + row * K + actual_col_high, high_val.to(tl.float8e4nv), mask=actual_col_high < K)
def nvfp4_to_fp8(weight_packed, weight_scale, weight_global_scale):
"""Convert NVFP4 packed weights to FP8 E4M3 via Triton kernel."""
N, K_packed = weight_packed.shape
K = K_packed * 2
n_scale_cols = weight_scale.shape[1]
group_size = K // n_scale_cols
out = torch.empty(N, K, device=weight_packed.device, dtype=torch.float8_e4m3fn)
BLOCK_K = 256
grid = (N, triton.cdiv(K_packed, BLOCK_K))
_nvfp4_dequant_to_fp8_kernel[grid](
weight_packed, weight_scale, weight_global_scale, out,
K=K, K_packed=K_packed,
n_scale_cols=n_scale_cols, group_size=group_size,
BLOCK_K=BLOCK_K,
)
return out
# ============================================================
# Replacement methods for vLLM's MarlinNvFp4LinearKernel
# ============================================================
def _patched_process_weights(self, layer: torch.nn.Module) -> None:
"""Skip Marlin repack. Convert NVFP4 → FP8 once at load time.
Frees original FP4 weights after conversion to save VRAM.
Stores FP8 weights in layer._fp8_weight for runtime use.
"""
logger.warning_once(
"[nvfp4_fp8_patch] SM121 detected. Skipping Marlin repack, "
"will use Triton FP4→FP8 dequant + FP8 tensor core GEMM."
)
weight_packed = layer.weight.data # uint8 (N, K/2)
weight_scale = layer.weight_scale.data # float8_e4m3fn (N, K/group)
weight_global_scale = layer.weight_global_scale.data # float32
# One-time dequant: NVFP4 → FP8
w_fp8 = nvfp4_to_fp8(weight_packed, weight_scale, weight_global_scale)
# Store as (N, K) contiguous. At runtime, .t() produces a col-major
# (K, N) view — required by torch._scaled_mm for the B matrix.
layer._fp8_weight = torch.nn.Parameter(w_fp8.contiguous(), requires_grad=False)
layer._fp8_w_scale = torch.nn.Parameter(
torch.ones(1, 1, device=w_fp8.device, dtype=torch.float32),
requires_grad=False,
)
# Free original FP4 weights to reclaim VRAM
layer.weight = torch.nn.Parameter(
torch.empty(0, dtype=torch.uint8, device=weight_packed.device),
requires_grad=False,
)
layer.weight_scale = torch.nn.Parameter(
torch.empty(0, dtype=torch.float8_e4m3fn, device=weight_scale.device),
requires_grad=False,
)
logger.info(
f"[nvfp4_fp8_patch] Pre-converted layer to FP8: "
f"{weight_packed.shape} → {w_fp8.shape}, "
f"weight range: [{w_fp8.float().min():.4f}, {w_fp8.float().max():.4f}]"
)
def _patched_apply_weights(self, layer, x, bias=None):
"""FP8 tensor core GEMM using pre-converted weights.
Activation quantization uses pure torch ops so torch.compile/inductor
can trace and fuse the graph (vLLM custom ops break inductor here).
"""
orig_shape = x.shape
x_2d = x.reshape(-1, orig_shape[-1])
# Activation → FP8 (per-tensor dynamic quantization)
amax = x_2d.float().abs().amax()
a_scale = (amax / 448.0).clamp(min=1e-12)
x_fp8 = (x_2d.float() / a_scale).clamp(-448, 448).to(torch.float8_e4m3fn)
a_scale = a_scale.reshape(1, 1)
# FP8 GEMM: (M, K) @ (K, N) → (M, N)
out = torch._scaled_mm(
x_fp8,
layer._fp8_weight.t(), # (N,K).t() = col-major (K,N)
scale_a=a_scale,
scale_b=layer._fp8_w_scale,
out_dtype=x.dtype,
)
if bias is not None:
out = out + bias
return out.reshape(*orig_shape[:-1], out.shape[-1])
# ============================================================
# Patch installation: detect SM121, replace Marlin methods
# ============================================================
_patched = False
def patch():
"""Install the monkey-patch. Only activates on SM121 (GB10)."""
global _patched
if _patched:
return True
if torch.cuda.is_available():
cap = torch.cuda.get_device_capability()
if cap != (12, 1):
logger.info(f"SM{cap[0]}{cap[1]} detected, FP8 patch not needed")
return False
try:
from vllm.model_executor.kernels.linear.nvfp4.marlin import (
MarlinNvFp4LinearKernel,
)
MarlinNvFp4LinearKernel.process_weights_after_loading = _patched_process_weights
MarlinNvFp4LinearKernel.apply_weights = _patched_apply_weights
_patched = True
logger.warning(
"[nvfp4_fp8_patch] Installed! MarlinNvFp4LinearKernel now uses "
"pre-converted FP8 weights + FP8 tensor core GEMM."
)
return True
except ImportError as e:
logger.error(f"Failed to patch: {e}")
return False
serve_nvfp4_fp8.py
"""
Launch wrapper: install FP8 patch, then start vllm serve.
The patch runs at import time (outside __main__ guard) so spawned workers
also get patched. Only the parent process calls vllm_main().
Environment variables:
MODEL_PATH Model path (default: /model)
PORT API port (default: 8000)
MAX_MODEL_LEN Max context length (default: 4096)
GPU_MEM_UTIL GPU memory utilization (default: 0.90)
Docker usage:
docker run --gpus all \
-e VLLM_NVFP4_GEMM_BACKEND=marlin \
-e TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas \
-e MODEL_PATH=/model \
-v /path/to/nvfp4/model:/model \
-v /path/to/scripts:/scripts \
--entrypoint python3 <vllm-image> /scripts/serve_nvfp4_fp8.py
"""
import sys
import os
import logging
# Patch at import time so spawned workers inherit it
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import nvfp4_fp8_patch
nvfp4_fp8_patch.patch()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO,
format="%(asctime)s %(name)s %(levelname)s %(message)s")
model_path = os.environ.get("MODEL_PATH", "/model")
port = os.environ.get("PORT", "8000")
max_model_len = os.environ.get("MAX_MODEL_LEN", "4096")
gpu_mem = os.environ.get("GPU_MEM_UTIL", "0.90")
sys.argv = [
"vllm", "serve", model_path,
"--port", port,
"--max-model-len", max_model_len,
"--gpu-memory-utilization", gpu_mem,
"--trust-remote-code",
]
from vllm.entrypoints.cli.main import main as vllm_main
vllm_main()
Series: Part 19 — NVFP4 Is a Trap · Part 14 — Gemma 4 Complete Guide · Part 8 — vLLM vs Ollama
FAQ
- Is NVFP4 truly hopeless on DGX Spark?
- No native hardware support, but you can work around it. A Triton kernel dequants NVFP4 weights to FP8, then runs on FP8 tensor cores — 40.8 → 47.6 tok/s (+17%). Only 12% behind native FP8.
- Can I use this patch with vLLM today?
- Yes. It monkey-patches vLLM's MarlinNvFp4LinearKernel, converts weights to FP8 at load time, and runs torch._scaled_mm at inference. Set VLLM_NVFP4_GEMM_BACKEND=marlin. Full code in the article.
- Why not just run the FP8 model directly?
- If an FP8 checkpoint exists, use it — 53.8 tok/s beats everything. This patch matters when a model ships only as NVFP4 and no FP8 version exists.