[DL輪読会]Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies

>100 Views

August 06, 21

スライド概要

2021/08/06
Deep Learning JP:
http://deeplearning.jp/seminar-2/

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

関連スライド

各ページのテキスト
1.

DEEP LEARNING JP Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies [DL Papers] Hiroki Furuta http://deeplearning.jp/

2.

書誌情報 • タイトル: Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies • 著者・所属: Paul Vicol12, Luke Metz2, Jascha Sohl-Dickstein2 ➢ 1University of Tronto, 2Google Brain • URL: http://proceedings.mlr.press/v139/vicol21a.html • 概要: RNNやメタ学習に現れるunrolledな計算グラフは勾配の計算に 課題が多かった。Persistent Evolution Strategiesでこれらを解決した。 ICML2021 Outstanding paper 2

3.

Unrolled Computational Graph • RNNやメタ学習などに見られる • (一口に言うと)計算グラフにループ構造が入る場合に、それらを無理 やりforward pathのみに直す [解説記事(英語)] • forward path: ある関数 𝑓で𝑠𝑡 を𝑠𝑡+1 に変換する ➢ 𝑓の例: RNNのforward path, メタ学習のoptimizer (SGD, RMSprop, etc.) 3

4.

事前知識 • 目的関数: ➢ 各ステップのLossが前のstepのstate 𝑠𝑡−1 に依存している • BPTTなど ➢ 逆伝播の計算を𝐾 ≪ 𝑇までで打ち切ることで計算量を減らすアプローチがある (Truncation) • 進化計算(Evolution Strategies; ES) ➢ 有限差分を用いてblack-boxな関数の勾配を推定する 4

5.

Unrolled Computational Graphの課題 • メモリ使用量がunroll length(𝑇)に対して線形に増えてしまう • 𝑇stepのunrollそれぞれに対してパラメータを1回しか更新できないた め、計算量が多い。またパラメータの更新が遅くなる • unroll length(𝑇)が長いほど勾配爆発や勾配消失が起こりやすくなる ➢ 特にメタ学習で顕著 • 計算量やメモリ使用量の課題に対してはTruncationで対処することが 多い ➢ Truncated Backprop Through Time (TBPTT) [Tallec & Ollivier 2017] 5

6.

Unrolled Computational Graphの課題 • Truncationの問題は、逆伝播の計算をsub-sequenceに分割すること によって、biasedな勾配(真の勾配と一致しない)になること ➢ 例えば短期間の依存性しか捉えられなくなるなどの影響が出る • またUnrolled Computational Graphの損失関数のlandscapeは非連続 やカオスな構造になっていることが多く、最適化が難しい [Metz et al. 2019] ➢ ESにより、そのような損失関数のunbiasedな勾配の推定量を計算することは できるが、計算時間の観点から結局Truncationを用いることが多い 6

7.

関連研究 • Real-time recurrent learning (RTRL) [Williams & Zipser 1989] ➢ BPTTと異なりオンラインでパラメータ更新ができる ➢ truncation biasはないが、メモリ&計算量、勾配の分散が大きい、実装が複雑、 限られたclassのモデルにしか適用できないなどの課題がある ➢ Tallec & Ollivier 2017, Mujika et al. 2018, Benzing et al. 2019, Cooijmans & Martens 2019 • Hyperparameter optimization (HO) ➢ (1) black-box, (2) grey-box, (3) gradient-based なアプローチがある ➢ (1) Bergstra & Bengio 2012, Snoek et al. 2012, Salimans et al. 2018, (2) Swersky et al. 2014, Jamieson & Talwalkar 2016, Li et al. 2017, Jaderberg et al. 2017, Lorraine & Duvenaud 2018, MacKay et al. 2019, (3) Lorraine et al., 2020, Maclaurin et al. 2015 7

8.

Persistent Evolution Strategies • Notation: • ESで • を近似する; 𝑃 パラメータ数 平均0分散𝜎 2 のGaussianからサンプルされる摂動 ※ Supplementary Material に具体例が 載っていてわかりや すい 8

9.

Persistent Evolution Strategies • 𝑁個のサンプルによるMonte-Carlo推定で勾配を求める • 各stepでparticleを加える & perturbationのaccumulationがポイント 9

10.

ES vs PES • アルゴリズム 10

11.

Persistent Evolution Strategies • ESはtruncated unrollsごとに𝑠を初期化する • PESはtruncated unrolls間で同じ𝑠を使う ➢ 直感的な解釈としては過去の履歴情報が𝑠に保存される 11

12.

PES is Unbiased for Quadratic Losses • 損失関数が二次形式(quadratic)であるとPESの勾配はunbiasedになる 12

13.

PES is Unbiased for Quadratic Losses • 続き 13

14.

実験 • 勾配の分散 ➢ particleの数(𝑁): 大 → 勾配の分散: 小 • Meta-objective surface ➢ 2D regressionのトイタスク 14

15.

実験 • Influence Balancing task ➢ 最適化の際にshort-horizon biasに敏感なタスク ➢ PESはtruncation biasを解消できている 15

16.

実験 • Learned Optimizer Meta-Optimization ➢ MLP-based learned optimizerをCIFAR-10の分類問題について最適化 ➢ inner-stepは𝑇 = 1000, truncation lengthは𝐾 = 4 MNISTの分類タスクでlearning rateとDecayのmetaoptimization; 右: training loss, 左: validation accuracy 16

17.

実験 • Continuous Control ➢ RLアルゴリズム (SACなど)でも最大で150ぐらい 17

18.

まとめ • Unrolledな計算グラフにおいて、unbiasedな勾配の推定を可能にした Persistent Evolution Strategies (PES)を提案した。 • これまではbiasedな勾配しか得られないことが課題だったtruncated unrolls (例: RNNのtruncated BPTTなど)からでも、unbiasedな勾配 を高速に求めることができる • 実験でRNNに近いタスク、ハイパーパラメータ最適化、強化学習、メ タ学習など幅広いタスクに適用可能であることを実証。 18