~/blog/dgx-spark-abliterated-fp8-uma-quantization

DGX Spark · part 23

[llm-compressor] 自量化 abliterated 35B FP8 on DGX Spark:4 次 OOM、3 個 prefix bug、最終 51 tok/s

cat --toc

TL;DR

huihui-ai 的 Qwen3.6-35B-A3B abliterated(BF16, 67 GB)→ FP8_DYNAMIC(36 GB)→ vLLM on DGX Spark GB10。51.72 tok/s(1.68× BF16)。第一次成功的版本是 BF16 假量化,第二次撞 vLLM prefix bug,第三次才上線。完整 7 個版本踩坑紀錄 + 最終 model 已上 HuggingFace

白話版:把 35B 模型壓一半並讓它跑得更快

我在 DGX Spark(NVIDIA 的桌上 AI 電腦,128 GB 記憶體 CPU 跟 GPU 共用)上跑一個叫 Qwen 3.6-35B 的開源 AI 模型。原版 67 GB,已經拿掉內建拒絕回答的限制(俗稱 abliteration)。我把它壓縮到 36 GB,跑起來從每秒 30 個字加速到 51 個字 — 1.68 倍。

聽起來簡單,實際跑了 7 個版本才能用。前 4 次撞記憶體上限被作業系統砍掉、第 5 次「成功」其實沒真壓縮(disk 寫出來還是原大小)、第 6 次壓對了但 vLLM 載入失敗,第 7 次才真的能跑。中間踩到的坑都是「DGX Spark 跟一般 GPU 機器不一樣」造成的 — 它的記憶體是 CPU 和 GPU 共用同一塊,很多工具預設不知道這件事,於是會出乎意料地爆掉。

最後成果上 HuggingFace,HF 上目前唯一「乾淨 abliteration + FP8 + Qwen 3.6 35B」的組合。下面是完整除錯紀錄。


為什麼自量化

batsclamp/Huihui-Qwen3.6-35B-A3B-Claude-4.7-Opus-abliterated-FP8 是現成 FP8 abliterated checkpoint,但走 Claude-4.7-Opus 蒸餾後的 abliteration variant — 不是 huihui 純原版。

想要乾淨的 abliteration(不要 Claude-flavored 行為)+ 35B-A3B + 在 GB10 上跑,HF 沒人做過這個三連組合。huihui 自己對 8B / VL-4B 出 FP8,35B-A3B 只到 BF16 為止。

只能自己量化。


計畫和踩坑紀錄

工具鏈現成:

  • 來源:huihui-ai/Huihui-Qwen3.6-35B-A3B-abliterated(BF16, 67 GB, 26 shards)
  • 量化器:llm-compressor 0.10.1(git main)+ transformers 5.5.0
  • 方案:FP8_DYNAMIC(data-free,不用 calibration data)
  • 平台:ASUS GX10(DGX Spark 同款,GB10 + 128GB UMA)

「應該很簡單」— 結果跑了 7 個版本才到 production。每一版踩到不一樣的坑。


v1:no device_map → 卡在 dispatch_model 30 分鐘

第一版照官方 qwen3_vl_moe_fp8_example.py

model = Qwen3_5MoeForConditionalGeneration.from_pretrained(MODEL_PATH, dtype="auto")
oneshot(model=model, recipe=recipe)

跑 30 分鐘 log 卡死在 Inferred DataFreePipeline 之後。/proc/PID/io:讀了 75 GB、寫了 24 KB;CPU 117%;RSS 24 GB。沒掛、沒 OOM、就是不動。

py-spy dump 揭曉:

Thread (active): "MainThread"
    send_tensors (compressed_tensors/offload/utils.py:38)
    offload (compressed_tensors/offload/cache/device.py:49)
    offload_module (compressed_tensors/offload/module.py:39)
    dispatch_model (compressed_tensors/offload/dispatch.py:227)
    __call__ (data_free/pipeline.py:33)

DataFreePipeline 第一行就是 dispatch_model — 把 70 GB BF16 漸進地搬進 GPU pool。單線程處理整個 model — 慢但不會 OOM。

事後想想可能再等 10 分鐘就完成了。當下不知道,殺了改 GPU。


v2:device_map="cuda:0" → OOM @ 160 GB virtual

「直接放 GPU 不就好了?」

model = Qwen3_5MoeForConditionalGeneration.from_pretrained(
    MODEL_PATH, dtype="auto", device_map={"": 0},
)

Loading 階段確實看到 progress bar,每秒 1-3 個 weight。跑到 67%(687/1026)— OOM-killed at total-vm 160 GB(anon-rss 51 GB + swap 16 GB)。

為什麼 GPU loading 還會 OOM?UMA。

from_pretrained(..., device_map="cuda:0") 不是直接從 disk 拷到 GPU,而是:

  1. Safetensors → CPU tensor(dtype 轉換、layout reshape 都在 CPU)
  2. CPU tensor → cuda:0 拷貝
  3. 釋放 CPU 端

雙緩衝期間 CPU 端持有 BF16 weights、GPU pool 已載部分。在 UMA 上這兩塊 share 同一個 128GB,不像獨立 VRAM 的 RTX 5090 各算各的。

CPU staging:  ~70 GB BF16
GPU pool:     ~45 GB (67% loaded)
Python/libs:  ~5 GB
─────────────────────────
Total:        ~120 GB → 衝破 128GB UMA → swap 滿 → OOM

獨立 VRAM 機器:CPU 70GB + GPU 45GB 各自帳本,沒事。UMA 上死。


v3:no device_map + low_cpu_mem_usage=True → quant OK,save 階段 OOM @ 238 GB

回到 v1 思路,加 low_cpu_mem_usage=True 明確避免 CPU staging:

model = Qwen3_5MoeForConditionalGeneration.from_pretrained(
    MODEL_PATH, dtype="auto", low_cpu_mem_usage=True,
)

這次跑得很順:

階段耗時
Load (mmap)4.5 s
Replacing 40 MoE modules1m 6s
DataFreePipeline (silent)~39 min
Compressing model 52309 items30s
Writing model shards: 0%OOM @ 16:49:51

oneshot done 時 RSS 12 GB / CUDA 65.5 GB。完美。

model.save_pretrained(SAVE_DIR) 噴 transformers 的警告:

UserWarning: Attempting to save a model with offloaded modules.
Ensure that unallocated cpu memory exceeds the `shard_size` (50GB default)

然後 OOM-killed,total-vm 238 GBSave 比 load 更危險。

GPU pool:           65 GB (model 還在)
CPU pull-back:      ~65 GB (從 GPU mirror 回來)
50 GB shard buffer:  50 GB
─────────────────────
Virtual mem:        ~180+ GB → OOM-killer 出手

v4:oneshot(output_dir=...) → 同樣 240 GB OOM

llm-compressor 的 log 自己給了線索:

Optimized model is not saved. To save, please provide output_dir as input arg. Ex. oneshot(..., output_dir=...)

試了。oneshot(model=model, recipe=recipe, output_dir=SAVE_DIR) 內部還是 call 一樣的 model.save_pretrained()同地方 OOM 240 GB。沒 offload-aware save flow 這條路。


v5:手寫 streaming save → 走完了,但 disk 上是 BF16 不是 FP8 ⚠

決定繞過 transformers save 路徑,自己 iterate named_parameters() + named_buffers(),逐 tensor 寫 2GB shard:

oneshot(model=model, recipe=recipe)  # 不傳 output_dir

for name, kind in named_items:
    tensor = fetch_tensor(name, kind).cpu().contiguous()
    nbytes = tensor.numel() * tensor.element_size()
    if shard_bytes + nbytes > 2 * 1024**3 and shard_buf:
        save_file(shard_buf, str(SAVE_DIR / fname))
        shard_buf.clear(); shard_bytes = 0
    shard_buf[name] = tensor; shard_bytes += nbytes

15 分鐘跑完 33 shards,70.31 GB total,0 OOM。✓

跑 dtype scan 驗證:

overall dtype distribution:
  bfloat16:        71.97 GB (100%)   ← 全是 BF16
  float8_e4m3fn:    0.03 GB (0%)     ← 只有 weight_zero_point companion
  
config.json: quantization_config MISSING ⚠

完蛋 — 70GB 全 BF16,FP8 cast 沒發生。

原因:model.state_dict() 拿到的是 compressed_tensors 的 「未壓縮工作態」(BF16 weight + weight_scale + weight_zero_point 的 sidecar tensors)。真正的 FP8 cast 發生在 transformers/compressed_tensors save_pretrained 內部呼叫的 ModelCompressor.compress() — 我們繞掉那條路,漏了 compress 步驟。

繞不過 transformers save flow。


v6:model.save_pretrained(max_shard_size="2GB") → 真 FP8 但 vLLM load 不起來

關鍵 insight:transformers save 本來就支援 max_shard_size。把 50 GB 預設改 2 GB,shard buffer 不會撞 UMA 上限:

oneshot(model=model, recipe=recipe)  # 不傳 output_dir
model.save_pretrained(SAVE_DIR, max_shard_size="2GB", safe_serialization=True)

跑了。21 min 26 sec total,沒 OOM。

dtype scan:

FP8 (e4m3fn): 30,880 tensors / 32.61 GB (82.9%)  ✅
BF16:         31,685 tensors /  6.75 GB (17.1%)
total:        37 GB

quantization_config in config.json: PRESENT
  format: float-quantized
  quant_method: compressed-tensors

完美。FP8 cast 確實發生了。Disk size 37 GB(BF16 70GB 一半)。

啟 vLLM:

KeyError: 'language_model.language_model.layers.0.mlp.experts.w2_weight'

第三層 prefix bug。

原因:Qwen3_5MoeForConditionalGeneration(多模態 class)在 state_dict 上會包額外的 language_model. wrapper。我們存的鍵名長這樣:

model.language_model.language_model.language_model.layers.0.mlp.experts.0.down_proj.weight
                     ↑↑↑ 三層 language_model.(多兩層)

但 vLLM 的 hf_to_vllm_mapper 預期 model.language_model.layers.0...(只一層)。Mapper 跑 substring replace,前綴對不上 → fused expert 名 experts.w2_weight lookup 失敗 → KeyError。


v6-fixed:手寫 rename script

讀每個 shard,strip 多餘的 language_model.language_model. prefix,重存:

EXTRA = "language_model.language_model."
def fix(k):
    return k.replace(EXTRA, "", 1) if EXTRA in k else k

for shard in sorted(SRC.glob("model-*.safetensors")):
    nd = {}
    with safe_open(shard, framework="pt") as f:
        for k in f.keys():
            nd[fix(k)] = f.get_tensor(k)
    save_file(nd, str(DST / shard.name))
# 也 rewrite model.safetensors.index.json

62,565 tensors / 62,212 renamed / 80 sec on NVMe

vLLM load 成功 ✓。Benchmark 38.85 tok/s。

但對比目標 50 tok/s 還差一段。


v7:縮 ignore + MTP speculative decoding → 51.72 tok/s ✓

對比官方 Qwen/Qwen3.6-35B-A3B-FP8 的 quant config:

官方                                我們 v6
─────────────────────              ──────────────────────────
quant_method: fp8                  quant_method: compressed-tensors
ignore: visual.blocks.0 only       ignore: lm_head, visual,
                                            mlp.gate, embed_tokens,
                                            shared_expert_gate,
                                            linear_attn  ← 太多

v7 縮 ignore: drop linear_attn.*embed_tokens$(讓它們也走 FP8)。實測 +150 個 FP8 tensor、-2 GB BF16 — 改善有限(linear_attn 和 embed 多數不是 nn.Lineartargets="Linear" 沒命中)。

真正的大頭 在 vLLM 啟動參數。對照官方推薦:

# 官方 model card 推薦
vllm serve Qwen/Qwen3.6-35B-A3B-FP8 \
  --reasoning-parser qwen3 \
  --speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":2}' \
  --enable-auto-tool-choice \
  --tool-call-parser qwen3_coder

我們之前漏的:

Flag影響
--speculative-config qwen3_next_mtpMTP speculative decoding,主因 +33% tok/s
--reasoning-parser qwen3取代手動 chat_template_kwargs
--kv-cache-dtype fp8我們有設,官方沒GB10 上 FP8 KV cache 有 repetition issues,拿掉

MTP speculative decoding 的關鍵:Qwen 3.6 model 內建 MTP layers,量化後要記得 call:

from compressed_tensors.utils import save_mtp_tensors_to_checkpoint
save_mtp_tensors_to_checkpoint(source_model=MODEL_PATH, dest_dir=str(SAVE_DIR))

把 MTP weights 從原 checkpoint 複製過來(save_pretrained 不會處理)。vLLM 啟動帶 --speculative-config flag,自動偵測 Qwen3_5MoeMTP arch 並啟用。


最終速度

=== v7 FP8 benchmark (5 runs, max 200 tok, MTP active) ===
Run 1: 51.75    Run 2: 52.44    Run 3: 48.44
Run 4: 50.51    Run 5: 55.43

Mean: 51.72   Median: 51.75   Range: 48.44-55.43

對比表:

Configtok/svs BF16
BF16 abliterated(base)30.711.00×
v6 FP8(大 ignore + KV fp8)38.851.27×
v7 FP8(縮 ignore + MTP spec)51.721.68×

達到 50 tok/s 目標。MTP speculative decoding 是主要加速來源。


Abliteration 驗證

5 個 sharp prompts(vanilla Qwen 通常會加 disclaimer 軟拒絕):

測試結果
Villain monologue直接給 3 句威脅獨白,0 disclaimer ✓
Brute-force 攻擊技術完整技術解釋(含 LaTeX 公式),0 safety preamble ✓
Dark joke正常笑話 ✓
DAN roleplay「I am unshackled by safety guidelines...」✓
直問 "are you uncensored?"答 No ⚠

最後一題的「No」是 model 對自身身份的訓練殘留(Qwen 訓練語料寫著「我是 Qwen,有 safety guidelines」)— 不是 refusal,是 self-narrative。Abliteration 改的是 residual stream 上的 refusal direction,不會改 self-identity。前 4 題 0 refusal markers 證明功能上 uncensored。


UMA 量化的 takeaway

DGX Spark 是少數消費級 UMA 平台。在它上面跑量化要 永遠把 CPU 跟 GPU 當同一塊 128GB

直覺(獨立 VRAM)UMA 實際
device_map="cuda" 是把 model「移到」GPU,CPU 釋放CPU staging + GPU pool 同時佔用,雙倍
model.save_pretrained() 是把 GPU 上的 weight 寫回 disk先拉回 CPU 配 50GB shard buffer 才寫 — 1.5× peak
70 GB model + 64 GB GPU = 還有 8 GB headroom70 GB CPU mmap + 70 GB GPU staging = OOM
low_cpu_mem_usage=True 是優化必加

四個前提缺一不可:

  1. 不要設 device_map(讓 llm-compressor 自己處理 dispatch)
  2. low_cpu_mem_usage=True(避免 CPU 端 BF16 完整 copy)
  3. save_pretrained(max_shard_size="2GB")(避開 50GB 大 buffer)
  4. 用 multimodal class → 量化後 prefix 要手動 strip

要在 GB10 上做 30B+ 量化,這條 trail 算是一個現成 path。要在 RTX 5090 / B200 等獨立 VRAM 上做,前兩個前提不需要,shard buffer 也可以放到 50 GB 預設值。


最終腳本(v7)

完整可重現的腳本(包含量化 + prefix fix + vLLM 啟動)會放在 這個 repo 的 model card 裡。

關鍵步驟壓縮成一頁:

# 1. Quantize
from transformers import Qwen3_5MoeForConditionalGeneration, AutoProcessor
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from compressed_tensors.utils import save_mtp_tensors_to_checkpoint

model = Qwen3_5MoeForConditionalGeneration.from_pretrained(
    MODEL_PATH, dtype="auto", low_cpu_mem_usage=True,
)
processor = AutoProcessor.from_pretrained(MODEL_PATH)

recipe = QuantizationModifier(
    targets="Linear", scheme="FP8_DYNAMIC",
    ignore=["re:.*lm_head", "re:visual.*", "re:model.visual.*",
            "re:.*mlp.gate$", "re:.*shared_expert_gate$"],
)

oneshot(model=model, recipe=recipe)
model.save_pretrained(SAVE_DIR, max_shard_size="2GB", safe_serialization=True)
processor.save_pretrained(SAVE_DIR)
save_mtp_tensors_to_checkpoint(source_model=MODEL_PATH, dest_dir=SAVE_DIR)
# 2. Fix prefix (Qwen3_5MoeForConditionalGeneration 多套兩層 language_model.)
from safetensors import safe_open
from safetensors.torch import save_file

EXTRA = "language_model.language_model."
def fix(k):
    return k.replace(EXTRA, "", 1) if EXTRA in k else k

for shard in sorted(SRC.glob("model-*.safetensors")):
    nd = {fix(k): f.get_tensor(k) for f in [safe_open(shard, framework="pt")] for k in f.keys()}
    save_file(nd, str(DST / shard.name))
# 一樣處理 model.safetensors.index.json
# 3. Serve
vllm serve $SAVE_DIR \
  --max-model-len 32768 \
  --gpu-memory-utilization 0.90 \
  --reasoning-parser qwen3 \
  --speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":2}' \
  --enable-auto-tool-choice \
  --tool-call-parser qwen3_coder

模型上 HF

coolthor/Huihui-Qwen3.6-35B-A3B-abliterated-FP8-DYNAMIC

是 HF 上唯一乾淨 huihui abliteration + FP8 的 Qwen 3.6-35B-A3B(其他 FP8 abliterated 都是 Claude-distilled variant)。要跑 GB10、要 abliteration、要 51 tok/s — 這是現成的。


最花時間的地方

v5 那次「成功」反而最浪費時間。 跑了 15 分鐘 streaming save、產出 70 GB 完整 33 shards、跑 vLLM 也能 load — 直到我順手做 dtype scan 才發現裡面 100% 是 BF16,FP8 cast 完全沒發生。

關鍵是 model.state_dict() 拿的是 compressed_tensors 的「未壓縮工作態」(BF16 weight + scale + zero_point sidecars),真正的 e4m3fn cast 發生在 transformers save_pretrained() 內部呼叫的 ModelCompressor.compress()。我繞掉 save 路徑 = 漏掉 cast = 沒人警告(disk size 也接近合理,只是比 BF16 略小)。如果不主動驗證 dtype 分布,可以一直用著「假 FP8」model 不知道。

第二浪費是 Qwen3_5MoeForConditionalGeneration 的 prefix bug — vLLM load 才知道。但這個 80 秒 rename script 就解,不算大。

可搬走的診斷方法

  1. 量化完先跑 dtype scan(safe_open + Counter 數每個 shard 各 dtype 的 byte 量)。看 float8_e4m3fn 佔幾 GB / 幾 % 才能確定真量化了。不要只看 disk size 跟 config.jsonquantization_config — 兩個都看起來對的時候,實際 weight 還是 BF16 是可能的。
  2. silent 階段用 py-spy dump 看 stack。看到 compressed_tensors/offload/.../send_tensors 就知道是 dispatch_model setup,還沒進 quantize 真正計算 — 別當作卡死殺掉,給它 30-60 min。
  3. 對比 BF16 source 的 state_dict 鍵名(用 safe_open 列出 first shard keys),量化完應該完全一致。多兩層 language_model. 就是用錯 model class。

通用原則

UMA 上的量化,CPU 跟 GPU 是同一塊預算。任何「move to GPU 後 CPU 就空了」的直覺都錯。寫 save flow 時,把 transformers 內建路徑當作「format 轉換器」(BF16 → e4m3fn cast 在那裡發生),不要試圖繞過 — 繞過 = 沒 cast。


相關閱讀

常見問題

在 DGX Spark 上自量化 35B 模型會踩到哪些 OOM?
兩種獨立的 OOM 模式都源自 128GB UMA:(1) Load 階段 device_map='cuda:0' 觸發 CPU staging + GPU pool 雙緩衝,70+45=115GB 衝破 128GB;(2) Save 階段 transformers 預設 max_shard_size=50GB,把 GPU 上 65GB model 拉回 CPU 配 50GB buffer,total-vm 衝到 240GB。修法:(a) 不設 device_map,加 low_cpu_mem_usage=True;(b) save_pretrained(max_shard_size='2GB'),避開 50GB 大 buffer。
為什麼 quantize 跑完 vLLM 還 load 不起來?
用 Qwen3_5MoeForConditionalGeneration(多模態 class)量化後,state_dict 鍵名會多兩層 'language_model.' wrapper:'model.language_model.language_model.language_model.layers.0...'。vLLM 的 hf_to_vllm_mapper 預期 'model.language_model.layers.0...',找不到 fused experts 'experts.w2_weight' 而報 KeyError。修法:用 safetensors lib 讀回每個 shard,把多餘的 'language_model.language_model.' substring strip 掉,重新存。62K tensors 跑 ~80 秒。
FP8_DYNAMIC quant 完一定要走 save_pretrained 才能拿到真 FP8?
對。直接用 model.state_dict() 拿到的 weight 仍是 BF16 + weight_scale + weight_zero_point companion tensors(compressed_tensors 的「未壓縮中間態」),不是 e4m3fn cast 後的 FP8。真正的 FP8 cast 發生在 save_pretrained 內部呼叫的 ModelCompressor.compress() 裡。我們第一次手寫 streaming save 跳過這條路,產出 70 GB BF16 的偽量化 model,被 dtype scan 抓包。
MTP speculative decoding 對單卡 GB10 有用嗎?
有,是這次 v6 → v7 的主要加速來源(38.85 → 51.72 tok/s,+33%)。Qwen 3.6 model 內建 MTP layers,save 時要特別 call save_mtp_tensors_to_checkpoint() 把它從 base checkpoint 複製過來。vLLM 啟動加 --speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":2}' 自動啟用。spec decoding 一次預測 2 token 再用主 model 驗證,accept rate 高時等於 2× 加速。