Teaching Pretrained LMs to Think Deeper with Retrofitted Recurrence

McLeish, Li, Kirchenbauer, …, Geiping, Goldstein, Goldblum · UMD / NYU / LLNL / UNC / ELLIS Tübingen / Columbia · 2025-11 · arXiv:2511.07384
关键词: depth-recurrence · latent reasoning · model surgery · curriculum · test-time compute · Muon

速读卡片 (TL;DR)

一句话:不要再从零训练 depth-recurrent transformer(Geiping 等花了 800B tokens),直接把已有的 TinyLlama / OLMo-2 / Llama-3.2 切开,挑早期层做 prelude、晚期层做 recurrent block + coda,中间层丢掉,再加 input injection 和小规模 continued pretraining,就能把一个 1B 固定深度模型改造成一个"用 32 次循环替代深度"的 latent reasoning 模型,GSM8K / MATH 在更少参数 + 同样 FLOPs 下击败原模型。

~50B
retrofit 训练 token 数(对比 800B from scratch)
72.7%
recurrent 模型保留的参数比例 (TinyLlama, 6/22 层被丢)
+10% MMLU
+10% GSM8K
vs. Huginn-0125 (3.5B from-scratch)

立场:这是 latent reasoning 路线最实用的一篇 — 不要求重训,等于把"recurrent depth"做成了一个 post-training stage,跟扩 context length 同等地位。配方上有三个关键发现:pretrained init 远胜 randomcurriculum 调度循环次数Muon 比 AdamW 更稳;再加一个"healing 阶段"修复切层带来的 distribution shift。


1 · 动机:为什么 latent recurrence 重要,以及为什么从零训太贵

1.1 历史脉络:test-time compute scaling 的两条路

2024 下半年以来,大家公认要把"算力花在推理而不是参数"。这条路上有两个主流派系:

Huginn-0125 证明了 latent recurrence 路线本身可行:他们从零预训练一个 3.5B 的 depth-recurrent 模型,跑 800B tokens,在 ARC / MMLU / GSM8K 上确实能越循环越准。但这篇 paper 的核心追问是:

我们已经有了 Llama-3.2-1B、OLMo-2-1B、TinyLlama-1.1B 这些跑了 3T~9T tokens 的开源模型 —— 能不能不要再从零训,直接把它们改造成 recurrent 的?

1.2 别的方案为什么不够

这个问题其实之前有人做过,但每条路都有缺口:

已有方案做法缺口
Huginn (Geiping 2025)从零训 3.5B,800B tokens,平均 r=32太贵;参数不能复用现有 base model
Bae 2024 (Relaxed Recursive)把 pretrained 模型循环 2~3 次,加 LoRA 修复r 增大反而掉点,没有 test-time scaling
Koishekenov 2025 (encode-think-decode)OLMo → P/R/C 结构,但保留全部层,固定 r没 input injection;没用 curriculum;不报 FLOPs 难对比
Li 2025 (cyclic refinement)循环 GPT-2/OPT 模型只在 multiple choice 上有微弱增益

所以本文的位置是:把 Huginn 的"会 scale 的 latent reasoning"嫁接到现有 pretrained 模型上,且证明嫁接成本远小于从零训。这件事不平凡,原因有三:

1.3 为什么这事不平凡

  1. 参数既加又减:从 pretrained 模型出发要做"加" — 新增 input-injection 的 linear adapter (2h → h),还要适配 Geiping 的随机噪声 s₀;同时要做"减" — 丢中间几层(否则 r=32 时 FLOPs 爆炸)。怎么保证知识依然能 transfer 过来? 这没先验保证。
  2. "循环这件事"原模型从没见过:pretrained transformer 的每一层是独立功能单元,从来没被要求把"自己的输出再喂自己一次"。Bae 2024 已经证明,如果不重新训练,直接循环会掉点。
  3. 训练 r=32 的模型 forward 成本爆炸:每一步前向都要跑 32 次 recurrent block,如果一开始就用大 r,基本没法做 ablation。需要 curriculum:从小 r 起步,慢慢爬到大 r。但这又引入新风险 — 早期低 r 学到的 representation 可能跟后期高 r 不兼容。

论文的贡献正好对这三个难点各开一刀:

Pretrained TinyLlama (22层) L0 L1 L2 L3 (prelude) L4–L9 丢弃 L10–L17 (R) L18–L21 (coda) surgery Recurrent (4,8,4) model P (4 层) x → e R (8 层) + adapter [e, s_{i-1}] × r s₀ ~ 𝒩(0, σ²) C (4 层) → p
图 1:模型外科手术。从 22 层 TinyLlama 取早 4 层做 prelude (P),取晚 8 层做 recurrent block (R),取最后 4 层做 coda (C),中间 6 层(L4–L9)丢弃。每次循环 R 都把 prelude 输出 e 拼回当前隐藏态 s_{i-1},经 linear adapter 压回宽度 h。s₀ 来自高斯噪声。

2 · 背景速查

术语含义
depth-recurrent transformer同一组 transformer 层反复跑 r 次,r 可以训练时 sample、推理时调
P / R / CPrelude (含 embedding) / Recurrent block / Coda (含 unembedding) 的三段式 — 来自 Geiping 2025
input injection每次循环 R 时,把 prelude 的输出 e 与上一次循环输出 s_{i-1} concat 后过 adapter,避免遗忘原输入
truncated BP through time只对最后 8 次循环回传梯度,前面的循环只作 forward 不存激活 — 大幅省内存
Poisson-Lognormal r训练每一步从 Poisson-Lognormal 分布采一个 r,均值控制平均深度,长尾偶尔给到 r=100+
Muon2024 新优化器,对 2D 权重做 Newton-Schulz 正交化,在 narrow / 深网络上比 AdamW 更稳
scalable init (Takase 2023)跟 model shape 解耦的初始化,允许 prelude/coda 维度不一致也稳
healing period切层之后先用普通 web 数据(FineWeb-Edu)训一段,让模型恢复基础语言能力,再切到 SFT / math 数据

回顾:Huginn 的训练分布(便于和本文对比)

Geiping 2025 用 Poisson-Lognormal,mean=32,truncated BP through time = 8。每一步训练采一个 r,前向 r 次 R,只对后 8 次回传。本文继承所有这些设置,但加上 r 的均值随训练 step 增大(curriculum)和 r 一开始就放在小值(warm-up 阶段)。


3 · 方法 1:Model Surgery —— 怎么挑层最划算

3.1 三段切法 (P, R, C) 的层选择

给定一个 22 层的 TinyLlama,要切成 (4, 8, 4) 的 P/R/C — 总共 16 层,丢 6 层。哪 6 层?论文的消融(Figure 12/13)给出了一个反直觉的答案:

挑两端,丢中间。 也就是 prelude 取 L0–L3,coda 取 L18–L21,recurrent 取 L10–L17,把 L4–L9 全扔掉。

为什么不挑"看起来均匀分布的层"?论文给的解释偏经验:

论文对比了 ShortGPT 的"按层相似度自动选"方法,发现 固定挑中间 比 ShortGPT 还略好(Figure 13)。猜测原因:ShortGPT 是在评估 next-token loss 上挑层,而循环训练要的是"能在 R 内部产生迭代式 progress 的层",这两个目标不完全一致。

3.2 Input injection 是必需品

没有 input injection 时,R 的输入只有 s_{i-1} —— 模型循环几十次后早就忘了原 token 长什么样。论文沿用 Geiping 的做法:

s = R( concat[ e, sᵢ₋₁ ] · Wadapter ),   Wadapter : ℝ2h → ℝh

这个 adapter 是 新增的(原模型没有),所以 retrofit 需要给它一个 sensible 初始化。本文用 Takase 2023 的 scalable init,大致是 (1/√2) · I_h 拼 (1/√2) · I_h 的方块,这样初始时 adapter 输出 ≈ (e + s_{i-1}) / √2,等价于加法 residual,不破坏 R 在 pretrained 模型里学到的功能。

3.3 反向论证:如果不丢中间层会怎样?

论文做了 (7, 8, 7) 的 TinyLlama 模型 (Figure 15) — 留下所有层,只是把中间 8 层拿来循环。结果是:per-FLOP 效率显著下降。原因很简单 — 22 层都用,但循环 32 次,等于做了 22 + 31×8 = 270 层深的 forward,绝大多数 FLOPs 花在 7+7=14 层的非循环部分,但这些部分本身已经训得很好,几乎没新信息可学。所以"挑层是为了让 FLOPs 都花在 R 上"。


4 · 方法 2:Recurrence Curriculum

4.1 直接训 r=32 的问题

Geiping 2025 全程用 mean=32。本文复现这个设置作为 baseline,但发现:

4.2 Curriculum:让 mean(r) 从小到大

论文设计两种 curriculum:

training step mean r 32 2 constant 32 (baseline) linear 1-sqrt(更激进) 75% T
图 2:Recurrence curriculum 示意。蓝色 linear 在前 75% 步线性升到 mean=32,绿色 1-sqrt 一开始更陡再趋稳,红色虚线是 Huginn 的 constant=32 baseline。

4.3 Curriculum 的效果(Figure 3)

关键观察来自 Figure 3 的两张图:

直觉:这跟渐进式 batch size渐进式序列长度是同源思想 — 模型早期还学不到 r=32 才能用上的"深度推理",此时把这部分 forward FLOPs 节省下来,等模型 representation 长好了再投入更深的循环,效率更高。

5 · 方法 3:Muon + Healing 阶段

5.1 Muon 比 AdamW 更稳(Figure 4)

recurrent 模型对优化器极其敏感 — 同一组权重反复跑 r 次,梯度通过 BP-through-time 累计,放大了任何不稳定因素。论文实测:

optimizerLR训练 loss 行为
AdamW(标准)5e-5多次 loss spike,后期 NaN
AdamW*(Geiping 2025 用的变体,去 ε / update clipping)比标准 AdamW 稳但仍偶有 spike
Muon1e-3(注意比 AdamW 大 20×)完全平稳,最终 loss 最低

对非循环 TinyLlama,Muon 和 AdamW 差别很小;是"循环"放大了优化器的差距。这是一个有意思的实证 — 之前 Muon 只在 narrow 模型上验证过优势,本文等于发现 recurrent 也是 Muon 的优势场景。

5.2 Healing:切层后先恢复一段

切掉 6 层后,模型在普通文本上的 perplexity 会跳一下。如果立刻用 high-quality math + SFT 数据训,模型会同时面对两个 distribution shift(架构变化 + 数据领域变化),容易把基础语言能力给"忘了"。

论文设计 two-phase training:

消融(Figure 8):非循环模型用不用 healing 几乎无差(它没切层,本来就没"伤"要愈合);循环模型用 healing 在 Arc-C 上 +5%。


6 · Worked Example: TinyLlama (4,8,4) 一次前向

设输入 token x = "What is 17+25?" — 5 个 token,h = 2048。我们追踪一个 r=8 的前向。

步骤形状说明
x: token ids(5,)5 个 BPE id
e = P(x): 过 L0–L3(5, 2048)prelude 输出, 含 token-level feature
s₀ ~ 𝒩(0, σ²)(5, 2048)每次 forward 重采样的噪声
i=1: concat[e, s₀] · W_a(5, 2048)2h → h 的 linear adapter
i=1: 过 R(L10–L17)s₁ (5, 2048)第一次"思考"
… i=2,…,8s₂, …, s₈每次都把 e 拼回当前 s_{i-1}
p = C(s₈): 过 L18–L21(5, V=32000)得到每个位置的下个 token 分布

训练时的反传(truncated BP)

设 mean training r = 16,Poisson-Lognormal 这步采到 r = 12。前向跑 12 次 R,但 只对 i=5..12 这 8 次回传梯度 — i=1..4 的前向当 "stop_grad" 算,不存激活。这把内存压住了:

Training: r = 12 采样, truncated BP through last 8 x P e R₁ R₂ R₃ R₄ R₅ R₆ R₇ R₈ R₉ R₁₀ R₁₁ R₁₂ C p loss e 注入每一次 R(input injection) → forward(全部 12 次) gradient 跳过 R₁..R₄,经 skip 连接到达 P stop_grad ← backward(只回 8 次:i=5..12) 前向但不存激活 (i ≤ r − 8) 前向 + 反传 (last 8) gradient 流向 e 注入 每步: r ~ Poisson-Lognormal(mean = curriculum(t)) 本步采到 r = 12 → 12 次 forward, 后 8 次带梯度
图 3 (训练):每步采一个 r(本例 r=12),前向跑 12 次 R,但只对最后 8 次回传梯度(绿色)。前 4 次(灰色)只前向,不存激活 — 把内存控制在"8 倍 R 块"的尺度,跟训练一个 64 层非循环模型相当。这是 truncated backprop through time 在 depth-recurrence 上的对应物。

推理时

没有 BP,r 可以拉到 32, 64, 128。论文实测 r 从 1 → 32,GSM8K acc 从 26% → 52%,等价于一个"用算力换准确率"的旋钮。

Inference: r 是用户旋钮,无 BP,无内存增长 r = 1 P R₁ C → "the" GSM8K 26% r = 8 P R₁ R₂ R₃ R₄ R₅ R₆ R₇ R₈ C → "answer is" GSM8K 41% r = 32 P R₁ R₂ R₃ … R₃₁ R₃₂ (32 次循环, KV cache 不增长) C → "42" 52% r=1 旋钮 r=32 旋钮 同一份权重 / 不重训 / 上下文长度不变 / KV cache 不增 更多算力 → 更深的 latent reasoning → 更准的答案 (CoT 是把算力花在 output token,这里是花在 hidden)
图 4 (推理):r 是一个推理时旋钮,从 1 拨到 32,等于在 hidden state 上多迭代 32 次。同一份权重、不重训、序列长度不变、KV cache 不增长 — 算力换准确率。这跟 chain-of-thought 把算力花在 output token 是互补关系:一个对 latent,一个对 verbal。
反向论证:如果不做 input injection 会怎样?

没有 e 的注入,R 只能从 s_{i-1} 拿信息。s₀ 是高斯噪声,r=1 时 s₁ 只含一点 token 信息 (来自 W_a 的耦合,但若 W_a 退化为 identity,s₁ = R(s₀) = R(noise)),整个轨迹会发散。论文验证了 input injection 是性能必需。也可以反过来理解:e 充当一个固定的"题目陈述",每次循环模型都重新看一遍题目,再基于当前思路 s_{i-1} 更新 — 这跟人解数学题时反复读题再继续是同构的。


7 · 公式: FLOPs 计算

7.1 模型核心方程(Geiping 2025 沿用)

e = P(x)
s ~ 𝒩(0, σ²)n × h
sᵢ = R(e, sᵢ₋₁),   i ∈ {1, …, r}
p = C(sᵣ)

7.2 训练 FLOPs 估算

标准 Kaplan 估算 FLOPs = 6ND(forward 2ND,backward 4ND)。但 recurrent 模型只有最后 k_BP=8 次循环带梯度,其余只 forward 不 backward:

FLOPs = 6 · N₁ · D + 2 · N₂ · D

其中 N₁ = 带梯度的有效参数数 = (P + 8 × R + C) 的参数计数,N₂ = 仅 forward 的剩余参数数 = (mean_r − 8) × R 的参数。

7.3 数值敏感性

train r (mean)每 step FLOPs(相对值)训练时间(相对)
41.01.0
81.41.5
162.12.4
323.44.0

物理直觉:前 r-8 次循环是"白嫖知识、不学新东西"。它们消耗 FLOPs 但贡献为 给后 8 次 BP 提供一个好起点 —— 类似 inference-time 多想几步,只是这几步发生在训练时。这就是为什么 curriculum 有效:训练早期模型还学不到"r=32 才暴露的现象",这部分 forward FLOPs 就是浪费,后期再投入。


8 · 实验关键结果

8.1 Pretrained init 远胜 random init (Figure 2)

同样训 120B FineWeb-Edu tokens,两个 (2,4,2) 模型:

初始化final lossHellaSwag (r=32)
Llama-3.2-1B init~50%
random (Takase 2023)~33% (接近 chance)

外推 log-linear 曲线,random init 大约要 950B tokens 才能追上 pretrained init。这是论文最重要的一个证据 — knowledge transfer 在这种"参数加加减减"的 surgery 下依然成立。

8.2 Recurrent 在 GSM8K / MATH 上击败 base (Figure 5, 6, 7)

baseparamsGSM8K (r=32)MATH (r=32)
TinyLlama-1.1B 原模型1.1B~45%~14%
TinyLlama → (4,8,4) recurrent0.7B (-30%)52%~14.5%
OLMo-2-1B 原模型1.0B~25%~12%
OLMo-2 → (4,6,4) recurrent0.9B~32%~22%
Llama-3.2-1B 原模型1.2B~35%~10%
Llama-3.2 → (4,6,4) recurrent1.0B~50%~20%

读法:三个家族都呈现 "参数少 30% 但 GSM8K +7% ~ +15%"。MATH 更难,gain 更明显。base model 越强,recurrent 改造后 ceiling 越高(Llama > OLMo > TinyLlama,论文称这是 "stronger base transfers more knowledge")。

8.3 推理时 r 是有效旋钮 (Figure 5/20 right + Appendix Table 3)

同一份 (4,8,4) TinyLlama 权重,推理时把 r 从 1 拨到 32,GSM8K acc 单调上升,在 r=4~8 处越过非循环 base 的 26.6%。下图是论文 Figure 5/20 right panel 的数值复刻 (math-only training,数据来自 Appendix Table 3):

0 10 20 30 40 50 1 2 4 8 16 32 test recurrence r (log₂) GSM8K accuracy (%) non-rec 26.6% train r = 4 train r = 16 train r = 32 r=4 处超过 non-rec
图 5:GSM8K 上的 test-time scaling 曲线。三条线是 train recurrence ∈ {4, 16, 32} 的 (4,8,4) TinyLlama,横轴是推理时 r(log₂),纵轴 GSM8K 准确率。灰色虚线是非循环 TinyLlama (26.6%)。三条曲线都在 r=4 处越过 baseline;train r 越大,小 r 处性能越差但大 r 处天花板越高 — train r=32 在 r=8 时达到 45.3% 的峰值。数据来自 Appendix Table 3。

四个关键观察:

MATH 上的曲线形状一致(详见 Appendix Table 3 的 MATH 列):train r=16 在 test r=4 已达 28.9% (vs base 24.0%),且 r=8 后基本饱和。越难的任务、train r 越大、test r 越深,gain 越明显 — 这是 latent reasoning 文献一致的现象。

论文为什么不画跟 CoT 的 FLOPs 对比?

论文里 CoT 只在 intro 提了一句 motivation,没有任何 head-to-head 实验。这是这条 line 最大的缺口 — 1B 级模型本身做不好 CoT,放在一起比也意义不大;真正可比的是 8B+ scale,但本文没做到。最公平的对比应该是: "同一个 1B base,做 CoT-SFT vs 做 retrofit-recurrence,各花多少推理 FLOPs 才能在 GSM8K 达到 50%?" — 这是后续工作可以填的坑。

8.4 比 Huginn-0125 强(Table 1)

本文最强的 (4,8,4) TinyLlama recurrent 模型(~0.7B 参数,52B tokens 训练)在 MMLU 上比 Huginn-0125 (3.5B 参数,800B tokens) 高 12 个点,GSM8K 高 10 个点。这是 retrofit 策略最直接的胜利证据。


9 · 与同类工作对比

本文 (McLeish 2025)Huginn (Geiping 2025)Bae 2024 (Relaxed Recursive)Koishekenov 2025
起点pretrainedfrom scratchpretrainedpretrained (OLMo)
训练 tokens~50B800B~少量 finetune~少量
层选择挑两端丢中间—(scratch)保留全部保留全部
Input injection
训练 r 分布Poisson-LogN + curriculumPoisson-LogN 固定 mean=32固定 r=2,3固定 r
Test-time scaling单调增,r=1→32 大幅涨同上r 升反而掉点少量增益
开源权重HF: tomg-group-umd已开

差异本质:Bae 2024 是"用循环近似原 forward,LoRA 修补差异" — 等价于压缩,没拿到 test-time scaling;本文是"用循环替代深度,layer drop 让 FLOPs 都花在 R 上" — 等价于把 base model 升级成 latent reasoner。两者都从 pretrained 出发,但目标完全不同。


10 · 局限 / 个人 take / 待验证问题

个人 take

这篇是 latent reasoning 通向"工程化"的关键节点:把它从一个昂贵的预训练实验,变成一个 固定深度模型的 post-training 阶段 — 类似扩 context length。如果在 7B+ scale 上也能复现,latent reasoning 就有可能跟 chain-of-thought 同等重要 — 一个把算力花在 hidden,一个把算力花在 token,两者完全可叠加。

同时它也指明了一个让我有点意外的事实:"挑两端丢中间"这件事的存在,跟过去几年关于"transformer 中间层冗余"(ShortGPT、LACO)的发现一致 — 也许 LLM 的算力分配本就极度不均,本文相当于把这一冗余重新利用起来:不光删掉它,而是用一个 R 块循环替代它的功能。

5 个待验证问题

  1. (4,8,4) → (8,16,8) 在 8B 模型上是否保持 per-FLOP 优势?
  2. 在 RLHF / GRPO 训练时,latent recurrence 能否替代 chain-of-thought tokens?(回答"我不确定"问题时 — 多 r 是否提供更好的 calibration?)
  3. 把 input injection 升级为 cross-attention(让 e 在 R 内部可被随时查询),还能涨多少?
  4. 是否能在 inference 时用 KV cache 加速 r-scaling — R 内部 attention 的 KV 在不同 r 间能不能共享?
  5. "挑早期 + 晚期、丢中间"在 instruct-tuned 模型上仍最优吗?还是 instruct 模型的 task-level abstraction 已在中间层?

Memory points

立场 latent recurrence 不必从零训 — pretrained 模型切两端 + 加 input injection + curriculum + Muon,50B tokens 就 retrofit 成功。
surgery 早 L 做 P、晚 L 做 R+C、丢中间 L4–L9。中间层冗余度最高,丢掉对原模型损失最小,且最适合被循环替代。
curriculum mean(r) 从 2 线性升到 32(占前 75% 步),per-step loss 不变但 per-FLOPs loss 降。1-sqrt 调度更激进。
optimizer Muon 在 recurrent 训练上完胜 AdamW;在非循环模型上差距很小。Recurrent 放大了优化器差异。
healing 切层后先用 FineWeb-Edu 训 26B tokens 恢复语言能力,再切到 math/SFT 数据。+5% Arc-C。
结果 0.7B (4,8,4) TinyLlama recurrent 在 MMLU 比 3.5B Huginn-0125 高 12 点,GSM8K 高 10 点。
test-time 旋钮 一次训练,推理 r ∈ {1, 2, 4, 8, 16, 32} 均可,acc 单调递增。
开源 HF: tomg-group-umd/retrofitting-recurrence;Code: github.com/mcleish7/retrofitting-recurrence