Beyond Multi-Token Prediction: Pretraining LLMs with Future Summaries
速读卡片 (TL;DR)
一句话: 预测具体的远期 token 是错的抽象层级 —— 远期未来该用一个紧凑的summary 向量来监督;FSP 在 NTP 旁挂一个辅助 head,要么去拟合未来窗口的 bag-of-words(FSP-BoW),要么去 distill 一个反向语言模型(RevLM, 右到左训)的隐藏状态(FSP-RevLM)。
立场: 这是一篇 objective design 的论文,不是新架构。核心贡献是把"未来"从 token 列表抽象成可学习的向量,从而绕开 MTP 加 head 不可扩展的瓶颈。RevLM 翻倍了 pretraining FLOPs,但作者赌的是"compute 多 / data 少"这个时代;summary 是 train-time only,推理时 head 直接丢掉。
1. 动机:为什么 MTP 不够,且 token 不是对的目标
1.1 历史脉络:从 NTP 的 exposure bias 到 MTP 的短视
整个故事的起点是teacher forcing 这个老问题。NTP 训练时永远喂 ground-truth 前缀去预测 xt+1;推理时模型却只能基于自己 sample 出的 token 继续走。这个 train-inference mismatch(Bengio 2015 起就在讲)在长输出场景里会复合:错误一旦累积,模型就 drift 出训练分布。
更阴险的是 Bachmann & Nagarajan 2024 指出的 "Clever Hans cheat":就算训练时不 drift,teacher forcing 也会让模型走捷径 —— 直接从 prefix 里某个局部线索抄答案,绕过它本应学的长程结构。结果是梯度饥饿(gradient starvation):该被训出来的 long-horizon plan 因为 loss 已经被简单捷径吃饱,就再也得不到信号了。
Gloeckle 等人的 MTP (2024) 是对此的第一记重击:在 backbone 顶上挂 k 个辅助 head,每个 head 单独负责预测 xt+2, xt+3, …, xt+k。这等价于"减弱 teacher forcing":要预测 xt+3 但只给到 x≤t,模型必须把 xt+1, xt+2, xt+3 的边际都隐式建出来。DeepSeek-V3、Qwen-3、MiMo、Nemotron-3 都已经在用 MTP 变体了 —— 它确实 work。
1.2 别的方案为什么不够 —— "exact token" 不是对的抽象
但 MTP 有两个根本性的限制,这正是 FSP 要打的靶子:
| 方案 | 怎么减弱 teacher forcing | 瓶颈 |
|---|---|---|
| NTP | 不减弱 | shortcut + exposure bias |
| MTP (Gloeckle) | 每个 future token 一个 head | head 数 ∝ 视野;实际只能 k=2~4 |
| DS-MTP | head 之间递归喂中间表示 | 同上,head 数仍受限 |
| L-MTP / MuToR | 跳跃 / 随机偏移采样未来 token | 启发式抽样,可能漏掉关键信号 |
| Joint-token / Next-Latent | 联合分布 / 表示空间对齐 | 仍是 token 粒度,head 无法长程 |
| FSP (本文) | 1 个 head 预测整段未来的 summary | 怎么定义 summary 是新问题 |
关键 insight:具体哪个 token 在哪个位置出现 —— 这件事在 100 个 token 之外 本质上是不可预测的(natural language 的天花板就是 entropy 噪声)。把 cross-entropy 算在它们头上,要么把模型逼疯,要么模型干脆放弃学习,只优化前几位的边际分布。这就是 paper 反复强调的"不是所有未来 token 都同等重要"。
而 summary —— 比如"未来这一段大概会出现哪些 vocabulary 项"或"未来这一段对应一个怎样的语义状态" —— 反而是可学的、可压缩的、对当前预测有用的。
1.3 为什么这事不平凡 —— 怎么定义和算这个 summary?
三个非平凡的地方:
- 形式选择:summary 是 multi-hot 词表向量?池化的 hidden state?VQ code?clustering id?每种都对应不同的 inductive bias 和 loss(BCE / ℓ2 / softmax)。本文实验了两种:BoW + RevLM 的 hidden state。
- 计算成本:如果 summary 需要训一个 teacher,等于 pretraining 成本翻倍。RevLM 跟 forward model 同等大小、同等步数训 —— 这是论文最大的"代价"。但作者论证的是 distillation 经典不算 teacher cost。
- "adaptive vs static":hand-crafted summary 把所有未来 token 等权对待,在 sibling discovery 这类"未来里只有一小段相关" 的场景下会被噪声淹没。learned summary 才能自己挑出"对当前预测有用的"那一块。
2. 背景速查
| 缩写 / 术语 | 含义 |
|---|---|
| NTP | Next-Token Prediction,标准 LM loss |
| MTP | Multi-Token Prediction (Gloeckle 2024),N 个 head 同时预测 t+1..t+N |
| DS-MTP | DeepSeek 变体,head 递归依赖 + 接收 ground-truth 中间 token |
| teacher forcing | 训练时喂 ground-truth 前缀给模型;严格定义"模型每接收 1 个 GT token 要预测多少 unseen 的信息" |
| FSP | 本文提出的 Future Summary Prediction,1 个 head 预测一个 future summary 向量 |
| FSP-BoW | summary = future window 的 multi-hot bag-of-words(可加 tf-idf 权重) |
| FSP-RevLM | summary = 反向 LM 在对应位置的 hidden state(ℓ2 distill) |
| RevLM | Reverse LM,在右到左序列上训的一个 standard transformer LM |
| τ | 未来窗口长度(BoW 收集 xt+2..xt+τ;RevLM 用全部 x≥t+2) |
| path-star / sibling discovery | 两个 synthetic 任务,分别测 long-horizon plan 与 adaptive future |
NTP / MTP 公式刷新
注意 MTP 是 marginal 假设(各 head 条件独立),近似真正的联合分布 P(xt+1..t+τ|x≤t)。
3. FSP 框架:统一所有 future-aware 目标
FSP 的 loss 形式特别干净:
其中:
Aφ(x≤t) = fha′ ∘ fs(x≤t)—— 共享 backbonefs,加一个独立的 auxiliary headfha′;a(t, τ)—— "未来 summary",这是 FSP 框架的设计自由度;ℓa—— 取决于 summary 形式:BCE(BoW) 或 ℓ2(hidden state)。
整个 FSP 设计的优雅之处在于:无论 summary 选什么,辅助 head 永远只有一个。这跟 MTP "每加一个未来位置就加一个 head" 是结构上的本质区别。
4. FSP-BoW: hand-crafted summary
FSP-BoW 是最简单的实现:在位置 t,统计未来窗口 {xt+2, …, xt+τ} 中出现的所有 token id,做成一个 V 维 multi-hot 向量。aux head 输出 V 维 logits,跟它做 reweighted BCE:
权重 w(i) 用 tf-idf:罕见的内容词权重大,常见的虚词权重小 —— 不然 "the / , / 's" 这种 token 把 BCE 的信号埋了。
Worked example:τ = 12 的某段代码
假设当前位置 t,未来 12 个 token 是:
x_{t+2..t+13} = ["def", " sum", "(", "a", ",", " b", ")", ":", " return", " a", " +", " b"]
BoW 向量 a ∈ {0,1}V:在 vocab id 对应这 12 个 unique token 的位置上是 1,其余是 0(注意 a, b 各只出现 2 次,但 multi-hot 不计数,只标 0/1)。tf-idf 后,def, return, + 这种"语义重"的 token 拿大权重,(, ,, :拿小权重。
aux head 看到 x≤t 输出 V 维 logit,要在所有这 ~9 个位置上 sigmoid 接近 1、其余位置 sigmoid 接近 0。
反向论证:为什么不直接预测排序的 token 序列?
因为位置在远期是不可预测的 —— 同样语义的代码可以是 def add(x,y): return x+y,排序错了 cross-entropy 直接爆炸。BoW 把位置抽掉,只保留"集合",这才让远期 supervision 变 trainable。τ 可以拉到 100(table 3),信号还在。
5. FSP-RevLM: learned summary via reverse LM
RevLM 是个独立训练的 transformer LM,但它喂的是 右到左 的序列。也就是说它在学:
训完之后,RevLM 在位置 t+2 输出的 hidden state gh ∘ gs(x≥t+2) 自然就编码了"从 t+2 到结束这一整段后缀的语义压缩"。这就是 forward model 的 aux head 要拟合的目标:
注意它做的是 representation distillation —— 老师不是另一个 forward LM,而是结构对偶的 reverse LM。这跟 next-latent prediction (Teoh 2025) 有相似 flavor,但 next-latent 拿的是 forward 自己的未来 hidden state(self-distillation),FSP-RevLM 拿的是显式的 right-to-left teacher。
Worked example:跟踪一个具体位置
令一段序列长 T = 8192,选 t = 3000:
- Forward 路径:x≤3000 → backbone fs(8B Llama3,d = 4096)→ 主 head 输出 logit∈ℝV 预测 x3001;aux head
fha′输出 ∈ ℝ4096。 - RevLM 路径:对同一序列做 reverse,在 reverse 序列里 x3002 是某个 prefix 的开头,RevLM 在那个位置的最后一层 hidden state 就是要拟合的 target。形状 ∈ ℝ4096。
- Loss:
‖aux − target‖²加到 NTP loss 上,coef 是固定的 1。
反向论证:为什么 reverse 训而不是 forward 自蒸馏?
如果用 forward 自己生成 target(像 next-latent prediction),target 还是带着 NTP 的 teacher-forcing 偏差;它习得的"未来"是被 ground-truth 引导出来的,可能仍然是 shortcut。RevLM 反向训完全独立 —— 它在 t+2 位置的 hidden state 已经"看过"完整 suffix,encode 的是真实 suffix 的语义,不带 forward 方向的捷径。
6. 同一位置,四种目标 worked example
把同一句拿出来对比,最能看清四种 objective 的差异。
假设序列是 (Python 注释):
# compute the sum of the list and return it def total(lst): return sum(lst)
取 t 在 def total(lst): 的冒号位置(即 xt = :)。
| 方法 | target | aux head 数 | loss |
|---|---|---|---|
| NTP | xt+1 = " return" | 0 | CE on V |
| MTP (k=4) | xt+1..t+4 = " return", " sum", "(", "lst" | 4 | 4× CE on V |
| FSP-BoW (τ=12) | multi-hot 集合 {return, sum, (, lst, )}(忽略位置) | 1 | BCE on V |
| FSP-RevLM | RevLM 在位置 t+2 的 hidden state ∈ ℝ4096,大约编码"这段是个返回 lst 求和的函数体" | 1 | ℓ2 on ℝd |
注意从上到下,target 越来越脱离"具体哪个 token 出现在哪",越来越靠近"这一段的语义角色"。FSP-RevLM 那一行的 target 描述本身就有信息论意义:它压缩了 dim < V × τ 的 token 序列到 d=4096 维。
7. 合成实验:验证两个 thesis
7.1 Path-Star Graph: 长视野 summary 重要
任务:输入是个 DAG 的邻接表 + start 节点 + end 节点;模型生成从 start 到 end 的完整路径。Bachmann & Nagarajan 的经典发现:NTP 学不会全程 plan,只学到"看 prefix 里 vi 的下一跳"的捷径。
7.2 Sibling Discovery: adaptive summary 重要
任务被改造成:序列由多个独立 component 拼接,每个 component 是 [S1i, S2i, S3i, Pi]。给定当前 component 的部分 prefix,模型要预测同 component 的 sibling —— 但未来里有大段无关 component 的 token。
结果(收敛速度,越低越快):
- FSP-BoW:components 数 ≤ 6 时优于 NTP,>6 时优势消失甚至变差 —— 因为 BoW 把无关 component 的 token 当 supervision,引入噪声。
- FSP-RevLM:全程优于 NTP,因为 RevLM 自己学会 attend 到当前 component 内的 token(论文 figure 7 的 attention 可视化证明了这点 —— heads 表现出 intra-component 模式)。
这是 "adaptive vs static summary" 的 cleanest 实验。
8. 真实 pretraining 结果
8.1 8B / 1T tokens 主表
| Task | NTP | MTP | DS-MTP | FSP-RevLM |
|---|---|---|---|---|
| ARC-Easy | 0.718 | 0.736 | 0.617 | 0.766 |
| ARC-Challenge | 0.531 | 0.552 | 0.426 | 0.559 |
| GSM8K | 0.716 | 0.678 | 0.704 | 0.705 |
| MATH | 0.342 | 0.309 | 0.335 | 0.351 |
| MBPP | 0.657 | 0.672 | 0.678 | 0.683 |
| HumanEval+ | 0.478 | 0.541 | 0.526 | 0.541 |
解读:
- FSP-RevLM 在 6 个 task 里 4 个最强,1 个并列。GSM8K 是唯一掉给 NTP 的(差 1.1pp,在 std error 内)。
- DS-MTP 在 8B 上 反而崩了 —— ARC-Easy 0.617 比 NTP 还差 10pp。论文没正面解释,推测跟 DS-MTP 把 ground-truth 中间 token 喂给 aux head 有关,在小窗口 + 大模型上可能反而引入了和 main head 冲突的信号。
- MTP 在 GSM8K 上明显倒退(0.678 vs 0.716),说明非"未来感知"目标并非总能赢 NTP;FSP-RevLM 至少没退步。
8.2 future summary 的形式 ablation (8B, table 3)
| Method | MATH | GSM8K | ARC-Easy |
|---|---|---|---|
| MTP (baseline) | 0.309 | 0.678 | 0.736 |
| MTP-Skip τ=4 (random sampling) | 0.277 | 0.639 | 0.722 |
| MTP-Skip τ=32 | 0.271 | 0.598 | 0.564 |
| FSP-BoW τ=12 | 0.331 | 0.699 | 0.737 |
| FSP-BoW τ=100 | 0.331 | 0.714 | 0.662 |
| FSP-RevLM | 0.351 | 0.705 | 0.766 |
解读:
- "随机偏移采样未来"是死路 —— MTP-Skip 全线劣于 MTP,τ 越大越差(可能因为 head 被噪声目标拉偏)。这是对 MuToR / L-MTP / TRELAWNEY 这类工作的直接打击。
- BoW 真的 work。τ=12 → 100 在 GSM8K 上还涨,但 ARC-Easy 在 τ=100 时反而降到 0.662 —— BoW 在长窗口下开始把无关词背景当 signal 了,这正好印证 sibling discovery 那个 thesis。
- FSP-RevLM 全面最强,印证 adaptive summary 的优势。
8.3 Pass@K diversity
论文的一个迷人副产物:在 GSM8K / MATH 上,FSP-RevLM 的 pass@K 曲线随 K 上升幅度比 MTP 更陡 —— 也就是说 base model 的解的多样性更高。
8.4 数据受限场景:FSP 优势放大
同 1T compute,但只用 50B unique tokens × 20 epochs。
- 1 epoch 设定下:FSP-RevLM 略优,DS-MTP 落后。
- 20 epoch (data-constrained) 设定下:所有 future-aware 方法 (MTP/DS-MTP/FSP) 全部 > NTP —— NTP 是 overfit 最快的那个。这跟 diffusion-LM 在 data-constrained 场景的表现是同方向的:更"约束式"的 supervision 在重复数据上 generalize 更好。
9. 与同类工作对比
| 方法 | 未来 supervision 形式 | aux head 数 | 关键差异 |
|---|---|---|---|
| MTP (Gloeckle 2024) [note] | k 个固定偏移 token | k | FSP 的起点,但短视野 |
| DS-MTP (DeepSeek-V3) | k 个 token + 中间 GT 反馈 | k | head 递归依赖,在 8B 上反而崩 |
| L-MTP (Liu 2025) | 跳跃式非相邻 future token | k | 启发式抽样,FSP 表 3 的 MTP-Skip 已证明这条路次于 BoW |
| MuToR (Gerontopoulos 2025) | register tokens 预测随机偏移 token | 无新 head(用 register) | 同上,heuristic 抽样 |
| TRELAWNEY (Thankaraj 2025) | 训练数据里插入 future window 块 | 无新 head | 用数据增强代替 head;仍是 token 粒度 |
| NCP / Next-latent (Teoh 2025) [NCP note] | forward 自己未来位置的 hidden state | 1 | self-distill;FSP 用 reverse teacher 而非 forward |
| SemFormer (Yin 2024) | planning token 预测 future autoencoder embedding | 1 个,但仅在 planning 位 | 位置稀疏 + 用 autoencoder(L-to-R)而非 reverse LM |
| BST (Hu 2024) / Twin Networks (Serdyuk 2017) | 对齐 forward/reverse hidden state | — | FSP 是 distill,不是 weight-sharing 或 dual encoder |
| Meet-in-the-Middle (Nguyen 2023) | forward/reverse 输出分布对齐 | — | 需要两个分布严格匹配;FSP 不要求 |
| ProphetNet (Qi 2020) | n-gram via shared self-attn | — | 仍逐 token 监督 + 保留位置 |
| FSP (本文) | 1 个 vector summary(BoW 或 RevLM hidden) | 1 | 统一框架,可换 summary 形式;long horizon + adaptive |
10. 局限 / 个人 take / 待验证问题
- RevLM 是 2× 训练成本。论文用"compute-rich, data-limited"justify 这件事,但实际部署里 doubling pretraining FLOPs 是巨大的代价。是否能用更小的 RevLM teacher(比如 forward 1/4 大小)拿到 80% 的收益?这是最显然的 follow-up。
- 3B vs 8B scaling:3B 上 DS-MTP 反而比 FSP-RevLM 强 —— 在 ARC、HumanEval+、MBPP 上 DS-MTP 都更高。FSP-RevLM 的优势是跟着 scale 涨。8B 之外会怎样?70B / 405B 是否会更显眼,还是会被新的 bottleneck 卡住?
- BoW 的 τ 选择敏感:τ=12 在 ARC-Easy 是 0.737,τ=100 跌到 0.662。说明 hand-crafted summary 真的是要看 task / data 调的;FSP-RevLM 优势之一就是不用挑 τ(用整段 suffix)。
- RevLM 学到了什么? Figure 7 给了 attention 可视化(intra-component subgroup),但没有 mechanistic 验证 hidden state 真的编码了"未来语义"而非"reverse direction 的 syntactic 残影"。linear probe / activation patching 是 paper 自己点名的 follow-up。
- 跟 RL post-training 的衔接:论文猜测 base model coverage 高 → RL 上限高,但没做实验证明。这是非常有 leverage 的实验,如果在 GSM8K 上 RL on FSP-base 真的拿到比 RL-on-NTP-base 更高的天花板,FSP 的实用价值就坐实了。
- summary 形式还有什么没被尝试的? VQ code / clustering id / contrastive embedding(SimCLR-style 学未来一段的 representation)都可以塞进 FSP 框架里。论文只测了 BoW + RevLM 两点。
follow-up 实验清单
- 用 forward 1/4 size 的 RevLM 重做 FSP-RevLM,看 8B 上能保留多少 gain。
- 对 RevLM hidden state 做 linear probe:能预测多少 future bag-of-words / future entity?
- RL on FSP-base vs RL on NTP-base,直接对比 post-training 上限。
- 把 BoW 替换成 contrastive learning 的 future embedding,看是否能拿到 BoW 的鲁棒性 + RevLM 的 adaptive。
- 组合 FSP + DS-MTP 的"slight teacher forcing":aux head 既看中间 GT 又预测 summary。
- 在 long-context (32K+) 上的表现 —— RevLM 的优势在长依赖上应该更明显。