写点什么

如何应用 TFGAN 快速实践生成对抗网络?

  • 2018-06-03
  • 本文字数:4252 字

    阅读完需:约 14 分钟

前言

生成对抗网络(Generative Adversarial Nets ,GAN)目前已广泛应用于图像生成、超分辨率图片生成、图像压缩、图像风格转换、数据增强、文本生成等场景。越来越多的研发人员从事 GAN 网络的研究,提出了各种 GAN 模型的变种,包括 CGAN、InfoGAN、WGAN、CycleGAN 等。为了更容易地应用及实践 GAN 模型,谷歌开源了名为 TFGAN 的 TensorFlow 库,可快速实践各种 GAN 模型。本文主要讲解 TFGAN 如何应用于原生 GAN、CGAN、InfoGAN、WGAN 等场景,如下所示:

其中,原生GAN 生成的Mnist 图像不可控:CGAN 可按照数字标签生成相应标签的数字图像;InfoGAN 可认为是无监督的CGAN,前两行表示用分类潜变量控制数字的生成类别,中间两行表示用连续型潜变量控制数字的粗细,最后两行表示用连续型潜变量控制数字的倾斜方向;ImageToImage 是CGAN 的一种,实现图像的风格转换。

生成对抗网络与TFGAN

GAN 由 Goodfellow 首先提出,主要由两部分构成:Generator(生成器),简称 G;Discriminator(判别器), 简称 D。生成器主要用噪声 z 生成一个类似真实数据的样本,样本越逼真越好;判别器用于估计一个样本来自于真实数据还是生成数据,判定越准确越好。如下图所示:

上图中,对于真实的采样数据,通过判别网络后,生成D(x)。D(x) 的输出是0-1 范围内的一个实数,用来判断这个图片是一个真实图片的概率是多大。这样对于真实数据,D(x) 越接近1 越好。对于随机噪声z,通过生成网络G 后,G 将这个随机噪声转化为生成数据x。如果是图片生成问题,G 网络的输出就是一张生成的假图片,用G(z) 表示。判别模型D 要使得D(G(z)) 接近与0,即能够判断生成的图片是假的;生成模型G 要使得D(G(z)) 接近于1,即要能够要欺骗判别模型,使得D 认为G(z) 生成的假数据是真的。这样通过判别模型D 和生成模型G 的博弈,使得D 无法判断一张图片是生成出来的还是真实的而结束。

假设P_r 和P_g 分别代表真实数据的分布与生成数据的分布,这样判别模型的目标函数可以表示为:

而生成模型的是让判别模型D 无法区别真实数据与生成数据,这样优化目标函数为:

TFGAN 库的地址为 https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/gan ,主要包含以下几个组件:

  1. 核心架构,主要包括创建 TFGAN 模型,添加 Loss 值,创建训练 operation,运行训练 operation。
  2. 常用操作,主要提供了梯度修剪操作,归一化操作及条件化操作等。
  3. 损失函数,主要提供了 GAN 中常用的损失和惩罚函数,如 Wasserstein 损失、梯度惩罚、互信息惩罚等。
  4. 模型评估,提供了 Inception Score 和 Frechet Distance 指标,用于评估无条件生成模型。
  5. 示例,谷歌同时开源了常用的 GAN 网络示例代码,包括 unconditional GAN,conditional GAN, InfoGAN,WGAN 等。相关用例可从 https://github.com/tensorflow/models/tree/master/research/gan/ 地址下载。

使用 TFGAN 库训练 GAN 网络主要包含如下几个步骤:

1. 确定 GAN 网络的输入,如下所示:

复制代码
images = mnist_data_provider.provide_data(FLAGS.batch_size)
noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])

2. 设定 GANModel 中的生成模型和判别模型,如下所示:

复制代码
gan_model = tfgan.gan_model(
generator_fn=mnist.unconditional_generator, # you define
discriminator_fn=mnist.unconditional_discriminator, # you define
real_data=images,
generator_inputs=noise)

3. 设定 GANLoss 中的损失方程,如下所示:

复制代码
gan_loss = tfgan.gan_loss(
gan_model,
generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss)

4. 设定 GANTrainOps 中的训练操作,如下所示:

复制代码
train_ops = tfgan.gan_train_ops(
gan_model,
gan_loss,
generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5))

5. 运行模型训练,如下所示:

复制代码
tfgan.gan_train(
train_ops,
hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps)],
logdir=FLAGS.train_log_dir)

CGAN

CGAN(Conditional Generative Adversarial Nets),针对 GAN 本身不可控的缺点,加入监督信息,训练从无监督变成有监督,指导 GAN 网络进行生成。例如输入分类的标签,可生成相应标签的图像。这样 CGAN 的目标方程可以转换为:

其中,y 是加入的监督信息,D(x|y) 表示在y 的条件下判定真实数据x,D(G(z|y)) 表示在y 的条件下判定生成数据G(z|y)。例如,MNIST 数据集可根据数字label 信息,生成相应标签的图片;人脸生成数据集,可根据性别、是否微笑、年龄等信息,生成相应的人脸图片。CGAN 的架构如下图所示:

在TFGAN 中提供了,基于one_hot_labels 变量和输入tensor 生成condition tensor 的API,如下所示:

tfgan.features.condition_tensor_from_onehot (tensor, one_hot_labels, embedding_size)其中,tensor 为输入数据,one_hot_labels 为 onehot 标签,shape 为 [batch_size, num_classes],embedding_size 为每个 label 对应的 embedding 大小,返回值为 condition tensor。

ImageToImage

Phillip Isola 等提出了基于 CGAN 的图片生成图片的对抗神经网络《Image-to-Image Translation with Conditional Adversarial Networks》。网络设计的基本思想如下所示:

其中,x 为输入的线条图,G(x) 为生成图片,y 为线条图x 对应渲染后的真图片,生成模型G 用于生成图片,判断模型D 用于判定生成图片的真假。判别网络能够最大化判断(x,y) 的数据为真,判断(x,G(x)) 数据为假。而生成网络使得判别网络判断(x,G(x)) 数据为真,从而进行生成模型和判别模型的相互博弈。为了使生成模型不仅能够欺骗判别模型,还要使得生成图像要像真实图片,这样在目标函数中加入了真实图像和生成图像的L1 距离,如下所示:

TFGAN 库,提供了 ImageToImage 生成对抗网络的相关损失方程 API 使用示例,如下所示:

复制代码
# 定义真实数据与生成数据的 L1 损失
l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1) / FLAGS.patch_size ** 2
# gan_loss 为目标函数损失
gan_loss = tfgan.losses.combine_adversarial_loss(gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor)

InfoGAN

在 GAN 中,生成器用噪声 z 生成数据时,没有加任何的条件限制,很难用 z 的任何一个维度信息表示相关的语义特征。所以在数据生成过程中,无法控制什么样的噪声 z 可以生成什么样的数据,在很大程度上限制了 GAN 的使用。InfoGAN 可以认为是无监督的 CGAN,在噪声 z 上增加潜变量 c,使得生成模型生成的数据与浅变量 c 具有较高的互信息,其中 Info 就是代表互信息的含义。互信息定义为两个熵的差值,H(x) 是先验分布的熵,H(x|y) 代表后验分布的熵。如果 x,y 是相互独立的变量,那么互信息的值为 0,表示 x,y 没有关系;如果 x,y 有相关性,那么互信息大于 0。这样在已知 y 的情况下,可以推断出那些 x 的值出现高。这样 InfoGAN 的目标方程为:

InfoGAN 的网络结构如下所示:

上图中InfoGAN 与GAN 的区别在于,对应判别网络的输出D(x),生成变分分布Q(c|x),从而能用Q(c|x) 来逼近P(c|x),从而增大生成数据与潜变量c 的互信息。

TFGAN 中提供了 InfoGan 相关 API,如下所示:

复制代码
#通过 tfgan.infogan_model,定义 infogan 模型
infogan_model = tfgan.infogan_model(
generator_fn=generator_fn,
discriminator_fn=discriminator_fn,
real_data=real_images,
unstructured_generator_inputs=unstructured_inputs,
structured_generator_inputs=structured_inputs)
#通过 tfgan.gan_loss,生成 infogan 模型的 loss 值:
infogan_loss = tfgan.gan_loss(
infogan_model,
gradient_penalty_weight=1.0,
mutual_information_penalty_weight=1.0)

#InfoGan 的 Loss 值为在 GAN 的 loss 值上,加上互信息 I(c;G(z,c)),TFGAN 中提供了互信息计算的 API,如下所示。其中 structured_generator_inputs 为潜变量的噪音信息,predicted_distributions 为变分分布 Q(c|x)。

def mutual_information_penalty(structured_generator_inputs, predicted_distributions)## WGAN

Martin Arjovsky 等提出了 WGAN(Wasserstein GAN),解决了传统 GAN 训练困难、生成器和判别器的 loss 很难指示训练进程、生成样本缺乏多样性等问题,主要有以下优点:

  1. 能够平衡生成器和判别器的训练程度,使得 GAN 的模型训练稳定。
  2. 能够保证生产样本的多样性。
  3. 提出使用 Wasserstein 距离来衡量模型训练的程度,数值越小表示训练得越好,成器生成的图像质量越高。

WGAN 的算法与原始 GAN 算法的差异主要体现在:

  1. 去掉判别模型最后一层的 sigmoid 操作。
  2. 生成模型和判别模型的 loss 值不取 log 操作。
  3. 每次更新判别模型的参数之后把模型参数的绝对值截断到不超过固定常数 c。
  4. 使用 RMSProp 算法,不用基于动量的优化算法,例如 momentum 和 Adam。

WGAN 的算法结构如下所示:

TFGAN 中提供了 WGan 相关 API,如下所示:

复制代码
#生成网络损失方程
generator_loss_fn=tfgan_losses.wasserstein_generator_loss
#判别网络损失方程
discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss

总结

本文首先介绍了生成对抗网络和 TFGAN,生成对抗网络模型用于图像生成、超分辨率图片生成、图像压缩、图像风格转换、数据增强、文本生成等场景;TFGAN 是 TensorFlow 库,用于快速实践各种 GAN 模型。然后讲解了 CGAN、ImageToImage、InfoGAN、WGAN 模型的主要思想,并对关键技术进行了分析,主要包括目标函数、网络架构、损失方程及相应的 TFGAN API。用户可基于 TFGAN 快速实践生成对抗网络模型,并应用到工业领域中的相关场景。

参考文献

[1] Generative Adversarial Networks.
[2] Conditional Generative Adversarial Nets.
[3] InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets.
[4] Wasserstein GAN.
[5] Image-to-Image Translation with Conditional Adversarial Networks.
[6] https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/gan .
[7] https://github.com/tensorflow/models/tree/master/research/gan .

作者简介

武维(微信:allawnweiwu):博士,现为 IBM 架构师。主要从事深度学习平台及应用研究,大数据领域的研发工作。

2018-06-03 17:582190

评论

发布
暂无评论
发现更多内容

Flink CDC 2.0 正式发布,详解核心改进

Apache Flink

flink

疫情之下,延期返工,我竟然“远程面试”了3家公司(备战春招)

公众号_愿天堂没有BUG

Java 编程 程序员 架构 面试

巧用Python访问台达AS228交互

林建

Python Modbus协议 台达 AS228T

终于有人!把双十一电商秒杀系统高并发架构全部讲清楚了

Java 程序员 面试 高并发 计算机

疫情在家“闭关修炼”,读完这些Java技术栈,愿金三银四过五斩六

公众号_愿天堂没有BUG

Java 编程 程序员 架构 面试

远程办公一星期,竟等来了阿里新零售视频面(Java岗,已过2面)

公众号_愿天堂没有BUG

Java 编程 程序员 架构 面试

中国如何应对中美博弈?

石云升

学习 贸易战 8月日更

【共识专栏】HotStuff共识

趣链科技

区块链 共识机制 拜占庭容错 共识算法

什么是工控主机?工控主机安卓主板有哪些配置?

双赞工控

区块链钱包搭建,去中心钱包搭建,仿IM钱包

GitHub星标63K霸榜半月!阿里大牛的微服务分布式架构笔记已上线

Java 编程 IT 计算机 知识

20张图让你彻底掌握负载均衡的秘密

负载均衡 编程 程序员 计算机

FastApi-12-Form表单

Python研究所

FastApi 8月日更

如何在多云环境中建立信任

云计算

如何实现H.264的实时传输?

拍乐云Pano

TCP 四次挥手

W🌥

计算机网络 TCP/IP 8月日更

冲击“金九银十”的利器!《Java权威面试指南(阿里版)》人手一份吊打面试官轻轻松松!

Java 编程 IT 计算机 知识分享

【SpringCloud 技术专题】「原生态 Fegin」打开 Fegin 之 RPC 技术的开端,你会使用原生态的 Fegin 吗?(下)

码界西柚

SpringCloud OpenFegin Fegin 8月日更

去中心化DeFi系统开发

Geek_23f0c3

智能合约 DeFi去中心化系统开发 DAPP智能合约交易系统开发

金三银四,如何远程面试拿下大厂offer?(附大厂面经+面试宝典)

公众号_愿天堂没有BUG

Java 编程 程序员 架构 面试

Fil价格今日行情?Fil有投资的价值吗?

区块链 分布式存储 IPFS fil fil价格今日行情怎么样

原理分析!如何将springboot项目打成war包放入tomcat中运行

Summer

Java 学习 程序员 架构 springboot

如何快速定位程序Core?

百度Geek说

Linux 后端

啃完这些Spring知识点,我竟吊打了阿里面试官(附面经+笔记

公众号_愿天堂没有BUG

Java 编程 程序员 架构 面试

通俗易懂的ReentrantLock,不懂你来砍我

程序猿阿星

AQS 公平锁 非公平锁 独占锁 ReentrantLock;

更智能更高效!区块链打造更“美” 服装行业

旺链科技

区块链 服装产业

微博SDK初始化问题 please init sdk before use it. Wb.install()

mengxn

微博sdk

肺炎在家“闭关”,阿里竟发来视频面试,4面顺利拿下offer

公众号_愿天堂没有BUG

Java 编程 程序员 架构 面试

iOS 开发技术栈与进阶

iOSer

ios 面试 iOS 知识体系 iOS技术栈

Flutter Android 端 FlutterInjector 及依赖流程源码分析

工匠若水

flutter android 8月日更

最全总结 | 聊聊 Python 数据处理全家桶(PgSQL篇)

星安果

Python 数据库 postgresql PgSQL

如何应用TFGAN快速实践生成对抗网络?_Google_武维_InfoQ精选文章