头图

这篇文章主要作为【深度学习个人笔记】从实际深度学习项目来学习深度学习1——pix2pixHD开源项目学习:训练部分中的一部分,对训练部分即train.py的第三部分——模型与优化器相关定义进行讲解,这一部分因为内容较多,因此在这篇文章中分成两个部分:模型定义与优化器定义

模型定义

train.py中的模型定义部分代码如下:

model = create_model(opt)

可见train.py中的这一部分非常简短,这是因为很多定义都放在了别的.py文件里,这里model是create_model函数返回值,create_model函数来自项目的models文件夹中的models.py文件,函数的代码如下所示:

def create_model(opt):
    if opt.model == 'pix2pixHD':
        from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
        if opt.isTrain:
            model = Pix2PixHDModel()
        else:
            model = InferenceModel()
    else:
        from .ui_model import UIModel
        model = UIModel()
    model.initialize(opt)
    if opt.verbose:
        print("model [%s] was created" % (model.name()))

    if opt.isTrain and len(opt.gpu_ids) and not opt.fp16:
        model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)

    return model

 可见create_model函数主要完成生成模型的任务,先来看一下函数的前半部分,即通过命令行参数的opt.model来判断生成哪种模型,当opt.model为'pix2pixHD'时显然就是生成pix2pixHD的模型,因此先来看一下这部分(至于另一部分中的UIModel是调用了pix2pixhd的ui界面,这里就不考虑了)
生成pix2pixHD的模型时,函数引入了models文件夹下的pix2pixHD_model.py中的Pix2PixHDModel类InferenceModel类分别对应训练阶段和测试阶段

而在create_model函数的后半部分,则首先调用了返回的类中的.initialize()方法;然后根据opt.verbose参数(这个参数在项目中用于表示一些信息是否通过print打印出来,默认False)打印一下.name()方法的结果;最后根据是否是训练过程+GPU index的数量+opt.fp16参数(这个参数表示是否使用自动混合精度(AMP)进行训练)决定是否使用torch.nn.DataParallel()函数来进行多GPU并行训练,因此我们将注意力集中在上面说到的Pix2PixHDModel类和InferenceModel类即可,这两个类就包含了模型定义的关键代码,同时在pix2pixHD_model.py文件中可以注意到,Pix2PixHDModel类实际上是BaseModel类的子类,而InferenceModel类实际上是Pix2PixHDModel类的子类,因此后续分节来介绍这三个类

BaseModel类

BaseModel类的代码如下所示:

class BaseModel(torch.nn.Module):
    def name(self):
        return 'BaseModel'

    def initialize(self, opt):
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.isTrain = opt.isTrain
        self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)

    def set_input(self, input):
        self.input = input

    def forward(self):
        pass

    # used in test time, no backprop
    def test(self):
        pass

    def get_image_paths(self):
        pass

    def optimize_parameters(self):
        pass

    def get_current_visuals(self):
        return self.input

    def get_current_errors(self):
        return {}

    def save(self, label):
        pass

    # helper saving function that can be used by subclasses
    def save_network(self, network, network_label, epoch_label, gpu_ids):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        torch.save(network.cpu().state_dict(), save_path)
        if len(gpu_ids) and torch.cuda.is_available():
            network.cuda()

    # helper loading function that can be used by subclasses
    def load_network(self, network, network_label, epoch_label, save_dir=''):        
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        if not save_dir:
            save_dir = self.save_dir
        save_path = os.path.join(save_dir, save_filename)        
        if not os.path.isfile(save_path):
            print('%s not exists yet!' % save_path)
            if network_label == 'G':
                raise('Generator must exist!')
        else:
            #network.load_state_dict(torch.load(save_path))
            try:
                network.load_state_dict(torch.load(save_path))
            except:   
                pretrained_dict = torch.load(save_path)                
                model_dict = network.state_dict()
                try:
                    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}                    
                    network.load_state_dict(pretrained_dict)
                    if self.opt.verbose:
                        print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
                except:
                    print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
                    for k, v in pretrained_dict.items():                      
                        if v.size() == model_dict[k].size():
                            model_dict[k] = v

                    if sys.version_info >= (3,0):
                        not_initialized = set()
                    else:
                        from sets import Set
                        not_initialized = Set()                    

                    for k, v in model_dict.items():
                        if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
                            not_initialized.add(k.split('.')[0])
                    
                    print(sorted(not_initialized))
                    network.load_state_dict(model_dict)                  

    def update_learning_rate():
        pass

可见BaseModel类是torch.nn.Module的子类,在使用Pytorch编写深度学习代码时如果需要自己创建网络模型则需要继承这个类,在BaseModel类中很多只是定义了一个方法,然后用pass占个位,因此这里就主要关注一些有内容的方法:

  • initialize()方法:初始化了一些类的属性,其中opt、gpu_ids和isTrain在之前都介绍过,因此这里就主要介绍Tensor和save_dir
    save_dir属性显然是存储的路径,通过os.path.join()函数将opt.checkpoints_dir(默认是项目下的checkpoints文件夹)和opt.name拼接起来作为存储路径

Tensor属性则是根据是否使用了GPU来定义张量的数据类型,Pytorch中的tensor包括CPU上的数据类型和GPU上的数据类型,GPU上的张量一般由CPU上的张量加.cuda()方法得到

  • save_network()方法:根据传入的network_label和epoch_label参数确定模型文件.pth的文件名,然后根据save_dir属性确定模型文件的保存路径,最后用torch.save()函数保存模型参数(.state_dict()方法返回一个字典对象,将每一层与它的对应参数建立映射关系,同时只有可以训练的layer才会被保存)
  • load_network()方法:根据传入的network_label和epoch_label参数确定模型文件.pth的文件名,然后根据save_dir属性确定模型文件的保存路径,然后使用os.path.isfile()函数判断模型文件是否存在,如果存在则使用一系列的try/except语句来捕捉异常
    首先直接.load_state_dict()torch.load()函数加载模型参数,state_dict是一个字典,该字典中包含了模型各层和其参数tensor的对应关系,有关Pytorch获取模型的参数,可以参考下面的博客:
    【PyTorch技巧1】详解pytorch中的state_dict
    Pytorch保存和加载模型(load和load_state_dict)
    这里贴出一个state_dict的示例,有助于理解:
    state_dict的示例

如果引发了异常则说明传入的network参数,即需要加载模型参数的网络有些参数无法加载,故在调用一次try/except语句,分别定义pretrained_dict(来自保存的模型参数)和model_dict(来自传入的网络参数),需要把pretrained_dict中不属于model_dict的键剔除掉,代码中通过.items()方法将字典中的每对key和value组成一个元组,并且只保留在model_dict中存在的键,然后再通过.load_state_dict()方法加载参数,也就是下面这两行代码:

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}                
network.load_state_dict(pretrained_dict)

如果还引发了异常,则说明虽然参数字典中的键一样,但是值的尺寸可能不一样,因此通过.size()方法判断字典中的值的尺寸是否一样,如果一样则改变model_dict字典的键对应的值

 最后在load_network()函数中还打印出了没有进行初始化的网络层,保存到Python中的集合(set)中,首先程序中根据sys.version_info判断Python的版本,version_info是sys模块中的一个函数,主要用于返回你当前所使用的Python版本号,是一个包含了版本号5个组成部分的元祖,这5个部分分别是主要版本号(major)、次要版本号(minor)、微型版本号(micro)、发布级别(releaselevel)和序列号(serial),可见如果Python是3.0版本往上(现在应该都是用的是3.0以上的Python版本),用set()函数初始化not_initialized为空集合

Python中的set和dict一样,只是没有value,相当于dict的key集合。由于dict中的key不能重复,所以在set中没有重复的元素,故集合(set)是一个无序的不重复元素序列,然后根据model_dict中的键在不在pretrained_dict中以及元素的尺寸是否不匹配将相应的键添加到集合中并打印出来,然后仍然是通过.load_state_dict()方法加载模型参数

Pix2PixHDModel类

Pix2PixHDModel类的代码方法很多,比较长,因此这里根据不同方法分小节来进行讲述,从初始化,即.initialize()方法开始

initialize()方法

单独initialize()方法的代码如下所示:

def initialize(self, opt):
    BaseModel.initialize(self, opt)
    if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
        torch.backends.cudnn.benchmark = True
    self.isTrain = opt.isTrain
    self.use_features = opt.instance_feat or opt.label_feat
    self.gen_features = self.use_features and not self.opt.load_features
    input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc

    ##### define networks        
    # Generator network
    netG_input_nc = input_nc        
    if not opt.no_instance:
        netG_input_nc += 1
    if self.use_features:
        netG_input_nc += opt.feat_num                  
    self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, 
                                  opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, 
                                  opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)        

    # Discriminator network
    if self.isTrain:
        use_sigmoid = opt.no_lsgan
        netD_input_nc = input_nc + opt.output_nc
        if not opt.no_instance:
            netD_input_nc += 1
        self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, 
                                      opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)

    ### Encoder network
    if self.gen_features:          
        self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', 
                                      opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids)  
    if self.opt.verbose:
            print('---------- Networks initialized -------------')

    # load networks
    if not self.isTrain or opt.continue_train or opt.load_pretrain:
        pretrained_path = '' if not self.isTrain else opt.load_pretrain
        self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)            
        if self.isTrain:
            self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)  
        if self.gen_features:
            self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)              

    # set loss functions and optimizers
    if self.isTrain:
        if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
            raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
        self.fake_pool = ImagePool(opt.pool_size)
        self.old_lr = opt.lr

        # define loss functions
        self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)
        
        self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)   
        self.criterionFeat = torch.nn.L1Loss()
        if not opt.no_vgg_loss:             
            self.criterionVGG = networks.VGGLoss(self.gpu_ids)
            
    
        # Names so we can breakout loss
        self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake')

        # initialize optimizers
        # optimizer G
        if opt.niter_fix_global > 0:                
            import sys
            if sys.version_info >= (3,0):
                finetune_list = set()
            else:
                from sets import Set
                finetune_list = Set()

            params_dict = dict(self.netG.named_parameters())
            params = []
            for key, value in params_dict.items():       
                if key.startswith('model' + str(opt.n_local_enhancers)):                    
                    params += [value]
                    finetune_list.add(key.split('.')[0])  
            print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
            print('The layers that are finetuned are ', sorted(finetune_list))                         
        else:
            params = list(self.netG.parameters())
        if self.gen_features:              
            params += list(self.netE.parameters())         
        self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))                            

        # optimizer D                        
        params = list(self.netD.parameters())    
        self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))

可见initialize()方法先调用了父类的initialize()方法,这个在上面BaseModel类中已经讲过,然后设置了几个属性:isTrain、use_features和gen_features,其中use_features根据opt.instance_feat(表示是否将编码后的instance feature作为输入,默认为False)和opt.label_feat(比较是否将编码后的label feature作为输入,默认为False)来定义,这里的label和instance的区别可参见pix2pixHD论文中的一张图片,其中左边是label,右边的instance的边缘图:
label和instance的区别

而gen_features属性则是根据use_features属性和opt.load_features(表示是否加载预先计算出的特征,默认为False)进行定义,在代码用主要用于编码器网络的相关配置

input_nc参数表示输入的通道数量,代码中表示为opt.label_nc当opt.label_nc不为0时,当opt.label_nc=0时,则表示使用自己的数据集,设置为opt.input_nc(opt.label_nc和opt.input_nc都表示输入图像的通道,前者默认为35,后者默认为3)

initialize()方法在定义了一些参数后,后面的内容可以分为定义生成器网络、定义判别器网络、定义编码器网络、加载网络、设置loss函数和优化器几个部分,关于生成器和编码器网络之间的关系可以参考pix2pixHD中的图会比较好理解,如下所示:
生成器和编码器网络
下面就根据每个部分分小节来进行描述:

initialize()方法中的生成器网络

这一部分首先定义了netG_input_nc,等于前面定义的input_nc,表示生成器网络输入通道数,然后根据opt.no_instance(表示是否将instance map加入到输入中,默认为False),因此如果需要加入instance map则输入的通道数加1,;同时根据use_features属性,表示是否需要在输入添加feature map,如果需要加入则输入的通道数加opt.feat_num(表示经过编码器网络后的特征数量,默认为3)

然后定义了netG属性,即生成器网络,这个属性是define_G函数函数的返回值,函数来自同一文件夹下的networks.py文件,函数传入的参数如下所示:

  • netG_input_nc
  • opt.output_nc:表示输出图像的通道数,默认为3
  • opt.ngf:在第一个卷积层的输出通道数(默认为64)
  • opt.netG:选择生成器网络的模型,默认为'global'
  • opt.n_downsample_global:生成器网络中的下采样层数量,默认为4
  • opt.n_blocks_global:在全局生成器网络中残差块的数量,默认为9
  • opt.n_local_enhancers:使用的局部增强器的个数,默认为1
  • opt.n_blocks_local:在局部增强器中残差块的数量,默认为3
  • opt.norm:标准化方法的选择,instance normalization还是batch normalization,默认为'instance'
  • gpu_ids=self.gpu_ids:GPU index

在正式介绍项目中如何构建生成器网络之前,先对生成器网络的基本结构有所了解,pix2pixHD中的生成器被拆分成两个子网络:G1(全局生成器网络,global generator)和G2(局部增强器网络,local enhancer),它们之间的关系如下图所示:
pix2pixHD的生成器结构
论文中给出了生成器的具体网络结构,如下图所示:
pix2pixHD的生成器具体网络结构
其中c7s1-k表示7×7的卷积-实例标准化-ReLU层,其中有k个滤波器,步长为1;dk表示3×3的卷积-实例标准化-ReLU层,其中有k个滤波器,步长为2;Rk表示一个包含2个有同样滤波器个数的3×3卷积层的残差块;uk表示一个3×3的分数步长卷积-实例标准化-ReLU层,其中有k个滤波器,步长为1/2

了解了生成器的网络结构后,就可以回到define_G()函数看看生成器网络如何进行构建,首先根据传入的norm参数,即标准化层的类型定义norm_layer,使用了get_norm_layer()函数,其代码如下所示:

def get_norm_layer(norm_type='instance'):
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer

其中functools.partial()的作用是减少一部分参数,作用就是少传参数,更加简洁,可见'batch'对应的是nn.BatchNorm2d,'instance'对应的是nn.InstanceNorm2d

然后define_G()函数就根据传入的netG参数,即生成器网络的模型来定义网络,分为三个类型:'global'对应全局生成器'local'对应局部增强器'encoder'对应编码器,分别对应GlobalGenerator类LocalEnhancer类以及Encoder类:

  • 全局生成器
    GlobalGenerator类对应代码如下所示:

    class GlobalGenerator(nn.Module):
      def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, 
                   padding_type='reflect'):
          assert(n_blocks >= 0)
          super(GlobalGenerator, self).__init__()        
          activation = nn.ReLU(True)        
    
          model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
          ### downsample
          for i in range(n_downsampling):
              mult = 2**i
              model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
                        norm_layer(ngf * mult * 2), activation]
    
          ### resnet blocks
          mult = 2**n_downsampling
          for i in range(n_blocks):
              model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]
          
          ### upsample         
          for i in range(n_downsampling):
              mult = 2**(n_downsampling - i)
              model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
                         norm_layer(int(ngf * mult / 2)), activation]
          model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]        
          self.model = nn.Sequential(*model)
              
      def forward(self, input):
          return self.model(input)  

    传入类的参数为:input_nc(输入通道数)、output_nc(输出通道数)、ngf(第一个卷积层的输出通道数)、n_downsample_global(生成器网络中的下采样层数量)、n_blocks_global(全局生成器网络中残差块的数量)以及norm_layer

GlobalGenerator类中首先定义了activation,即激活函数,采用了nn.ReLU()函数,即使用ReLU激活函数,然后定义了model,model是一个列表,里面的元素是一个个神经网络层,分别是reflection padding层、卷积层、实例标准化层、ReLU激活函数,对应论文的global network结构中的c7s1-64

然后就是一系列的下采样层,根据n_downsampling在model列表后面继续加神经网络层,在下采样的同时通道数加倍,因此代码中用到了mult,每次循环都乘以2,列表中每次循环都加入卷积层、批标准化层、ReLU激活函数,对应论文的global network结构中的d128,d256,d512,d1024

然后就是一系列的残差模块,经过一系列的下采样层后,特征图的通道数为2的下采样层个数次方,即2**n_downsampling,然后根据n_blocks,即全局生成器网络中残差块的数量,然后在model列表中加入残差模块,残差模块由ResnetBlock类定义,其代码如下所示:

class ResnetBlock(nn.Module):
    def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)

    def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
                       norm_layer(dim),
                       activation]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
                       norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x)
        return out

可见ResnetBlock类定义了残差模块,其中conv_block属性就是残差模块中的卷积块,而前向传播的结果就是输入+输入卷积后的结果

conv_block属性为类的build_conv_block()方法的返回值,在build_conv_block()方法中同样定义了一个列表,将神经网络的每个层传入到列表中,首先是一个padding层,由于这里使用的是reflection padding,因此为nn.ReflectionPad2d(1),然后就是3×3的卷积层、标准化层、ReLU激活函数,这里的卷积层的输入输出通道数是相同的,然后根据use_dropout参数决定是否使用nn.Dropout(0.5),这里的dropout应该是用于模拟生成对抗网络的随机噪声;然后重复一样的过程,即又添加了nn.ReflectionPad2d(1)、卷积层、标准化层、ReLU激活函数;最后通过nn.Sequential(*conv_block)
残差模块对应pix2pixHD论文中给出的global network结构中一连串的R1024

回到GlobalGenerator类中的初始化方法,经过下采样和残差模块后就要定义一系列的上采样层,与之前的下采样层对应,因此同样根据n_downsampling在model列表后面继续加神经网络层,这里加入的是转置卷积层、实例标准化层以及ReLU激活函数,这里的上采样层对应pix2pixHD论文中给出的global network结构中的u512、u256、u128、u64

最后生成器需要生成一张图像,因此再添加nn.ReflectionPad2d(3)、一个卷积层和激活函数,其中卷积输出通道数为3,激活函数为nn.Tanh(),然后设置类的model属性为nn.Sequential(*model),对于前向传播过程,就直接将input传输model属性,即self.model(input)

  • 局部增强器
    LocalEnhancer类的对应代码如下:

    class LocalEnhancer(nn.Module):
      def __init__(self, input_nc, output_nc, ngf=32, n_downsample_global=3, n_blocks_global=9, 
                   n_local_enhancers=1, n_blocks_local=3, norm_layer=nn.BatchNorm2d, padding_type='reflect'):        
          super(LocalEnhancer, self).__init__()
          self.n_local_enhancers = n_local_enhancers
          
          ###### global generator model #####           
          ngf_global = ngf * (2**n_local_enhancers)
          model_global = GlobalGenerator(input_nc, output_nc, ngf_global, n_downsample_global, n_blocks_global, norm_layer).model        
          model_global = [model_global[i] for i in range(len(model_global)-3)] # get rid of final convolution layers        
          self.model = nn.Sequential(*model_global)                
    
          ###### local enhancer layers #####
          for n in range(1, n_local_enhancers+1):
              ### downsample            
              ngf_global = ngf * (2**(n_local_enhancers-n))
              model_downsample = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0), 
                                  norm_layer(ngf_global), nn.ReLU(True),
                                  nn.Conv2d(ngf_global, ngf_global * 2, kernel_size=3, stride=2, padding=1), 
                                  norm_layer(ngf_global * 2), nn.ReLU(True)]
              ### residual blocks
              model_upsample = []
              for i in range(n_blocks_local):
                  model_upsample += [ResnetBlock(ngf_global * 2, padding_type=padding_type, norm_layer=norm_layer)]
    
              ### upsample
              model_upsample += [nn.ConvTranspose2d(ngf_global * 2, ngf_global, kernel_size=3, stride=2, padding=1, output_padding=1), 
                                 norm_layer(ngf_global), nn.ReLU(True)]      
    
              ### final convolution
              if n == n_local_enhancers:                
                  model_upsample += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]                       
              
              setattr(self, 'model'+str(n)+'_1', nn.Sequential(*model_downsample))
              setattr(self, 'model'+str(n)+'_2', nn.Sequential(*model_upsample))                  
          
          self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
    
      def forward(self, input): 
          ### create input pyramid
          input_downsampled = [input]
          for i in range(self.n_local_enhancers):
              input_downsampled.append(self.downsample(input_downsampled[-1]))
    
          ### output at coarest level
          output_prev = self.model(input_downsampled[-1])        
          ### build up one layer at a time
          for n_local_enhancers in range(1, self.n_local_enhancers+1):
              model_downsample = getattr(self, 'model'+str(n_local_enhancers)+'_1')
              model_upsample = getattr(self, 'model'+str(n_local_enhancers)+'_2')            
              input_i = input_downsampled[self.n_local_enhancers-n_local_enhancers]            
              output_prev = model_upsample(model_downsample(input_i) + output_prev)
          return output_prev

    在之前的全局生成器和局部增强器之间的关系图中可以发现全局生成器的输入通道数和局部增强器有关,因此LocalEnhancer类中先根据n_local_enhancers,即局部增强器的数量来计算ngf_global,并生成相应的全局生成器模型,需要注意的是,这里的全局生成器模型需要剔除最后的三层,这是因为有局部增强器时,全局生成器不需要生成图像,而是将生成图像前的feature map和局部增强器的一部分的输出feature map进行element-wise sum,这一部分对应下面的几行代码:

    ngf_global = ngf * (2**n_local_enhancers)
    model_global = GlobalGenerator(input_nc, output_nc, ngf_global, n_downsample_global, n_blocks_global, norm_layer).model        
    model_global = [model_global[i] for i in range(len(model_global)-3)] # get rid of final convolution layers        
    self.model = nn.Sequential(*model_global)   

然后LocalEnhancer类就生成了对应局部增强器的一系列神经网络层,对于有多个局部增强器的情况需要定义多个对应不同局部增强器的神经网络层,这里以1个为例,分析一下局部增强器由什么构成
 首先定义了model_downsample,对应pix2pixHD论文中local enhancer的结构的c7s1-32,d64,即Padding层、卷积层、标准化层、ReLU激活函数、卷积层、标准化层、ReLU激活函数
 然后定义了model_upsample,对应pix2pixHD论文中local enhancer的结构的一系列R64u32,c7s1-3,因此根据n_blocks_local,即局部增强器中残差块的数量,在model_upsample列表中添加残差块,然后添加转置卷积、标准化和ReLU激活函数,最后添加Padding层、卷积层和Tanh激活函数以生成最终的图像

对应每一个局部增强器,代码中都设置了对应的model_downsample和model_upsample属性,使用了setattr()函数来设置属性值,最后还设置一个downsample属性,即全局平均池化(nn.AvgPool2d()函数)

在LocalEnhancer类的前向传播部分,首先根据局部增强器的数量计算出一系列input_downsampled,即用downsample属性进行下采样,用于后面模型的输出,其中最后一级的输出就是全局生成器的输入,得到的结果为output_prev,然后就先用model_downsample部分对图像进行下采样,然后和output_prev进行element-wise sum,迭代生成最后的图像

  • 编码器
    编码器用于在生成器输入前预先提取特征,Encoder类的代码如下所示:

    class Encoder(nn.Module):
      def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d):
          super(Encoder, self).__init__()        
          self.output_nc = output_nc        
    
          model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), 
                   norm_layer(ngf), nn.ReLU(True)]             
          ### downsample
          for i in range(n_downsampling):
              mult = 2**i
              model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
                        norm_layer(ngf * mult * 2), nn.ReLU(True)]
    
          ### upsample         
          for i in range(n_downsampling):
              mult = 2**(n_downsampling - i)
              model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
                         norm_layer(int(ngf * mult / 2)), nn.ReLU(True)]        
    
          model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
          self.model = nn.Sequential(*model) 
    
      def forward(self, input, inst):
          outputs = self.model(input)
    
          # instance-wise average pooling
          outputs_mean = outputs.clone()
          inst_list = np.unique(inst.cpu().numpy().astype(int))        
          for i in inst_list:
              for b in range(input.size()[0]):
                  indices = (inst[b:b+1] == int(i)).nonzero() # n x 4            
                  for j in range(self.output_nc):
                      output_ins = outputs[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]]                    
                      mean_feat = torch.mean(output_ins).expand_as(output_ins)                                        
                      outputs_mean[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]] = mean_feat                       
          return outputs_mean

    传入类的参数为:opt.output_nc(输出通道数)、opt.feat_num(编码后的特征的向量长度)、opt.nef(第一个卷积层的输出通道数)、opt.n_downsample_E(编码器中的下采样层数)以及norm_layer

编码器的网络非常的直接,先进行下采样然后再进行上采样,两个过程是对称的,对应到代码中同样是定义了model列表,然后先加入卷积层再加入转置卷积层,最后通过nn.Sequential(*model)定义model属性

在编码器的前向传播中,先使用model属性生成网络的输出,然后再通过instance-wise average pooling计算对应特征实例的平均特征,对于这个instance-wise average pooling的理解,以下的代码非常重要:

outputs_mean = outputs.clone()
        inst_list = np.unique(inst.cpu().numpy().astype(int))           # one dimension
        for i in inst_list:
            for b in range(input.size()[0]):
                indices = (inst[b:b+1] == int(i)).nonzero() # n x 4            
                for j in range(self.output_nc):
                    output_ins = outputs[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]]                    
                    mean_feat = torch.mean(output_ins).expand_as(output_ins)                                        
                    outputs_mean[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]] = mean_feat

首先inst_list就是根据输入的instance map使用np.unique()函数返回一个不重复的实例元组,然后通过三层循环找出每个实例对应位置的特征,求平均后作为最后特征图的输出

回到define_G函数,在获得了netG,也即网络模型后先直接打印看看网络的结构,然后通过判断len(gpu_ids)来决定是否使用GPU,在检查了torch.cuda.is_available()即GPU是否可用后,用.cuda()方法将网络加载到GPU上

最后define_G函数还用到了torch.nn.Module.apply()方法,这个方法的作用是递归地将函数应用于模型的每个子模块(包括当前模块),并返回应用后的模型,这里使用它是为了权重初始化,使用了weights_init()函数,其代码如下所示:

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

可见weights_init()函数就是使用了.__class__.__name__属性,即实例对应的类的名称,对卷积层(在类名称中找'Conv',有就是卷积层,没有返回-1),权重初始化为均值为0,标准差为0.02的高斯分布;对于batch normalization层,权重初始化为均值为1,标准差为0.02的高斯分布,偏置初始化为0

initialize()方法中的判别器网络

这一部分需要根据isTrain属性决定是否需要定义判别器网络,因为对于生成对抗网络而言只有在训练的时候判别器才能起作用
如果此时是训练阶段,则需要定义判别器网络,首先定义了use_sigmoid为opt.no_lsgan(表示是否使用LSGAN,默认为False,即使用LSGAN);然后定义了netD_input_nc,表示判别器输入的通道数,这里为input_nc+opt.output_nc(表示输出的图像通道数,默认为3),这里之所以要加是因为使用的是条件GAN,因此输入输出要进行concatenate操作;然后同样根据opt.no_instance决定netD_input_nc是否要加1

在研究判别器网络如何构建之前,需要对pix2pixHD采用的判别器有所了解,pix2pixHD的判别器采用多尺度判别器,即判别器网络的结构是相同的但是在不同的图像尺寸上进行判别,除此之外,判别器采用了70×70的Patch-GAN结构,论文中给出了其具体的网络如下所示:
判别器的网络结构
其中Ck表示4×4的卷积层、实例标准化层以及LeakyReLU层,其中卷积的卷积核个数为k,步长为2,C64中没有使用实例标准化层,LeakyReLU的slope设置为0.2,尺度选择为3,即有3个判别器,这3个判别器有相同的网络结构

回到这一部分最后定义了netD属性,即判别器网络,这个属性是define_D()函数函数的返回值,函数同样来自同一文件夹下的networks.py文件,函数的代码如下所示:

def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]):        
    norm_layer = get_norm_layer(norm_type=norm)   
    netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat)   
    print(netD)
    if len(gpu_ids) > 0:
        assert(torch.cuda.is_available())
        netD.cuda(gpu_ids[0])
    netD.apply(weights_init)
    return netD

函数的输入参数为:netD_input_nc、opt.ndf(表示判别器的第一个卷积层的输出通道数)、opt.n_layers_D、opt.norm、use_sigmoid、opt.num_D(表示使用的判别器数量,默认为3)、opt.no_ganFeat_loss(表示是否使用判别器特征匹配loss,默认为False,表示使用)

define_D()函数中也是首先获取norm_layer,然后定义了netD,即判别器网络模型,后面的操作和之前的类似,即先打印网络模型、将模型加载到GPU上、权重初始化。netD是MultiscaleDiscriminator类的实例,同样来自项目中models文件夹中的networks.py文件中,其代码如下所示:

class MultiscaleDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, 
                 use_sigmoid=False, num_D=3, getIntermFeat=False):
        super(MultiscaleDiscriminator, self).__init__()
        self.num_D = num_D
        self.n_layers = n_layers
        self.getIntermFeat = getIntermFeat
     
        for i in range(num_D):
            netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
            if getIntermFeat:                                
                for j in range(n_layers+2):
                    setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))                                   
            else:
                setattr(self, 'layer'+str(i), netD.model)

        self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)

    def singleD_forward(self, model, input):
        if self.getIntermFeat:
            result = [input]
            for i in range(len(model)):
                result.append(model[i](result[-1]))
            return result[1:]
        else:
            return [model(input)]

    def forward(self, input):        
        num_D = self.num_D
        result = []
        input_downsampled = input
        for i in range(num_D):
            if self.getIntermFeat:
                model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
            else:
                model = getattr(self, 'layer'+str(num_D-1-i))
            result.append(self.singleD_forward(model, input_downsampled))
            if i != (num_D-1):
                input_downsampled = self.downsample(input_downsampled)
        return result

MultiscaleDiscriminator类中主要根据num_D,即多尺度判别器的数量定义多个判别器,并用setattr()函数设置为类的属性,其中getIntermFeat是为了获取中间层的feature map,这与pix2pixHD的loss函数定义有关,单尺寸的判别器通过NLayerDiscriminator类来定义,其代码如下所示:

class NLayerDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):
        super(NLayerDiscriminator, self).__init__()
        self.getIntermFeat = getIntermFeat
        self.n_layers = n_layers

        kw = 4
        padw = int(np.ceil((kw-1.0)/2))
        sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]

        nf = ndf
        for n in range(1, n_layers):
            nf_prev = nf
            nf = min(nf * 2, 512)
            sequence += [[
                nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
                norm_layer(nf), nn.LeakyReLU(0.2, True)
            ]]

        nf_prev = nf
        nf = min(nf * 2, 512)
        sequence += [[
            nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
            norm_layer(nf),
            nn.LeakyReLU(0.2, True)
        ]]

        sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]

        if use_sigmoid:
            sequence += [[nn.Sigmoid()]]

        if getIntermFeat:
            for n in range(len(sequence)):
                setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
        else:
            sequence_stream = []
            for n in range(len(sequence)):
                sequence_stream += sequence[n]
            self.model = nn.Sequential(*sequence_stream)

    def forward(self, input):
        if self.getIntermFeat:
            res = [input]
            for n in range(self.n_layers+2):
                model = getattr(self, 'model'+str(n))
                res.append(model(res[-1]))
            return res[1:]
        else:
            return self.model(input)  

NLayerDiscriminator类就是定义了一系列卷积层,需要注意的是卷积层的参数,其中kw就是卷积核的大小,padw是卷积的填充,可见NLayerDiscriminator类的初始化函数中sequence就是一系列神经网络存放的地方,sequence中首先是卷积层、Lealy ReLU激活函数,这对应的是pix2pixHD论文的判别器结构的C64,而且C64中没有标准化层,然后根据n_layers定义n_layers-1个卷积层、标准化层和Lealy ReLU激活函数,这里n_layers=3,因此对应的是pix2pixHD论文的判别器结构的C128、C256,最后再定义一个卷积层、标准化层和Lealy ReLU激活函数,对应pix2pixHD论文的判别器结构的C512,最后还需要定义一个卷积层使输出通道数为1

最后如果use_sigmoid为True,则还需要在sequence后加一个nn.Sigmoid(),然后根据getIntermFeat决定是否要根据每一个卷积层设置类的参数,这是因为要获得中间卷积层的特征图,而如果不需要获取则只定义一个model属性

而在NLayerDiscriminator类的前向传播中也与上面的属性定义类似,首先是根据getIntermFeat来判断,如果不需要获得中间卷积层的输出特征,则直接将input作为model的参数;而如果需要获得,则循环获取每一层的输出,需要循环n_layers+2次,这是因为卷积层除了n_layers层还有一层输出通道为512的卷积层以及一层输出通道为1的卷积层

回到MultiscaleDiscriminator类,定义了netD属性后同样根据getIntermFeat判断是否将判别器模型的每一层卷积层分开,如果分开则需要定义多个属性,最后定义一个downsample属性,使用平均池化nn.AvgPool2d()函数
在MultiscaleDiscriminator类的前向传播函数中用到了singleD_forward()方法,这个方法就是单个判别器的前向传播,和前面的NLayerDiscriminator类的前向传播类似,定义result为单个判别器前向传播的结果,并每一次都使用downsample属性,即平均池化对输入进行下采样,如此就形成了多尺度的Patch判别器

initialize()方法中的编码器网络

这一部分根据gen_features属性决定是否定义编码器网络,编码器网络netE同样来自models文件夹下的networks.py文件中的define_G()函数,并由Encoder类定义,这个类在前面已经讲生成器网络的时候已经顺便讲过了,因此这里就不再赘述

initialize()方法中的加载网络

这一部分需要根据isTrain属性、opt.continue_train(表示是否继续训练,即加载最近一次训练的模型,默认为False)以及opt.load_pretrain(表示是否加载预训练网络,参数为一个地址,默认为'')来决定是否定义,opt.continue_train和opt.load_pretrain的定义比较明确,isTrain属性在这里是因为测试的时候也需要加载模型

加载网络主要使用的是load_network()方法,这个方法在前面也介绍过了,因此这里不再赘述,这里应当主要注意传入的参数如何形成存储的模型文件的路径,这个路径应当与后面保存网络模型时保持一致,其中opt.which_epoch表示加载哪一个epoch保存的网络模型

initialize()方法中的设置loss函数

这一部分主要根据isTrain属性来决定是否定义,毕竟在测试阶段不需要训练,也就不需要定义loss函数和优化器,如果是训练阶段,则需要定义。这部分定义了很多属性,下面进行一一说明:

  • fake_pool属性:ImagePool类的返回值,我也不知道这个类起了啥作用,但是传入参数opt.pool_size默认为0,则类的query()方法返回值就是输入,所以可以暂时忽略这个类
  • old_lr:学习率,因为在训练过程中学习率是要动态改变的,因此这里设置的是上一次的学习率,设置为opt.lr(表示初始学习率,默认为0.0002)
  • loss_filter属性:init_loss_filter()方法的返回值,传入方法的参数为:opt.no_ganFeat_loss(表示是否使用判别器特征匹配损失,默认为False,即使用)和opt.no_vgg_loss(表示是否使用VGG的特征匹配损失,默认为False,即使用),方法的代码如下:

    def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
      flags = (True, use_gan_feat_loss, use_vgg_loss, True, True)
      def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake):
          return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg,d_real,d_fake),flags) if f]
      return loss_filter

    可见init_loss_filter()设置了一些标志位,每个标志位对应一种损失函数,并定义了一个函数,函数中zip()函数的作用是将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表,因此函数的返回值是根据标志位决定的,标志位为True,即表示这个损失函数使能并返回

  • loss_names属性:loss_filter属性的返回值,传入的参数是五个字符串,对应五个loss函数,可见这个返回值就是上面讲到的对应标志位为True的loss函数

下面的几种属性就是loss函数对应的属性,在介绍之前,需要对pix2pixHD中的loss函数有所了解,如下图所示:
pix2pixHD的loss函数

  • criterionGAN属性:表示GAN的loss函数,是GANLoss类的实例,这个类来自项目中models文件夹中的networks.py文件,下面定义的几种loss函数的属性也是一样的,类的代码如下所示:

    class GANLoss(nn.Module):
      def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
                   tensor=torch.FloatTensor):
          super(GANLoss, self).__init__()
          self.real_label = target_real_label
          self.fake_label = target_fake_label
          self.real_label_var = None
          self.fake_label_var = None
          self.Tensor = tensor
          if use_lsgan:
              self.loss = nn.MSELoss()
          else:
              self.loss = nn.BCELoss()
    
      def get_target_tensor(self, input, target_is_real):
          target_tensor = None
          if target_is_real:
              create_label = ((self.real_label_var is None) or
                              (self.real_label_var.numel() != input.numel()))
              if create_label:
                  real_tensor = self.Tensor(input.size()).fill_(self.real_label)
                  self.real_label_var = Variable(real_tensor, requires_grad=False)
              target_tensor = self.real_label_var
          else:
              create_label = ((self.fake_label_var is None) or
                              (self.fake_label_var.numel() != input.numel()))
              if create_label:
                  fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
                  self.fake_label_var = Variable(fake_tensor, requires_grad=False)
              target_tensor = self.fake_label_var
          return target_tensor
    
      def __call__(self, input, target_is_real):
          if isinstance(input[0], list):
              loss = 0
              for input_i in input:
                  pred = input_i[-1]
                  target_tensor = self.get_target_tensor(pred, target_is_real)
                  loss += self.loss(pred, target_tensor)
              return loss
          else:            
              target_tensor = self.get_target_tensor(input[-1], target_is_real)
              return self.loss(input[-1], target_tensor)

    可见GANLoss类在初始化方法中定义了几个属性:real_label属性表示真实的图像的label,值为1;fake_label属性表示生成的图像的label,值为0;Tensor属性表示张量数据类型,是GPU的张量(torch.cuda.FloatTensor)还是CPU的张量(torch.Tensor);loss属性就是损失函数,如果使用LSGAN,则损失函数为nn.MSELoss(),如果不使用,则损失函数为nn.BCELoss(),即二分类交叉熵

GANLoss类还定义了__call__()方法,这个方法将类变得可以像函数一样可调用,在__call__()方法中需要判断输入的列表中的元素是否还是列表,这是因为在使能了GAN feature loss函数后,输入就是各个卷积层的输出,因此需要根据这一点不同来分别计算loss函数的结果。__call__()方法中还用到了get_target_tensor()方法,这个方法是生成一个和输入相同尺寸的label,如果是真,则填充为real_label,即填充1;如果为假,则填充为fake_label,即填充0,然后返回这个label和输入进行loss计算

  • criterionFeat属性:torch.nn.L1Loss()函数,表示L1 loss,即计算绝对值
  • criterionVGG属性:表示VGG feature loss函数,是VGGLoss类的实例,类的代码如下所示:

    class VGGLoss(nn.Module):
      def __init__(self, gpu_ids):
          super(VGGLoss, self).__init__()        
          self.vgg = Vgg19().cuda()
          self.criterion = nn.L1Loss()
          self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]        
    
      def forward(self, x, y):              
          x_vgg, y_vgg = self.vgg(x), self.vgg(y)
          loss = 0
          for i in range(len(x_vgg)):
              loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())        
          return loss

    可见VGGLoss类选用的VGG模型为VGG19,计算生成图像和真实图像经过VGG19网络中的几层结果,同样使用nn.L1Loss()计算每一层的loss结果,然后加权求和作为VGGloss最后的值

initialize()方法中的设置优化器

这一部分需要根据opt.niter_fix_global(表示只训练外层的局部增强器的epoch数量)进行判断,这是因为局部增强器和全局生成器的训练是不同的,pix2pixHD中的论文提到,采用先训练全局生成器,再训练局部增强器,最后进行微调的训练策略

如果opt.niter_fix_global等于0,则说明局部增强器没有被定义,则params,即网络参数只有全局生成器的参数,即netG.parameters();如果如果opt.niter_fix_global大于0,则只获取最外层局部增强器的网络参数作为params,表示需要对参数进行opt.niter_fix_global次微调

其中需要注意的是.named_parameters()方法返回的list中,每个元组打包了2个内容,分别是layer-name和layer-param(网络层的名字和参数的迭代器);而与之类似的还有.parameters()方法,该方法只返回参数的迭代器;而.state_dict()方法返回的是将layer_name:layer_param的键值信息,即存储为dict形式,同时.state_dict()方法返回的是该model中包含的所有layer中的所有参数,而.named_parameters()方法返回的是

最后定义了optimizer_G属性,即生成器的优化器,使用的是Adam优化器(torch.optim.Adam()
同理对于判别器,获取params即判别器的网络参数(netD.parameters()),然后定义optimizer_D属性,即判别器的优化器,同样使用Adam优化器

forward()方法

讲完了Pix2PixHDModel类的initialize()方法,下面就看一下Pix2PixHDModel类的前向传播部分,其代码如下所示:

def forward(self, label, inst, image, feat, infer=False):
    # Encode Inputs
    input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)  

    # Fake Generation
    if self.use_features:
        if not self.opt.load_features:
            feat_map = self.netE.forward(real_image, inst_map)                     
        input_concat = torch.cat((input_label, feat_map), dim=1)                        
    else:
        input_concat = input_label
    fake_image = self.netG.forward(input_concat)

    # Fake Detection and Loss
    pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
    loss_D_fake = self.criterionGAN(pred_fake_pool, False)        

    # Real Detection and Loss        
    pred_real = self.discriminate(input_label, real_image)
    loss_D_real = self.criterionGAN(pred_real, True)

    # GAN loss (Fake Passability Loss)        
    pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))        
    loss_G_GAN = self.criterionGAN(pred_fake, True)               
    
    # GAN feature matching loss
    loss_G_GAN_Feat = 0
    if not self.opt.no_ganFeat_loss:
        feat_weights = 4.0 / (self.opt.n_layers_D + 1)
        D_weights = 1.0 / self.opt.num_D
        for i in range(self.opt.num_D):
            for j in range(len(pred_fake[i])-1):
                loss_G_GAN_Feat += D_weights * feat_weights * \
                    self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
               
    # VGG feature matching loss
    loss_G_VGG = 0
    if not self.opt.no_vgg_loss:
        loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
    
    # Only return the fake_B image if necessary to save BW
    return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ), None if not infer else fake_image ]

InferenceModel类

优化器定义

train.py中的优化器定义部分代码如下:

visualizer = Visualizer(opt)
if opt.fp16:    
    from apex import amp
    model, [optimizer_G, optimizer_D] = amp.initialize(model, [model.optimizer_G, model.optimizer_D], opt_level='O1')             
    model = torch.nn.DataPara
    llel(model, device_ids=opt.gpu_ids)
else:
    optimizer_G, optimizer_D = model.module.optimizer_G, model.module.optimizer_D

total_steps = (start_epoch-1) * dataset_size + epoch_iter

display_delta = total_steps % opt.display_freq
print_delta = total_steps % opt.print_freq
save_delta = total_steps % opt.save_latest_freq

可见在forward()方法中首先使用了encode_input()方法来获得输入的各个图像,encode_input()方法的代码如下所示:

def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):             
    if self.opt.label_nc == 0:
        input_label = label_map.data.cuda()
    else:
        # create one-hot vector for label map 
        size = label_map.size()
        oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
        input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
        input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
        if self.opt.data_type == 16:
            input_label = input_label.half()

    # get edges from instance map
    if not self.opt.no_instance:
        inst_map = inst_map.data.cuda()
        edge_map = self.get_edges(inst_map)
        input_label = torch.cat((input_label, edge_map), dim=1)         
    input_label = Variable(input_label, volatile=infer)

    # real images for training
    if real_image is not None:
        real_image = Variable(real_image.data.cuda())

    # instance map for feature encoding
    if self.use_features:
        # get precomputed feature maps
        if self.opt.load_features:
            feat_map = Variable(feat_map.data.cuda())
        if self.opt.label_feat:
            inst_map = label_map.cuda()

    return input_label, inst_map, real_image, feat_map

encode_input()方法就是对几种输入的图像进行预处理,其中input_label代表输入的label和instance map经过边缘提取后的结果的concatenate形式,还有inst_map、real_image以及feat_map

获取了几种结果后,在前向传播方法中先根据use_features属性决定是否使用feature,如果需要加载预先的feature则不用经过Encoder网络,如果没有预先的feature,则需要将real_image和inst_map级联并输入Encoder网络中获取feature,同时将获得的feature和input_label级联起来输入生成器网络中。同理如果不需要使用feature,则输入生成器网络中的只有input_label。这里的生成器网络的输出为fake_image,即生成的图像

前向传播中后面的部分就是计算各种loss,用到了discriminate()方法,其代码如下所示:

def discriminate(self, input_label, test_image, use_pool=False):
    input_concat = torch.cat((input_label, test_image.detach()), dim=1)
    if use_pool:            
        fake_query = self.fake_pool.query(input_concat)
        return self.netD.forward(fake_query)
    else:
        return self.netD.forward(input_concat)

可见discriminate()方法就是将输入经过判别器网络的前向传播结果并计算出一个概率值

因此在前向传播方法中,loss_D_fake表示输入判别器的是生成图像的loss函数值,而loss_D_real表示输入判别器的是真实图像的loss函数值


Kazusa
5 声望7 粉丝

梦里不觉秋已深,余情岂是为他人