Multi-Token Prediction Needs Registers (MuToR)

Anastasios Gerontopoulos · Spyros Gidaris · Nikos Komodakis · 2025-05-15 · arXiv:2505.10518
关键词: multi-token prediction · register tokens · auxiliary supervision · SFT · PEFT · autoregressive image generation

速读卡片 (TL;DR)

一句话:把 MTP 的"加 n 个 output head"换成"在输入序列里穿插 learnable register token,每个负责预测一个未来 offset"——零架构改动、≈ 2K 额外参数、horizon 任意,SFT 上首次稳定打过 next-token baseline。

≈ 2K
新增可训练参数 (vs Gloeckle 110M–550M)
+1.4 ~ +3 pt
数学/总结 SFT 提升 (Gemma-2B / Llama3-8B)
d ≤ 6
prediction horizon, register 数与 d 无关

立场:这是把 ViT registers 的 idea 反向迁移到 LM 训练时辅助信号的论文。核心 contribution 不是新损失函数,而是用 attention mask 完美隔离register 的训练信号 — 推理时一字不变,SFT pipeline 全兼容。


1 · 动机:为什么 Gloeckle 的 MTP 在 SFT 上不灵

1.1 历史脉络:从 Prophet 到 Gloeckle 再到这里

多 token 预测(MTP)不是新事:

另一条平行线: 2024 年 ICLR Darcet et al. "Vision Transformers Need Registers"。他们观察到 ViT 的 attention map 在大图上出现莫名的 high-norm "artifact patches" — 模型把背景的某些 token 偷偷当成"草稿纸"做全局聚合。解法很简单: 在输入 patch 序列前面 prepend 几个 learnable register token,显式给模型一块"不被任何下游任务读取的草稿纸",artifact 立刻消失,attention map 变干净,DINOv2 性能上升。

MuToR 的智识来源就是把这两条线交起来:"register = 不污染下游的额外 attention 槽位"这个抽象,恰好是 SFT 阶段做 MTP 所需要的——既要给模型多 token 的监督信号,又不能动正常的 next-token 通路(否则 inference 完全坏了)。

1.2 别的方案为什么不够

方案怎么注入未来 token 信号SFT 上的硬伤
Gloeckle 并行 headsn 个 transformer layer + n 个 LM head,共享 trunk每个 head ~110M 新参数 (Gemma-2B)从 0 训,SFT 数据量小学不动;且 base trunk 必须同时服务 n 个 head 的 loss → 表征压力大
DeepSeek-V3 sequential headsn 个 head 串成短链 (用 head_i 的 hidden 做 head_{i+1} 的输入)同样需要预训练阶段才能学到合理的 hidden,SFT 直接接的话信号噪
ProphetNet multi-stream每个未来 offset 一个独立 attention streamO(n × seq²) 计算,decoder-only 大模型扛不住
Pause / Think tokens (Goyal 2024)在 prompt 后追加 dummy token,增加推理时计算改变 inference 行为,违背"零开销 SFT"诉求
PASS / lookahead tokens (Monea 2023)append lookahead token,冻结 base 训它们做 spec sampling目标是加速 inference,不是给 base 增强训练信号
MuToR (本文) 👈interleave register token 进 input,每个预测某 future offset2K 参数,base 完全可用 (frozen 兼容 LoRA),inference 拔掉 register 后字面上不变

1.3 为什么这事不平凡

表面看 MuToR 很 trivial: 不就是塞几个 token 进去给 loss 吗?三件事让它 non-obvious。

(a) 未来信息泄露问题。如果 register r₁(放在 x_t 后面,要预测 x_{t+1})被 normal token x_{t+1}, x_{t+2}, ... 看见,那就是数据泄露 — 训练时监督信号是污染的,推理时这部分信号又消失,SFT 完直接退化。所以必须 regular token 完全看不到 register

(b) 同时 register 又得"看到正确的过去"。r_d 要预测 x_{t+d},它必须能看到 x_1, ..., x_t(不能多也不能少)。如果它能看到其他 register 怎么办?那就违背独立性,而且其他 register 的 hidden 是"未来 target"的预测,看了等于又泄露。所以 register 之间也要互相不可见

(c) Position id 怎么设。RoPE 是通过相对位置索引算 attention 的;如果 r_d 和 x_t 共享同一个位置 t,模型就分不清"我现在是在做 next-token 还是 d-token 预测"。Gloeckle 用不同 head 来区分,而 MuToR 没有 head — 必须把"我要预测 d 步以后"这个信息从位置编码灌进去。论文的解法: 把 r_d 的 pos id 设成 t+d−1,正好等于"在标准 NTP 设置下,谁会预测 x_{t+d}"那个位置。这样 RoPE 自然把 r_d 的 query 投到一个"假装自己已经在未来 t+d−1 位置"的方向,attention 行为与 NTP 一致。

核心洞察:register 的本质是"future-target attention slot" — 给模型提供一个独立的、不污染主干信息流的容器,把多 token 监督的梯度从这里反传回主干。它和 ViT registers (草稿纸) 的功能一致: 提供 un-tokenized 的 attention 槽位,不污染 normal flow

2 · 背景速查

2.1 关键术语

术语含义
NTP (next-token prediction)给定 x_{≤t} 预测 x_{t+1},decoder-only LM 的标准目标
MTP (multi-token prediction)同时预测 x_{t+1:t+d_max},Gloeckle 2024 提出的训练增强
Teacher forcing训练时用 ground-truth 历史 token 作为 context,允许 mask 化并行
Register token本文定义: 插入在 input sequence 中的 learnable embedding,只用作辅助 loss 的"载体"
Offset dregister 要预测的未来步数,从 {1, ..., d_max} 均匀采样
d_maxprediction horizon 的上界 (语言任务最优 d_max=4,2D 图像 d_max_2D=4)
a (loss weight)L = (1−a) L_ntp + a L_reg,语言任务最佳 a∈[0.1, 0.5]
RoPERotary Position Embedding,Su et al. 2024,相对位置 encoding
Causal mask下三角 attention mask,token i 只能看 j≤i
ViT register (Darcet 2024)在 ViT 输入前 prepend learnable token,作为"全局草稿纸"消除 artifact

2.2 Gloeckle MTP 速回顾(便于对比)

预测 t+1, t+2, ..., t+n 用 n 个独立的"transformer layer + LM head" 接在 base 主干输出上,主干 hidden h_t 被 n 个 head 共享:

L_mtp = ED [ −Σ_t Σ_{i=1}^{n} log P_θ(x_{t+i} | x_{≤t}) ]

每个 head 独立产 logits,反向传梯度回主干。每个 head 是一整层 transformer (~110M for Gemma-2B 的 hidden size),这就是论文 Table 3 里 dmax=2 的 Gloeckle 要多 110M 参数、dmax=4 要 330M 的来源。


3 · 方法:Register 是怎么工作的

3.1 一句话定义

x' = ( x₁ , r_d , x₂ , r_d , ... , x_{T−1} , r_d , x_T )

把 sequence 每两个相邻 regular token 之间塞一个 register token r_d。所有 r_d 共享同一个 learnable embedding(2K 参数 = hidden_dim × 1),区分"我要预测谁"的信息纯靠 position id。同一个 sequence 里 d 是固定的(从 {1,...,d_max} 均匀采样一个);不同 sequence 之间 d 不同。

3.2 序列布局示意

x₁ pos 1 x₂ pos 2 x₃ pos 3 x₄ pos 4 x₅ pos 5 x₆ pos 6 r₃ →x₄ pos 1+3−1=3 r₃ →x₅ pos 2+3−1=4 r₃ →x₆ pos 3+3−1=5 r₃ →x₇ pos 4+3−1=6 r₃ →x₈ pos 5+3−1=7 regular token (NTP) register r_d (predicts x_{t+d}) d=3 时, sequence 实际长这样 (interleave 后): x₁, r₃, x₂, r₃, x₃, r₃, x₄, r₃, x₅, r₃, x₆ — 共 11 个 token,所有 r₃ 共享 embedding,只是 position id 不同
关键点: 所有 r₃ 共用一个 learnable embedding,但 position id 各不相同 (3, 4, 5, 6, 7)。RoPE 通过这个 position id 让每个 r₃ 的 query 对齐到"假装自己在 t+d−1 位置发起 next-token 查询"的方向 — 完全模仿 NTP。

3.3 训练 loss

L = (1−a) · L_ntp + a · L_reg

L_ntp 走 regular token 通路,完全不变;L_reg 是 register 位置上的预测 loss:

L_reg = ED [ −Σ_t log P_θ(x_{t+d} | x_{≤t}, r_d) ]

注意 cond 上是 x_{≤t} 而不是 x_{≤t+d−1} — 即便 r_d 的 position id 是 t+d−1,attention mask 也强制它只能 attend 到 x_{≤t}。位置在未来,看见的过去仍然停在 t。这是设计的精髓。

3.4 反向论证: 不这样设计会怎样?

偏离设计会发生什么
regular token 能看见 register训练时 x_{t+1} 看到 r_d (它已经在"猜" x_{t+d}),信息泄露 → 推理时 register 不在,分布漂移 → SFT 后性能崩
register 之间能互相 attendr_d@pos 4 attend 到 r_d@pos 3,后者承载着对 x_5 的预测信息 → 间接看到了未来 token 的猜测,信号被噪声污染
register 用 position t (和它前面的 regular token 同位)RoPE 给的 query 角度与 NTP 不一致,模型分不清"我现在到底要预测什么距离"
每个 d 用独立 embedding论文 Table 5 实测: shared 比 different 略好 (42.10 vs 41.85)。因为 position id 已经蕴含 d 信息,再多一组 embedding 反而切碎数据 — 每个 embedding 只看到 1/d_max 比例的样本

4 · Attention mask:三类 token 的连接矩阵

这是全文最 load-bearing 的图。用 d=2 的 5-token 输入做例子: 序列 augment 后是 (x₁, r₂, x₂, r₂, x₃, r₂, x₄, r₂, x₅),我们关心一个 9×9 的 mask。

MuToR Attention Mask (深 = 可 attend, 浅 = 不可) x₁ r₂ x₂ r₂ x₃ r₂ x₄ r₂ x₅ x₁ (q) r₂ (q) x₂ (q) r₂ (q) x₃ (q) r₂ (q) x₄ (q) r₂ (q) x₅ (q) x → x (NTP 通路, 标准 causal) r → x (register 看过去 regular) 禁止 三条铁律: ① x 行的 r 列 = 全空 → 推理时拔掉 register, x 行为不变 ② r 行的 r 列 = 全空 → register 互不可见 ③ r 行的 x 列 = 严格因果到 t → r_d (插在 x_t 后) 看 x_{≤t}, 不看 x_{t+1..t+d-1} 蓝色子矩阵恰好是原始 NTP 的 5×5 下三角 mask — 拔掉 r 行 r 列后字面相等 这就是 "inference time identical to NTP" 的来源
整个矩阵可以拆成两个独立信道: 蓝色子矩阵(x↔x)是未触动的 NTP 计算图;红色子矩阵(r→x)是辅助信号通路。两者通过 attention 在同一次 forward里并行算完,但梯度永远只能从 r 单向流向 x — 这是设计的核心几何结构。

5 · Position embedding 的精巧设计

这一节是论文最微妙的地方。RoPE 通过 (q_pos − k_pos) 的相对位置算 attention,所以 query 和 key 的 position id 共同决定了"我在以什么角度看历史"。

5.1 设计选择

TokenPosition id解释
x_t (regular)t原序列里它本来的位置,完全不动
r_d 插在 x_t 后,要预测 x_{t+d}t + d − 1"如果按 NTP 规则,谁会预测 x_{t+d}?" 答: x_{t+d−1}。让 r_d 站在那个位置上,query 角度就和 NTP 一致

5.2 物理直觉:让 r_d 假装自己是"那个未来位置"

NTP 时,x_{t+d−1} 的 hidden 通过 query Q(pos = t+d−1) 与 keys K(pos = 1..t+d−1) 做 attention,然后预测 x_{t+d}。MuToR 时 r_d 也想预测 x_{t+d},于是论文让它在 RoPE 视角下与 x_{t+d−1} 长得像: 同一个 query position id。但 attention mask 限制了 key 视野只到 x_t — 位置在未来,看到的过去更短

这听起来像作弊,但其实就是教模型: "假如我现在站在 t+d−1 这个未来位置,但只能看到 t 之前的信息,我能猜出 t+d 是什么吗?" 这正是希望模型学会的"前瞻 planning"能力。

5.3 worked numerics

设 hidden_dim=4096,RoPE 的 base θ=10000,某个 head 的频率 ω_k = θ^(−2k/d_head)。

所以 RoPE 几何上 r₂ 是"缺一只眼睛的 x₃"。模型必须学会用更短的历史做更远的预测,这恰好是 MTP 想要的"planning 信号"。

5.4 反向论证: position 设成 t (跟 x_t 同位)会怎样?

那 r_d 的 query 角度 = t·ω_k,与 NTP 里 x_t 的 query 一致 — 但 x_t 是预测 x_{t+1} 的!模型现在面对一个矛盾输入: "query 角度让我预测 x_{t+1},但 loss 让我预测 x_{t+d}",梯度方向冲突,训练不收敛。论文 Table 5 暗示了这一点 — 共享 embedding + 区别 position 是必须的组合。


6 · Worked example:三个 register 串起一帧

这是把 §3-§5 全部组合起来的具象案例。考虑一句中文 SFT 数据 "一加一等于二",分词后 x = (一, 加, 一, 等, 于, 二),T=6。SFT 时设 d_max=4,每个 sequence 采样 d∈{1,2,3,4}。本帧采到 d=3。

6.1 augment 后的输入

d=3 case: 在每个 x_t (t≤T−d) 后插一个 r₃,目标 x_{t+3} pos=1 → 加 pos=2 → 一 pos=3 → 等 pos=4 → 于 pos=5 pos=6 r₃ pos=3 → 等 (t+3=4) r₃ pos=4 → 于 (t+3=5) r₃ pos=5 → 二 (t+3=6) 这一帧产生的总 loss: L_ntp = −[log p(加|一) + log p(一|一加) + log p(等|一加一) + log p(于|一加一等) + log p(二|一加一等于)] L_reg = −[log p(等 | 一, r₃@pos3) + log p(于 | 一加, r₃@pos4) + log p(二 | 一加一, r₃@pos5)]
注意: r₃@pos3 通过 attention mask 只看见 "一",但 RoPE 让它的 query 角度 = 标准 NTP 中 pos=3 的 query 角度;它要从极少的 context 跳过两个 token 直接预测"等"。这是相当难的任务 — 但梯度反传会逼 base trunk 把 hidden(一) 编码得比 NTP 时更"信息密集",迫使表征前瞻。这就是 MTP 信号的本质。

6.2 一个 register 的信息流

Layer L (输出层) Layer 0 (embedding) e(一) e_reg e(加) e_reg e(一) e_reg pos=1 pos=3 pos=2 pos=4 pos=3 pos=5 query × × h_r shared LM head target = "二" r₃@pos5 想预测 "二": query 角度 = 5·ω, 但只能 attend 到 (一@1, 加@2, 一@3) 三个 key. 难!但梯度反传逼迫 base 让前 3 个 token 的 hidden 蕴含"5 步后是结尾"的语义.
追踪一个 register 的信息流。注意 LM head 是共享的(NTP 和 register 共用同一个 vocab projection),所以 register 不引入额外的 output head 参数 — 这是为什么"only 2K extra params"成立的关键。

7 · 2D 扩展(图像生成)

论文展示了 MuToR 不止于语言:LlamaGen-B (111M) 在 ImageNet 256×256 上,把 d 推广到 2D。这是对 Gloeckle n-head 范式的暴击: Gloeckle 要预测 (d_h, d_w) 范围内的所有未来 token,需要 d_max_2D² − 1 个独立 head。MuToR 一个 register embedding 通吃。

7.1 2D offset 采样

(d_h, d_w) ~ Uniform({1,...,d_max_2D}² \ {(1,1)}), d = (d_h − 1)·w + d_w − 1

w 是图像宽度(token grid 的宽)。每个 register 预测一个 2D 邻域中的 token,d_max_2D=4 时有 15 个候选 target。

7.2 数据点:dmax_2D 比 1D 强很多

方法 (200K iters)FID ↓IS ↑
Next-Token baseline6.83158.4
MuToR-1D (d=4)6.43163.0
MuToR-2D (d_max_2D=4, 15 targets)5.65183.5

关键观察:"MuToR-2D 100K iter" 已经超越 "Next-Token 200K iter" — 即等计算下加速收敛 2×。这是 SFT 之外另一个稳定胜利的证据。同时 sparse register (256 → 80,只增 30% 序列长度)几乎不掉点 — 提示 register 摆放有"潜在的 budget-vs-quality 自由度",论文留为 future work。


8 · 实验关键结果

8.1 SFT 数学推理 (核心战场)

ModelMethodGSM8K (1M-GSM)MATH500 (1M-MATH)
Gemma 2BNext-Token66.0926.73
Multi-Token (Gloeckle)66.69 (+0.6)26.87 (+0.1)
MuToR68.33 (+2.2)28.13 (+1.4)
Llama3 8BNext-Token85.7441.4
Multi-Token85.67 (−0.1)42.6 (+1.2)
MuToR87.03 (+1.3)43.2 (+1.8)

读法:Gloeckle 在 8B + GSM8K 上甚至倒挂 next-token,MuToR 全表正向。注意 Llama3-8B GSM8K 已经 85+,在这个高位再吃 1.3 pt 是不容易的

8.2 增加 d_max,Gloeckle 越涨越拖后腿

d_maxGloeckle 新参数Gloeckle GSM8K (1M-GSM)MuToR 新参数MuToR GSM8K (1M-GSM)
2110M66.692K67.15
3220M66.362K68.01
4330M65.532K68.33
6550M65.552K68.16

Gloeckle 在 d_max=4/6 时连 next-token (66.09) 都打不过 — 加的 head 学不动。MuToR 一直稳稳挂在 +2pt。

8.3 LoRA + MuToR (PEFT 兼容)

方法GSM8K1M-GSM
Full Next-Token38.8766.09
LoRA Next-Token36.3466.11
LoRA + MuToR38.59 (≈ full)68.11 (> full)

PEFT 上 MuToR 让 LoRA 追平甚至反超 full-finetune。Gloeckle 的"加一整层 transformer"在 LoRA 下根本没法用 — head 是 frozen base 之外新加的,要从 0 学。

8.4 Star-Graph 路径任务

Bachmann & Nagarajan 2024 的合成任务,标准 teacher forcing 由于 shortcut learning 完全失败 (solve rate ≈ 0%)。MuToR 直接 solve,提供了 MTP 信号能"破除 shortcut"的最 clean 证据。


9 · 与同类工作对比

工作核心机制额外参数Inference 影响SFT 友好?
Gloeckle 2024 (MTP heads)n 个并行 transformer headn × layer (~110M each)不变 (扔掉 heads)弱 (head 从 0 学,数据少不动)
DeepSeek-V3 sequential headsn 个 head 串成短链n × layer不变同上
L-MTP (leap MTP, 2024)跳跃式 MTP,head 预测 t+k 而非连续n × layer不变同上
FSP (future supervised prediction)用 future hidden 做 contrastive 监督contrastive head不变中等
ProphetNet (2020)Multi-stream attentionO(n) 计算不变scale 差
Pause/Think tokens (Goyal 2024)append dummy → 增加推理 compute≈0改变 inference中等
PASS (Monea 2023)append lookahead, frozen base, 训自己≈0改 spec sampling不适用
ViT registers (Darcet 2024)prepend register,消除 attention artifactk × hidden不变 (扔掉)不同领域,但智识来源
MuToR (本文)interleave register,共享 embedding,position id 编码 d2K (一个 embedding)完全不变优秀
关键差别:MuToR 与 ViT registers 的形态最像(都是"插额外 token + attention mask 隔离"),但语义不同 — ViT register 是"聚合工具"(让背景信息有处可去),MuToR register 是"前瞻容器"(把未来监督信号作为辅助 loss 灌进去)。两者共享的抽象是: 给 transformer 提供一类 attention 槽位,既参与计算又不污染主流输出。这就是为什么作者借用了"register"这个词。

10 · 局限 / 个人 take / 待验证问题

论文承认的局限

个人 take

待验证问题

  1. 序列长度倍增 (d=4 时输入变 ~2×) → FlashAttention / memory footprint 翻倍。在 32k+ long-CoT SFT 上是否仍可行?
  2. register 之间互不可见的设计是否过严? 如果允许 r 看更早的 r (类似 MoE 的 expert 间通信),信号会不会更丰富?
  3. position id = t+d−1 的设定在extrapolation(序列长度超出训练)上行为如何? RoPE 已知有外推问题,加 register 会不会进一步打破?
  4. 能否把 register loss 改成 distillation (从 teacher 模型蒸馏未来 token 分布) → 类似 EAGLE-3 的训练目标但作为辅助而非主任务?
  5. Loss weight a 在不同 d_max 下的最优值规律是什么? 论文 Table 9 显示 a∈[0.1, 0.5] 跳变,没有 trend 解释。
  6. 训练时 d 的分布(目前 uniform)能否换成 task-aware (e.g. 数学推理偏长 d, 短文本偏短 d)? 论文没探。

记忆点

立场 Register = 不污染 inference 的辅助 loss 容器,把 ViT registers 的抽象搬到 LM 训练侧
机制 interleave learnable token,共享 embedding (2K 参数),靠 position id = t+d−1 区分目标 offset
三铁律 x 不见 r;r 不见 r;r 看 x_{≤t} (强因果到 t 而非 t+d−1)
对比 Gloeckle 加 110M+ 参数 SFT 学不动;MuToR 加 2K 稳定胜
扩展 2D 版本一个 embedding 通吃 d² 个邻域 target,prior 方法需 d² 个 head
陷阱 position id 设错 (用 t 而非 t+d−1) 会让 query 角度与 NTP 矛盾,训练发散
见也 see 14_MTP_Gloeckle_2404.19737 (foundational MTP); ViT registers (Darcet 2024) 智识来源

精读笔记 v1 · 2026-05-07 · 配套论文 PDF 在 /data/szhang967/papers/paper-notes/models/MuToR_2505.10518.pdf