1

U-net介绍

今天来介绍一个经典的语义分割网络U-net, 它于2015年提出,最初应用在医疗影像分割任务上,由于效果很好,之后被广泛应用在各种分割任务中。至今已衍生出许多基于U-net的分割模型。
U-net是典型的Encoder-Decoder结构,encoder进行特征提取,decoder
进行上采样。由于数据的限制,U-net在训练阶段使用了大量的数据增强操作,最后得到了不错的效果。

U-net网络结构

U-net的网络结构如下所示。左边为encoder部分,对输入进行下采样,下采样通过最大池化实现;右边为decoder部分,对encoder的输出进行上采样,恢复分辨率,上采样通过Upsample实现;中间为跳跃连接(Skip-connect),进行特征融合。由于整个网络形似一个"U",所以称为U-net。
网络中除了最后的输出层,其余所有卷积层均为3 * 3卷积。
未命名图片.png

U-net代码实现

import torch as t
import torch.nn as nn

class  DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.dconv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            
            # inplace设为True可以节省显存/内存
            nn.ReLU(inplace=True),
            
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, img):
        return self.dconv(img)
        
# 下采样
class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.down = nn.Sequential(
            nn.MaxPool2d(2, 2),
            DoubleConv(in_channels, out_channels)
        )
    def forward(self, img):
        return self.down(img)

# 上采样
class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()
        # ConvTranspose2D 有可学习的参数, 会在训练过程中不断调整参数。会增加模型的复杂度,可能会造成过拟合
        # Upsample 没有可学习的参数
        # 和Conv2d和MaxPooling2d的区别一样
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # pading 保证x1和x2的大小一样
        dx = x2.shape[3] - x1.shape[3]
        dy = x2.shape[2] - x1.shape[2]
        x1 = nn.functional.pad(x1, [dx // 2, dx - dx // 2, dy // 2, dy - dy // 2])
        # 通道合并
        x = t.cat([x1, x2], dim=1)
        return self.conv(x)


# 主网络
class CrackUnet(nn.Module):
    def __init__(self, channels, classes, bilinear=True):
        super(CrackUnet, self).__init__()
        self.channels = channels
        self.classes = classes
        self.bilinear = bilinear
        # 
        self.inconv = DoubleConv(self.channels, 64)

        # 4个下采样层
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        # 4个上采样层, 采用双线性采样
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outconv = nn.Conv2d(64, channels, 1)

    def forward(self, img):
        img = self.inconv(img)
        down1 = self.down1(img)
        down2 = self.down2(down1)
        down3 = self.down3(down2)
        down4 = self.down4(down3)
        x = self.up1(down4, down3)
        del down4
        del down3
        x = self.up2(x, down2)
        del down2
        x = self.up3(x, down1)
        del down1
        x = self.up5(x, img)
        del img
        return self.outconv(x)

总结

U-net结构简单稳定,是典型的下采样+上采样的分割网络结构。尤其在数据集较小的时候,推荐使用。


mhxin
84 声望15 粉丝