提前说明
GAN生成某某图片数据估计已经被各大博客做烂了。我只是贴一下我的理解和我的步骤。各位加油,找到一个好博客努力搞懂。
文末有完整代码。最好是看着代码就着思路讲解下饭。
GAN
generative adversirial network,经典理论主要由两个部分组成,generator和discriminator,generator生成和数据集相似的新图片,让discriminator分辨这个图片是真实图片还是生成的图片。二者对立统一,discriminator分辨能力提高,促使generator生成更接近真实图片的图片;generator生成更“真”的图片后促使discriminator提高辨识能力。
当然我们想要的东西往往是generator,去生成新图片(或者其他数据)。
自己训练的时候注意:
- generator和discriminator要“势均力敌”,才能得到好的generator
- 可以适当调节两者的学习进度,比如在真图片中加噪声干扰discriminator学习、调整二者学习率、调整训练次数(比如训练1次discriminator就训练5次generator)等等。
实现思路
准备
导入包,image_size是28×28,后面会用到,我把一张图片直接作为一个[1,28×28]的向量来处理了。
然后一个工具类,为了查看dataloader里面的数据,其实没什么用,你自己可以写个自己的版本的。
import torch
import torch.utils.data
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision
DEVICE= 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE=128
IMAGE_SIZE= 28 * 28 # it denotes vector length as well
def ShowDataLoader(dataloader,num):
i=0
for imgs,labs in dataloader:
print("imgs",imgs.shape)
print("labs", labs.shape)
i += 1
if i==num: break
return
加载数据集
加载图片自然少不了先对图片预处理,torchvision提供了大量的函数帮助。transform先变成tensor,然后对其进行正则化。由于我不想用那么多的数据,所以我对数据进行了分割,只拿出了32*1500条数据。
里米那个MNIST如果没有数据集,可以直接下载的,download=True设置一下就可以了。
然后pytroch dataloader加载进去,pytorch基本操作了。
if __name__ == '__main__':
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5],std=[0.5])
])
#load mnist
mnist=torchvision.datasets.MNIST(root="./data-source",train=True,transform=transform)
#split data set
whole_length=len(mnist)
# data length should be number that can be divided by bath_size
sub_length=32*1500
sub_minist1,sub_minist2=torch.utils.data.random_split(mnist, [sub_length, whole_length - sub_length])
#load dataset
dataloader=torch.utils.data.DataLoader(dataset=sub_minist1, batch_size=BATCH_SIZE,shuffle=True)
# plt.imshow(next(iter(dataloader))[0][0][0])
# plt.show()
创建discriminator和generator等
用sequential方便一点,声明了两个网络。
选用BCELoss作为loss function,optimizer也可以用别的,两个网络一人一个optimizer。
discriminator的输入好说,不管是真的还是假的,就是一批batch的图片。
generator的输入经典论文是推荐一个latent vector,其实就是个随机生成的向量,经过generator变化后,把它生成到一个batch的图片数据。
learning rate也可调,具体看训练情况。
Discriminitor=nn.Sequential(
nn.Linear(IMAGE_SIZE, 300),
nn.LeakyReLU(0.2),
nn.Linear(300,150),
nn.LeakyReLU(0.2),
nn.Linear(150,1),
nn.Sigmoid()
)
Discriminitor = Discriminitor.to(DEVICE)
latent_size=64
Generator=nn.Sequential(
nn.Linear(latent_size,150),
nn.ReLU(True),
nn.Linear(150,300),
nn.ReLU(True),
nn.Linear(300, IMAGE_SIZE),
nn.Tanh()#change it into range of (-1,1)
)
Generator=Generator.to(DEVICE)
loss_fn=nn.BCELoss()
d_optimizer=torch.optim.SGD(Discriminitor.parameters(), lr=0.002)
g_optimizer=torch.optim.Adam(Generator.parameters(), lr=0.002)
训练
训练代码格式,pytorch经典写法我就不必多解释。
我实现的时候秉着这样的想法:
- 通过矩阵变换,变成一张图片一个向量,维度[1,28*28]。
- 先计算discriminator的loss,其来自两部分,一部分是真实数据,一部分是generator生成的数据。
- 由于我们的label并不是MNIST的label,而是应该表示图片真假的label,所以我们需要自己做label。
- 两部分loss进行backward之后,discriminator的训练就算完成。
- generator生成的图片如果被discriminator判别程假的那么就是失败的,所以其希望其生成的图片的标签应当是“真”。所以loss是discriminator的结果和全真向量所比较的BCEloss。
loader_len=len(dataloader)
EPOCH =30
G_EPOCH=1
for epoch in range(EPOCH):
for i,(images, _) in enumerate(dataloader):
images=images.reshape(images.shape[0], IMAGE_SIZE).to(DEVICE)
# noise distraction
# noise=torch.randn(images.shape[0], IMAGE_SIZE)
# images=noise+images
#make labels for training
label_real_pic = torch.ones(BATCH_SIZE, 1).to(DEVICE)
label_fake_pic = torch.zeros(BATCH_SIZE, 1).to(DEVICE)
#have a glance at real image
if i%100==0:
plt.title('real')
data=images.view(BATCH_SIZE, 28, 28).data.cpu().numpy()
plt.imshow(data[0])
plt.pause(1)
#calculate loss of the "real part"
res_real=Discriminitor(images)
d_loss_real=loss_fn(res_real, label_real_pic)
#calculate loss of the "fake part"
#generate fake image
z=torch.randn(BATCH_SIZE,latent_size).to(DEVICE)
fake_imgs=Generator(z)
res_fake=Discriminitor(fake_imgs.detach()) #detach means to fix the param.
d_loss_fake=loss_fn(res_fake,label_fake_pic)
d_loss=d_loss_fake+d_loss_real
#update discriminator model
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
#change G_EPOCH to modify epoch of discriminator
for dummy in range(G_EPOCH):
tt=torch.randn(BATCH_SIZE,latent_size).to(DEVICE)
fake1=Generator(tt)
res_fake2=Discriminitor(fake1)
g_loss=loss_fn(res_fake2,label_real_pic)
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
if i %50==0:
print("Epoch [{}/{}], Step [ {}/{} ], d_loss: {:.4f}, g_loss: {:.4f}, "
.format(epoch,EPOCH,i,loader_len,d_loss.item(),g_loss.item()))
#take a look at how generator is at the moment
temp = torch.randn(BATCH_SIZE, latent_size).to(DEVICE)
fake_temp = Generator(temp)
ff = fake_temp.view(BATCH_SIZE, 28, 28).data.cpu().numpy()
plt.title('generated')
plt.imshow(ff[0])
plt.pause(1)
效果
我这个程序会隔一段时间显示当前训练的batch中的一张,真实数据和生成数据都有,代表当前情况。
其中生成数据可以看作当前generator能生成到什么程度了。当然你也可以看两方的loss,在console可以看到。
完整代码
import torch
import torch.utils.data
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision
DEVICE= 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE=128
IMAGE_SIZE= 28 * 28 # it denotes vector length as well
def ShowDataLoader(dataloader,num):
i=0
for imgs,labs in dataloader:
print("imgs",imgs.shape)
print("labs", labs.shape)
i += 1
if i==num: break
return
if __name__ == '__main__':
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5],std=[0.5])
])
#load mnist
mnist=torchvision.datasets.MNIST(root="./data-source",train=True,transform=transform)
#split data set
whole_length=len(mnist)
# data length should be number that can be divided by bath_size
sub_length=32*1500
sub_minist1,sub_minist2=torch.utils.data.random_split(mnist, [sub_length, whole_length - sub_length])
#load dataset
dataloader=torch.utils.data.DataLoader(dataset=sub_minist1, batch_size=BATCH_SIZE,shuffle=True)
# plt.imshow(next(iter(dataloader))[0][0][0])
# plt.show()
Discriminitor=nn.Sequential(
nn.Linear(IMAGE_SIZE, 300),
nn.LeakyReLU(0.2),
nn.Linear(300,150),
nn.LeakyReLU(0.2),
nn.Linear(150,1),
nn.Sigmoid()
)
Discriminitor = Discriminitor.to(DEVICE)
latent_size=64
Generator=nn.Sequential(
nn.Linear(latent_size,150),
nn.ReLU(True),
nn.Linear(150,300),
nn.ReLU(True),
nn.Linear(300, IMAGE_SIZE),
nn.Tanh()#change it into range of (-1,1)
)
Generator=Generator.to(DEVICE)
loss_fn=nn.BCELoss()
d_optimizer=torch.optim.SGD(Discriminitor.parameters(), lr=0.002)
g_optimizer=torch.optim.Adam(Generator.parameters(), lr=0.002)
loader_len=len(dataloader)
EPOCH =30
G_EPOCH=1
for epoch in range(EPOCH):
for i,(images, _) in enumerate(dataloader):
images=images.reshape(images.shape[0], IMAGE_SIZE).to(DEVICE)
# noise=torch.randn(images.shape[0], IMAGE_SIZE)
# images=noise+images
label_real_pic = torch.ones(BATCH_SIZE, 1).to(DEVICE)
label_fake_pic = torch.zeros(BATCH_SIZE, 1).to(DEVICE)
if i%100==0:
plt.title('real')
data=images.view(BATCH_SIZE, 28, 28).data.cpu().numpy()
plt.imshow(data[0])
plt.pause(1)
res_real=Discriminitor(images)
d_loss_real=loss_fn(res_real, label_real_pic)
#generate fake image
z=torch.randn(BATCH_SIZE,latent_size).to(DEVICE)
fake_imgs=Generator(z)
res_fake=Discriminitor(fake_imgs.detach()) #detach means to fix the param.
d_loss_fake=loss_fn(res_fake,label_fake_pic)
d_loss=d_loss_fake+d_loss_real
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
for j in range(G_EPOCH):
tt=torch.randn(BATCH_SIZE,latent_size).to(DEVICE)
fake1=Generator(tt)
res_fake2=Discriminitor(fake1)
g_loss=loss_fn(res_fake2,label_real_pic)
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
if i %50==0:
print("Epoch [{}/{}], Step [ {}/{} ], d_loss: {:.4f}, g_loss: {:.4f}, "
.format(epoch, EPOCH, i, loader_len, d_loss.item(), g_loss.item()))
temp = torch.randn(BATCH_SIZE, latent_size).to(DEVICE)
fake_temp = Generator(temp)
ff = fake_temp.view(BATCH_SIZE, 28, 28).data.cpu().numpy()
plt.title('generated')
plt.imshow(ff[0])
plt.pause(1)
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。