Abstract: U2Net is an excellent salient target detection algorithm, published by Qin Xuebin et al. in Pattern Recognition 2020 journal [Arxiv]. The origin of the name U2Net is that its network structure consists of a two-layer nested Unet structure, which can be trained from scratch without the need for pre-training the backbone network, and has excellent performance.
This article is shared from Huawei Cloud Community " ModelArts Notebook Fast Open Source Project Actual Combat-U2Net ", author: shpity.
One, U2Net introduction
U2Net is an excellent salient target detection algorithm, published by Qin Xuebin and others in the Pattern Recognition 2020 journal [Arxiv]. The origin of the name U2Net is that its network structure consists of a two-layer nested Unet structure, which can be trained from scratch without the need for pre-training the backbone network, and has excellent performance. Its network structure is shown as in Fig. 1.
Figure 1. The main frame of U2Net is a codec structure similar to U-Net, but each block is replaced with a newly proposed residual U-block module
Project open source address: https://github.com/xuebinqin/U-2-Net
Two, create a notebook development environment
1. Enter the ModelArts console
2. Select Development Environment -> Notebook -> Create
3. Create Notebook
3.1 You can choose a name related to the task to facilitate management;
3.2 In order to reduce unnecessary resource consumption, it is recommended to turn on automatic stop;
3.3 The operating environment required by U2Net is already included in the public image, you can choose pytorch1.4-cuda10.1-cudnn7-ubuntu18.04;
3.4 It is recommended to select the GPU type to facilitate fast training of the model;
3.5 Select Create Now -> Submit, wait for the notebook to be created and open the notebook.
4. Import the source code of the open source project (git/manual upload)
4.1 Use git to clone a remote warehouse in Terminal
cd work # 注意:只有/home/ma-user/work目录及其子目录下的文件在Notebook实例关闭后会保存
git clone https://github.com/xuebinqin/U-2-Net.git
4.2 If the git speed is slow, you can also upload the code locally, directly drag the compressed package to the file directory bar on the left or upload it using OBS.
Three, data preparation
1. Download the training data APDrawing dataset
Use Wget to download directly to Notebook, or download it locally and then drag it to Notebook.
wget https://cg.cs.tsinghua.edu.cn/people/~Yongjin/APDrawingDB.zip
unzip APDrawingDB.zip
Note: If the data set is large (>5GB) and needs to be downloaded to another directory (it will be deleted after the instance is stopped), it is recommended to store it in OBS and pull it at any time when needed.
#从OBS中拉取代码到指定目录
sh-4.4$ source /home/ma-user/anaconda3/bin/activate PyTorch-1.4
sh-4.4$ python
>>> mox.file.copy_parallel('obs://bucket-xxxx/APDrawingDB', '/home/ma-user/work/APDrawingDB')
2. Split the training data
The data set ./APDrawingDB/data/train contains 420 training images with a resolution of 512*1024. The left side is the input image, and the right side is the corresponding ground truth. We need to split the big picture into two sub-pictures from the middle.
2.1 Create a new Pytorch-1.4 jupyter Notebook file in the Notebook development environment. The name can be split.ipynb. The script will generate 840 sub-images in the ./APDrawingDB/data/train/split directory. The original image ends in .jpg. The gt image ends in .png, which is convenient for the subsequent training code to read [similar to the test folder segmentation step].
from PIL import Image
import os
train_img_dir = os.path.join("./APDrawingDB/data/train")
img_list = os.listdir(train_img_dir)
for image in img_list:
img_path = os.path.join(train_img_dir, image)
if not os.path.isdir(img_path):
img = Image.open(img_path)
#print(img.size)
save_img_dir = os.path.join(train_img_dir, 'split_train')
if not os.path.exists(save_img_dir):
os.mkdir(save_img_dir)
save_img_path = os.path.join(save_img_dir, image)
cropped_left = img.crop((0, 0, 512, 512)) # (left, upper, right, lower)
cropped_right = img.crop((512, 0, 1024, 512)) # (left, upper, right, lower)
cropped_left.save(save_img_path[:-3] + 'jpg')
cropped_right.save(save_img_path)
test_img_dir = os.path.join("./APDrawingDB/data/test")
img_list = os.listdir(test_img_dir)
for image in img_list:
img_path = os.path.join(test_img_dir, image)
if not os.path.isdir(img_path):
img = Image.open(img_path)
#print(img.size)
save_img_dir = os.path.join(test_img_dir, 'split')
if not os.path.exists(save_img_dir):
os.mkdir(save_img_dir)
save_img_path = os.path.join(save_img_dir, image)
cropped_left = img.crop((0, 0, 512, 512)) # (left, upper, right, lower)
cropped_right = img.crop((512, 0, 1024, 512)) # (left, upper, right, lower)
cropped_left.save(save_img_path[:-3] + 'jpg')
3. Organize the segmented data into the datasets folder required for training and testing according to the following hierarchical structure
datasets/
├── test (70 sliced pictures, only original pictures included)
└── train (840 segmented images, including 420 original images and corresponding gt)
Note: You can save the segmented data set to the OBS directory to reduce the disk space occupied by ./work.
4. The complete U-2-Net project structure is as follows:
U-2-Net/
├── .git
├── LICENSE
├── README.md
├── pycache
├── clipping_camera.jpg
├── data_loader.py
├── datasets
├── figures
├── gradio
├── model
├── requirements.txt
├── saved_models
├── setup_model_weights.py
├── test_data
├── u2net_human_seg_test.py
├── u2net_portrait_demo.py
├── u2net_portrait_test.py
├── u2net_test.py
└── u2net_train.py
Four, training
1. The path of the data in the official training code is somewhat different from our datasets. Some modifications to the training script are required. It is recommended to use jupyter notebook to troubleshoot errors.
Create a new Pytorch-1.4 jupyter Notebook file, the name can be train.ipynb
import moxing as mox
# 如果需要从OBS拷贝切分好的训练数据
#mox.file.copy_parallel('obs://bucket-test-xxxx', '/home/ma-user/work/U-2-Net/datasets')
INFO:root:Using MoXing-v1.17.3-43fbf97f
INFO:root:Using OBS-Python-SDK-3.20.7
import os
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.optim as optim
import torchvision.transforms as standard_transforms
import numpy as np
import glob
import os
from data_loader import Rescale
from data_loader import RescaleT
from data_loader import RandomCrop
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset
from model import U2NET
from model import U2NETP
/home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages/skimage/io/manage_plugins.py:23: UserWarning: Your installed pillow version is < 7.1.0. Several security issues (CVE-2020-11538, CVE-2020-10379, CVE-2020-10994, CVE-2020-10177) have been fixed in pillow 7.1.0 or higher. We recommend to upgrade this library.
from .collection import imread_collection_wrapper
bce_loss = nn.BCELoss(size_average=True)
/home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages/torch/nn/_reduction.py:43: UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.
warnings.warn(warning.format(ret))
def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
loss0 = bce_loss(d0,labels_v)
loss1 = bce_loss(d1,labels_v)
loss2 = bce_loss(d2,labels_v)
loss3 = bce_loss(d3,labels_v)
loss4 = bce_loss(d4,labels_v)
loss5 = bce_loss(d5,labels_v)
loss6 = bce_loss(d6,labels_v)
loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.data.item(),loss1.data.item(),loss2.data.item(),loss3.data.item(),loss4.data.item(),loss5.data.item(),loss6.data.item()))
return loss0, loss
model_name = 'u2net' #'u2netp'
data_dir = os.path.join(os.getcwd(), 'datasets', 'train' + os.sep)
# tra_image_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'im_aug' + os.sep)
# tra_label_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'gt_aug' + os.sep)
image_ext = '.jpg'
label_ext = '.png'
model_dir = os.path.join(os.getcwd(), 'saved_models', model_name + os.sep)
epoch_num = 100000
batch_size_train = 24
batch_size_val = 1
train_num = 0
val_num = 0
tra_img_name_list = glob.glob(data_dir + '*' + image_ext)
tra_lbl_name_list = []
for img_path in tra_img_name_list:
img_name = img_path.split(os.sep)[-1]
aaa = img_name.split(".")
bbb = aaa[0:-1]
imidx = bbb[0]
for i in range(1,len(bbb)):
imidx = imidx + "." + bbb[i]
tra_lbl_name_list.append(data_dir + imidx + label_ext)
print("---")
print("train images: ", len(tra_img_name_list))
print("train labels: ", len(tra_lbl_name_list))
print("---")
train_num = len(tra_img_name_list)
---
train images: 420
train labels: 420
---
salobj_dataset = SalObjDataset(
img_name_list=tra_img_name_list,
lbl_name_list=tra_lbl_name_list,
transform=transforms.Compose([
RescaleT(320),
RandomCrop(288),
ToTensorLab(flag=0)]))
salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=1)
# ------- 3. define model --------
# define the net
if(model_name=='u2net'):
net = U2NET(3, 1)
elif(model_name=='u2netp'):
net = U2NETP(3,1)
if torch.cuda.is_available():
net.cuda()
# ------- 4. define optimizer --------
print("---define optimizer...")
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
---define optimizer...
# ------- 5. training process --------
print("---start training...")
ite_num = 0
running_loss = 0.0
running_tar_loss = 0.0
ite_num4val = 0
save_frq = 2000 # save the model every 2000 iterations
---start training...
for epoch in range(0, epoch_num):
net.train()
for i, data in enumerate(salobj_dataloader):
ite_num = ite_num + 1
ite_num4val = ite_num4val + 1
inputs, labels = data['image'], data['label']
inputs = inputs.type(torch.FloatTensor)
labels = labels.type(torch.FloatTensor)
# wrap them in Variable
if torch.cuda.is_available():
inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(),
requires_grad=False)
else:
inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
# y zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)
loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)
loss.backward()
optimizer.step()
# # print statistics
running_loss += loss.data.item()
running_tar_loss += loss2.data.item()
# del temporary outputs and loss
del d0, d1, d2, d3, d4, d5, d6, loss2, loss
print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % (
epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
if ite_num % save_frq == 0:
model_weight = model_dir + model_name+"_bce_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val)
torch.save(net.state_dict(), model_weight)
mox.file.copy_parallel(model_weight, 'obs://bucket-xxxx/output/model_save/' + model_weight.split('/')[-1])
running_loss = 0.0
running_tar_loss = 0.0
net.train() # resume train
ite_num4val = 0
l0: 0.167562, l1: 0.153742, l2: 0.156246, l3: 0.163096, l4: 0.176632, l5: 0.197176, l6: 0.247590
[epoch: 1/100000, batch: 24/ 420, ite: 500] train loss: 1.189413, tar: 0.159183
l0: 0.188048, l1: 0.179041, l2: 0.180086, l3: 0.187904, l4: 0.198345, l5: 0.218509, l6: 0.269199
[epoch: 1/100000, batch: 48/ 420, ite: 501] train loss: 1.266652, tar: 0.168805
l0: 0.192491, l1: 0.187615, l2: 0.188043, l3: 0.197142, l4: 0.203571, l5: 0.222019, l6: 0.261745
[epoch: 1/100000, batch: 72/ 420, ite: 502] train loss: 1.313146, tar: 0.174727
l0: 0.169403, l1: 0.155883, l2: 0.157974, l3: 0.164012, l4: 0.175975, l5: 0.195938, l6: 0.244896
[epoch: 1/100000, batch: 96/ 420, ite: 503] train loss: 1.303333, tar: 0.173662
l0: 0.171904, l1: 0.157170, l2: 0.156688, l3: 0.162020, l4: 0.175565, l5: 0.200576, l6: 0.258133
[epoch: 1/100000, batch: 120/ 420, ite: 504] train loss: 1.299787, tar: 0.173369
l0: 0.177398, l1: 0.166131, l2: 0.169089, l3: 0.176976, l4: 0.187039, l5: 0.205449, l6: 0.248036
Five, test
Create a new Pytorch-1.4 jupyter Notebook file, the name can be test.ipynb
import moxing as mox
# 拷贝数据
mox.file.copy_parallel('obs://bucket-xxxx/output/model_save/u2net.pth', '/home/ma-user/work/U-2-Net/saved_models/u2net/u2net.pth')
import os
import sys
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim
import numpy as np
from PIL import Image
import glob
from data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset
from model import U2NET # full size version 173.6 MB
from model import U2NETP # small version u2net 4.7 MB
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
# normalize the predicted SOD probability map
def normPRED(d):
ma = torch.max(d)
mi = torch.min(d)
dn = (d-mi)/(ma-mi)
return dn
def save_output(image_name,pred,d_dir, show=False):
predict = pred
predict = predict.squeeze()
predict_np = predict.cpu().data.numpy()
im = Image.fromarray(predict_np*255).convert('RGB')
img_name = image_name.split(os.sep)[-1]
image = io.imread(image_name)
imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
pb_np = np.array(imo)
if show:
show_on_notebook(image, im)
aaa = img_name.split(".")
bbb = aaa[0:-1]
imidx = bbb[0]
for i in range(1,len(bbb)):
imidx = imidx + "." + bbb[i]
imo.save(d_dir+imidx+'.png')
return im
def show_on_notebook(image_original, pred): #此函数可以在notebook中展示模型的预测效果
plt.subplot(1,2,1)
imshow(np.array(image_original))
plt.subplot(1,2,2)
imshow(np.array(pred))
# --------- 1. get image path and name ---------
model_name='u2net'#u2netp
image_dir = os.path.join(os.getcwd(), 'datasets', 'test') #注意这里的test_data/original存放的是datasets/test中的原始图片,不包含gt
prediction_dir = os.path.join(os.getcwd(), 'output', model_name + '_results' + os.sep)
model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth')
img_name_list = glob.glob(os.path.join(os.getcwd(), 'datasets/test/*.jpg'))
# print(img_name_list)
# --------- 2. dataloader ---------
#1. dataloader
test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
lbl_name_list = [],
transform=transforms.Compose([RescaleT(320),
ToTensorLab(flag=0)])
)
test_salobj_dataloader = DataLoader(test_salobj_dataset,
batch_size=1,
shuffle=False,
num_workers=1)
# --------- 3. model define ---------
if(model_name=='u2net'):
print("...load U2NET---173.6 MB")
net = U2NET(3,1)
elif(model_name=='u2netp'):
print("...load U2NEP---4.7 MB")
net = U2NETP(3,1)
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_dir))
net.cuda()
else:
net.load_state_dict(torch.load(model_dir, map_location='cpu'))
net.eval()
# --------- 4. inference for each image ---------
for i_test, data_test in enumerate(test_salobj_dataloader):
# print("inferencing:",img_name_list[i_test].split(os.sep)[-1])
inputs_test = data_test['image']
inputs_test = inputs_test.type(torch.FloatTensor)
if torch.cuda.is_available():
inputs_test = Variable(inputs_test.cuda())
else:
inputs_test = Variable(inputs_test)
d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
# normalization
pred = d1[:,0,:,:]
pred = normPRED(pred)
# save results to test_results folder
if not os.path.exists(prediction_dir):
os.makedirs(prediction_dir, exist_ok=True)
save_output(img_name_list[i_test],pred,prediction_dir, show=True)
# sys.exit(0)
del d1,d2,d3,d4,d5,d6,d7
Six, accessories
see attached
If you want to learn more about the dry goods of AI technology, welcome to the AI area of HUAWEI CLOUD. There are currently six practical camps such as AI programming and Python for everyone to learn for free. (Six actual combat camp link: http://su.modelarts.club/qQB9)
Attachment .zip71.28KB
Click to follow and learn about Huawei Cloud's fresh technology for the first time~
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。