L-MTP: Leap Multi-Token Prediction Beyond Adjacent Context

Xiaohao Liu, Xiaobo Xia 等 · NUS / HIT / Tsinghua / CAS / CSU · NeurIPS 2025 · arXiv:2505.17505
关键词: multi-token prediction · leap mechanism · self-speculative decoding · tree attention · long-range dependency

速读卡片 (TL;DR)

一句话:把标准 MTP 的 adjacent 多 token 预测 (t+1, t+2, ..., t+n) 改成leap (t+1, t+3, t+5, ..., t+k(n−1)+1),逼模型跳过短程相关、看更远;配合 "looking backward" 的 self-speculative decoding 把跳过的位置从历史预测里捡回来。

k=2, n=4
默认 leap stride / heads
+22%
在 Medusa 上换 L-MTP decoding 的相对加速
~4×
配合 spec decoding 端到端最高加速

立场:论文卖点是 "broader + faster";真正新意在 训练目标的 leap 化looking backward 解码 两个组件。性能提升弱(常常和 MTP 持平 / 互有胜负),工程上更像"换一个 supervision pattern 顺带加 22% 速度"的小补丁,不是革命。也可参见同系列:Gloeckle 2024 标准 MTP 精读


1 · 动机:adjacent MTP 不够,为什么要 leap

1.1 历史脉络:从 NTP 到 MTP,再到 leap

语言模型主流目标一直是 NTP (next-token prediction):每个 t 位置只学 p(xt+1|x≤t)。简单、一致,但有个老问题:它把所有"未来"压缩到下一个 token 上,模型自然倾向于最容易预测的局部相关(prefix → 紧邻 suffix),长程结构反而被边缘化。

2020 年 ProphetNet (Qi 等) 提出 n-step-ahead prediction;2024 年 Gloeckle (本论文最直接的前作) 把它推到 LLM 预训练:

三种 prediction paradigm 在位置 t 同时预测什么 位置: t t+1 t+2 t+3 t+4 t+5 t+6 t+7 NTP MTP 1 2 3 4 L-MTP k=2 1 skip 2 skip 3 skip 4 1 head 4 heads, adjacent 4 heads, leap
给定 4 个 prediction head, MTP 监督的是 [t+1, t+2, t+3, t+4] —— 全部 adjacent;L-MTP (k=2) 监督的是 [t+1, t+3, t+5, t+7] —— 跳过中间。两者的 head 数和 forward 成本相同,差别只在每个 head 学的是哪个位置

1.2 别的方案为什么不够

"看远一点"在 NLP 里早就被翻来覆去做过。L-MTP 的位置感得放在这张表里看才清楚:

路径怎么"看远"代价 / 局限
Long context (RoPE / ALiBi 扩展)把窗口拉长到 1M token不改变训练目标,模型仍然只学 t+1,远端用得上但不一定预测得准
n-gram / ProphetNet预测下 n 个连续 n-gram仍是 adjacent;短程相关被反复重学
标准 MTP (Gloeckle 2024)n 个 head 各预测 t+1..t+n第 1 head 学到的还是 next-token,长程能力靠"附加监督"渗透到 backbone
DeepSeek-V3 sequential MTPn 个 head 链式预测,后一个 head 看前一个 head 的输出每个 head 仍只跨 1 步;靠 chain depth 累积长程
MuToR (register tokens)插入特殊 register token 让模型规划改 input,不改 supervision pattern
FSP (Future Summary Prediction)预测未来一段的 summary embedding不是 token 级监督,需要额外 summary 模型
L-MTP (本文)n 个 head 各预测 t+1, t+1+k, t+1+2k, ...跳过最强信号(adjacent),强迫每个 head 学跨 stride 的依赖

本质区分:绝大多数加速 / 长程方案都不动 supervision pattern。L-MTP 改的是每个 head 学什么位置这一最基础的设定。

1.3 为什么 "leap" 这事不平凡

"加大 stride 让 head 学远点"听上去 1 行代码就完事,但仔细想三个非平凡处:

  1. Combinatorial choice of leap pattern.固定 stride k? 学习的 stride? 随机 stride? 多 stride 混合?如果 head i 应该看 t+i 还是 t+f(i)? 论文给的答案是固定整数 stride k=2,但这是 design 选择,不是定理 — 后面 §3 会说为什么还是固定的合理。
  2. Training signal sparsity.原来 MTP 在每个 t 位置都监督 [t+1, t+2, t+3, t+4],4 个连续 ground truth;L-MTP 监督 [t+1, t+3, t+5, t+7],中间 t+2, t+4, t+6 没有显式监督。这些位置的 token 在哪儿被学? 答案:它们在 t−1, t−3 等位置上以"head 1, head 2..." 角色被监督过。即跨 t 滑动窗口拼起来才覆盖整个序列,这意味着每个具体 token 的训练信号密度比 adjacent MTP 稀疏了 1/k。
  3. Inference 时的"洞"必须填上.预测出 [t+1, t+3, t+5, t+7] 是好的,但要输出连续序列还得有 t+2, t+4, t+6。论文的解法:这些洞在上一个 decoding step (从 t−1 出发预测 t, t+2, t+4, t+6) 已经被预测出来了 — 所以 looking backward 重用历史预测。这个机制把"稀疏 supervision"和"密集输出"勉强对齐,但需要 tree attention 配合 verification 才能不破坏 lossless 性质。
关键洞察:L-MTP 的 supervision 在位置维度是稀疏的,但在时间维度(滑动 t)是密集的。看似简单的 stride,实际上把时间-位置 2D 监督平面从"对角带状"变成了"平行斜线"。这是论文真正改动的东西。

2 · 背景速查

2.1 关键术语

术语含义
NTPnext-token prediction:p(xt+1|x≤t),单 head
MTPmulti-token prediction:n 个 head,每个预测 xt+i,i ∈ {1,...,n} adjacent
L-MTP本文 leap MTP:每个 head 预测 xt+k(i−1)+1,跨度 k 的 leap
k (leap stride)每两个相邻预测位置之间跳过的 token 数 + 1;k=1 退化为 MTP,k=2 默认
n (heads)总 head 数;最远预测位置为 t + k(n−1) + 1
Self-speculative decoding用模型自己的多 head 当 draft,再用同一模型的 next-token head 验证(无外部 draft)
Tree attention把多个候选 draft path 组织成树,用 mask 让每个位置只 attend 自己祖先,一次前向并行验证多条 path
Looking backwardL-MTP 推理时,把当前 step 缺失的中间位置用上一步已预测过的同位置补上
Acceptance length L每个 speculation step 平均产出几个 token;越大越快
Attenuation论文定义的性质:p(xt+i|x≤t) 随 i 单调递减(预测越远越不确定)

2.2 标准 MTP + self-speculative decoding 复习

给定 backbone θ′ 出 hidden z,n 个 head θi 各自独立投影到 vocab:

p(xt+1...t+n|x≤t) = ∏i=1..n p(xt+i|z≤t; θi) · p(z≤t|x≤t; θ′)

推理流程 (lossless 版):

  1. Prediction:一次 forward 得到 n 个候选 token
  2. Verification:把这 n 个候选当作 prefix 再喂回 backbone,并行得到每个位置的 next-token 分布
  3. Acceptance:rejection sampling,直到第一个被拒绝;接受的 token 写入 KV cache,继续下一轮

这套和 Leviathan 2023 的 spec decoding 等价 — 输出分布严格等同于 vanilla AR。


3 · L-MTP 训练目标:把 supervision 打散到 leap 位置

3.1 目标函数

L-MTP 用两阶段训练:

Stage 1: Head warm-up.backbone + head 1 (=NTP head) 全冻结,只训新增的 head 2..n。监督来自self-distillation:用 base LLM 自己的输出作为 target distribution。

L(1)L-MTP = − ΣT log p(x[t+k(n−1)+1, ..., t+k+1] | z≤t; {θi}i>1)

Stage 2: Full model tuning.用 LoRA (rank 32) 解冻 backbone 一起训:

L(2)L-MTP = − ΣT [ log p(xt+1|x≤t; θ′,θ1) + β · log p(x[t+k(n−1)+1, ..., t+k+1]|x≤t; θ′,{θi}i>1) ]

β 控制 leap heads 的权重。注意 head 1 永远是 NTP — 这一条必须保留,因为推理时它是 verifier。

3.2 一个具体的 Worked Example

设输入序列(token id)是:

x₁  x₂  x₃  x₄  x₅  x₆  x₇  x₈  x₉  x₁₀
 ↑
 t=4 (假设当前位置)

n=4, k=2 时,在位置 t=4,各 head 监督的目标分别是:

Head目标位置 t+k(i−1)+1具体 token跨度 (距 t)
Head 1 (NTP)t+1 = 5x₅+1
Head 2t+k+1 = 7x₇+3
Head 3t+2k+1 = 9x₉+5
Head 4t+3k+1 = 11x₁₁+7

对比标准 MTP (k=1):head 1..4 监督 x₅, x₆, x₇, x₈,跨度 +1..+4。

关键观察:L-MTP 在 t=4 时监督 x₆, x₈, x₁₀。但是在 t=3 时,它会监督 x₄, x₆, x₈, x₁₀ — 所以 x₆ 仍然在某个 t 位置被学到了,只是从更远的距离(+3 而不是 +2)。这就是"稀疏 supervision 在时间上密集化"的具体含义。

3.3 为什么这样设计 (反向思考)

Leap pattern 的设计选择空间 A. 固定 stride (本文) positions = {t+1, t+3, t+5, t+7} + 简单 / 推理时易做 looking backward + tree attention mask 容易构造 B. 学习 stride (per-position) positions = f(context) → 学出来的 set − 推理时 head 不知道学到了哪里 − verification 路径不固定 → 难重用 C. 随机 stride (训练时采样) 每 batch 采一个 k ~ Uniform{1..K} − 推理时仍要选定 k ~ 论文未尝试,但 future work 提到 entropy-aware D. 多 stride 混合 heads 同时学 k=1, k=2, k=3 − head 数翻倍, 训练开销大 ~ 退化情况:k=1 时变 Hydra/Medusa
L-MTP 选 A 而非 B/C/D 的核心原因是推理可逆:固定 k 让每一步缺失的中间位置必然落在固定的"前 k−1 步"里,可以确定性地从历史预测里取回(§4.1)。学习 stride 会破坏这个对齐。

反向论证:如果 L-MTP 不保留 head 1 = NTP,而是把 4 个 head 都做 leap (比如 t+2, t+4, t+6, t+8),会怎样?

所以 head 1 = NTP 是架构上的硬约束。论文称为 "we keep the original NTP head for verification"。

3.4 训练目标可视化:位置 × 时间二维图

"哪个 t 在监督哪个目标位置" 二维图 (n=4) MTP (k=1) — 对角带状 target position → t (input) L-MTP (k=2) — 平行斜线 target position → t (input)
左:MTP 的 (t, target position) 监督在每行上是4 格连续;右:L-MTP 是 4 格隔 1 个。横向看任何一个目标位置都被多个 t 监督到 — 这就是为什么 token 不会"漏学"。绿框是 §3.2 的 worked example (t=4 监督 x₉)。

4 · Looking backward 解码 + tree attention

4.1 推理时如何把"洞"补上

L-MTP 在 t 输出 [t+1, t+3, t+5, t+7]。要得到连续序列还需要 [t+2, t+4, t+6]。直接重新跑 forward 就回到 NTP 速度了。论文的关键技巧是looking backward:

这些位置在 t−1 时已经被 head 们预测过了:从 t−1 出发,heads 监督的是 [t, t+2, t+4, t+6]。所以 [t+2, t+4, t+6] 的候选是上一步的"副产品",直接读 cache 即可。

Looking backward: 把缺失中间位置从历史步取回 位置 t−1 t t+1 t+2 t+3 t+4 t+5 t+6 t+7 step (t−1): a b c d step (t): A B C D 合并: a A b B c C d D 每个 step 输出 4 个 leap token, 与上一 step 的 4 个交错合并 → 长度 7 的连续候选 [t, t+1, ..., t+6] 再经 verifier 验证一次决定接受到哪
step (t−1) 已经预测了偶数步 [t, t+2, t+4, t+6] (蓝);step (t) 预测奇数步 [t+1, t+3, t+5, t+7] (橙)。两组在位置维度无重叠,直接交错就得到 8 个连续位置的候选。注:第一步开始时不存在"上一步",所以 t=1 仍要从空起步。

4.2 Tree attention:让多个候选并行验证

但仅仅交错还不够。L-MTP 让每个 head 的输出可能是 top-k (比如每个 head 取 top 2 个候选 token),这就形成了一棵候选树。论文复用 Medusa / EAGLE 的 tree-attention mask:每个候选 token 只能 attend 到自己 path 上的祖先节点,这样 verifier 一次 forward 同时验证多条 path。

Tree attention 验证 leap 候选 (示意 n=3, top-2 per head) x_t A A' B₁ B₂ B'₁ B'₂ C C C C C C C C head 1 (t+1) head 2 (t+3) head 3 (t+5)
n=3, k=2,每个 head 取 top-2,生成 2³=8 条 path。Tree attention mask 让每个 path 上的 head-2 candidate 只看 head-1 自己 path 的 ancestor (而不是兄弟 path 的 candidate)。Verifier 一次 forward 算所有 path 的 next-token 概率,再做 rejection,选出最长被接受的 path。注意:head 2 监督的是 t+3 而不是 t+2,所以 verifier 验证时也要按 leap 位置查表。

4.3 直接拿 L-MTP decoding 套到现成 Medusa 模型

论文一个有意思的实验:Vicuna 7B/13B + Medusa(本身是 adjacent MTP)的模型,只把解码策略换成 L-MTP 的 looking backward,不重新训练 — 仍然能拿到 22% 相对加速 (1.83× → 2.32× on GSM8K, Vicuna 7B)。

这说明 leap 解码本身就是一个独立的工程优化:即使 head 当初是按 adjacent 训的,把"look backward"的复用机制套上去也能省时间。也就是 paper 的"训练侧"和"推理侧"两个 contribution 在某种程度上可以解耦。


5 · 理论:attenuation × consistency × 加速上界

5.1 两个性质

Definition 1 (Attenuation). 模型预测越远的 token,边际概率越低:p(xt+1|x≤t) > p(xt+2|x≤t) > ... — 因为 context x≤t 距离 xt+i 越远,信息量越少。

Assumption 2 (Consistency). 假设 Ex≤t∼D[p(xt+i|x≤t)] = f(i),即在数据分布上平均后,这个边际概率只依赖于 horizon i,与具体 input 无关。论文假设 f(i) = exp(−γ(i−1)),γ ≥ 0 为衰减系数。

5.2 Acceptance length 期望

E[L]s = Σm=1..ni=1..m f(i)
E[L]l = Σm=1..k(n−1)+1i=1..m f(i + (i−1) mod k)

下标 s 是标准 (sequential) MTP, l 是 leap。差别在于 leap 用了 t−1 步的 hidden state 来"补"中间位置(因此每隔 k 个位置 horizon 重置 → mod k)。

5.3 Theorem 3 — Less attenuation, more speed-up

存在常数 C > 0,当 γn² ≤ C (即 γ = O(1/n²)) 时,渐进有 E[L]l > E[L]s。也就是 当模型 attenuation 不太严重,leap 总比 adjacent 期望长度更大。

5.4 数值直觉

γnE[L]_s (MTP)E[L]_l (L-MTP, k=2)L-MTP 优势
0.014~3.94~6.85远胜
0.054~3.50~5.50明显
0.104~2.85~3.40略优
0.304~1.85~1.55劣 (γ 太大)

物理直觉:γ 大 = 模型对未来很不确定 → 即使 leap 把 horizon 延长,接受概率也撑不住;γ 小 = 模型本身就有强长程能力 → leap 把"反正能预测"的远端位置直接拿来用,白赚。

反向警示:这条定理也意味着一个反直觉的事实 — backbone 越强,L-MTP 相对 MTP 的优势越大;backbone 越弱,L-MTP 反而吃亏。这与表 1 中 Gemma-12B (强) 上 L-MTP 大幅胜过 MTP、Qwen2.5-7B (中等) 上互有胜负的实验现象基本吻合。

6 · 实验:性能 / 加速 / acceptance / 数据规模

6.1 性能 (Table 1, 选载)

ModelStrategyGSM8KHumanEval+IFEvalAvg
Llama3.2-3BNTP3.7117.6820.7425.52
MTP3.8718.2918.5925.46
L-MTP5.9120.7320.3826.68
Qwen2.5-7BNTP52.9969.5143.4163.86
MTP52.6269.5141.4963.34
L-MTP56.0371.9544.1264.16
Gemma3-12BNTP13.4256.1029.3846.87
MTP5.6154.2730.4645.17
L-MTP26.3855.4933.0949.58

解读:
L-MTP 在数学任务 (GSM8K) 上提升最显著:Llama3.2 3.87→5.91, Gemma3-12B 5.61→26.38(后者 4.7×)。这与"长程依赖" 直觉相符 — 数学需要规划。
L-MTP 经常优于 NTP,但 MTP 不一定。论文坦白 NTP 自己有时劣于 base(因为是用通用 instruction 数据微调,不是预训练),L-MTP 能"捞"回来 — 但 base 仍是天花板。
③ Avg 提升幅度大多 < 2 pt,在噪声范围内。性能不是论文最强的卖点,加速才是。

6.2 加速 (Figure 6 + Table 2)

self-speculative decoding 配合 L-MTP looking backward,在 GSM8K / MBPP / IFEval 上对 NTP 的加速比一般在 2–4×,且常常比 MTP 同设置高。

ModelStrategyGSM8KMBPP
Vicuna 7BMTP (Medusa)1.83×1.97×
L-MTP (套用 leap decoding)2.32× (+27%)2.01×
Vicuna 13BMTP (Medusa)2.24×1.98×
L-MTP (套用)2.43× (+8%)2.02×

关键观察:解码侧 L-MTP 单独提供加速,而且不需要重新训练。"22% relative boosting" 这个论文摘要数字就是 GSM8K Vicuna 7B 的相对增长。

6.3 Per-position accuracy (RQ3, Figure 7)

论文测每个 head 的预测准确率。结果:

6.4 Myopic generation (Figure 8)

论文挑了一个有点反 NTP 教科书的发现:scale 越大,head 2..n 的 prediction accuracy 越低。即"模型越大,越只关心紧邻"。论文管这叫 myopia of NTP pre-training

这是支持 leap 训练的重要证据 — adjacent MTP 不能纠正,而 L-MTP 强迫 head 看更远,应当能缓解。但论文没直接验证 "L-MTP 训出来的模型 myopia 减少了多少" 这个最该测的问题(只在 leap 设定下测到了 acc 模式与理论一致)。

6.5 Data scale (Figure 9)


7 · 与同类工作对比

工作核心机制与 L-MTP 差别
Gloeckle 2024 (标准 MTP) [精读]n 个 head 各预测 t+1..t+n adjacentL-MTP 改 supervision pattern,head 数 / 架构相同
DeepSeek-V3 sequential MTPn 个 head 链式;head i+1 看 head i 输出L-MTP heads 并行不串行;但每 head 跨距更大
Medusa (Cai 2024)n 个 FFN head + tree attention 验证仍 adjacent;L-MTP 把它的解码改成 leap 加 +22%
EAGLE (Li 2024)复用 hidden + 自回归 draft headExternal-style,与 L-MTP 正交可结合
Hydra (Ankner 2024)sequentially-dependent draft headshead 之间有依赖,改善 acceptance;L-MTP 只改位置
MuToR (register tokens)插入 [REG] token 让模型规划动 input,L-MTP 动 supervision
FSP (Future Summary Prediction)预测一段未来的 summary embedding非 token 级目标,不易做 spec decoding
ProphetNet (Qi 2020)n-stream decoder 预测 future n-gram重在 seq2seq 预训练,无 spec decoding 视角
定位:L-MTP 是对 Gloeckle MTP 的最小可行改动 — 不动 backbone, 不动 head 架构, 只改"head 2..n 学的目标位置 + 推理时的 token 拼接方式"。所以它和 EAGLE / Medusa / Hydra 这些 "改 head 架构以提高 acceptance" 的工作几乎完全正交,可以叠加。

8 · 局限 / 个人 take / 待验证问题

论文承认 / 暴露的局限

我的个人 take

  1. 真正的贡献是解码侧,不是训练侧。Vicuna-Medusa 直接套 leap decoding +22% 那个数字含金量最高 — 它说明 leap 的"复用上一步预测"思路是个独立有效的工程优化。
  2. 训练侧的故事不太自洽。论文用 "long-range dependency" 做卖点,但每个具体 head 学的还是固定 horizon 的 adjacent 关系(只是远了点),并没有真正让模型"跳过中间步推理"。要做到后者得是 GSM8K 风格的 step-skipping,而不是 token-level 跳。
  3. k=2 + n=4 这个组合让最远 horizon 是 t+7,本质上和 "n=8 adjacent MTP" 在覆盖范围上类似,但是训练成本相同 (n=4 个 head)。这是个实在的 trade-off:同样 head 数下覆盖更广,代价是局部密度降低。

想验证的问题

  1. 如果换成 k=2 但 n=8 (覆盖到 t+15) vs k=4 n=4 (覆盖到 t+13),哪种更优?论文没扫 k 和 n 的联合空间。
  2. Theorem 3 假设 f(i) = exp(−γ(i−1))。真实 f 不一定指数衰减(Figure 7 看起来更像 sigmoid),这对定理结论有多大影响?
  3. L-MTP 训完的模型,即使用 NTP 解码 (单 head),性能和 NTP 训的差距如何? — 即"leap 监督"是否真的让 backbone 学到更长程结构,还是只是 head 2..n 的事。
  4. 把 L-MTP 的 looking backward 套到 EAGLE-3 (而不是 Medusa) 上,加速能进一步加吗?EAGLE 本身 acceptance 已经很高,leap 的相对收益可能挤压。
  5. RL post-training 场景下:论文 future work 提到结合 RL,但没有数据。leap supervision 对 reasoning trace 的连贯性 (尤其 chain-of-thought 中间步) 是否会有破坏作用?
  6. 最远 head 的 entropy 是否变高?若变高,sampling 阶段是否需要不同温度策略?

记忆点

立场 把 MTP 的 head 监督位置从 [t+1..t+n] 改成 [t+1, t+1+k, t+1+2k, ...],其余基本不动
公式 监督位置 = t + k(i−1) + 1, i ∈ [n];默认 k=2, n=4 → 最远 t+7
推理 Looking backward 把当前 step 缺失的中间位置从上一步预测里取回 → 拼成连续候选 → tree attention 验证
理论 当 γ = O(1/n²) (低 attenuation) 时 L-MTP 期望 acceptance 长度 > MTP
数字 Medusa 上换 leap decoding +22% (1.83→2.32×); 训练版 GSM8K Gemma-12B 5.61→26.38
陷阱 性能提升大多在噪声内;k 固定为 2;所有实验都是 LoRA 微调,不是预训练

精读笔记 v1 · 配套论文 PDF 在 /data/szhang967/papers/paper-notes/models/LMTP_2505.17505.pdf
See also: 14_MTP_Gloeckle_2404.19737.html (foundational MTP)