LK Losses: Direct Acceptance Rate Optimization for Speculative Decoding

Nebius · Samarin, Krutikov, Shevtsov, Skvortsov, Fisin, Golubev · 2026-03-02 · arXiv:2602.23881
关键词: speculative decoding · draft training · KL divergence · TV distance · acceptance rate · EAGLE-3 · MEDUSA · MTP

速读卡片 (TL;DR)

一句话:SD draft 训练用了多年 KL/CE,其实是个 proxy 目标——只在 q=p 全局最优时才等价于 acceptance rate;在容量受限的 draft 上,KL 最小≠α 最大。LK losses 直接以 α 为目标,通过 TV(p,q) 与 −log α 两条路径绕过 TV 梯度消失这道坎,作为已有 SD 训练框架的 drop-in replacement,在 8B–685B 六个 target × 四种 draft 架构上一律提升 τ,平均 acceptance length 涨幅达 8–10%。

+8.2%
Qwen3-235B EAGLE-3, T=1
+7.7%
GPT-OSS 120B, T=1
0 FLOPs
额外训练开销

立场:这是把 SD draft 训练的"目标函数"层面拨正——比 MARS 那种基于 acceptance threshold 的方法更"根级",因为它直接换 loss、不需要新增推理逻辑。最大启发:KL 在小容量下不是"无害的好代理",它在权衡哪里牺牲这件事上和 α 不一致。


1 · 动机:proxy 之坑

1.1 历史脉络:为什么大家一直用 KL/CE

从 Leviathan 2023 把 SD 落地开始,draft model 的训练几乎一律是一个 KD(knowledge distillation)流程:把 target 当 teacher,用 KL(p∥q) 或等价的 CE 让 draft 学习 target 的输出分布。MEDUSA(2024)、EAGLE 系列(2024–2025)、DeepSeek-V3 MTP(2024)、FR-Spec(2025)——架构在变,但 loss 的骨架没变。EAGLE 在此基础上加了 hidden-state 的 MSE regression,MEDUSA 加了 LM 目标,但都是额外补丁,核心仍然是 KL。

这种共识基于一个事实:KL 和 acceptance rate α 共享同一个全局最优点。如果 q 能完美匹配 p,KL=0、α=1。所以"训 KL ≈ 训 acceptance"在理论上是没毛病的——前提是你能到达那个全局最优。

但 draft model 通常只有 target 1–5% 的参数(EAGLE-3 一层 transformer 对 70 层的 Qwen3-235B),它注定到不了 q=p。问题就是:在你必然停留的这个 suboptimal 点附近,KL 最小化和 α 最大化是两件不同的事。这就是这篇 paper 的核心攻击面。

1.2 别的方案为什么不够

方案核心思路遗留问题
KL(p∥q) — mode covering避免 q 在 p 有概率的地方崩到 0把概率铺得太散,小容量下 α 反而低(toy 例子里 α=50.2%)
Reverse KL(q∥p) — mode seekingq 必须把质量集中在自己 commit 的 mode 上容易 collapse 到单峰,multi-mode target 上同样亏 α(50.8%)
MSE on hidden(EAGLE)追 hidden state 让 inference/training 一致仍是间接,且 hidden 距离 ≠ 输出分布距离
EAGLE-3 multi-layer KD融合 target 多层 hidden 提高表达力架构层面的努力,loss 仍然是 KL
MARS 类阈值方法在 inference 时丢掉低 acceptance 的 head/位置事后挽救,不是从训练目标上改
DistillSpec(尝试 TV)意识到 TV 才直接对 α仅在已预训练的 LM 上 fine-tune;从 random init 训不动 — 见 §3

注意最后一行:DistillSpec 早就指出 TV 是对的目标。但他们的实验是从一个已经会说话的 LM 上微调(初始 q 已经接近 p),所以 TV 的梯度问题被掩盖。在从头训 EAGLE/MEDUSA 这种小 head 的真实场景下,纯 TV 训练会立刻撞墙——这是 LK 这篇要解决的工程关键。

1.3 为什么这事不平凡

看上去很直白:既然 α=1−TV(p,q),那就直接最小化 TV 不行吗?或者更直接,把 α 写出来当 loss?三个障碍让这件事过去三年没人做成:

  1. min(p,q) 不光滑。α=Σ min(pi,qi) 在 pi=qi 处不可微,对应 TV 中的 sign 跳变。直接梯度只携带方向(过预测 vs 欠预测),不携带幅度——欠预测一点点和欠预测很多得到一样大小的 sign 信号。
  2. TV 梯度幅度消失。论文 §A.5 给出在"q 几乎均匀(随机初始化) + p 集中在 k 个 token"的 regime 下:‖∇zTV‖ = O(√k / V)。对 V=128k 的现代 vocabulary,这个梯度比 KL 的 O(1/√k) 小好几个数量级。换句话说,从随机初始化开始,TV 推不动 logits。
  3. −log α 看似优雅,但 α=Σ min(p,q) 不是 softmax 的 likelihood。它没有 NLL 那种"目标 token 一个 slot"的好分解(虽然在 p 是 point mass 时退化为 NLL,见 §B)。它的梯度推不出标准 cross-entropy 那种 q−p 形式;LK 推到最后是 (1/α)·∇TV,等于"TV 加自适应 1/α 放大"——这层自适应是关键,把 TV 的梯度消失问题在 α→0(随机初始化)时自动恢复 O(1/√k)。

所以这篇 paper 的"非平凡"含金量,不在 idea(idea 是 obvious 的:就直接优化 α 嘛),而在把 TV 梯度的病理诊断清楚 + 给出两条都能绕过的路径(混合调度 / 自适应放大)。这是工程师视角下的"算法障碍解决"。

target p (multi-modal) KL fit (mode covering) α ≈ 50% TV fit (overlap maximizing) α ≈ 60% 同样的 single-Gaussian 容量 同样的 target 不同 loss → 不同 α KL: 把质量铺到所有有 p 的位置(包括小峰/尾巴) → 主峰位质量被稀释 TV: 只关心 overlap 面积 敢放弃小峰,集中主峰 → α 直接最大
同一个单 Gaussian 容量去拟合 multi-modal target。KL 把质量铺得均匀以避免 log(0) 惩罚,主峰处反而薄;TV 不在乎"覆盖率",只在乎"重合面积"=acceptance rate,因此敢于放弃小峰把质量压到主峰。toy 例子已经能看到 α 从 ~50%→60% 的差距,这正是 LK 想要复刻到 draft 训练里的现象。

2 · 背景速查

符号 / 术语含义
p, qtarget / draft 在某 context 下的 next-token 分布
β(x) = min(1, p(x)/q(x))给定 q 抽出 x,target 接受它的概率(rejection sampling)
α = Σ min(pi, qi) = 1 − TV(p,q)per-position acceptance rate
τ = K · (#accepted/#drafted) + 1每轮 SD 期望生成 token 数 (含 bonus token);primary metric
Kdraft 序列长度(EAGLE-3/MTP 用 7,MEDUSA/MLP 用 6)
zqdraft logits;q = softmax(zq)
LλLKhybrid loss = λ·KL + (1−λ)·TV,λ 由当前 α 自适应
LαLKlikelihood loss = −log α
ηλ 调度器的衰减强度,λ = exp(−η·sg[α])
γ = 0.8多 head loss 的 exponential decay,前面的 head 权重大
速记 §3.1 的 SD 验证规则:把 K 个 draft token 一次性 forward target,从位置 1 起逐个判 accept(prob = min(1, p/q));遇到第一个 reject 就停,并从 (p−q)+ 中重新采样 1 个 token 替换它(这就是 bonus token 的来源)。所以 τ 至少为 1。

3 · 梯度解剖:KL vs TV(为什么直接做 TV 撞墙)

论文最 load-bearing 的不是 idea 而是这一节的诊断,值得逐字啃。

3.1 KL 的梯度:漂亮但代理

zq KL(p∥q) = q − p

这是个被反复用过的好朋友:每个 logit 收到的力 = "我现在的概率"减"目标概率"。方向、大小都自洽。在随机初始化 + p 集中在 k token 的 regime 下,‖q−p‖ = O(1/√k) — 假设 k=20,梯度幅度 ~0.22,稳稳能动。

3.2 TV 的梯度:方向对了但小到看不见

zq TV(p, q) = ½ q ⊙ (s − Eq[s]), si = sign(qi − pi)

问题三连:

梯度信号在 logit 空间的"力"对比 Token A: 欠预测 0.001 (轻微) KL 力 TV 力 ←sign 信号无关大小, 欠 0.001 也是一格 Token B: 欠预测 0.5 (严重) KL 力 TV 力 ↑ KL 力按比例放大 TV 不变(只看符号) 随机初始化 + V=128k: ‖∇KL‖ ~ O(1/√k) ≈ 0.22 ‖∇TV‖ ~ O(√k/V) ≈ 3.5×10⁻⁵ (差 ~7000×) 非光滑点:qi = pi sign 在等值面上跳变, 优化器附近震荡无法稳定
TV 梯度的三重病理:① 信号只有 sign;② 被 q 整体缩小,在 random init + 大词表下幅度爆炸性消失;③ 等值面上不可微。这就是从随机初始化开始训纯 TV 不行的原因——也是为什么 DistillSpec 用 TV 必须从 pretrained LM 接着调。

3.3 关键数值表(论文 Table 3 的复述)

LossGradient on S (the k support tokens)Gradient off S (the V−k 噪声 token)
KL−1/k+1/V
TV−1/V≈ 0
LαLK = −log α−1/k+1/V

读法:LαLK 在主梯度成分上和 KL 同量级(支持 token 上 −1/k),但方向和 TV 一致(对 α 直接求导)。这就是它在工程上"既能从 random init 起步,又把 α 当真目标"的来源。

4 · LλLK:KL+TV 自适应混合(trust region 的味道)

4.1 公式

LλLK(p, q) = λ · KL(p∥q) + (1 − λ) · TV(p, q)
λ = exp(−η · sg[α]), η > 0

sg[·] 是 stop-gradient,防止 λ 自己被反向传播改写。α 用当前 batch / sequence / position 聚合后的真实 acceptance rate 估计。

4.2 直觉:课程学习而非线性混合

关键点不是"加权平均",而是自适应调度:

这和 trust region(TRPO)的精神一致:KL 当 soft constraint 把 q 拉进 p 附近,然后在 trust region 内对真目标 TV 做下降。论文里直接给了这个对偶解读:

minq TV(p,q) s.t. KL(p∥q) ≤ δ
1.0 0 α 0 0.5 1.0 λ η=3 (default) η=1 (slow decay) Phase 1:KL 主导 (光滑,有幅度) Phase 2:TV 主导 (对 α 直接做下降)
λ 调度曲线。η=3 默认配置下,α 从 0 涨到 1 时 λ 从 1 衰减到 0.05;意味着训练前期几乎纯 KL,后期几乎纯 TV。η 越大切换越早,论文对 MEDUSA 用 η=10 因其 α 涨得慢,需要更激进地切到 TV 才能见效。这是一条课程学习曲线,不是固定加权——后者(λ=0.5 fixed)在实验里输给了所有自适应配置。

4.3 Worked example: 一个 token 走一遍

设词表 V=1024(简化),target p 集中在 token x=42,p42=0.7,p11=0.2,p99=0.1,其余为 0。draft 早期 q 接近均匀,qi≈1/1024≈9.8e-4,设当前 batch 平均 α ≈ 0.05(等价于很差的对齐)。

训了几千步后 α=0.6,q42=0.55,p42=0.7:

5 · LαLK:−log α 一条龙

5.1 直觉

把 α 写成 marginal probability of acceptance:α = Σ q(x)·β(x)。这是"draft 抽到 x 且被接受"的边缘概率。最大化 α 的最大似然写法就是最小化 −log α:

LαLK = −log α = −log Σx∈V min(p(x), q(x))

比 hybrid 更简洁:不需要混合权重、不需要 schedule。但能不能从随机初始化训得动?这就要看它的梯度。

5.2 关键关系(§A.4)

zq LαLK = (1/α) · ∇zq TV(p, q)

这是整篇 paper 最漂亮的一行。意思是:−log α 的梯度 = TV 梯度 × 1/α。1/α 是自适应 boost:

LαLK = (1/α)·∇TV — 一图看明白自适应放大 vocab token i α = Σ min(p,q) = 重叠面积(灰) 早期: α=0.05 ‖∇Lα‖ = (1/0.05) × O(√k/V) = 20× boost → 量级回到 O(1/√k),与 KL 同档 后期: α=0.7 ‖∇Lα‖ ≈ 1.43× ∇TV → 几乎纯 TV,直接对 α 做精修
α=Σmin(p,q) 的几何理解(灰色重叠区即 acceptance rate)+ 1/α 自适应放大的双阶段示意。这条单一 loss 不需要 schedule,自带"早期靠放大、后期回归 TV"的 curriculum——这就是为什么作者在论文里把它推荐给 simplicity-first 的实现者。但实验上它通常略输 hybrid,因为 hybrid 的 KL 项在早期同时提供了"方向更自洽"的梯度。

5.3 Worked example:同 CE 不同 α

这是一个最能击穿"KL 训练 = α 训练"幻觉的对比。设 V=4,target p = (0.50, 0.30, 0.15, 0.05)。

方案qCE = −Σ p log qα = Σ min(p,q)
qA(KL 偏好)(0.40, 0.30, 0.20, 0.10)≈ 1.3430.40+0.30+0.15+0.05 = 0.90
qB(TV 偏好)(0.55, 0.30, 0.10, 0.05)≈ 1.3550.50+0.30+0.10+0.05 = 0.95
qC(spread)(0.35, 0.30, 0.20, 0.15)≈ 1.3470.35+0.30+0.15+0.05 = 0.85

读这个表:qA 的 CE 最低(KL 也最低,因 KL=CE−H(p)),但α 反而比 qB 低 0.05。qB 牺牲了一点 CE,换来 acceptance rate 提升。qC 最铺得均匀,CE 居中但 α 最差。这就是"KL 最小 ≠ α 最大"的具体数值证据。LK 的两个 loss 都会在 qA 和 qB 之间偏向后者。

6 · Vocabulary truncation:KL 多了一层 proxy,LK 没这个问题

EAGLE-3 + FR-Spec 配置下,draft 的 LM head 只输出一个截断词表(比如 32k,target 词表 128k),其它 token 的 qi=0。这对 KL 是致命的:pi>0 而 qi=0 → KL=∞。常规做法是把 target 也截:p̃ = softmax(m⊙zp),把截断外的 logit 设 −∞,再去做 KL(p̃∥q)。

这就是论文一句犀利的吐槽:"makes KL a proxy of a proxy"(代理的代理)。一层是 KL≠α,另一层是 p̃≠p。

LK 怎么处理?从 α=Σmin(pi,qi) 看,被截掉的 token 上 qi=0,贡献 min(pi, 0)=0,自然被忽略 — 不需要修改 p。LαLK 和 hybrid 的 TV 部分都直接对原始 p 做优化,不引入额外近似。这是 LK 落地 EAGLE-3+FR-Spec 时的一个免费红利。

为什么这点重要:在生产部署里 truncated vocab 是越来越流行(FR-Spec 在 EAGLE-3 上节省 LM head 计算)。如果你的训练 loss 已经被迫从 KL(p∥q) 退到 KL(p̃∥q),目标和真实 α 之间有两道距离,LK 直接抹掉一道。

7 · 公式串与数值直觉

7.1 推导链条(§A 复述)

  1. softmax Jacobian:∂qi/∂zq,j = qiij − qj)。
  2. KL:∂KL/∂zq,j = −Σi (pi/qi)·qiij − qj) = −pj + qj·Σpi = qj − pj
  3. TV:链式 sign 经过 softmax Jacobian 得到 ½qj(sj − Eq[s])。
  4. −log α:∂(−log α)/∂z = (1/α)·(−∂α/∂z) = (1/α)·∂TV/∂z(因 α=1−TV)。

7.2 在 p 是 point mass 的退化情况

若 p(x*)=1,其它为 0,则 α = Σ min(pi, qi) = min(1, q(x*)) = q(x*)(因 q 是概率)。所以:

LαLK | p=δx* = −log q(x*) = NLL

这是个不错的 sanity check:当 target 退化为硬标签,LK 等于普通的 cross-entropy 训练。它是 NLL 在软标签(distillation)情况下的"正确" generalization——比直接做 soft-label CE 更直接对 α。

7.3 数值敏感性表(η 选择)

α 当前值η=1, λη=3, λη=10, λ
0.01.001.001.00
0.20.820.550.135
0.50.610.220.0067
0.80.450.0913.4e-4

读法:η 越大,KL 越早退场。EAGLE-3 / MTP 这种α 涨得快的架构默认 η=3 即可;MEDUSA 因为头之间独立、α 涨得慢,论文用 η=10 强行推快 TV 接管。这是个有点反常识但可解释的调参逻辑——你想让 TV 在 α 还没涨太高的时候就接管,因为再不接管它就不打磨了。

8 · 实验关键结果

8.1 跨架构(LLaMA-3.1-8B,Table 1 节选)

DraftLossMT-bench τHumanEval τGSM8K τ
EAGLE-3 (T=1)KL3.394.313.88
TV (纯)2.673.253.12
LλLK η=33.484.524.02
MEDUSA (T=1)KL1.722.021.81
LλLK η=101.852.221.92
MLP (T=1)KL2.132.162.16
LλLK η=32.192.622.18

关键观察:

8.2 跨 target scale(Table 2)

TargetMean τ KL (T=1)Mean τ LλLK (T=1)Δ%
LLaMA-3.1-8B3.864.01+3.9
LLaMA-3.3-70B4.504.66+3.5
GPT-OSS-20B3.173.29+3.8
GPT-OSS-120B2.462.65+7.7
Qwen3-235B3.774.08+8.2
DeepSeek-V3 685B (MTP)4.434.68+5.6

大模型(尤其 MoE)+小 dense draft 的容量差越悬殊,LK 的相对提升越大 — Qwen3-235B 的 +8.2% 是论文最亮的数字。DeepSeek-V3 是 fine-tune MTP(不是从头训),也提升 5.6%,印证 LK "在 q,p 距离大时都能赢" 的普适性。

8.3 一些容易被略过的细节

9 · 与同类工作对比

Loss / 方法对 α 的关系从随机初始化能训吗?truncated vocab 兼容?额外开销
CE / KL(p∥q)proxy(同 global opt)需把 p 也截0
Reverse KL(q∥p)proxy,mode seeking兼容0
MSE on hidden(EAGLE)很间接看实现+ regression head
EAGLE-3 multi-layer KD架构层面,loss 仍是 KL同 KL+ feature fusion
DistillSpec(reverse KL/TV)TV 直接对 α需 pretrained LM 起步取决于 divergence0
MARS(threshold)inference-time filter+ inference 逻辑
LαLK直接(−log α)能(1/α 自适应放大)天然兼容0
LλLK直接(后期)能(KL 起步)天然兼容(TV 部分)0

核心差异化:

10 · 局限 / 个人 take / 待验证

需要验证的几个问题:

  1. LαLK更大 V(比如 200k+)和更深 draft(2-3 层 transformer)上,1/α 放大是否足够?会不会触发数值不稳定(α 接近 0 时 1/α 爆炸)?
  2. 把 LK 嫁接到 RL post-training 里的 draft model(参考 ReSpec / NeMo-RL),acceptance length 抖动场景下 sg[α] 估计的方差会不会让 λ 跑飞?
  3. 对 reverse KL(mode seeking)直接也做一次 hybrid:λ·rKL + (1−λ)·TV 是否优于 λ·KL + (1−λ)·TV?论文没尝试。
  4. p 是 point mass 时退化为 NLL,中间区间(p 是 top-p=0.9 截断的)是否退化为某个加权 NLL?是否解释了 LK 在 GSM8K(高 confidence target)上提升相对小?
  5. Loss landscape 在 p≈q 附近 KL 与 TV 的 Hessian 性质对比 — 是否能解释为什么 hybrid 略胜 likelihood?

11 · Memory points

立场 KL/CE 是 SD draft 训练的proxy 目标;在容量受限的 draft 上 KL 最小 ≠ α 最大。LK losses 把 loss 直接换成 acceptance rate 的可微表示。
公式 1 LαLK = −log Σ min(pi, qi) ;其梯度 = (1/α)·∇TV — 自适应放大,从随机初始化训得动。
公式 2 LλLK = λ·KL + (1−λ)·TV,λ = exp(−η·sg[α]) ;trust-region 视角下 KL 是 soft constraint,TV 是真目标。
障碍 直接做 TV 不行:‖∇TV‖=O(√k/V),V=128k 时小到推不动 logits;且 sign-only,landscape 不光滑。
数据 6 个 target(8B–685B)× 4 个 draft 架构,τ 一律提升;最大增益 Qwen3-235B +8.2%,GPT-OSS-120B +7.7%(均 T=1)。
规律 容量越小 LK 增益越大 — MEDUSA 7.8%,MLP 8.3%,EAGLE-3 3.8%(同 8B target)。
免费红利 truncated vocabulary 下 KL 要给 p 也截(proxy of proxy);LK 不需要(min(p,0)=0 自然忽略)。
退化 p 是 point mass 时 LαLK = −log q(x*) = NLL;LK 是 distillation 软标签下的"正确"NLL 推广。
η 调参 EAGLE-3/MTP 用 η=3 即可;MEDUSA 用 η=10(α 涨得慢需提早切到 TV)。
实现 真 drop-in:0 额外 FLOPs、0 推理改动、和 EAGLE-3/FR-Spec 等已有架构正交叠加。注意 vLLM 需要 patch 让 draft 在 T>0 时做真正 rejection sampling 而不是 greedy。
未解 top-k/top-p 部署、learnable head 聚合、对 K(draft 长度)扫描的曲线;τ 提升到 wall-clock 的换算。