~/blog/part1-why-your-dgx-spark-says-exclamation-marks

DGX Spark · part 4

[vLLM] Why Your DGX Spark Only Says "!!!!!": Debugging NVFP4 on SM121

2026-03-179 min read#dgx-spark#sm121#vllm#nvfp4中文版

Preface

You bought enterprise GPU hardware. You downloaded a model that says it supports the hardware. You ran the standard command. The server started with no errors. The model output !!!!!!!!!!!!!! for every single prompt.

This article is a full diagnostic record of what happened and why — four independent bugs, each required to get from garbage output to coherent text.

The analogy: imagine a new car where the engine computer is programmed for the European version of the model, but you bought the North American one. The engine runs. It starts fine. It just doesn't fire on the right fuel map, so performance is subtly wrong in a way that looks like the car works until you actually drive it. Same hardware, different microcode — the software was never told they're not the same.

That's SM121 versus SM120 in one sentence.


You got a DGX Spark. You downloaded Qwen3.5-122B-A12B-NVFP4. You fired up vLLM:

vllm serve Qwen/Qwen3.5-122B-A12B-NVFP4 --quantization mxfp4

The server starts. No errors. You send a prompt. You get:

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

Every prompt. Every temperature. Every model. !!!!!!!!!!.

This is the debugging story.


What ! Actually Means

In the Qwen3.5 tokenizer, ! is token ID 0. Repeated ! output means your model is emitting near-zero logits — argmax of near-zero values = 0 = !. The model isn't confused or undertrained. Something in the computation pipeline is producing wrong values that drive all logits toward zero.

The first instinct is "bad weights" or "wrong quantization format." It's neither. The weights are fine. The format is correct. The problem is the kernel.


The Hardware: SM121 ≠ SM120

The DGX Spark has an NVIDIA GB10 GPU — compute capability 12.1, NVIDIA internal designation SM121.

The GB200 (data center NVL72 rack) is SM120.

Despite both being "Blackwell SM12x," they have different microarchitectures. CUTLASS FP4 GEMM kernels are compiled targeting the SM120 ISA. On SM121, they execute without any CUDA error — they just produce wrong answers. Silently. Every time.

The vLLM code doesn't know this. It sees "SM12x" and thinks everything is fine.


The Four Bugs

Getting from !!!!! to coherent output required fixing four independent issues. Two were already patched in the community fork being used (namake-taro's vLLM branch); two required manual intervention.

Bug #1: PTX Instruction Missing on SM121

The Triton downcast kernel in _downcast_to_mxfp.py contains:

elif cuda_capability_geq(10, 0):
    # use hardware FP4 downcast instruction
    return tl.inline_ptx_asm("cvt.rn.satfinite.e2m1x2.f32 ...")

SM121 has compute capability 12.1 (≥ 10.0), so it takes this branch. But cvt.rn.satfinite.e2m1x2.f32 doesn't exist on SM121 hardware. The PTX assembles without error, but the instruction produces undefined behavior at runtime.

Already fixed in namake-taro's fork: the SM12x family is explicitly excluded from the hardware PTX path and falls back to software emulation.

Bug #2: Marlin 256-Thread Race on SM121

The Marlin MoE kernel in fused_marlin_moe.py defaults to 256 threads per block. On SM121, this configuration triggers a race condition for large N (the typical decode batch size range), producing non-deterministic garbage.

Already fixed in namake-taro's fork: forces 128 threads for N ≥ 2048.

Bug #3: SupportsQuant Missing — GDN Layers Get Quantized

Qwen3.5-122B is a hybrid architecture mixing standard Transformer attention layers with GDN (Gated Delta Network) SSM layers. The SSM layers must stay in BF16. Quantizing them to NVFP4 corrupts the recurrent hidden states.

vLLM handles this via an exclude_modules list in the quantization config. The list specifies which layers to skip by weight name. But there's a catch: the exclusion list uses HuggingFace-format weight names, while vLLM uses internally remapped names. The SupportsQuant mixin calls apply_vllm_mapper() to translate between the two formats.

# qwen3_5.py — before the fix
class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid):
    #                                                                   ^^^ No SupportsQuant

Without SupportsQuant, apply_vllm_mapper() is never called. The exclusion list keeps the HF-format prefixes. is_layer_excluded() returns False for every GDN layer. Every SSM layer gets NVFP4-quantized. The GDN hidden states become garbage, and the model's internal state degrades across the sequence.

Fix:

class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid, SupportsQuant):

Qwen3_5MoeForConditionalGeneration (the 122B class) inherits from this, so it gets the fix transitively.

Bug #4: CUTLASS FP4 Falsely Selected for SM121 (The Killer)

This is the root cause of the !!!!! output.

4a. Linear layerscutlass_fp4_supported() in nvfp4_utils.py:

def cutlass_fp4_supported() -> bool:
    capability = current_platform.get_device_capability()
    capability_int = capability.major * 10 + capability.minor  # = 121 for SM12.1
    return cutlass_scaled_mm_supports_fp4(capability_int)       # 121 > threshold → True

The function passes capability_int = 121 as if it were a device ID. Since 121 is numerically larger than any reasonable SM capability threshold, it returns True. But calling cutlass_scaled_mm_supports_fp4(device_id=0) with the actual device ID correctly returns False for GB10.

One function. Two calling conventions. Two different answers. CUTLASS FP4 gets selected for every linear layer.

4b. MoE layersCutlassExpertsFp4._supports_current_device() in cutlass_moe.py:

@staticmethod
def _supports_current_device() -> bool:
    p = current_platform
    return p.is_cuda() and (
        p.is_device_capability_family(100)
        or p.is_device_capability_family(110)
        or p.is_device_capability_family(120)   # ← matches SM120 AND SM121
    )

is_device_capability_family(120) returns True for any SM12x device, including SM121. The broken CUTLASS FP4 kernel gets selected for all MoE expert GEMMs as well.

Note: setting VLLM_NVFP4_GEMM_BACKEND=marlin only affects the linear layer path (Bug 4a). The MoE path (Bug 4b) ignores this environment variable entirely. You need both fixes.


The Diagnostic Signature

How do you know it's CUTLASS FP4 and not something else? Hook into the GEMM output before the lm_head and inspect the tensor:

row 0:  [-28.625, -28.625, -28.625, -28.625, ...]
row 1:  [-12.500, -12.500, -12.500, -12.500, ...]
row 2:  [ -4.250,  -4.250,  -4.250,  -4.250, ...]

Every element in a row has the same value. That's the CUTLASS FP4 garbage signature on SM121 — the kernel computes one value and broadcast-fills the row. Not zeros, not NaN, not random noise. The same wrong number, repeated across the entire row width.

Legitimate computation errors produce varied wrong values. Row-identical output means the GEMM kernel itself is structurally miscomputing — it's computing a scalar and broadcasting it instead of computing a proper matrix multiply.

The lm_head then takes this row-identical garbage as input and produces near-zero logits everywhere, which argmax collapses to token ID 0 (!).


The Fixes

Environment variables

export VLLM_NVFP4_GEMM_BACKEND=marlin   # linear layers: Marlin W4A16
export VLLM_MXFP4_USE_MARLIN=1          # MoE linear: bypass CUTLASS_FP4 branch
export VLLM_USE_FLASHINFER_MOE_FP4=0
export VLLM_MARLIN_USE_ATOMIC_ADD=1     # SM121 Marlin atomic race fix
export CUDA_DEVICE_MAX_CONNECTIONS=1

Code patches (2 files)

cutlass_moe.py — exclude SM121 from CutlassExpertsFp4:

@staticmethod
def _supports_current_device() -> bool:
    p = current_platform
    cap = p.get_device_capability()
    if cap is not None and cap.major == 12 and cap.minor >= 1:
        return False  # SM121 produces garbage with CUTLASS FP4
    return p.is_cuda() and (
        p.is_device_capability_family(100)
        or p.is_device_capability_family(110)
        or p.is_device_capability_family(120)
    )

mxfp4.py — allow VLLM_MXFP4_USE_MARLIN to bypass the SM12x CUTLASS_FP4 branch:

# Before:
if is_sm12x:
# After:
if is_sm12x and not envs.VLLM_MXFP4_USE_MARLIN:

qwen3_5.py — add SupportsQuant to fix GDN exclusion:

class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid, SupportsQuant):

Performance After Fix

The good news: it works. The caveats: Marlin W4A16 is not W4A4.

| Metric | Value | |--------|-------| | Decode speed | ~15 tok/s | | Backend | Marlin W4A16 (weight decompression) | | KV cache (fp8, 0.90 util) | 26.7 GiB → 580K tokens | | Max context | 200K tokens | | CUDAGraph | ✅ PIECEWISE + FULL captured |

Marlin decompresses FP4 weights to BF16 at inference time — activations stay BF16, not FP4. Speed is memory-bandwidth bound: GB10 has 273 GB/s, and 122B × 4-bit ≈ 61 GB ≈ ~15 tok/s theoretical. We're close to the ceiling.

True W4A4 CUTLASS (compute-bound) would be roughly 2–3× faster, but there's no working SM121 kernel for it yet. The GB10's compute capability is there; the software stack hasn't caught up.


What Needs Upstreaming

Three issues worth filing against vllm-project/vllm:

1. cutlass_fp4_supported() false positive — The function in nvfp4_utils.py passes capability_int (an integer like 121) to cutlass_scaled_mm_supports_fp4() which was designed to receive a device ID. The fix is to call cutlass_scaled_mm_supports_fp4(device_id=current_device_id) instead.

2. CutlassExpertsFp4._supports_current_device() includes SM121is_device_capability_family(120) returns True for any SM12x. Needs an explicit minor-version check to exclude SM121.

3. Qwen3_5ForConditionalGeneration missing SupportsQuant — GDN (linear_attn) layers in the hybrid 122B model are incorrectly quantized when SupportsQuant is absent.


What Was Gained

Even with all four bugs fixed, Qwen3.5-122B lands at ~15 tok/s — not the theoretical W4A4 ceiling. That's a separate story covered in Part 2. But the work here wasn't wasted.

Confirmed:

  • The row-identical failure signature is a reliable diagnostic for CUTLASS FP4 on SM121. Not zeros, not random noise — same value repeated across every column in a row. If you see this, you know exactly what to look for.
  • All four bugs are independent. Each one needs to be fixed separately. Fixing three of four doesn't give you 75% working output — it gives you garbage with a different root cause.
  • SupportsQuant is a non-obvious requirement for hybrid Qwen architectures. The GDN layers look like any other layer in the model's config — there's no warning when they get quantized incorrectly.

Established:

  • SM121 (GB10, DGX Spark) and SM120 (GB200) are not interchangeable software targets, despite both being labeled "Blackwell SM12x." The difference is invisible to vLLM's capability detection without explicit patches.
  • VLLM_NVFP4_GEMM_BACKEND (wrong) vs VLLM_MXFP4_BACKEND (correct) — one is silently ignored. Always confirm with the startup log.

TL;DR

If you have a DGX Spark and your NVFP4 model only outputs !!!!!:

  1. SM121 (GB10) ≠ SM120 (GB200). CUTLASS FP4 kernels target SM120 ISA and silently produce garbage on SM121.
  2. Set VLLM_NVFP4_GEMM_BACKEND=marlin and VLLM_MXFP4_USE_MARLIN=1 — these cover different code paths. You need both.
  3. Patch CutlassExpertsFp4._supports_current_device() to return False for SM12.1+ (MoE layer path ignores env vars).
  4. If running Qwen3.5-122B: add SupportsQuant to Qwen3_5ForConditionalGeneration so GDN layers stay BF16.
  5. The row-identical output signature is your diagnostic confirmation — not zeros, not random noise, every column in a row has the same value.

You'll land at ~15 tok/s via Marlin W4A16. Not the theoretical W4A4 throughput, but coherent output is a prerequisite for everything else.


Next in this series: Qwen3.5-122B Runs. But at 14 tok/s.