Predicting the Order of Upcoming Tokens Improves Language Modeling

MBZUAI · Zayd M. K. Zuhri, Erland Hilman Fuadi, Alham Fikri Aji · 2026-02-17 · arXiv:2508.19228v2
关键词: token order prediction · learning-to-rank · ListNet · MTP · DS-MTP · auxiliary loss · pretraining

速读卡片 (TL;DR)

一句话:把 MTP 那种"精确预测 t+1…t+k 个 token"的辅助任务松弛为"按 proximity 排序未来 W 个 token"——一个 ListNet 风格的 learning-to-rank loss,只额外加 一层 unembedding,在 340M / 1.8B / 7B 的 pretrain + 续训中全面胜过 NTP / MTP / DS-MTP。

2DV
每 token 训练 FLOPs(MTP 是 (N−1)·24D²)
G(5,5) 100%
星图任务 TOP 唯一全对
α=0.9
TOP loss 权重越大越好

立场:MTP 不灵的根因不是 look-ahead 思路错,而是把 look-ahead 任务定得太死(精确预测某个 offset 上的 token)。改成排序就把分类问题降成相对位置问题,representation 反而学得更好。架构成本几乎为零,值得作为通用 NTP 的低风险替代去试。


1. 动机:为什么 MTP 不够用

1.1 历史脉络:NTP 的批评与 MTP 的兴起

故事要从 NTP 的批评说起。NTP(next-token prediction)是 Shannon 时代就奠定的 objective,简单粗暴地最小化 −Σ log P(x_{t+1} | x_{0:t})。但近年它收到两路质疑:

第二种批评更棘手,因为它指向 representation 本身。社区于是开始找能逼模型多看几步的辅助 objective。ProphetNet(2020)预测未来 n-gram,MTP(Gloeckle 2024 / Meta)在 transformer trunk 之上叠 N 个并行的 transformer block 头,每头预测 t+n;DS-MTP(DeepSeek-V3)把头改成顺序串联,每一头吃前一头的隐状态加 RMSNorm 后的下一个 ground-truth embedding。

问题是:MTP 的收益不稳定。原 MTP 论文 Appendix G 自承在 standard NLP benchmark 上不一定占优;只有在 coding / math 这类需要 look-ahead 的生成任务上、而且模型 ≥1B–3B 时,它才稳定胜过 NTP。DeepSeek-V3 干脆只用 N=3,这本身就是承认大 N 不 work。

1.2 别的方案为什么不够 — 多 token 预测的"看得越远学得越烂"陷阱

作者用一个 16M 的小 transformer 训了带 16 个 MTP 头的模型,把每个头的训练 loss 单独画出来(Figure 2)。结果非常干净:

Training Steps → Loss t+1 t+16 TOP loss MTP per-head losses 排成"楼梯"·越远越难学 楼梯越往上 loss 下降越慢 → 远 token 几乎学不动
图 1. 复刻自原文 Figure 2 的示意。MTP 16 个头的训练 loss 自上而下按 offset t+1…t+16 等距排列,远的头 loss 高且下降慢。同尺寸 TOP 模型只有一条 loss,且明显更低。

这张图是整篇论文的论点骨架:把"准确预测某个远位置的 token"塞给小模型,等于给它布置了一个它本来就做不到的题,梯度大部分时间在告诉它"你又错了",并不能转化成有用的 representation 信号。

对比一下作者列举的几条替代路径:

方案辅助信号架构成本痛点
NTP only0teacher forcing 下不学 look-ahead;星图任务直接 fail
MTP (Meta)每个 offset 一个独立 next-token CEN−1 个 transformer block远 offset loss 几乎不降;小模型反而被拖累
DS-MTP (DeepSeek-V3)顺序展开,每头看前一头 + 真实 embedding≈ MTP + 2 RMSNorm + 1 linear实际只敢用 N=3;在标准 benchmark 上和 MTP 一档
ProphetNet n-gram预测未来 n-gram多头仍是精确预测,且为 seq2seq 设计
MLM / span corruption (BERT/T5)双向重建encoder-only不能直接做 autoregressive 生成
TOP(本文)把未来 W 个 token 排序1 层 unembedding只能 greedy(NTP head 仍保留以做采样)

1.3 为什么"换成排序"这事不平凡

把"预测"换成"排序"听起来像 trivial 改写,但论文的非平凡之处在于三件事的结合:

  1. 任务难度的显式松弛。"在 vocab V 里精确 argmax 出 t+5 是哪个 token"是一个 1-of-V 的分类题(以 V≈50k 算,信息熵 ≈ 15.6 bits/token);而"按 proximity 给未来 W 个 token 一个排序分布"只要 model 能大致把"应该出现的 token"集中在 softmax 的 head 上即可。后者是软监督,梯度永远指向"再多关注一下这一片词"而不是"你又押错了"。
  2. 架构成本几乎为零。没有额外 transformer block。因为 NTP head 和 TOP head 的目标是对齐的(都把 next token 排第一),它们可以共享同一个 hidden state htL,只各自带一个线性 unembedding。inference 时直接扔掉 TOP head,模型退化回普通 transformer——这一点对部署友好到几乎没成本。
  3. ListNet 的 top-one 公式刚好契合 LM。LM 本来就在 vocab 上输出 logits,softmax(y_t)(target proximity scores)和 softmax(U_TOP h_t)(model logits)直接做交叉熵,跟 NTP 的 CE 形式一模一样,只是 target 从 one-hot 变成了带 −∞ 的 ranking 分布。这种"接口一致"的简洁性是它能写进现有 fused linear cross-entropy kernel 不增加 overhead 的关键。
反向论证:为什么不直接预测一个"未来 W 个 token 的集合"(multi-label,如 Ahn et al. 2025)?因为集合丢了顺序信息,t+1 和 t+W 同等重要,模型没动力优先关心近的 token,容易和 NTP head 抢梯度方向。TOP 的 proximity score W−d 给近 token 更高 target 分,先验地把 ranking 稳在"近 → 远"。

2. 背景速查

记号 / 术语含义
T, D, V序列长 / hidden dim / vocab size
htL第 t 个位置最后一层 transformer 的 hidden(NTP / TOP head 都吃它)
UNTP, UTOP两个并列的 linear unembedding ℝD→ℝV
N (MTP / DS-MTP)预测的未来 token 数(含 next),典型 2、3、4
W (TOP)窗口大小,默认 = 序列长(实验里也试过 4 / 16 / 128 / 1024 / 4096)
proximity score y[t,v]= W − d,d 是 token v 在 t 之后首次出现的距离;d 不在 (0,W] 时为 −∞
ListNet top-one loss−softmax(y) · log softmax(ŷ),把两个排序分布之间用 cross-entropy 衡量
star graph G(d, l)从中心 0 出发分 d 条长度为 l 的链;给 edge list+起点+终点,要求输出从起点到终点的路径
self-speculative decoding用模型自己的"未来头"draft、再用同一个模型 verify,无需额外 draft model
回顾:NTP / MTP / DS-MTP 三个 loss 的精确写法
LNTP = −Σt log softmax(UNTP(htL))[xt+1]
LMTP = −Σt Σn=1..N log softmax(UNTP(Fn(htL−1)))[xt+n]
LDS-MTP: 各头串行,htL+n−1 = Fn([RMSN(htL+n−2); RMSN(E(xt+n−1))]),N=3 已是上限

关键差异:MTP 各头并行、共享 trunk 的倒数第二层输出;DS-MTP 各头串行、每头还吃真值 embedding。两者都需要额外的 transformer block。

3. 方法 · TOP target 是怎么造的

给定输入 token 序列 x(长度 T+W),TOP target y 是一个 (T, V) 的张量。Algorithm 1 的本质是逆序扫一遍序列,顺手记下每个 vocab token "下一次出现在哪个位置":

never gonna give you up never gonna let t=0 1 2 3 4 5 6 7 在 t=0 (never) 处看未来 W=4 个 token: gonna, give, you, up gonna: d=1, W−d=3 give: d=2, W−d=2 you: d=3, W−d=1 up: d=4, W−d=0 其他 v ∈ V: y[0,v] = −∞ → y[0] 在 vocab 维度上是: {gonna:3, give:2, you:1, up:0, 其余: −∞} softmax(y[0]) 把概率质量集中在这 4 个 token,且 gonna 概率最大
图 2. TOP target 构造示例。窗口 W=4 时,t=0 处看到的"未来 4 token"被赋分 3,2,1,0,其它 vocab 全 −∞,经 softmax 得到一个稀疏的 ranking 分布。

Algorithm 1 的"逆序扫一遍"为什么对

朴素实现会对每个 t 都跑一遍 W 长的 lookahead,O(TWV)。Algorithm 1 用一个下次出现位置数组 n[v]:

  1. 初始 y ← −∞,n[v] ← T+W(哨兵)
  2. 从 t = T+W−1 倒着到 0:
    • 先更新 n[x[t]] ← t(把当前 token 的"最近一次出现"标到 t)
    • 若 t < T,对每个 v 计算 d = n[v] − t,若 0 < d ≤ W 则 y[t,v] = W−d

这个写法的妙处是"未来"在反向扫的视角下变成了"过去"——因为我们倒着走,先看到的就是更远的时间点,n[v] 总是记着 v 在 t 之后最早出现的位置。论文原话是"in practice, we have an optimized Triton kernel that creates the target sequence on the fly during training and practically incurs no overhead"——所以这一步可以和数据 loader 融合,不预处理整个 dataset 也行。

注意:同一个 token 在窗口内多次出现时只算最近一次的距离,远处的重复出现被覆盖掉了。这是有意的——proximity 看的是"下一次什么时候来",不是"在窗口里出现几次"。

4. 方法 · 损失函数与架构

4.1 ListNet top-one loss

有了 target y,损失就是 ListNet 的标准形式:

LTOP = −Σt softmax(yt) · log softmax(UTOP htL)

等价于一个"软标签 cross-entropy":把 target ranking 通过 softmax 变成概率分布(注意 −∞ 的位置 softmax 后是 0),把 model output 也 softmax 化,然后两个分布做 CE。最终总损失是平等相加:

L = LNTP + LTOP

论文 §5.5.2 还做了 α 加权扫描,惊人地发现 α=0.9(即 TOP 占 9 成)在大部分 benchmark 上最好——意味着在他们的 setup 下,TOP 信号比 NTP 信号还要有用。这一点反过来支持"NTP 是 TOP 的特例 / 弱化"的直觉(NTP target 等价于一个只标 t+1 的 one-hot,TOP target 是把它扩展成多个 token 上的连续 ranking)。

4.2 架构图:四种 loss 的比较

NTP Trunk hL U_NTP L_NTP MTP (Meta) Trunk hL−1 F1 F2 F3 F4 +1 transformer block each U_NTP (shared) Σ_n CE(t+n) DS-MTP (DeepSeek) Trunk hL−1 F1 F2 F3 + ē_{t+1} + ē_{t+2} L_DS-MTP TOP (本文) Trunk hL U_NTP U_TOP L_NTP L_TOP 仅多一层线性层 inference 时丢掉 U_TOP, 退回普通 transformer
图 3. NTP / MTP / DS-MTP / TOP 四种 loss 的 head 结构对比。MTP 在 trunk 后并联 N−1 个 transformer block,DS-MTP 串联且每头吃下一个真值 embedding;TOP 只在最后一层 hidden 上并联一个线性 unembedding。

4.3 inference 行为

训练完后,UTOP 直接扔掉,模型推理时和普通 NTP transformer 完全等价(同样的 sampling、temperature、KV cache、speculative decoding 兼容性都不变)。这是 TOP 相对 MTP 的一个隐藏优势:不需要在部署时改 inference engine。论文也尝试用 TOP head 做 self-speculative decoding(把 TOP 排序前 k 个 token 拼起来当 draft),但效果不如 MTP / DS-MTP——因为 TOP head 学到的是 ranking,不是真正的"按位置预测"。

5. 复杂度对比 · 为什么"只加一层"是对的尺度

方法额外 FLOPs (per token)额外参数D=4096, V=128k, N=4 时
MTP(N−1)(24D² + 2DV)(N−1)(16D² + 2D)~3.4 GFLOPs / 805 M params
DS-MTP(N−1)(30D² + 2DV)(N−1)(16D² + 2D)~3.5 GFLOPs / 805 M params
TOP2DVDV~1.05 GFLOPs / 525 M params (但只算辅助层)

对于一个 7B 模型来说:

更重要的是训练时这层只需要做一次 matmul + softmax + 软交叉熵,作者改造了 Yang & Zhang 2024 的 fused linear cross-entropy Triton kernel,把 unembedding 和 loss 算成一块、分块流水,实测和不加 TOP 的训练速度几乎一致。MTP 那 N−1 个完整 transformer block 是无法这样吃掉的,因为每块要算 attention + MLP。

这就是"对的尺度"的含义:辅助 objective 的目的是给 trunk 加一点训练时的引导信号,而不是引入一整个并行的子网络。MTP 把"辅助"做成了"加一个 mini-model",成本被 trunk 的容量稀释,实验也确实需要 ≥1B 才看到稳定收益。TOP 把辅助回归到它该有的轻量级形态。

6. 具体例子 · TOP 和 MTP 在哪一步分道扬镳

取一段 token: "never gonna give you up never gonna let you down"(序列长 10,W=4),看 t=0 处 (token 是 never) 三种 loss 在同一个 hidden h0L 下要求模型做什么:

Loss这一步要求模型预测什么 / 怎么算 loss
NTP看到 never,预测下一个 token = gonna;target one-hot at "gonna"
MTP-44 个独立 head 分别预测 gonna(t+1)、give(t+2)、you(t+3)、up(t+4),各自 CE 求和;t+4 head 的 loss 比 t+1 head 高 2–3 倍(因为多元 ambiguity)
DS-MTP-3串行预测 t+1, t+2, t+3,t+2 那一头吃 RMSNorm(gonna) 的真实 embedding,所以"上下文越来越完整"——但反过来也意味着 t+3 那头其实只在做"看到 never+gonna+give 后预测 you",和 NTP 在 t+2 处做的事情高度重叠
TOPtarget softmax 后是 {gonna: e³/Z, give: e²/Z, you: e¹/Z, up: e⁰/Z},Z=e³+e²+e¹+e⁰≈30.19。要求 model logits 经 softmax 后整体形状逼近这个分布——不必精确到位置,但要把这 4 个 token 集中起来,且 gonna 占大头

关键差异:当 model 把概率质量分给 gonnagive 而非全压在 gonna 时,NTP 和 MTP 各 head 的 loss 都会上升,但 TOP 的 loss 反而是合理的——因为 TOP 的 target 本来就允许 give 拿到 e²/Z ≈ 24% 的概率。这正是"软监督"的体现:模型不会因为"略微关心了 t+2 token"而被惩罚,反而因此被奖励。

同一 hidden h₀ᴸ 下,model 输出 softmax 与 target 比较 TOP target gonna .66 give .24 you .09 up .03 理想 model softmax gonna give you up TOP loss 鼓励 model 把质量"分梯度"给这 4 个 token NTP / MTP head 1 只奖励 100% 押中 gonna,惩罚分给 give 的概率
图 4. 同一个 hidden 在 NTP / MTP 视角下"专一" vs 在 TOP 视角下"展开"。这种 target 形状的差异,是 TOP 让 representation 学得更全面的关键。

7. 实验关键结果

7.1 通用 NLP benchmark(9 个任务,3 种规模)

挑出最关键的几个数(Table 2,加粗了胜者):

SIZE方法LAMBADA acc↑HellaSwag↑ARC-C↑MMLU↑TriviaQA EM↑
340MNTP36.3542.5328.8429.814.93
MTP35.3242.7329.8629.082.55
DS-MTP34.6640.2927.5628.470.87
TOP37.0743.5729.3530.094.37
1.8BNTP49.5860.0538.6535.3411.85
MTP47.9358.2940.6134.7615.98
DS-MTP48.7157.4840.4435.0112.06
TOP50.3460.4542.3236.2118.93
7BNTP55.8967.4445.6539.4724.28
MTP53.1365.8545.5638.1423.36
DS-MTP55.6266.0344.3738.1618.54
TOP57.0368.7346.4239.6530.90

读法:

7.2 续训:Math / Code

1.8B 续训GSM8K↑MATH↑HumanEval@64↑MBPP@64↑
NTP39.2013.3433.5346.70
MTP38.5915.0034.7648.20
DS-MTP2.653.6623.7831.90
TOP45.6416.6638.4144.80

1.8B GSM8K 上 TOP 比 NTP 高 6.4 个点,比 MTP 高 7 个点——这是论文最有说服力的实验之一,因为 math reasoning 是公认需要 look-ahead 的任务,TOP 显著优于专为 look-ahead 设计的 MTP。DS-MTP 的崩溃(2.65)论文怀疑是 overfitting,需要更多分析。

7.3 self-speculative decoding 是 TOP 的弱项

Table 6:用 TOP head 做 self-speculative draft,7B 上每次 forward 平均 accept 1.52 个 token,而 MTP 是 2.49,DS-MTP 是 3.03。这不奇怪——TOP head 学的是 ranking,不是"在准确位置预测准确 token",当 draft sequence 必须严格按 t+1, t+2 …排序时,TOP 的 top-k 排序天生比 MTP 各位置 head 的精确预测要差。

7.4 消融:窗口大小 W 与 loss 权重 α

8. 星图任务 · 最干净的可分性证明

这是论文最让人 convinced 的实验。任务来自 Bachmann & Nagarajan 2024,设计目的就是把 NTP 的"teacher forcing 短视症"暴露出来:

4 1 9 2 7 3 5 训练样本(单行 token 序列): edge list = 2,7|4,1|4,3|1,9|3,5|4,2 /4,7 ↑ start, goal = 4,2,7 ← 模型要输出的 path 关键: 给定 "4,7"(start, goal), 下一 token 必须是 "2"——但只看前缀, "1"、"2"、"3" 等价 (都是 4 的邻居), NTP teacher-forcing 学不到要"先看 goal"。
图 5. 星图任务 G(3, 2):中心节点 4,3 条长 2 的分支。给定 edge list + 起点 + 终点,要求输出从起点到终点的路径。NTP 在第一步就要决定走哪条分支,但前缀里看不到 goal 7 的影响——必须真正"看到未来"才能学对。
模型paramsG(3,3)G(3,5)G(5,3)G(5,5)
NTP14.2M33.832.519.50.1
MTP-216.0M10059.019.60.1
MTP-419.6M10010010019.5
DS-MTP-216.6M10032.510019.2
DS-MTP-420.7M10033.610019.3
TOP14.2M100100100100

读这张表:

TOP 为什么对所有 G 都 100%?因为 TOP 的 target 在第一步就告诉它"未来 5 步内 goal 7 会出现",而 ranking 损失允许模型先粗略锁定"应该走过 7 的那一支",再让 NTP head 输出第一个 token 2——这本质上是把"先全局规划再生成"显式注入了训练信号。

9. 与同类工作对比

方法辅助信号粒度架构成本对小模型友好对长 look-ahead 友好inference 改动
NTPnone0
MTP (Meta)每 offset 精确分类N−1 transformer block(<1B 反而掉点)(t+8 比 t+4 差)可做 self-spec(2.5×)
DS-MTP串行精确分类 + 真值 embedding稍贵于 MTPself-spec 极强(3.0×)
ProphetNet (2020)未来 n-gram多头seq2seq only
Register tokens for MTP (Gerontopoulos 2025)把 NTP 转 MTP 的轻量改造+ register tokens
MLM / span (BERT/T5/UL2)双向重建encoder不能直接 AR 生成
TOP未来 W token 的 排序1 unembedding(340M 已胜)(W=4096 也 work)(扔掉 TOP head)

另一类相关工作是 TLM-style 的"辅助 prediction loss":比如做一个旁路头预测 sentence-level 属性、文档来源、bag-of-future-words 等。TOP 和它们的本质区别是 TOP 的 target 是从原序列直接计算的,无需额外标注或外部信号,且 ranking loss 的形式让它能和现有 fused CE kernel 几乎无缝结合。这种"零标注、零 inference 改动、零额外网络"的三零定位是它最大的工程吸引力。

10. 局限 / 个人 take / 待验证问题

待验证清单

  1. 把 TOP head 和 NTP head tied weights,会怎么样?如果 tied 后还能 work,说明 TOP 的全部增益来自 target 形状而非额外参数
  2. 在 SFT / RLHF 阶段加 TOP loss 是否还有用?(论文只做了 pretrain + 续训)
  3. 把 TOP 和 EAGLE 的 draft head 训练目标融合:draft head 也用 ranking loss,会不会同时改善 representation 和 acceptance length?
  4. 消融"target 中的 W−d vs binary mask"。如果直接把 target 改成 {appears: 1, not appear: 0} 的 multi-label,是不是还行?这能分离"排序"和"出现"两种信号
  5. 把 TOP loss 仅施加在最后 K 层 hidden 上(而不是只在最后一层)——会不会让中间层也学到 look-ahead?

11. Memory Points(冷回忆)

立场 MTP 失灵不是因为 look-ahead 错,而是因为"精确预测远 token"太难。松弛成排序,信号变软,gradient 永远有用。
尺度 辅助 objective 的正确尺度是多一层 unembedding,不是多一个 mini-model。MTP 把辅助做成了子网络,代价被 trunk 容量稀释,小模型反而吃亏。
数学 Loss = NTP CE + ListNet top-one CE,target = softmax(W − d_v),不在窗口内的 v 设 −∞。形式上 = 软标签 cross-entropy。
α 反直觉 α=0.9(TOP 占 9 成)在大多数 benchmark 上最优,意味着 TOP 信号本身就足以指导主任务,NTP 头主要是为采样保留概率分布。
星图 G(5,5):TOP 是唯一能 100% 解的方法(NTP 0.1%,MTP-4 19.5%,DS-MTP-4 19.3%),且参数最少(14.2M)。这是"排序信号 ≠ 多预测一个头"最干净的证明。
部署 TOP head 推理时直接扔掉,模型退回普通 transformer,inference engine 零改动。但 self-spec 弱(7B accept 1.52 vs MTP 2.49),想要加速还得叠 EAGLE-style 外部 draft。
隐患 70B+ 规模未验;DS-MTP 续训崩盘原因不明;target 的 −∞ 极度稀疏在 fp16 下的数值稳定性论文没细讲。