みんなのAI
機械学習AI論文
読み込み中…

学ぶ

🏅マイ実績

Ch.06

Swin Transformer:階層型ウィンドウと大域文脈

Swin Transformer

パッチマージ

パッチ は画像の細かいマス、ウィンドウ は一度に見る範囲。W は窓の内側だけ、SW は窓を少しずらして隣とも短くつなぐ。マージ は縦横を半分にしつつ、特徴チャネルを深く積む段階。

4×44\times44×4 をチェッカー状の 2×22\times22×2 四群に分け、同じ位置同士を連結して一時的に C→4CC\rightarrow4CC→4C(先に情報を保持)にします。続いて LayerNorm→Linear で重要情報だけを 4C→2C4C\rightarrow2C4C→2C に圧縮。つまり空間は縮み、意味チャネルは要約しながら保つ段階です。
ConcatLayerNormLinear
H4×W4×C\frac{H}{4} \times \frac{W}{4} \times C4H​×4W​×C
H2×W2×4C\frac{H}{2} \times \frac{W}{2} \times 4C2H​×2W​×4C
H2×W2×2C\frac{H}{2} \times \frac{W}{2} \times 2C2H​×2W​×2C
第5章で学んだ ViT(Vision Transformer) は、画像全体のピクセルが一度に議論する「超大型タウンホール(グローバル注意)」でした。精度は高い一方、参加者(解像度)が増えるとメモリ計算量が二乗で爆発(O(N2)\mathcal{O}(N^2)O(N2))するという致命的な弱点がありました。
これを解決する Swinトランスフォーマー(Shifted Window Transformer) は非常に巧妙です。まずは小さな会議室(Window)の中だけで会話し、次の層では仕切りを少しずらして(Shift)隣室と情報を混ぜます。さらに Patch Merging で4人の担当者を1人のリーダーに要約し、より大局的な理解を作ります。本章では、この分割統治の直感が数式としてどう計算量を削減するかを見ます。

式の読み方(Swin 要点)

1) W-MSA(小会議室ルール): ウィンドウ一辺が MMM のとき、1ウィンドウ内の関連は概ね M×MM\times MM×M 規模。全員同時会議ではなく、小部屋単位に分割して計算量を制御します。
PatchW-MSAlocalShiftSW-MSA+ maskmerge ↓N
図は パッチ → ウィンドウ注意(W/SW) → パッチマージ の流れです。
2) SW-MSA(仕切り移動): 次ブロックでウィンドウを ⌊M/2⌋\lfloor M/2 \rfloor⌊M/2⌋ だけシフトし、アテンションマスクで有効な Q–K 接続のみ残します。隣接情報を混ぜつつ、無効接続は遮断できます。
3) パッチマージ(情報量保存の視点):
① 2×22\times22×2 の4パッチ(各 CCC チャネル)を1位置で Concat し、一時的に 4C4C4C にする(まずは捨てずに積む)。
② その後 Linear で重要情報を 2C2C2C に圧縮。
つまり空間は (H2,W2)\left(\frac{H}{2},\frac{W}{2}\right)(2H​,2W​) へ縮む一方、チャネルを増やして情報損失を緩和します。
4) なぜ解像度↓・チャネル↑なのか(少し深く):
① 情報容量の観点
容量の直感指標は (H×W)×C(H\times W)\times C(H×W)×C。空間が1/4になってチャネル据え置きだと表現力が急減し、細部が失われやすい。そこで 4C→2C4C\rightarrow2C4C→2C で損失を緩衝します。
② 計算量と表現力のバランス
高解像度維持は高コスト、チャネルだけ無限増加も高コスト。Swin は先にトークン数を減らし、チャネルで意味密度を高めるバランスを取ります。
③ 文脈・受容野の観点
深層では細部より大域構造を扱う必要があります。解像度低下は広域文脈を、チャネル増加は多様な意味表現の器を提供します。
④ 実務一言
解像度だけ下げると"記憶喪失"、チャネルだけ増やすと"計算爆発"。Swin の解像度↓+チャネル↑はその中間の工学的妥協点です。

Swinトランスフォーマー:直感と数式が交わる地点

1. ウィンドウ注意(W-MSA):小会議室で計算量をダイエット
比喩: 大人数が同時に話す(ViT)と計算が重くなります。そこで画像を 7x7 程度の小会議室(ウィンドウ)に分け、同じ部屋の中だけで注意計算を行います。
核心式: パッチ総数を NNN、ウィンドウ一辺を MMM とすると、ViT は O(N2)\mathcal{O}(N^2)O(N2)。W-MSA は O(N⋅M2)\mathcal{O}(N\cdot M^2)O(N⋅M2)。MMM は固定小定数(例 7)なので、実運用ではほぼ O(N)\mathcal{O}(N)O(N) 的に扱えます。
2. シフトウィンドウ(SW-MSA):仕切りをずらして連携する
比喩: ずっと同室メンバーだけでは情報が閉じます。次の層で仕切りを半窓分ずらすと、昨日は隣室だったトークンが今日同室になり情報交換できます。
核心式: 数学的には ⌊M2⌋\lfloor \frac{M}{2} \rfloor⌊2M​⌋ だけシフト。端で生じる断片は cyclic shift と attention mask で整合的に処理し、不要な接続を防ぎます。
3. パッチマージ(Patch Merging):4人の担当を1人の管理へ
比喩: 深い層では細かな意見を集約し、より大局的な判断が必要です。隣接する 2×22\times22×2(計4個)を1単位にまとめます。
核心式: 解像度は (H2,W2)\left(\frac{H}{2},\frac{W}{2}\right)(2H​,2W​) に減少。チャネルは一時的に 4C4C4C へ増え、Linear で 2C2C2C に圧縮。結果として空間は縮み、意味表現は深まるマルチスケール構造になります。
4. 2連続ブロック(Two-Successive Blocks):常にセットで動く
概念: Swin は(通常ウィンドウ → シフトウィンドウ)の2段コンボで動作します。
* zl=W-MSA(… )z^l = \text{W-MSA}(\dots)zl=W-MSA(…):第1段、同室内で対話
* zl+1=SW-MSA(… )z^{l+1} = \text{SW-MSA}(\dots)zl+1=SW-MSA(…):第2段、仕切りをずらして隣接情報を混合

重要性

重い O(N2)O(N^2)O(N2) を避ける工学的勝利
ViT は解像度が上がると計算量が二乗で増え、一般GPUでは学習が厳しくなります。Swin はこれを実質線形に近い形へ抑え、4K映像やギガピクセル病理画像(WSI)でも OOM を起こしにくい実用路線を開きました。
密な予測(Dense Prediction)でのパラダイム転換
分類だけなら ViT でも十分ですが、検出・セグメンテーションでは物体サイズが多様なため マルチスケール・ピラミッド が不可欠です。Swin は Patch Merging でこの構造をTransformerに自然に組み込み、多くの場面でCNNバックボーンを置き換えました。
視覚バックボーンの統一化
以前は「分類はViT、検出はResNet」のように分かれていましたが、Swin 以降は分類・検出・分割を同系統バックボーンで統一しやすくなり、開発保守コストを大きく下げられます。

使われ方

1. デバッグ最優先:入力解像度は 32 の倍数に合わせる
Swin では 2×22\times22×2 マージを繰り返すため、解像度が段階的に半減します(1/2→1/4→1/8→1/16→1/321/2\rightarrow1/4\rightarrow1/8\rightarrow1/16\rightarrow1/321/2→1/4→1/8→1/16→1/32)。入力の縦横は 32 で割り切れる値(224, 256, 512, 1024 など)が安全です。合わない場合はゼロパディングやリサイズを行います。
2. 実務はファインチューニング開始が基本
小規模社内データでゼロから学習するのは非効率で過学習しやすいです。一般には Hugging Face / MMDetection の事前学習済み Swin-T / Swin-B を読み込み、ヘッドだけ差し替えて微調整します。
3. OOM 時の実戦サバイバル手順
線形化してもモデルは重いので、OOM なら順に
1) Batch Size を下げる
2) 入力解像度を下げる(例 1024→512)
3) ウィンドウサイズ MMM を下げる(例 7→5)
を実施。補助として `gradient_checkpointing=True` も有効です。

まとめ

💡 [要点チートシート]
* 動作哲学: 分割して計算(W-MSA)、ずらして再接続(SW-MSA)、4つを1つへ要約(Patch Merging)。
* 意義: 高解像度で爆発しやすい O(N2)O(N^2)O(N2) 問題を実務可能なスケールへ抑え、CNN 的マルチスケール構造まで取り込んだハイブリッド視覚モデル。
* 実務マインド: 「入力は 32 の倍数、学習は事前学習重みから、メモリ不足時は MMM と batch を調整」。

解法のヒント

Swin 問題の共通フレーム(まずこの3行)
1) 問題が W-MSA / SW-MSA / Patch Merging のどれを聞いているかを特定する。
2) 計算問題なら単位を先に整理する:トークン数 NNN、ウィンドウサイズ MMM、チャネル CCC、解像度 (H,W)(H,W)(H,W)。
3) マージが出たら基本変換を即適用する:
* 空間:(H,W)→(H2,W2)(H,W)\rightarrow(\frac{H}{2},\frac{W}{2})(H,W)→(2H​,2W​)
* チャネル:C→4C→2CC\rightarrow4C\rightarrow2CC→4C→2C

例(概念)
「Swin の核心でないものは?」→ 毎層固定の全域注意は核心ではない。

例(○/×)
「パッチマージは空間解像度を上げる」→ 誤り。答え 0

例(シナリオ)
「高解像度で OOM。最初に何を試す?」→ バッチ↓、解像度↓、ウィンドウサイズ(MMM)↓の順。
問題タイプ別クイック例(Ch05 レイアウト同一、内容だけ Swin 用)
例(選択式・計算)
「N=20N=20N=20 のとき全域注意の N2N^2N2 は?」→ 202=40020^2=400202=400。

例(累積マージ計算)
「パッチマージ3回後の解像度は?」→ 1回ごとに辺が1/2なので 1/81/81/8、つまり (H8,W8)(\frac{H}{8},\frac{W}{8})(8H​,8W​)。

例(構成・個数)
「1辺ウィンドウ8個の正方格子の総数は?」→ 82=648^2=6482=64。

例(総合推論)
「入力 224×224224\times224224×224、パッチ 4×44\times44×4、マージ1回後トークン数は?」
1) 初期トークン:(224/4)2=562=3136(224/4)^2=56^2=3136(224/4)2=562=3136
2) 1回マージ:3136/4=7843136/4=7843136/4=784
→ 答え 784

実戦ワンポイント
計算問題は「何が1/2になり、何が2倍になるか」を先に分けると速く解けます。(空間は縮小、意味チャネルは深化)