【DL輪読会】Controlling Large Language Model with Latent Actions

>100 Views

January 15, 26

スライド概要

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

ダウンロード

関連スライド

各ページのテキスト
1.

Controlling Large Language Model with Latent Action Yuya IMAI, Matsuo Iwasawa Lab | 1 of 17

2.

書誌情報 タイトル: Controlling Large Language Model with Latent Action 会議: ICML 2025(Poster) 著者: Chengxing Jia, Ziniu Li, Pengyuan Wang, Yi-Chen Li, Zhenyu Hou, Yuxiao Dong, Yang Yu TL;DR: LLMの行動をトークンではなく少数の離散潜在アクションで制御するCoLAを提案し、探索空間を圧縮して RLを効率化する。数学推論やエージェントタスクで性能向上し、Reward Hackingにも比較的頑健。 リンク: https://openreview.net/forum?id=cEKrGCFXPA https://proceedings.mlr.press/v267/jia25e.html https://arxiv.org/abs/2503.21383 https://github.com/LAMDA-RL/CoLA https://huggingface.co/LAMDA-RL/Llama-3.1-CoLA-10B | 2 of 17

3.

背景と課題 LLMを特定のタスクに適応させるために、強化学習(RLHF / RLAIF など)を使う手法は一 般的だが、従来の定式化には課題がある 課題1: 行動空間が巨大すぎる(探索が非効率) LLMが次に出力する各トークンをそのまま行動として扱う 近年のモデルは語彙数が非常に大きい(例: Llama-3 系は 12万語彙以上) → 1ステップの分岐が大きく、探索・クレジット割当が難しいため サンプル効率が悪い 課題2: 構造の欠如 LLMは本来「次トークン予測(next-token prediction)」の生成モデルであり、RLエージェントとして設計されてい ない そのため、報酬に合わせて望ましい振る舞いを安定に制御しづらい | 3 of 17

4.

提案手法: CoLA (Controlling LLMs with Latent Actions) 離散潜在変数を用いた階層的な条件付き生成モデルに拡張し、強化学習の探索空間を削減 主な構成要素: 1. Language World Model (fworld ) 2. Policy Model (π) 3. Inverse Dynamics Model (finverse ) ​ ​ Naive Decoder-only Pipeline CoLA Pipeline | 4 of 17

5.

フレームワーク詳細 1. Language World Model (fworld ) ​ 入力: 過去コンテキスト x1:t , 潜在アクション at Merge moduleを使って、事前学習済みLLMの埋め込みに、潜 在アクションを埋め込みとして注入 出力: 次トークン xt+1 の分布 ​ ​ ​ 2. Policy Model (π) 入力: 過去コンテキスト x1:t 出力: 潜在アクション at の分布 ここをRLで学習することで、LLM本体を大きく変更せずに制御 ​ ​ | 5 of 17

6.

フレームワーク詳細(続き) 3. Inverse Dynamics Model (finverse ) ​ 入力: 過去コンテキスト x1:t と次トークン xt+1 出力: (その遷移を説明する)潜在アクション at の分布 教師なし学習で潜在アクションを抽出するために使用 サイズN のコードブック C (実験ではN = 64)を使用 ​ ​ ​ 全体の推論プロセス Step 1: at ∼ π(⋅∣x1:t ) Step 2: xt+1 = fworld (x1:t , at ) ​ ​ ​ ​ ​ ​ | 6 of 17

7.

学習プロセス 更新対象の整理 (再掲) fworld : (x1:t , at ) ↦ p(xt+1 ) π : x1:t ↦ p(at ) finverse : (x1:t , xt+1 ) ↦ p(at ) ​ Inverse Dynamics Model: θinverse Language World Model: θworld = (θbase , θmerge ) Policy Model: θpolicy ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ 全体像 段階1: 潜在アクション制御の構築(更新: θinverse , θmerge , θpolicy ) 段階2: 指示追従のための世界モデル調整(更新: θbase ) 段階3: 潜在アクション空間でのRL(更新: θpolicy ) ​ ​ ​ ​ ​ | 7 of 17

8.

学習プロセス: 段階1 Constructing Latent Action Control (再掲) ​ 1-1. Inverse Dynamics Model(+Merge module)の学習 Inverse Dynamics Model(finverse )で潜在アクション at を推定 x1:t と at から Language World Model(fworld ) で xt+1 を予測し、予測誤差を最小化 更新: θinverse , θmerge / 凍結: θ^base ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ fworld : (x1:t , at ) ↦ p(xt+1 ) π : x1:t ↦ p(at ) finverse : (x1:t , xt+1 ) ↦ p(at ) ​ ​ 1-2. Policy ModelをBehavior Cloning(BC)で初期化 目的: RL前に安定な初期Policy Modelを作る Policy Model(π)で潜在アクション at を推定 Inverse Dynamics Model(finverse )の at を擬似ラベルにして予測誤差を最小化 ​ ​ 更新: θpolicy / 凍結: θ^inverse ​ ​ ​ | 8 of 17

9.

学習プロセス: 段階2 Fine-Tuning under Action Guidance (再掲) ​ 目的: Language World Model(fworld )を指示追従データに適応させる。 Inverse Dynamics Model(finverse )を固定して、出力 at を使用する。 x1:t と at から Language World Model(fworld ) で xt+1 を予測し、予測誤差を最小化 ​ ​ ​ fworld : (x1:t , at ) ↦ p(xt+1 ) π : x1:t ↦ p(at ) finverse : (x1:t , xt+1 ) ↦ p(at ) ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ 更新: θbase / 凍結: θ^inverse , θ^merge base更新後にPolicy Modelを再度Behavior Cloningで更新 baseによる埋め込みが変わるため ​ ​ ​ | 9 of 17

10.

学習プロセス: 段階3 Latent Action Reinforcement Learning(RL) (再掲) fworld : (x1:t , at ) ↦ p(xt+1 ) π : x1:t ↦ p(at ) finverse : (x1:t , xt+1 ) ↦ p(at ) ​ 与えられるもの: prompt-only データ Drl = {x1:p } と報酬モデル R(x1:T ) 固定: θ^world , θ^inverse (生成の言語能力は世界モデル側に保持) 更新: θpolicy (潜在アクション選択のみ学習) Roll-out(生成) ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ ​ xt+1 ∼ pworld (⋅∣x1:t , at , θ^world ) at ∼ πθpolicy (⋅∣x1:t ), ​ ​ ​ ​ ​ ​ ​ ​ ​ 目的関数 max E[R(x1:T )] ​ θpolicy ​ ​ 実装上は 初期Policyを参照モデルとして、潜在アクション空間でKLを計算して制約・正則化 RL更新はPPO以外にも、GRPO / RLOO / ReMax / REINFORCE++ などのLLM向け手法が選択肢 | 10 of 17

11.

実験結果: 意味的多様性の向上 検証内容: 潜在アクションによる制御が生成テキストの多様性に どう影響するか 検証データから 複数の prefix(過去コンテキスト)をランダ ムに選び、各prefixに対して複数の生成結果を得る BGE-M3埋め込みの cos類似度の総和の逆数 結果 (Fig 2): Latent Action Sampling (青) が高い多様性を示す Random action sampling:latent action をランダムにサ ンプルして world model で生成 Base model sampling:ベースLLM(Llama-3.1-8B)で 通常生成 Random token sampling:トークンをランダムにサンプル して生成 事前学習トークン数が増えるほど多様性は向上 (赤) | 11 of 17

12.

実験結果: 数学推論タスク 検証内容: 数学推論タスクでの性能 結果 (Fig 3): CoLA は Baseline (Llama-3.1-8B SFT) を上回る性能 特にPass@Kにおいて探索能力の高さが示された。RL後、Math500で 42.4 (Baseline 38.2) を達成 Benchmarks Pass@K on Math500 | 12 of 17

13.

効率的な探索: Action-level MCTS MCTS-Q: 潜在アクション空間上でのモンテカルロ木探索 (MCTS) 潜在アクション空間が小さいため、トークン単位よりも探索が 効率的 Q関数(Qwen-Math-2.5-72B reward model)に基づく枝刈り を導入 結果: Math500において、MCTS-Q (CoLA) は 68.2 を達成 Baseline + MCTS 63.2(Baseline + MCTS-Q 63.0、 CoLA + MCTS 65.4)を上回る Math500 Score Comparison Baseline (SFT): 38.2 CoLA (RL): 42.4 Baseline + MCTS: 63.2 Baseline + MCTS-Q: 63.0 CoLA + MCTS: 65.4 CoLA + MCTS-Q: 68.2 | 13 of 17

14.

エージェントタスク: Countdown Game タスク: 与えられた数を使って目標値を計算する。思考過程( <think> )と回答( <answer> )のフォーマット厳守。 結果 (Fig 4): CoLAはBaselineよりも早くフォーマット報酬を獲得(約2倍速) ただしこの設定では正答率は10–15%程度と限定的で、正しく解くのは難しい Reward Curve Response Length | 14 of 17

15.

エージェントタスク: Alfworld & Scienceworld マルチターンRLタスクでの性能検証 結果 (Table 1): CoLA-RLはBaseline-RLと比較して、Seen/Unseenタスクの両方で大幅な性能向上 複雑な環境での探索と適応において優位性を示す BENCHMARK ALFWORLD (Seen) BASE-SFT BASE-RL CoLA-FTA 68.6 68.6 (+0.0) 75.7 ALFWORLD (Unseen) 67.9 71.6 (+3.7) 70.9 CoLA-RL 77.9 (+2.2) 74.6 (+3.7) SCIENCEWORLD (Seen) 17.0 18.0 (+1.0) 24.7 SCIENCEWORLD (Unseen) 17.5 15.6 (-1.9) 20.4 28.4 (+3.7) 21.8 (+1.4) | 15 of 17

16.

Reward Hackingへの頑健性 検証内容: 不完全な報酬モデルを用いたRLHFにおけるReward Hackingの影響 結果 (Fig 5): CoLAはKL制約が弱い場合(KL = 0.00)でも、Baselineに比べてReward Hackingに強い Baselineは意味のない質問を繰り返すなどの縮退が見られたが、CoLAは回答能力を維持 Policy Modelのみを学習し、Language World Model (言語能力) を固定しているため頑健 Win rate vs Baseline Win rate (KL=0 vs KL=0.01) | 16 of 17

17.

まとめ CoLA (Controlling Large Language Models with Latent Actions) LLMを「Policy Model」と「Language World Model」に分離 巨大なトークン空間ではなく、コンパクトな潜在アクション空間でRLを行う 利点 探索効率の向上: Math500やエージェントタスクで高性能 高い制御性: MCTSなどの探索アルゴリズムとの親和性が高い 頑健性: Reward Hackingに対して強く、言語能力を維持しやすい 今後の展望 より多様なBase Modelでの検証 さらに複雑なタスクへの応用 | 17 of 17