LK Losses: Direct Acceptance Rate Optimization for Speculative Decoding
速读卡片 (TL;DR)
一句话:SD draft 训练用了多年 KL/CE,其实是个 proxy 目标——只在 q=p 全局最优时才等价于 acceptance rate;在容量受限的 draft 上,KL 最小≠α 最大。LK losses 直接以 α 为目标,通过 TV(p,q) 与 −log α 两条路径绕过 TV 梯度消失这道坎,作为已有 SD 训练框架的 drop-in replacement,在 8B–685B 六个 target × 四种 draft 架构上一律提升 τ,平均 acceptance length 涨幅达 8–10%。
立场:这是把 SD draft 训练的"目标函数"层面拨正——比 MARS 那种基于 acceptance threshold 的方法更"根级",因为它直接换 loss、不需要新增推理逻辑。最大启发:KL 在小容量下不是"无害的好代理",它在权衡哪里牺牲这件事上和 α 不一致。
1 · 动机:proxy 之坑
1.1 历史脉络:为什么大家一直用 KL/CE
从 Leviathan 2023 把 SD 落地开始,draft model 的训练几乎一律是一个 KD(knowledge distillation)流程:把 target 当 teacher,用 KL(p∥q) 或等价的 CE 让 draft 学习 target 的输出分布。MEDUSA(2024)、EAGLE 系列(2024–2025)、DeepSeek-V3 MTP(2024)、FR-Spec(2025)——架构在变,但 loss 的骨架没变。EAGLE 在此基础上加了 hidden-state 的 MSE regression,MEDUSA 加了 LM 目标,但都是额外补丁,核心仍然是 KL。
这种共识基于一个事实:KL 和 acceptance rate α 共享同一个全局最优点。如果 q 能完美匹配 p,KL=0、α=1。所以"训 KL ≈ 训 acceptance"在理论上是没毛病的——前提是你能到达那个全局最优。
但 draft model 通常只有 target 1–5% 的参数(EAGLE-3 一层 transformer 对 70 层的 Qwen3-235B),它注定到不了 q=p。问题就是:在你必然停留的这个 suboptimal 点附近,KL 最小化和 α 最大化是两件不同的事。这就是这篇 paper 的核心攻击面。
1.2 别的方案为什么不够
| 方案 | 核心思路 | 遗留问题 |
|---|---|---|
| KL(p∥q) — mode covering | 避免 q 在 p 有概率的地方崩到 0 | 把概率铺得太散,小容量下 α 反而低(toy 例子里 α=50.2%) |
| Reverse KL(q∥p) — mode seeking | q 必须把质量集中在自己 commit 的 mode 上 | 容易 collapse 到单峰,multi-mode target 上同样亏 α(50.8%) |
| MSE on hidden(EAGLE) | 追 hidden state 让 inference/training 一致 | 仍是间接,且 hidden 距离 ≠ 输出分布距离 |
| EAGLE-3 multi-layer KD | 融合 target 多层 hidden 提高表达力 | 架构层面的努力,loss 仍然是 KL |
| MARS 类阈值方法 | 在 inference 时丢掉低 acceptance 的 head/位置 | 事后挽救,不是从训练目标上改 |
| DistillSpec(尝试 TV) | 意识到 TV 才直接对 α | 仅在已预训练的 LM 上 fine-tune;从 random init 训不动 — 见 §3 |
注意最后一行:DistillSpec 早就指出 TV 是对的目标。但他们的实验是从一个已经会说话的 LM 上微调(初始 q 已经接近 p),所以 TV 的梯度问题被掩盖。在从头训 EAGLE/MEDUSA 这种小 head 的真实场景下,纯 TV 训练会立刻撞墙——这是 LK 这篇要解决的工程关键。
1.3 为什么这事不平凡
看上去很直白:既然 α=1−TV(p,q),那就直接最小化 TV 不行吗?或者更直接,把 α 写出来当 loss?三个障碍让这件事过去三年没人做成:
- min(p,q) 不光滑。α=Σ min(pi,qi) 在 pi=qi 处不可微,对应 TV 中的 sign 跳变。直接梯度只携带方向(过预测 vs 欠预测),不携带幅度——欠预测一点点和欠预测很多得到一样大小的 sign 信号。
- TV 梯度幅度消失。论文 §A.5 给出在"q 几乎均匀(随机初始化) + p 集中在 k 个 token"的 regime 下:‖∇zTV‖ = O(√k / V)。对 V=128k 的现代 vocabulary,这个梯度比 KL 的 O(1/√k) 小好几个数量级。换句话说,从随机初始化开始,TV 推不动 logits。
- −log α 看似优雅,但 α=Σ min(p,q) 不是 softmax 的 likelihood。它没有 NLL 那种"目标 token 一个 slot"的好分解(虽然在 p 是 point mass 时退化为 NLL,见 §B)。它的梯度推不出标准 cross-entropy 那种 q−p 形式;LK 推到最后是 (1/α)·∇TV,等于"TV 加自适应 1/α 放大"——这层自适应是关键,把 TV 的梯度消失问题在 α→0(随机初始化)时自动恢复 O(1/√k)。
所以这篇 paper 的"非平凡"含金量,不在 idea(idea 是 obvious 的:就直接优化 α 嘛),而在把 TV 梯度的病理诊断清楚 + 给出两条都能绕过的路径(混合调度 / 自适应放大)。这是工程师视角下的"算法障碍解决"。
2 · 背景速查
| 符号 / 术语 | 含义 |
|---|---|
| p, q | target / draft 在某 context 下的 next-token 分布 |
| β(x) = min(1, p(x)/q(x)) | 给定 q 抽出 x,target 接受它的概率(rejection sampling) |
| α = Σ min(pi, qi) = 1 − TV(p,q) | per-position acceptance rate |
| τ = K · (#accepted/#drafted) + 1 | 每轮 SD 期望生成 token 数 (含 bonus token);primary metric |
| K | draft 序列长度(EAGLE-3/MTP 用 7,MEDUSA/MLP 用 6) |
| zq | draft logits;q = softmax(zq) |
| LλLK | hybrid loss = λ·KL + (1−λ)·TV,λ 由当前 α 自适应 |
| LαLK | likelihood loss = −log α |
| η | λ 调度器的衰减强度,λ = exp(−η·sg[α]) |
| γ = 0.8 | 多 head loss 的 exponential decay,前面的 head 权重大 |
3 · 梯度解剖:KL vs TV(为什么直接做 TV 撞墙)
论文最 load-bearing 的不是 idea 而是这一节的诊断,值得逐字啃。
3.1 KL 的梯度:漂亮但代理
这是个被反复用过的好朋友:每个 logit 收到的力 = "我现在的概率"减"目标概率"。方向、大小都自洽。在随机初始化 + p 集中在 k token 的 regime 下,‖q−p‖ = O(1/√k) — 假设 k=20,梯度幅度 ~0.22,稳稳能动。
3.2 TV 的梯度:方向对了但小到看不见
问题三连:
- 方向对、幅度无关:si 只是 ±1。qi 比 pi 低 0.001 还是低 0.5,得到一样的 sign。
- 有 q 的尺度因子:整体被 ⊙q 调制。随机初始化时 qi ≈ 1/V,V=128k → qi ≈ 8e-6。‖∇TV‖ = O(√k/V) ≈ √20/128000 ≈ 3.5e-5。比 KL 小约 7000 倍。
- landscape 不光滑:在 qi=pi 的流形上 sign 跳变。
3.3 关键数值表(论文 Table 3 的复述)
| Loss | Gradient on S (the k support tokens) | Gradient off S (the V−k 噪声 token) |
|---|---|---|
| KL | −1/k | +1/V |
| TV | −1/V | ≈ 0 |
| LαLK = −log α | −1/k | +1/V |
读法:LαLK 在主梯度成分上和 KL 同量级(支持 token 上 −1/k),但方向和 TV 一致(对 α 直接求导)。这就是它在工程上"既能从 random init 起步,又把 α 当真目标"的来源。
4 · LλLK:KL+TV 自适应混合(trust region 的味道)
4.1 公式
sg[·] 是 stop-gradient,防止 λ 自己被反向传播改写。α 用当前 batch / sequence / position 聚合后的真实 acceptance rate 估计。
4.2 直觉:课程学习而非线性混合
关键点不是"加权平均",而是自适应调度:
- 训练初期 α≈0:λ = exp(0) = 1 → 等价于 KL,享受光滑、有幅度的梯度。
- 训练后期 α 接近高水位(η=3 时,α=1 → λ=exp(−3)≈0.05):TV 接管,直接打磨 α。
这和 trust region(TRPO)的精神一致:KL 当 soft constraint 把 q 拉进 p 附近,然后在 trust region 内对真目标 TV 做下降。论文里直接给了这个对偶解读:
4.3 Worked example: 一个 token 走一遍
设词表 V=1024(简化),target p 集中在 token x=42,p42=0.7,p11=0.2,p99=0.1,其余为 0。draft 早期 q 接近均匀,qi≈1/1024≈9.8e-4,设当前 batch 平均 α ≈ 0.05(等价于很差的对齐)。
- 当下 λ = exp(−3 × 0.05) = exp(−0.15) ≈ 0.86 → KL 占大头(86% KL + 14% TV)。
- KL 在 z42 的梯度 ≈ q42 − p42 ≈ −0.699:把 logit 42 往上推。
- TV 在 z42:s42=sign(q−p)=−1, Eq[s] ≈ +1−2k/V ≈ 0.994, q42·(s−Eq[s])/2 ≈ 9.8e-4 × (−1.994)/2 ≈ −9.8e-4。比 KL 小 700×。
- 合成:λ·∇KL + (1−λ)·∇TV ≈ 0.86×(−0.699) + 0.14×(−9.8e-4) ≈ −0.601。KL 项主导,稳稳推。
训了几千步后 α=0.6,q42=0.55,p42=0.7:
- λ = exp(−1.8) ≈ 0.165 → TV 接管(83.5%)。
- KL 在 z42:0.55−0.7=−0.15。
- TV 在 z42:q42 大了,梯度幅度也跟上来了,这时 sign-only 信号反而是优点 — 不会被 outlier 大误差带跑。
5 · LαLK:−log α 一条龙
5.1 直觉
把 α 写成 marginal probability of acceptance:α = Σ q(x)·β(x)。这是"draft 抽到 x 且被接受"的边缘概率。最大化 α 的最大似然写法就是最小化 −log α:
比 hybrid 更简洁:不需要混合权重、不需要 schedule。但能不能从随机初始化训得动?这就要看它的梯度。
5.2 关键关系(§A.4)
这是整篇 paper 最漂亮的一行。意思是:−log α 的梯度 = TV 梯度 × 1/α。1/α 是自适应 boost:
- α 小(随机初始化、训练早期):1/α 大,把 TV 那个消失的梯度强行放大回 O(1/√k) — 量级追上 KL。
- α 大(训练后期):1/α 接近 1,梯度恢复成纯 TV 的尺度,正好这时 q 已经接近 p,TV 本身的梯度也不再病态。
5.3 Worked example:同 CE 不同 α
这是一个最能击穿"KL 训练 = α 训练"幻觉的对比。设 V=4,target p = (0.50, 0.30, 0.15, 0.05)。
| 方案 | q | CE = −Σ p log q | α = Σ min(p,q) |
|---|---|---|---|
| qA(KL 偏好) | (0.40, 0.30, 0.20, 0.10) | ≈ 1.343 | 0.40+0.30+0.15+0.05 = 0.90 |
| qB(TV 偏好) | (0.55, 0.30, 0.10, 0.05) | ≈ 1.355 | 0.50+0.30+0.10+0.05 = 0.95 |
| qC(spread) | (0.35, 0.30, 0.20, 0.15) | ≈ 1.347 | 0.35+0.30+0.15+0.05 = 0.85 |
读这个表:qA 的 CE 最低(KL 也最低,因 KL=CE−H(p)),但α 反而比 qB 低 0.05。qB 牺牲了一点 CE,换来 acceptance rate 提升。qC 最铺得均匀,CE 居中但 α 最差。这就是"KL 最小 ≠ α 最大"的具体数值证据。LK 的两个 loss 都会在 qA 和 qB 之间偏向后者。
6 · Vocabulary truncation:KL 多了一层 proxy,LK 没这个问题
EAGLE-3 + FR-Spec 配置下,draft 的 LM head 只输出一个截断词表(比如 32k,target 词表 128k),其它 token 的 qi=0。这对 KL 是致命的:pi>0 而 qi=0 → KL=∞。常规做法是把 target 也截:p̃ = softmax(m⊙zp),把截断外的 logit 设 −∞,再去做 KL(p̃∥q)。
这就是论文一句犀利的吐槽:"makes KL a proxy of a proxy"(代理的代理)。一层是 KL≠α,另一层是 p̃≠p。
LK 怎么处理?从 α=Σmin(pi,qi) 看,被截掉的 token 上 qi=0,贡献 min(pi, 0)=0,自然被忽略 — 不需要修改 p。LαLK 和 hybrid 的 TV 部分都直接对原始 p 做优化,不引入额外近似。这是 LK 落地 EAGLE-3+FR-Spec 时的一个免费红利。
7 · 公式串与数值直觉
7.1 推导链条(§A 复述)
- softmax Jacobian:∂qi/∂zq,j = qi(δij − qj)。
- KL:∂KL/∂zq,j = −Σi (pi/qi)·qi(δij − qj) = −pj + qj·Σpi = qj − pj。
- TV:链式 sign 经过 softmax Jacobian 得到 ½qj(sj − Eq[s])。
- −log α:∂(−log α)/∂z = (1/α)·(−∂α/∂z) = (1/α)·∂TV/∂z(因 α=1−TV)。
7.2 在 p 是 point mass 的退化情况
若 p(x*)=1,其它为 0,则 α = Σ min(pi, qi) = min(1, q(x*)) = q(x*)(因 q 是概率)。所以:
这是个不错的 sanity check:当 target 退化为硬标签,LK 等于普通的 cross-entropy 训练。它是 NLL 在软标签(distillation)情况下的"正确" generalization——比直接做 soft-label CE 更直接对 α。
7.3 数值敏感性表(η 选择)
| α 当前值 | η=1, λ | η=3, λ | η=10, λ |
|---|---|---|---|
| 0.0 | 1.00 | 1.00 | 1.00 |
| 0.2 | 0.82 | 0.55 | 0.135 |
| 0.5 | 0.61 | 0.22 | 0.0067 |
| 0.8 | 0.45 | 0.091 | 3.4e-4 |
读法:η 越大,KL 越早退场。EAGLE-3 / MTP 这种α 涨得快的架构默认 η=3 即可;MEDUSA 因为头之间独立、α 涨得慢,论文用 η=10 强行推快 TV 接管。这是个有点反常识但可解释的调参逻辑——你想让 TV 在 α 还没涨太高的时候就接管,因为再不接管它就不打磨了。
8 · 实验关键结果
8.1 跨架构(LLaMA-3.1-8B,Table 1 节选)
| Draft | Loss | MT-bench τ | HumanEval τ | GSM8K τ |
|---|---|---|---|---|
| EAGLE-3 (T=1) | KL | 3.39 | 4.31 | 3.88 |
| TV (纯) | 2.67 | 3.25 | 3.12 | |
| LλLK η=3 | 3.48 | 4.52 | 4.02 | |
| MEDUSA (T=1) | KL | 1.72 | 2.02 | 1.81 |
| LλLK η=10 | 1.85 | 2.22 | 1.92 | |
| MLP (T=1) | KL | 2.13 | 2.16 | 2.16 |
| LλLK η=3 | 2.19 | 2.62 | 2.18 |
关键观察:
- 纯 TV 全面输给 KL —— 印证 §3 的 TV 梯度消失分析,这不是抽象数学,是真训练曲线。
- 容量越小,LK 增益越大:MEDUSA + 7.8%, MLP + 8.3%, EAGLE-3 + 3.8% — 这与"limited capacity → proxy gap 越严重"的理论一致。
- 固定 λ=0.5 几乎没提升 — 自适应 schedule 是非平凡的。
8.2 跨 target scale(Table 2)
| Target | Mean τ KL (T=1) | Mean τ LλLK (T=1) | Δ% |
|---|---|---|---|
| LLaMA-3.1-8B | 3.86 | 4.01 | +3.9 |
| LLaMA-3.3-70B | 4.50 | 4.66 | +3.5 |
| GPT-OSS-20B | 3.17 | 3.29 | +3.8 |
| GPT-OSS-120B | 2.46 | 2.65 | +7.7 |
| Qwen3-235B | 3.77 | 4.08 | +8.2 |
| DeepSeek-V3 685B (MTP) | 4.43 | 4.68 | +5.6 |
大模型(尤其 MoE)+小 dense draft 的容量差越悬殊,LK 的相对提升越大 — Qwen3-235B 的 +8.2% 是论文最亮的数字。DeepSeek-V3 是 fine-tune MTP(不是从头训),也提升 5.6%,印证 LK "在 q,p 距离大时都能赢" 的普适性。
8.3 一些容易被略过的细节
- vLLM 默认实现是greedy 抽 draft token(即使 T>0),这违反了 SD 的理论假设。作者打了 patch 实现真正的 rejection sampling(基于 vllm-project/vllm#20459)。这点足够独立成 paper—— evaluation 设置下 T=1 的 acceptance 数字,在 unpatched vLLM 上是被低估的(见 §D)。
- Loss 跨 head 用 γ=0.8 exponential decay,前面的 head 权重大。理由:earlier head 的 acceptance rate 决定整段是否能继续,影响 τ 的杠杆远大于后头的。
- MTP 微调里 hybrid loss 自带"为后头 head 补 KL"的副作用:后头 α 低 → λ 大 → KL 主导 → 拉回到 trust region;前头 α 已高 → λ 小 → TV 精修。同一条公式同时处理两类 head。
9 · 与同类工作对比
| Loss / 方法 | 对 α 的关系 | 从随机初始化能训吗? | truncated vocab 兼容? | 额外开销 |
|---|---|---|---|---|
| CE / KL(p∥q) | proxy(同 global opt) | 能 | 需把 p 也截 | 0 |
| Reverse KL(q∥p) | proxy,mode seeking | 能 | 兼容 | 0 |
| MSE on hidden(EAGLE) | 很间接 | 能 | 看实现 | + regression head |
| EAGLE-3 multi-layer KD | 架构层面,loss 仍是 KL | 能 | 同 KL | + feature fusion |
| DistillSpec(reverse KL/TV) | TV 直接对 α | 需 pretrained LM 起步 | 取决于 divergence | 0 |
| MARS(threshold) | inference-time filter | — | — | + inference 逻辑 |
| LαLK | 直接(−log α) | 能(1/α 自适应放大) | 天然兼容 | 0 |
| LλLK | 直接(后期) | 能(KL 起步) | 天然兼容(TV 部分) | 0 |
核心差异化:
- vs CE/KL distillation:同样易实现、同样无开销,但目标对齐。这是真正的 drop-in replacement。
- vs MSE on hidden / EAGLE-3 multi-layer KD:这些是架构 / 特征层面的努力,正交于 LK。论文显示的实验是直接套到 EAGLE-3 上,所以这两条是叠加的关系而非互斥。
- vs DistillSpec:DistillSpec 第一个意识到 TV 才是对的目标,但他们没解决 from-scratch 训练的梯度消失问题。LK 的两条公式是工程层面的"修复方案"。
- vs MARS-类阈值方法:那些是事后挽救(reject 掉低 α head),LK 是事前(从训练目标就对齐)。两者甚至可以叠加。
10 · 局限 / 个人 take / 待验证
- "τ" 不等于 wall-clock speedup。论文只报 τ,没给端到端 latency 数字。τ 涨 8% 在 LM head 不主导的场景下大致 = 8% throughput,但在 FR-Spec 截断词表的场景下要看 head 计算占比。
- η 的调参依然是 manual(EAGLE-3 用 3,MEDUSA 用 10)。这暗示 LK 的"自适应"并不全自动,而是把"调温度"的负担从 schedule 函数式上转移到了 η 这个常数。Future work 里作者也提到 learnable per-head loss aggregation 会更好。
- α 是 batch 聚合估计,带方差。早期 batch 小时 sg[α] 噪声会让 λ 抖动,论文没讨论 batch size 的稳健性。
- top-k / top-p 部署没纳入。生产里目标分布常带 top-p 截断,这会让 α 的实际分布跟训练时定义有 gap。这一条作者在 future work 里直接承认。
- K=6/7 之外的延展性。Figure 1 显示 K 越大 LK 优势越明显,但论文没系统扫描 K∈{2,3,…,15};EAGLE-3 论文里 K 上去后 τ 边际下降,LK 是否真改这条曲线值得验证。
- p 的来源:训练用 target sample 自己生成的 660k Infinity-Instruct,这是 on-policy。如果换成静态语料(如 ShareGPT),LK vs KL 的差距会怎么变?未做对照。
需要验证的几个问题:
- LαLK 在更大 V(比如 200k+)和更深 draft(2-3 层 transformer)上,1/α 放大是否足够?会不会触发数值不稳定(α 接近 0 时 1/α 爆炸)?
- 把 LK 嫁接到 RL post-training 里的 draft model(参考 ReSpec / NeMo-RL),acceptance length 抖动场景下 sg[α] 估计的方差会不会让 λ 跑飞?
- 对 reverse KL(mode seeking)直接也做一次 hybrid:λ·rKL + (1−λ)·TV 是否优于 λ·KL + (1−λ)·TV?论文没尝试。
- p 是 point mass 时退化为 NLL,中间区间(p 是 top-p=0.9 截断的)是否退化为某个加权 NLL?是否解释了 LK 在 GSM8K(高 confidence target)上提升相对小?
- Loss landscape 在 p≈q 附近 KL 与 TV 的 Hessian 性质对比 — 是否能解释为什么 hybrid 略胜 likelihood?