引言

​ 本论文介绍了一种U型网络结构,用于语义分割。它其实基于一种编码与解码的思想,可以有效的结合低分辨率的信息和高分辨率的信息,能够更好的分割图像边缘。它与FCN同一年提出,在思想上上也类似,但是u-net用了完全对称的结构,以及在拼接图像时用的不是像素的相加,而是通道的叠加,在我实际使用过程中发现比FCN-8有更好的分割精度,这得益于它的对称结构连接了更多的图像语义信息,而FCN-8则相对较少。本论文作者为Olaf Ronneberger, Philipp Fischer, Thomas Brox

原论文

网络结构

​ 本篇虽然给出了具体的网络结构,但这篇论文对我们更重要的影响是它使用的这种U型编码解码结构,至今的许多结构都是基于这种结构进行的完善和改进。原文中给出的网络结构如下JZikzF.png图例中也很明确的写出了每一步的操作是什么(其中卷积层的激活函数都使用的事ReLU),最终的输出就是基于每个像素的softmax分类。可以看到整个网络结构就是个U型的,左半部分相当于是编码部分,把原尺寸的图像进行特征提取压缩形成一个热图,然后再对其进行上采样的同时与原先的图像层进行拼接,有助于帮助恢复上采样中丢失的细节信息。

训练

权重图

​ 论文中给出了一个weight map的计算方法,它在训练之前就预先算好。它的作用就在于对于相连的物体有很好的分割作用,它给予相连物体之间的边界背景在损失函数中有极大的权重。下面是它的计算公式$$w(x) = w_c(x)+w_0.exp(-\frac{(d_z(x)+d_2(x))^2}{2\sigma^2})$$

$w_c$:是用于平衡类频率的权重图

$d_1$:表示当前像素到达最近边界单元格的距离

$d_2$:表示当前像素到达第二近边界单元格的距离

$w_0$和$\sigma$都是超参数。

(实际上我对这个权重图的方法的计算细节并不是特别清楚,网上搜寻后也没有人对这个方法有解答,在实际的实现中并没有使用它的这个方法。)

overlap-tile

这是作者论文中有一个令人有点困惑的地方。我的理解是,作者原文卷积操作并没有使用same卷积而是用了valid卷积,导致图像越卷越小并且丢失了一些边缘信息,那为何不用padding填充。作者的意思可能是输入图片非常的大,由于显存限制需要把一张图片分割后再进行拼接,使用overlap-tile的方法就能使得合并后的边缘更加合理。那么overlap-tile具体是怎么做,例如是左上角的边缘信息,那么我们就把它右边和下面的一部分图像做镜像填充到上面和左边再进行卷积,如下图JZiEM4.png黄色是卷积核卷积的部分,它的左边和上边都被右边和下边的镜像填充了(由于图像不大,这个方法我也没有进行尝试)。

数据增强

由于训练集不大以及增强网络的不变性和鲁棒性,需要使用一些增强数据的方式,文章是对细胞图像进行的分割,所以使用了弹性形变增强数据,这也符合细胞具有的生物学特性。实际使用过程中也可以根据实际情况使用其它的如平移、旋转等方式对进行图像增强。

实验

实验使用了tensorflow2中keras实现,利用了在imagenet上已经训练好的vgg-16网络中的前14层并设置不再更新这些层的参数。数据集利用了这里的数据并且卷积层都使用了same卷积,下面就是网络主体的框架。

vgg16_model = tf.keras.applications.vgg16.VGG16(weights='imagenet', include_top=False, input_tensor=keras.Input(shape=(320, 320, 3)))
vgg16_model.trainable = False
vgg16_model.summary()
class unet_model(tf.keras.Model):
    def __init__(self,n_class):
      super().__init__()
      self.n_class = n_class
      self.vgg16_model = vgg16_model
      self.conv1_1 = vgg16_model.layers[1]
      self.conv1_2 = vgg16_model.layers[2]
      self.pool1 = vgg16_model.layers[3]
     
      self.conv2_1 = vgg16_model.layers[4]
      self.conv2_2 = vgg16_model.layers[5]
      self.pool2 = vgg16_model.layers[6]
        
      self.conv3_1 = vgg16_model.layers[7]
      self.conv3_2 = vgg16_model.layers[8]
      self.conv3_3 = vgg16_model.layers[9]
      self.pool3 =  vgg16_model.layers[10]
       
      self.conv4_1 = vgg16_model.layers[11]
      self.conv4_2 = vgg16_model.layers[12]
      self.conv4_3 = vgg16_model.layers[13]
      self.pool4 = vgg16_model.layers[14]
        
      self.conv6 = Conv2D(1024,(3,3),(1,1),padding="same",activation="relu")
      self.conv7 = Conv2D(512,(3,3),(1,1),padding="same",activation="relu")
      self.conv_t1 = Conv2DTranspose(512,(2,2),(2,2),padding="same")
      self.fuse_1 = Concatenate()
      self.conv8 = Conv2D(512,(3,3),(1,1),padding="same",activation="relu")
      self.conv9 = Conv2D(256,(3,3),(1,1),padding="same",activation="relu")
      self.conv_t2 = Conv2DTranspose(256,(2,2),(2,2),padding="same")
      self.fuse_2 = Concatenate()
      self.conv10 = Conv2D(256,(3,3),(1,1),padding="same",activation="relu")
      self.conv11 = Conv2D(128,(3,3),(1,1),padding="same",activation="relu")
      self.conv_t3 = Conv2DTranspose(128,(2,2),(2,2),padding="same")
      self.fuse_3 = Concatenate()
      self.conv12 = Conv2D(128,(3,3),(1,1),padding="same",activation="relu")
      self.conv13 = Conv2D(64,(3,3),(1,1),padding="same",activation="relu")
      self.conv_t4 = Conv2DTranspose(64,(2,2),(2,2),padding="same")
      self.fuse_4 = Concatenate()
      self.conv14 = Conv2D(64,(3,3),(1,1),padding="same",activation="relu")
      self.conv15 = Conv2D(64,(3,3),(1,1),padding="same",activation="relu")
      self.conv16 = Conv2D(n_class,(1,1),(1,1),padding='same',activation='softmax')
    
        
    def call(self,input):

      x = self.conv1_1(input)
      x_1 = self.conv1_2(x)
      x = self.pool1(x_1)
      x = self.conv2_1(x)
      x_2 = self.conv2_2(x)
      x = self.pool2(x_2)
      x = self.conv3_1(x)
      x = self.conv3_2(x)
      x_3 = self.conv3_3(x)
      x = self.pool3(x_3)
      x = self.conv4_1(x)
      x = self.conv4_2(x)
      x_4 = self.conv4_3(x)
      x = self.pool4(x_4)

      x = self.conv6(x)
      x = self.conv7(x)
      x = self.conv_t1(x)
      x = self.fuse_1([x,x_4])

      x = self.conv8(x)
      x = self.conv9(x)
      x = self.conv_t2(x)
      x = self.fuse_2([x,x_3])

      x = self.conv10(x)
      x = self.conv11(x)
      x = self.conv_t3(x)
      x = self.fuse_3([x,x_2])

      x = self.conv12(x)
      x = self.conv13(x)
      x = self.conv_t4(x)
      x = self.fuse_4([x,x_1])

      x = self.conv14(x)
      x = self.conv15(x)
      x = self.conv16(x)

      return x

经过100轮训练,最终能在训练集上达到95%精度,如下图是在训练集上的测试。JZPZUP.png

​ 但如果在测试集上就只有89%左右。由于训练集中人出现的少,测试集大多为人,可能在测试集效果不佳,可以把训练集和数据集加一起打乱后再分割成训练集和数据集再训练,以及利用数据增强的方法,应该会让模型的泛化能力更好。


MatthewY
6 声望2 粉丝