17K Views
January 18, 24
スライド概要
AI・機械学習を勉強したい学生たちが集まる、京都大学の自主ゼミサークルです。私たちのサークルに興味のある方はX(Twitter)をご覧ください!
2023年度論文読み会#6 Google DeepMindの気象予報モデル「GraphCast」 京都大学理学部 3回生 松田拓巳 0
⚫ ⚫ 質問があるときは • ミュート外して質問(解説途中でもOK!) • メッセージで質問 備考 • 気象の基本知識が必要ですが、できるだけ補足説明します • ただし、説明すっ飛ばしてしまうかもしれないので、そのときは遠慮なく 聞いてください! • 今回の論文は(補足資料含めると)102ページあるので、超細かい部分は 解説を割愛します 1
Google DeepMindの気象予報モデル「GraphCast」 目次 1. ざっくりまとめ 2. 先行研究 3. 入出力データ 4. モデル 5. モデルの評価 2
1. ざっくりまとめ 1.1 どんなもの? ⚫ グラフニューラルネットワーク(GNN)を使った全球中期気象予報モデル。 • 全球予報=地球全体の気象状態の予測 • 中期予報=10日間先までの予報 初期時刻 2018-05-05 12:00(UTC)から予報した、6日先の地上気温 3
1. ざっくりまとめ 1.2 モデル a. Input weather state =現在&6時間前の気象状態 b. Predict the next state =現在から6時間先の気象状態 c. Roll out a forecast =b.で得られた予測値を入力として 10日先*まで繰り返す *もちろん原理的には10日より先の予測もできる 4
1. ざっくりまとめ 1.2 モデル d. Encoder 格子点(緯度経度グリッド) → ノード(グラフ) e. Processor ノード間で情報集約・伝達 f. Decoder ノード(グラフ) → 格子点(緯度経度グリッド) 5
1. ざっくりまとめ 1.3 現業数値予報モデルとの精度比較 ⚫ 1380個(69変数×予報時間20step)のうち1246個(90.3%)で GraphCastの方が数値予報モデル*よりも有意にRMSEが低かった(p<0.05) *ECMWF(ヨーロッパ中期予報センター)の全球数値予報モデルHRESと比較 6
1. ざっくりまとめ 1.4 台風、大気の川、異常気温も予報 台風 大気の川 異常気温 ⚫ 実際の進路とのズレ ⚫ 大雨や大雪をもたらす ⚫ ○○年に一度の高温/低温 ⚫ 台風の発生検知精度 ⚫ 鉛直積算水蒸気輸送(IVT) の検知精度 7
2. 先行研究 8
2. 先行研究 2.1 数値予報(Numerical Weather Prediction) 数値予報モデルの特徴 ⚫ ⚫ 大気現象を支配する微分方程式を 問題点 ⚫ 過去の膨大な気象データを直接使っ 数値計算で解き、未来の大気状態 て精度を改善する実用的な手法がな を予測する い モデルの精度を上げる方法 • 格子点を細かく取る • 積分時間間隔を短くする • アルゴリズム、近似手法を開発する ⚫ スパコンをぶん回して数時間かかる 9
2. 先行研究 2.2 機械学習ベース予報(ML-based Weather Prediction) 機械学習ベースモデルの特徴 ⚫ 過去の気象データを直接学習し、 先行研究の例 ⚫ ⚫ Can Machines Learn to Predict Weather? … ⚫ Improving data-driven global weather… ⚫ … a Resnet pretrained on climate simulations 予測を行う ⚫ メリット • • ⚫ Transformer ⚫ ClimaX ⚫ Pangu-Weather 方程式では表現できないようなパター ンやスケールを捉えられる CNN ⚫ 数分で計算可能(GraphCastは1分) GNN ⚫ Forecasting Global Weather with GNN ⚫ Fourier Neural Operators ⚫ FourCastNet: Accelerating Global … ⚫ FourCastNet: A Global Data-driven … 10
3. 入出力データ 11
3. 入出力データ 3.1 大気再解析データ「ERA5」を使用 大気状態を表すデータセット ■補足|再解析とは ⚫ ECMWFが作成。精度に定評あり ⚫ 過去の大気状態を再現すること ⚫ 水平格子:0.25°×0.25° ⚫ 数値予報モデルの計算結果を観測 ⚫ 鉛直層数:37等圧面 ⚫ 1979年~2017年を学習データ* データで補正する(データ同化) *train=1979~2015年,valid=2016~2017年 12
3. 入出力データ 3.2 大気状態とは ⚫ 気温 ⚫ 風速の東西/南北/鉛直成分 ⚫ 比湿 6変数×37等圧面 =222個 高度 約9km 300hPa 高度 約5km 500hPa … Surface variables ⚫ 気温 高度 約1.5km 850hPa … ⚫ 風速の東西/南北成分 1000hPa 地上 ⚫ 海面気圧 ⚫ 降水量 100hPa … ⚫ ジオポテンシャル高度 1hPa … … Atmospheric variables 5個 地表面 13
4. モデル 14
4. モデル 4.1 アーキテクチャ概観 Encoder Processor Decoder ⚫ 緯度経度グリッド(入力) ⚫ ノード間で情報伝達 ⚫ グラフのノード → グラフのノードへencode → 緯度経度グリッド(出力) 15
4. モデル 4.2 グラフ構造「multi-mesh」 ⚫ 正20面体 (M0 ) の各面を4分割した80面体をM1とする → これをM6まで繰り返す(=81920面体,40962頂点) ⚫ ノードはM6の頂点、エッジはM0~M6の辺を使う 16
4. モデル 4.2 グラフ構造「multi-mesh」 multi-meshの利点① ⚫ M0の長いエッジで長距離相互作用 を捉え、 M6の短いエッジでは局所 的な相互作用を捉えることができる multi-meshの利点② ⚫ 緯度経度グリッドのように、緯度に よる解像度の違いがない 高緯度:グリッドが細かい 低緯度:グリッドが粗い 緯度経度グリッドの場合 17
4. モデル 4.3 学習方法 3段階で学習 1. 普通に学習(1step=6h後の予測値 vs 真値 の誤差を逆伝播),学習率は線形増加 2. 普通に学習,学習率はcos減少 3. 𝑖 step(𝑖 = 2~12)予測してBPTT*で逆伝播,学習率は 3×10-7 で一定 *時間方向の誤差逆伝播法。「自然言語処理の基礎」p.89や「ゼロつく❷」5.2.3節を参照 BackPropagetion Through Time 18
4. モデル 4.3 学習方法 損失関数 最適化手法 ⚫ 重み付き二乗和誤差 • 低高度ほど線形に重みをつける(高度について ⚫ AdamW(𝛽1 = 0.9, 𝛽2 = 0.95, 𝜆 = 0.1) 和を取ると1になるように) • 地上風速,海面気圧,降水量の重みは0.1 学習上の工夫 ⚫ 勾配クリッピング(閾値:32) ⚫ TPU使用(学習時はbfloat16で計算) ⚫ Gradient Checkpointing(論文) 学習コスト ⚫ Google Cloud TPU v4 (32GB)を32個 ⚫ 4週間 ⚫ モデルのパラメータ数:3670万* *参考:ResNet=1000万~6000万,GPT-3=1750億,Pangu-Weather=2.56億 19
4. モデル 4.4 モデルの詳細:入力(Encoderの前) ❶ ノード/エッジの値を初期化 mesh 𝑡−1 𝑡 𝑡−1 𝑡 𝒗G,feat = 𝒙 𝑖 , 𝒙𝑖 , 𝒇𝑖 , 𝒇𝑖 , 𝒄𝑖 𝑖 edge特徴量 𝒗M,feat = [cos 𝑙𝑎𝑡 , sin 𝑙𝑜𝑛 , cos(𝑙𝑜𝑛)] 𝑖 𝒆𝑀,feat 𝑣𝑠M →𝑣𝑟M mesh特徴量 = 𝑙𝑣𝑠M→𝑣𝑟M , 𝒅𝑣𝑠M →𝑣𝑟M mesh (=M) 𝒆G2M,feat = 𝑙𝑣𝑠G →𝑣𝑟M , 𝒅𝑣𝑠G →𝑣𝑟M 𝑣 G →𝑣 M 𝑠 𝑟 grid (=G) 𝒆M2G,feat = 𝑙𝑣𝑠M→𝑣𝑟G , 𝒅𝑣𝑠M →𝑣𝑟G 𝑣 M →𝑣 G 𝑠 edge特徴量 𝑟 入力特徴量 入力特徴量 ❷ 特徴量を埋め込み表現に変換 grid 𝒗M,feat = MLP 𝒗M,feat 𝑖 𝑖 G,feat 𝒗G 𝑖 = MLP 𝒗𝑖 𝑀,feat 𝒆M = MLP 𝒆 M M 𝑣 →𝑣 𝑣 M →𝑣 M 𝑠 𝑟 𝑠 𝑟 G2M,feat 𝒆G2M = MLP 𝒆 G M 𝑣 →𝑣 𝑣 G →𝑣 M 𝑠 𝑟 𝑠 𝑟 M2G,feat 𝒆M2G = MLP 𝒆 G M 𝑣 →𝑣 𝑣 M →𝑣 G 𝑠 𝑟 𝑠 𝑟 20
4. モデル 4.5 モデルの詳細: Encoderの処理 ❶ エッジ重み,その両端ノードを使ってエッジ重みの更新量を計算 ′ 𝒆G2M = MLP 𝑣 G →𝑣 M 𝑠 𝑟 G , 𝒗M 𝒆G2M , 𝒗 G M 𝑠 𝑟 𝑣 →𝑣 𝑠 𝑟 𝒆G2M 𝑣 G →𝑣 𝐌 𝑠 𝑣𝑟M 𝒓 mesh (=M) grid (=G) 𝑣𝑠G ❷ meshノードの値と,そのノードに向かうエッジの重み和を使ってmeshノードの更新量を計算 ′ 𝒗M 𝑖 = MLP G2M 𝒗M 𝑖 , agg 𝒆𝑣 G →𝑣 M 𝑠 ′ 𝑣𝑖M 𝑖 ❸ gridノードの値をMLPに通してgridノードの更新量を計算 G′ 𝒗𝑖 mesh (=M) grid (=G) = MLP 𝒗G𝑖 ❹ ノード/エッジの値を更新する G G 𝒗G 𝑖 ← 𝒗𝑖 + 𝒗𝑖 ′ M M 𝒗M 𝑖 ← 𝒗𝑖 + 𝒗𝑖 ′ 𝒆G2M 𝑣 G →𝑣 M 𝑠 ← 𝑟 𝒆G2M 𝑣 G →𝑣 M + G2M ′ 𝒆𝑣 G →𝑣 M 21
4. モデル 4.6 モデルの詳細: Processorの処理 ❶ エッジ重み,その両端ノードを使ってエッジ重みの更新量を計算 ′ 𝒆M = MLP 𝑣 M →𝑣 M 𝑠 𝑟 M , 𝒗M 𝒆M , 𝒗 M M 𝑠 𝑟 𝑣 →𝑣 𝑠 mesh (=M) 𝒆M 𝑣𝑠M →𝑣𝒓𝐌 𝑟 𝑣𝑠M 𝑣𝑟M ❷ meshノードの値と,そのノードに向かうエッジの重み和を使ってmeshノードの更新量を計算 ′ 𝒗M 𝑖 = MLP M 𝒗M , agg 𝒆 𝑖 𝑣 M →𝑣 M 𝑠 ′ mesh (=M) 𝑣𝑖M 𝑖 ❸ ノード/エッジの値を更新する 𝒗M 𝑖 ← 𝒗M 𝑖 + ′ 𝒗M 𝑖 これを16回繰り返す M M 𝒆M ← 𝒆 + 𝒆 M M M M 𝑣 →𝑣 𝑣 →𝑣 𝑣 M →𝑣 M 𝑠 𝑟 𝑠 𝑟 𝑠 ′ 𝑟 22
4. モデル 4.7 モデルの詳細: Decoderの処理 ❶ エッジ重み,その両端ノードを使ってエッジ重みの更新量を計算 𝑠 ′ 𝒆M2G = MLP 𝑣 M →𝑣 G 𝑠 𝒆M2G 𝑣 M →𝑣 𝐆 𝑟 𝑣𝑠M 𝒓 M , 𝒗G 𝒆M2G , 𝒗 G M 𝑠 𝑟 𝑣 →𝑣 𝑠 𝑟 𝑣𝑟G mesh (=M) grid (=G) ❷ gridノードの値と,そのノードに向かうエッジの重み和を使ってgridノードの更新量を計算 G′ 𝒗𝑖 = MLP 𝒗G 𝑖 , agg M2G ′ 𝒆𝑣 M→𝑣 G 𝑠 𝑖 ❸ gridノードの値を更新する G G 𝒗G ← 𝒗 + 𝒗 𝑖 𝑖 𝑖 mesh (=M) 𝑣𝑖G grid (=G) ′ 23
4. モデル 4.8 モデルの詳細: 出力(Decoderの後) ❶ gridノードをMLPに通し,現在時刻の値を加えたものが予測値 G ෝG 𝒚 = MLP 𝒗 𝑖 𝑖 𝑋 𝑡+1 = 𝐺𝑟𝑎𝑝ℎ𝐶𝑎𝑠𝑡 𝑋 𝑡 , 𝑋 𝑡−1 = 𝑋 𝑡 + 𝑌 𝑡 24
4. モデル ■特徴量やモデルについての補足 スライド4.4節の変数 *詳細は論文p.26を参照 ⚫ 𝒙𝑡𝑖 :時刻𝑡における,グリッド𝑖での大気状態ベクトル(スライド3.2節参照) ⚫ 𝒇𝑡𝑖 : 〃 での時間に関するベクトル = 大気上端における1時間当たりの総入射日射量,現地日時のsin,cos ⚫ 𝒄𝑖 :グリッド𝑖の位置に関するベクトル = 陸か海か,地表のジオポテンシャル,緯度経度のsin,cos ▼ 関数の形 ▼ 導関数 モデル ⚫ MLPすべて隠れ層1層 ⚫ 活性化関数はswish ⚫ Layer Normalization Swishについて:https://atmarkit.itmedia.co.jp/ait/articles/2004/15/news018.html 25
5. モデルの評価 26
5. モデルの評価 5.1 評価方法 評価対象の期間 主な評価指標 ⚫ 2018年1月1日~2018年12月31日 ⚫ RMSE(Root Mean Square Error) 比較対象 ⚫ GraphCast ⚫ Pangu-Weather(Transformerモデル) ⚫ HRES(数値予報モデル) 評価に用いた変数 二乗平均誤差 ⚫ ACC(Anomaly Correlation Coefficient) 「真値ー平均」と「予測値ー平均」の相関係数 ⚫ normalized difference RMSEA − RMSEB , RMSEB ACCA − ACCB 1 − ACCB ⚫ 高度:WeatherBenchに則る ⚫ 変数:ECMWF Scores cardsに則る ※ERA5のバイアス特性を考慮し降水量も除いている 値が小さいほどモデルBよりも モデルAのほうが精度が良い 値が大きいほどモデルBよりも モデルAのほうが精度が良い 27
5. モデルの評価 ■補足|「評価に用いた変数」の詳細 ⚫ 高度 50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000 (hPa) ⚫ 変数 ECMWF Scores cardsには「風速の鉛直成分」はないので除外 ERA5の「降水量」の精度が怪しいので除外(論文) ⚫ 評価に用いた変数の個数は… 地上変数4個+大気変数5個×13層=69個 28
5. モデルの評価 5.3 概観 ⚫ 500hPa高度の予測精度 *500hPa等高線≒上空5500m付近の等圧線 ⚫ GraphCastの精度はHRESを上回る 29
5. モデルの評価 5.3 概観 ⚫ 赤い領域はHRESの方が高精度 ⚫ 青い領域はGraphCastの方が高精度 ⚫ 成層圏ではGraphCastの精度が悪化する 30
5. モデルの評価 5.4 GraphCast vs. MLベースモデル RMSEで比較 ⚫ 値が小さいほど精度が良いことを表す ⚫ 99.2%の変数で、GraphCastのRMSEは Pangu-WeatherのRMSEより低かった 31
5. モデルの評価 5.5 GraphCast vs. 数値予報モデル RMSEで比較 ⚫ 「5.3 概観」のグラフ a), b) を、他の変数について描画 32
5. モデルの評価 5.5 GraphCast vs. 数値予報モデル ACCで比較 ⚫ 「5.3 概観」のグラフ c) を、他の変数について描画 33
5. モデルの評価 5.6 最新のデータを学習に取り込むと精度はどうなる? データの最新性がもたらす影響 ⚫ 2021年のデータで精度を算出 ⚫ 新しいデータを取り入れるほどRMSEは 低くなっていく傾向 *他変数については論文補足資料7.1.3節のFigure15を参照 34
5. モデルの評価 5.7 地域によって精度はどうなる? ⚫ ⚫ 緯度20°を境に「北半球」「熱帯域」 「南半球」と分け、各地域で精度を算出 「ヨーロッパ」「東アジア」などの区分 でも精度を検証している *区分方法はECMWF scorecardsに基づく *他地域のグラフはFigure17,18を参照 35
5. モデルの評価 5.8 緯度/高度によって精度はどうなる? ⚫ 横軸:緯度,縦軸:気圧(≒高度) ⚫ 赤い領域はHRESの方が高精度 ⚫ 青い領域はGraphCastの方が高精度 ⚫ 成層圏でGraphCastの精度が悪化 … データセット(ERA5,HRES)間で成層圏の変数 の予測可能性が異なる可能性 36
5. モデルの評価 5.9 予測のズレを地図で見る ⚫ 「予測値ー真値」の分布をプロット 37
5. モデルの評価 5.10 GraphCastが得意な領域・苦手な領域を地図で見る ⚫ 赤い領域はHRESの方が高精度 ⚫ 青い領域はGraphCastの方が高精度 ⚫ 標高が高い地域は苦手? 38
5. モデルの評価 5.11 地表高度によって精度はどうなる? ⚫ 横軸:地表高度,縦軸:GraphCastとHRESのRMSE差(負=GraphCastが高精度) ⚫ 高高度ほどGraphCastの精度が悪化 ⚫ 予報時間が長くなると,HRESとの精度差はマシになる傾向 39
5. モデルの評価 5.11 地表高度によって精度はどうなる? ⚫ GraphCastが標高と気温の関係をうまく学習 できていなかった可能性 ⚫ GraphCast精度 悪 考えられる対策 ⚫ 「ハイブリッド座標系」を採用する ⚫ 高度重み*を標高を考慮したものにする *本スライド4.3節を参照 地上気温 5℃ 地上気温 30℃ 地表高度 高 5000m 40
5. モデルの評価 5.12 台風進路予測の精度は? 実際の進路とのズレ ⚫ 期間:2018~2021年 ⚫ 台風追跡アルゴリズムはECMWF trackerに基づくものを使用 外れ値=追跡アルゴリズムのエラーによるもの GraphCastでもHRESでも 台風の見逃し率は同じくらい であることを確認* *GraphCastはMSEで学習しており、ぼかしの効果によって台風を見落としている可能性がある。見落としが多いと、HRESとの比較 においてGraphCast有利になる可能性があるため、この検証を行ったそう。 41
5. モデルの評価 5.13 大気の川の予測精度は? 鉛直積算水蒸気輸送(IVT) ⚫ モデルで予測した比湿Qと風速U,VからIVTを求める (g=重力加速度,pb=1000hPa,pt=300hPa) ⚫ 北米沿岸/東太平洋で大気の川が頻繁に発生する時期 (1月~4月,10月~12月)を対象 画像引用:https://ja.wikipedia.org/wiki/大気の川 42
5. モデルの評価 5.14 異常気温の検出精度は? ⚫ ⚫ 1993~2016年の気温記録の上位(下位) ○%に入る気温を「異常気温」と定義 PR曲線は次式のgainを動かして得られる* gain×(モデル予測値ー中央値)+中央値 gain=0 → どんな予測をしても「異常高温ではない」判定となる → 偽陽性ゼロなので適合率100%,再現率低い gain→∞ → どんな予測をしても「異常気温」判定となる → 偽陰性ゼロなので再現率100%,適合率低い *グラフではgain=0.8~4.5の間で動かしている 43
44