写点什么

谷歌刷新世界纪录!2 分钟搞定 ImageNet 训练

  • 2018-11-23
  • 本文字数:3086 字

    阅读完需:约 10 分钟

谷歌刷新世界纪录!2分钟搞定ImageNet训练

AI 前线导读:随着技术、算力的发展,在 ImageNet 上训练 ResNet-50 的速度被不断刷新。2018 年 7 月,腾讯机智机器学习平台团队在 ImageNet 数据集上仅用 6.6 分钟就训练好 ResNet-50,创造了 AI 训练世界纪录;一周前,壕无人性的索尼用 2176 块 V100 GPU 将这一纪录缩短到了 224 秒;如今,这一纪录再次被谷歌刷新……


深度学习非常依赖于硬件条件,它是一个计算密集型的任务。硬件供应商通过在大型计算集群中部署更快的加速器来做出更快的相应。在 petaFLOPS(运算能力单位,每秒千万亿次浮点数运算)规模的设备上训练深度学习模型需要同时面临算法和系统软件两方面的挑战。Google 于近日推出了一种大规模计算集群的图像分类人物训练解决方案,相关论文发表于 Arxiv:Image Classification at Supercomputer Scale。本文的作者使用 Google TPU v3 Pod 训练 ResNet-50,在识别率没有降低的情况下,仅使用了 2.2 分钟的时间。

背景

深度神经网络的成功应用与发展离不开疯狂增长的算力,在许多领域,深度学习的发展可以说是由硬件驱动的。在深度网络的训练过程中,最关键的部分就是使用随机梯度下降算法(SGD)优化网络权重。通常情况下,模型需要使用 SGD 在一个数据集上进行多次的便利才能达到收敛。在整个过程中,浮点数运算能力显得至关重要。例如,在 ImageNet 数据库上训练 ResNet-50 模型,遍历一次数据库需要 3.2 万万亿次浮点数运算。而使模型达到收敛,通常需要遍历 90 次数据库。


尽管硬件加速设备(例如 GPU、TPU)已经加快了迭代的次数,使用单个加速设备在大规模数据库训练大型的神经网络仍然需要几个小时或数天的时间。最常见的加速方法便是通过分布式的 SGD 算法使用多个设备并行训练,将每个 mini-batch 分布在多个相同的加速设备上。


以往大家都喜欢用异步分布式 SGD 算法在将多个线程联合起来进行训练,但是近期的一些工作发现,异步分布式 SGD 算法优化的模型在收敛程度和验证准确率方面都不如同步分布式 SGD 训练出的模型。但是,为了保证在提速的同时模型的质量不会有所损失,在使用同步分布式 SGD 算法的过程中,会遇到很多技术和硬件方面的瓶颈,作者总结出以下几点:


  1. 模型的准确率依赖于全局的 batch size 和计算集群中每个节点的 batch size。

  2. 在加速设备计算能力足够高时,CPU 向 GPU 等专用设备的输入过程成为了训练过程中的瓶颈。

  3. 使用同步分布式 SGD 算法需要大规模的高速并行通信方案,即如何解决一个计算集群内部各个节点之间通信速度的瓶颈。


本文的作者提出了一种同步的分布式 SGD 优化算法,同时还提出了几个大规模分布式深度学习训练过程中使用的机器学习方法和优化方法,在加速收敛的过程中保证模型的质量没有损失。


图:左图为4-chip的云TPU v2设备,峰值计算能力为180 teraFLOPS(每秒万亿次浮点运算),使用64GB的HBM(高带宽内存);右图为使用水冷的4-chip云TPU v3设备,峰值计算能力为420 teraFLOPS,使用128GBHBM。TPU v2设备可组成最高256-chip的计算集群,称为TPU Pod,可提供高达11.5petaFLOPS的混合精度吞吐量。TPU v3 Pod的规模可达1024-chip,是TPU v2 Pod的四倍,理论上可提供107.5 petaFLOPS的混合精度吞吐量

方法

本文的作者受之前大规模训练方法的启发,在实验过程中使用了以下一些技术:


  • 混合精度:在实验过程中,卷积操作使用了 bfloat16 数据,这是一种 TPU 上的半精度 16 位浮点数。此外,卷积层之间的激活函数也使用了 bfloat16 的格式。为了保证计算精度与 32 位浮点数网络不相上下的精度,对于所有的非卷积的操作(例如,批归一化、损失函数计算、梯度求和)都使用了 32 位浮点数。由于网络训练过程中的主要计算和内存消耗都是在卷积操作上,因此使用 bfloat16 可以获得更高的训练吞吐量。

  • 学习率配置:先前的一些研究表明,学习率应当与 batch size 成比例。在实验过程中,作者使用了线性变化的学习率策略进行配置(例如,batch size 设成两倍,则学习率也设为两倍)。同时作者也使用了平缓的学习率预热(warm-up)方法和学习率衰减。

  • 分层自适应速率缩放(LARS):尽管使用动量(momentum)的随机梯度下降算法已经可以将 batch 最高设为 8192,但使用 LARS 优化器可以达到 32786 的 batch size 并且对于模型质量没有影响。更大的 batch size 也增加了模型在 TPU 集群上执行时的吞吐量。

分布式批归一化

批归一化在图像分类任务中有着不可或缺的作用,它通过对一个 mini-batch 内的数据进行归一化,使得经过 batch-norm 层的数据服从相同均值与方差的分布,使得下层神经元可以更好的对数据分布情况进行学习。


在分布式训练过程中,通常让每个计算节点独立的进行 batch norm, 这样的好处是可以大大缩短训练时间,因为每个计算节点之间无需额外的通信过程。在实验过程中,作者发现 BN 的批大小(例如计算节点的批大小)对模型的验证准确率有重要影响。已经有研究证明在计算节点的批大小小于 32 时,ResNet-50 的最终训练结果在验证数据上的准确率并不能收敛。


当使用数据并行的方法在大规模计算机集群上进行部署时,需要同时对全局的 batch size 大小进行扩大,同时对每个节点的局部 batch size 进行缩小。考虑到 BN 层的影响,作者主要针对每个节点上的 batch size 较小的情况进行研究。


作者通过对几个计算节点组成的子节点做分布式的批归一化来实现对 BN 这一过程的增强。具体算法如图所示:


图:分布式批归一化算法示意图,图中集群包含两个计算节点


  1. 首先各个节点计算独立的局部均值与方差

  2. 计算一个子集群(图中子集群包含两个计算节点的)中的分布式均值和方差。

  3. 使用分布式均值和方差对子集群中的所有节点进行归一化

输入管道优化

训练模型过程中,输入管道包括了数据读取、数据分析、预处理、旋转和批量化等操作。如果输入管道的吞吐量不能和 TPU 等模型管道(前向或反向传播过程)的吞吐量相匹配,整个过程将会由于输入管道的问题产生吞吐量上的瓶颈。导致输入管道与模型管道吞吐量差异的主要原因是专用硬件加速设备与 CPU 之间的性能差异,因为模型管道是完全在专用硬件加速设备上执行的。


在本文中, 作者使用了很多关键的优化方法来解决输入管道导致的瓶颈。此前,还未有工作对这些技术进行整合。具体方法如下:


  • 数据共享与缓存:理想情况下,所有的数据会一次性读取并缓存在内存中以备直接使用,但是对于真实情况中的大规模数据集这种做法往往是不可行的。由于计算集群之间是可以共享内存与数据的,因此在大规模计算集群中,作者使用这种数据集共享与缓存的方法来提高输入管道的吞吐量。

  • 预提取并计算:在计算当前批的数据同时对下一批的数据进行提取和处理,当前批计算完时便可直接提取数据使用。

  • 混合 JPEG 解码与裁剪:使用原始的编码数据进行数据增强等操作然后只对有效的部分进行解码

  • 并行数据分析:对于输入管道来说,数据分析与处理是非常消耗算力的,多核 CPU 可以使用多线程进行加速。

二维梯度求和

本文的作者提出了一种二维梯度求和方法,用于多个计算节点之间的梯度的计算和传播。在传统的一维方法中,梯度求和这一步的时间复杂度是 O(n^2),使用二维求和后,时间复杂度可以降到 O(n)。具体计算方法如下图所示。


图:二维环形梯度传播,第一阶段,蓝色张量在Y轴方向进行求和,红色张量在X轴方向进行求和。第二阶段,维度进行转换再次求和。

实验与分析

作者进行了多个实验,对文中提到的几个技术细节进行论证。

分布式批归一化

分布式归一化的结果如下图所示,实验使用了 TPU v2 Pod 进行训练,并且没有使用 LARS 优化。


输入管道优化

左图是逐渐增加每种优化方法的实验结果,中间的图是组合优化的结果与逐渐减少其他优化方法的结果对比,右图是并行化数量对实验结果的影响。所有的实验结果都以数据吞吐量为指标。


二维梯度求和

下图是二维梯度求和算法与一维梯度求和算法的比较,可见使用二维梯度求和在各个配置的情况下都可以有效的减少分布式求和的时间。


与已有最好方法的对比

最后,作者与目前最好的分布式计算方法进行了比较,在准确率相同的情况下,本文提出的方法相比之前的方法大大减少了时间消耗。



目前谷歌云已经上线 Cloud TPU v3 测试版,单台设备价格每小时 2.4 美元到 8 美元,也不是很贵,你也可以动手试试看哦~




会议推荐:12 月 20-21,AICon 将于北京开幕,在这里可以学习来自 Google、微软、BAT、360、京东、美团等 40+AI 落地案例,与国内外一线技术大咖面对面交流。


2018-11-23 18:452741

评论 1 条评论

发布
暂无评论
发现更多内容

【TcaplusDB知识库】表操作—如何申请查询申请单信息

TcaplusDB

week4作业

Asha

【TcaplusDB知识库】表操作—如何查询事务详情

TcaplusDB

异构注册中心机制在中国工商银行的探索实践

SOFAStack

GitHub 开源 分布式架构 注册中心 工商银行

【TcaplusDB知识库】表操作—如何审核删除表申请

TcaplusDB

春暖花开,等你而来!4月月更挑战开始啦!

InfoQ写作社区官方

热门活动 4月月更

【TcaplusDB知识库】事务操作—如何执行事务

TcaplusDB

【TcaplusDB知识库】事务操作—如何恢复(处于挂起状态的)事务

TcaplusDB

如何快速实现持续交付

阿里云云效

云计算 阿里云 软件开发 CI/CD 持续交付

云时代,租电脑还是初创型企业最好的选择吗?

阿里云弹性计算

远程办公 无影云电脑 初创型企业

墨天轮访谈 | 华为云温云博:从客户视角出发,GaussDB(for Redis)究竟“香”在哪里?

墨天轮

数据库 redis 华为云 国产数据库 键值数据库

从二十年开源经历出发,70 后大龄程序员谈成长、困境与突围

TDengine

数据库 tdengine 开源

【TcaplusDB知识库】表操作—如何申请复制表数据

TcaplusDB

【TcaplusDB知识库】表操作—如何申请审核复制表数据

TcaplusDB

VuePress 博客搭建系列 33 篇正式完结!

冴羽

JavaScript Vue 前端 vuepress 博客搭建

明天直播:如何测试硬件设备与龙蜥操作系统的兼容性?

OpenAnolis小助手

硬件 直播 开源社区 sig 兼容性

【TcaplusDB知识库】表操作—如何申请重建表

TcaplusDB

Facebook 开源 Golang 实体框架 Ent 现已支持 TiDB

Geek_2d6073

《LeetCode 刷题报告》题解内容Ⅱ

謓泽

3月月更

【TcaplusDB知识库】表操作—如何审核重建表申请

TcaplusDB

【TcaplusDB知识库】表操作—如何设置表数据淘汰

TcaplusDB

谷歌刷新世界纪录!2分钟搞定ImageNet训练_AI_Chris Ying_InfoQ精选文章