AICon全球人工智能与机器学习技术大会周四开幕,点击查看完整日程>> 了解详情
写点什么

如何用 PyTorch 构建 GAN?

  • 2021 年 11 月 22 日
  • 本文字数:3705 字

    阅读完需:约 12 分钟

如何用 PyTorch 构建 GAN?

生成对抗网络(Generative Adversarial Network,GAN)由 Goodfellow 等人在 2014 年提出,它彻底改变了计算机视觉中的图像生成领域:没有人能够相信这些令人惊叹而生动的图像实际上是纯粹由机器生成的。


事实上,人们曾经认为生成的任务是不可能的,并且被 GAN 的力量所震惊,因为传统上,根本没有任何事实可以比较我们生成的图像。


本文介绍了创建 GAN 背后的简单直觉,然后介绍了通过 PyTorch 实现的卷积 GAN 及其训练过程。

GAN 背后的直觉

不同于传统分类方法,我们的网络预测可以直接与事实的正确答案相比较,而生成图像的“正确性”是很难定义和衡量的。Goodfellow 等人在他们的原创论文《生成对抗网络》(Generative Adversarial Network)中提出了一个有趣的想法:使用经过训练的分类器来区分生成的图像和实际图像。如果存在这样的分类器,我们可以创建并训练一个生成器网络,直到它输出的图像能完全骗过分类器。


GAN 管道


GAN 是这一过程的产物:它包含一个根据给定的数据集生成图像的生成器,以及一个区分图像是真实的还是生成的判别器(分类器)。GAN 的详细管道见图 1。

损失函数

对生成器和判别器进行优化都很困难,因为正如你所想象的那样,这两个网络的目标完全相反:生成器希望尽可能地创造出真实的东西,但判别器希望区分生成的材料。


为了说明这一点,我们让 D(x) 是判别器的输出,也就是 x 是真实图像的概率,而 G(z) 是我们的生成器的输出。判别器类似于一个二元分类器,因此判别器的目标是使函数最大化:


本质上是二元交叉熵损失,没有开头的负号。另一方面,生成器的目标是使判别器做出正确判断的机会最小化,因此它的目标是最小化函数。所以,最终的损失函数将是两个分类器之间的一个极小极大博弈(minimax game),具体如下:



从理论上讲,这将收敛到判别器,预测所有事件的概率为 0.5。


但在实践中,极小极大博弈往往会导致网络无法收敛,因此仔细调整训练过程非常重要。像学习率这样的超参数对于训练 GAN 时显然更为重要:一个微小的变化会导致 GAN 产生一个输出,而与输入噪声无关。

运算环境

我们通过 PyTorch 库(包括 torchvision)来构建整个程序。GAN 的生成结果的可视化是通过 Matplotlib 库绘制的。下面的代码导入了所有的库:


importGAN.py


"""Import necessary libraries to create a generative adversarial networkThe code is mainly developed using the PyTorch library"""import timeimport torchimport torch.nn as nnimport torch.optim as optimfrom torch.utils.data import DataLoaderfrom torchvision import datasetsfrom torchvision.transforms import transformsfrom model import discriminator, generatorimport numpy as npimport matplotlib.pyplot as plt
复制代码

数据集

在 GAN 训练中,数据集是一个重要方面。图像的非结构化性质意味着任何给定的类别(如狗、猫或手写的数字)都可以有一个可能的数据分布,而这种分布最终是 GAN 生成内容的基础。


为了演示,本文将使用最简单的 MNIST 数据集,其中包含 60000 张从 0 到 9 的手写数字图像。事实上,像 MNIST 这样的非结构化数据集可以在 Graviti 上找到。这是一家年轻的创业公司,他们希望通过非结构化数据集为社区提供帮助,在他们的平台上有一些最好的公共非结构化数据集,包括 MNIST。

硬件要求

最好的方法是用 GPU 训练神经网络,它可以显著地提高训练速度。但是,如果只有 CPU 可用,你仍然可以测试程序。要使你的程序能够自行确定硬件,你可以使用以下方法:


torchDevice.py


"""Determine if any GPUs are available"""device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
复制代码

实施

网络架构

由于数字的简单性,这两种架构——判别器和生成器,都是由全连接层构建的。请注意,在某些情况下,全连接的 GAN 也比 DCGAN 略微容易收敛。


以下是两种架构的 PyTorch 实现:


GANArchitecture.py


"""Network ArchitecturesThe following are the discriminator and generator architectures"""
class discriminator(nn.Module): def __init__(self): super(discriminator, self).__init__() self.fc1 = nn.Linear(784, 512) self.fc2 = nn.Linear(512, 1) self.activation = nn.LeakyReLU(0.1)
def forward(self, x): x = x.view(-1, 784) x = self.activation(self.fc1(x)) x = self.fc2(x) return nn.Sigmoid()(x)
class generator(nn.Module): def __init__(self): super(generator, self).__init__() self.fc1 = nn.Linear(128, 1024) self.fc2 = nn.Linear(1024, 2048) self.fc3 = nn.Linear(2048, 784) self.activation = nn.ReLU()
def forward(self, x): x = self.activation(self.fc1(x)) x = self.activation(self.fc2(x)) x = self.fc3(x) x = x.view(-1, 1, 28, 28) return nn.Tanh()(x)
复制代码

训练

在训练 GAN 时,我们优化了判别器的结果,同时也改进了我们的生成器。这样,在每次迭代过程中会有两个相互矛盾的损失来同时优化它们。我们送入生成器的是随机噪声,而生成器理应根据给定噪声的微小差异来生成图像:


trainGAN.py


"""Network training procedureEvery step both the loss for disciminator and generator is updatedDiscriminator aims to classify reals and fakesGenerator aims to generate images as realistic as possible"""for epoch in range(epochs):    for idx, (imgs, _) in enumerate(train_loader):        idx += 1
# Training the discriminator # Real inputs are actual images of the MNIST dataset # Fake inputs are from the generator # Real inputs should be classified as 1 and fake as 0 real_inputs = imgs.to(device) real_outputs = D(real_inputs) real_label = torch.ones(real_inputs.shape[0], 1).to(device)
noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5 noise = noise.to(device) fake_inputs = G(noise) fake_outputs = D(fake_inputs) fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device)
outputs = torch.cat((real_outputs, fake_outputs), 0) targets = torch.cat((real_label, fake_label), 0)
D_loss = loss(outputs, targets) D_optimizer.zero_grad() D_loss.backward() D_optimizer.step()
# Training the generator # For generator, goal is to make the discriminator believe everything is 1 noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5 noise = noise.to(device)
fake_inputs = G(noise) fake_outputs = D(fake_inputs) fake_targets = torch.ones([fake_inputs.shape[0], 1]).to(device) G_loss = loss(fake_outputs, fake_targets) G_optimizer.zero_grad() G_loss.backward() G_optimizer.step()
if idx % 100 == 0 or idx == len(train_loader): print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}'.format(epoch, idx, D_loss.item(), G_loss.item()))
if (epoch+1) % 10 == 0: torch.save(G, 'Generator_epoch_{}.pth'.format(epoch)) print('Model saved.')
复制代码

结果

当 100 个轮数(epoch)之后,我们可以绘制数据集,并看到从随机噪音中生成的数字的结果:



图 2:GAN 生成的结


如上图所示,生成的结果看起来确实相当像真实的结果。鉴于网络非常简单,所以结果看起来确实很有希望!

超越单纯的内容创作

GAN 的创造与计算机视觉领域的先前工作如此不同。随后的众多应用使学术界对深度网络的能力感到惊讶。下面将介绍一些令人惊讶的工作。

CycleGAN

Zhu 等人的 CycleGAN 引入了一种概念,它无需配对样本就可以将图像从 X 域翻译成 Y 域。马被转化为斑马,夏日的阳光被转化为暴风雪,CycleGAN 的结果令人惊讶且准确。


3:Zhu 等人的 CycleGAN 生成的结果。

GauGAN

Nvidia 利用 GAN 的力量,把简单的绘画,根据画笔的语义,转换成优雅而逼真的照片。尽管训练资源的计算成本很高,但它创造了一个全新的研究和应用领域。



4:GaoGAN 的生成结果。左为原图,右为生成的结果。

AdvGAN

GAN 还扩展到清理对抗性图像,并将其转化为不会欺骗分类器的干净样本。关于对抗性攻击和防御的更多信息可以在这里到。

结语

所以,你已经拥有了它!希望这篇文章对如何构建 GAN 提供了一个概览。完整的实现可以在下面的 Github 资源库中找到:


https://github.com/ttchengab/MnistGAN


作者简介:


Ta-ying Cheng,中国香港人,牛津大学哲学博士新生,爱好 3D 视觉、深度学习。


原文链接:


https://towardsdatascience.com/building-a-gan-with-pytorch-237b4b07ca9a

2021 年 11 月 22 日 17:521
用户头像
刘燕 InfoQ记者

发布了 716 篇内容, 共 232.1 次阅读, 收获喜欢 1370 次。

关注

评论

发布
暂无评论
发现更多内容

LocalDateTime、OffsetDateTime、ZonedDateTime互转,这一篇绝对喂饱你

YourBatman

LocalDateTime OffsetDateTime ZonedDateTime

Java程序员福音!阿里最新产物分布式小册:存储+计算+通信+资源调度

Java架构追梦

Java 阿里巴巴 架构 面试 分布式

腾讯T4大牛的10万字《Java架构进阶面试知识笔记》,收藏吃灰系列

Crud的程序员

Java 架构

阿里首推Java微服务架构实战宝典开源,SpringBoot/Cloud+Docker+RabbitMQ彻底玩转微服务!

程序员小毕

Java 架构 面试 微服务 消息中间件

应对新冠病毒传播-粤政协委员建议构建公共卫生区块链平台

Geek_987812

区块链 公共卫生

用APICloud开发iOS App Clip(苹果小程序)详细教程

APICloud

小程序云开发 前端 移动开发 APP开发

见证产品成长,共享AI力量!

百度大脑

架构师训练营第九周作业

zamkai

第一周作业-产品备忘录

Eva

软件架构模式之分层架构

架构精进之路

架构设计 七日更 28天写作

老熟人,新朋友!写作平台邀新季!

InfoQ写作平台官方

活动专区

别让假“努力”毁掉了你!面试了10家企业软件测试岗位,面试题整理

程序员阿沐

程序员 面试 软件测试 自动化测试 测试工程师

数据库表数据量大读写缓慢如何优化(3)【Elasticsearch的使用】

我爱娃哈哈😍

大数据 elasticsearch 优化 死磕Elasticsearch 架构·

谷歌面试题:如何从无序链表中移除重复项?

田维常

面试

【面试必备】Swift 面试题及其答案

ios swift

高承实:区块链是一个技术结构组织 而不是技术

Geek_987812

大数据

红河州加速区块链等新技术与实体经济的深度融合

Geek_987812

数字经济

第四周作业

oooh-la

测试一下

TJJ

重学JS | Set和Map是如何过滤重复值的?

梁龙先森

面试 前端 编程语言 28天写作

Java 程序经验小结: 慎用可变参数

后台技术汇

28天写作

拍乐云技术分享 | 美术教学中视频矫正是怎么做的?

拍乐云Pano

音视频 RTC 图像处理 拍乐云 视频处理

第一章作业

tera

PostgreSQL中Oid和Relfilenode的映射

PostgreSQLChina

数据库 postgresql 开源 软件

图解 | 原来这就是TCP

程序员 网络协议 架构师

常见运维监控系统的技术选型

OpsMind

运维 监控系统

Hbase内核剖析

永健_何

大数据 HBase 底层技术 分布式数据储存

喜讯 | 拍乐云Pano荣获「2020大数据产业创新技术突破」奖

拍乐云Pano

大数据 音视频 RTC 拍乐云

产品思维和产品意识

ALone

想学AI开发很简单:只要你会复制粘贴

华为云开发者社区

GitHub 开源 AI mindspore 推理

目标岗位差异化对比

Geek_6a8931

数据cool谈(第2期)寻找下一代企业级数据库

数据cool谈(第2期)寻找下一代企业级数据库

如何用 PyTorch 构建 GAN?-InfoQ