【DL輪読会】GIT RE-BASIN: MERGING MODELS MODULO PERMU- TATION SYMMETRIES

268 Views

November 11, 22

スライド概要

2022/11/11
Deep Learning JP
http://deeplearning.jp/seminar-2/

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

関連スライド

各ページのテキスト
1.

DEEP LEARNING JP “GIT RE-BASIN: MERGING MODELS [DL Papers] PERMU- TATION SYMMETRIES” MODULO 発表者:岩澤有祐 http://deeplearning.jp/

2.

書誌情報 “Git Re-Basin: Merging Models Modulo Permutation Symmetries” , ICLR 2023 Under review • 著者:Samuel K. Ainsworth, Jonathan Hayase, Siddhartha Srinivasa (University of Washington) • 概要 – なぜSGDが毎回同じような性能を達成するのか? – SGDで到達されるほとんどの解はPermutationを除いて Linear Mode Connectedであるから(右図) • SGDとLMCに関連する論文として下記も簡単に紹介 – “Unmasking the Lottery Ticket Hypothesis: What's Encoded in a Winning Ticket's Mask?” ※他に断りがない限り本資料の図表は当該論文より抜粋 2

3.

Outline • 前提知識:Linear Mode Connectivity – “Linear Mode Connectivity and the Lottery Ticket Hypothesis”,ICML2020 • “Git Re-Basin: Merging Models Modulo Permutation Symmetries” • “Unmasking the Lottery Ticket Hypothesis: What's Encoded in a Winning Ticket's Mask?” 3

4.

Instability, Error Barrier • ある初期値W0から開始 • 異なるノイズ(サンプルの順序など)を 加えてW0から2つの重みを作る • W1とW2の間を線形補間したときの 性能の劣化がError Barrierと呼ぶ Linear Mode Connectivity and the Lottery Ticket Hypothesisより抜粋 4

5.

Barrierの実際の例 • 別のデータセット,別のアーキテクチャを初期値から初めて学習した ときのBarrierの可視化(左はBarrierなし,右はBarrierあり) Linear Mode Connectivity and the Lottery Ticket Hypothesisより抜粋 5

6.

Barrierの図示 W0 W3 W2 W1 • W1とW2は間を補完しても性能が下がる点がない(同じ局所解周辺) • W1とW3は間を保管すると性能が下がる(異なる局所解周辺) => Barrierがない状況はSGDがノイズに対して頑健(同じ解周辺に到達)を意味 6

7.

Linear Mode Connectivity Definition: ε-Linear Mode Connected (LMC) ある2つの重み𝒘𝟏 ,𝒘𝟐 が次の性質を満たすときLMCと呼ぶ. 𝐿 𝛼𝒘𝟏 + 1 − 𝛼 𝒘𝟐 ≤ 𝛼𝐿 𝒘𝟏 + 1 − 𝛼 𝐿 𝒘𝟏 Definition: Error Barrier 上記を満たす最小のεを𝒘𝟏 ,𝒘𝟐 のError Barrierと呼ぶ 7

8.

様々なアーキテクチャにおけるError Barrier • 別のデータセット,別のアーキテクチャを初期値から初めて学習したときのBarrierの可視化 • 簡単なタスク,単純なモデルではError Barrierがほぼゼロ • ResNet等では初期値から始めるとBarrierが存在 Linear Mode Connectivity and the Lottery Ticket Hypothesisより抜粋 8

9.

RewindingとLMC • 学習を最初からではなく途中からやりなおすことをRewindingと呼ぶ • 大きめのモデルでも学習の途中でError Barrierがなくなる => SGDは学習途中からは安定に同じ局所解にたどり着いている Linear Mode Connectivity and the Lottery Ticket Hypothesisより抜粋 9

10.

Outline • 前提知識:Linear Mode Connectivity – “Linear Mode Connectivity and the Lottery Ticket Hypothesis”,ICML2020 • “Git Re-Basin: Merging Models Modulo Permutation Symmetries” • “Unmasking the Lottery Ticket Hypothesis: What's Encoded in a Winning Ticket's Mask?” 10

11.

Permutation symmetries of Neural Networks • NNの重みは入れ替えても機能的には不変 𝑧𝑙+1 = 𝑃 𝑇 𝑃𝑧𝑙+1 = 𝑃 𝑇 𝑃𝜎 𝑊𝑙 𝑧𝑙 + 𝑏𝑙 = 𝑃 𝑇 𝜎 𝑃𝑊𝑙 𝑧𝑙 + 𝑃𝑏𝑙 • σ:活性化関数 • P:Permutation Matrix 11

12.

“The Role of Permutation Invariance in Linear Mode Connectivity of Neural Networks”, arxiv, 2021 Conjecture “Most SGD solution belong to a set whose elements can be permuted so that no barrier exists on linear interpolation between any two permuted elements” 図は“The Role of Permutation Invariance in Linear Mode Connectivity of Neural Networks”より抜粋 12

13.

参考:# Permutation Symmetries • 取りうるPermutation Symmetriesは膨大(前述の予測の厳密な検証は困難) “Git Re-Basin: Merging Models Modulo Permutation Symmetries”より抜粋 13

14.

Permutation Selection Method • 方法1:Matching Activations – データが必要だがSolverがある • 方法2:Matching Weights – データは必要ないが素朴には解けないので層ごとに行う • 方法3:Straight Through Estimator Matching 14

15.

Permutation後のError Barrier • 4つのデータ,モデルでの検証(ザックリ右に行くほど難しい) – 右2つは効率性の観点からWeight Matching (緑)のみを検証 • いずれもPermutationの修正のみでBarrierがあった2つのモデルのBarrierが大幅に減少 – ※ MNIST, MLPとかは既存ではそもそもBarrierないことになっている気がするがそれは不明 • Permutation方法はざっくりSTE >= Weight > Activation • Weight Matchingはmere secondくらいで発見できるらしい 15

16.

NNの幅とPermuted Error Barrier • 幅(フィルタ数)を変更したときのPermutation後のError Barrier • 幅を大きくすることが重要 16

17.

参考:Permutation前のBarrierと幅と深さの関係 • 幅(フィルタ数)を大きくするとそもそもError Barrierは減る (SGDが同じ解に到達しやすくなる) • 深さは増やすとError Barrierは大きくなる 図は“The Role of Permutation Invariance in Linear Mode Connectivity of Neural Networks”より抜粋 17

18.

LMCは何によって生じるのか • MLPをMNIST(左)とCIFAR-10(右)で学習した際のBarrierの推移 • Loss Barrierは学習が進むに連れて小さくなる ※ 厳密な記載がないが多分Weight MatchingでPermutationを戻している • モデルアーキテクチャ自体によって引き起こされているのではなくSGDによるバイアス 18

19.

別データの重みをPermutation後にモデルをマージ • 普通に別のデータで学習した重みを平均化すると性能は劣化する • Weight Matching後の重みは平均化する事により性能が若干向上する • ただし,普通にアンサンブルしたり全データで訓練する場合よりは性能落ちる 19

20.

まとめ • SGDが学習する解はPermutation Symmetryを除き同じ局所解 とLocally Connectedであるという仮説 [Entezari+2021] • 本論文は膨大なPermutation空間を効率的に探索する方法を提 案し,上記仮説を検証 • クラス分類タスクにおいてはある程度妥当性があることを検証 20

21.

議論 • クラス分類以外での不安定性との関連 • 実際には学習はLMCが発生するより幅が狭いネットワークでも起こり, かつ深いネットワークでも起こっている? – Deepがよくうまくいくことの説明にはあまりなっていない • Permutation Symmetry以外のInvarianceが学習に与える 影響 – 層を跨いだマッチング [Nguyen+2021] – Re-scaling Invariance [Ainsworth+2018] 21

22.

Outline • 前提知識:Linear Mode Connectivity – “Linear Mode Connectivity and the Lottery Ticket Hypothesis”,ICML2020 • “Git Re-Basin: Merging Models Modulo Permutation Symmetries” • “Unmasking the Lottery Ticket Hypothesis: What's Encoded in a Winning Ticket's Mask?” 22

23.

書誌情報2 “Unmasking the Lottery Ticket Hypothesis” , ICLR 2023 Under review • 著者 – Mansheej Paul, Feng Chen, Brett W. Larsen, Jonathan Frankle, Surya Ganguli, Gintare Karolina Dziugaite – Stanford, Metaなど • 概要 – 宝くじ仮説の実験では,もとよりかなり小さいパラメータで同等の性能を達成するサブ ネット(Matching Networks)が存在することが示唆されている – ただし,小さなサブネットの発見はOne-Shotでは行えず,IterativeなPruningが必要 – かつ,Iterative Pruningの際に重みを初期値に戻す必要がある(Rewinding) – これらがなぜ必要なのかについてLMCの観点から考察 23

24.

Iterative Magnitude Pruning 1. NNをランダムに初期化(𝜽𝟎 ). 𝒎𝟎 = 𝟏 2. NNを一定イテレーション訓練(𝜽𝒋 ) 3. For i in 0…L 1. 𝒎𝒊+𝟏 ⊙ 𝜽𝒋を訓練 2. 重みの大きさ下位α%を刈り取るマスク𝒎𝒊+𝟏を作成 4. 最終的な𝒎𝑳 ⊙ 𝜽𝒋 を訓練する ※上記の手続きで訓練したサブネットが普通に訓練したNNと同程度の 正解率を達成する場合Matching Networksと呼ばれる 24

25.

IMP from LMC • 各Levelでαだけ重みを残す場合のIterative Pruningの模式図 • 本論文では,各レベルでのMatching Networksがその前のNetworksとLinearly ConnectedであることがIMPの成功に重要であることを検証 25

26.

実験結果の抜粋 • 左:各レベル間でのLoss Barrier.マッチングに成功している場合(緑)はLevel間 でLoss Barrierがない!中央は拡大図. • 右:すべてのLevelでのLoss Barrier.全ペアでLoss Barrierがないわけではない 26

27.

なぜこのようになるのか? • IMPで得られる摂動と同程度のランダムな摂動を加えた際のError Barrierの比較 • 重みのPruneだけではなくランダムな摂動に対しても同様にLMC • SGDの頑健性がLMCを引き起こしている 27

28.

どの程度刈り取っていいいのか • ざっくりいうとパラメータの曲率とProjectionによって発生する距離に依存して最大 Prune Rateが決まる • 完全にではないがMatchingの成否をある程度予測できる 28

29.

なぜRewindが必要なのか • Fine-Tuning:各レベルで重みと学習率を引き継ぐ • Learning Rate Rewinding:各レベルで重みのみ引き継ぎ学習率は戻す • Weight Rewinding:各レベルで重みも学習率も引き継がない • Fine-Tuningだけ小さい値の重みが少ない => 刈り取ったときの影響(曲率)が大 きくなり,Pruningに失敗する 29

30.

まとめ • Winning Ticketの発見に使われるIMPがなぜ必要なのかについて LMCの観点から分析 • (1)Pruningの各レベルで得られる解は前のレベルとLMC.ただし すべてのペアがつながっている訳では無い. • (2)これはSGDの頑健性により起きている. • (3)Rewindが必要なのはRewindをしないと値が小さなパラメータ がいなくなり削ることにより距離が大きく離れてしまうから 30