Multi-Token Prediction Needs Registers (MuToR)
速读卡片 (TL;DR)
一句话:把 MTP 的"加 n 个 output head"换成"在输入序列里穿插 learnable register token,每个负责预测一个未来 offset"——零架构改动、≈ 2K 额外参数、horizon 任意,SFT 上首次稳定打过 next-token baseline。
立场:这是把 ViT registers 的 idea 反向迁移到 LM 训练时辅助信号的论文。核心 contribution 不是新损失函数,而是用 attention mask 完美隔离register 的训练信号 — 推理时一字不变,SFT pipeline 全兼容。
1 · 动机:为什么 Gloeckle 的 MTP 在 SFT 上不灵
1.1 历史脉络:从 Prophet 到 Gloeckle 再到这里
多 token 预测(MTP)不是新事:
- 2020 ProphetNet (Qi et al.)首次把"预测未来 n-gram"作为 seq2seq 预训练目标。但其 multi-stream attention 随预测深度计算量爆炸,scale 不到 decoder-only 大模型。
- 2024 Gloeckle (Better & faster LLMs via MTP) 提出最简形式: 在主干 transformer 之上加 n 个并行 output head(每个一层 transformer + LM head),分别预测 t+1, t+2, ..., t+n。在 pretraining 阶段(代码任务尤其)显著加速、提质,被 DeepSeek-V3 采纳为 sequential head 变体。
- 但 SFT 场景下 Gloeckle 几乎没增益。Table 11 (论文复现): Gemma-2B 在 GSM8K SFT 上 Gloeckle MTP 40.66 vs 38.87,加倍训练后反而掉到 39.98。这是 MuToR 想攻击的痛点。
另一条平行线: 2024 年 ICLR Darcet et al. "Vision Transformers Need Registers"。他们观察到 ViT 的 attention map 在大图上出现莫名的 high-norm "artifact patches" — 模型把背景的某些 token 偷偷当成"草稿纸"做全局聚合。解法很简单: 在输入 patch 序列前面 prepend 几个 learnable register token,显式给模型一块"不被任何下游任务读取的草稿纸",artifact 立刻消失,attention map 变干净,DINOv2 性能上升。
MuToR 的智识来源就是把这两条线交起来:"register = 不污染下游的额外 attention 槽位"这个抽象,恰好是 SFT 阶段做 MTP 所需要的——既要给模型多 token 的监督信号,又不能动正常的 next-token 通路(否则 inference 完全坏了)。
1.2 别的方案为什么不够
| 方案 | 怎么注入未来 token 信号 | SFT 上的硬伤 |
|---|---|---|
| Gloeckle 并行 heads | n 个 transformer layer + n 个 LM head,共享 trunk | 每个 head ~110M 新参数 (Gemma-2B)从 0 训,SFT 数据量小学不动;且 base trunk 必须同时服务 n 个 head 的 loss → 表征压力大 |
| DeepSeek-V3 sequential heads | n 个 head 串成短链 (用 head_i 的 hidden 做 head_{i+1} 的输入) | 同样需要预训练阶段才能学到合理的 hidden,SFT 直接接的话信号噪 |
| ProphetNet multi-stream | 每个未来 offset 一个独立 attention stream | O(n × seq²) 计算,decoder-only 大模型扛不住 |
| Pause / Think tokens (Goyal 2024) | 在 prompt 后追加 dummy token,增加推理时计算 | 改变 inference 行为,违背"零开销 SFT"诉求 |
| PASS / lookahead tokens (Monea 2023) | append lookahead token,冻结 base 训它们做 spec sampling | 目标是加速 inference,不是给 base 增强训练信号 |
| MuToR (本文) 👈 | interleave register token 进 input,每个预测某 future offset | 2K 参数,base 完全可用 (frozen 兼容 LoRA),inference 拔掉 register 后字面上不变 |
1.3 为什么这事不平凡
表面看 MuToR 很 trivial: 不就是塞几个 token 进去给 loss 吗?三件事让它 non-obvious。
(a) 未来信息泄露问题。如果 register r₁(放在 x_t 后面,要预测 x_{t+1})被 normal token x_{t+1}, x_{t+2}, ... 看见,那就是数据泄露 — 训练时监督信号是污染的,推理时这部分信号又消失,SFT 完直接退化。所以必须 regular token 完全看不到 register。
(b) 同时 register 又得"看到正确的过去"。r_d 要预测 x_{t+d},它必须能看到 x_1, ..., x_t(不能多也不能少)。如果它能看到其他 register 怎么办?那就违背独立性,而且其他 register 的 hidden 是"未来 target"的预测,看了等于又泄露。所以 register 之间也要互相不可见。
(c) Position id 怎么设。RoPE 是通过相对位置索引算 attention 的;如果 r_d 和 x_t 共享同一个位置 t,模型就分不清"我现在是在做 next-token 还是 d-token 预测"。Gloeckle 用不同 head 来区分,而 MuToR 没有 head — 必须把"我要预测 d 步以后"这个信息从位置编码灌进去。论文的解法: 把 r_d 的 pos id 设成 t+d−1,正好等于"在标准 NTP 设置下,谁会预测 x_{t+d}"那个位置。这样 RoPE 自然把 r_d 的 query 投到一个"假装自己已经在未来 t+d−1 位置"的方向,attention 行为与 NTP 一致。
2 · 背景速查
2.1 关键术语
| 术语 | 含义 |
|---|---|
| NTP (next-token prediction) | 给定 x_{≤t} 预测 x_{t+1},decoder-only LM 的标准目标 |
| MTP (multi-token prediction) | 同时预测 x_{t+1:t+d_max},Gloeckle 2024 提出的训练增强 |
| Teacher forcing | 训练时用 ground-truth 历史 token 作为 context,允许 mask 化并行 |
| Register token | 本文定义: 插入在 input sequence 中的 learnable embedding,只用作辅助 loss 的"载体" |
| Offset d | register 要预测的未来步数,从 {1, ..., d_max} 均匀采样 |
| d_max | prediction horizon 的上界 (语言任务最优 d_max=4,2D 图像 d_max_2D=4) |
| a (loss weight) | L = (1−a) L_ntp + a L_reg,语言任务最佳 a∈[0.1, 0.5] |
| RoPE | Rotary Position Embedding,Su et al. 2024,相对位置 encoding |
| Causal mask | 下三角 attention mask,token i 只能看 j≤i |
| ViT register (Darcet 2024) | 在 ViT 输入前 prepend learnable token,作为"全局草稿纸"消除 artifact |
2.2 Gloeckle MTP 速回顾(便于对比)
预测 t+1, t+2, ..., t+n 用 n 个独立的"transformer layer + LM head" 接在 base 主干输出上,主干 hidden h_t 被 n 个 head 共享:
每个 head 独立产 logits,反向传梯度回主干。每个 head 是一整层 transformer (~110M for Gemma-2B 的 hidden size),这就是论文 Table 3 里 dmax=2 的 Gloeckle 要多 110M 参数、dmax=4 要 330M 的来源。
3 · 方法:Register 是怎么工作的
3.1 一句话定义
把 sequence 每两个相邻 regular token 之间塞一个 register token r_d。所有 r_d 共享同一个 learnable embedding(2K 参数 = hidden_dim × 1),区分"我要预测谁"的信息纯靠 position id。同一个 sequence 里 d 是固定的(从 {1,...,d_max} 均匀采样一个);不同 sequence 之间 d 不同。
3.2 序列布局示意
3.3 训练 loss
L_ntp 走 regular token 通路,完全不变;L_reg 是 register 位置上的预测 loss:
注意 cond 上是 x_{≤t} 而不是 x_{≤t+d−1} — 即便 r_d 的 position id 是 t+d−1,attention mask 也强制它只能 attend 到 x_{≤t}。位置在未来,看见的过去仍然停在 t。这是设计的精髓。
3.4 反向论证: 不这样设计会怎样?
| 偏离设计 | 会发生什么 |
|---|---|
| regular token 能看见 register | 训练时 x_{t+1} 看到 r_d (它已经在"猜" x_{t+d}),信息泄露 → 推理时 register 不在,分布漂移 → SFT 后性能崩 |
| register 之间能互相 attend | r_d@pos 4 attend 到 r_d@pos 3,后者承载着对 x_5 的预测信息 → 间接看到了未来 token 的猜测,信号被噪声污染 |
| register 用 position t (和它前面的 regular token 同位) | RoPE 给的 query 角度与 NTP 不一致,模型分不清"我现在到底要预测什么距离" |
| 每个 d 用独立 embedding | 论文 Table 5 实测: shared 比 different 略好 (42.10 vs 41.85)。因为 position id 已经蕴含 d 信息,再多一组 embedding 反而切碎数据 — 每个 embedding 只看到 1/d_max 比例的样本 |
4 · Attention mask:三类 token 的连接矩阵
这是全文最 load-bearing 的图。用 d=2 的 5-token 输入做例子: 序列 augment 后是 (x₁, r₂, x₂, r₂, x₃, r₂, x₄, r₂, x₅),我们关心一个 9×9 的 mask。
5 · Position embedding 的精巧设计
这一节是论文最微妙的地方。RoPE 通过 (q_pos − k_pos) 的相对位置算 attention,所以 query 和 key 的 position id 共同决定了"我在以什么角度看历史"。
5.1 设计选择
| Token | Position id | 解释 |
|---|---|---|
| x_t (regular) | t | 原序列里它本来的位置,完全不动 |
| r_d 插在 x_t 后,要预测 x_{t+d} | t + d − 1 | "如果按 NTP 规则,谁会预测 x_{t+d}?" 答: x_{t+d−1}。让 r_d 站在那个位置上,query 角度就和 NTP 一致 |
5.2 物理直觉:让 r_d 假装自己是"那个未来位置"
NTP 时,x_{t+d−1} 的 hidden 通过 query Q(pos = t+d−1) 与 keys K(pos = 1..t+d−1) 做 attention,然后预测 x_{t+d}。MuToR 时 r_d 也想预测 x_{t+d},于是论文让它在 RoPE 视角下与 x_{t+d−1} 长得像: 同一个 query position id。但 attention mask 限制了 key 视野只到 x_t — 位置在未来,看到的过去更短。
这听起来像作弊,但其实就是教模型: "假如我现在站在 t+d−1 这个未来位置,但只能看到 t 之前的信息,我能猜出 t+d 是什么吗?" 这正是希望模型学会的"前瞻 planning"能力。
5.3 worked numerics
设 hidden_dim=4096,RoPE 的 base θ=10000,某个 head 的频率 ω_k = θ^(−2k/d_head)。
- 普通 NTP: 给定 x₃ (pos=3) 预测 x₄,query 旋转角 = 3·ω_k
- MuToR d=2: r₂ 插在 x₂ 后预测 x₄,position 设 2+2−1 = 3,query 旋转角 = 3·ω_k 完全一致
- 但 r₂ 看到的 keys 是 x₁ (旋转 1·ω_k)、x₂ (旋转 2·ω_k) — 比 NTP 少了 x₃
所以 RoPE 几何上 r₂ 是"缺一只眼睛的 x₃"。模型必须学会用更短的历史做更远的预测,这恰好是 MTP 想要的"planning 信号"。
5.4 反向论证: position 设成 t (跟 x_t 同位)会怎样?
那 r_d 的 query 角度 = t·ω_k,与 NTP 里 x_t 的 query 一致 — 但 x_t 是预测 x_{t+1} 的!模型现在面对一个矛盾输入: "query 角度让我预测 x_{t+1},但 loss 让我预测 x_{t+d}",梯度方向冲突,训练不收敛。论文 Table 5 暗示了这一点 — 共享 embedding + 区别 position 是必须的组合。
6 · Worked example:三个 register 串起一帧
这是把 §3-§5 全部组合起来的具象案例。考虑一句中文 SFT 数据 "一加一等于二",分词后 x = (一, 加, 一, 等, 于, 二),T=6。SFT 时设 d_max=4,每个 sequence 采样 d∈{1,2,3,4}。本帧采到 d=3。
6.1 augment 后的输入
6.2 一个 register 的信息流
7 · 2D 扩展(图像生成)
论文展示了 MuToR 不止于语言:LlamaGen-B (111M) 在 ImageNet 256×256 上,把 d 推广到 2D。这是对 Gloeckle n-head 范式的暴击: Gloeckle 要预测 (d_h, d_w) 范围内的所有未来 token,需要 d_max_2D² − 1 个独立 head。MuToR 一个 register embedding 通吃。
7.1 2D offset 采样
w 是图像宽度(token grid 的宽)。每个 register 预测一个 2D 邻域中的 token,d_max_2D=4 时有 15 个候选 target。
7.2 数据点:dmax_2D 比 1D 强很多
| 方法 (200K iters) | FID ↓ | IS ↑ |
|---|---|---|
| Next-Token baseline | 6.83 | 158.4 |
| MuToR-1D (d=4) | 6.43 | 163.0 |
| MuToR-2D (d_max_2D=4, 15 targets) | 5.65 | 183.5 |
关键观察:"MuToR-2D 100K iter" 已经超越 "Next-Token 200K iter" — 即等计算下加速收敛 2×。这是 SFT 之外另一个稳定胜利的证据。同时 sparse register (256 → 80,只增 30% 序列长度)几乎不掉点 — 提示 register 摆放有"潜在的 budget-vs-quality 自由度",论文留为 future work。
8 · 实验关键结果
8.1 SFT 数学推理 (核心战场)
| Model | Method | GSM8K (1M-GSM) | MATH500 (1M-MATH) |
|---|---|---|---|
| Gemma 2B | Next-Token | 66.09 | 26.73 |
| Multi-Token (Gloeckle) | 66.69 (+0.6) | 26.87 (+0.1) | |
| MuToR | 68.33 (+2.2) | 28.13 (+1.4) | |
| Llama3 8B | Next-Token | 85.74 | 41.4 |
| Multi-Token | 85.67 (−0.1) | 42.6 (+1.2) | |
| MuToR | 87.03 (+1.3) | 43.2 (+1.8) |
读法:Gloeckle 在 8B + GSM8K 上甚至倒挂 next-token,MuToR 全表正向。注意 Llama3-8B GSM8K 已经 85+,在这个高位再吃 1.3 pt 是不容易的。
8.2 增加 d_max,Gloeckle 越涨越拖后腿
| d_max | Gloeckle 新参数 | Gloeckle GSM8K (1M-GSM) | MuToR 新参数 | MuToR GSM8K (1M-GSM) |
|---|---|---|---|---|
| 2 | 110M | 66.69 | 2K | 67.15 |
| 3 | 220M | 66.36 | 2K | 68.01 |
| 4 | 330M | 65.53 | 2K | 68.33 |
| 6 | 550M | 65.55 | 2K | 68.16 |
Gloeckle 在 d_max=4/6 时连 next-token (66.09) 都打不过 — 加的 head 学不动。MuToR 一直稳稳挂在 +2pt。
8.3 LoRA + MuToR (PEFT 兼容)
| 方法 | GSM8K | 1M-GSM |
|---|---|---|
| Full Next-Token | 38.87 | 66.09 |
| LoRA Next-Token | 36.34 | 66.11 |
| LoRA + MuToR | 38.59 (≈ full) | 68.11 (> full) |
PEFT 上 MuToR 让 LoRA 追平甚至反超 full-finetune。Gloeckle 的"加一整层 transformer"在 LoRA 下根本没法用 — head 是 frozen base 之外新加的,要从 0 学。
8.4 Star-Graph 路径任务
Bachmann & Nagarajan 2024 的合成任务,标准 teacher forcing 由于 shortcut learning 完全失败 (solve rate ≈ 0%)。MuToR 直接 solve,提供了 MTP 信号能"破除 shortcut"的最 clean 证据。
9 · 与同类工作对比
| 工作 | 核心机制 | 额外参数 | Inference 影响 | SFT 友好? |
|---|---|---|---|---|
| Gloeckle 2024 (MTP heads) | n 个并行 transformer head | n × layer (~110M each) | 不变 (扔掉 heads) | 弱 (head 从 0 学,数据少不动) |
| DeepSeek-V3 sequential heads | n 个 head 串成短链 | n × layer | 不变 | 同上 |
| L-MTP (leap MTP, 2024) | 跳跃式 MTP,head 预测 t+k 而非连续 | n × layer | 不变 | 同上 |
| FSP (future supervised prediction) | 用 future hidden 做 contrastive 监督 | contrastive head | 不变 | 中等 |
| ProphetNet (2020) | Multi-stream attention | O(n) 计算 | 不变 | scale 差 |
| Pause/Think tokens (Goyal 2024) | append dummy → 增加推理 compute | ≈0 | 改变 inference | 中等 |
| PASS (Monea 2023) | append lookahead, frozen base, 训自己 | ≈0 | 改 spec sampling | 不适用 |
| ViT registers (Darcet 2024) | prepend register,消除 attention artifact | k × hidden | 不变 (扔掉) | 不同领域,但智识来源 |
| MuToR (本文) | interleave register,共享 embedding,position id 编码 d | 2K (一个 embedding) | 完全不变 | 优秀 |
10 · 局限 / 个人 take / 待验证问题
论文承认的局限
- Register 的位置是均匀插或随机插,没有 task-aware 的 placement 策略 — 留给 future work
- 同一个 sequence 里 d 是固定的(单个 d),没有混合 d 的尝试 — 简化处理
- Pretraining 上只在图像上验证,语言 pretraining 的对比实验缺席(只做了 SFT)
- Inference 完全不用 register — 没有像 EAGLE/Medusa 那样把 register 复用到 spec decoding 的尝试
个人 take
- 这是"register 抽象"在 LM 训练侧第一次落地。更深远的意义可能是: 任何"我想要给模型加辅助监督但不想动 inference"的 case,register 都是合适的载体。比如 process reward 的过程监督、retrieval 信号、structure prediction……
- 论文的对比偏弱:没有跟 DeepSeek-V3 sequential head 在 SFT 上对比(因为 V3 是 MoE 难复现,但应该提)。
- 2D 扩展非常 underrated。视觉 AR 模型 (LlamaGen, MAR) 几乎没人用 MTP — Gloeckle 在 2D 上扩展需要 d² 个 head,prohibitive。MuToR 给了一个简单可行的入口。
- 与 Gloeckle 在 pretraining 上的优势会不会消失? 值得跑一次 1B 级别的 from-scratch 对比。
待验证问题
- 序列长度倍增 (d=4 时输入变 ~2×) → FlashAttention / memory footprint 翻倍。在 32k+ long-CoT SFT 上是否仍可行?
- register 之间互不可见的设计是否过严? 如果允许 r 看更早的 r (类似 MoE 的 expert 间通信),信号会不会更丰富?
- position id = t+d−1 的设定在extrapolation(序列长度超出训练)上行为如何? RoPE 已知有外推问题,加 register 会不会进一步打破?
- 能否把 register loss 改成 distillation (从 teacher 模型蒸馏未来 token 分布) → 类似 EAGLE-3 的训练目标但作为辅助而非主任务?
- Loss weight a 在不同 d_max 下的最优值规律是什么? 论文 Table 9 显示 a∈[0.1, 0.5] 跳变,没有 trend 解释。
- 训练时 d 的分布(目前 uniform)能否换成 task-aware (e.g. 数学推理偏长 d, 短文本偏短 d)? 论文没探。
记忆点
机制 interleave learnable token,共享 embedding (2K 参数),靠 position id = t+d−1 区分目标 offset
三铁律 x 不见 r;r 不见 r;r 看 x_{≤t} (强因果到 t 而非 t+d−1)
对比 Gloeckle 加 110M+ 参数 SFT 学不动;MuToR 加 2K 稳定胜
扩展 2D 版本一个 embedding 通吃 d² 个邻域 target,prior 方法需 d² 个 head
陷阱 position id 设错 (用 t 而非 t+d−1) 会让 query 角度与 NTP 矛盾,训练发散
见也 see 14_MTP_Gloeckle_2404.19737 (foundational MTP); ViT registers (Darcet 2024) 智识来源
精读笔记 v1 · 2026-05-07 · 配套论文 PDF 在 /data/szhang967/papers/paper-notes/models/MuToR_2505.10518.pdf