【DL輪読会】The Clock and the Pizza: Two Stories in Mechanistic Explanation of Neural Networks

1.6K Views

November 17, 23

スライド概要

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

関連スライド

各ページのテキスト
1.

DEEP LEARNING JP [DL Papers] The Clock and the Pizza: Two Stories in Mechanistic Explanation of Neural Networks Gouki Minegishi, Matsuo Lab http://deeplearning.jp/ 1

2.

書誌情報 • 概要 – ニューラルネットワークで算術アルゴリズムがどのように発⾒されるかを検証 – アーキテクチャによって異なるアルゴリズムを発⾒することを検証 • 選定理由 – Neurips 2023のoral(top5%) 2

3.

Mechanistically understanding(メカニズム解明) • 深層ニューラルネットワークをリバースエンジニアリング(Interpretability) – 学習したアルゴリズム、どのような表現が獲得されているか解明[5] • Mathematical Tasks(算術タスク) – Taskを解くアルゴリズムがよくわかっている[6] • Phase Transitions (相転移) – 訓練データを与えると、勝⼿にアルゴリズムを獲得する[7] • Emergent ability – モデルの特定の能⼒の発現とアルゴリズムの関係[8] ニューラルネットワークはどのような条件下でどのようなアルゴリズムを獲得するのか? 3

4.

Modular Arithmetic and Grokking • Modular Arithmetic(合同算術) – 𝑎 + 𝑏 = 𝑐 mod 𝑝 e.g., 9+4 = 1 (mod 12) – 時計の計算など⽇常で無意識に⾏われている [3] • Grokking – 合同算術タスクで検証されることが多い – 過学習し続けると急に汎化する現象[1] – Embedding空間をt-SNEすると円状になる[2] 4

5.

先⾏研究[4] • • Transformerは⼊⼒(𝑎, 𝑏)を円状にマッピングし、円状で⾜し算をしていることを発⾒。 ネットワークは以下のアルゴリズムを学習しているはずだ。 1. One-hotベクトルの(𝑎, 𝑏)を sin 𝜔! 𝑎 , cos 𝜔! 𝑎 と sin 𝜔! 𝑏 , cos 𝜔! 𝑏 にマッピング。𝜔! = 2. cos(𝜔! 𝑎 + 𝑏 )とsin(𝜔! 𝑎 + 𝑏 )を計算。e. g. , cos 𝜔! 𝑎 + 𝑏 3. Logitを計算。cos 𝜔! 𝑎 + 𝑏 − 𝑐 • • = cos 𝜔! 𝑎 + 𝑏 "!# ,𝑘 $ ∈ ℕ = cos 𝜔! 𝑎 cos 𝜔! 𝑏 − sin 𝜔! 𝑎 sin(𝜔! 𝑏) cos 𝜔! 𝑐 + sin 𝜔! 𝑎 + 𝑏 sin(𝜔! 𝑐) 1はEmbed, 2はAttention+MLP, 3はUnembedで⾏われている。 最終的には様々なkでのLogitが⾜し合わされている。 5

6.

先⾏研究[4] 先のアルゴリズムを学習している証拠 1. Embeddingの周期性 1. 2. 𝑊% を⼊⼒⽅向でフーリエ変換 6つの周期{14,35,41,42,52}に分解できる 2. AttentionとMLPの周期性 1. 2. 3. (𝑎, 𝑏)全ての⼊⼒パターンのattention scoreを計算 MLPのあるニューロンの発⽕をplot いずれも周期性を確認 6

7.

本研究のメインの主張 • Nanda[4]が⽰してるのはmodular additionを解くアルゴリズムの⼀つの側⾯ (clock Algorithm)にすぎない • Clock, Pizza, Non-circularのアルゴリズムがある • これらのアルゴリズムがどのような条件下で獲得されるのか検証 7

8.

実験 • Task – 𝑎 + 𝑏 = 𝑐 mod 𝑝 e.g., 9+4 = 2 (mod 12) • Model – One-layer Transformer with constant attention (ModelA) • 𝐸! = 𝑊",$! + 𝑊%&',! • E() = 𝐸* + ∑+ 𝑊,+ 𝑊-+ (𝐸. + 𝐸* ) (Constant attention) • • 1 + 𝑊 1 𝑅𝑒𝐿𝑈(𝑏 0 + 𝑊 0 𝐸 ) 𝐻.* = 𝐸.* + 𝑏&/0 &/0 !2 !2 .* 𝑂 = 𝑊3 𝐻.* – Standard One-layer Transformer (ModelB) • モデルと学習するアルゴリズムの関係を調査 8

9.

Clock Algorithm (Transformer=modelB) • 10 Embedding(𝐸! , 𝐸" ) – 𝑎 → 𝐸5 = 𝐸5,7 , 𝐸5,8 = (cos 𝜔! 𝑎 , sin(𝜔! 𝑎)), 𝜔! = "!# ,𝑘 $ ∈ ℕ – 𝑏 → 𝐸9 = 𝐸9,7 , 𝐸9,8 = (cos 𝜔! 𝑏 , sin(𝜔! 𝑏)) • 11 12 𝐸5,7 𝐸9,7 − 𝐸5,8 𝐸9,8 𝐸59,7 cos(𝜔! 𝑎 + 𝑏 ) = = 𝐸59,8 𝐸5,7 𝐸9,7 + 𝐸5,8 𝐸9,8 sin(𝜔! 𝑎 + 𝑏 ) = 𝐸59 3 𝜃! 𝐸! 8 4 5 7 加法定理(𝐸!" , 𝐻!" ) 2 𝐸" 𝜃" 9 1 6 – 𝐸59 = – 𝐻59 • Logit計算(logit = 𝑄!"# ) – 𝑄59: = 𝑈: 𝐻59 , 𝑈: ≡ 𝐸:,7 , 𝐸:,8 = (cos 𝜔! 𝑐 , sin 𝜔! 𝑐 ) • cos 𝜔4 𝑎 + 𝑏 − 𝑐 = cos 𝜔4 𝑎 + 𝑏 cos 𝜔4 𝑐 + sin 𝜔4 𝑎 + 𝑏 sin(𝜔4 𝑐) – 𝑄59: = 𝐸5,7 𝐸9,7 − 𝐸5,8 𝐸9,8 𝐸:,7 + 𝐸5,7 𝐸9,7 + 𝐸5,8 𝐸9,8 𝐸:,8 • First Evidence – 勾配の対称性: ∇%! 𝑄59: ≠ ∇%" 𝑄59: • Second Evidence – Logitが(a-b)の値に依存するか – Clockはしない,Pizzaはする 10

10.

Pizza Algorithm (MLP=modelA) • Embedding(𝐸! , 𝐸" ) – 𝑎 → 𝐸5 = 𝐸5,7 , 𝐸5,8 = (cos 𝜔! 𝑎 , sin(𝜔! 𝑎)) – 𝑏 → 𝐸9 = 𝐸9,7 , 𝐸9,8 = (cos 𝜔! 𝑏 , sin(𝜔! 𝑏)) • ベクトルの平均と和積(𝐸!" , 𝐻!" ) D D – 𝐸59 = " 𝐸5 + 𝐸9 = " (cos 𝜔! 𝑎 + cos(𝜔! 𝑏), sin 𝜔! 𝑎 + sin(𝜔! 𝑏)) • 𝐸.* から𝐻.* は⾮線形な変換 E# 5F9 " – 𝐻59 = |cos( • )|(cos 𝜔! 𝑎 + 𝑏 , sin 𝜔! 𝑎 + 𝑏 ) Logit計算 E# 5F9 " – 𝑄59: = |cos( • )| cos 𝜔! 𝑎 + 𝑏 − 𝑐 First Evidence – a-bに依存 e.g., (a,b)=(7,7),(10,4)だと答えは同じなのにlogitは変わっちゃう • Second Evidence – – – – EmbeddingのPCAは円になる(どっちも) Pizza(MLP)は2PCで 31%,6PCで91% Clock(Transfomer)は6PC消しても100% ClockはCircle Isolation? 11

11.

Interpolate between MLP and Transformer • Attention rate ≔ 𝛼 – 𝑀G = 𝑀𝛼 + 𝐼(1 − 𝛼) • Gradient-symmetricity – 勾配の対称性 – 𝑠H ≡ • D I JK!"$ JK!"$ , J% ) J%! " ∑𝑠𝑖𝑚( Distance irrelevance – (a-b)にどれくらい依存するか – 𝑞≡ • • % ∑LMN & 𝐿O,OPN 𝑖 ∈ ℤ"Q LMN(R'(|O,T∈ℤ)&) Attention rateが⼤きいと attentionメカニズムがNN内で ⽀配的になりclock Algorithmを 学習。逆も然り。 Widthが広くなるとMLPがより capableになるattentionの benefitが⼩さい 12

12.

Pizzas Come in Pairs • Pizzaアルゴリズムは完璧でない? – Antipodalな位置にあるペア(3,9),(2,8),(1,7)は平均が原点になる – Pが奇数でもめちゃノルムの⼩さいベクトルになってしまう – これらの区別ができない? • 付け合わせのピザ(accompanying pizzas)がある – 異なる⾓度(間隔)ピザがあればお互いに補完できる – #1,#2,#3の付け合わせとして#4,#5,#6のピザがある – #1,#2,#3のピザだけで99.7%, #4,#5,#6だけだと16.7% • • #4,#5,#6は最終的にはあんま意味ない? 仮説 – – – – #1,#2,#3は3つのlottery ticketsに対応している 学習初期は#1,#2,#3の弱点を#4,#5,#6が補う 訓練が進むと#4,#5,#6がpruneされる 確かにtransformerの⽅がgrokking ticketsの効果が⼩さい • MLPだとめちゃ無駄があってTransformerだとない? – 実験検証はない 13

13.

まとめ • まとめ – ロジット可視化、埋め込み空間における主成分の分離、モデルの対称性の勾配に基づく判定など、様々な新しい解釈 可能な技術を提案し、学習アルゴリズムを判定する – Transformer, MLPと学習アルゴリズムとの関係を分析 • Limitation – Modular additionだけ – より複雑で実世界のタスクで考える必要である • Broader Impact – Safe AIに解釈性は役⽴つ • 感想 – Neural Circuitとの関係(あくまでもlogitの勾配とかマクロな分析) – Transformerが強い理由との関係 • 今回だとclockの⽅がa-bに依存しないから良いアルゴリズムな気がする 14

14.

参考⽂献 [1] Power, Alethea, et al. "Grokking: Generalization beyond overfitting on small algorithmic datasets." arXiv preprint arXiv:2201.02177 (2022). [2] Liu, Ziming, et al. "Towards understanding grokking: An effective theory of representation learning." Advances in Neural Information Processing Systems 35 (2022): 34651-34663. [3] https://ja.wikipedia.org/wiki/%E5%90%88%E5%90%8C%E7%AE%97%E8%A1%93 [4] Nanda, Neel, et al. "Progress measures for grokking via mechanistic interpretability." arXiv preprint arXiv:2301.05217 (2023). [5] Elhage, Nelson, et al. "Toy models of superposition." arXiv preprint arXiv:2209.10652 (2022). [6] Hoshen, Yedid, and Shmuel Peleg. "Visual learning of arithmetic operation." Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 30. No. 1. 2016. [7] Ziyin, Liu, and Masahito Ueda. "Exact phase transitions in deep learning." arXiv preprint arXiv:2205.12510 (2022). [8] Michaud, Eric J., et al. "The quantization model of neural scaling." arXiv preprint arXiv:2303.13506 (2023).. 15