提前说明

GAN生成某某图片数据估计已经被各大博客做烂了。我只是贴一下我的理解和我的步骤。各位加油,找到一个好博客努力搞懂。
文末有完整代码。最好是看着代码就着思路讲解下饭。

GAN

generative adversirial network,经典理论主要由两个部分组成,generator和discriminator,generator生成和数据集相似的新图片,让discriminator分辨这个图片是真实图片还是生成的图片。二者对立统一,discriminator分辨能力提高,促使generator生成更接近真实图片的图片;generator生成更“真”的图片后促使discriminator提高辨识能力。
当然我们想要的东西往往是generator,去生成新图片(或者其他数据)。
自己训练的时候注意:

  1. generator和discriminator要“势均力敌”,才能得到好的generator
  2. 可以适当调节两者的学习进度,比如在真图片中加噪声干扰discriminator学习、调整二者学习率、调整训练次数(比如训练1次discriminator就训练5次generator)等等。

实现思路

准备

导入包,image_size是28×28,后面会用到,我把一张图片直接作为一个[1,28×28]的向量来处理了。
然后一个工具类,为了查看dataloader里面的数据,其实没什么用,你自己可以写个自己的版本的。

import torch
import torch.utils.data
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision

DEVICE= 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE=128
IMAGE_SIZE= 28 * 28 # it denotes vector length as well

def ShowDataLoader(dataloader,num):
    i=0
    for imgs,labs in dataloader:
        print("imgs",imgs.shape)
        print("labs", labs.shape)
        i += 1
        if i==num: break
    return

加载数据集

加载图片自然少不了先对图片预处理,torchvision提供了大量的函数帮助。transform先变成tensor,然后对其进行正则化。由于我不想用那么多的数据,所以我对数据进行了分割,只拿出了32*1500条数据。
里米那个MNIST如果没有数据集,可以直接下载的,download=True设置一下就可以了。
然后pytroch dataloader加载进去,pytorch基本操作了。

if __name__ == '__main__':

    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5],std=[0.5])
    ])

    #load mnist
    mnist=torchvision.datasets.MNIST(root="./data-source",train=True,transform=transform)
    #split data set
    whole_length=len(mnist)
    # data length should be number that can be divided by bath_size
    sub_length=32*1500
    sub_minist1,sub_minist2=torch.utils.data.random_split(mnist, [sub_length, whole_length - sub_length])


    #load dataset
    dataloader=torch.utils.data.DataLoader(dataset=sub_minist1, batch_size=BATCH_SIZE,shuffle=True)
    # plt.imshow(next(iter(dataloader))[0][0][0])
    # plt.show()

创建discriminator和generator等

用sequential方便一点,声明了两个网络。
选用BCELoss作为loss function,optimizer也可以用别的,两个网络一人一个optimizer。

discriminator的输入好说,不管是真的还是假的,就是一批batch的图片。
generator的输入经典论文是推荐一个latent vector,其实就是个随机生成的向量,经过generator变化后,把它生成到一个batch的图片数据。

learning rate也可调,具体看训练情况。

Discriminitor=nn.Sequential(
        nn.Linear(IMAGE_SIZE, 300),
        nn.LeakyReLU(0.2),
        nn.Linear(300,150),
        nn.LeakyReLU(0.2),
        nn.Linear(150,1),
        nn.Sigmoid()
    )
    Discriminitor = Discriminitor.to(DEVICE)


    latent_size=64
    Generator=nn.Sequential(
        nn.Linear(latent_size,150),
        nn.ReLU(True),
        nn.Linear(150,300),
        nn.ReLU(True),
        nn.Linear(300, IMAGE_SIZE),
        nn.Tanh()#change it into range of (-1,1)
    )

    Generator=Generator.to(DEVICE)

    loss_fn=nn.BCELoss()

    d_optimizer=torch.optim.SGD(Discriminitor.parameters(), lr=0.002)
    g_optimizer=torch.optim.Adam(Generator.parameters(), lr=0.002)

训练

训练代码格式,pytorch经典写法我就不必多解释。
我实现的时候秉着这样的想法:

  1. 通过矩阵变换,变成一张图片一个向量,维度[1,28*28]。
  2. 先计算discriminator的loss,其来自两部分,一部分是真实数据,一部分是generator生成的数据。
  3. 由于我们的label并不是MNIST的label,而是应该表示图片真假的label,所以我们需要自己做label。
  4. 两部分loss进行backward之后,discriminator的训练就算完成。
  5. generator生成的图片如果被discriminator判别程假的那么就是失败的,所以其希望其生成的图片的标签应当是“真”。所以loss是discriminator的结果和全真向量所比较的BCEloss。
    loader_len=len(dataloader)
    EPOCH =30
    G_EPOCH=1
    for epoch in range(EPOCH):
        for i,(images, _) in enumerate(dataloader):
            images=images.reshape(images.shape[0], IMAGE_SIZE).to(DEVICE)

            # noise distraction
            # noise=torch.randn(images.shape[0], IMAGE_SIZE)
            # images=noise+images
            
            #make labels for training
            label_real_pic = torch.ones(BATCH_SIZE, 1).to(DEVICE)
            label_fake_pic = torch.zeros(BATCH_SIZE, 1).to(DEVICE)
            
            #have a glance at real image
            if i%100==0:
                plt.title('real')
                data=images.view(BATCH_SIZE, 28, 28).data.cpu().numpy()
                plt.imshow(data[0])
                plt.pause(1)

            #calculate loss of the "real part"
            res_real=Discriminitor(images)
            d_loss_real=loss_fn(res_real, label_real_pic)
            
            #calculate loss of the "fake part"
            #generate fake image
            z=torch.randn(BATCH_SIZE,latent_size).to(DEVICE)
            fake_imgs=Generator(z)

            res_fake=Discriminitor(fake_imgs.detach()) #detach means to fix the param.
            d_loss_fake=loss_fn(res_fake,label_fake_pic)

            d_loss=d_loss_fake+d_loss_real
            
            #update discriminator model
            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()

            #change G_EPOCH to modify epoch of discriminator
            for dummy in range(G_EPOCH):
                tt=torch.randn(BATCH_SIZE,latent_size).to(DEVICE)
                fake1=Generator(tt)
                res_fake2=Discriminitor(fake1)
                g_loss=loss_fn(res_fake2,label_real_pic)
                g_optimizer.zero_grad()
                g_loss.backward()
                g_optimizer.step()

            if i %50==0:
                print("Epoch [{}/{}], Step [ {}/{} ], d_loss: {:.4f}, g_loss: {:.4f}, "
                      .format(epoch,EPOCH,i,loader_len,d_loss.item(),g_loss.item()))
                #take a look at how generator is at the moment
                temp = torch.randn(BATCH_SIZE, latent_size).to(DEVICE)
                fake_temp = Generator(temp)
                ff = fake_temp.view(BATCH_SIZE, 28, 28).data.cpu().numpy()
                plt.title('generated')
                plt.imshow(ff[0])
                plt.pause(1)

效果

我这个程序会隔一段时间显示当前训练的batch中的一张,真实数据和生成数据都有,代表当前情况。
其中生成数据可以看作当前generator能生成到什么程度了。当然你也可以看两方的loss,在console可以看到。

完整代码

import torch
import torch.utils.data
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision


DEVICE= 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE=128
IMAGE_SIZE= 28 * 28 # it denotes vector length as well

def ShowDataLoader(dataloader,num):
    i=0
    for imgs,labs in dataloader:
        print("imgs",imgs.shape)
        print("labs", labs.shape)
        i += 1
        if i==num: break
    return


if __name__ == '__main__':

    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5],std=[0.5])
    ])

    #load mnist
    mnist=torchvision.datasets.MNIST(root="./data-source",train=True,transform=transform)
    #split data set
    whole_length=len(mnist)
    # data length should be number that can be divided by bath_size
    sub_length=32*1500
    sub_minist1,sub_minist2=torch.utils.data.random_split(mnist, [sub_length, whole_length - sub_length])


    #load dataset
    dataloader=torch.utils.data.DataLoader(dataset=sub_minist1, batch_size=BATCH_SIZE,shuffle=True)
    # plt.imshow(next(iter(dataloader))[0][0][0])
    # plt.show()



    Discriminitor=nn.Sequential(
        nn.Linear(IMAGE_SIZE, 300),
        nn.LeakyReLU(0.2),
        nn.Linear(300,150),
        nn.LeakyReLU(0.2),
        nn.Linear(150,1),
        nn.Sigmoid()
    )
    Discriminitor = Discriminitor.to(DEVICE)


    latent_size=64
    Generator=nn.Sequential(
        nn.Linear(latent_size,150),
        nn.ReLU(True),
        nn.Linear(150,300),
        nn.ReLU(True),
        nn.Linear(300, IMAGE_SIZE),
        nn.Tanh()#change it into range of (-1,1)
    )

    Generator=Generator.to(DEVICE)

    loss_fn=nn.BCELoss()

    d_optimizer=torch.optim.SGD(Discriminitor.parameters(), lr=0.002)
    g_optimizer=torch.optim.Adam(Generator.parameters(), lr=0.002)


    loader_len=len(dataloader)
    EPOCH =30
    G_EPOCH=1
    for epoch in range(EPOCH):
        for i,(images, _) in enumerate(dataloader):
            images=images.reshape(images.shape[0], IMAGE_SIZE).to(DEVICE)

            # noise=torch.randn(images.shape[0], IMAGE_SIZE)
            # images=noise+images
            
            label_real_pic = torch.ones(BATCH_SIZE, 1).to(DEVICE)
            label_fake_pic = torch.zeros(BATCH_SIZE, 1).to(DEVICE)

            if i%100==0:
                plt.title('real')
                data=images.view(BATCH_SIZE, 28, 28).data.cpu().numpy()
                plt.imshow(data[0])
                plt.pause(1)

            res_real=Discriminitor(images)

            d_loss_real=loss_fn(res_real, label_real_pic)

            #generate fake image
            z=torch.randn(BATCH_SIZE,latent_size).to(DEVICE)
            fake_imgs=Generator(z)

            res_fake=Discriminitor(fake_imgs.detach()) #detach means to fix the param.

            d_loss_fake=loss_fn(res_fake,label_fake_pic)
            d_loss=d_loss_fake+d_loss_real

            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()


            for j in range(G_EPOCH):
                tt=torch.randn(BATCH_SIZE,latent_size).to(DEVICE)
                fake1=Generator(tt)
                res_fake2=Discriminitor(fake1)
                g_loss=loss_fn(res_fake2,label_real_pic)
                g_optimizer.zero_grad()
                g_loss.backward()
                g_optimizer.step()

            if i %50==0:
                print("Epoch [{}/{}], Step [ {}/{} ], d_loss: {:.4f}, g_loss: {:.4f}, "
                      .format(epoch, EPOCH, i, loader_len, d_loss.item(), g_loss.item()))

                temp = torch.randn(BATCH_SIZE, latent_size).to(DEVICE)
                fake_temp = Generator(temp)
                ff = fake_temp.view(BATCH_SIZE, 28, 28).data.cpu().numpy()
                plt.title('generated')
                plt.imshow(ff[0])
                plt.pause(1)

Yonggie
95 声望4 粉丝