大家的AI
机器学习AI论文
加载中…

学习

🏅我的成就

Ch.06

Swin Transformer:从全局注意力到分层窗口结构

Swin Transformer

Patch 合并

patch 是把图像切成的小格;窗口 是一次看多大范围。W 只在窗内互看;SW 把窗平移一点,让邻格短暂相连。合并 让高宽减半、把特征通道堆得更深。

把 4×44\times44×4 网格按棋盘拆成四个 2×22\times22×2 组后,将相同位置先拼接,临时得到 C→4CC\rightarrow4CC→4C(先保留信息)。随后经 LayerNorm→Linear,把关键语义压缩为 4C→2C4C\rightarrow2C4C→2C。即:空间变小,但语义通道以“先堆叠后压缩”的方式被保留。
ConcatLayerNormLinear
H4×W4×C\frac{H}{4} \times \frac{W}{4} \times C4H​×4W​×C
H2×W2×4C\frac{H}{2} \times \frac{W}{2} \times 4C2H​×2W​×4C
H2×W2×2C\frac{H}{2} \times \frac{W}{2} \times 2C2H​×2W​×2C
在第5章学到的 ViT(Vision Transformer) 中,整张图像像是在一个超大会议厅里同时讨论(全局注意力)。效果强,但参与者(分辨率)一多,内存与计算会按平方爆炸(O(N2)\mathcal{O}(N^2)O(N2)),这是致命痛点。
Swin Transformer(Shifted Window Transformer) 的解决方式非常巧妙:先只在小会议室(Window)里讨论;到下一层,把隔板稍微错开(Shift),让邻近小组交换信息;再通过 Patch Merging 把4个细粒度单元压成1个更高语义单元。这个“分而治之”的直觉如何对应到显著降耗的数学机制,就是本章重点。

公式怎么读(Swin 要点)

1) W-MSA(小会议室法则): 若窗口边长为 MMM,单窗口内相关性规模约为 M×MM\times MM×M。也就是把“全员大会”拆成多个小房间,控制计算量。
PatchW-MSAlocalShiftSW-MSA+ maskmerge ↓N
示意图展示 patch → 窗口注意力(W/SW) → patch 合并 的流程。
2) SW-MSA(错位隔板): 下一块中将窗口平移 ⌊M/2⌋\lfloor M/2 \rfloor⌊M/2⌋,并用注意力掩码只保留合法 Q–K 连接。这样既能跨窗口融合信息,又能避免无效连接。
3) Patch Merging(信息保真视角):
① 先把 2×22\times22×2 的4个 patch(各 CCC 通道)在同一位置 concat,临时变成 4C4C4C(先堆叠,不先丢信息)。
② 再用 Linear 压到 2C2C2C。
因此空间变为 (H2,W2)\left(\frac{H}{2},\frac{W}{2}\right)(2H​,2W​),同时通过增通道缓冲信息损失。
4) 为什么是“分辨率↓、通道↑”?(更深入)
① 信息容量视角
可粗看为 (H×W)×C(H\times W)\times C(H×W)×C。若空间缩到 1/4 而通道不变,表达容量会骤降,细节易丢失;所以用 4C→2C4C\rightarrow2C4C→2C 做缓冲。
② 算力与表达力平衡
维持高分辨率很贵,只增通道也很贵。Swin 先降 token 数(省显存/算力),再用通道增加每个 token 的语义密度。
③ 上下文/感受野视角
深层需要更多全局结构信息。分辨率下降带来更大上下文,通道上升提供更丰富语义编码空间。
④ 工程一句话
只降分辨率会“失忆”,只升通道会“算力爆炸”。Swin 的分辨率↓+通道↑是两者之间的工程折中。

Swin Transformer:直觉与公式的交汇点

1. 窗口注意力(W-MSA):在小会议室里给计算量“减脂”
比喻: 如果所有人同时发言(ViT),计算会非常重。Swin 把图像切成 7x7 级别的小会议室(窗口),只允许同一窗口内部做注意力。
核心公式: 总 patch 数记为 NNN,窗口边长记为 MMM。ViT 复杂度是 O(N2)\mathcal{O}(N^2)O(N2),W-MSA 变为 O(N⋅M2)\mathcal{O}(N\cdot M^2)O(N⋅M2)。因为 MMM 通常是固定小常数(如7),整体可视为接近 O(N)\mathcal{O}(N)O(N) 的线性扩展。
2. 移位窗口(SW-MSA):错位隔板实现跨组沟通
比喻: 永远只和本房间交流会形成信息孤岛。下一层把隔板向右下平移半窗后,原本隔壁的 patch 会进入同一窗口,从而发生信息交换。
核心公式: 窗口通常平移 ⌊M2⌋\lfloor \frac{M}{2} \rfloor⌊2M​⌋ 个 patch。边界碎片通过 cyclic shift 与 attention mask 优雅处理,避免无效连接而不显著增加额外内存。
3. Patch Merging:4个员工压成1个主管
比喻: 越往深层,越要把细碎局部意见汇总成全局语义。于是把相邻 2×22\times22×2(共4个)patch 合并。
核心公式: 空间分辨率降到 (H2,W2)\left(\frac{H}{2},\frac{W}{2}\right)(2H​,2W​),通道先涨到 4C4C4C,再经 Linear 压到 2C2C2C。于是层级越深,空间越小但语义越强,形成多尺度金字塔。
4. 双连续块(Two-Successive Blocks):总是成对出现
概念: Swin 总是以(普通窗口 → 移位窗口)的二连击工作。
* zl=W-MSA(… )z^l = \text{W-MSA}(\dots)zl=W-MSA(…):第1步,先在本窗口内交流
* zl+1=SW-MSA(… )z^{l+1} = \text{SW-MSA}(\dots)zl+1=SW-MSA(…):第2步,错位后跨邻域交流

为何重要

绕开沉重 O(N2)O(N^2)O(N2) 的工程胜利
ViT 在分辨率上升后会出现二次方增长,普通 GPU 很快吃不消。Swin 将其压到近线性扩展,让 4K 自动驾驶视频、病理全切片(WSI)等超高分辨率数据训练变得可落地,显著降低 OOM 风险。
密集预测(Dense Prediction)的范式转变
分类任务 ViT 已可胜任,但检测与语义分割要求多尺度表达,因为目标尺寸差异巨大。Swin 通过 Patch Merging 将金字塔结构原生引入 Transformer,使其在大量场景中替代传统 CNN 主干。
统一视觉主干(Backbone)的可能
过去常见“分类用 ViT、检测用 ResNet”的割裂方案。Swin 之后,分类/检测/分割可共享同类主干,显著降低工程维护与研究成本。

如何使用

1. 调试第一条:输入分辨率尽量是 32 的倍数
Swin 内部多次执行 2×22\times22×2 合并,分辨率会不断对半:1/2→1/4→1/8→1/16→1/321/2\rightarrow1/4\rightarrow1/8\rightarrow1/16\rightarrow1/321/2→1/4→1/8→1/16→1/32。因此输入宽高最好能被 32 整除(如 224、256、512、1024)。否则应先补零(Zero Padding)或重采样。
2. 实务几乎都从 Fine-tuning 开始
用少量业务数据从零训练 Swin 往往成本高且易过拟合。行业标准是直接加载 Hugging Face / MMDetection 的预训练权重(如 Swin-T、Swin-B),替换任务头后做微调。
3. OOM 生存指南
即使是近线性复杂度,Swin 仍然不轻。若显存溢出,按顺序处理:
1) 减小 batch size;
2) 降低 输入分辨率(如 1024→512);
3) 缩小 窗口大小 MMM(如 7→5)。
额外技巧:开启 `gradient_checkpointing=True` 可进一步省显存(会牺牲一些速度)。

小结

💡【核心速记卡】
* 运行哲学: 能拆就拆(W-MSA),隔墙就错位再连(SW-MSA),层数加深就四合一压缩(Patch Merging)。
* 在视觉AI中的意义: 既缓解了高分辨率下 Transformer 的 O(N2)O(N^2)O(N2) 致命瓶颈,又吸收了 CNN 的多尺度金字塔优势。
* 工程心法: “输入尽量用 32 倍数,训练先用预训练权重,显存不足时优先调 batch 和窗口大小 MMM。”

解题提示

Swin 解题通用框架(先看这3行)
1) 先判断题目在问 W-MSA / SW-MSA / Patch Merging 哪一块。
2) 若是计算题,先整理单位:token 数 NNN、窗口大小 MMM、通道 CCC、分辨率 (H,W)(H,W)(H,W)。
3) 只要出现 merging,立刻套用基础变换:
* 空间:(H,W)→(H2,W2)(H,W)\rightarrow(\frac{H}{2},\frac{W}{2})(H,W)→(2H​,2W​)
* 通道:C→4C→2CC\rightarrow4C\rightarrow2CC→4C→2C

示例(概念)
“哪个不是 Swin 核心?”→ 每层固定全局注意力不是核心。

示例(判断)
“Patch Merging 会提升空间分辨率。”→ 错。答案 0

示例(场景)
“高分辨率 OOM,先做什么?”→ 按 batch↓、分辨率↓、窗口大小(MMM)↓顺序排查。
按题型快速示例(版式与 Ch05 相同,仅内容换成 Swin)
示例(选择计算)
“N=20N=20N=20 时全局注意力的 N2N^2N2 是多少?”→ 202=40020^2=400202=400。

示例(累积合并)
“Patch Merging 做3次后分辨率是多少?”→ 每次边长减半,3次后为 1/81/81/8,即 (H8,W8)(\frac{H}{8},\frac{W}{8})(8H​,8W​)。

示例(配置/计数)
“正方网格每边 8 个窗口,总窗口数?”→ 82=648^2=6482=64。

示例(综合推理)
“输入 224×224224\times224224×224,patch 4×44\times44×4,做1次 merging 后 token 数?”
1) 初始 token:(224/4)2=562=3136(224/4)^2=56^2=3136(224/4)2=562=3136
2) 一次 merging:3136/4=7843136/4=7843136/4=784
→ 答案 784

实战一句话
计算题先分清“谁变成 1/2、谁变成 2 倍”就能快速落地。(空间变小,语义通道变深)