659 Views
October 08, 25
スライド概要
DL輪読会資料
DEEP LEARNING JP [DL Papers] Diffusion for World Modeling: Visual Details Matter in Atari Ku Onoda, Matsuo Lab http://deeplearning.jp/ 1
書誌情報 • Diffusion for World Modeling: Visual Details Matter in Atari – NeurIPS 2024 Spotlight • 概要 – 拡散モデルを世界モデル(World Model)として初めて採用 • DIAMONDを提案 – 想像上のピクセル空間上でエージェントを学習, Atari環境でSOTA – プレイ可能なゲーム生成(Atari, CS:GO) • プロジェクトページ – https://diamond-wm.github.io/ 2
背景 • 世界モデルでのRLエージェントの学習は安全で、サンプル効率の良い アプローチとして期待される – 従来のWorld Model(Dreamer, IRIS, ...)は離散潜在空間に圧縮してモデリン グ →視覚ディティールが失われることが性能低下につながる • 拡散モデルは画像/動画生成分野で卓越した性能 – 連続ピクセル空間の高品質生成が可能 →拡散モデルを用いて効率的な世界モデルは実現できるのか? 3
世界モデル • 拡散モデルを利用した世界モデル – 観測と、アクションを与える→diffusion modelが次のobservationを出力 • 世界モデルからアクションをサンプリングし軌道の生成を可能に 4
提案手法 • DIAMOND (DIffusion As a Model Of eNvironment Dreams) – 世界モデルに必要な3つの特性を持つ • 環境ダイナミクスの正確なモデリング – 拡散モデルで次観測をピクセルレベルで生成 • RLにおける高速な推論 – 3ステップの逆拡散で高速にロールアウト可能 • Long trajectoriesでの安定性 – EDMを使用→ドリフトや崩壊を抑制 – 主な工夫 • DenoisingにDDPMではなくEDMを使用 • 直近4フレームを入力し行動をU-Netの各層に条件付け 5
DDPMが有用か • RLの学習で高速な推論が要求される→denoising stepを小さくしたい • DDPM→少ないステップではdrifting(誤差)が大きい 6
DDPM, EDMの比較 • DDPM – Objective:ノイズ予測 – 低ノイズ→ノイズ予測 – 高ノイズ→恒等写像を学ぶ(構造的な情報を含まない) • EDM(Elucidated Diffusion Model) – ノイズの強さに応じて、targetを切り替えるように設計されている – 低ノイズ(c_skip→1)→ノイズ予測 高ノイズ(c_skip→0)→画像を予測 7
Denoising Stepsの選択 • EDM – 1ステップでも破綻しにくい • 環境の不確実性が高いと平均的なボヤけた生成となる – 複数ステップにすることで、確率的なモード( 敵が右/左に動くなど)に収束 できる – 3ステップが良かった →多様な想像(multi-modal)をうまく生成できる – それ以上のステップ(4~10) • 1ステップの結果にも及ばない 8
世界モデルの学習 • Loss • アーキテクチャ – 2D Unet – 行動条件付け • 入力 – 直近4フレームをチャンネル方向にスタック(Flame Stacking) • 出力 – 次のフレームを生成 9
RLエージェントの学習 • 仮想環境内でのエージェント学習 • 拡散モデルを環境として使用 – 別ネットワークで報酬と終端を予測(CNN+LSTM) • replay buffer内の実データからサンプリングし、それぞれクロスエントロピー誤差を最小 化 • REINFORCEでポリシー勾配を更新 • 学習済みポリシーを実環境で再評価→データの再収集 10
実験 Atari • Atari – 100k ベンチマーク(26ゲーム) – HNS(人間との相対スコア)を用いた評価でSOTA • 平均HNS 1.46 11
実験 再構成 • 再構成の比較 – DIAMOND • 不整合が発生しづらい – IRIS • 離散トークンに圧縮し、自己回帰Transformerでトークン生成 • 細部の再構成品質が落ち、視覚的な不整合が発生 – (敵と報酬の混同, スコアの描画が不正確, …) 12
実験 ablation study • Denoising step – 1ステップ、3ステップの定量評価 13
実験 CS:GO • CS:GO(シューティングFPS) • 87時間の人間のプレイ動画から学習 • アクション – 移動 (WASD), ジャンプ、視点操作 • 0.1秒ごと(10Hz)に行動をサンプリング • 高解像度(512*512)・滑らかなカメラ挙動を保持可能 14
議論 • 課題 – 学習外のアクションは崩壊 • ジャンプ連打、下を向くなどOOMによる脆弱性 – 長期的なメモリ・一貫性の欠如 • 現状は4フレームのみ参照 – Spatial Drift(誤差の蓄積) – 視覚外の状態遷移の履歴がない • 環境の隠れ状態は明示的にモデリングしていないので、観測のみ生成 – 計算コスト • 低ステップ数(3denoising step)でも高品質になるがDreamer系よりも重い • 報酬/終了は別のモデルで別に学習 • 長期ロールアウトはメモリ、計算コストが線形に増加 15
まとめ • 初めて拡散モデルを用いて世界モデルを構築した – 世界モデルに必要な構造を検証 • 直近4フレームのスタッキング, 行動の条件付け • EDM > DDPM – 有効性を示した • Atari100kでSOTA, 3Dゲーム(CS:GO)でも生成可能 – 同時に複数の課題も上げた • 長期メモリ欠如 • OOD脆弱性 • 潜在状態のモデリング – 後続の研究につながる研究としての位置付け 16