[DL輪読会]Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets

>100 Views

April 05, 22

スライド概要

2022/03/25
Deep Learning JP:
http://deeplearning.jp/seminar-2/

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

関連スライド

各ページのテキスト
1.

DEEP LEARNING JP [DL Papers] “Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets” (ICLR 2021 workshop) Okimura Itsuki, Matsuo Lab, B4 http://deeplearning.jp/ 1

2.

アジェンダ 1. 2. 3. 4. 書誌情報 概要 背景 実験 2

3.

1 書誌情報 タイトル: Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets 出典: ICLR 2021 (1st Mathematical Reasoning in General Artificial Intelligence Workshop) https://nips.cc/Conferences/2021/ScheduleMultitrack?event=25970 著者: Alethea Power, Yuri Burda, Harri Edwards, Igor Babuschkin & Vedant Misra (OpenAI) 選んだ理由:現象が面白い 3

4.

2 概要 • 一般的にニューラルネットワークはある一定の学習ステップ数を経ると 以降の検証データで損失が減少しない過学習に陥るとされる. • しかし、数式の答えを求めるタスクにおいて 過学習するステップ数を遥かに超えて学習を続けると, ニューラルネットワークがデータ中のパターンを 「理解する(Grokking)」プロセスを通じて学習し, 急激な汎化性能の向上が見られる場合があることを示した. • この現象において,より小さなデータセットを用いた場合, 汎化のためのステップ数の必要量が増加することも示した. 4

5.

3 背景 Bias-Variance Trade-off 統計的機械学習における学習理論 汎化誤差はバイアスとバリアンスと ノイズに分解され, モデルの複雑度低すぎる場合: 予測値と真の値との差(バイアス)が高くなり、 データの特徴を捉えられず、 学習不足に陥る モデルの複雑度高すぎる場合: 予測値の分散(バリアンス)が高くなり、 データのノイズまで学習してしまい、 過学習に陥る 出典: https://towardsdatascience.com/bias-andvariance-but-what-are-they-really-ac539817e171 5

6.

3 背景 近年サイズの大きいモデルも成果を出す モデルの複雑度高いのに 過学習しないの? Bias-Variance Trade-offは? 出典: https://www.microsoft.com/en-us/research/blog/using-deepspeed-and-megatron-to-trainmegatron-turing-nlg-530b-the-worlds-largest-and-most-powerful-generative-language-model/ 6

7.

3 背景 二重降下(Double Descent) モデルの複雑度(モデル構造や学習量など)を 増やしていくと, 一度テストエラーの減少が止まったのち, 再度テストエラーの減少が起こる 二重降下が報告される CNN,ResNet,およびTransformerなどの 最新モデルでも報告 条件によってはより多くのデータを使っても 性能が悪化する場合もあった 出典: https://arxiv.org/pdf/1912.02292.pdf 7

8.

3 背景 GPT-3はパラメータ数が13Bを超えると足し算、引き算が解け始める Few-shotの設定において パラメータ数が6.7B以下のモデル ほとんど四則演算が解けない パラメータ数が13Bのモデル 二桁の足し算、引き算を50%前後解ける パラメータ数が175Bのモデル 二桁の足し算、引き算は98%以上, 三桁の足し算、引き算は80%以上解ける 掛け算や複合演算なども20%前後解ける →解けるかどうかはパラメータ数依存? 出典: https://arxiv.org/pdf/2005.14165.pdf 8

9.

4 実験 二項演算を小規模なニューラルネットにより汎化できるか検証 𝑎 ∘ 𝑏 = 𝑐の形の等式のデータセット (0 ≤ 𝑎, 𝑏 < 97)に対して サイズの小さい(400Kパラメーター) Transformerで学習を行う すべての可能な等式の適切な部分集合に ついてニューラルネットワークを訓練することは、 右図の二項演算表の空白を埋めることに等しい トークン化されるため内部的な情報 (ex. 10進数での表記)にはアクセスできず 離散的な記号の他の要素との相関から 二項演算を解くことができるか? 9

10.

4 実験 Grokking 102step程度から訓練データでの損失と 検証データでの損失に乖離が見え始め, 103step程度で訓練データでの損失が0近くになる その後105stepから検証データでの損失が 下がり始める. そして106stepで検証データでも損失が0近くになる →過学習し始してからその1000倍以上の 最適化stepを経ることで急激に検証データでの 精度が急激に向上する。(Grokking) *grok【他動】〈米俗〉完全に[しっかり・ 心底から]理解[把握]する 出典: https://eow.alc.co.jp/search?q=grok 10

11.

4 実験 データ数の影響 異なったデータ数の条件においてのGrokkingを比較 一般的な教師あり学習の設定において データが少ないほど学習データ量を減らすと モデルの収束する汎化性能が低下する 一方,Grokkingが観測される場合, 学習データ量の範囲内では 収束性能が100%で一定であるのに対し, データ量の減少に伴って 収束性能を達成するために 必要な最適化時間が急速に増加する 11

12.

4 実験 二項演算の種類による影響 二項演算の種類ごとに生じる Grokkingの様子を比較 オペランドの順序に関して 対称なもの(ex. 𝑥 + 𝑦, 𝑥 2 + 𝑦 2 )は 非対称な演算(ex. 𝑥Τ𝑦, 𝑥 3 + 𝑥𝑦)よりも 汎化のためのデータが少なくて済む 傾向がある →位置埋め込みを無視するだけで良く 学習しやすかった可能性 一部の直観的に困難な数式 (ex. 𝑥 3 + 𝑥𝑦 2 + 𝑦)では 予算内では汎化に至らず 12

13.

4 実験 正則化の影響 異なった正則化アルゴリズムでの 105step学習した後の性能を比較 重み減衰がデータ効率に対し, 最も効果的で収束に必要となる サンプル数を大幅に削減 →重み減衰が重要な役割を担う? ミニバッチの使用による勾配ノイズ, ガウスノイズといったものも 学習に一定の効果 最適でないハイパーパラメータだと 収束率は著しく低下する 13

14.

4 実験 出力層の可視化 𝑥 + 𝑦で学習したTransformerの 埋め込み層をt-SNEを用いて可視化する 各要素に8を加えた数列が示される →埋め込み層が数学的な特徴を捉えている 14

15.

まとめ • 一般的にニューラルはある一定の学習ステップ数を経ると 以降の検証データで損失が減少しない過学習に陥るとされる. • しかし、数式の答えを求めるタスクにおいて 過学習するステップ数を遥かに超えて学習を続けると, ニューラルネットワークがデータ中のパターンを 「理解する(Grokking)」プロセスを通じて学習し, 急激な汎化性能の向上が見られる場合があることを示した. • この現象において,より小さなデータセットを用いた場合, 汎化のためのステップ数の必要量が増加することも示した. 15

16.

感想 (確かに“完全に理解した曲線”っぽい) パラメーター数の小さなモデルでも汎化に到達するのが意外だった。 パラメータ数が増えた場合にGrokkingに到達するまでの比較は見たかった。 今回の実験では二項演算を97×97の数字の表の空きスロットを埋める計算として定義。 実際の計算よりはかなり小規模なスケール 各桁数字ごとにトークン化すればより大きな桁でのGrokkingに到達できる? 小さいモデルでも実現可能な一方で,Grokkingは学習率を比較的狭い範囲(1桁以内)で 調整する必要があるらしく,やはりモデルの複雑度を上げるのにもパラメータを増やすのが有力? 16

17.

DEEP LEARNING JP [DL Papers] “Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets” (ICLR 2021 workshop) Okimura Itsuki, Matsuo Lab, B4 http://deeplearning.jp/