阿里云飞天发布时刻,领先大模型限免,超7000万 tokens免费体验 了解详情
写点什么

ASGD

  • 2019-11-29
  • 本文字数:1833 字

    阅读完需:约 6 分钟

ASGD

简介

Asynchronous Stochastic Gradient Descent (ASGD)异步的随机梯度下降在深度学习模型的训练中经常被用到,但是会存在 delayed gradients 的问题,就是当一个 worker 向参数 server 端提交它算出的梯度时,server 端其实已经被其它 worker 更新好多次了。因此该工作提出了梯度补偿的概念,主要方法是利用梯度函数的泰勒展开去有效逼近 loss 函数的 Hessian 矩阵。通过在 cifar 和 imagenet 数据集上验证,实验结果显示,新的方法 DC-ASGD 性能优于同步 SGD 和异步 SGD,几乎接近序列 SGD 的性能。

ASGD 介绍

传统的 SGD,更新公式为:



其中,wt 为当前模型,(xt, yt)为随机抽取的数据,g(wt; xt, yt)为(xt, yt)所对应的经验损失函数关于当前模型 wt 的梯度,η为步长/学习率。


同步随机梯度下降法(Synchronous SGD)在优化的每轮迭代中,会等待所有的计算节点完成梯度计算,然后将每个工作节点上计算的随机梯度进行汇总、平均并上面的公式更新模型。之后,工作节点接收更新之后的模型,并进入下一轮迭代。由于 Sync SGD 要等待所有的计算节点完成梯度计算,因此好比木桶效应,Sync SGD 的计算速度会被运算效率最低的工作节点所拖累。


异步随机梯度下降法(Asynchronous SGD)在每轮迭代中,每个工作节点在计算出随机梯度后直接更新到模型上,不再等待所有的计算节点完成梯度计算。因此,异步随机梯度下降法的迭代速度较快,也被广泛应用到深度神经网络的训练中。然而,Async SGD 虽然快,但是用以更新模型的梯度是有延迟的,会对算法的精度带来影响。如下图:



在 Async SGD 运行过程中,某个工作节点 Worker(m)在第 t 次迭代开始时获取到模型的最新参数 [公式] 和数据(xt, yt),计算出相应的随机梯度 [公式] ,并将其返回并更新到全局模型 w 上。由于计算梯度需要一定的时间,当这个工作节点传回随机梯度[公式]时,模型[公式]已经被其他工作节点更新了τ轮,变为了 [公式] 。也就是说,Async SGD 的更新公式为:



可以看到,对参数[公式]更新时所使用的随机梯度是 g(wt),相比 SGD 中应该使用的随机梯度 g(wt+τ)产生了τ步的延迟。因而,我们称 Async SGD 中随机梯度为“延迟梯度”。


延迟梯度所带来的最大问题是,由于每次用以更新模型的梯度并非是正确的梯度,因为 g(wt) ≠ g(wt+τ),所以导致 Async SGD 会损伤模型的准确率,并且这种现象随着机器数量的增加会越来越严重。


因此 DC-ASGD 算法设计了一种可以补偿梯度延迟的方法,他们首先研究了正确梯度 g(wt+τ)和延迟梯度 g(wt)之间的关系,我们将 g(wt+τ)在 wt 处进行泰勒展开得到:



其中,∇g(wt)为梯度的梯度(loss fuction 的 Hessian 矩阵,因此梯度 g(wt)是 loss 函数关于参数 wt 的导数)。H(g(wt))为梯度的 Hessian 矩阵。那么如果将所有的高阶项都计算出来,就可以修正延迟梯度为准确梯度了。然而,由于余项拥有无穷项,并且计算量十分复杂,所以无法被准确计算。因此,可用上述公式中的一阶项进行延迟补偿:



但是上面的公式还是要计算∇g(wt)(参数的 Hessian 矩阵),但是在 DNN 中有上百万甚至更多的参数,计算和存储 Hessian 矩阵∇g(wt)很困难。因此,寻找 Hessian 矩阵的一个良好近似是能否补偿梯度延迟的关键。根据费舍尔信息矩阵的定义,梯度的外积矩阵是 Hessian 矩阵的一个渐近无偏估计:



其实,进一步可以写成:[公式] 。


又可知,在 DNN 中用 Hessian 矩阵的对角元素来近似表示 Hessian 矩阵,可在显著降低运算和存储复杂度的同时还可以保持算法精度,于是我们采用外积矩阵的 diag(G(wt))作为 Hessian 矩阵的近似。为了进一步降低近似的方差,我们使用一个(0,1]之间参数λ来对偏差和方差进行调节。另外由于:



综上,带有延迟补偿的异步随机梯度下降法(DC-ASGD):

具体算法

算法 1 中,worker m 从参数服务器中 pull 最新的模型参数 w,然后计算得到梯度 [公式] 后 push 到参数服务器中。


算法 2 中,当参数服务器接收到 worker m 算出来的梯度 [公式]后,利用梯度补偿公式算出下一个时间刻参数服务器正确的参数。如果参数服务器接受到 worker m 的 pull 参数请求时,将当前参数服务器的参数 wt 备份成 w_bak,并将 wt 发送给 worker m。


实验

在 CIFAR10 数据集和 ImageNet 数据集上对 DC-ASGD 算法进行了评估,实验结果显示:DC-ASGD 算法与 Async SGD 算法相比,在相同的时间内获得的模型准确率有显著的提升,并且也高于 Sync SGD,基本可以达到 SGD 相同的模型准确率。


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/80978479


2019-11-29 08:002178

评论

发布
暂无评论
发现更多内容

蓝鲸研运体系在腾讯内是如何应用实践的?

嘉为蓝鲸

运维 智能运维AIOps

2023 重学 Angular

PingCode研发中心

前端框架

基于云原生技术的融合通信是如何实现的?

阿里云CloudImagine

阿里云 云通信

企业想要高效运营,还需要选择瓴羊Quick BI软件

流量猫猫头

大数据

天翼云Serverless边缘容器下沉服务 促进企业聚焦业务创新

天翼云开发者社区

老板让我在Linux中使用traceroute排查服务器网络问题,幸好我收藏了这篇文章!

wljslmz

Linux 网络故障 11月月更 traceroute

互联网企业面试必问Spring源码?搞定Spring源码,看完这篇就够了

钟奕礼

Java java面试 java编程 程序员‘

MegEngine Inference 卷积优化之 Im2col 和 winograd 优化

MegEngineBot

深度学习框架 卷积 MegEngine

10月&11月书单

图灵社区

书单推荐

图数据技术护航网络安全

Neo4j 图无处不在

网络安全 neo4j 图数据库 知识图谱 图算法

新时代冠军企业成功硬道理:人效管理与可组装式HCM SaaS

ToB行业头条

腾讯云原生容器服务发布三大新能力,创新自研技术助力企业降本增效

科技热闻

8年程序员年初被迫毕业,前后面试30家公司,如今终于上岸

Java永远的神

程序人生 后端 java程序员 java面试 面经分享

MyBatis resultMap元素的用途是什么呢?

@下一站

技术 mybatis java; 11月月更

企业数字营销和运营如何效果更好?瓴羊Quick BI成为了不错的选择

小偏执o

企业内部统一的移动平台,实现安全高效的业务移动化

BeeWorks

嘉为科技吴文豪:重塑运维系统,跨越烟囱式建设的陷阱

嘉为蓝鲸

运维 #WeOps

制造业的敏捷分析,还需要使用瓴羊Quick BI

对不起该用户已成仙‖

阿里云洛神云网络集中式网关丨技术解读与产品实践

云布道师

云网络

三年后端开发:拿下阿里/腾讯/美团等四个大厂的Offer后,总结如下

钟奕礼

Java Java 面试 程序员‘ java 编程

SpringMVC常用注解

@下一站

软件开发 程序 Java‘’ 11月月更

Neo4j CEO Emil Eifrem 解读图数据平台引领数据库未来十年的发展

Neo4j 图无处不在

neo4j 图数据库 知识图谱 图可视化引擎 图数据

嘉为科技宋蕴真:观测不止于监控,让运维不开盲盒

嘉为蓝鲸

运维 智能运维AIOps

对话Neo4j首席科学家Jim Webber:图数据库江湖5年后将尘埃落定

Neo4j 图无处不在

neo4j 图数据库 知识图谱 非关系型数据库 图技术

在结构效率不变情况下的降本增效

PMO实践

数字化转型 数字化 数智化 11月月更

跟误告警说再见,Smart Metrics 帮你用算法配告警

阿里巴巴云原生

阿里云 云原生 Grafana

精彩回顾 | 云原生系统软件的产业应用

BoCloud博云

云原生

焱融科技为国家重点实验室打造海量高性能存储

焱融科技

云计算 分布式系统 高性能 文件存储

瓴羊Quick BI在商业智能BI发展趋势方面如何?

对不起该用户已成仙‖

数字产业化的颠覆创新和生态打法

PMO实践

产业数字化 11月月更

天翼云混合云容灾技术解析

天翼云开发者社区

ASGD_文化 & 方法_Alex-zhai_InfoQ精选文章