欢迎大家来到IT世界,在知识的湖畔探索吧!
图:pixabay
本教程将介绍如何在MNIST图像上构建和训练条件生成式对抗网络(CGAN)。
GAN如何进行工作的
一般来说,生成式对抗模型是同时训练两个模型的:一个是学习从未知分布中输出假样本的生成器,而另一个是学习区分真假样本的鉴别器。
CGAN是GAN的条件变体,其中生成器被指示生成具有特定特征的真实样本,而不是来自完全分布的通用样本。这样的条件可以是与本教程中的图像相关联的标签或者是更为详细的标签,如下图示例所示:
图片来源:Scott Reed
初始设置
运行本教程需要以下软件包:
完整的演示是由GitHub提供的以下两个脚本组成的:
•CGAN_mnist_setup.R:准备数据并定义模型结构
•CGAN_train.R:执行训练操作
准备数据
我们需要的MNIST数据集可在Kaggle上获得。一旦我们将train.csv下载到数据/文件夹后,我们就可以将其导入到R中去。
自定义迭代器在iterators.R中定义,并由CGAN_mnist_setup.R导入。
生成器
生成器是一个从2个输入中创建新样本(MNIST图像)的网络:
•噪声矢量
•定义对象条件的标签(要生成哪个数字)
噪声矢量为Generator模型提供了构建块,它将学习如何将噪声结构化为样本。mx.symbol.Deconvolution操作符用于将初始输入从1×1形状向上采样到28×28图像。
用于生成假样本的标签上的信息是由附加到随机噪声的标签索引的独热编码(one-hot encoding)来提供的。对于MNIST来说,0-9索引因此被转换为长度为10的二进制向量。更复杂的应用将需要的是嵌入而不是简单的单向编码来编码条件。
鉴别器
鉴别器尝试区分生成器产生的假样本和从MNIST训练数据中抽取的真实样本。
在条件式GAN中,与样品相关联的标签也被提供给鉴别器。而在此次的演示中,这些信息将作为一个独热的编码标签,以便传播从而匹配图像的尺寸(10 – >28x28x10)。
训练逻辑
鉴别器的训练过程是最为明显的:损耗就是一个简单的二进制TRUE / FALSE响应,而且损耗可以传播回CNN网络。因此它可以理解为一个简单的二进制分类问题。
生成器损耗来自鉴别器损耗反向传播到其产生的输出。通过将生成器标签伪装成真实样本进入到鉴别器中,鉴别器反向传播损耗为生成器提供了如何最佳地调整其参数,从而欺骗鉴别器相信假样本是真实的信息。
这需要将梯度反向传播到鉴别器的输入数据中(而在普通前馈网络中通常忽略该输入梯度)。
上述训练步骤在CGAN_train.R脚本中执行。
监督训练
在训练期间,相机包(imager package)可以方便进行假样本的视觉质量评估。
以下是在不同训练阶段获得的样本。
从噪音开始:
慢慢地得到下面这个结果——迭代200:
根据需要生成指定的数字图像——迭代2400:
推理
一旦模型被训练,可以通过用固定标签而不是训练期间使用的随机生成的图像馈送到生成器来产生所需数字的合成图像。
在这里我们会产生假的“9”:
CGAN方法的进一步细节可以在Generative Adversarial Text to Image Synthesis论文中找到。
作者:Jeremie Desgagne-Bouchard
来源:http://dmlc.ml
免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://itzsg.com/94481.html