SWinIR概述
1.概要
图像恢复是一个长期存在的低水平视觉问题,它旨在从低质量的图像(例如,缩小比例、有噪声和压缩的图像)中恢复高质量的图像。该文提出了一种基于Swin变换的强基线模型SwinIR。SwinIR由浅层特征提取、深度特征提取和高质量的图像重建三部分组成。特别地,深度特征提取模块由几个残余的Swin变压器块(RSTB)组成,每个块都有几个Swin变压器层和一个残余连接。该文对三个具有代表性的任务进行了实验:图像超分辨率(包括经典、轻量级和真实世界的图像超分辨率)、图像去噪(包括灰度和彩色图像去噪)和JPEG压缩伪影减少。实验结果表明,SwinIR在不同任务上的性能比最先进的方法高达0.14~0.45dB,而参数总数可减少67%。
参数量和PSNR关系:
作为CNN的替代品,Transformer设计了一种自我注意机制来捕获上下文之间的全局交互,并在几个视觉问题中显示出了良好的性能。然而,用于图像恢复的Vision Transformers通常将输入图像分割成固定大小的patch(如48×48),并独立处理每个patch。这种策略不可避免地会带来两个缺点。首先,边界像素不能利用补丁之外的邻近像素进行图像恢复。其次,恢复后的图像可能会在每个补丁周围引入边界伪影。虽然这个问题可以通过patch重叠来缓解,但它会带来额外的计算负担。
最近,Swin Transformer集成了CNN和Transformer的优势,显示出了巨大的前景。一方面,由于局部注意力机制,具有CNN处理大尺寸图像的优势。另一方面,它具有对移位窗口方案进行远程依赖建模的优势。
2. 模型介绍
更具体地说,SwinIR由三个模块组成:浅层特征提取、深度特征提取和高质量的图像重建模块。浅层特征提取模块采用卷积层提取浅层特征,并直接传输到重建模块,以保留低频信息。深度特征提取模块主要由剩余的Swin Transformer块(RSTB)组成,每个Transformer块利用多个Swin Transformer进行局部注意和跨窗口交互。此外,我们在块的末尾添加了一个卷积层,用于特征增强,并使用了一个残差连接,为特征聚合提供了一个快捷方式。最后,在重建模块中融合浅层和深度特征,进行高质量的图像重建。
2.1 网络结构
如下图所示,SwinIR由三个模块组成:浅层特征提取、深度特征提取和高质量(HQ)图像重建模块。我们对所有的恢复任务使用相同的特征提取模块,但对不同的任务使用不同的重构模块。
2.1.1 Shallow and deep feature extraction
给定一个低质量(LQ)输入\(I_LQ∈R^{H×W×C_{in}}\)(H,W和\(C_{in}\)分别是图像的高度、宽度和输入通道数),我们使用一个3×3卷积\(H_{SF}(\cdot)\)来提取浅层特征\(F_0∈R^{H\times W\times C_{in}}\):
\[F_0=H_{SF}(I_{LQ}) \tag{1} \]其中,C为特征通道号。卷积层擅长于较浅的视觉处理,导致更稳定的优化和更好的结果。它还提供了一种简单的方法来映射输入的图像空间到一个更高维的特征空间。然后,我们从\(F_0\)中提取深度特征\(F_{DF}\in R^{H\times W\times C}\):
\[F_{DF}=H_{DF}(F_0) \tag{2} \]其中,\(H_{DF}(\cdot)\)为深度特征提取模块,其中包含K个残余的Swin变压器块(RSTB)和一个3×3的卷积层。更具体地说,中间特征\(F_1、F_2、……、F_K\)和输出的深度特征\(F_{DF}\)被逐块提取为:
\[F_i=H_{RSTB_i}(F_{i-1}),\ i=1,2,...,K \\ F_{DF}=H_{CONV}(F_K) \]其中\(H_{RSTB_i}(\cdot)\)表示第i个RSTB,\(H_{CONV}\)为最后一个卷积层。在特征提取的最后使用卷积层可以将卷积操作的归纳偏置引入到基于Transformer的网络中,为后期浅、深特征的聚合奠定了更好的基础。
2.1.2 Image reconstrcution
以图像SR为例,将浅层和深层特征聚合为特征,重建高质量的图像\(I_{RHQ}\):
\[I_{RHQ}=H_{REC}(F_0+F_{DF}) \tag{4} \]其中,\(H_{REC}(\cdot)\)为重构模块的函数。浅层特征主要包含低频,而深层特征则侧重于恢复丢失的高频。SwinIR通过shortcut过连接,可以将低频信息直接传输到重构模块,帮助深度特征提取模块专注于高频信息,稳定训练。对于重构模块的实现,我们使用亚像素卷积层来对特征进行上采样。
对于不需要上采样的任务,如图像去噪和\(JPEG\)压缩伪影减少,使用单一的卷积层进行重建。此外,我们利用残差学习来重建LQ和HQ图像之间的残差,而不是HQ图像。这被表述为:
\[I_{RHQ}=H_{SwinIR}(I_{LQ})+I_{LQ}\tag{5} \]2.2 Residual Swin Transformer Block
剩余Swin变压器块(RSTB)是具有Swin变压器层(STL)和卷积层的剩余块。给定第\(i\)个RSTB的输入特征\(F_{i,0}\),我们首先提取中间特征\(F_{i,1},F_{i,2},…,F_{i,L}\):
\[F_{i,j}=H_{STL_{i,j}}(F_{i,j-1}), j=1,2,...,L \tag{8} \]其中,\(H_{STL_{i,j}}(\cdot)\)是第\(i\)个RSTB中的第\(j\)个Swin变压器层。然后,我们在剩余连接之前添加一个卷积层。RSTB的输出公式为:
\[F_{i,out}=H_{CONV_i}(F_{i,L})+F_{i,0} \tag{9} \]其中,\(H_{CONV_i(\cdot)}\)是第i个RSTB中的卷积层。这种设计有两个好处。首先,虽然Transformer可以看作是空间变化卷积的一个特定实例,但具有空间不变滤波器的卷积层可以增强SwinIR的平移等方差。其次,残差连接提供了从不同块到重建模块的基于身份的连接,允许聚合不同级别的特征。
2.2.1 Swin Transformer layer
Swin变压器层(STL)是基于原变压器层的标准多头自关注来实现的。主要区别在于局部注意和窗口机制。如上图所示,给定大小为H×W×C的输入,Swin变压器首先通过将输入划分为不重叠的\(M×M\)局部窗口,来重构\(\frac{HW}{M^2}×M^2×C\)特征的输入,其中\(\frac{HW}{M^2}\)为窗口总数。然后,它分别计算每个窗口的标准自注意(即局部注意)。对于局部窗口特征\(X\in R^{M^2\times C}\),查询、键和值矩阵Q、K和V计算为:
\[Q=XP_Q,K+XP_K,V=XP_V \tag{10} \]其中,\(P_Q、P_K\)和\(P_V\)是跨不同窗口共享的投影矩阵。一般来说,我们有\(Q,K,V∈R^{M^2×d}\)。因此,利用局部窗口中的自注意机制计算出的注意矩阵为:
\[Attention(Q,K,V)=SoftMax(QK^T/\sqrt{d}+B)V, \tag{11} \]其中B是可学习的相对位置编码,注意这里和绝对位置编码是不同的。在实践中,在[76]之后,我们并行执行h次的注意函数,并将多头自注意(MSA)的结果连接起来
绝对位置编码体现在:
相对位置编码在WindowAttention中:
relative position最终加在了attention map上面。
q其中attention map是由于Shifted Window操作导致右下方的某些window像素不连续,需要对权重进行重新赋值。
具体参考:SWin Transformer
接下来,使用一个多层感知器(MLP),它有两个完全连接的层,它们之间具有GELU非线性,以进行进一步的特征变换。在MSA和MLP之前都添加了LayerNorm(LN)层,并且对这两个模块都采用了残差连接。整个过程被表述为:
\[X=MSA(LN(X))+X \\ X=MLP(LN(X))+X \tag{12} \]2.3 Loss function
对于图像SR,我们通过最小化\(L1\)像素损失来优化SwinIR的参数:
\[L=||I_{RHQ}-I_{HQ}||_1 \tag{6} \]其中,\(I_{RHQ}\)以\(I_{LQ}\)为SwinIR的输入,\(I_{HQ}\)为对应的Ground Truth。对于真实世界的图像SR,我们使用像素损失、GAN损失和感知损失的组合来提高视觉质量。
3. 实验结果
4. 总结
(1)内容的交互作用,可以解释为空间变化的卷积。
(2)的远程依赖建模可以通过移位的窗口机制来启用。
(3)用更少的参数具有更好的性能。例如,如图1所示,与现有的图像SR方法相比,SwinIR以更少的参数实现了更好的PSNR。
5. 测试
我用的预训练模型:001_classicalSR_DIV2K_s48w8_SwinIR-M_x4.pth分别在Set5,Set14,RealSRSet+5上做测试:
PSNR | SSIM | PSNR-Y | SSIM-Y | |
---|---|---|---|---|
set5 | 30.80 | 0.8728 | 32.72 | 0.9021 |
set14 | 18.91 | 0.5483 | 20.41 | 0.5890 |
RealSRSet+5 | 16.78 | 0.5004 | 18.23 | 0.5379 |
发现在Set5上结果与论文一致,但在其他数据集有着较大差距。
源代码某些疑惑
1. PatchMerging函数
该函数应用在:
\[RSTB->BasicLayer(downsample参数) \]再RSTB的forward函数中:
\[self.patch\_unembed(self.residual\_group(x, x\_size), x\_size)) \]如果BasicLayer采用downsample(PatchMerging),则最后输出tensor再通过patch_unembed是有问题的。
我们首先来看DownSample函数,也就是PatchMerging函数。
Downsample的输出tensor维度:$$size: [B,H/2W/2,2C]$$
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
# x size: [B H/2*W/2 2*C]
return x
PatchUnEmbed函数定义如下:
class PatchUnEmbed(nn.Module):
r""" Image to Patch Unembedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
def forward(self, x, x_size):
B, HW, C = x.shape
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
return x
def flops(self):
flops = 0
return flops
注意改行:x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1])
此处数据输入维度应为:$$[B,C,H,W]$$
由前述分析可知,经过DownSample的BasicLayer的输出数据维度为:$$[B,H/2W/2,2C]$$
是无法转换为$$[B,C,H,W]$$的(差两倍),会报错,如下图。
修改方法:
小编认为:
2. Patch_size
Patch_size的设置在初始化SwinIR的时候,默认参数为1,为1的时候是不会报错的。
patch_size的应用主要在PatchEmbed和Patchunembed两个函数中。
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# merge non-overlapping patches into image
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
PatchEmbed函数定义如下:
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
可见,在forward函数中,对于输入\(x\)的处理和patch_size没有任何关系。
我们再来看SwinIR中的forward函数,其中forward_features函数用到了上述函数:
def forward_features(self, x):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
if self.ape:
# torch.zeros(1, num_patches, embed_dim)
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x, x_size)
x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
return x
如果我们的patch_size不为1,比如为2,那么经过上述patch_embed函数,输出宽高维度依然不变(因为forward函数中与patch_size没有关系);但是该函数内的num_patches会发生变化,缩减为原来的四倍,此时若在进行加法:$$x = x + self.absolute_pos_embed$$会报错。
那为什么要加上patch_size参数呢?
小编猜测,SWinIR借用了Swin Transformer的S-MSA结构,每个window都会做self-Attention操作,但是每个window里面的token数目就是像素数目,加上patch_size是希望通过缩小图片来减小参数量和计算量。