Predicting the Order of Upcoming Tokens Improves Language Modeling
速读卡片 (TL;DR)
一句话:把 MTP 那种"精确预测 t+1…t+k 个 token"的辅助任务松弛为"按 proximity 排序未来 W 个 token"——一个 ListNet 风格的 learning-to-rank loss,只额外加 一层 unembedding,在 340M / 1.8B / 7B 的 pretrain + 续训中全面胜过 NTP / MTP / DS-MTP。
立场:MTP 不灵的根因不是 look-ahead 思路错,而是把 look-ahead 任务定得太死(精确预测某个 offset 上的 token)。改成排序就把分类问题降成相对位置问题,representation 反而学得更好。架构成本几乎为零,值得作为通用 NTP 的低风险替代去试。
1. 动机:为什么 MTP 不够用
1.1 历史脉络:NTP 的批评与 MTP 的兴起
故事要从 NTP 的批评说起。NTP(next-token prediction)是 Shannon 时代就奠定的 objective,简单粗暴地最小化 −Σ log P(x_{t+1} | x_{0:t})。但近年它收到两路质疑:
- LeCun 的批评(2024):NTP 在推理时误差会逐 token 累积,长序列必然崩。
- Bachmann & Nagarajan 的反驳(ICML 2024):问题不在推理时累积,而在训练时 teacher forcing 让模型根本学不会一个准确的 next token predictor。换句话说,模型学到的并不是真正"看清未来"的表征,只是"在给定真实前缀时蒙下一个 token"的近视器。他们设计的 star graph pathfinding 任务把这个缺陷放大成一个二分判别题——后面 §8 会详谈。
第二种批评更棘手,因为它指向 representation 本身。社区于是开始找能逼模型多看几步的辅助 objective。ProphetNet(2020)预测未来 n-gram,MTP(Gloeckle 2024 / Meta)在 transformer trunk 之上叠 N 个并行的 transformer block 头,每头预测 t+n;DS-MTP(DeepSeek-V3)把头改成顺序串联,每一头吃前一头的隐状态加 RMSNorm 后的下一个 ground-truth embedding。
问题是:MTP 的收益不稳定。原 MTP 论文 Appendix G 自承在 standard NLP benchmark 上不一定占优;只有在 coding / math 这类需要 look-ahead 的生成任务上、而且模型 ≥1B–3B 时,它才稳定胜过 NTP。DeepSeek-V3 干脆只用 N=3,这本身就是承认大 N 不 work。
1.2 别的方案为什么不够 — 多 token 预测的"看得越远学得越烂"陷阱
作者用一个 16M 的小 transformer 训了带 16 个 MTP 头的模型,把每个头的训练 loss 单独画出来(Figure 2)。结果非常干净:
这张图是整篇论文的论点骨架:把"准确预测某个远位置的 token"塞给小模型,等于给它布置了一个它本来就做不到的题,梯度大部分时间在告诉它"你又错了",并不能转化成有用的 representation 信号。
对比一下作者列举的几条替代路径:
| 方案 | 辅助信号 | 架构成本 | 痛点 |
|---|---|---|---|
| NTP only | 无 | 0 | teacher forcing 下不学 look-ahead;星图任务直接 fail |
| MTP (Meta) | 每个 offset 一个独立 next-token CE | N−1 个 transformer block | 远 offset loss 几乎不降;小模型反而被拖累 |
| DS-MTP (DeepSeek-V3) | 顺序展开,每头看前一头 + 真实 embedding | ≈ MTP + 2 RMSNorm + 1 linear | 实际只敢用 N=3;在标准 benchmark 上和 MTP 一档 |
| ProphetNet n-gram | 预测未来 n-gram | 多头 | 仍是精确预测,且为 seq2seq 设计 |
| MLM / span corruption (BERT/T5) | 双向重建 | encoder-only | 不能直接做 autoregressive 生成 |
| TOP(本文) | 把未来 W 个 token 排序 | 1 层 unembedding | 只能 greedy(NTP head 仍保留以做采样) |
1.3 为什么"换成排序"这事不平凡
把"预测"换成"排序"听起来像 trivial 改写,但论文的非平凡之处在于三件事的结合:
- 任务难度的显式松弛。"在 vocab V 里精确 argmax 出 t+5 是哪个 token"是一个 1-of-V 的分类题(以 V≈50k 算,信息熵 ≈ 15.6 bits/token);而"按 proximity 给未来 W 个 token 一个排序分布"只要 model 能大致把"应该出现的 token"集中在 softmax 的 head 上即可。后者是软监督,梯度永远指向"再多关注一下这一片词"而不是"你又押错了"。
- 架构成本几乎为零。没有额外 transformer block。因为 NTP head 和 TOP head 的目标是对齐的(都把 next token 排第一),它们可以共享同一个 hidden state htL,只各自带一个线性 unembedding。inference 时直接扔掉 TOP head,模型退化回普通 transformer——这一点对部署友好到几乎没成本。
- ListNet 的 top-one 公式刚好契合 LM。LM 本来就在 vocab 上输出 logits,
softmax(y_t)(target proximity scores)和softmax(U_TOP h_t)(model logits)直接做交叉熵,跟 NTP 的 CE 形式一模一样,只是 target 从 one-hot 变成了带 −∞ 的 ranking 分布。这种"接口一致"的简洁性是它能写进现有 fused linear cross-entropy kernel 不增加 overhead 的关键。
2. 背景速查
| 记号 / 术语 | 含义 |
|---|---|
| T, D, V | 序列长 / hidden dim / vocab size |
| htL | 第 t 个位置最后一层 transformer 的 hidden(NTP / TOP head 都吃它) |
| UNTP, UTOP | 两个并列的 linear unembedding ℝD→ℝV |
| N (MTP / DS-MTP) | 预测的未来 token 数(含 next),典型 2、3、4 |
| W (TOP) | 窗口大小,默认 = 序列长(实验里也试过 4 / 16 / 128 / 1024 / 4096) |
| proximity score y[t,v] | = W − d,d 是 token v 在 t 之后首次出现的距离;d 不在 (0,W] 时为 −∞ |
| ListNet top-one loss | −softmax(y) · log softmax(ŷ),把两个排序分布之间用 cross-entropy 衡量 |
| star graph G(d, l) | 从中心 0 出发分 d 条长度为 l 的链;给 edge list+起点+终点,要求输出从起点到终点的路径 |
| self-speculative decoding | 用模型自己的"未来头"draft、再用同一个模型 verify,无需额外 draft model |
回顾:NTP / MTP / DS-MTP 三个 loss 的精确写法
关键差异:MTP 各头并行、共享 trunk 的倒数第二层输出;DS-MTP 各头串行、每头还吃真值 embedding。两者都需要额外的 transformer block。
3. 方法 · TOP target 是怎么造的
给定输入 token 序列 x(长度 T+W),TOP target y 是一个 (T, V) 的张量。Algorithm 1 的本质是逆序扫一遍序列,顺手记下每个 vocab token "下一次出现在哪个位置":
Algorithm 1 的"逆序扫一遍"为什么对
朴素实现会对每个 t 都跑一遍 W 长的 lookahead,O(TWV)。Algorithm 1 用一个下次出现位置数组 n[v]:
- 初始 y ← −∞,n[v] ← T+W(哨兵)
- 从 t = T+W−1 倒着到 0:
- 先更新 n[x[t]] ← t(把当前 token 的"最近一次出现"标到 t)
- 若 t < T,对每个 v 计算 d = n[v] − t,若 0 < d ≤ W 则 y[t,v] = W−d
这个写法的妙处是"未来"在反向扫的视角下变成了"过去"——因为我们倒着走,先看到的就是更远的时间点,n[v] 总是记着 v 在 t 之后最早出现的位置。论文原话是"in practice, we have an optimized Triton kernel that creates the target sequence on the fly during training and practically incurs no overhead"——所以这一步可以和数据 loader 融合,不预处理整个 dataset 也行。
4. 方法 · 损失函数与架构
4.1 ListNet top-one loss
有了 target y,损失就是 ListNet 的标准形式:
等价于一个"软标签 cross-entropy":把 target ranking 通过 softmax 变成概率分布(注意 −∞ 的位置 softmax 后是 0),把 model output 也 softmax 化,然后两个分布做 CE。最终总损失是平等相加:
论文 §5.5.2 还做了 α 加权扫描,惊人地发现 α=0.9(即 TOP 占 9 成)在大部分 benchmark 上最好——意味着在他们的 setup 下,TOP 信号比 NTP 信号还要有用。这一点反过来支持"NTP 是 TOP 的特例 / 弱化"的直觉(NTP target 等价于一个只标 t+1 的 one-hot,TOP target 是把它扩展成多个 token 上的连续 ranking)。
4.2 架构图:四种 loss 的比较
4.3 inference 行为
训练完后,UTOP 直接扔掉,模型推理时和普通 NTP transformer 完全等价(同样的 sampling、temperature、KV cache、speculative decoding 兼容性都不变)。这是 TOP 相对 MTP 的一个隐藏优势:不需要在部署时改 inference engine。论文也尝试用 TOP head 做 self-speculative decoding(把 TOP 排序前 k 个 token 拼起来当 draft),但效果不如 MTP / DS-MTP——因为 TOP head 学到的是 ranking,不是真正的"按位置预测"。
5. 复杂度对比 · 为什么"只加一层"是对的尺度
| 方法 | 额外 FLOPs (per token) | 额外参数 | D=4096, V=128k, N=4 时 |
|---|---|---|---|
| MTP | (N−1)(24D² + 2DV) | (N−1)(16D² + 2D) | ~3.4 GFLOPs / 805 M params |
| DS-MTP | (N−1)(30D² + 2DV) | (N−1)(16D² + 2D) | ~3.5 GFLOPs / 805 M params |
| TOP | 2DV | DV | ~1.05 GFLOPs / 525 M params (但只算辅助层) |
对于一个 7B 模型来说:
- MTP-4 增加 ≈ 800M 额外参数(主 trunk 还要砍 3 层来"补偿"参数预算)
- TOP 增加 D·V ≈ 0.5B 参数,且不需要砍 trunk——因为 D·V 主要是 unembedding 矩阵,本身可以和 input embedding tied(论文虽然没特别说 tying,但实际部署完全可以)
更重要的是训练时这层只需要做一次 matmul + softmax + 软交叉熵,作者改造了 Yang & Zhang 2024 的 fused linear cross-entropy Triton kernel,把 unembedding 和 loss 算成一块、分块流水,实测和不加 TOP 的训练速度几乎一致。MTP 那 N−1 个完整 transformer block 是无法这样吃掉的,因为每块要算 attention + MLP。
6. 具体例子 · TOP 和 MTP 在哪一步分道扬镳
取一段 token: "never gonna give you up never gonna let you down"(序列长 10,W=4),看 t=0 处 (token 是 never) 三种 loss 在同一个 hidden h0L 下要求模型做什么:
| Loss | 这一步要求模型预测什么 / 怎么算 loss |
|---|---|
| NTP | 看到 never,预测下一个 token = gonna;target one-hot at "gonna" |
| MTP-4 | 4 个独立 head 分别预测 gonna(t+1)、give(t+2)、you(t+3)、up(t+4),各自 CE 求和;t+4 head 的 loss 比 t+1 head 高 2–3 倍(因为多元 ambiguity) |
| DS-MTP-3 | 串行预测 t+1, t+2, t+3,t+2 那一头吃 RMSNorm(gonna) 的真实 embedding,所以"上下文越来越完整"——但反过来也意味着 t+3 那头其实只在做"看到 never+gonna+give 后预测 you",和 NTP 在 t+2 处做的事情高度重叠 |
| TOP | target softmax 后是 {gonna: e³/Z, give: e²/Z, you: e¹/Z, up: e⁰/Z},Z=e³+e²+e¹+e⁰≈30.19。要求 model logits 经 softmax 后整体形状逼近这个分布——不必精确到位置,但要把这 4 个 token 集中起来,且 gonna 占大头 |
关键差异:当 model 把概率质量分给 gonna 和 give 而非全压在 gonna 时,NTP 和 MTP 各 head 的 loss 都会上升,但 TOP 的 loss 反而是合理的——因为 TOP 的 target 本来就允许 give 拿到 e²/Z ≈ 24% 的概率。这正是"软监督"的体现:模型不会因为"略微关心了 t+2 token"而被惩罚,反而因此被奖励。
7. 实验关键结果
7.1 通用 NLP benchmark(9 个任务,3 种规模)
挑出最关键的几个数(Table 2,加粗了胜者):
| SIZE | 方法 | LAMBADA acc↑ | HellaSwag↑ | ARC-C↑ | MMLU↑ | TriviaQA EM↑ |
|---|---|---|---|---|---|---|
| 340M | NTP | 36.35 | 42.53 | 28.84 | 29.81 | 4.93 |
| MTP | 35.32 | 42.73 | 29.86 | 29.08 | 2.55 | |
| DS-MTP | 34.66 | 40.29 | 27.56 | 28.47 | 0.87 | |
| TOP | 37.07 | 43.57 | 29.35 | 30.09 | 4.37 | |
| 1.8B | NTP | 49.58 | 60.05 | 38.65 | 35.34 | 11.85 |
| MTP | 47.93 | 58.29 | 40.61 | 34.76 | 15.98 | |
| DS-MTP | 48.71 | 57.48 | 40.44 | 35.01 | 12.06 | |
| TOP | 50.34 | 60.45 | 42.32 | 36.21 | 18.93 | |
| 7B | NTP | 55.89 | 67.44 | 45.65 | 39.47 | 24.28 |
| MTP | 53.13 | 65.85 | 45.56 | 38.14 | 23.36 | |
| DS-MTP | 55.62 | 66.03 | 44.37 | 38.16 | 18.54 | |
| TOP | 57.03 | 68.73 | 46.42 | 39.65 | 30.90 |
读法:
- 7B 处 TOP 在 LAMBADA / HellaSwag / ARC / MMLU / TriviaQA 全部胜出,差距随规模放大(TriviaQA 7B: 30.90 vs NTP 24.28,相对 +27%)
- MTP 7B 在 LAMBADA / HellaSwag 反而低于 NTP——验证了原 MTP 论文 Appendix G 的"在通用 NLP 上不行"
- DS-MTP 全程没看到比 MTP 显著好的迹象,这一发现和 DeepSeek-V3 大量工程优化(包括只用 N=3)结合后,提示 DS-MTP 更多是 inference 加速 trick 而非 representation 提升器
7.2 续训:Math / Code
| 1.8B 续训 | GSM8K↑ | MATH↑ | HumanEval@64↑ | MBPP@64↑ |
|---|---|---|---|---|
| NTP | 39.20 | 13.34 | 33.53 | 46.70 |
| MTP | 38.59 | 15.00 | 34.76 | 48.20 |
| DS-MTP | 2.65 | 3.66 | 23.78 | 31.90 |
| TOP | 45.64 | 16.66 | 38.41 | 44.80 |
1.8B GSM8K 上 TOP 比 NTP 高 6.4 个点,比 MTP 高 7 个点——这是论文最有说服力的实验之一,因为 math reasoning 是公认需要 look-ahead 的任务,TOP 显著优于专为 look-ahead 设计的 MTP。DS-MTP 的崩溃(2.65)论文怀疑是 overfitting,需要更多分析。
7.3 self-speculative decoding 是 TOP 的弱项
Table 6:用 TOP head 做 self-speculative draft,7B 上每次 forward 平均 accept 1.52 个 token,而 MTP 是 2.49,DS-MTP 是 3.03。这不奇怪——TOP head 学的是 ranking,不是"在准确位置预测准确 token",当 draft sequence 必须严格按 t+1, t+2 …排序时,TOP 的 top-k 排序天生比 MTP 各位置 head 的精确预测要差。
7.4 消融:窗口大小 W 与 loss 权重 α
- W:所有 W ∈ {4, 16, 128, 1024, 4096} 都比 NTP 好,任务最优 W 不一致——LAMBADA 偏好 W=4(短距离),SciQ 偏好 W=1024。这说明 W 是个软超参,不必精调。
- α:α 越大(TOP loss 占比越高)越好,α=0.9 在 6/7 个 benchmark 上最优。这个发现颠覆了"NTP 是主、TOP 是辅"的直觉——TOP 自己就足以驱动训练,NTP 只是为了保留 sampling 用的概率分布。
8. 星图任务 · 最干净的可分性证明
这是论文最让人 convinced 的实验。任务来自 Bachmann & Nagarajan 2024,设计目的就是把 NTP 的"teacher forcing 短视症"暴露出来:
| 模型 | params | G(3,3) | G(3,5) | G(5,3) | G(5,5) |
|---|---|---|---|---|---|
| NTP | 14.2M | 33.8 | 32.5 | 19.5 | 0.1 |
| MTP-2 | 16.0M | 100 | 59.0 | 19.6 | 0.1 |
| MTP-4 | 19.6M | 100 | 100 | 100 | 19.5 |
| DS-MTP-2 | 16.6M | 100 | 32.5 | 100 | 19.2 |
| DS-MTP-4 | 20.7M | 100 | 33.6 | 100 | 19.3 |
| TOP | 14.2M | 100 | 100 | 100 | 100 |
读这张表:
- G(5,5) 上,所有 baseline(包括 MTP-4 和 DS-MTP-4)都 fail;TOP 唯一全对
- TOP 的参数量最少(14.2M,等于 NTP),却胜过 20.7M 的 DS-MTP-4。这直接证明了 TOP 的优势不是参数容量
- 注意 MTP 头数和路径长度的关系:MTP-2 在路径长 5 的图上 fail,MTP-4 又在 G(5,5) fail——头数必须 ≥ 路径长才学得动,这是"精确预测"任务的硬限制
TOP 为什么对所有 G 都 100%?因为 TOP 的 target 在第一步就告诉它"未来 5 步内 goal 7 会出现",而 ranking 损失允许模型先粗略锁定"应该走过 7 的那一支",再让 NTP head 输出第一个 token 2——这本质上是把"先全局规划再生成"显式注入了训练信号。
9. 与同类工作对比
| 方法 | 辅助信号粒度 | 架构成本 | 对小模型友好 | 对长 look-ahead 友好 | inference 改动 |
|---|---|---|---|---|---|
| NTP | none | 0 | — | 否 | — |
| MTP (Meta) | 每 offset 精确分类 | N−1 transformer block | 差(<1B 反而掉点) | 差(t+8 比 t+4 差) | 可做 self-spec(2.5×) |
| DS-MTP | 串行精确分类 + 真值 embedding | 稍贵于 MTP | 中 | 中 | self-spec 极强(3.0×) |
| ProphetNet (2020) | 未来 n-gram | 多头 | — | — | seq2seq only |
| Register tokens for MTP (Gerontopoulos 2025) | 把 NTP 转 MTP 的轻量改造 | + register tokens | — | — | — |
| MLM / span (BERT/T5/UL2) | 双向重建 | encoder | — | — | 不能直接 AR 生成 |
| TOP | 未来 W token 的 排序 | 1 unembedding | 好(340M 已胜) | 好(W=4096 也 work) | 无(扔掉 TOP head) |
另一类相关工作是 TLM-style 的"辅助 prediction loss":比如做一个旁路头预测 sentence-level 属性、文档来源、bag-of-future-words 等。TOP 和它们的本质区别是 TOP 的 target 是从原序列直接计算的,无需额外标注或外部信号,且 ranking loss 的形式让它能和现有 fused CE kernel 几乎无缝结合。这种"零标注、零 inference 改动、零额外网络"的三零定位是它最大的工程吸引力。
10. 局限 / 个人 take / 待验证问题
- 规模上限未触达。7B / 104B tokens 是作者算力天花板。TOP 是否在 70B + 多 trillion tokens 时仍占优,完全未知。MTP 论文的关键发现就是"小模型不行,1B 后才行",TOP 也存在反向风险:大模型可能 saturated 不再需要这种软监督。
- α=0.9 最优很可疑。TOP loss 比 NTP loss 还重要 9 倍?要么 TOP 信号确实更密集,要么 ListNet 的归一化让 loss 数量级差很多导致 α 不可比。这块需要看绝对 loss 值和 grad norm。
- self-speculative decoding 弱。TOP 牺牲了"精确位置预测"换"排序",draft 质量必然弱。对追求 inference 加速的场景,TOP + 外部 EAGLE-style draft 才是正解,而不是用 TOP head 自己做 draft。
- DS-MTP 在续训上崩盘(GSM8K 2.65)论文怀疑是 overfit,但同样的 hyperparam 别的方法都没崩——这暗示 DS-MTP 对 continued training 不稳健,值得另起一篇剖析。
- TOP target 的 −∞ 处理。vocab V≈50k,窗口 W=4096,意味着每个 t 的 target 里 ≥ 92% 的位置是 −∞(softmax 后是 0)。这种极稀疏 target 在 fp16/bf16 下数值稳定吗?Triton kernel 是不是用 logsumexp + masking 来避免 NaN?论文没细讲。
- "TOP head 扔掉"的对偶性。既然 inference 不用 TOP head,理论上 TOP 等价于"用排序信号 regularize 主 trunk 的 hidden 表征"。如果换成更复杂的 ranking head(比如非线性 MLP),会不会拿到更强的 regularization?论文没尝试。
待验证清单
- 把 TOP head 和 NTP head tied weights,会怎么样?如果 tied 后还能 work,说明 TOP 的全部增益来自 target 形状而非额外参数
- 在 SFT / RLHF 阶段加 TOP loss 是否还有用?(论文只做了 pretrain + 续训)
- 把 TOP 和 EAGLE 的 draft head 训练目标融合:draft head 也用 ranking loss,会不会同时改善 representation 和 acceptance length?
- 消融"target 中的 W−d vs binary mask"。如果直接把 target 改成 {appears: 1, not appear: 0} 的 multi-label,是不是还行?这能分离"排序"和"出现"两种信号
- 把 TOP loss 仅施加在最后 K 层 hidden 上(而不是只在最后一层)——会不会让中间层也学到 look-ahead?