38.2K Views
March 23, 23
スライド概要
Chat-GPTをはじめ、昨今の大規模言語モデル(LLM)の礎となった機械翻訳モデルTransformerの解説資料です。GPTもBERTも、基本構造はTransformerとほぼ変わりません。近年のLLMの理解には不可欠なTransformerの構造をできるだけ詳細に書き下してみました
Transformer解説 ~Chat-GPTの源流~ 1
Chat-GPTを理解したい Chat-GPTすごい APIが公開され、活用アプリ&怪しい 記事が爆増 CSer, ISerとして、 根底から理解しよう 2
あくまで私は計算機屋さん 細かい理論についてはわからん ✓大規模言語モデルのお気持ちには触れつつ、 あくまで、その計算フローに焦点を当てる ✓本資料では学習方法は深く取り扱わない 3
目次 • ざっくり深層学習 • 自然言語処理モデル • RNN • Transformer 4
深層学習が何をやってるか 深層学習 …複雑な関数を、単純な線形変換()を大量に重ねて近似すること function(DNN) Rabbit Kawauso Cat 5
単純な線形変換() 𝑦 = 𝐴𝑐𝑡𝑖𝑣𝑎𝑡𝑒(𝑥𝑊 + 𝑏) (以降こいつをLinearって呼びます) ➢こいつの重ね方に工夫が生まれる ➢𝑊, 𝑏の値をよしなに調整するのが学習 Activate Func * 非線形性を生み、 表現力が向上 etc. 6
自然言語処理 人間の言語の解釈を要するタスクを機械に解かせたい ➢曖昧性の高さから、深層学習によるアプローチが主流 • 文章要約 •Q&A • 翻訳 7
Chat-GPTへの道のり GPT GPT-2 GPT-3 Chat-GPT Transformer BERT 全てはTransformerから始まった ➢まずはコイツから始めましょう! 8
ざっくりTransformer “Attention Is All You Need” (Ashish Vaswani @Google Brain et al. ) 機械翻訳用の自然言語モデル 従来のRNNベースの手法から大幅に性能改善 ➢自然言語処理のbreak throughを作った革命的なモデル 9
翻訳の主流:Encoder-Decoderモデル Encoder Decoder I am a man . 私は人だ。 単語ベクトル群 単語ベクトル群 DNN DNN 文の意味っぽいベクトル 文の意味っぽいベクトル 10
機械翻訳の祖:RNN Recurrent Neural Network … 入力長分、共通の線形変換()を繰り返し適用するモデル 可変長の入力に対応可能、系列データ全般に強い y0 y1 s0 y2 s1 y3 s2 y4 s3 s4 Linear Linear Linear Linear Linear x0 x1 x2 x3 x4 11
RNNで機械翻訳 - Encoder部分 Encoder s0 s1 s2 s3 Linear Linear Linear Linear Linear x0 x1 x2 x3 x4 文 全 体 の 意 味 Word Embedding(実はこいつもDNN) I have a pen . 12
RNNで機械翻訳 - Decoder部分 Decoder 文 全 体 の 意 味 t0 t1 t2 t3 Linear Linear Linear Linear Linear y0 y1 y2 y3 y4 Word Embedding(実はこいつもDNN) 私は ペンを 持って いる 。 13
RNNの問題点 • 計算フローのクリティカルパスが文の長さに比例 ➢GPU等の並列計算で高速化できない Encoder Decoder s0 s1 s2 s3 Linear Linear Linear Linear Linear x0 x1 x2 x3 x4 文 全 体 の 意 味 t0 have a pen t2 t3 Linear Linear Linear Linear Linear y0 y1 y2 y3 y4 Word Embedding(実はこいつもDNN) I t1 Word Embedding(実はこいつもDNN) . 14 私は ペンを 持って いる 。 14
Output Probabilities Transformer softmax Linear Layer Norm 並列性の高い計算フローを持つ Encoder-Decoder型DNN 主要なパーツ • Positional Encoding • Feed-Forward Network • Layer Normalization • Multi-Head Attention Nx + Feed Forward Layer Norm Layer Norm + + Feed Forward Multi-Head Attention Layer Norm Layer Norm + + Multi-Head Attention Masked Multi-Head Attention 〜 + 〜 xN + Input Embedding Output Embedding Inputs Outputs 15
80% 男の子 まずは超ざっくり見る 10% 6% 男性 女の子 4% 犬 Output Probabilities Layer Norm Decoder 1. 入力文をEncode + 入力文の意味 Feed Forward 2. 出力済の文と1の結果から、 Encoder 次単語の確率分布を生成 Layer Norm 3. ビームサーチで次単語確定、 出力済の文に追加 Layer Norm + Multi-Head Attention + Masked Multi-Head Attention 〜 4. 2に戻る I am a boy . 私は + 16
Output Probabilities Transformer 主要なパーツ • Positional Encoding • Feed-Forward Network • Layer Normalization • Multi-Head Attention softmax Linear Layer Norm + Feed Forward Nx Layer Norm Layer Norm + + Feed Forward Multi-Head Attention Layer Norm Layer Norm + + Multi-Head Attention Masked Multi-Head Attention 〜 + 〜 xN + Input Embedding Output Embedding Inputs Outputs 17
Positional Encoding 文の意味解釈で、各単語の位置情報は重要 Linear層は単語の順序を考慮しない ➢入力時点で、単語自体に位置情報を明示的に埋め込む必要性 𝑑 pos 単 語 ベ ク ト ル i 𝑃𝐸 𝑝𝑜𝑠, 2𝑖 = sin 𝑝𝑜𝑠 2𝑖 10000 𝑑 𝑃𝐸 𝑝𝑜𝑠, 2𝑖 + 1 = cos( 𝑝𝑜𝑠 2𝑖 10000 𝑑 ) Word Embedding I am a boy . 18
Output Probabilities Transformer 主要なパーツ • Positional Encoding • Feed-Forward Network • Layer Normalization • Multi-Head Attention softmax Linear Layer Norm + Feed Forward Nx Layer Norm Layer Norm + + Feed Forward Multi-Head Attention Layer Norm Layer Norm + + Multi-Head Attention Masked Multi-Head Attention 〜 + 〜 xN + Input Embedding Output Embedding Inputs Outputs 19
Feed-Forward Network 𝑧 = 𝑅𝑒𝐿𝑈 𝑥𝑊1 + 𝑏1 𝑦 = 𝑧𝑊2 + 𝑏2 ➢Linear x2。それだけ ※bは省略 20
Output Probabilities Transformer 主要なパーツ • Positional Encoding • Feed-Forward Network • Layer Normalization • Multi-Head Attention softmax Linear Layer Norm + Feed Forward Nx Layer Norm Layer Norm + + Feed Forward Multi-Head Attention Layer Norm Layer Norm + + Multi-Head Attention Masked Multi-Head Attention 〜 + 〜 xN + Input Embedding Output Embedding Inputs Outputs 21
Layer Normalization Layers… I am a boy 行単位で適用 22 … … … ただの正規化もどき ➢学習の高速化や過学習の抑制に寄与 -0.9 0 -0.51.4 … LN … LN … … 1 3 2 6 … 𝑥𝑖 − 𝜇 𝐿𝑁 𝑥𝑖 = 𝛾+𝛽 𝜎 𝜇, 𝜎: 𝑥𝑖 の平均, 標準偏差 𝛾, 𝛽: パラメタ(スカラ値)
Output Probabilities Transformer 主要なパーツ • Positional Encoding • Feed-Forward Network • Layer Normalization • Multi-Head Attention softmax Linear Layer Norm + Feed Forward Nx Positional Encoding Layer Norm Layer Norm + + Feed Forward Multi-Head Attention Layer Norm Layer Norm + + Multi-Head Attention Masked Multi-Head Attention 〜 + Positional Encoding 〜 xN + Input Embedding Output Embedding Inputs Outputs 23
Multi-Head Attention 𝑴𝒖𝒍𝒕𝒊𝑯𝒆𝒂𝒅(𝑄, 𝐾, 𝑉) = 𝒄𝒐𝒏𝒄𝒂𝒕 ℎ𝑒𝑎𝑑𝑖 𝑊 𝑜 𝑄 ℎ𝑒𝑎𝑑𝑖 = 𝑺𝑫𝑷𝑨𝒕𝒕𝒆𝒏𝒕𝒊𝒐𝒏 𝑄𝑊𝑖 , 𝐾𝑊𝑖𝐾 , 𝑉𝑊𝑖𝑉 お気持ち • 𝑉には、整理されていない有益情報がたくさん • 𝐾は𝑉に紐づく情報がたくさん • 𝑄に近い情報がKにあれば、対応する有益情報を𝑉から抽出 24
Scaled Dot Product Attention お気持ち • 𝑘𝑖 (𝑘𝑒𝑦), 𝑣𝑖 (𝑣𝑎𝑙𝑢𝑒)という対を為すベクトルが沢山 • 各入力ベクトル𝑞𝑗 と似ているkeyを集める • keyに対応するvalueたちを混ぜて出力 𝑞1 𝑞2 𝑞3 𝑞4 𝑘1 𝑘2 𝑘3 𝑣1 𝑣2 𝑣3 25
Scaled Dot Product Attention① 𝑄𝐾 Τ の各要素は𝑞𝑖 と𝑘𝑗 の内積 ➢𝑞𝑖 , 𝑘𝑗 の向きが近いほど値が大きいため、類似度の指標に (内積はベクトル長に比例してしまうため、 𝑑で割る) 𝑑 𝑞1 𝑞1 ∗ 𝑘1𝑞1 ∗ 𝑘2𝑞1 ∗ 𝑘3 𝑞2 𝑘1 𝑘2 𝑘3 𝑘4 𝑘5𝑞2 ∗ 𝑘1𝑞2 ∗ 𝑘2𝑞2 ∗ 𝑘3 𝑞3 𝑞4 * 𝑞1と各keyとの 類似度ベクトル 𝑞3 ∗ 𝑘1𝑞3 ∗ 𝑘2𝑞3 ∗ 𝑘3 𝑞4 ∗ 𝑘1𝑞4 ∗ 𝑘2𝑞4 ∗ 𝑘3 26
Scaled Dot Product Attention② • 𝒔𝒐𝒇𝒕𝒎𝒂𝒙 𝒙 = 𝒆𝒙 𝟏 𝒆𝒙 𝟐 𝒆𝒙 𝟑 [ σ 𝒙𝒊 , σ 𝒙𝒊 , σ 𝒙𝒊 𝒆 𝒆 𝒆 , … ] ➢ベクトルを少し過激に確率分布に変換する関数 ex.) 𝑠𝑜𝑓𝑡𝑚𝑎𝑥([2,3,5]) = [0.4, 0.11, 0.85] 𝑞1 ∗ 𝑘1𝑞1 ∗ 𝑘2𝑞1 ∗ 𝑘3 softmax 𝑞1 ~𝑘1 𝑞1 ~𝑘2 𝑞1 ~𝑘3 𝑞2 ∗ 𝑘1𝑞2 ∗ 𝑘2𝑞2 ∗ 𝑘3 softmax 𝑞2 ~𝑘1 𝑞2 ~𝑘2 𝑞2 ~𝑘3 𝑞3 ∗ 𝑘1𝑞3 ∗ 𝑘2𝑞3 ∗ 𝑘3 softmax 𝑞3 ~𝑘1 𝑞3 ~𝑘2 𝑞3 ~𝑘3 𝑞4 ∗ 𝑘1𝑞4 ∗ 𝑘2𝑞4 ∗ 𝑘3 softmax 𝑞4 ~𝑘1 𝑞4 ~𝑘2 𝑞4 ~𝑘3 𝑞1と各keyとの 類似性の確率分布 27
Scaled Dot Product Attention③ 前stepで求めた確率分布を重みと捉え、valuesを加重平均 𝑞1 ~𝑘1 𝑞1 ~𝑘2 𝑞1 ~𝑘3 𝑞2 ~𝑘1 𝑞2 ~𝑘2 𝑞2 ~𝑘3 𝑞3 ~𝑘1 𝑞3 ~𝑘2 𝑞3 ~𝑘3 * 𝑞4 ~𝑘1 𝑞4 ~𝑘2 𝑞4 ~𝑘3 𝑣1 𝑞1~𝑘𝑖 ∗ 𝑣𝑖 𝑣2 𝑞2 ~𝑘𝑖 ∗ 𝑣𝑖 𝑣3 𝑞3 ~𝑘𝑖 ∗ 𝑣𝑖 𝑞4 ~𝑘𝑖 ∗ 𝑣𝑖 [0.4, 0.11, 0.85] 28
Multi-Head Attention 𝑀𝑢𝑙𝑡𝑖𝐻𝑒𝑎𝑑 𝑄, 𝐾, 𝑉 = 𝑐𝑜𝑛𝑐𝑎𝑡 ℎ𝑒𝑎𝑑𝑖 𝑊 𝑜 𝑄 ℎ𝑒𝑎𝑑𝑖 = 𝑆𝐷𝑃𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 𝑄𝑊𝑖 , 𝐾𝑊𝑖𝐾 , 𝑉𝑊𝑖𝑉 ➢𝑄, 𝐾, 𝑉に様々な変換を加え、組合わせに多様性を持たせている 𝑄 𝐾 𝑉 * 𝑄𝑄 𝑊 𝑄 3 𝑊𝑊 12 ′𝑄 ′𝑄3′ 𝑄1 2 * 𝐾 𝐾 𝑊 𝐾 3 𝑊𝑊 12 ′ ′ 𝐾 𝐾1𝐾 2 3 * 𝑉3𝑉 𝑊 𝑉 𝑊 𝑊1 2 ′ ′ 𝑉1′𝑉2𝑉3 ′ * 𝑊𝑂 SDP Attention 29
Multi-Head Attentionの使われ方① Output Probabilities Q,K,Vが全て同じ入力(文) ➢入力(文)を様々な角度で切り出した物同士 を見比べ、注目すべき箇所を決めて出力 勝手なイメージ softmax Linear Layer Norm + the violin was too big Feed Forward SDP Attention Layer Norm Layer Norm + + Feed Forward Multi-Head Attention Layer Norm Layer Norm + + Multi-Head Attention Masked Multi-Head Attention Nx it was too big K not because put in the bag 𝑊1𝑄 𝑊1𝐾 Q V violin bag the violin 𝑊1𝑉 I could not put the violin in the bag because it was too big. 〜 + Input Embedding Inputs 〜 + Output Embedding Outputs 30 xN
Multi-Head Attentionの使われ方② Output Probabilities 𝑄: 加工済み出力文 𝐾, 𝑉: encoderの出力 softmax Linear Layer Norm + Feed Forward ➢出力文から、入力文のどの意味がまだ不足 しているか等を判断している? Layer Norm Layer Norm + + Feed Forward Multi-Head Attention Layer Norm Layer Norm + + Multi-Head Attention Masked Multi-Head Attention Nx 〜 + Input Embedding Inputs 〜 + Output Embedding Outputs 31 xN
Masked Multi-Head Attention Output Probabilities 主に学習のための機構 softmax Linear Layer Norm 学習時は入力文と出力文の模範解答を流す + Feed Forward 次単語予測の正解がわからないように、 出力文を一部maskするだけ Layer Norm Layer Norm + + Feed Forward Multi-Head Attention Layer Norm Layer Norm + + Multi-Head Attention Masked Multi-Head Attention Nx 〜 + Input Embedding Inputs 〜 + Output Embedding Outputs 32 xN
パーツ理解完了 最後に流れを再確認して 締めましょう 33
Output Probabilities Transformer softmax Linear Layer Norm + Feed Forward Nx Positional Encoding Layer Norm Layer Norm + + Feed Forward Multi-Head Attention Layer Norm Layer Norm + + Multi-Head Attention Masked Multi-Head Attention 〜 + 〜 Positional Encoding xN + Input Embedding Output Embedding Inputs Outputs 34
Output Probabilities 次回予告 softmax Linear Layer Norm transformerは本来翻訳家 + Feed Forward だが、意味解釈能力が超凄い これ、何にでも応用できる? N x ➢GPTs, BERT Layer Norm Layer Norm + + Feed Forward Multi-Head Attention Layer Norm Layer Norm + + Multi-Head Attention Masked Multi-Head Attention 〜 + 〜 xN + Input Embedding Output Embedding Inputs Outputs 35