在【深度学习个人笔记】从实际深度学习项目来学习深度学习1——pix2pixHD开源项目学习中对pix2pixHD项目的README.md文件,并将对于pix2pixHD项目的学习分为两个部分:训练部分和测试部分,因此在这一篇博客中主要分析训练部分
pix2pixHD项目的训练部分从train.py作为起点,先不用管.py文件开头import了什么模块,等后面用到了再说,这里为了方便根据不同的功能将train.py分为四个部分(不包括.py文件开头的import部分),分别进行学习,其中第四五部分由于内容比较多,而且涉及深度学习的核心部分,将放在别的博客进行讲解
train.py的第一部分——项目的参数设置
train.py的第一部分的代码如下所示:
opt = TrainOptions().parse()
iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
if opt.continue_train:
try:
start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int)
except:
start_epoch, epoch_iter = 1, 0
print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
else:
start_epoch, epoch_iter = 1, 0
opt.print_freq = lcm(opt.print_freq, opt.batchSize)
if opt.debug:
opt.display_freq = 1
opt.print_freq = 1
opt.niter = 1
opt.niter_decay = 0
opt.max_dataset_size = 10
可见这一部分主要是定义了一些参数,下面就对这些参数的含义以及项目是如何定义这些参数
首先,TrainOptions
来自项目中options
文件夹下的train_options.py
文件,TrainOptions是BaseOptions
的子类,而BaseOptions来自项目中options文件夹的base_options.py
文件,也是一个类,包含类的初始化方法(__init__)、initialize方法以及parse方法
而TrainOptions只有一个方法initialize,如下图所示:
可见这个方法是先调用父类,即BaseOptions的initialize方法,因此先来看看BaseOptions类的初始化方法和initialize方法,接下来就按照顺序对这些方法进行讲解
BaseOptions类的初始化方法
BaseOptions类的初始化方法的代码如下:
def __init__(self):
self.parser = argparse.ArgumentParser()
self.initialized = False
可见初始化方法定义了两个属性:parser
和initialized
,initialized就是一个标志位,parser则是来自argparse模块的ArgumentParser对象
argparse模块是Python用于解析命令行参数和选项的标准模块,可以帮助编写用户友好的命令行接口,具体用法可参考下面的博客:
python之argparse模块常见用法包含实例(超详细)
python学习之argparse模块
argparse的参数说明(一文精通)
python argparse中action的可选参数store_true的作用
对argparse的使用可以总结为如下五个步骤:
import argparse
,即导入argparse模块parser = argparse.ArgumentParser(description='xxxx')
,即创建一个解析对象,该对象包含将命令行输入内容解析成Python数据的过程所需的全部功能,而description
是该对象的描述信息,可以在命令中加入-h
查看parser.add_argument('xxx', type=int, help='xxx')
,即添加需要输入的命令行参数,括号中依次为参数名、参数类型(这里是int,默认数据类型为str)、描述信息args = parser.parse_args()
,ArgumentParser通过parse_args()方法解析参数,获取到命令行中输入的参数- 调用命令行中的输入的参数完成功能
可见初始化方法中的parser就是完成了第一步和第二步,至于后面的步骤以及它们的一些具体用法,在别的方法中再进行说明
BaseOptions类的initialize方法
BaseOptions类的initialize方法的代码比较长,这里只截取其中的一部分,后面还有很多:
def initialize(self):
# experiment specifics
self.parser.add_argument('--name', type=str, default='label2city', help='name of the experiment. It decides where to store samples and models')
self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
self.parser.add_argument('--model', type=str, default='pix2pixHD', help='which model to use')
self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')
self.parser.add_argument('--data_type', default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit")
self.parser.add_argument('--verbose', action='store_true', default=False, help='toggles verbose')
self.parser.add_argument('--fp16', action='store_true', default=False, help='train with AMP')
self.parser.add_argument('--local_rank', type=int, default=0, help='local rank for distributed training')
.....
可见这一方法实际上就是完成了上一节中提到的对argparse的使用的五个步骤的第三个步骤,即添加需要输入的命令行参数,这一部分的内容大同小异,所以只需要看一部分就能明白剩下的是什么意思
首先对于添加命令行参数的基本用法进行讲解,命令行参数分为位置参数和选项参数,都是通过parser.add_argument('xxx', type=int, help='xxx')
,而对于位置参数和选项参数,一般分别用下面两种方式:
- 位置参数:
parser.add_argument('xxx', type=int, help='xxx')
- 选项参数:
parser.add_argument('-x', '--xxxx', type=int, help='xxxx')
可见对于选项参数可以在参数名前加两个-
,或者在比较简略的形式前加一个-
在添加命令行参数中除了type和help外,添加命令行参数还有一些常用的选项:
- default:当命令行没有设置具体的参数值时的默认参数
- required:表示参数是否一定需要设置
- choices:参数值只能从固定选项中选择
- metavar:参数的名字,在显示帮助信息时才会用到
- dest:argparse默认的变量名是--或-后面的字符串,可以通过dest=xxx来设置参数的变量名,然后在代码中只能用
args.xxx
来获取参数的值,而不能用--后面的字符串 action:对于True/False类型的参数,向add_argument方法中加入参数
action='store_true'/'store_false'
- 当输入命令时,不指定相应的参数时,store_true默认显示为False,store_false默认显示为True;指定相应的参数时,store_true的参数变为Ture,store_false的参数变为False
根据上面的知识可以对initialize方法中的参数进行分析,它们默认是什么值,是什么类型,有什么用等,这些参数在后面的数据集搭建和模型搭建中都用得到,因此等后面用到这些实际参数后再回过头看比较好,这里只需要掌握这种定义命令行参数的方式即可
在initialize方法的最后,除了设置命令行参数之外,还设置了initialized属性的值:
self.initialized = True
可见initialized属性就是表示是否初始化了这些命令行参数
TrainOptions类的initialize方法
TrainOptions类的initialize方法和BaseOptions类的initialize方法类似,由于比较长,这里也只给出部分代码如下:
def initialize(self):
BaseOptions.initialize(self)
# for displays
self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
self.parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results')
self.parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs')
self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration')
可见TrainOptions类的initialize方法先调用了父类的initialize方法,然后同样添加了一些命令行参数,只不过从类的名字也可以看出,这里添加的是和训练相关的参数
最后initialize()方法同样设置了一个属性:
self.isTrain = True
可见属性isTrain表示当前进行的是训练过程
BaseOptions类的parse方法
BaseOptions类的parse方法的代码如下所示:
def parse(self, save=True):
if not self.initialized:
self.initialize()
self.opt = self.parser.parse_args()
self.opt.isTrain = self.isTrain # train or test
str_ids = self.opt.gpu_ids.split(',')
self.opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
self.opt.gpu_ids.append(id)
# set gpu ids
if len(self.opt.gpu_ids) > 0:
torch.cuda.set_device(self.opt.gpu_ids[0])
args = vars(self.opt)
print('------------ Options -------------')
for k, v in sorted(args.items()):
print('%s: %s' % (str(k), str(v)))
print('-------------- End ----------------')
# save to the disk
expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
util.mkdirs(expr_dir)
if save and not self.opt.continue_train:
file_name = os.path.join(expr_dir, 'opt.txt')
with open(file_name, 'wt') as opt_file:
opt_file.write('------------ Options -------------\n')
for k, v in sorted(args.items()):
opt_file.write('%s: %s\n' % (str(k), str(v)))
opt_file.write('-------------- End ----------------\n')
return self.opt
可见parse方法首先根据initialized属性决定是否调用initialize()方法,即添加命令行参数,然后就调用了args = parser.parse_args()
,即使用argparse的第四个步骤,通过parse_args()方法解析参数,获取到命令行中输入的参数,并把这些参数放在的属性opt
中
parse方法后面的内容则完成了使用argparse的第五个步骤,调用命令行中的输入的参数设置一些参数,下面就依次介绍一下设置的参数的含义(结合代码进行阅读):
- opt.isTrain:表示训练还是测试,前面已经将isTrain属性设置为True
str_ids、opt.gpu_ids:表示GPU的编号,如果有多个GPU的话,命令行输入的参数为类似"0,1,2"这样的字符串形式,程序中根据逗号将其分隔开,并用列表将每个编号重新赋给opt.gpu_ids,最后根据opt.gpu_ids的长度使用
torch.cuda.set_device()
来设置GPU设备
但是这里值得一提的是,尽管在命令行中输入了多个GPU,但是程序仍然只采用了一个GPU进行训练,对应的程序如下:torch.cuda.set_device(self.opt.gpu_ids[0]) # 只设置了列表的第一个元素
事实上,我在项目的README.md文件中也发现了下面这句话:
- args = vars(self.opt):vars()函数返回类的对象的属性和属性值的字典对象,然后程序中使用了
sorted函数
(对所有可迭代的对象进行排序操作)以及.items()方法
(字典的方法,以列表返回可遍历的(键, 值)元组数组)来打印这些参数 - expr_dir:存储路径,用
os.path.join
拼接文件路径,这里是checkpoint的文件夹(./checkpoints)+项目的名字,然后使用util.mkdirs来创建目录 - 程序最后根据save(默认为True)以及opt.continue_train(是否继续训练,默认为False)来决定是否保存self.opt,即以opt.txt的文件名存放在上面的创建的目录中,最后程序返回self.opt,作为项目的参数供后续程序使用
train.py的第一部分的剩余部分
train.py的第一部分中除了前面讲的参数配置外还有剩余的一小部分,因此在这一节中进行讲解
首先根据opt.continue_train,即是否继续训练来确定训练开始的epoch(start_epoch
)以及epoch的iteration参数(epoch_iter
),由于此时还没有进行参数,因此这两个参数设置为默认值:
- start_epoch = 1
- epoch_iter = 0
有关epoch、iteration以及batch_size在深度学习中的关系,可以看一下下面的文章:
快速搞定 epoch, batch, iteration
可以总结为:迭代次数(iteration)=样本总数(epoch)/批尺寸(batchsize)
然后程序重新设置了opt.print_freq参数,用定义的lcm函数来进行计算,lcm函数的定义在train.py的import部分,代码如下:
def lcm(a,b): return abs(a * b)/fractions.gcd(a,b) if a and b else 0
其中fractions.gcd用于求最大公约数,lcm函数应该是用于求解最小公倍数,因为两个数的乘积等于这两个数的最大公约数与最小公倍数的积,这样求是为了保证print_freq是batchSize的倍数
值得一提的是, fractions.gcd自从python3.5以来被弃用,python3.9的时候更是已经删除了,所以如果报错应该改成math.gcd
最后train.py的第一部分根据opt.debug参数确定程序是否进行debug模式,这个参数默认是False,在参数配置部分给出了它的描述:only do one epoch and displays at each iteration
,可见代码进行debug模式后只进行一个epoch的训练,并且在每个iteration都进行训练结果的显示,从后面的参数也可以看出这一点:
- opt.display_freq = 1:在屏幕(screen)上显示训练结果的频率
- opt.print_freq = 1:在控制台(console)上显示训练结果的频率
- opt.niter = 1:开始的学习率持续的iteration
- opt.niter_decay = 0:开始线性降低学习率到0的iteration
- opt.max_dataset_size = 10:数据集允许的最大样本数量,如果数据集目录下的样本数大于最大数量,则只有一部分数据会被加载
train.py的第二部分——数据集相关定义
train.py的第二部分的代码如下所示:
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
print('#training images = %d' % dataset_size)
可见这一部分主要和数据集的相关定义有关,首先CreateDataLoader()
是来自项目data文件夹下的data_loader.py文件的函数,其代码如下:
def CreateDataLoader(opt):
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
print(data_loader.name())
data_loader.initialize(opt)
return data_loader
可见这个函数主要用到了来自data文件夹下的CustomDatasetDataLoader
类,函数中使用这个类定义了一个对象,然后调用了这个对象的.name()
方法和.initialize()
方法,CustomDatasetDataLoader类的代码如下:
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads))
def load_data(self):
return self.dataloader
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
在对CustomDatasetDataLoader类进行说明之前,先来看看在Pytorch中如何构建自定义的数据集,这需要用到torch.utils.data
中的Dataset
和DataLoader
类,使用它们的流程大概可以总结如下:
- 定义Dataset的子类并实例化
- 使用DataLoader加载数据
- 循环使用DataLoader加载的数据进行训练或验证
构建数据集的模板代码如下:
from torch.utils.data import Dataset, DataLoader
import numpy as np
class MyDataSet(Dataset):
def __init__(self, path):
"""可以在初始化函数当中对数据进行一些操作,比如读取、归一化等"""
self.data = np.loadtxt(path) # 读取 txt 数据
self.x = self.data[:, 1:] # 输入变量
self.y = self.data[:, 0] # 输出变量
def __len__(self):
"""返回数据集当中的样本个数"""
return len(self.data)
def __getitem__(self, index):
"""返回样本集中的第 index 个样本;输入变量在前,输出变量在后"""
return self.x[index, :], self.y[index]
data_set = MyDataSet('data.txt')
data_loader = DataLoader(dataset=data_set, batch_size=400, shuffle=True, drop_last=False)
可见通过Dataset类定义数据集需要实现__len__()
和__getitem__()
方法,其中__len__()方法返回数据集中的样本总个数。而__getitem__()方法需要根据索引返回对应的样本;而DataLoader则用于创建迭代器,它有几个参数,说明如下:
- dataset:要读取的数据集,要么是torch.utils.data.Dataset类的对象,要么是继承自torch.utils.data.Dataset类的子类的对象
- batch_size:训练时一个batch的数量
- shuffle:是否打乱数据
- num_workers:是否多线程读取数据
- drop_last:当样本数不能被batchsize整除时,最后一批数据是否舍弃
- collate_fn:如何取样本,可自己定义函数实现想要的功能
了解了使用Pytorch构建数据集的基本流程,下面就看一下CustomDatasetDataLoader类的几个方法:
- name()方法:可见就是返回一个字符串命名
initialize()方法:首先CustomDatasetDataLoader类是BaseDataLoader的子类,在initialize()方法中首先调用了BaseDataLoader类的initialize()方法,BaseDataLoader类的代码如下所示:
class BaseDataLoader(): def __init__(self): pass def initialize(self, opt): self.opt = opt pass def load_data(): return None
可见BaseDataLoader()只是定义了一些方法,没有具体实现方法的功能,而在initialize()方法中,仅是将传入的命令行参数opt赋给了self.opt属性
然后回到CustomDatasetDataLoader类的initialize()方法,方法调用了BaseDataLoader()类的initialize()方法后,定义了两个属性dataset和dataloader,首先dataloader属性在前面介绍Pytorch中如何构建自定义的数据集时已经提到了,dataloader属性就是torch.utils.data.DataLoader创建的迭代器,其中数据集参数为dataset属性,batch_size参数设置为命令行参数中的batchSize(程序中默认为1),shuffle参数设置为命令行参数中的serial_batches(程序中默认为False)的取反,num_workers参数设置为命令行参数中的nThreads(程序中默认为2,表示加载数据的线程数)
至于dataset属性,由于涉及到数据集的定义,内容比较多也很重要,因此这里就先介绍CustomDatasetDataLoader类的其他方法,具体的等这些方法介绍完后放在一个专门的小节里进行介绍
- load_data()方法:直接返回initialize()方法中定义的属性dataloader
- __len__()方法:
还需要了解类中的__len__()方法的相关知识,对于类而言,len()函数是没有办法直接计算类的长度的,因为在类中包含着众多的属性以及方法,是一种抽象的实体。如果在类中没有定义__len__()方法来指明程序到底该计算哪个属性的长度时,必须采用len(对象.属性)才能得到我们想要的结果。如果直接采用len(对象)的方法,程序会报错并提示类并没有len()方法
有关类的__len__()方法的详细说明,可以参考下面这篇博客:
python的__len__()方法
CustomDatasetDataLoader类中initialize()方法的dataset属性介绍
下面就来看一下dataset属性,也就是数据集是如何进行配置的,dataset属性是CreateDataset()
函数的返回值,其代码如下所示:
def CreateDataset(opt):
dataset = None
from data.aligned_dataset import AlignedDataset
dataset = AlignedDataset()
print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
return dataset
可见dataset属性是项目中的data文件夹下的aligned_dataset.py文件中的AlignedDataset
类的对象,而在CreateDataset()函数首先就是创建了AlignedDataset类的对象,然后调用了.name()
方法和.initialize()
方法,下面就来看看AlignedDataset类如何定义,其代码如下所示:
class AlignedDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
### input A (label maps)
dir_A = '_A' if self.opt.label_nc == 0 else '_label'
self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A)
self.A_paths = sorted(make_dataset(self.dir_A))
### input B (real images)
if opt.isTrain or opt.use_encoded_image:
dir_B = '_B' if self.opt.label_nc == 0 else '_img'
self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B)
self.B_paths = sorted(make_dataset(self.dir_B))
### instance maps
if not opt.no_instance:
self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst')
self.inst_paths = sorted(make_dataset(self.dir_inst))
### load precomputed instance-wise encoded features
if opt.load_features:
self.dir_feat = os.path.join(opt.dataroot, opt.phase + '_feat')
print('----------- loading features from %s ----------' % self.dir_feat)
self.feat_paths = sorted(make_dataset(self.dir_feat))
self.dataset_size = len(self.A_paths)
def __getitem__(self, index):
### input A (label maps)
A_path = self.A_paths[index]
A = Image.open(A_path)
params = get_params(self.opt, A.size)
if self.opt.label_nc == 0:
transform_A = get_transform(self.opt, params)
A_tensor = transform_A(A.convert('RGB'))
else:
transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
A_tensor = transform_A(A) * 255.0
B_tensor = inst_tensor = feat_tensor = 0
### input B (real images)
if self.opt.isTrain or self.opt.use_encoded_image:
B_path = self.B_paths[index]
B = Image.open(B_path).convert('RGB')
transform_B = get_transform(self.opt, params)
B_tensor = transform_B(B)
### if using instance maps
if not self.opt.no_instance:
inst_path = self.inst_paths[index]
inst = Image.open(inst_path)
inst_tensor = transform_A(inst)
if self.opt.load_features:
feat_path = self.feat_paths[index]
feat = Image.open(feat_path).convert('RGB')
norm = normalize()
feat_tensor = norm(transform_A(feat))
input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
'feat': feat_tensor, 'path': A_path}
return input_dict
def __len__(self):
return len(self.A_paths) // self.opt.batchSize * self.opt.batchSize
def name(self):
return 'AlignedDataset'
可见AlignedDataset类是BaseDataset
类的子类,BaseDataset类的代码如下所示:
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
def name(self):
return 'BaseDataset'
def initialize(self, opt):
pass
可见BaseDataset类和之前提到的BaseDataLoader类类似,只是定义了一些方法,没有具体实现方法的功能,唯一值得注意的是BaseDataset类是torch.utils.data.Dataset
的子类,这个在前面说过是Pytorch提供的自定义数据集的类
下面就回到AlignedDataset类中看看它有哪些方法:
- name()方法:可见就是返回一个字符串命名
- initialize()方法:initialize()方法完成一些初始化功能,主要是根据命令行参数定义了一些属性,下面就来看看这些属性的含义
- self.opt:直接代表命令行参数
- self.root:数据集的根目录,默认为项目中datasets文件夹中的cityscapes文件夹下
- self.dir_A:使用os.path.join()函数拼接文件路径,这里拼接了opt.dataroot、opt.phase(训练模式下为'train')以及dir_A(当opt.label_nc为0时设置为'_A',不为0时设置为'_label'),可见这个参数指的是label图像的路径(至于label图像是什么,可以去看一下对cityscapes数据集的描述)
self.A_paths:
sorted(make_dataset())
两个函数的返回值
首先sorted()函数是对可迭代对象进行排序操作,而make_dataset()函数的代码如下:def make_dataset(dir): images = [] assert os.path.isdir(dir), '%s is not a valid directory' % dir for root, _, fnames in sorted(os.walk(dir)): for fname in fnames: if is_image_file(fname): path = os.path.join(root, fname) images.append(path) return images
可见make_dataset()函数使用了os.walk()函数来遍历文件夹及子文件夹下所有文件并得到路径,返回的是一个三元组(root,dirs,files),其中root指当前正在遍历的这个文件夹的本身的地址,dirs是一个 list,内容是该文件夹中所有的目录的名字(不包括子目录),而files也是list,内容是该文件夹中所有的文件(不包括子目录),函数会自己改变root的值以遍历所有子文件夹
对于每一个文件(也就是程序里的fname和fnames),先用is_image_file()函数判断是否为图像文件,相关的代码如下所示:IMG_EXTENSIONS = [ '.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' ] def is_image_file(filename): return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
可见is_image_file()就是通过扩展名来检查是否是图像文件
检查是否是图像文件后,将图像文件的路径加入到images列表中并返回,因此self.A_paths实际上就是对应路径下的图像文件名排序后构成的列表- self.dir_B和self.B_paths:这两个属性是由opt.isTrain(训练模式下为True)和opt.use_encoded_image(来自测试的命令行参数,默认为False)决定是否定义的,和上面的属性类似,分别表示路径和图像文件名的列表,同时可以看出这里的图像代表encode的图像(至于具体是什么,可以参考pix2pixHD的论文,或者等到讲模型的时候再详细说)
- self.dir_inst和self.inst_paths:这两个属性是由opt.no_instance(默认为False)决定是否定义的,同样分别代表路径和图像文件名的列表,这里的图像代表instance的图像
- self.dir_feat和self.feat_paths:这两个属性是由opt.load_features(默认为False)决定是否定义的,同样分别代表路径和图像文件名的列表,这里的图像代表feature的图像
- self.dataset_size:数据集的长度,实际上就是self.A_paths列表的长度
__getitem__()方法:这个方法就是前面提到的Pytorch通过Dataset类定义数据集时必须要实现的方法之一,即根据index返回对应的样本
其中根据index返回对应的样本和上面的initialize()方法是对应的,因为在initialize()方法中的几个属性是通过判断来决定是否定义的,而在__getitem__()方法中因为要取样本,因此定义了属性才能取样本,因为这些属性就是这些样本存放的路径
首先__getitem__()方法定义了A_path,根据index取出图像列表中对应index的图像文件路径,然后定义了A,用PIL模块中的Image来打开图像文件,然后定义了params,为get_params()
函数的返回值,get_params()函数的代码如下所示:def get_params(opt, size): w, h = size new_h = h new_w = w if opt.resize_or_crop == 'resize_and_crop': new_h = new_w = opt.loadSize elif opt.resize_or_crop == 'scale_width_and_crop': new_w = opt.loadSize new_h = opt.loadSize * h // w x = random.randint(0, np.maximum(0, new_w - opt.fineSize)) y = random.randint(0, np.maximum(0, new_h - opt.fineSize)) flip = random.random() > 0.5 return {'crop_pos': (x, y), 'flip': flip}
可见get_params()函数就是返回一些参数,首先根据传入的参数size(也就是原图像的尺寸)计算新的图像尺寸,
opt.resize_or_crop
代表是否对图像进行缩放(scale)和裁剪(crop),有四种选择:resize_and_crop、crop、scale_width、scale_width_and_crop,函数中的判断判断了以下两种情况:- 'resize_and_crop':图像新的尺寸设置为opt.loadSize(默认为1024)
- 'scale_width_and_crop':等比例缩放,即图像新的宽度设置为opt.loadSize,而图像新的高度设置为opt.loadSize除以原来的宽度乘以原来的高度
然后就定义了开始进行图像裁剪的坐标,使用了random.randint()函数(random.randint()函数生成[a,b]之间的随机整数),生成0到new_w/new_h-opt.fineSize之间的随机整数,其中opt.fineSize代表裁剪后的尺寸,则为了保证尺寸只能在一个范围内随机确定开始裁剪的坐标
最后定义了flip,由于random.random()函数是生成[0,1)之间的随机数,因此和0.5比较以1/2的概率确定是否对图像进行翻转,函数返回一个字典,包含上面说的随机裁剪的坐标以及是否flip,即图像翻转
从对get_params()函数的分析可以看出,params就是对数据集进行预处理的参数,通过这些参数后面可以调用相应的预处理函数对数据集进行处理
然后__getitem__()方法根据opt.label_nc(这个参数的意义是输入label通道数,默认为35)来定义transform_A和A_tensor变量,首先transform_A是get_transform()
函数的返回值,函数的代码如下:
def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
transform_list = []
if 'resize' in opt.resize_or_crop:
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Scale(osize, method))
elif 'scale_width' in opt.resize_or_crop:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
if 'crop' in opt.resize_or_crop:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
if opt.resize_or_crop == 'none':
base = float(2 ** opt.n_downsample_global)
if opt.netG == 'local':
base *= (2 ** opt.n_local_enhancers)
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
transform_list += [transforms.ToTensor()]
if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
可见get_transform()函数是以params,即对数据集进行预处理的参数来获取一系列图像的预处理操作列表,也就是代码中的transform_list
首先根据opt.resize_or_crop参数中含有'resize'还是'scale_width'来决定对图像是直接缩放还是等比例缩放,需要注意的是,这里引入了torchvision.transforms模块,transforms.Scale()
函数的作用是对图像进行缩放,传入的参数是期望输出的尺寸以及内插方法(这里默认是Image.BICUBIC,即三次插值);而transforms.Lambda()
函数的作用是用于自定义变换操作,其参数表示用于进行变换的函数,代码中采用lambda匿名函数的方式实现,其中使用了__scale_width()
函数,其代码如下所示:
def __scale_width(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), method)
可见__scale_width()就是根据图像原来的尺寸以及目标宽度,计算出等比例的目标高度,再调用.resize()方法,采用插值对图像进行缩放
然后在get_transform()函数中,根据opt.resize_or_crop参数中是否含有'crop'来决定是否对图像进行裁剪,变换函数同样采用transforms.Lambda()自定义变换操作,并采用匿名函数的形式,使用了__crop()函数,其代码如下所示:
def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
if (ow > tw or oh > th):
return img.crop((x1, y1, x1 + tw, y1 + th))
return img
可见__crop()函数就是调用了.crop()方法根据传入的坐标和大小参数对图像进行裁剪
然后在get_transform()函数中又根据opt.resize_or_crop参数是否为'none'来定义变换函数,由于这里跟模型的结构有关,没有对模型进行描述之前不太好讲,因此这里就先跳过。然后get_transform()函数中根据opt.isTrain以及opt.no_flip(如果为True表示不对图像进行翻转,默认为False)决定是否进行图像翻转,同样用匿名函数,具体采用了__flip()函数,其代码如下所示:
def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
可见根据上面的随机1/2概率决定是否翻转,如果要翻转则调用.transpose()方法,其中Image.FLIP_LEFT_RIGHT表示左右翻转
最后,get_transform()函数添加了transforms.ToTensor()
操作将PIL Image或者ndarray转换为tensor,形状从HWC转换为CHW,并且除以255归一化至[0-1];然后还需要根据normalize决定是否对图像进行标准化,即使用transforms.Normalize()
函数,传入的参数为三个通道上的均值和标准差,然后通过(x-mean)/std进行标准化,关于标准化的操作可以参考下面这篇文章以及评论区的讨论:
pytorch中归一化transforms.Normalize的真正计算过程
get_transform()函数中生成的变换函数列表transform_list传入transforms.Compose()
函数,这个函数就是将变换操作列表中的变换操作进行遍历,因此只需调用这个函数的返回值就相当于完成了列表中的所有图像变换操作
了解了get_transform()函数后就可以回到AlignedDataset类中的__getitem__()方法中,transform_A就是get_transform()函数返回的一系列图像变换操作,A_tensor就是通过图像变换操作后的张量,然后面的类似分别定义了transform_B和B_tensor(对应encode的图像)、inst_tensor(对应instance的图像)、feat_tensor(对应feature的图像)
值得注意的是,代码中有很多.convert('RGB')
操作,这是因为有些图像读出来的图像是RGBA四通道的,A通道为透明通道,该通道值对深度学习模型训练来说暂时用不到,因此使用convert('RGB')进行通道转换
最后__getitem__()方法将这些张量整合成了一个字典,并作为函数的返回值
- __len__()方法:这个方法也是通过Dataset类定义数据集时必须要实现的方法之一,即返回数据集中的样本总个数,这里就直接计算A_paths路径列表的长度与opt.batchSize相除并向下取整的结果,然后再乘以opt.batchSize,这是为了保证数据集中的样本个数等于opt.batchSize的整数倍
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。