>100 Views
May 07, 26
スライド概要
AI・機械学習を勉強したい学生たちが集まる、京都大学の自主ゼミサークルです。私たちのサークルに興味のある方はX(Twitter)をご覧ください!
2026年前期輪読会 #3 2026/5/7 深層学習による画像認識の基礎 誤差逆伝播法・実装 2.6~2.7 工学部情報学科4回生 宮前明生 0
アジェンダ ◼ 誤差逆伝播法 ◼ 画像分類の実装 1
アジェンダ ◼ 誤差逆伝播法 ◼ 画像分類の実装 2
誤差逆伝播法:前回の復習 ⚫ 勾配降下法で重み𝐰を更新することで損失ℒを最小化できる 𝑤 𝑡+1 =𝑤 𝑡 𝜕ℒ −𝛼 𝜕𝑤 𝑡 𝜕ℒ ⚫ 勾配 𝜕𝑤 𝑡 を直接それぞれの重みについて計算するより、後ろの重 みの勾配を保存しながら計算することで効率よく計算できる 誤差逆伝播法 ⇒誤差逆伝播法 ⚫ 例として以下のようなモデルを考える • • • • 入力層: 𝒙 中間層1: 𝒉1 = 𝜎 𝑾1 𝒙 中間層2: 𝒉2 = 𝜎 𝑾2 𝒉1 出力層:ෝ 𝒚 = 𝑾3 𝒉 2 • 損失関数: ℒ(𝒘) = 1 2 𝑦ො − 𝑦 2 ※2次の回帰問題、活性化関数𝜎(𝑥) = 1/(1 + 𝑒 −𝑥 ) 𝜎はベクトルのそれぞれの要素に適応する 3
誤差逆伝播法:連鎖律 ⚫ 連鎖律とは、合成関数の微分で現れるルール d 𝑓 𝑔 𝑥 d𝑥 d𝑓 d𝑔 = d𝑔 d𝑥 ⚫ 多変数についての微分(偏微分)についての連鎖律 𝜕 𝑓 𝑔1 𝑥 , 𝑔2 𝑥 𝜕𝑥 ⚫ 誤差逆伝播法 𝜕𝑓 𝜕𝑔1 𝜕𝑓 𝜕𝑔2 = + 𝜕𝑔1 𝜕𝑥 𝜕𝑔2 𝜕𝑥 𝜕ℒ を計算する 𝜕𝑦ො 1 • ℒ 𝒘 = ⚫ 重み𝑤3 11 1 2 • 𝜕ℒ 𝜕𝑦ො 1 = 𝑦ො 1 − 𝑦 1 に関する勾配を計算する • 𝑦ො 1 = 𝑤3 ⇒ 𝑦ො − 𝑦 2 ⇒ 𝜕𝑦ො 1 11 1 ℎ2 + 𝑤3 13 3 ℎ2 前の計算結果 𝜕ℒ 𝜕𝑦ො 1 11 = 𝜕𝑦 ො 1 𝜕𝑤 11 𝜕𝑤3 3 𝜕ℒ 2 ℎ2 + 𝑤3 1 11 = ℎ2 𝜕𝑤3 12 1 =(𝑦ො 1 − 𝑦 1 )ℎ2 4
誤差逆伝播法:連鎖律 ⚫ 𝜕ℒ 1 𝜕ℎ2 を計算する • ⚫ 重み𝑤2 前の計算結果 𝜕ℒ 𝜕𝑦ො 𝑘 1 = σ𝑘 𝜕𝑦 ො 𝑘 𝜕ℎ 1 𝜕ℎ2 2 𝜕ℒ 11 に関する勾配 1 • ℎ2 = 𝜎(𝑤2 1 ⇒ 𝜕ℎ2 11 𝜕𝑤2 11 𝜕ℒ 11 𝜕𝑤2 = σ𝑘 (𝑦ො 𝑘 − 𝑦 𝑘 )𝑤3 𝑘1 誤差逆伝播法 を計算する 1 ℎ1 + 𝑤2 12 2 ℎ1 + 𝑤2 13 3 ℎ1 ) =ℎ11 𝜎′(𝑧21 ) (𝑧21 = 𝑤211 ℎ11 + 𝑤212 ℎ12 + 𝑤213 ℎ13 , 𝜎′(・)は𝜎 (・)の微分) 𝑤2 前の計算結果 • 𝜕ℒ 𝜕ℒ 𝜕𝑤2 1 𝜕ℎ2 11 = 1 𝜕ℎ2 𝜕ℒ 𝜕𝑤2 1 𝜕ℎ2 11 = 1 11 1 ℎ1 𝜎′(𝑧2 ) ⚫ 誤差逆伝播法はこのように前の計算結果を利用して、後ろの重みか ら求めていく 𝜕ℒ 𝜕ℒ 𝜕ℒ 𝜕ℒ 𝜕ℒ 𝜕ℒ ⇒ ⇒ ⇒ ⇒ ⇒ 𝑖𝑗 𝑘 𝑖𝑗 𝑘 𝑖𝑗 𝜕𝑦ො 𝑘 𝜕ℎ 𝜕ℎ 𝜕𝑤 𝜕𝑤 𝜕𝑤 3 2 2 1 1 5
誤差逆伝播法:連鎖律 ⚫ 重み𝑤3 11 𝑤2 11 求めたい。 ( に関する勾配のみを計算したが、重み行列𝑾3 , 𝑾2 に関する勾配を 𝜕ℒ 𝜕ℒ は成分が 𝑖𝑗 の行列) 𝜕𝑾3 𝜕𝑤 3 • 𝜕ℒ 𝜕𝑦ො 1 11 = 𝜕 𝑦 ො 1 𝜕𝑤 11 𝜕𝑤3 3 𝜕ℒ ⇒ • 1 = 𝑦ො 1 − 𝑦 1 ℎ2 𝜕ℒ 𝜕ℒ 𝜕ෝ 𝒚 = = (ෝ 𝒚 − 𝒚)𝒉𝑇2 𝜕𝑾3 𝜕ෝ 𝒚 𝜕𝑾3 1 𝜕ℒ 𝜕ℒ 𝜕ℎ2 𝜕𝑤2 𝜕ℎ2 1 𝜕𝑤2 11 = σ𝑘 𝑤3 11 = 𝑘1 1 1 (𝑦ො 𝑘 − 𝑦 𝑘 ) ℎ1 𝜎 ′ 𝑧2 𝜕ℒ 𝜕ℒ 𝜕𝒉2 ⇒ = = ((𝑾𝑇2 (ෝ 𝒚 − 𝒚)) ⊙ 𝜎′(𝑾2 𝒉1 ))𝒉1𝑇 𝜕𝑾2 𝜕𝒉2 𝜕𝑾2 ※⊙は要素積 6
誤差逆伝播法:計算グラフ ⚫ 連鎖律と誤差逆伝播法を計算グラフで表現することで理 解を深める ⚫ 計算グラフとは、グラフ理論のようにノード(頂点)と 枝からなるグラフで計算を表現する (a)計算グラフにおける連鎖律 𝑥 ⇒ 𝑦の関数𝑓の計算ノードに対して、逆方向には勾配を乗算 するものとする 流れてきた信号ℒに 𝜕𝑦 𝜕𝑦 をかけてℒ を得る 𝜕𝑥 𝜕𝑥 (b)加算ノードにおける逆伝播 𝑥 + 𝑦 = 𝑧の計算ノードは、 各勾配が 流れてきた信号 𝜕𝑧 𝜕𝑧 = = 1なので 𝜕𝑥 𝜕𝑦 𝜕ℒ 𝜕ℒ に対して、𝑥方向、𝑦方向で 1を得る 𝜕𝑧 𝜕𝑧 7
誤差逆伝播法:計算グラフ (c)乗算ノードにおける逆伝播 𝑥𝑦 = 𝑧の計算ノードは、 各勾配が 𝜕ℒ 𝜕𝑧 𝜕𝑧 = 𝑦, = 𝑥なので 𝜕𝑥 𝜕𝑦 𝜕ℒ 𝜕ℒ 流れてきた信号 に対して、𝑥方向で 𝑦、𝑦方向で 𝑥 𝜕𝑧 𝜕𝑧 𝜕𝑧 を得る (d)𝐿 = 𝑥𝑦 2 における逆伝播 𝐿 = 𝑥𝑦 2 の計算ノードは、𝑥𝑦 = 𝑧と𝑧 2 = 𝐿の合成と考 えられるので、 𝜕ℒ 𝜕ℒ = 1に対して、𝑧方向で1 = 2𝑥𝑦、 𝜕ℒ 𝜕𝑧 𝜕ℒ 𝜕ℒ 𝑥方向で2𝑥𝑦 = (2𝑥𝑦)𝑦、𝑦方向で2𝑥𝑦 = (2𝑥𝑦)𝑥を得 𝜕𝑥 𝜕𝑦 流れてきた信号 (d)𝐿 = 𝑥𝑦 2 における逆伝播 る 8
計算ノード 誤差逆伝播法:計算グラフ (e) 2層ニューラルネットワークの計算ノードにおける逆伝播 MMは行列の積、ReLUは活性化関数を表す。逆伝播の計算過程は、連鎖律で重み行列𝑾3 , 𝑾2 に関する勾配を求めるときと同じ。(スライド6) 結果として、 𝑾3 方向で 𝜕ℒ 𝜕ℒ = (ෝ 𝒚 − 𝒚)𝒉𝑇2 、 𝑾2 方向で = ((𝑾𝑇2 (ෝ 𝒚 − 𝒚)) ⊙ 𝜎′(𝑾2 𝒉1 ))𝒉1𝑇 𝜕𝑾3 𝜕𝑾2 このように誤差逆伝播法の後ろ から勾配を求めていく過程を可 視化できる。 𝜕ℒ 𝜕ℒ 𝜕ℒ 𝜕ℒ ⇒ ⇒ ⇒ 𝜕ෝ 𝒚 𝜕𝑾3 𝜕𝒉2 𝜕𝑾2 (e) 2層ニューラルネットワークの計算ノードにおける逆伝播 9
アジェンダ ◼ 誤差逆伝播法 ◼ 画像分類の実装 10
画像分類の実装:MNIST(手書き文字の分類問題) ⚫ 3層ニューラルネットワークでMNISTデータセット (手書き文字の分類問題)を学習する。 ⚫ MNISTデータセットは、0~9までの手書き文字画像 とその正解ラベルを持つ。(1ピクセルは0~255を 取り、0は黒、255は白) ⚫ 実装コード 11