gpt-oss-120B on DGX Spark · part 1
[vLLM] gpt-oss-120B 跑到 59 tok/s:六個坑與一份可用的 Serve Script
前言
同一台硬體,不同的模型,不同的問題集。SM121 NVFP4 搞定之後,下一個自然的問題是:這台機器還能跑什麼?gpt-oss-120B 是 Microsoft 的開源 120B 模型——不同的量化格式、不同的 tokenizer、以及它自己的一類靜默失敗。
這篇是完整的除錯記錄:從空白的 shell 到一份可用的 serve script,中間六個 bug,最後附上腳本。
用個比喻:就像搬進一間新公寓,看起來一切都是標準規格,但沒有一樣完全吻合——插座是另一個標準、冰箱用錯電壓、大樓水壓以沒有人警告過你的方式影響水管。每個問題都能修,但沒有一個是你事先知道會踩到的。
第一篇講的是讓任何 NVFP4 模型在 SM121 上不輸出 !!!!! 的四個 bug。先把那些修完。這篇從那之後開始,講試著跑 gpt-oss-120B 的過程——一個完全不同的模型格式,有它自己的一套地雷。
目標:59 tok/s,131K context,單顆 GB10(128 GB 統一記憶體)。實際路徑:六個獨立的 bug,其中一個比其他所有合起來還難找。
這個模型是什麼
gpt-oss-120B 是 Microsoft 的開源 120B 模型,MXFP4 量化,格式叫 harmony。它用 openai_harmony tokenizer(依賴 tiktoken,不是 HuggingFace tokenizers)。它有一個專門設計給它輸出格式的 --reasoning-parser。這三個事實各自貢獻了這個故事裡的一個 bug。
支援 SM121 的 vLLM fork 是 eugr 的分支,把 MXFP4 支援 patch 成能在 GB10 上跑。
注意: 這篇文章寫完之後,eugr 發布了 spark-vllm-docker——一個 Docker 化的安裝方式,有預編譯的 nightly wheel、model recipe、以及多機支援。裡面有
run-recipe.sh openai-gpt-oss-120b可以自動處理大部分設定。如果你要手動 patch stock vLLM,下面的步驟還是適用;但大多數情況下,Docker repo 是更簡單的路徑。
坑 1:Import 路徑不對
eugr 的 fork 是針對他自己的 vLLM 樹寫的。把這些 patch 打在標準 vLLM 0.17.1 上,一行 import 馬上壞掉:
# eugr fork 的路徑(在標準 vLLM 上是錯的)
from vllm.model_executor.layers.quantization.quant_utils import cutlass_fp4_supported
# 標準 vLLM 0.17.1(正確)
from vllm.model_executor.layers.quantization.nvfp4_utils import cutlass_fp4_supported
這在 mxfp4.py 裡。伺服器啟動時直接 crash,error message 沒什麼幫助。改一行解決。
坑 2:--enforce-eager 讓速度砍半
第一次成功跑起來。速度:26 tok/s。預期:約 59 tok/s。
Serve script 裡有 --enforce-eager。這個 flag 會關掉 CUDAGraph 和 torch.compile——它是個 debug flag,強制走 eager execution mode。某個時間點有人加進去 debug 用,然後忘了拔掉。
拿掉它。26 tok/s → 59 tok/s。
--enforce-eager 永遠不應該出現在 production 的 serve script 裡。如果有,在 debug 其他任何東西之前先拔掉它。
坑 3:內網機器的 tiktoken Vocab 下載失敗
gpt-oss 用 openai_harmony tokenizer,底層用 tiktoken。啟動時,tiktoken 會去 OpenAI CDN 下載 vocab 檔:
https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken
GX10 在內網。下載失敗,要麼靜默失敗,要麼報一個不直觀的 error:
HarmonyError: error downloading or loading vocab file
繞過這個問題需要翻 tiktoken 原始碼。tiktoken 用檔案的 SHA1 hash 當作快取檔名。這個機制沒有任何文件。o200k_base.tiktoken 的 hash 是:
fb374d419588a4632f3f557e76b4b70aebbca790
修法:
# 在有網路的機器上
mkdir -p ~/models/tiktoken_cache
wget "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken" \
-O ~/models/tiktoken_cache/fb374d419588a4632f3f557e76b4b70aebbca790
# 在 serve script 裡
export TIKTOKEN_RS_CACHE_DIR=/home/username/models/tiktoken_cache
目錄和檔名必須完全對應。快取檔案沒有副檔名——就是原始的 SHA1 hash。
坑 4:content: null——Reasoning Parser 的陷阱
伺服器跑起來了。發出第一個測試 request。GX10 的 log 顯示確實在生成 token。但沒有任何回覆收到。查 log:
content: None
reasoning_len: 431
tokens: {'prompt_tokens': 68, 'completion_tokens': 100}
--reasoning-parser openai_gptoss 這個 flag 會把所有輸出都放進 reasoning 欄位,讓 content 變成 null。gpt-oss 不是 thinking model。這個 parser 是為 gpt-oss 特定的輸出格式設計的——但如果你的 client 只讀 content(任何標準 OpenAI 相容 client 都這樣),它什麼都收不到。
初步修法:拿掉 --reasoning-parser openai_gptoss。content 就出現了。
(這個診斷是對的,但坑 6 同時也在作怪。拿掉 parser 有幫助,但沒修坑 6 的話,稍長一點的輸出還是會陷入重複迴圈。)
坑 5:System Message 沒有正確 Encode 成 Harmony 格式
gpt-oss 使用 harmony 訊息格式。vLLM 處理 chat completion request 的大致邏輯是:
- 用
get_system_message()建立一個正確的 harmony system message - 遍歷
request.messages,把每一條訊息加進去
Bug(追蹤中的 vLLM PR #31607,撰文時尚未 merge):如果 client 的 messages array 裡有 role: "system" 的訊息,它會被序列化成 raw harmony message,而不是走 get_system_message() 的正確路徑。模型看到格式錯誤的 token sequence 就開始胡言亂語。
任何帶 system prompt 的 client 都會踩到這個——幾乎所有 client 都這樣。
手動 fix:patch vllm/entrypoints/openai/serving_chat.py,把 request.messages 裡的 system role 抽出來,透過 get_system_message(instructions=...) 正確注入。
同時需要:
export VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS=1
這個 env var 預設是 0。不開的話,get_system_message() 的 instructions 參數被靜默忽略,system prompt 完全不會送達模型。
坑 6:用錯環境變數(靜默失敗)
這是最後一個、也是花最久時間找的坑。
修完坑 1–5 之後,簡單的 request 可以跑了。但稍微長一點的回覆就陷入重複迴圈:
The user is asking about... The user is asking about... The user is asking about...
Temperature、top_p、repetition_penalty——全沒用。迴圈必勝。
已知原因:在 SM121 上,CUTLASS_FP4 會產生 first-token logit 錯誤(參見第一篇和 vLLM issue #37030)。修法是強制所有 MXFP4 GEMM 走 Marlin。
注意: 這個修法適用於 stock vLLM。spark-vllm-docker 使用另一個有不同 patch 的 vLLM build,設定
--mxfp4-backend CUTLASS加上VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1——他們的 patch 讓 CUTLASS 能在 SM121 上正確運作。如果你用的是那個 Docker image,不需要 Marlin workaround。如果你用的是 stock vLLM 或手動打 patch,則需要。
Serve script 裡已經有這個了:
export VLLM_NVFP4_GEMM_BACKEND=marlin # ← 看起來對
這個環境變數在 vLLM 0.17.1 不存在。
vLLM 讀到它,找不到對應的 env var,靜默忽略。沒有任何警告或報錯。它 fallback 到 auto-selection。SM12x 上 auto 選的是 CUTLASS_FP4。啟動 log 顯示:
[MXFP4] Auto-selected: CUTLASS_FP4 (vLLM native SM120 FP4 grouped GEMM for SM12x)
正確的 env var 是:
export VLLM_MXFP4_BACKEND=marlin # ← 正確
設對之後,啟動 log 出現:
[MXFP4] Using backend: marlin (VLLM_MXFP4_BACKEND=marlin)
修完之後:hi → 正常回覆。稍長的 story request → 3,970 個字,無迴圈,finish_reason: stop。
可用的 Serve Script
#!/bin/bash
source /home/username/.python-vllm-eugr/bin/activate
# SM121 後端:強制全部走 Marlin(CUTLASS_FP4 在 SM121 是壞的)
# 注意:使用 eugr/spark-vllm-docker 的話,他們的 patch 讓 CUTLASS 可以跑——
# 改用 --mxfp4-backend CUTLASS + VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1
export VLLM_MXFP4_BACKEND=marlin # ← 不是 VLLM_NVFP4_GEMM_BACKEND
export VLLM_MARLIN_USE_ATOMIC_ADD=1 # SM121 Marlin atomic race fix
export FLASHINFER_DISABLE_VERSION_CHECK=1
# 離線 tokenizer cache
export TIKTOKEN_RS_CACHE_DIR=/home/username/models/tiktoken_cache
# gpt-oss harmony system message fix
export VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS=1
# 清除每次跑留下的 compile cache
rm -rf ~/.cache/flashinfer/ ~/.cache/vllm/torch_compile_cache/ 2>/dev/null || true
exec vllm serve /home/username/models/gpt-oss-120b \
--served-model-name gpt-oss-120b \
--host 0.0.0.0 --port 8001 \
--quantization mxfp4 \
--mxfp4-layers moe,qkv,o,lm_head \
--kv-cache-dtype fp8 \
--max-model-len 131072 \
--max-num-batched-tokens 8192 \
--gpu-memory-utilization 0.90 \
--attention-backend FLASHINFER \
--moe-backend marlin
原始 script 的兩個 flag 根據 eugr/spark-vllm-docker recipe 和 NVIDIA Developer Forum 的 GB10 benchmark 討論 做了更新:
--attention-backend FLASHINFER取代TRITON_ATTN。TRITON 是 fallback 路徑;FLASHINFER 是完整實作,實測更快。--mxfp4-layers moe,qkv,o,lm_head明確量化 projection layers。不加的話,qkv/o/lm_head 跑在 BF16——白白浪費效能。
注意:--reasoning-parser openai_gptoss 不在這份 script 裡。除非你需要 reasoning channel 分離,否則不要加。
效能數字
| 指標 | 數值 | |------|------| | Decode 速度 | ~59 tok/s | | 後端 | Marlin W4A16(weight 解壓縮) | | KV cache(fp8,0.90 使用率) | 約 580K tokens 容量 | | 最大 context | 131K tokens | | CUDAGraph | ✅ 已捕捉 |
數學算法:GB10 有 273 GB/s 記憶體頻寬。120B × 4-bit ≈ 60 GB。在 273 GB/s 下,理論 decode 速度約 ~60 tok/s。實測 59 tok/s 幾乎完全貼到頻寬上限。這就是 GB10 上正常運作的樣子。
得到了什麼
六個坑全部修完。59 tok/s 是結果——幾乎貼到頻寬上限,代表整個 stack 正確運作。但這些教訓值得明確寫出來,因為每一個都不直觀,而且都會再出現。
花最多時間的坑:
- 坑 6(錯誤環境變數) 最難找,因為它完全隱形。沒有 error,沒有 warning。啟動 log 是唯一的訊號。在除錯任何模型行為之前,先讀啟動 log。
- 坑 5(harmony system message) 沒有辦法靠設定繞過——需要改原始碼。任何帶 system prompt 的 client 都會靜默踩到,如果沒有打 patch 的話。
可轉移的診斷模式:
- gpt-oss 出現重複迴圈 → 看啟動 log 有沒有
Auto-selected: CUTLASS_FP4。如果有,你的 Marlin env var 沒有被讀到。 - response 裡
content: null→ 要麼 reasoning parser 把輸出路由到錯誤的欄位,要麼模型在生成 reasoning token。先查 parser flag。 - 內網機器 tokenizer 下載失敗 → tiktoken 用 SHA1 hash 當快取檔名。這個機制沒有任何文件。繞過方法是預先下載並把
TIKTOKEN_RS_CACHE_DIR指向該目錄。
到處都適用的模式:
--enforce-eager 永遠不應該出現在 production serve script 裡。如果你接手了一份 script 而且速度不對,先查這個 flag。它很容易引入 2× 的速度退化,然後被忘掉。
總結:有順序的檢查清單
在 SM121 上跑 gpt-oss-120B:
- 先做第一篇的 SM121 修法。
- 修
mxfp4.py的 import 路徑(eugr fork 路徑 → 標準 vLLM 路徑)。 - 從 serve script 拔掉
--enforce-eager。確認。再確認一次。 - 預先下載 tiktoken vocab;設定
TIKTOKEN_RS_CACHE_DIR。 - 用
VLLM_MXFP4_BACKEND=marlin——不是VLLM_NVFP4_GEMM_BACKEND。錯誤的變數靜默失效。 - 確認啟動 log 顯示
Using backend: marlin,不是Auto-selected: CUTLASS_FP4。 - 如果有送 system prompt:patch
serving_chat.py(PR #31607),設VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS=1。
坑 6 的診斷方法:如果你在遇到重複迴圈,而你以為已經強制走 Marlin,看啟動 log。如果顯示 Auto-selected,你的 env var 沒有被讀到。