ResNet

ResNet

引言

之前对ResNet网络架构有了解一些,大概知道它能在深度网络训练中解决梯度爆炸和消失的问题,至于为什么work,只明白了一个shotcut connection起的作用,关于数学背后的意义其实不知道。这次因为要复现Pix2Pix和CycleGAN,急急忙忙看了一些ResNet的文章,论文瞄了一哈,先给自己挖一个小坑,最后大喊一声,凯神威武!

ResNet和之前的CNN网络不同

背景

从经验上来看,CNN网络越深越能提取一些复杂的特征。在之前CNN网络架构的尝试上也确实如此,越深的网络取得的特征后利用分类算法得到的识别率更高。但是深的网络会出现Degradation problem(退化问题),主要表现在网络越深,参数过多,梯度爆炸和梯度消失问题使得很难训练。而ResNet的提出,在不改变网络深度的情况下,利用trick解决了训练问题。

初窥

刚才提出了退化问题,推测凯神的想法:如何保证前一层到后一层不会出现导致学习变差的情况。是否有一种方法可以使得前层到后层最坏的结果控制在恒等映射附近一点点,也就是在极端的情况下,后层什么都没学习,复制前层学习到的特征。那在一般的情况下,就是在前一层的特征基础上,训练新的特征。

短路连接

分析

假设刚开始输入$x$,原始学习到的是$H(x)$,ResNet的思想是希望网络学习的是一个残差值即$F(x)=H(x)-x$,学习到的特征变成$F(x)+x$,在前一层的基础上多学了一个残差。当残差为0时,前层与后层为恒等映射。

shortcut connection

贴上论文里面的Residual learning 结构图

有了这样示意图,可以简单从数学分析角度看看为什么ResNet能解决梯度爆炸梯度消失问题。

那么关于$x_l$的梯度:

公式中的1的存在保证了梯度不会消失,意味着残差学习更加容易。另外关于防止梯度爆炸,一般随着迭代,会不断decay学习率。

不同层数的ResNet

小trick

  1. 关于residual block的设计一般有两种,分别对应着浅层和深层。
  2. 当输入输出维度不一样的时候,两种方法:第一种1x1的卷积核,第二种,采用下采样。但是pytorch代码中用的是3x3卷积加一个BatchNorm。

    ResNet block code

    Residual block 给出pytorch版本Residual block,有了这个搭18层的ResNet。
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
    super(ResidualBlock, self).__init__()
    self.conv1 = conv3x3(in_channels, out_channels, stride)
    self.bn1 = nn.BatchNorm2d(out_channels)
    self.relu = nn.ReLU(inplace=True)
    self.conv2 = conv3x3(out_channels, out_channels)
    self.bn2 = nn.BatchNorm2d(out_channels)
    self.downsample = downsample

    def forward(self, x):
    residual = x
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)
    out = self.conv2(out)
    out = self.bn2(out)
    if self.downsample:
    residual = self.downsample(x)
    out += residual
    out = self.relu(out)
    return out

decay 学习率

1
2
3
def update_lr(optimizer, lr):    
for param_group in optimizer.param_groups:
param_group['lr'] = lr

reference

Deep Residual Learning for Image Recogniton
Identity Mappings in Deep Residual Networks