Better & Faster LLMs via Multi-token Prediction

Meta FAIR · Gloeckle, Idrissi, Rozière, Lopez-Paz, Synnaeve · ICML 2024 · arXiv:2404.19737
关键词: MTP · multi-token prediction · self-speculative decoding · pretraining loss · sample efficiency

速读卡片 (TL;DR)

一句话:把 LLM 的预训练 loss 从"预测下一个 token"扩展成"用 n 个独立 head 同时预测下 n 个 token,共享一个 transformer trunk",训练时间不变、内存通过精心调度也不增,下游 code 显著变强,且 inference 时几个 head 直接拿来做 self-speculative decoding,3× 免费加速。

+12% / +17%
13B HumanEval / MBPP pass@1
3.0× / 6.4×
code (4-token) / byte (8-byte) 推理加速
0 overhead
训练时间 / wall clock

立场:这是 MTP 谱系的奠基论文 — 后续 DeepSeek-V3、MiMo-V2 的 MTP head、Medusa 等都借鉴这套"shared trunk + n heads + 仔细调度反传"。生效门槛是scale(7B 以上)且偏向 generative/coding 任务。


1 · 动机:NTP 的内在缺陷与"预知未来"的诱惑

1.1 历史脉络:NTP 是 LLM 的全部,但也是它的天花板

从 GPT-2 到 Llama,LLM 的预训练目标几乎只有一个:next-token prediction (NTP),即最小化 −Σ log P(x_{t+1} | x_{1:t})。这套范式简单、可并行(teacher forcing)、跟下游 autoregressive generation 形式上一致。它给我们带来了 GPT-3、Llama、Gemini。

但论文一句话点穿它的结构性缺陷:

"State-of-the-art next-token predictors call for orders of magnitude more data than human children to arrive at the same level of fluency."

翻译:同样达到流利的语言水平,LLM 要烧掉比人类儿童多几个数量级的语料。这意味着 NTP 这把锤子把 sample efficiency 浪费得很厉害。问题在哪?论文给出三层诊断:

  1. Teacher forcing 锁死了"局部模式":训练时模型每步都看到 ground-truth 的过去,只需要根据上一个 token 预测下一个;它学不到"我现在选了 A,几个 token 后会 derail 还是收敛"这种长程依赖的因果。
  2. 所有 token loss 等权:语言里有"句子里的 the"和"代码里的函数名"这种关键 choice point,后者一错满盘皆输,但 cross-entropy 给它和 the 一样的权重。
  3. 训练-推理分布不匹配 (exposure bias):训练时永远看到 gold history,推理时只能看到自己生成的东西。误差累积没人管。

1.2 让模型"预知未来"会强迫它学到什么

设想一下:把训练目标从"预测 t+1"扩展成"同时预测 t+1, t+2, t+3, t+4",会发生什么?

所以 MTP 不是简单的"多预测几个未来 token";它是用 auxiliary task 把"长程一致性"从神经网络里榨出来

1.3 别人的方案为什么不够

替代方案动了什么致命缺陷
UniLM (Dong 2019)BERT 风格 mask + 多种 attention只对 ~15% token 反传 — 大头数据浪费
UL2 / span corruption (Tay 2022)encoder-decoder span 填空同上,15–25% token 实际反传;且不是纯 causal,部署不友好
XLNet (Yang 2019) permuted LM随机 permute 序列置换太难重建,实际只对 15% 反传
Qi 2020 multi-token多 token 预测,但 residual stream 复制 n 份参数 / 计算膨胀 n 倍,无法做 compute-matched 比较
Medusa (Cai 2024)finetune 阶段加多 token head只为 inference 加速服务,不影响 base model 的能力
MTP (本文)pretrain 时 n 个轻量 head 共享 trunk,所有 token 都反传100% token 参与,compute-matched,base 能力提升
关键差别: 之前的非 NTP 工作要么降低 token 利用率(BERT/UL2 等只反传一小部分),要么只为 inference 加速(Medusa 是 finetune 阶段塞进去的)。MTP 是第一个把"多 token 同时预测"作为 pretraining 主目标、且保持 compute-matched 与 100% token 反传的方案。

2 · 背景速查

2.1 关键术语

术语含义
NTPNext-token prediction,标准 LLM 预训练 loss L₁ = −Σ log P(x_{t+1} | x_{1:t})
MTPMulti-token prediction,扩展为 L_n = −Σ Σ_{i=1..n} log P(x_{t+i} | x_{1:t})
Shared trunk f_s主 transformer 主干,从 context 产出隐藏表示 z_t
Output head f_{h_i}第 i 个独立预测头(轻量 transformer 层),负责预测 t+i
Unembedding f_u共享的反嵌入矩阵,把 hidden 投到 vocab logits
Self-speculative decoding用模型自己额外的 head 当 draft,用主 head 当 verifier 的 spec decoding 变体,无需额外 draft model
Compute-matched对比时保持总参数量 / 总 FLOPs 一致 — 加 n−1 个 head 就从 trunk 拿掉 n−1 层
Induction headOlsson 2022 提出的 in-context 复制能力的最小机制
Choice point语言生成中"高熵 / 后果重大"的关键位置(如函数名、推理步骤的论断)

2.2 NTP 的标准回顾

给定 token 序列 x_1, x_2, ..., x_T,模型由 transformer 主干 + 最后一层 unembedding 组成,产出每个位置的 vocab 分布。Loss 是所有位置的 cross-entropy 之和,反传一次,所有 token 都参与。这是 GPT 系训练的全部公式。


3 · 方法:shared trunk + n 独立 heads

架构本身极简:一个 transformer 主干 f_s(占绝大多数参数),n 个并列的小 transformer head f_{h_1}..f_{h_n}(每个 1 层左右),共享 unembedding f_u

P_θ(x_{t+i} | x_{1:t}) = softmax(f_u(f_{h_i}(f_s(x_{1:t}))))

注意三件事:

  1. n 个 head 是并行的,不是串行。第 i 个 head 不依赖第 i−1 个 head 的输出 — 它直接从 trunk 的 hidden state z_t 出发。
  2. head 之间不共享权重(否则它们就退化成同一个东西)。但 unembedding 是共享的(否则 vocab 大词表的 V × d 会复制 n 份,内存爆)。
  3. 推理时默认只用 head 1(next-token head)做 autoregressive,完全跟 NTP 部署兼容;额外的 head 是可选的红利
(a) Next-token prediction Transformer trunk L 层 输入 x_{1:t} → hidden z_t f_u (unembed) P(x_{t+1} | ·) 单 head — 只学局部 (b) Multi-token prediction (n=4) Shared trunk f_s L−(n−1) 层 (compute-matched) 输入 x_{1:t} → z_t ∈ ℝ^d f_h1 f_h2 f_h3 f_h4 独立 head (各 1 层 transformer) shared f_u P(x_{t+1}) P(x_{t+2}) P(x_{t+3}) P(x_{t+4})
左:标准 NTP — 一个 head,学到的 hidden 只为下一 token 服务。右:n=4 MTP — 主干 trunk 把 z_t 同时喂给 4 个并行 head。每个 head 是 1 层 transformer + 共享 unembedding。注意 trunk 减少了 n−1=3 层以保持 compute-matched(对比 baseline 时 NTP 多 3 层 trunk,等价 FLOPs)。

3.1 为什么 head 之间不串行 / 不共享?

论文 Appendix B 比较了几种变体(causal head、anti-causal head、replicated unembed)。结论:本文这种"独立并行 head + shared unembed"在质量与内存间取得最佳平衡。


4 · 内存调度:把 O(nV) 压回 O(V)

n 个 head 同时算 logits 听起来内存就要爆 — vocab size V (~32k 或 100k+) 远大于 hidden d (~4k),所以 logits 张量 (n, batch, seq, V) 是显存大头。Naive 实现内存峰值 = O(nV + d)。

论文的解法是顺序前向 + 顺序反向 + 在 trunk 累加梯度:

Step 0: Trunk forward → z_t (持有, ℝ^d) 仅 d 维 hidden 长期占内存 Step 1: head 1 FW → logits (V) BW → ∂L/∂z 累加 Step 2: head 2 FW → logits (V) BW → 累加 ∂L/∂z Step 3: head 3 FW → logits (V) BW → 累加 ∂L/∂z Step 4: head 4 FW → logits (V) BW → 累加 ∂L/∂z 关键: 每个 step 完成后立即 free 掉 logits / activations,只保留 d 维 trunk 梯度累加 Step 5: 用累加好的 ∂L_n/∂z 一次 trunk backward → 完成
Memory-efficient MTP 训练流程。整个过程峰值显存只比 NTP 多一个 head 的临时 logits(同样 O(V)),因为 head 是一个一个跑的,跑完立刻释放。Trunk 那侧只需要一个 d 维的梯度累加器。最后一次 trunk backward 用累加后的 ∂L_n/∂z 完成。Table S5 显示这跟 naive O(nV) 比,显存从爆显卡降到能跟 NTP 同 batch size 训练。

具体数据 (V=32k, d=4096, n=4)

反向论证: 如果不做这个调度,n=4 MTP 实质上不能训 — 这就是为什么之前 Qi 2020 那种"复制 residual stream"的方案没火,它根本无法做 compute/memory-matched 比较。本文的工程优化是它能 scale 到 13B 的前提。

5 · 公式拆解 + 信息论解释

5.1 总 loss

L_n = −Σ_t Σ_{i=1..n} log P_θ(x_{t+i} | x_{1:t})

注意是对每个位置 t 都算 n 个 head 的 loss,所以训练里反传次数 = T × n,但因为 head 轻量、内存调度好,wall-clock 不变。

5.2 因式分解(为什么不只是"平均 n 个 NTP loss")

论文的关键代数:

P_θ(x_{t+n:t+1} | x_{1:t}) = P_θ(z_{1:t} | x_{1:t}) · Π_{i=1..n} P_θ(x_{t+i} | z_{1:t})

这里有一个独立性假设:给定 trunk hidden z_t 后,n 个未来 token 互相条件独立。这是 MTP 跟"先预测 t+1 再用 t+1 条件预测 t+2"(后者就是 NTP 自回归 unrolling)的本质不同

这个假设强不强?它意味着 trunk 必须把"未来 n 步所有的共有信息"压进 z_t — 因为 head 之间不能再交流。这正是 MTP 给 trunk 的隐式压力,也是为什么 base model 能力会变强。

5.3 信息论解释 (论文 §5.2)

设 X = next token, Y = next-next token,context C 略去:

H(X) = H(X|Y) + I(X; Y)
H(X) + H(Y) = H(X|Y) + 2·I(X; Y) + H(Y|X)

论文的洞察:第二项 H(Y|X) 在下一个位置 t+1 还会出现一次(那时它就是新的 H(X)),所以跨位置看可以"丢掉"。剩下的等式说明:

2-token prediction 把 I(X;Y)(当前 token 与下一 token 的互信息)的权重提高了 2 倍。

物理直觉:I(X;Y) 度量的是"X 这步对 Y 的影响有多大" — 也就是 X 是不是 choice point。MTP 给 choice point 这种位置增加权重,而 stylistic 位置(I 接近 0)的权重不变。这跟 §1.2 直觉一致。

5.4 数值敏感性: choice point 实际拿到多少倍权重?

n普通点权重choice point 权重比值
1 (NTP)11
2231.5×
44102.5×
88364.5×

来自论文 Appendix L.3:choice point 的隐式权重 = n(n+1)/2,普通点 = n。所以 n=4 大概给关键决策位置 2.5× 的相对放大。


6 · Worked Example: n=4 头一个具体位置

设场景:训练 7B code 模型,vocab 32k,hidden d=4096,n=4。当前训练序列片段:

def factorial(n):
    if n == 0:
        return 

当前位置 t 对应到 token "return "(注意空格)。trunk 跑完前向得到 hidden z_t ∈ ℝ⁴⁰⁹⁶。这个 hidden 同时被 4 个 head 消费:

context: ... if n == 0:\n return Trunk → z_t ∈ ℝ⁴⁰⁹⁶ head 1 → t+1 "1" : 0.78 ✓ "n" : 0.09 "None" : 0.05 "0" : 0.04 target: "1" head 2 → t+2 "\n" : 0.62 ✓ " " : 0.18 "#" : 0.07 "\nelse": 0.05 target: "\n" head 3 → t+3 " " : 0.41 "return": 0.22 ✓ "def" : 0.09 "#" : 0.05 target: "return" head 4 → t+4 " n" : 0.27 ✓ " 1" : 0.14 " None" : 0.10 " " : 0.08 target: " n"
n=4 MTP 在 "return " 位置同时预测下 4 个 token。head 1 (next-token) 把 0.78 概率放到 "1" — 容易,只看 if 分支。head 4 (4 步后) 要预测 " n" — 这要求 trunk 在 z_t 里就编码了"我们正在写 return 1\n return n*factorial(n-1)"这种程序级别的预期。head 4 的 loss 才是真正逼着 trunk 学习长程程序结构的力量。target 概率从 0.78 → 0.62 → 0.22 → 0.27 单调下降(预测越远越难),但即便 0.27 也远高于 1/32000 的随机水平,说明 trunk 学到了真东西。

6.1 这个例子里 choice point 的影子

位置 t 是 "return " 后,选哪个 token 是有 if 分支约束的局部决策(选 1 几乎确定)— 不是 choice point。但位置 t+3 才是真正的 choice point:写完第一行 return 后,要不要再写一个 return / if / else 是程序结构的关键岔路。head 4 的 loss 就把这个岔路的判别信号反传到 t 处的 trunk hidden,推动它编码"完整 factorial 函数模板"。

6.2 反向论证:为什么不只用 head 4?

试问:如果只训 head 4 (skip 1/2/3),会怎样?答:你失去了完整的"链条"。模型必须知道"先生成 X 再生成 Y"才能写出连贯输出 — 跳着只学第 4 步的 token 等于让模型只看每隔 4 个 token 的语料,信息量降到 1/4,且没法做 autoregressive inference。所以 1..n 的组合才是 MTP — 它是 NTP 的超集而非替代。


7 · Self-Speculative Decoding: 顺手的 3× 加速

这是论文最"商业气息"的副产品。训完一个 n=4 MTP 模型后,推理时:

  1. 主 head (head 1) 跑 1 步 → 候选 t+1 token
  2. head 2/3/4 在同一次 forward 里也输出了 t+2/t+3/t+4 的 logits — 免费的 3 个 draft token
  3. 把这 4 个 token 喂回 trunk 做一次 forward(并行验证 4 个位置)
  4. 用 rejection sampling 决定接受到第几个

这就是 Stern 2018 的 blockwise parallel decoding,也是 Medusa 的前身。MTP 的关键优势是:额外 head 是预训练时就训好的,acceptance rate 远高于"finetune 阶段塞 head"的 Medusa。

推理时的 self-speculative decoding (n=4) AR (NTP): step1 step2 step3 step4 step5 5 forwards → 5 tokens (1 token / forward) MTP self-spec: forward 1 → 提议 t+1..t+4 forward 2 → 并行验证 4 个 2 forwards → 平均 2.5 接受 → ~3.0× (code) code: 3.0× (2.5/3 接受率) · text: 2.7× · 8-byte 模型: 6.4×
Self-spec 流程:一次主 forward 同时给出 4 个候选 → 一次 trunk forward 并行验证。Speedup ≈ avg accepted tokens / 1 forward。关键:这个 speedup 是"训了 MTP 之后免费送的" — 不需要 draft model,不需要额外 finetune。即便你认为 MTP 对 base 能力提升不大,光这个推理 perk 就值得用。

7.1 跟 Medusa 的对比 — 这是 MTP 最辣的产品角度

MedusaMTP self-spec
额外 head 何时训预训练完之后 finetune预训练时就一起训
head 学习量~10B token finetune,有限1T token,与 trunk 联训
Acceptance rate偏低 (~1.5–2 / 3)2.5 / 3 (code)
对 base 能力无影响(冻结)提升 (§3 的实验)
实施门槛给已有模型加需从头预训

8 · 实验关键结果

8.1 规模决定一切 (Figure 3)

这是论文最核心、最反直觉的结论:MTP 在小模型上甚至略 hurt baseline,只在 6.7B+ 才显著超越

规模MBPP pass@1 (n=4 vs n=1)HumanEval pass@1 (n=4 vs n=1)
0.3B2 vs 42 vs 5
1.3B~ 持平~ 持平
6.7B+1.7%+3.9%
13B+4.5% (24→26)+12% 相对 (相对收益)
解释: 小模型自己 trunk 容量都不够支撑 NTP 流畅度,再分一份给"预测未来"就压力过大。一旦 trunk 容量足够(7B+),它有"余裕"被 MTP 推动学习长程结构,从而 base 能力反弹超越 NTP。这是 emergent behavior 的一种。

8.2 最优 n (Table 1)

nMBPP@1HumanEval@1APPS/Intro@1
1 (baseline)30.022.82.8
230.322.22.1
433.824.01.6
631.920.63.5
830.720.03.5

结论:n=4 (token-level) 几乎是普适甜点。byte-level 模型 n=8 才是甜点(byte 信息密度低,需要看更远)。

8.3 多 epoch 训练依然有效 (1T tokens, 4 epoch)

MBPP@1HumanEval@100
n=1, 1T (4 epoch)40.783.0
n=4, 1T (4 epoch)43.1 (+2.4)86.2 (+3.2)

说明 MTP 不是"早期训练 trick" — 重复看数据时 MTP 依然继续榨数据的价值。

8.4 Natural language: 取决于评测方式 (Figure 5/6)

解读:MTP 在生成质量上有优势,在判别 / 似然类任务上不一定有。这跟"MTP 让模型学到长程结构"的理论是一致的 — 长程结构对生成连贯性重要,对单点 likelihood 不重要。

8.5 Algorithmic reasoning (Figure 8)

多项式算术任务(F7[X]/(X^5),1–10 操作数):n=2/n=4 在 in-domain 与 OOD 都显著超 NTP。论文有一句猛话:

"Tripling the model size has a considerably smaller effect than replacing next-token prediction with multi-token prediction loss."

即:把 30M → 100M 模型的提升 不如 NTP → MTP 的提升。这是 MTP 最猛的卖点之一。


9 · 与同类工作对比

工作核心做法跟本文的关系
NTP (GPT, Llama)预测 t+1,单 head本文的 baseline,被超越在 code/reasoning
Stern 2018 blockwisefinetune 阶段加 linear 多 head本文用 transformer head 替换 + pretrain 期联训
Medusa (Cai 2024)finetune 阶段加 head, tree attentionspec decoding 角度互补;Medusa 不动 base
Qi 2020多 token 预测,但 residual 复制 n 份计算膨胀,无法 compute-matched;本文是"compute-matched 版本"
UL2 / span corruptionspan 填空15–25% token 反传,本文 100%
XLNet permuted LM排列序列实践中只 15% 反传
DeepSeek-V3 MTP串行 head + 给 t+i 看 t+i−1 的 hidden是本文的变体,DeepSeek 选择串行换更高质量
MiMo-V2-Flash (见 04)生产部署 MTP head 用于 spec decoding本文方法的商业落地版;MiMo 直接复用 MTP head 做 inference 加速
TOP (见 10)预测未来 token 集合的排序,不再预测精确序列反命题 — TOP 认为 MTP 让模型学"我说 X 之后会说什么"太死板,改成"未来 N 个 token 的相对排名"反而更鲁棒
位置定位: 本文 (Gloeckle 2024) 是 MTP 这个谱系的开山。后续路线可粗分两支:① 沿用且改进结构(DeepSeek-V3 串行版本、MiMo 工程版本); ② 质疑核心假设(TOP 认为预测精确未来 token 是 NTP 的同样错误,应该转向 ranking/set prediction)。读者读完本文后可以读 TOP 的反命题、再读 MiMo-V2 的产品化,形成完整的 MTP 谱系图。

10 · 局限 / 个人 take / 待验证问题

论文承认或暗示的局限

我的疑问 (offline 阅读后想验证的)

  1. 13B 是论文 ceiling,但 70B / 405B / MoE 上 MTP 收益是继续放大还是饱和?DeepSeek-V3 的成功说明放大,但论文里只跑到 13B。
  2. "choice point 拿到 n(n+1)/2 倍权重"很优美,但有没有反例 — 比如 stylistic 但局部熵高(韵脚押韵之类)的位置会不会被错误升权?
  3. self-spec acceptance 2.5/3 听起来惊人,但和模型 entropy 强相关 — 在 RLHF 后 entropy 显著下降的 model 上,MTP head 的预测分布是否还跟 main head 对齐?(参考 NeMo-RL × EAGLE-3 论文的 verifier-exact 讨论)
  4. 跟 DeepSeek-V3 MTP(串行)比,本文的"并行 head + 条件独立假设"在大模型上是否真的够好,还是 DeepSeek 的串行才是正确路径?这个 ablation 论文里没有。
  5. TOP 的批评 — "预测精确未来 token 跟 NTP 同样有 exposure bias 问题" — 是否成立?MTP 里 t+4 的 head 训练时也是 teacher-forced 看 z_t (z_t 来自 ground-truth context),理论上确实没解决这个根本问题,只是把 bias 推远了 4 步。
  6. n=4 的 4 个 head 总参数 + 1 个共享 unembed,部署时如果只用 head 1 就能扔掉其他 3 个 — 但实际生产中 self-spec 加速这么诱人,大家肯定都留着。这就让"compute-matched"的论断在 inference 端有点取巧:你训练时是 compute-matched,但 inference 时 4 个 head 都在内存里。

记忆点

立场 NTP 是 sample-inefficient 的;MTP 用 n 头同预 + 共享 trunk + 仔细内存调度,把"预知未来"作为 pretrain 主目标,no time/memory overhead
公式 L_n = −Σ_t Σ_{i=1..n} log P(x_{t+i} | x_{1:t});条件独立假设把 trunk 推向学习长程结构
规模门槛 <3B regress, 6.7B+ 起飞, 13B 给出 +12% HumanEval, +17% MBPP
免费红利 n−1 个 head 直接做 self-spec → 3.0× code, 6.4× byte
最优 n token-level 4, byte-level 8;NLP multiple-choice 上 n=4 反而 regress
谱系 DeepSeek-V3 / MiMo-V2 是工程后继;TOP (10) 是反命题;Medusa 是只做 inference 的弱化版
关键洞察 Choice point 隐式权重 n(n+1)/2 vs 普通点 n — MTP 是自动的 token re-weighting

精读笔记 · 配套 PDF: /data/szhang967/papers/paper-notes/models/MTP_Gloeckle_2404.19737.pdf
关联: 04 · MiMo-V2-Flash (MTP 工程化) · 10 · TOP (反命题)