1K Views
November 27, 25
スライド概要
DL輪読会資料
DEEP LEARNING JP [DL Papers] LaDiR: Latent Diffusion Enhances LLMs for Text Reasoning Kai Yamashita, Matsuo Lab http://deeplearning.jp/ 1
書誌情報 • Title: LaDiR: Latent Diffusion Enhances LLMs for Text Reasoning • Authors: Haoqiang Kang, Yizhe Zhang, Nikki Lijing Kuang, Nicklas Majamaki, Navdeep Jaitly, Yi-An Ma, Lianhui Qin • Submitting to ICLR2026 TL;DR • CoT(思考過程)を文ごとに VAE で連続潜在トークン列に圧縮し、その潜在空間上で拡散 (flow matching)による「推敲付き推論」を行うフレームワーク LaDiR を提案。 • 潜在ブロックをノイズから徐々に洗練させ、その列を条件に 通常のARデコードで答えの みを生成することで、計算量と精度のトレードオフや多様な推論パス探索を実現。 • 数学・パズルベンチマークで、AR CoT や従来の latent/diffusion 手法より高い精度と多様 性(Pass@k)を達成し、「連続潜在×拡散」での推論強化が有効なことを示した。 2
背景 Chain-of-Thought(CoT) and Reasoning Models •複雑な問題を解くときに,「結論だけ」ではなく 途中の考え方や推論のステップを順番に書き出していく方法 •人間が頭の中で「まずAだからB、だからC…」と考えるのと同じように,理由づけを細かく分解してつなげた思 考の鎖 •大規模言語モデル(LLM)では,このChain of Thoughtを使うことで,数学・論理パズル・長い文章の理解などの 精度を上げており,思考過程もモデル化するものをReasoning Modelと呼称する 3
背景 Test-Time Scaling ある入力に対して,複数の出力を生成してBest-of-NやSearchアルゴリズムを活用することで,計算コストの トレードオフで精度を向上させることができる 4
Problems of CoT in text token • 現在の言語モデルはAutoregressive(AR)形式であり,左から順にTokenを生成し ていく形式 • Inference-Time Scalingなどで複数の推論pathをサンプリングしているものの, ひとつのpathは一方向で,一旦生成したTokenは後戻りできない • Scalingしても同じような推論pathしか出せない • 一方向なので,間違った前提に乗ったまま最後まで推論してしまう • トークンの一部は推論に関係ない • 推論に必要ない自然言語トークンにも推論能力が割かれる 5
Latent Reasoning COCONUT • • • <bot> … <eot> の区間では一切トークンを出さず,直前トークンの隠れ状態ベクトルをそのまま再入力して「continuous thought(潜在思考)」として内部だけで推論させる。 既存のCoT軌跡を使い,最初は全文CoTで学習 → 一部区間をlatentに置き換える → さらに置き換え範囲を広げる,というカリ キュラム学習で「latentで考える」挙動を身につけさせる。 このlatent推論により,複数の分岐をベクトル上で並列に保持・評価しつつ探索でき,CoTよりもトークン数・計算量を抑え 6 つつ,ハルシネーションや誤った分岐を減らせる。
Latent Recurrent Reasoning Scaling up Test-Time Compute with Latent Reasoning: A Recurrent Depth Approach • デコーダ専用 Transformer を「Prelude–再帰コア–Coda」に分け、同じコアを何度も再帰して“計算深 さ”を伸ばす。 • 学習時に再帰回数 r をランダムに変えながら言語モデリング訓練し、推論時に r を自由に増減できる ようにする。 • truncated BPTT(最後の数ステップだけ逆伝播)+RMSNorm などで、長い再帰でも学習が安定するよ うに設計。 7
Problelms of Previous Latent Reasoning • Latent Stateはどのベクトルがどの推論状態を指しているかわかりに くいという問題 • 性能面で通常のAR CoTの性能を越えられてない • Discrete Latentはトークン空間の延長にすぎない • 推論プロセスそのものの軌道のモデルかができていない 8
LaDiR • CoTを文ごとにVAEで“潜在思考ブロック”へ圧縮し、各ステップの思考を固定長の連続ベクトル列と して表現することで、「意味単位の連続latentな思考空間」を用意するモデル • その潜在ブロック列に対して flow matching による latent diffusionを実行し、「推論=反復的なノイ ズ除去・推敲プロセス」としてモデル化、denoising step数で計算量と精度をトレードオフできる • 生成された潜在思考ブロック列を条件に、最後だけ通常のAR言語モデルでFinal Answerを出力する 構成 9
LaDiR • CoTを文ごとにVAEで“潜在思考ブロック”へ圧縮し、各ステップの思考を固定長の連続ベクトル列と して表現することで、「意味単位の連続latentな思考空間」を用意するモデル • その潜在ブロック列に対して flow matching による latent diffusionを実行し、「推論=反復的なノイ ズ除去・推敲プロセス」としてモデル化、denoising step数で計算量と精度をトレードオフできる • 生成された潜在思考ブロック列を条件に、最後だけ通常のAR言語モデルでFinal Answerを出力する 構成 10
Latent Block Learning • 各文を固定長を潜在変数にLLM Encoderと,Learnable Embeddingを用いてEncode • Decoderを凍結して,Encoderのみをβ-VAE損失で学習(Stage-0) • 潜在変数へのノイズ付加,入力Token置換などでロバスト性を付与 11
Autoregressive Model • • • • 以下の形式でformat • [質問 q] , <BOT>, Z^(1), <EOT>, <BOT>, Z^(2), <EOT>, …, <SOA>, [答えトークン列] ブロック内はBidirectional Attention, ブロック間はCausal Attention 出力Head • ベクトル場 • Answerを出力するHead • 各<EOT>位置で次が<BOT>か<SOA>か判定する2値分類Head ベクトル場予測,Answer Token Cross Entropy, <BOT> or <SOA> binary cross entropy lossで学習 • 最初はVAE Encoderから生成したBlock Latentを使って学習(Stage-1) • 次にDiffusion Modelからdenoising stepを削ってBlock Latentを生成し,勾配を切らずに学習(ドメインシフトの防止) (Stage-2) 12
Diverse Sampling 1. Initial Noiseの分散の強化 Denoising processにおける初期ノイズの分散を大きくとって,同じConditionでも潜在軌道の出発点の多 様性を広げる 1. Denoising Stepにおけるin-batch repulsionの導入による多様性ガイダンス Denoising stepにおける潜在変数において,バッチ内の同時刻の他サンプルとの距離からバンド幅を median距離として計算し,近いサンプル同士を押し離すエネルギーを定義 Repursionの強さのスケールは,denoising stepの最初は強めにし,終盤は弱めにして収束を優先 元のdiffusion modelに加えて,repursionを足したものでdenoising stepを行うことで,バッチ内でカバレ ッジの広いLatent Reasoning軌道が得られる 13
Pass@1 Evaluation • • • • • • Reasoning Taskについて,In-Domainタスクと,OODタスクで検証 ベースラインとして,CoT SFTと,Sol-Only SFT(答えのみSFT), COCONUT, Discrete Latentなどと比較 LaDiRはpass@1評価のaverageで最も高い性能 最良のLatent Reasoning手法のDiscreate Latentよりも良いスコアという主張 COCONUT(Continous Latent Reasoning)手法と比べると大きく改善 Stage-2の学習を抜くと性能は有意にDrop, Stage-2の学習の有効性を示唆 14
Pass@K evaluation • • • • • パズルゲーム Countdownにおける性能評価 CD-4タスクではAR ベースラインと比べて,大きく性能を改善 Pass@100とDiversityも最もよく,多様な回答を出しつつ正解を出せていることの示唆 より難しいCD-5でもARベースラインを上回りつつ,Diversityを改善 タスク特化モデル(MGDM)にも逼迫する性能を出し,pass@kのkを増やすとMGDMも上回る 15
Denoising steps scaling • Denoising stepを増やすと,monotomic に精度が向上 • Denoising stepに関してscale則が現れている 16
Diversity Sampling evaluation • CD-4における,初期ノイズの分散と,多様性ガイダンスの強さの変化による効果を分析 • 初期ノイズの分散と,多様性ガイダンスを強くすると,多様性はMonotomicに上昇 • ただ,pass@100における正答率には最適な点があり,適度な調節が必要 17
Ablations • • Diffusion modelのlossのvariantsのAblation • Target latentの回帰(MSE Loss), x0 prediction, noise-predisction(ε), DDIM-velocity-prediction(v), FMvelocity prediction(u)で比較 • FMのLossで学習するのが最も良さそう ブロックのLatent長のAblation • 長さを増やしていくと,潜在表現の空間が広くなるため再構成誤差は減っていくが,拡散モデルが 探索する空間が増え,精度はmonotomicに上昇しない 18
Qualitative Results Denoisingされるに従って,無関係な文字列からreasoning文に連続的に変化していく 19
Qualitative Results BlockごとにReasoning Sectionの1文に対応している 20
まとめ 結論 • CoT を文ごとに連続潜在に圧縮し、その上で拡散(flow matching)することで、 「意味レベルで自己修正する推論フレームワーク LaDiR」を提案。 • 拡散ステップ数や多様性ガイダンスをいじることで、精度・計算量・多様性をテ スト時に調整できる。 • 数学ベンチマークと Countdown で、AR CoT・既存 latent・既存 diffusion より 高い精度と多様な解法探索を実現。 Limitation • 評価タスクが数学推論+Countdown にほぼ限定されており、一般的な言語タス クへの有効性は未検証。 • 高品質な CoT データと、VAE+拡散+LLM+2段階学習という重めの学習パイプ ラインに依存。 • ブロック分割や潜在次元 L_b、拡散ステップ数など、タスク依存のハイパラ調整 がかなり必要。 21