Teaching Pretrained LMs to Think Deeper with Retrofitted Recurrence
速读卡片 (TL;DR)
一句话:不要再从零训练 depth-recurrent transformer(Geiping 等花了 800B tokens),直接把已有的 TinyLlama / OLMo-2 / Llama-3.2 切开,挑早期层做 prelude、晚期层做 recurrent block + coda,中间层丢掉,再加 input injection 和小规模 continued pretraining,就能把一个 1B 固定深度模型改造成一个"用 32 次循环替代深度"的 latent reasoning 模型,GSM8K / MATH 在更少参数 + 同样 FLOPs 下击败原模型。
+10% GSM8K
立场:这是 latent reasoning 路线最实用的一篇 — 不要求重训,等于把"recurrent depth"做成了一个 post-training stage,跟扩 context length 同等地位。配方上有三个关键发现:pretrained init 远胜 random、curriculum 调度循环次数、Muon 比 AdamW 更稳;再加一个"healing 阶段"修复切层带来的 distribution shift。
1 · 动机:为什么 latent recurrence 重要,以及为什么从零训太贵
1.1 历史脉络:test-time compute scaling 的两条路
2024 下半年以来,大家公认要把"算力花在推理而不是参数"。这条路上有两个主流派系:
- Token-level scaling:让模型在 output space 写更长的 chain-of-thought / 多 candidate + best-of-N。代表是 OpenAI o1、DeepSeek-R1。优点是直接、可解释、能 RL;缺点是 每多想一步就吃一段 context、占 KV cache,而且"想"必须以 token 为载体,带宽极低(每 token 只能携带 ~16 bits)。
- Depth-level / latent scaling:让模型在 hidden state 里"绕圈",同一段权重反复跑 r 次,把更多 FLOPs 花在不增长序列的潜在思考上。代表是 Universal Transformers、Schwarzschild 的 deep thinking、HRM,以及最关键的 Geiping et al. 2025 (Huginn-0125)。优点是 KV cache 不长、信息带宽是整个 hidden state(d_model × n_token bits 量级);缺点是训练贵 —— 一次 r=32 的前向相当于 32 倍深度的网络。
Huginn-0125 证明了 latent recurrence 路线本身可行:他们从零预训练一个 3.5B 的 depth-recurrent 模型,跑 800B tokens,在 ARC / MMLU / GSM8K 上确实能越循环越准。但这篇 paper 的核心追问是:
我们已经有了 Llama-3.2-1B、OLMo-2-1B、TinyLlama-1.1B 这些跑了 3T~9T tokens 的开源模型 —— 能不能不要再从零训,直接把它们改造成 recurrent 的?
1.2 别的方案为什么不够
这个问题其实之前有人做过,但每条路都有缺口:
| 已有方案 | 做法 | 缺口 |
|---|---|---|
| Huginn (Geiping 2025) | 从零训 3.5B,800B tokens,平均 r=32 | 太贵;参数不能复用现有 base model |
| Bae 2024 (Relaxed Recursive) | 把 pretrained 模型循环 2~3 次,加 LoRA 修复 | r 增大反而掉点,没有 test-time scaling |
| Koishekenov 2025 (encode-think-decode) | OLMo → P/R/C 结构,但保留全部层,固定 r | 没 input injection;没用 curriculum;不报 FLOPs 难对比 |
| Li 2025 (cyclic refinement) | 循环 GPT-2/OPT 模型 | 只在 multiple choice 上有微弱增益 |
所以本文的位置是:把 Huginn 的"会 scale 的 latent reasoning"嫁接到现有 pretrained 模型上,且证明嫁接成本远小于从零训。这件事不平凡,原因有三:
1.3 为什么这事不平凡
- 参数既加又减:从 pretrained 模型出发要做"加" — 新增 input-injection 的 linear adapter (2h → h),还要适配 Geiping 的随机噪声 s₀;同时要做"减" — 丢中间几层(否则 r=32 时 FLOPs 爆炸)。怎么保证知识依然能 transfer 过来? 这没先验保证。
- "循环这件事"原模型从没见过:pretrained transformer 的每一层是独立功能单元,从来没被要求把"自己的输出再喂自己一次"。Bae 2024 已经证明,如果不重新训练,直接循环会掉点。
- 训练 r=32 的模型 forward 成本爆炸:每一步前向都要跑 32 次 recurrent block,如果一开始就用大 r,基本没法做 ablation。需要 curriculum:从小 r 起步,慢慢爬到大 r。但这又引入新风险 — 早期低 r 学到的 representation 可能跟后期高 r 不兼容。
论文的贡献正好对这三个难点各开一刀:
- 对 (1) 做层选择消融,发现"早 prelude / 晚 R+coda / 丢中间"最优,并设计 healing 阶段修复 distribution shift。
- 对 (2) 用 input injection(每次循环把 prelude 的 e 再拼回去)+ Muon 优化器,保证训练稳定。
- 对 (3) 设计 Poisson-Lognormal mean 的 linear / 1-sqrt 调度,把训练 FLOPs 砍半,且 loss 反而更低。
2 · 背景速查
| 术语 | 含义 |
|---|---|
| depth-recurrent transformer | 同一组 transformer 层反复跑 r 次,r 可以训练时 sample、推理时调 |
| P / R / C | Prelude (含 embedding) / Recurrent block / Coda (含 unembedding) 的三段式 — 来自 Geiping 2025 |
| input injection | 每次循环 R 时,把 prelude 的输出 e 与上一次循环输出 s_{i-1} concat 后过 adapter,避免遗忘原输入 |
| truncated BP through time | 只对最后 8 次循环回传梯度,前面的循环只作 forward 不存激活 — 大幅省内存 |
| Poisson-Lognormal r | 训练每一步从 Poisson-Lognormal 分布采一个 r,均值控制平均深度,长尾偶尔给到 r=100+ |
| Muon | 2024 新优化器,对 2D 权重做 Newton-Schulz 正交化,在 narrow / 深网络上比 AdamW 更稳 |
| scalable init (Takase 2023) | 跟 model shape 解耦的初始化,允许 prelude/coda 维度不一致也稳 |
| healing period | 切层之后先用普通 web 数据(FineWeb-Edu)训一段,让模型恢复基础语言能力,再切到 SFT / math 数据 |
回顾:Huginn 的训练分布(便于和本文对比)
Geiping 2025 用 Poisson-Lognormal,mean=32,truncated BP through time = 8。每一步训练采一个 r,前向 r 次 R,只对后 8 次回传。本文继承所有这些设置,但加上 r 的均值随训练 step 增大(curriculum)和 r 一开始就放在小值(warm-up 阶段)。
3 · 方法 1:Model Surgery —— 怎么挑层最划算
3.1 三段切法 (P, R, C) 的层选择
给定一个 22 层的 TinyLlama,要切成 (4, 8, 4) 的 P/R/C — 总共 16 层,丢 6 层。哪 6 层?论文的消融(Figure 12/13)给出了一个反直觉的答案:
挑两端,丢中间。 也就是 prelude 取 L0–L3,coda 取 L18–L21,recurrent 取 L10–L17,把 L4–L9 全扔掉。
为什么不挑"看起来均匀分布的层"?论文给的解释偏经验:
- 早期层是 token-level feature:做 detokenization 和粗糙语法,这部分必须留 — 因为 prelude 直接吃 raw token,且每次循环都要把 e 注回来,prelude 的工作不可被循环替代。
- 晚期层是 task-level abstraction:做 vocabulary projection 前的 high-level 推理,这部分作为循环单元 R + 输出头 C 是最自然的。
- 中间层最冗余:已有 ShortGPT (Men 2024) 的发现:transformer 中间层的输入输出相似度极高,本身就是"残差小的 noop",删掉对原模型损失最小;循环也最容易让 R "学会" 补这部分。
论文对比了 ShortGPT 的"按层相似度自动选"方法,发现 固定挑中间 比 ShortGPT 还略好(Figure 13)。猜测原因:ShortGPT 是在评估 next-token loss 上挑层,而循环训练要的是"能在 R 内部产生迭代式 progress 的层",这两个目标不完全一致。
3.2 Input injection 是必需品
没有 input injection 时,R 的输入只有 s_{i-1} —— 模型循环几十次后早就忘了原 token 长什么样。论文沿用 Geiping 的做法:
这个 adapter 是 新增的(原模型没有),所以 retrofit 需要给它一个 sensible 初始化。本文用 Takase 2023 的 scalable init,大致是 (1/√2) · I_h 拼 (1/√2) · I_h 的方块,这样初始时 adapter 输出 ≈ (e + s_{i-1}) / √2,等价于加法 residual,不破坏 R 在 pretrained 模型里学到的功能。
3.3 反向论证:如果不丢中间层会怎样?
论文做了 (7, 8, 7) 的 TinyLlama 模型 (Figure 15) — 留下所有层,只是把中间 8 层拿来循环。结果是:per-FLOP 效率显著下降。原因很简单 — 22 层都用,但循环 32 次,等于做了 22 + 31×8 = 270 层深的 forward,绝大多数 FLOPs 花在 7+7=14 层的非循环部分,但这些部分本身已经训得很好,几乎没新信息可学。所以"挑层是为了让 FLOPs 都花在 R 上"。
4 · 方法 2:Recurrence Curriculum
4.1 直接训 r=32 的问题
Geiping 2025 全程用 mean=32。本文复现这个设置作为 baseline,但发现:
- 每一步 forward 跑 32 次 R,training step 比同 FLOPs 的非循环模型慢得多。
- 训练初期模型还没学会"用循环",过深的 forward chain 反而把梯度搅成噪声。
4.2 Curriculum:让 mean(r) 从小到大
论文设计两种 curriculum:
- Linear:从 mean=2 线性升到 mean=32,占总训练的前 75%,然后保持 32。
- 1-sqrt:更激进 — mean(t) = 32 · (1 - √(1 - t/T_warmup)),即一开始非常陡地涨,后半段慢慢逼近 32。
4.3 Curriculum 的效果(Figure 3)
关键观察来自 Figure 3 的两张图:
- 左图(loss vs step):curriculum 跟 constant=32 几乎重合,说明 curriculum 不损失 每 step 的学习能力。
- 右图(loss vs FLOPs):curriculum 显著更低,因为前期 r 小,每 step FLOPs 少。换算下来 同 FLOPs 下 loss 低 ~0.05,等价于 ~1.3× 训练加速。
5 · 方法 3:Muon + Healing 阶段
5.1 Muon 比 AdamW 更稳(Figure 4)
recurrent 模型对优化器极其敏感 — 同一组权重反复跑 r 次,梯度通过 BP-through-time 累计,放大了任何不稳定因素。论文实测:
| optimizer | LR | 训练 loss 行为 |
|---|---|---|
| AdamW(标准) | 5e-5 | 多次 loss spike,后期 NaN |
| AdamW*(Geiping 2025 用的变体,去 ε / update clipping) | — | 比标准 AdamW 稳但仍偶有 spike |
| Muon | 1e-3(注意比 AdamW 大 20×) | 完全平稳,最终 loss 最低 |
对非循环 TinyLlama,Muon 和 AdamW 差别很小;是"循环"放大了优化器的差距。这是一个有意思的实证 — 之前 Muon 只在 narrow 模型上验证过优势,本文等于发现 recurrent 也是 Muon 的优势场景。
5.2 Healing:切层后先恢复一段
切掉 6 层后,模型在普通文本上的 perplexity 会跳一下。如果立刻用 high-quality math + SFT 数据训,模型会同时面对两个 distribution shift(架构变化 + 数据领域变化),容易把基础语言能力给"忘了"。
论文设计 two-phase training:
- Phase 1 (healing):26B tokens 的 FineWeb-Edu(普通教科书风格 web 数据)。让模型先在"原始分布"上把缺的层补回来。
- Phase 2 (specialize):再用 26B tokens 的 (1/3 FineWeb-Edu + 1/3 Nemotron-General + 1/3 Nemotron-Math) 做后训练。
消融(Figure 8):非循环模型用不用 healing 几乎无差(它没切层,本来就没"伤"要愈合);循环模型用 healing 在 Arc-C 上 +5%。
6 · Worked Example: TinyLlama (4,8,4) 一次前向
设输入 token x = "What is 17+25?" — 5 个 token,h = 2048。我们追踪一个 r=8 的前向。
| 步骤 | 形状 | 说明 |
|---|---|---|
| x: token ids | (5,) | 5 个 BPE id |
| e = P(x): 过 L0–L3 | (5, 2048) | prelude 输出, 含 token-level feature |
| s₀ ~ 𝒩(0, σ²) | (5, 2048) | 每次 forward 重采样的噪声 |
| i=1: concat[e, s₀] · W_a | (5, 2048) | 2h → h 的 linear adapter |
| i=1: 过 R(L10–L17) | s₁ (5, 2048) | 第一次"思考" |
| … i=2,…,8 | s₂, …, s₈ | 每次都把 e 拼回当前 s_{i-1} |
| p = C(s₈): 过 L18–L21 | (5, V=32000) | 得到每个位置的下个 token 分布 |
训练时的反传(truncated BP)
设 mean training r = 16,Poisson-Lognormal 这步采到 r = 12。前向跑 12 次 R,但 只对 i=5..12 这 8 次回传梯度 — i=1..4 的前向当 "stop_grad" 算,不存激活。这把内存压住了:
- 不 truncate:激活内存 ∝ 12 × 8 层 × seq × h ≈ 96 个 transformer 层的激活。
- truncate=8:激活内存 ∝ 8 × 8 = 64 层 — 跟一个 64 层非循环模型差不多。
推理时
没有 BP,r 可以拉到 32, 64, 128。论文实测 r 从 1 → 32,GSM8K acc 从 26% → 52%,等价于一个"用算力换准确率"的旋钮。
反向论证:如果不做 input injection 会怎样?
没有 e 的注入,R 只能从 s_{i-1} 拿信息。s₀ 是高斯噪声,r=1 时 s₁ 只含一点 token 信息 (来自 W_a 的耦合,但若 W_a 退化为 identity,s₁ = R(s₀) = R(noise)),整个轨迹会发散。论文验证了 input injection 是性能必需。也可以反过来理解:e 充当一个固定的"题目陈述",每次循环模型都重新看一遍题目,再基于当前思路 s_{i-1} 更新 — 这跟人解数学题时反复读题再继续是同构的。
7 · 公式: FLOPs 计算
7.1 模型核心方程(Geiping 2025 沿用)
s₀ ~ 𝒩(0, σ²)n × h
sᵢ = R(e, sᵢ₋₁), i ∈ {1, …, r}
p = C(sᵣ)
7.2 训练 FLOPs 估算
标准 Kaplan 估算 FLOPs = 6ND(forward 2ND,backward 4ND)。但 recurrent 模型只有最后 k_BP=8 次循环带梯度,其余只 forward 不 backward:
其中 N₁ = 带梯度的有效参数数 = (P + 8 × R + C) 的参数计数,N₂ = 仅 forward 的剩余参数数 = (mean_r − 8) × R 的参数。
7.3 数值敏感性
| train r (mean) | 每 step FLOPs(相对值) | 训练时间(相对) |
|---|---|---|
| 4 | 1.0 | 1.0 |
| 8 | 1.4 | 1.5 |
| 16 | 2.1 | 2.4 |
| 32 | 3.4 | 4.0 |
物理直觉:前 r-8 次循环是"白嫖知识、不学新东西"。它们消耗 FLOPs 但贡献为 给后 8 次 BP 提供一个好起点 —— 类似 inference-time 多想几步,只是这几步发生在训练时。这就是为什么 curriculum 有效:训练早期模型还学不到"r=32 才暴露的现象",这部分 forward FLOPs 就是浪费,后期再投入。
8 · 实验关键结果
8.1 Pretrained init 远胜 random init (Figure 2)
同样训 120B FineWeb-Edu tokens,两个 (2,4,2) 模型:
| 初始化 | final loss | HellaSwag (r=32) |
|---|---|---|
| Llama-3.2-1B init | 低 | ~50% |
| random (Takase 2023) | 高 | ~33% (接近 chance) |
外推 log-linear 曲线,random init 大约要 950B tokens 才能追上 pretrained init。这是论文最重要的一个证据 — knowledge transfer 在这种"参数加加减减"的 surgery 下依然成立。
8.2 Recurrent 在 GSM8K / MATH 上击败 base (Figure 5, 6, 7)
| base | params | GSM8K (r=32) | MATH (r=32) |
|---|---|---|---|
| TinyLlama-1.1B 原模型 | 1.1B | ~45% | ~14% |
| TinyLlama → (4,8,4) recurrent | 0.7B (-30%) | 52% | ~14.5% |
| OLMo-2-1B 原模型 | 1.0B | ~25% | ~12% |
| OLMo-2 → (4,6,4) recurrent | 0.9B | ~32% | ~22% |
| Llama-3.2-1B 原模型 | 1.2B | ~35% | ~10% |
| Llama-3.2 → (4,6,4) recurrent | 1.0B | ~50% | ~20% |
读法:三个家族都呈现 "参数少 30% 但 GSM8K +7% ~ +15%"。MATH 更难,gain 更明显。base model 越强,recurrent 改造后 ceiling 越高(Llama > OLMo > TinyLlama,论文称这是 "stronger base transfers more knowledge")。
8.3 推理时 r 是有效旋钮 (Figure 5/20 right + Appendix Table 3)
同一份 (4,8,4) TinyLlama 权重,推理时把 r 从 1 拨到 32,GSM8K acc 单调上升,在 r=4~8 处越过非循环 base 的 26.6%。下图是论文 Figure 5/20 right panel 的数值复刻 (math-only training,数据来自 Appendix Table 3):
四个关键观察:
- r=1 退化严重:模型从来不在 r=1 训练,直接 r=1 推理就是 OOD,准确率比非循环 base 低得多(train r=32 时 r=1 仅 5.6%)。
- r=4 处全部反超:三条线都在 r=4 跨越 26.6% 的 base 线,这是 latent reasoning 真正起作用的"启动门槛"。
- train r 大 → 天花板高:train r=32 最终在 r=8 达到 45.3% (best),而 train r=4 卡在 ~38%。训练时给得起多深,推理时就能爬多高。
- r > train r 后饱和:train r=4 的模型在 test r=4 后就持平,说明它没学到怎么用更多循环。这跟 Bae 2024 的 "r 升反而掉点" 不同 — 这里是平台期而非崩溃。
MATH 上的曲线形状一致(详见 Appendix Table 3 的 MATH 列):train r=16 在 test r=4 已达 28.9% (vs base 24.0%),且 r=8 后基本饱和。越难的任务、train r 越大、test r 越深,gain 越明显 — 这是 latent reasoning 文献一致的现象。
论文为什么不画跟 CoT 的 FLOPs 对比?
论文里 CoT 只在 intro 提了一句 motivation,没有任何 head-to-head 实验。这是这条 line 最大的缺口 — 1B 级模型本身做不好 CoT,放在一起比也意义不大;真正可比的是 8B+ scale,但本文没做到。最公平的对比应该是: "同一个 1B base,做 CoT-SFT vs 做 retrofit-recurrence,各花多少推理 FLOPs 才能在 GSM8K 达到 50%?" — 这是后续工作可以填的坑。
8.4 比 Huginn-0125 强(Table 1)
本文最强的 (4,8,4) TinyLlama recurrent 模型(~0.7B 参数,52B tokens 训练)在 MMLU 上比 Huginn-0125 (3.5B 参数,800B tokens) 高 12 个点,GSM8K 高 10 个点。这是 retrofit 策略最直接的胜利证据。
9 · 与同类工作对比
| 本文 (McLeish 2025) | Huginn (Geiping 2025) | Bae 2024 (Relaxed Recursive) | Koishekenov 2025 | |
|---|---|---|---|---|
| 起点 | pretrained | from scratch | pretrained | pretrained (OLMo) |
| 训练 tokens | ~50B | 800B | ~少量 finetune | ~少量 |
| 层选择 | 挑两端丢中间 | —(scratch) | 保留全部 | 保留全部 |
| Input injection | ✓ | ✓ | ✗ | ✗ |
| 训练 r 分布 | Poisson-LogN + curriculum | Poisson-LogN 固定 mean=32 | 固定 r=2,3 | 固定 r |
| Test-time scaling | 单调增,r=1→32 大幅涨 | 同上 | r 升反而掉点 | 少量增益 |
| 开源权重 | HF: tomg-group-umd | 已开 | — | — |
差异本质:Bae 2024 是"用循环近似原 forward,LoRA 修补差异" — 等价于压缩,没拿到 test-time scaling;本文是"用循环替代深度,layer drop 让 FLOPs 都花在 R 上" — 等价于把 base model 升级成 latent reasoner。两者都从 pretrained 出发,但目标完全不同。
10 · 局限 / 个人 take / 待验证问题
- 规模仅到 1B / 50B tokens:论文自承没在 7B+ 验证。Mid-layer 是不是仍然最冗余、curriculum 的最优形状会不会变,都不清楚。
- 仅在 math 上看到明确增益:这跟 latent reasoning 文献的共识一致 — depth 帮"算法式"任务最大。开放式 RLHF、长文本理解上是否同样?未知。
- 没有 adaptive recurrence:r 在推理时由用户给。理想情况是模型自己根据题目难度决定 stop — 这条路在 Geiping 2025、Bae 2025(MoR)、Inner Thinking Transformer (Chen 2025) 都有探索,本文还没接上。
- 训练 FLOPs 仍然偏高:同样 50B tokens 训练 (4,8,4) recurrent ≈ 训练一个 1.4B 非循环模型的 cost。Cost-benefit 真正成立的前提是 推理端能反复利用这次训练换来的 r-scaling,所以更适合长期 serve 的产品场景。
- healing 期长度还没系统化:论文用 26B 是经验值,跟 model surgery 的"剧烈程度"有没有可预测的关系?目前是个开放问题。
- Muon 在更大模型上是否仍稳:论文在 1B 用 LR=1e-3。Muon 在 70B 级别的稳定性是 2026 才会逐渐清楚的事情。
个人 take
这篇是 latent reasoning 通向"工程化"的关键节点:把它从一个昂贵的预训练实验,变成一个 固定深度模型的 post-training 阶段 — 类似扩 context length。如果在 7B+ scale 上也能复现,latent reasoning 就有可能跟 chain-of-thought 同等重要 — 一个把算力花在 hidden,一个把算力花在 token,两者完全可叠加。
同时它也指明了一个让我有点意外的事实:"挑两端丢中间"这件事的存在,跟过去几年关于"transformer 中间层冗余"(ShortGPT、LACO)的发现一致 — 也许 LLM 的算力分配本就极度不均,本文相当于把这一冗余重新利用起来:不光删掉它,而是用一个 R 块循环替代它的功能。
5 个待验证问题
- (4,8,4) → (8,16,8) 在 8B 模型上是否保持 per-FLOP 优势?
- 在 RLHF / GRPO 训练时,latent recurrence 能否替代 chain-of-thought tokens?(回答"我不确定"问题时 — 多 r 是否提供更好的 calibration?)
- 把 input injection 升级为 cross-attention(让 e 在 R 内部可被随时查询),还能涨多少?
- 是否能在 inference 时用 KV cache 加速 r-scaling — R 内部 attention 的 KV 在不同 r 间能不能共享?
- "挑早期 + 晚期、丢中间"在 instruct-tuned 模型上仍最优吗?还是 instruct 模型的 task-level abstraction 已在中间层?