DGX Spark · part 4
[vLLM] 為什麼你的 DGX Spark 只會輸出「!!!!!」:SM121 上的 NVFP4 除錯記錄
前言
你買了企業級 GPU 硬體。下載了宣稱支援這個硬體的模型。跑了標準指令。伺服器正常啟動,沒有任何 error。然後每一個 prompt 都只回傳 !!!!!!!!!!!!!!。
這篇是完整的診斷記錄——四個獨立的 bug,每一個都是必要條件,缺一不可地把模型從垃圾輸出修到正常文字。
用個比喻:想像你買了一台新車,但引擎電腦燒錄的是歐規版本的程式,你買的卻是北美版。引擎會啟動,開起來也沒有明顯報錯,只是在燃油噴射的某個細節上默默地用錯了設定,導致實際行為跟預期不符。同樣的硬體,不同的微碼——軟體從來沒有被告知它們不是同一個東西。
SM121 和 SM120 的關係,就是這樣。
你拿到了一台 DGX Spark。下載了 Qwen3.5-122B-A12B-NVFP4。啟動 vLLM:
vllm serve Qwen/Qwen3.5-122B-A12B-NVFP4 --quantization mxfp4
伺服器啟動。沒有任何錯誤。你送出一個 prompt。你得到:
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
每個 prompt 都這樣。每個 temperature 都這樣。每個模型都這樣。!!!!!!!!!!。
這是那次除錯的完整記錄。
! 的真正含義
在 Qwen3.5 的 tokenizer 裡,! 是 token ID 0。持續輸出 ! 代表你的模型輸出了接近零的 logits——near-zero 的 argmax = 0 = !。模型沒有問題,權重也沒有問題。是計算過程中某個環節產生了錯誤的數值,把所有 logits 都壓向零。
直覺反應是「壞掉的權重」或「錯誤的量化格式」。兩個都不是。問題在 kernel 本身。
硬體差異:SM121 ≠ SM120
DGX Spark 搭載的是 NVIDIA GB10 GPU——compute capability 12.1,NVIDIA 內部代號 SM121。
資料中心的 GB200(NVL72 機架)是 SM120。
儘管兩者都是「Blackwell SM12x」,它們有不同的微架構。CUTLASS FP4 GEMM kernel 是針對 SM120 ISA 編譯的。在 SM121 上,它們會正常執行,不報任何 CUDA 錯誤——只是靜默地產生錯誤答案。
vLLM 不知道這件事。它看到「SM12x」就以為一切正常。
四個 Bug
從 !!!!! 到正常輸出,需要修掉四個獨立的問題。其中兩個在社群 fork(namake-taro 的 vLLM 分支)裡已經有補丁;另外兩個需要手動處理。
Bug #1:SM121 缺少 PTX 指令
_downcast_to_mxfp.py 裡的 Triton downcast kernel 有這段:
elif cuda_capability_geq(10, 0):
# 使用硬體 FP4 downcast 指令
return tl.inline_ptx_asm("cvt.rn.satfinite.e2m1x2.f32 ...")
SM121 的 compute capability 是 12.1(≥ 10.0),所以會走這個分支。但 cvt.rn.satfinite.e2m1x2.f32 這條 PTX 指令在 SM121 硬體上不存在。PTX 組譯不報錯,但這條指令在執行時產生未定義行為。
namake-taro fork 已修正:SM12x 系列被明確排除在硬體 PTX 路徑之外,改走軟體模擬。
Bug #2:Marlin 256-Thread Race 條件
fused_marlin_moe.py 裡的 Marlin MoE kernel 預設使用每個 block 256 個 thread。在 SM121 上,這個配置在較大的 N(典型的 decode batch size 範圍)下會觸發 race condition,產生不確定性的垃圾輸出。
namake-taro fork 已修正:N ≥ 2048 時強制使用 128 個 thread。
Bug #3:SupportsQuant 缺失——GDN 層被量化
Qwen3.5-122B 是混合架構,結合了標準 Transformer attention 層和 GDN(Gated Delta Network)SSM 層。SSM 層必須保持 BF16;把它們量化成 NVFP4 會破壞遞迴 hidden state。
vLLM 透過量化設定裡的 exclude_modules 清單來跳過特定層。這個清單指定要排除的層的 weight 名稱。但有個陷阱:清單使用 HuggingFace 格式的 weight 名稱,而 vLLM 內部使用重新映射過的名稱。SupportsQuant mixin 會呼叫 apply_vllm_mapper() 來在兩種格式間轉換。
# qwen3_5.py — 修正前
class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid):
# ^^^ 沒有 SupportsQuant
沒有 SupportsQuant,apply_vllm_mapper() 就不會被呼叫。排除清單保留了 HF 格式的前綴。is_layer_excluded() 對每一個 GDN 層都回傳 False。所有 SSM 層都被 NVFP4 量化。GDN hidden state 變成垃圾,模型的內部狀態在整個序列中逐漸崩潰。
修正:
class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid, SupportsQuant):
Qwen3_5MoeForConditionalGeneration(122B 的類別)繼承自這個類別,所以也會繼承到這個修正。
Bug #4:CUTLASS FP4 被錯誤地選擇用於 SM121(元兇)
這是 !!!!! 輸出的根本原因。
4a. Linear 層 — nvfp4_utils.py 裡的 cutlass_fp4_supported():
def cutlass_fp4_supported() -> bool:
capability = current_platform.get_device_capability()
capability_int = capability.major * 10 + capability.minor # SM12.1 = 121
return cutlass_scaled_mm_supports_fp4(capability_int) # 121 > 閾值 → True
這個函式把 capability_int = 121 當作 device ID 傳入。由於 121 在數值上大於任何合理的 SM capability 閾值,回傳 True。但用實際 device ID 呼叫 cutlass_scaled_mm_supports_fp4(device_id=0) 在 GB10 上會正確回傳 False。
一個函式,兩種呼叫慣例,兩個不同答案。結果:CUTLASS FP4 被選用於所有 linear 層。
4b. MoE 層 — cutlass_moe.py 裡的 CutlassExpertsFp4._supports_current_device():
@staticmethod
def _supports_current_device() -> bool:
p = current_platform
return p.is_cuda() and (
p.is_device_capability_family(100)
or p.is_device_capability_family(110)
or p.is_device_capability_family(120) # ← 匹配 SM120 和 SM121
)
is_device_capability_family(120) 對所有 SM12x 裝置都回傳 True,包含 SM121。損壞的 CUTLASS FP4 kernel 被選用於所有 MoE expert GEMM。
注意:設定 VLLM_NVFP4_GEMM_BACKEND=marlin 只影響 linear 層路徑(Bug 4a)。MoE 路徑(Bug 4b)完全無視這個環境變數。兩個修正都需要。
診斷特徵
怎麼確認是 CUTLASS FP4 的問題,而不是其他原因?在 lm_head 前加一個 debug hook,直接檢查 GEMM 輸出 tensor:
row 0: [-28.625, -28.625, -28.625, -28.625, ...]
row 1: [-12.500, -12.500, -12.500, -12.500, ...]
row 2: [ -4.250, -4.250, -4.250, -4.250, ...]
每一行的每個元素都是相同的值。 這就是 SM121 上 CUTLASS FP4 垃圾輸出的特徵——kernel 計算出一個值然後廣播填滿整行。不是零,不是 NaN,不是隨機雜訊。同一個錯誤數字,重複填滿整行。
正常的計算錯誤會產生各種不同的錯誤數值。Row-identical 輸出代表 GEMM kernel 本身在結構上計算錯誤——它算出一個純量然後廣播,而不是做正確的矩陣乘法。
lm_head 接收這個 row-identical 的垃圾作為輸入,產生幾乎全零的 logits,argmax 全部塌縮到 token ID 0(!)。
修正方案
環境變數
export VLLM_NVFP4_GEMM_BACKEND=marlin # linear 層:Marlin W4A16
export VLLM_MXFP4_USE_MARLIN=1 # MoE linear:繞過 CUTLASS_FP4 分支
export VLLM_USE_FLASHINFER_MOE_FP4=0
export VLLM_MARLIN_USE_ATOMIC_ADD=1 # SM121 Marlin atomic race 修正
export CUDA_DEVICE_MAX_CONNECTIONS=1
程式碼修改(2 個檔案)
cutlass_moe.py — 從 CutlassExpertsFp4 排除 SM121:
@staticmethod
def _supports_current_device() -> bool:
p = current_platform
cap = p.get_device_capability()
if cap is not None and cap.major == 12 and cap.minor >= 1:
return False # SM121 使用 CUTLASS FP4 會產生垃圾輸出
return p.is_cuda() and (
p.is_device_capability_family(100)
or p.is_device_capability_family(110)
or p.is_device_capability_family(120)
)
mxfp4.py — 允許 VLLM_MXFP4_USE_MARLIN 繞過 SM12x CUTLASS_FP4 分支:
# 修改前:
if is_sm12x:
# 修改後:
if is_sm12x and not envs.VLLM_MXFP4_USE_MARLIN:
qwen3_5.py — 加入 SupportsQuant 修正 GDN 排除問題:
class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid, SupportsQuant):
修正後的效能
好消息:可以正常運作了。需要注意的是:Marlin W4A16 不是 W4A4。
| 指標 | 數值 | |------|------| | Decode 速度 | ~15 tok/s | | 後端 | Marlin W4A16(weight 解壓縮) | | KV cache(fp8,0.90 使用率) | 26.7 GiB → 580K tokens | | 最大 context | 200K tokens | | CUDAGraph | ✅ PIECEWISE + FULL 已捕捉 |
Marlin 在推理時把 FP4 weight 解壓縮成 BF16——activation 保持 BF16,不是 FP4。速度受記憶體頻寬限制:GB10 有 273 GB/s,122B × 4-bit ≈ 61 GB ≈ 理論 ~15 tok/s,實際數字非常接近。
真正的 W4A4 CUTLASS(計算量受限)大約快 2–3 倍,但目前還沒有可用的 SM121 kernel。GB10 的計算能力是足夠的;軟體堆疊還沒跟上。
值得回報給上游的問題
三個值得對 vllm-project/vllm 提 issue 的問題:
1. cutlass_fp4_supported() 誤報 — nvfp4_utils.py 裡的函式把 capability_int(像 121 這樣的整數)傳給設計用來接收 device ID 的 cutlass_scaled_mm_supports_fp4()。修法是改用 cutlass_scaled_mm_supports_fp4(device_id=current_device_id)。
2. CutlassExpertsFp4._supports_current_device() 包含了 SM121 — is_device_capability_family(120) 對所有 SM12x 都回傳 True。需要明確的 minor version 檢查來排除 SM121。
3. Qwen3_5ForConditionalGeneration 缺少 SupportsQuant — 混合 122B 模型裡的 GDN(linear_attn)層在沒有 SupportsQuant 時會被錯誤量化。
得到了什麼
修完四個 bug 之後,Qwen3.5-122B 落在 ~15 tok/s——不是理論上的 W4A4 上限。那是另一個故事,在第二篇裡講。但這裡的工作沒有白費。
確認了:
- Row-identical 失敗特徵是 SM121 上 CUTLASS FP4 問題的可靠診斷依據。不是零,不是隨機雜訊——每一行裡所有元素是同一個值。看到這個,就知道往哪裡找。
- 四個 bug 是獨立的。每一個都要單獨修。修了三個不等於 75% 正常——是帶著不同根本原因的垃圾。
SupportsQuant對混合 Qwen 架構是非顯然的必要條件。GDN 層在模型 config 裡看起來跟其他層沒有兩樣——被量化錯了也不會有任何警告。
確立了:
- SM121(GB10,DGX Spark)和 SM120(GB200)不是可以互換的軟體目標,儘管兩者都被標記為「Blackwell SM12x」。沒有明確的 patch,vLLM 的 capability 偵測看不出差異。
VLLM_NVFP4_GEMM_BACKEND(錯的)vsVLLM_MXFP4_BACKEND(對的)——前者被靜默忽略。一定要用啟動 log 確認。
總結
如果你有一台 DGX Spark,而你的 NVFP4 模型只輸出 !!!!!:
- SM121(GB10)≠ SM120(GB200)。CUTLASS FP4 kernel 針對 SM120 ISA,在 SM121 上靜默產生垃圾。
- 設定
VLLM_NVFP4_GEMM_BACKEND=marlin和VLLM_MXFP4_USE_MARLIN=1——這兩個覆蓋不同的程式碼路徑,兩個都要設。 - Patch
CutlassExpertsFp4._supports_current_device()讓 SM12.1+ 回傳 False(MoE 層路徑無視環境變數)。 - 如果跑 Qwen3.5-122B:在
Qwen3_5ForConditionalGeneration加入SupportsQuant,讓 GDN 層保持 BF16。 - Row-identical 輸出特徵是你的診斷確認——不是零,不是隨機雜訊,而是一行裡所有列都有相同的值。
最終你會得到 ~15 tok/s(Marlin W4A16)。不是理論上的 W4A4 吞吐量,但輸出正常是一切的前提。