导入相关库
import os
from glob import glob
import torch as t
# 设置随机种子是为了保证结果的可重复性
t.random.manual_seed(0)
t.cuda.manual_seed_all(0)
# Benchmark模式会提升计算速度,但是由于计算中有随机性,每次网络前馈结果略有差异
t.backends.cudnn.benchmark = True
# 避免上一句所带来的波动
t.backends.cudnn.deterministic = True
from PIL import Image
import torch.nn as nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, MultiStepLR, CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patch
import torch.nn.functional as F
import json
from torchvision.models.mobilenet import mobilenet_v2
from torchvision.models.resnet import resnet18, resnet34
from torchsummary import summary
%matplotlib inline
设置网络配置参数
class Config:
batch_size = 16
# 初始学习率
lr = 1e-2
# 动量
momentum = 0.9
# 衰减系数
weights_decay = 1e-5
class_num = 11
# 每隔多少个epoch进行一次网络评估
eval_interval = 1
# 每隔多少个epoch保存一次模型
checkpoint_interval = 1
# 每隔多少个iteration进行进度条更新或输出log
print_interval = 50
# 模型保存路径
checkpoints = 'drive/My Drive/Data/Datawhale-DigitsRecognition/checkpoints/'
# 预训练模型加载路径
pretrained = '/content/drive/My Drive/Data/Datawhale-DigitsRecognition/checkpoints/epoch-32_acc-0.67.pth'
# 开始训练的epoch
start_epoch = 0
# 一共训练的epoch数目
epoches = 50
# label smooth参数,为1表示不使用label smooth
smooth = 0.1
# 随机擦除的概率, 为0表示不擦除
erase_prob = 0.5
config = Config()
构建网络模型
通常而言,在构建Baseline时,会选择参数尽可能少,模型复杂度较低的轻量级网络作为backbone。如果可以work,后期才会用更复杂的backbone来替换它。
这里选用的是MobileNet V2作为backbone, 来搭建一个分类网络
class DigitsMobilenet(nn.Module):
def __init__(self, class_num=11):
super(DigitsMobilenet, self).__init__()
self.net = mobilenet_v2(pretrained=True)
self.net.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1))
)
self.fc1 = nn.Linear(1280, class_num)
self.fc2 = nn.Linear(1280, class_num)
self.fc3 = nn.Linear(1280, class_num)
self.fc4 = nn.Linear(1280, class_num)
self.fc5 = nn.Linear(1280, class_num)
def forward(self, img):
"""
Params:
img(tensor): shape [N, C, H, W]
Returns:
fc1(tensor): 代表第1个字符的presentation
fc2(tensor): 代表第2个字符的presentation
fc3(tensor): 代表第3个字符的presentation
fc4(tensor): 代表第4个字符的presentation
fc5(tensor): 代表第5个字符的presentation
"""
features = self.net(img).view(-1, 1280)
fc1 = self.fc1(features)
fc2 = self.fc2(features)
fc3 = self.fc3(features)
fc4 = self.fc4(features)
fc5 = self.fc5(features)
return fc1, fc2, fc3, fc4, fc5
class DigitsResnet18(nn.Module):
def __init__(self, class_num=11):
super(DigitsMobilenet, self).__init__()
self.net = resnet18(pretrained=True)
# nn.Identity表示空层, 输入等于输出
self.net.fc = nn.Identity()
self.fc1 = nn.Linear(512, class_num)
self.fc2 = nn.Linear(512, class_num)
self.fc3 = nn.Linear(512, class_num)
self.fc4 = nn.Linear(512, class_num)
self.fc5 = nn.Linear(512, class_num)
def forward(self, img):
features = self.net(img).squeeze()
fc1 = self.fc1(features)
fc2 = self.fc2(features)
fc3 = self.fc3(features)
fc4 = self.fc4(features)
fc5 = self.fc5(features)
return fc1, fc2, fc3, fc4, fc5
构建训练模块
这里使用了几个Tricks
- Label Smooth标签平滑
标签平滑是一种正则化技术,避免由于数据量小导致的过拟合。
Label smooth的公式如下,ε表示平滑度(实验中设置为0.1)C表示多分类的类别数,Pi表示软化后的标签概率。
$$P_i=\begin{cases} 1-\epsilon \quad if(i=y)\\\frac{\epsilon}{C-1}\quad if(i\neq y) \end{cases}$$比如一个label的one-hot 编码向量为[0, 1, 0, 0], 经过label smooth之后的one-hot 编码向量变为[0.033, 0.9, 0.033, 0.033]。
- 余弦衰减+warmup
通常而言,刚开始梯度是极其不稳定的,因此应该使用较小的学习率先train几个迭代次数,然后将学习率恢复到初始学习率,开始正常训练。
warmup在前n(n设为10)次迭代过程中,线性调整学习率到达初始学习率.一定程度上保证了训练的稳定性,并且可以更好的收敛到极小值。
而余弦衰减调整策略则可以很好的跳出局部极小值,有更大的可能得到更优的局部极小值。
如下图所示,分别表示warmup和余弦衰减策略下的学习率曲线
# ----------------------------------- LabelSmoothEntropy ----------------------------------- #
class LabelSmoothEntropy(nn.Module):
def __init__(self, smooth=0.1, class_weights=None, size_average='mean'):
super(LabelSmoothEntropy, self).__init__()
self.size_average = size_average
self.smooth = smooth
self.class_weights = class_weights
def forward(self, preds, targets):
lb_pos, lb_neg = 1 - self.smooth, self.smooth / (preds.shape[0] - 1)
smoothed_lb = t.zeros_like(preds).fill_(lb_neg).scatter_(1, targets[:, None], lb_pos)
log_soft = F.log_softmax(preds)
if self.class_weights is not None:
loss = -log_soft * smoothed_lb * self.class_weights[None, :]
else:
loss = -log_soft * smoothed_lb
loss = loss.sum(1)
if self.size_average == 'mean':
return loss.mean()
elif self.size_average == 'sum':
return loss.sum()
else:
raise NotImplementedError
class Trainer:
def __init__(self):
self.device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
self.train_set = DigitsDataset(data_dir['train_data'], data_dir['train_label'])
self.train_loader = DataLoader(self.train_set, batch_size=config.batch_size, num_workers=8, pin_memory=True, drop_last=True)
self.val_loader = DataLoader(DigitsDataset(data_dir['val_data'], data_dir['val_label'], aug=False), batch_size=config.batch_size,\
num_workers=8, pin_memory=True, drop_last=True)
self.model = DigitsMobilenet(config.class_num).to(self.device)
# 使用Label Smooth
self.criterion = LabelSmoothEntropy().to(self.device)
self.optimizer = SGD(self.model.parameters(), lr=config.lr, momentum=config.momentum, weight_decay=config.weights_decay, nesterov=True)
# 使用余弦衰减学习率调整策略
self.lr_scheduler = CosineAnnealingWarmRestarts(self.optimizer, 10, 2, eta_min=10e-4)
# self.lr_scheduler = (self.optimizer, [10, 20, 30], 0.5)
self.best_acc = 0
if config.pretrained is not None:
self.load_model(config.pretrained)
# print('Load model from %s'%config.pretrained)
acc = self.eval()
self.best_acc = acc
print('Load model from %s, Eval Acc: %.2f'%(config.pretrained, acc * 100))
def train(self):
for epoch in range(config.start_epoch, config.epoches):
self.train_epoch(epoch)
if (epoch + 1) % config.eval_interval == 0:
print('Start Evaluation')
acc = self.eval()
if acc > self.best_acc:
os.makedirs(config.checkpoints, exist_ok=True)
save_path = config.checkpoints+'epoch-%d_acc-%.2f.pth'%(epoch+1, acc)
self.save_model(save_path)
print('%s saved successfully...'%save_path)
self.best_acc = acc
def train_epoch(self, epoch):
total_loss = 0
corrects = 0
tbar = tqdm(self.train_loader)
self.model.train()
for i, (img, label) in enumerate(tbar):
img = img.to(self.device)
label = label.to(self.device)
self.optimizer.zero_grad()
pred = self.model(img)
loss = self.criterion(pred[0], label[:, 0]) + \
self.criterion(pred[1], label[:, 1]) + \
self.criterion(pred[2], label[:, 2]) + \
self.criterion(pred[3], label[:, 3]) + \
self.criterion(pred[4], label[:, 4])
total_loss += loss.item()
loss.backward()
self.optimizer.step()
temp = t.stack([\
pred[0].argmax(1) == label[:, 0], \
pred[1].argmax(1) == label[:, 1], \
pred[2].argmax(1) == label[:, 2], \
pred[3].argmax(1) == label[:, 3], \
pred[4].argmax(1) == label[:, 4]\
], dim=1)
# 只有预测的数字全部正确才算正确
corrects += t.all(temp, dim=1).sum().item()
if (i + 1) % config.print_interval == 0:
self.lr_scheduler.step()
tbar.set_description('loss: %.3f, acc: %.3f'%(loss/(i+1), corrects*100/((i + 1) * config.batch_size)))
def eval(self):
self.model.eval()
corrects = 0
with t.no_grad():
tbar = tqdm(self.val_loader)
for i, (img, label) in enumerate(tbar):
img = img.to(self.device)
label = label.to(self.device)
pred = self.model(img)
temp = t.stack([
pred[0].argmax(1) == label[:, 0], \
pred[1].argmax(1) == label[:, 1], \
pred[2].argmax(1) == label[:, 2], \
pred[3].argmax(1) == label[:, 3], \
pred[4].argmax(1) == label[:, 4]\
], dim=1)
corrects += t.all(temp, dim=1).sum().item()
tbar.set_description('Val Acc: %.2f'%(corrects * 100 /((i+1)*config.batch_size)))
self.model.train()
return corrects / (len(self.val_loader) * config.batch_size)
def save_model(self, save_path, save_opt=False, save_config=False):
# 保存模型
dicts = {}
dicts['model'] = self.model.state_dict()
if save_opt:
dicts['opt'] = self.optimizer.state_dict()
if save_config:
dicts['config'] = {s: config.__getattribute__(s) for s in dir(config) if not s.startswith('_')}
t.save(dicts, save_path)
def load_model(self, load_path, save_opt=False, save_config=False):
# 加载模型
dicts = t.load(load_path)
self.model.load_state_dict(dicts['model'])
if save_opt:
self.optimizer.load_state_dict(dicts['opt'])
if save_config:
for k, v in dicts['config'].items():
config.__setattr__(k, v)
总结
总的来说,个人觉得用分类的思想还是挺新颖的,刚开始我都没想过要用分类来做。如果分类模型就能搞定,那何必用目标检测来干呢。当然,针对竞赛而言,目标检测效果应该会更好。
这部分内容和之前的内容是高度相关的,这部分用到了之前的代码。
代码放在我的gihub仓库,欢迎Star。
所有数据我也通过云盘共享,这是地址
ok,暂时就这样了
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。