Better & Faster LLMs via Multi-token Prediction
速读卡片 (TL;DR)
一句话:把 LLM 的预训练 loss 从"预测下一个 token"扩展成"用 n 个独立 head 同时预测下 n 个 token,共享一个 transformer trunk",训练时间不变、内存通过精心调度也不增,下游 code 显著变强,且 inference 时几个 head 直接拿来做 self-speculative decoding,3× 免费加速。
立场:这是 MTP 谱系的奠基论文 — 后续 DeepSeek-V3、MiMo-V2 的 MTP head、Medusa 等都借鉴这套"shared trunk + n heads + 仔细调度反传"。生效门槛是scale(7B 以上)且偏向 generative/coding 任务。
1 · 动机:NTP 的内在缺陷与"预知未来"的诱惑
1.1 历史脉络:NTP 是 LLM 的全部,但也是它的天花板
从 GPT-2 到 Llama,LLM 的预训练目标几乎只有一个:next-token prediction (NTP),即最小化 −Σ log P(x_{t+1} | x_{1:t})。这套范式简单、可并行(teacher forcing)、跟下游 autoregressive generation 形式上一致。它给我们带来了 GPT-3、Llama、Gemini。
但论文一句话点穿它的结构性缺陷:
"State-of-the-art next-token predictors call for orders of magnitude more data than human children to arrive at the same level of fluency."
翻译:同样达到流利的语言水平,LLM 要烧掉比人类儿童多几个数量级的语料。这意味着 NTP 这把锤子把 sample efficiency 浪费得很厉害。问题在哪?论文给出三层诊断:
- Teacher forcing 锁死了"局部模式":训练时模型每步都看到 ground-truth 的过去,只需要根据上一个 token 预测下一个;它学不到"我现在选了 A,几个 token 后会 derail 还是收敛"这种长程依赖的因果。
- 所有 token loss 等权:语言里有"句子里的 the"和"代码里的函数名"这种关键 choice point,后者一错满盘皆输,但 cross-entropy 给它和 the 一样的权重。
- 训练-推理分布不匹配 (exposure bias):训练时永远看到 gold history,推理时只能看到自己生成的东西。误差累积没人管。
1.2 让模型"预知未来"会强迫它学到什么
设想一下:把训练目标从"预测 t+1"扩展成"同时预测 t+1, t+2, t+3, t+4",会发生什么?
- 预测 t+2 必须先在 t 处对 t+1 的选择有把握 — 否则 t+2 的分布就被 t+1 的不确定性彻底淹没。这相当于让模型在内部表征里显式地规划 t+1 是什么。
- 从信息论看,共享 trunk 必须把"对未来 n 步都有用"的信息压进 hidden state,而不只是"刚好能挑出 t+1"的局部特征。这反着推动 trunk 学全局结构。
- 对那些 NTP 觉得"无关紧要、随便挑一个就行"的 stylistic 位置,MTP 也不会浪费容量;但对 choice point,因为它的后果会出现在 t+2..t+n 中,这个位置在 loss 里被多算几次(论文 §5.1 算出来是
n(n+1)/2vs 普通点的n)。
所以 MTP 不是简单的"多预测几个未来 token";它是用 auxiliary task 把"长程一致性"从神经网络里榨出来。
1.3 别人的方案为什么不够
| 替代方案 | 动了什么 | 致命缺陷 |
|---|---|---|
| UniLM (Dong 2019) | BERT 风格 mask + 多种 attention | 只对 ~15% token 反传 — 大头数据浪费 |
| UL2 / span corruption (Tay 2022) | encoder-decoder span 填空 | 同上,15–25% token 实际反传;且不是纯 causal,部署不友好 |
| XLNet (Yang 2019) permuted LM | 随机 permute 序列 | 置换太难重建,实际只对 15% 反传 |
| Qi 2020 multi-token | 多 token 预测,但 residual stream 复制 n 份 | 参数 / 计算膨胀 n 倍,无法做 compute-matched 比较 |
| Medusa (Cai 2024) | finetune 阶段加多 token head | 只为 inference 加速服务,不影响 base model 的能力 |
| MTP (本文) | pretrain 时 n 个轻量 head 共享 trunk,所有 token 都反传 | 100% token 参与,compute-matched,base 能力提升 |
2 · 背景速查
2.1 关键术语
| 术语 | 含义 |
|---|---|
| NTP | Next-token prediction,标准 LLM 预训练 loss L₁ = −Σ log P(x_{t+1} | x_{1:t}) |
| MTP | Multi-token prediction,扩展为 L_n = −Σ Σ_{i=1..n} log P(x_{t+i} | x_{1:t}) |
| Shared trunk f_s | 主 transformer 主干,从 context 产出隐藏表示 z_t |
| Output head f_{h_i} | 第 i 个独立预测头(轻量 transformer 层),负责预测 t+i |
| Unembedding f_u | 共享的反嵌入矩阵,把 hidden 投到 vocab logits |
| Self-speculative decoding | 用模型自己额外的 head 当 draft,用主 head 当 verifier 的 spec decoding 变体,无需额外 draft model |
| Compute-matched | 对比时保持总参数量 / 总 FLOPs 一致 — 加 n−1 个 head 就从 trunk 拿掉 n−1 层 |
| Induction head | Olsson 2022 提出的 in-context 复制能力的最小机制 |
| Choice point | 语言生成中"高熵 / 后果重大"的关键位置(如函数名、推理步骤的论断) |
2.2 NTP 的标准回顾
给定 token 序列 x_1, x_2, ..., x_T,模型由 transformer 主干 + 最后一层 unembedding 组成,产出每个位置的 vocab 分布。Loss 是所有位置的 cross-entropy 之和,反传一次,所有 token 都参与。这是 GPT 系训练的全部公式。
3 · 方法:shared trunk + n 独立 heads
架构本身极简:一个 transformer 主干 f_s(占绝大多数参数),n 个并列的小 transformer head f_{h_1}..f_{h_n}(每个 1 层左右),共享 unembedding f_u。
注意三件事:
- n 个 head 是并行的,不是串行。第 i 个 head 不依赖第 i−1 个 head 的输出 — 它直接从 trunk 的 hidden state z_t 出发。
- head 之间不共享权重(否则它们就退化成同一个东西)。但 unembedding 是共享的(否则 vocab 大词表的 V × d 会复制 n 份,内存爆)。
- 推理时默认只用 head 1(next-token head)做 autoregressive,完全跟 NTP 部署兼容;额外的 head 是可选的红利。
3.1 为什么 head 之间不串行 / 不共享?
论文 Appendix B 比较了几种变体(causal head、anti-causal head、replicated unembed)。结论:本文这种"独立并行 head + shared unembed"在质量与内存间取得最佳平衡。
- 串行 head 会让训练时一个 head 失败影响下一个 head(类似 RNN 的 BPTT 问题),且无法做内存调度优化。
- 共享 head 参数就退化成 NTP — 因为 head_i 必须输出和 head_j 不同的分布(预测 t+i ≠ t+j),共享会让两个目标互相打架。
- 独立 head + 共享 unembed:每个 head 只承担"把 trunk hidden 翻译成第 i 步未来分布"这件事,head 自己的参数小,大头还在 trunk 里 — trunk 才是被 n 个 loss 共同推动的对象。这就是 MTP 改善 base model 能力的核心。
4 · 内存调度:把 O(nV) 压回 O(V)
n 个 head 同时算 logits 听起来内存就要爆 — vocab size V (~32k 或 100k+) 远大于 hidden d (~4k),所以 logits 张量 (n, batch, seq, V) 是显存大头。Naive 实现内存峰值 = O(nV + d)。
论文的解法是顺序前向 + 顺序反向 + 在 trunk 累加梯度:
具体数据 (V=32k, d=4096, n=4)
- Naive 实现: 4 × 32k × 4 bytes × batch × seq = 巨大,batch 16 立即 OOM
- 本文调度: 同时只有 1 × 32k(在内存里),batch 可以保持 NTP 水平
- 论文 Table S5: training time 与 NTP 完全一致(顺序 head 没增加 wall-clock,因为 head 本身就轻量,GPU 算力没空跑别的)
5 · 公式拆解 + 信息论解释
5.1 总 loss
注意是对每个位置 t 都算 n 个 head 的 loss,所以训练里反传次数 = T × n,但因为 head 轻量、内存调度好,wall-clock 不变。
5.2 因式分解(为什么不只是"平均 n 个 NTP loss")
论文的关键代数:
这里有一个独立性假设:给定 trunk hidden z_t 后,n 个未来 token 互相条件独立。这是 MTP 跟"先预测 t+1 再用 t+1 条件预测 t+2"(后者就是 NTP 自回归 unrolling)的本质不同。
这个假设强不强?它意味着 trunk 必须把"未来 n 步所有的共有信息"压进 z_t — 因为 head 之间不能再交流。这正是 MTP 给 trunk 的隐式压力,也是为什么 base model 能力会变强。
5.3 信息论解释 (论文 §5.2)
设 X = next token, Y = next-next token,context C 略去:
H(X) + H(Y) = H(X|Y) + 2·I(X; Y) + H(Y|X)
论文的洞察:第二项 H(Y|X) 在下一个位置 t+1 还会出现一次(那时它就是新的 H(X)),所以跨位置看可以"丢掉"。剩下的等式说明:
2-token prediction 把 I(X;Y)(当前 token 与下一 token 的互信息)的权重提高了 2 倍。
物理直觉:I(X;Y) 度量的是"X 这步对 Y 的影响有多大" — 也就是 X 是不是 choice point。MTP 给 choice point 这种位置增加权重,而 stylistic 位置(I 接近 0)的权重不变。这跟 §1.2 直觉一致。
5.4 数值敏感性: choice point 实际拿到多少倍权重?
| n | 普通点权重 | choice point 权重 | 比值 |
|---|---|---|---|
| 1 (NTP) | 1 | 1 | 1× |
| 2 | 2 | 3 | 1.5× |
| 4 | 4 | 10 | 2.5× |
| 8 | 8 | 36 | 4.5× |
来自论文 Appendix L.3:choice point 的隐式权重 = n(n+1)/2,普通点 = n。所以 n=4 大概给关键决策位置 2.5× 的相对放大。
6 · Worked Example: n=4 头一个具体位置
设场景:训练 7B code 模型,vocab 32k,hidden d=4096,n=4。当前训练序列片段:
def factorial(n):
if n == 0:
return
当前位置 t 对应到 token "return "(注意空格)。trunk 跑完前向得到 hidden z_t ∈ ℝ⁴⁰⁹⁶。这个 hidden 同时被 4 个 head 消费:
"return " 位置同时预测下 4 个 token。head 1 (next-token) 把 0.78 概率放到 "1" — 容易,只看 if 分支。head 4 (4 步后) 要预测 " n" — 这要求 trunk 在 z_t 里就编码了"我们正在写 return 1\n return n*factorial(n-1)"这种程序级别的预期。head 4 的 loss 才是真正逼着 trunk 学习长程程序结构的力量。target 概率从 0.78 → 0.62 → 0.22 → 0.27 单调下降(预测越远越难),但即便 0.27 也远高于 1/32000 的随机水平,说明 trunk 学到了真东西。6.1 这个例子里 choice point 的影子
位置 t 是 "return " 后,选哪个 token 是有 if 分支约束的局部决策(选 1 几乎确定)— 不是 choice point。但位置 t+3 才是真正的 choice point:写完第一行 return 后,要不要再写一个 return / if / else 是程序结构的关键岔路。head 4 的 loss 就把这个岔路的判别信号反传到 t 处的 trunk hidden,推动它编码"完整 factorial 函数模板"。
6.2 反向论证:为什么不只用 head 4?
试问:如果只训 head 4 (skip 1/2/3),会怎样?答:你失去了完整的"链条"。模型必须知道"先生成 X 再生成 Y"才能写出连贯输出 — 跳着只学第 4 步的 token 等于让模型只看每隔 4 个 token 的语料,信息量降到 1/4,且没法做 autoregressive inference。所以 1..n 的组合才是 MTP — 它是 NTP 的超集而非替代。
7 · Self-Speculative Decoding: 顺手的 3× 加速
这是论文最"商业气息"的副产品。训完一个 n=4 MTP 模型后,推理时:
- 主 head (head 1) 跑 1 步 → 候选 t+1 token
- head 2/3/4 在同一次 forward 里也输出了 t+2/t+3/t+4 的 logits — 免费的 3 个 draft token
- 把这 4 个 token 喂回 trunk 做一次 forward(并行验证 4 个位置)
- 用 rejection sampling 决定接受到第几个
这就是 Stern 2018 的 blockwise parallel decoding,也是 Medusa 的前身。MTP 的关键优势是:额外 head 是预训练时就训好的,acceptance rate 远高于"finetune 阶段塞 head"的 Medusa。
7.1 跟 Medusa 的对比 — 这是 MTP 最辣的产品角度
| Medusa | MTP self-spec | |
|---|---|---|
| 额外 head 何时训 | 预训练完之后 finetune | 预训练时就一起训 |
| head 学习量 | ~10B token finetune,有限 | 1T token,与 trunk 联训 |
| Acceptance rate | 偏低 (~1.5–2 / 3) | 2.5 / 3 (code) |
| 对 base 能力 | 无影响(冻结) | 提升 (§3 的实验) |
| 实施门槛 | 给已有模型加 | 需从头预训 |
8 · 实验关键结果
8.1 规模决定一切 (Figure 3)
这是论文最核心、最反直觉的结论:MTP 在小模型上甚至略 hurt baseline,只在 6.7B+ 才显著超越。
| 规模 | MBPP pass@1 (n=4 vs n=1) | HumanEval pass@1 (n=4 vs n=1) |
|---|---|---|
| 0.3B | 2 vs 4 | 2 vs 5 |
| 1.3B | ~ 持平 | ~ 持平 |
| 6.7B | +1.7% | +3.9% |
| 13B | +4.5% (24→26) | +12% 相对 (相对收益) |
8.2 最优 n (Table 1)
| n | MBPP@1 | HumanEval@1 | APPS/Intro@1 |
|---|---|---|---|
| 1 (baseline) | 30.0 | 22.8 | 2.8 |
| 2 | 30.3 | 22.2 | 2.1 |
| 4 | 33.8 | 24.0 | 1.6 |
| 6 | 31.9 | 20.6 | 3.5 |
| 8 | 30.7 | 20.0 | 3.5 |
结论:n=4 (token-level) 几乎是普适甜点。byte-level 模型 n=8 才是甜点(byte 信息密度低,需要看更远)。
8.3 多 epoch 训练依然有效 (1T tokens, 4 epoch)
| MBPP@1 | HumanEval@100 | |
|---|---|---|
| n=1, 1T (4 epoch) | 40.7 | 83.0 |
| n=4, 1T (4 epoch) | 43.1 (+2.4) | 86.2 (+3.2) |
说明 MTP 不是"早期训练 trick" — 重复看数据时 MTP 依然继续榨数据的价值。
8.4 Natural language: 取决于评测方式 (Figure 5/6)
- Multiple-choice / NLL benchmark: n=2 持平 baseline,n=4 略 regress(7B 规模)
- Summarization (ROUGE-L): n=2 / n=4 都比 NTP 好
- GSM8K (200B token): n=2 显著超 NTP;但 500B 后反转 — 数据量大时 NTP 追上来
解读:MTP 在生成质量上有优势,在判别 / 似然类任务上不一定有。这跟"MTP 让模型学到长程结构"的理论是一致的 — 长程结构对生成连贯性重要,对单点 likelihood 不重要。
8.5 Algorithmic reasoning (Figure 8)
多项式算术任务(F7[X]/(X^5),1–10 操作数):n=2/n=4 在 in-domain 与 OOD 都显著超 NTP。论文有一句猛话:
"Tripling the model size has a considerably smaller effect than replacing next-token prediction with multi-token prediction loss."
即:把 30M → 100M 模型的提升 不如 NTP → MTP 的提升。这是 MTP 最猛的卖点之一。
9 · 与同类工作对比
| 工作 | 核心做法 | 跟本文的关系 |
|---|---|---|
| NTP (GPT, Llama) | 预测 t+1,单 head | 本文的 baseline,被超越在 code/reasoning |
| Stern 2018 blockwise | finetune 阶段加 linear 多 head | 本文用 transformer head 替换 + pretrain 期联训 |
| Medusa (Cai 2024) | finetune 阶段加 head, tree attention | spec decoding 角度互补;Medusa 不动 base |
| Qi 2020 | 多 token 预测,但 residual 复制 n 份 | 计算膨胀,无法 compute-matched;本文是"compute-matched 版本" |
| UL2 / span corruption | span 填空 | 15–25% token 反传,本文 100% |
| XLNet permuted LM | 排列序列 | 实践中只 15% 反传 |
| DeepSeek-V3 MTP | 串行 head + 给 t+i 看 t+i−1 的 hidden | 是本文的变体,DeepSeek 选择串行换更高质量 |
| MiMo-V2-Flash (见 04) | 生产部署 MTP head 用于 spec decoding | 本文方法的商业落地版;MiMo 直接复用 MTP head 做 inference 加速 |
| TOP (见 10) | 预测未来 token 集合的排序,不再预测精确序列 | 反命题 — TOP 认为 MTP 让模型学"我说 X 之后会说什么"太死板,改成"未来 N 个 token 的相对排名"反而更鲁棒 |
10 · 局限 / 个人 take / 待验证问题
论文承认或暗示的局限
- 规模阈值高:<3B 模型上 MTP 反而 hurt — 学术研究 / 资源有限的小模型不适用
- NLP multiple-choice benchmark 上 n=4 regress — MTP 偏向生成,不偏向 likelihood-as-eval 任务
- n 是 hyperparameter,不同任务最优值不同(code 4 / byte 8 / MATH-200B 是 2 / MATH-500B 是 1)
- 因式分解假设 head 之间条件独立(给定 z_t)— 实际中 t+1 和 t+2 显然相关,这个假设可能让 head 的预测有点"模糊化"
- 没研究 MTP × instruction tuning / RLHF 的 interaction — 现在所有 SOTA 模型都过 RLHF,MTP 在那个阶段是否依然有用?
我的疑问 (offline 阅读后想验证的)
- 13B 是论文 ceiling,但 70B / 405B / MoE 上 MTP 收益是继续放大还是饱和?DeepSeek-V3 的成功说明放大,但论文里只跑到 13B。
- "choice point 拿到 n(n+1)/2 倍权重"很优美,但有没有反例 — 比如 stylistic 但局部熵高(韵脚押韵之类)的位置会不会被错误升权?
- self-spec acceptance 2.5/3 听起来惊人,但和模型 entropy 强相关 — 在 RLHF 后 entropy 显著下降的 model 上,MTP head 的预测分布是否还跟 main head 对齐?(参考 NeMo-RL × EAGLE-3 论文的 verifier-exact 讨论)
- 跟 DeepSeek-V3 MTP(串行)比,本文的"并行 head + 条件独立假设"在大模型上是否真的够好,还是 DeepSeek 的串行才是正确路径?这个 ablation 论文里没有。
- TOP 的批评 — "预测精确未来 token 跟 NTP 同样有 exposure bias 问题" — 是否成立?MTP 里 t+4 的 head 训练时也是 teacher-forced 看 z_t (z_t 来自 ground-truth context),理论上确实没解决这个根本问题,只是把 bias 推远了 4 步。
- n=4 的 4 个 head 总参数 + 1 个共享 unembed,部署时如果只用 head 1 就能扔掉其他 3 个 — 但实际生产中 self-spec 加速这么诱人,大家肯定都留着。这就让"compute-matched"的论断在 inference 端有点取巧:你训练时是 compute-matched,但 inference 时 4 个 head 都在内存里。
记忆点
公式 L_n = −Σ_t Σ_{i=1..n} log P(x_{t+i} | x_{1:t});条件独立假设把 trunk 推向学习长程结构
规模门槛 <3B regress, 6.7B+ 起飞, 13B 给出 +12% HumanEval, +17% MBPP
免费红利 n−1 个 head 直接做 self-spec → 3.0× code, 6.4× byte
最优 n token-level 4, byte-level 8;NLP multiple-choice 上 n=4 反而 regress
谱系 DeepSeek-V3 / MiMo-V2 是工程后继;TOP (10) 是反命题;Medusa 是只做 inference 的弱化版
关键洞察 Choice point 隐式权重 n(n+1)/2 vs 普通点 n — MTP 是自动的 token re-weighting
精读笔记 · 配套 PDF: /data/szhang967/papers/paper-notes/models/MTP_Gloeckle_2404.19737.pdf
关联: 04 · MiMo-V2-Flash (MTP 工程化) · 10 · TOP (反命题)