文本生成图像之AttnGAN

一个新系列文本生成图像,这个是之前一直在研究的东西,有一些idea,但gan的训练有点坑而且很费时。先记录下来,放这。

AttnGAN

一、github与论文链接

github链接: [https://github.com/cn-boop/At...]()

论文链接:AttnGAN: Fine-Grained Text to Image Generation with Attentional Generative Adversarial Networks

二、阅读总结

1.Abstract

在本文中作者提出了一个 Attentional Generative Ad-
versarial Network(AttnGAN),一种attention-driven的多stage的细粒度文本到图像生成器。
并借助一个深层注意多模态相似模型(deep attentional multimodal similarity model)来训练该生成器。
它首次表明 the layered attentional GAN 能够自动选择单词级别的condition来生成图像的不同部分。

2.Model Structure


模型由两部分组成:
image

  • attentional generative network

该部分使用了注意力机制来生成图像中的子区域,并且在生成每个子区域时还考虑了文本中与该子区域最相关的词。

  • Deep Attentional Multimodal Similarity Model (DAMSM)

该部分用来计算生成的图像与文本的匹配程度。用来训练生成器。

3.Pipeline

  • 输入的文本通过一个Text Encoder 得到 sentence feature 和word features
  • 用sentence feature 生成一个低分辨率的图像I0
  • 基于I0 加入 word features 和setence feature 生成更高分辨率细粒度的图像

三、代码详解

1.attentional generative network

  • step1:使用text_encoder的得到sentence features 和word features.
  • step2:sentence features(sentence embedding)提取condition ,然后与z结合产生低分辨率的图像以及对应的图像特征h0.
  • step3: 每一层低分辨图像的特征被用来生成下一层的高分辨图像特征

word feature(e,大小为DT) 与h0 (大小为D’N)通过attention model ,输出大小为D’*N的张量

1.text_encoder

text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM
class RNN_ENCODER(nn.Module):
    def __init__(self, ntoken, ninput=300, drop_prob=0.5,
             nhidden=128, nlayers=1, bidirectional=True):
    super(RNN_ENCODER, self).__init__()
    ......
得到的word_emb为[2,256,18]   sent_emb:[2,256]
num_words = words_embs.size(2) #每句话18个词

2.image_encoder


image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)# 256
self.emb_features = conv1x1(768, self.nef)  #self.enf:256
self.emb_cnn_code = nn.Linear(2048, self.nef)
def forward(self, x):
    features = None
    # --> fixed-size input: batch x 3 x 299 x 299
    x = nn.Upsample(size=(299, 299), mode='bilinear')(x)
    ......
image_encoder用于最后提取256*256图像的图像特征,输入是256*256的图像,输出是2048维的向量。采用的是inceptionv3的网络架构

3.generative network

第一个生成器

 if cfg.TREE.BRANCH_NUM > 0:
self.h_net1 = INIT_STAGE_G(ngf * 16, ncf)
self.img_net1 = GET_IMAGE_G(ngf)
class INIT_STAGE_G(nn.Module):  #可以看出  和stackgan的第一个生成器类似
def forward(self, z_code, c_code):
    """
    :param z_code: batch x cfg.GAN.Z_DIM
    :param c_code: batch x cfg.TEXT.EMBEDDING_DIM
    :return: batch x ngf/16 x 64 x 64
    """
    ......
输入sent_emb,得到[2,32,64,64]的h_code1和[2,3,64,64]的fake_img1。(这时候还没用到word embedding)

后面的生成器


if cfg.TREE.BRANCH_NUM > 1:
    self.h_net2 = NEXT_STAGE_G(ngf, nef, ncf)
    self.img_net2 = GET_IMAGE_G(ngf)
class NEXT_STAGE_G(nn.Module):
    def __init__(self, ngf, nef, ncf):
        super(NEXT_STAGE_G, self).__init__()
        self.gf_dim = ngf
        self.ef_dim = nef
        self.cf_dim = ncf
        ......
输入上一层的h_code以及c_code,word_embs,mask.mask是个啥不知道,在论文和源码中均未找到解释。大小为[2,18].返回[2,32,128,128]的h_code2和[2,18,64,64]的attn

生成器Loss Function


def generator_loss(netsD, image_encoder, fake_imgs, real_labels,
               words_embs, sent_emb, match_labels,
               cap_lens, class_ids):
    numDs = len(netsD)
    batch_size = real_labels.size(0)
    logs = ''
    # Forward
    errG_total = 0
    ......

3.discriminator network

class D_NET64(nn.Module):
    def __init__(self, b_jcu=True):
        super(D_NET64, self).__init__()
        ndf = cfg.GAN.DF_DIM  #64
        nef = cfg.TEXT.EMBEDDING_DIM  #256
        self.img_code_s16 = encode_image_by_16times(ndf)
        if b_jcu:
            self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False)
                ......

辨别器Loss Function

def discriminator_loss(netD, real_imgs, fake_imgs, conditions,
                   real_labels, fake_labels):
    # Forward
    real_features = netD(real_imgs)
    fake_features = netD(fake_imgs.detach())
    # loss
    #
    cond_real_logits = netD.COND_DNET(real_features, conditions)
    ......

2.Deep Attentional Multimodal Similarity Model (DAMSM)

DAMSM Structure

  • text_encoder

    是一个双向LSTM.输出sentence embedding和word embedding (e :D*T)
  • image encoder

    
    它的输出是一个2048维的向量,代表了整个图像的特征:f'
  • s=e.transpose()*v,

这样把图像和句子结合在了一起.s(i,j)代表了句子中的第i个单词和第图像中第j个区域的相关性.
最后s和h结合得到相关损失.

DAMSM Loss Function


Q:generative pictures             D:text description
def evaluate(dataloader, cnn_model, rnn_model, batch_size):
        cnn_model.eval()
       rnn_model.eval()
    s_total_loss = 0
    w_total_loss = 0
    for step, data in enumerate(dataloader, 0):
        real_imgs, captions, cap_lens, \
            class_ids, keys = prepare_data(data)
            words_features, sent_code = cnn_model(real_imgs[-1])
        # nef = words_features.size(1)
        # words_features = words_features.view(batch_size, nef, -1)
        hidden = rnn_model.init_hidden(batch_size)
        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)
        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
                                            cap_lens, class_ids, batch_size)



nlp learner

0 声望
5 粉丝
0 条评论
推荐阅读
关于latex写论文的几个坑
最近写自己的第一篇paper真的有点难受,在用latex写论文的时候也出现了一些坑,在这记录一下: 1.双栏模板中,表格太大,超过单栏的位置: 使用 {代码...} 在begin{tabular}{l c c c}加上: {代码...} 如果上述两...

cn_boop1阅读 2.8k评论 1

算法可视化:一文弄懂 10 大排序算法
在本文中,我们将通过动图可视化加文字的形式,循序渐进全面介绍不同类型的算法及其用途(包括原理、优缺点及使用场景)并提供 Python 和 JavaScript 两种语言的示例代码。除此之外,每个算法都会附有一些技术说...

破晓L7阅读 906

封面图
TOPI 简介
这是 TVM 算子清单(TOPI)的入门教程。 TOPI 提供了 numpy 风格的通用操作和 schedule,其抽象程度高于 TVM。本教程将介绍 TOPI 是如何使得 TVM 中的代码不那么样板化的。

超神经HyperAI1阅读 90.7k

编译 PyTorch 模型
本篇文章译自英文文档 Compile PyTorch Models。作者是 Alex Wong。更多 TVM 中文文档可访问 →TVM 中文站。本文介绍了如何用 Relay 部署 PyTorch 模型。首先应安装 PyTorch。此外,还应安装 TorchVision,并将其...

超神经HyperAI1阅读 92.9k

编译 MXNet 模型
本篇文章译自英文文档 Compile MXNet Models。作者是 Joshua Z. Zhang,Kazutaka Morita。更多 TVM 中文文档可访问 →TVM 中文站。本文将介绍如何用 Relay 部署 MXNet 模型。首先安装 mxnet 模块,可通过 pip 快速...

超神经HyperAI1阅读 42k

横向对比 11 种算法,多伦多大学推出机器学习模型,加速长效注射剂新药研发
内容一览:长效注射剂是解决慢性病的有效药物之一,不过,该药物制剂的研发耗时、费力,颇具挑战。对此,多伦多大学研究人员开发了一个基于机器学习的模型,该模型能预测长效注射剂药物释放速率,从而提速药物整...

超神经HyperAI1阅读 31k

封面图
科罗拉多州立大学发布CSU-MLP模型,用随机森林预测中期恶劣天气
内容一览:近期,来自美国科罗拉多州立大学与 SPC 的相关学者联合发布了一个基于随机森林的机器学习模型 CSU-MLP,该模型能够对中期 (4-8天) 范围内恶劣天气进行准确预报。目前该成果刊已发表在《Weather and For...

超神经HyperAI阅读 48.9k

封面图

nlp learner

0 声望
5 粉丝
宣传栏