【DL輪読会】”OMNIGROK: GROKKING BEYOND ALGORITHMIC DATA” ICRL2023

414 Views

February 09, 24

スライド概要

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

関連スライド

各ページのテキスト
1.

DEEP LEARNING JP ”OMNIGROK: GROKKING BEYOND ALGORITHMIC DATA” ICRL2023 [DL Papers] Kensuke Wakasugi, Panasonic Holdings Corporation. http://deeplearning.jp/ 1

2.

書誌情報 ◼ タイトル: OMNIGROK: GROKKING BEYOND ALGORITHMIC DATA ◼ 著者・所属: • Ziming Liu, Eric J. Michaud & Max Tegmark • Department of Physics, Institute for AI and Fundamental Interactions, Massachusetts Institute of Technology ◼ その他情報: • ICRL2023 • Omnigrok: Grokking Beyond Algorithmic Data | OpenReview ◼ 選書理由 • grokking現象に興味。過去からの流れも含めて調査 ※特に記載しない限り、本資料p18迄の図表は上記論文からの引用です。 2

3.

3 関連研究 Grokking:過学習後の遅延した急激な汎化現象 [1]Power et al., 2022 • • OpenAI、Google Workshop, ICLR 2021 [2]Liu et al., 2023 MIT ICRL2023 notable top 25% • • アルゴリズミックに生成されたデータにおいて、 過学習後、急激に汎化性能が向上する現象を 報告。Transformerベース。 現象発現には、weight decayが最も効果的 [3] Levi et al., 2024 weight normに着目 初期normが大きい場合に、 grokkingが発生 Tel-Aviv University ICLR 2024 poster • 1. 2. 3. [2201.02177] Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets (arxiv.org) Omnigrok: Grokking Beyond Algorithmic Data | OpenReview Grokking in Linear Estimators -- A Solvable Model that Groks without Understanding | OpenReview grokking発現にかかる時間を 理論的に解析

4.

4 関連研究 多様体上での遷移として解釈 [4] Millidge, 2022. • • • 4. 5. Oxford Blog 個人Blog 精度100%の多様体の形状が重要 パラメタの過不足、NTKなども絡めて 全体像をとらえやすい [5] Thilak, et al. 2022 • • • 最適化アルゴに着目 損失局面からはじき出され、相転移する現象として理解 Lossのスパイクに対応。最終層のnormが段階的に増加 Grokking 'grokking' (beren.io) [2206.04817] The Slingshot Mechanism: An Empirical Study of Adaptive Optimizers and the Grokking Phenomenon (arxiv.org)

5.

5 先行研究:[1]Power et al., 2022 tranformerでのgrokking現象を報告。weight decayが最も効果的 [1]Power et al., 2022より引用 ◼ タスク例 ◼ transformer • • 2 layers, width 128, and 4 attention heads 4・105 non-embedding parameters

6.

6 先行研究:[1]Power et al., 2022 特徴空間において、良好な構造を学習 • • 左図:対称群に関する特徴空間 右図:モジュラー加算に関する特徴空間 • 適当なサブセット(同色)が近接している

7.

書誌情報 再掲 ◼ タイトル: OMNIGROK: GROKKING BEYOND ALGORITHMIC DATA ◼ 著者・所属: • Ziming Liu, Eric J. Michaud & Max Tegmark • Department of Physics, Institute for AI and Fundamental Interactions, Massachusetts Institute of Technology ◼ その他情報: • ICRL2023 • Omnigrok: Grokking Beyond Algorithmic Data | OpenReview ◼ 選書理由 • grokking現象に興味。過去からの流れも含めて調査 ※特に記載しない限り、本資料p18迄の図表は上記論文からの引用です。 7

8.

8 イントロ 様々なデータでgrokking現象を検証。weight normで統一的に説明 Q1 Grokkingの起源は何か:過学習後に遅れて汎化するのはなぜか? A1 訓練損失とテスト損失の振る舞いの違い。“LUメカニズム“で説明 Q2 Grokkingの普遍性:アルゴリズミックなデータ以外でも生じるのか? A2 生じる。急峻さは減少するが、画像分類、感情分析、分子特性予測でも確認。

9.

9 LUメカニズム weight direction/normに分けて考える。normについて、訓練/テスト損失の形状がLU • weight norm一定の下で、 最適なパラメータをw*とおく • wに関して、 訓練損失の形状が「L」 テスト損失の形状が「U」 • • wc<wの時、勾配が小さく最適化が遅延 訓練/テスト双方の最適化は一致しており、 特定の条件下で最適化が遅いため grokkingが生じていると主張

10.

10 TOYモデル TEACHER-STUDENT で検証 ◼ アーキテクチャ • 5-100-100-5 MLP、tanh activation • 教師、生徒双方同じ。初期化乱数は別 教師 ガウス乱数 ◼ データ • input:ガウス乱数 • output:教師NNの出力 ◼ 学習 • Adam(learning rate 3*10-4) • 105Step 出力結果 生徒 ガウス乱数 出力結果

11.

11 TOYモデル実験結果 weight normによって、grokkingの有無、発生時間が変化 • • w0=テストNNの初期ノルム αでスケーリング、教師NNはα=1 • weight normに対し、 LU形状を確認 testのgrokking時間がweight decayの逆数に比例 •

12.

12 MNIST データ数が少ない場合はgrokkingが発生 • • depth-3 width-200 MLP、ReLU AdamW、MSE loss、one-hot targets • • データ数Nも振ったうえで検証 データ数増加でgrokking解消 (初めから性能がよいため、発生しない)

13.

13 感情分析 少データ、大きなweight normでわずかにgrokkingを確認 • • • • LSTM model、classification padding to length 500 two layers, embedding dim 64, hidden dim 128 Adam、learning rate 0.001、binary cross entropy loss • • • はっきりとしたgrokkingは確認できず weight decayなしでも訓練でU形状になる 暗黙的な正則化を含んでいるか

14.

14 分子、グラフNN 感情分析(LSTM)と同様の結果 • 少データ、大きなweight normの条件でgrokkingを確認

15.

アルゴリズミックなタスク 整数の足し算を検証。乱雑さを表すパラメタmを導入 ◼ セットアップ • 整数の足し算:特徴表現としては、ベクトルの足し算+decoder(MLP) • ガウス乱数とone-hot表現を用意。表現学習を内挿領域に限定。m:messiness • MLPのwと、乱雑さmを学習 15

16.

16 アルゴリズミックなタスク 乱雑さmの最適化が遅い。データ増で勾配増加 • • 一般的な最適化はA~Eの順に進行 D→E工程では、勾配が小さく進行が遅い

17.

17 アルゴリズミックなタスク Transformerのweight normでも、整合する傾向を確認 • • 訓練データの過学習時にnorm増大 grokking時にnorm減少 • normを制約すると、grokkingが生じない

18.

18 MNISTとの対比 特徴表現に依存しないため、grokkingが生じない ◼ セットアップ • • N=60000使用 生データと線形表現(全ピクセル=ラベル)を用意。表現学習を内挿領域に限定。m:messiness • • 乱雑さに関係なく汎化性能が高い grokkingが生じない

19.

Discussion、Conclusion ◼ Discussion • 適切な表現学習が必要な場合にgrokkingが生じている しかし、言語モデルでは見受けられない。その理由は? (1)最適な表現であっても、乱雑な可能性 (2)事前学習でよい表現が学習され、回避されている ◼ Conclusion • grokkingは、訓練/テスト損失の振る舞いの違い(“LUメカニズム“)によって生じる • 多様なデータでgrokkingを確認したが、最も顕著なのはアルゴリズミックなタスク • grokkingの程度は、そのタスクの表現学習の重要性に依る 19

20.

書誌情報2 ◼ タイトル: Grokking in Linear Estimators – A Solvable Model that Groks without Understanding ◼ 著者・所属: • Noam Levi, Alon Beck & Yohai Bar Sinai • Raymond and Beverly Sackler School of Physics and Astronomy Tel-Aviv University ◼ その他情報: • ICLR 2024 poster • Grokking in Linear Estimators -- A Solvable Model that Groks without Understanding | OpenReview ※特に記載しない限り、本資料p34迄の図表は上記論文からの引用です。 20

21.

関連研究 21 再掲 Grokking:過学習後の遅延した急激な汎化現象 [1]Power et al., 2022 • • OpenAI、Google Workshop, ICLR 2021 [2]Liu et al., 2023 MIT ICRL2023 notable top 25% • • アルゴリズミックに生成されたデータにおいて、 過学習後、急激に汎化性能が向上する現象を 報告。Transformerベース。 現象発現には、weight decayが最も効果的 [3] Levi et al., 2024 weight normに着目 初期normが大きい場合に、 grokkingが発生 Tel-Aviv University ICLR 2024 poster • 1. 2. 3. [2201.02177] Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets (arxiv.org) Omnigrok: Grokking Beyond Algorithmic Data | OpenReview Grokking in Linear Estimators -- A Solvable Model that Groks without Understanding | OpenReview grokking発現にかかる時間を 理論的に解析

22.

22 概要 Grokkingをランダム行列理論を用いて解析。遅延時間を解析的に導出 • 教師-生徒モデルにおいて、ランダム行列理論を用いて解析 線形ネットワークで解析。非線形NNにも一部拡張。 • grokkingは、 “interesting”な汎化ではなく、訓練損失の低速な減衰によって生じる。 その遅延時間は、入力の次元とデータ数の比によって決まる • 初期weightのスケールは、汎化損失減少に至る必要精度を増加し、 grokking発生時間を遅らせる • 出力次元の増加とともに、grokkingまでの時間が増加する • L2正則化は、grokkingを抑制する

23.

訓練損失と汎化損失 訓練損失と汎化損失の減少速度を解析 ■入力データ ■訓練損失 ■汎化損失 • • 線形変換関数を考える。Tx、Sxがそれぞれ教師・生徒の出力 → Tを固定して、SをTに近づけるように学習。 ランダム行列理論を用いて、損失の減衰の仕方を算出。 → それぞれの収束時間、精度を解析的に算出 ※ 訓練損失→経験分布、汎化損失→理論分布 という関係? 23

24.

訓練損失と汎化損失の時間発展 訓練損失についての微分方程式を解き、時間発展を導出 ■訓練損失の微分 ■微分方程式として解く ■訓練損失と汎化損失の時間発展導出 24

25.

Marchenko-Pastur (MP) distribution 訓練データサンプルへの依存性をMP分布に代替し、消去 ■訓練データのグラム行列の固有値が、MP分布に従う(らしい) ■ MP分布に関する期待値計算で、各種近似を使い訓練/汎化損失が求まる 25

26.

26 訓練/汎化損失の差 精度95%到達時間の差を解析的に導出 ■回帰モデルにおける識別精度の定義 ■訓練/汎化損失の差 ■精度95%到達時間の差 ※この他、weight normのスケーリングや ラベルノイズの影響も解析的に導出可能

27.

27 導出結果の妥当性 解析結果と実験結果が整合 • λ=0.1 din<Ntr grokking無し(汎化〇) • λ=0.9 din≒Ntr grokking有り(汎化〇) • λ=1.5 din>Ntr grokking無し(汎化×) • grokking発生までの遅延時間の λ依存性も解析結果と一致

28.

28 doutの影響 遅延時間が最大化するdoutが存在

29.

weight decayの影響 γが小さいとき、 grokking発生までの遅延時間は一定 ■D=(S-T)の更新式 ■MP分布による期待値で記述 ■最終的なパラメータ依存性 29

30.

30 weight decayの影響 γの大きさは、grokkingの有無に影響 • γ=10-2 grokking無し(汎化×) • γ=10-3 grokking有り • γ=10-5 grokking有り

31.

31 パラメータ依存性の総括 λ≒0、すなわち、データ数が多い場合にgrokking発生までの遅延時間が短い • 空白領域は、γによってgrokking の有無が変わるとみなせる

32.

32 一般化 2層NNに拡張。線形とtanhの二通りで解析 ■これまでと同様の流れ • 活性化関数σには、線形とtanhを採用

33.

33 一般化 一般化しても解析と実験がよく一致 • • 線形とtanh双方ともに解析と実験が一致 非線形モデルへ拡張できることを示した。

34.

まとめ ◼ Discussion • 非線形関数を含むアーキテクチャにおいて、訓練損失/汎化損失のダイナミクスに基づき、 grokkingの挙動を解析的に求めた • grokkingは訓練損失と汎化損失が依存する共分散行列の差によるもので、 暗記から理解へ、という質的な変化を意味していない • 今後、入力データの変更や、オプティマイザーや損失関数などの変更によって、 汎化に関する理解が深まることを期待 34

35.

35 総括・所感 活用方法はまだ不明瞭だが、汎化性向上に寄与することを期待 • 全体通して、weight norm、weight decayが 重要な役割を担っていることは共通 • 質的な変化というよりも、勾配消失問題に近い印象 → だとしたらこれまでのオプティマイザー関連の研究と の関連は? 関連研究[4][5] • weight空間で、原点付近が汎化性能がよい、 という仮説に基づく • アルゴリズミックなタスクは大なり小なり部分的に含ん でいることが多い気がするので、 部分的にはgrokking促進は有効ではないか?