~/blog/part2-gpt-oss-120b-serve-script

gpt-oss-120B on DGX Spark · part 1

[vLLM] gpt-oss-120B at 59 tok/s: 6 Pitfalls and a Working Serve Script

2026-03-1910 min read#dgx-spark#sm121#vllm#gpt-oss中文版

Preface

Same hardware, different model, different problem set. After getting SM121 NVFP4 working, the natural next question is: what else can this machine run? gpt-oss-120B is Microsoft's open-source 120B model — a different quantization format, a different tokenizer, and its own category of silent failures.

This article is the full debugging record: six bugs between a blank shell and a working serve script, with a working script at the end.

The analogy: it's like moving into a new apartment where everything looks standard but nothing quite fits — the outlets are a different standard, the fridge runs on the wrong voltage, and the building's water pressure affects the pipes in ways no one warned you about. Each problem is fixable. None of them are obvious until you hit them.


Part 1 covered the four bugs that cause any NVFP4 model to output !!!!! on SM121. Fix those first. This post picks up after that — and covers what happens when you try to serve gpt-oss-120B, a completely different model format with its own set of landmines.

Target: 59 tok/s at 131K context on a single GB10 (128 GB unified memory). Actual path to get there: six separate bugs, one of which took longer to find than all the others combined.


The Model

gpt-oss-120B is Microsoft's open-source 120B model, MXFP4-quantized, in a format called harmony. It uses the openai_harmony tokenizer (which depends on tiktoken, not HuggingFace tokenizers). It has a --reasoning-parser designed specifically for its output format. These three facts each contribute exactly one bug to this story.

The SM121-compatible vLLM fork is eugr's branch, which patches MXFP4 support to work on GB10.

Note: Since this article was written, eugr has published spark-vllm-docker — a Docker-based setup with prebuilt nightly wheels, model recipes, and multi-node support. It includes a run-recipe.sh openai-gpt-oss-120b that handles most of the configuration below automatically. The manual approach documented here still applies if you're patching stock vLLM directly, but the Docker repo is the easier path for most setups.


Bug 1: Import Path Mismatch

The eugr fork was written against his own vLLM tree. When you apply the patches to stock vLLM 0.17.1, one import breaks immediately:

# eugr fork path (wrong on stock vLLM)
from vllm.model_executor.layers.quantization.quant_utils import cutlass_fp4_supported

# stock vLLM 0.17.1 (correct)
from vllm.model_executor.layers.quantization.nvfp4_utils import cutlass_fp4_supported

This is in mxfp4.py. Server crashes on startup, error message is not helpful. One-line fix.


Bug 2: --enforce-eager Cuts Speed in Half

First successful run. Speed: 26 tok/s. Expected: ~59 tok/s.

The serve script had --enforce-eager. This flag disables CUDAGraph and torch.compile — it's a debugging flag that forces eager execution mode. Someone added it for debugging at some point and forgot to remove it.

Remove it. 26 tok/s → 59 tok/s.

--enforce-eager should never appear in a production serve script. If it's there, remove it before debugging anything else.


Bug 3: tiktoken Vocab Download Fails on Air-Gapped Machines

gpt-oss uses the openai_harmony tokenizer, which uses tiktoken under the hood. On startup, tiktoken downloads its vocab file from OpenAI's CDN:

https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken

The GX10 is on an internal network. Download fails silently or with a cryptic error:

HarmonyError: error downloading or loading vocab file

The workaround requires reading the tiktoken source code. tiktoken caches vocab files using their SHA1 hash as the filename. This is not documented anywhere. The hash for o200k_base.tiktoken is:

fb374d419588a4632f3f557e76b4b70aebbca790

Fix:

# on a machine with internet access
mkdir -p ~/models/tiktoken_cache
wget "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken" \
  -O ~/models/tiktoken_cache/fb374d419588a4632f3f557e76b4b70aebbca790

# in your serve script
export TIKTOKEN_RS_CACHE_DIR=/home/username/models/tiktoken_cache

The directory and filename must match exactly. No extension on the cached file — just the raw SHA1 hash.


Bug 4: content: null — The Reasoning Parser Trap

Server running. First test request. The GX10 logs show tokens being generated. But no response arrives. Check the logs:

content: None
reasoning_len: 431
tokens: {'prompt_tokens': 68, 'completion_tokens': 100}

The --reasoning-parser openai_gptoss flag routes all output into the reasoning field and sets content to null. gpt-oss is not a thinking model. The parser is designed for a specific output format gpt-oss uses for its reasoning channel — but if your client only reads content (as any standard OpenAI-compatible client does), it gets nothing.

Initial fix: remove --reasoning-parser openai_gptoss. Content appears.

(This diagnosis was correct, but bug 6 was also active. Removing the parser helped, but without fixing bug 6, outputs still degraded into repetition loops for anything longer than a few sentences.)


Bug 5: System Messages Bypass Harmony Encoding

gpt-oss uses the harmony message format. vLLM processes chat completion requests roughly like this:

  1. Build a system message using get_system_message()
  2. Iterate over request.messages and append each one

The bug (tracked as vLLM PR #31607, unmerged at time of writing): if the client sends a message with role: "system" in the messages array, it gets serialized as a raw harmony message rather than going through get_system_message(). The model sees a malformed token sequence and starts producing garbage.

Any client that sends a system prompt hits this — which is most of them.

Manual fix: patch vllm/entrypoints/openai/serving_chat.py to extract system-role messages from request.messages and route them through get_system_message(instructions=...).

Also required:

export VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS=1

This env var defaults to 0. Without it, the instructions parameter to get_system_message() is silently ignored, and the system prompt never reaches the model regardless.


Bug 6: The Wrong Environment Variable (Silent Failure)

This is the one that takes the longest to find.

After bugs 1–5 are fixed, simple requests work. But anything longer falls into a repetition loop:

The user is asking about... The user is asking about... The user is asking about...

Temperature, top_p, repetition_penalty — none of it matters. The loop always wins.

The known cause: on SM121, CUTLASS_FP4 produces first-token logit corruption (see Part 1 and vLLM issue #37030). Fix is to force Marlin for all MXFP4 GEMMs.

Caveat: This applies to stock vLLM. The spark-vllm-docker build uses a different patched vLLM with --mxfp4-backend CUTLASS and VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1 — their patches make CUTLASS work on SM121. If you're using that Docker image, the Marlin workaround is not required. If you're on stock vLLM or applying patches manually, it is.

The serve script already had this:

export VLLM_NVFP4_GEMM_BACKEND=marlin   # ← looks right

This environment variable does not exist in vLLM 0.17.1.

vLLM reads it, finds no matching env var, and silently ignores it. No warning, no error. It falls back to auto-selection. Auto-selection on SM12x picks CUTLASS_FP4. The startup log shows:

[MXFP4] Auto-selected: CUTLASS_FP4 (vLLM native SM120 FP4 grouped GEMM for SM12x)

The correct env var is:

export VLLM_MXFP4_BACKEND=marlin   # ← correct

When this is right, the startup log shows:

[MXFP4] Using backend: marlin (VLLM_MXFP4_BACKEND=marlin)

After this fix: hi → normal response. Longer story request → 3,970 chars, no loop, finish_reason: stop.


The Working Serve Script

#!/bin/bash
source /home/username/.python-vllm-eugr/bin/activate

# SM121 backend: force Marlin everywhere (CUTLASS_FP4 is broken on SM121)
# NOTE: if using eugr/spark-vllm-docker, their patches make CUTLASS work —
#       use --mxfp4-backend CUTLASS + VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1 instead
export VLLM_MXFP4_BACKEND=marlin               # ← NOT VLLM_NVFP4_GEMM_BACKEND
export VLLM_MARLIN_USE_ATOMIC_ADD=1            # SM121 Marlin atomic race fix
export FLASHINFER_DISABLE_VERSION_CHECK=1

# Offline tokenizer cache
export TIKTOKEN_RS_CACHE_DIR=/home/username/models/tiktoken_cache

# gpt-oss harmony system message fix
export VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS=1

# Clear compile caches between runs
rm -rf ~/.cache/flashinfer/ ~/.cache/vllm/torch_compile_cache/ 2>/dev/null || true

exec vllm serve /home/username/models/gpt-oss-120b \
  --served-model-name gpt-oss-120b \
  --host 0.0.0.0 --port 8001 \
  --quantization mxfp4 \
  --mxfp4-layers moe,qkv,o,lm_head \
  --kv-cache-dtype fp8 \
  --max-model-len 131072 \
  --max-num-batched-tokens 8192 \
  --gpu-memory-utilization 0.90 \
  --attention-backend FLASHINFER \
  --moe-backend marlin

Two flags updated from the original script, based on eugr/spark-vllm-docker recipe and a NVIDIA Developer Forum thread benchmarking gpt-oss-120B on GB10:

  • --attention-backend FLASHINFER replaces TRITON_ATTN. TRITON is a fallback path; FLASHINFER is the complete path and benchmarks faster.
  • --mxfp4-layers moe,qkv,o,lm_head explicitly quantizes projection layers. Without it, qkv/o/lm_head run in BF16 — leaving performance on the table.

Note: --reasoning-parser openai_gptoss is not included. Remove it unless you specifically need the reasoning channel separated.


Performance

| Metric | Value | |--------|-------| | Decode speed | ~59 tok/s | | Backend | Marlin W4A16 (weight dequantized at inference) | | KV cache (fp8, 0.90 utilization) | ~580K tokens capacity | | Max context | 131K tokens | | CUDAGraph | ✅ captured |

The arithmetic: GB10 has 273 GB/s memory bandwidth. 120B × 4-bit ≈ 60 GB. At 273 GB/s, theoretical decode is ~4.5 matrix loads/s, meaning ~60 tok/s. The measured 59 tok/s matches the bandwidth ceiling almost exactly. This is what working looks like on GB10.


What Was Gained

All six bugs were fixed. 59 tok/s is the result — close to the bandwidth ceiling, which means the stack is working correctly. But the lessons are worth stating explicitly, because each one is non-obvious and will recur.

The bugs that cost the most time:

  • Bug 6 (wrong env var) was the hardest because it was invisible. No error, no warning. The startup log is the only signal. Always read the startup log before debugging model behavior.
  • Bug 5 (harmony system message) has no workaround via config — it requires a source patch. Any client sending system prompts will hit this silently if unpatched.

Transferable diagnostics:

  • Repetition loops on gpt-oss → check startup log for Auto-selected: CUTLASS_FP4. If present, your Marlin env var isn't being read.
  • content: null in responses → either reasoning parser is routing output to the wrong field, or the model is generating into reasoning tokens. Check the parser flag first.
  • Tokenizer download failures on internal networks → tiktoken uses SHA1-hashed filenames for its cache. There is no documentation for this. The workaround is pre-downloading and pointing TIKTOKEN_RS_CACHE_DIR at the directory.

The pattern that applies everywhere: --enforce-eager should never be in a production serve script. If you inherited a script and the speed is wrong, check for this flag first. It's an easy 2× regression to introduce and forget.


Summary: The Ordered Checklist

If you're serving gpt-oss-120B on SM121:

  1. Apply the SM121 fixes from Part 1 first.
  2. Fix the mxfp4.py import path (eugr fork path → stock vLLM path).
  3. Remove --enforce-eager from your serve script. Check. Then check again.
  4. Pre-download the tiktoken vocab; set TIKTOKEN_RS_CACHE_DIR.
  5. Use VLLM_MXFP4_BACKEND=marlin — not VLLM_NVFP4_GEMM_BACKEND. The wrong variable silently does nothing.
  6. Confirm the startup log says Using backend: marlin, not Auto-selected: CUTLASS_FP4.
  7. If sending system prompts: patch serving_chat.py for PR #31607, and set VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS=1.

The diagnostic for bug 6: if you're getting repetition loops and you think you've forced Marlin, check your startup log. If it says Auto-selected, your env var isn't being read.


Also in this series: Qwen3.5-122B Runs. But at 14 tok/s — the GDN Kernel Gap