写点什么

了解学习率及其如何提高深度学习的性能

  • 2018-03-18
  • 本文字数:4152 字

    阅读完需:约 14 分钟

这篇文章记录了我对以下主题的理解:

  • 什么是学习率?它的意义是什么?
  • 如何系统地达到良好的学习率?
  • 为什么要在训练期间改变学习率?
  • 在使用预训练的模型时,如何处理学习率?

这篇文章的大部分内容,是基于过去 fast.ai 的研究员所写的 [1]、[2]、[5] 和 [4] 的内容的简洁版,以一种快速的方式,获得其精髓。要了解更多细节,请仔细参阅文末所列的参考资料。

AI 前线: fast.ai 是一个致力于为所有人提供学习深度学习机会的平台。他们认为深度学习将是一个转型的技术,正在研究综合利用人类与计算机各自优势的混合“人机”解决方案,建立一个随时可用的应用程序和模型库,开发完整的教育框架,并为开发人员和用户编写能够快速上手和易于使用的软件。

首先,什么是学习率?

学习率(Learning Rate,LR。常用η表示。)是一个超参数,考虑到损失梯度,它控制着我们在多大程度上调整网络的权重。值越低,沿着向下的斜率就越慢。虽然这可能是一个好主意(使用低学习率),以确保我们不会错过任何局部最小值;但也有可能意味着我们将耗费很久的时间来收敛——特别是当我们陷入平坦区(plateau region)的时候。

AI 前线:如果使用很高的学习率,训练可能根本不会收敛,甚至会发散。权重的该变量可能会非常大,使得优化越过最小值,导致损失函数变得更糟。

下面的公式显示了这种关系:

复制代码
new_weight = existing_weight — learning_rate * gradient

复制代码
学习率很小(上图)与学习率很大(下图)的梯度下降。(来源:Coursera 机器学习课程,Andrew Ng)

通常,学习率是由用户随机配置的。在最好的情况下,用户可以利用过去的经验(或者其他类型的学习材料)来获得关于设置学习率最佳值的直觉。

因此,很难做到这一点。下图演示了配置学习率时可能会遇到的不同场景。

复制代码
不同学习率对收敛的影响:(图片来源:csn231n)

此外,学习率会影响模型收敛到局部最小值的速度(也就是达到最佳的精度)。因此,在正确的方向做出正确的选择,意味着我们只需更少的时间来训练模型。

训练时间越少,则花在 GPU 云计算上的钱就越少。:)

AI 前线:目前深度学习使用的都是一阶收敛算法:梯度下降法。不管有多少自适应的优化算法,本质上都是对梯度下降法的各种变形。故初始学习率对深层网络的收敛起着决定性的作用。

有没有更好的方法来确定学习率?

在“训练神经网络的循环学习率(Cyclical Learning Rates (CLR)for Training Neural Networks)”[4] 的第 3.3 节中。Leslie N. Smith 认为,通过在每次迭代中以非常低的学习率来增加(线性或指数)的方式训练模型,可以估计好的学习率。

AI 前线:周期性学习率(Cyclical Learning Rates,CLR),即学习率退火与热重启,最初由 Smith 于 2015 年首次提出。这是一种新的学习率方法,和以前的不同,或者固定(fixed)或者单调递减。要使用 CLR,需指定这三个参数:max_lr、base_lr、stepsize。

复制代码
学习率在每个小批量之后增加

如果我们在每次迭代中记录学习率和训练损失,然后据此绘出曲线图;我们将会看到,随着学习率的提高,将会有一个损失停止下降并开始增加的点。在实践中,理想情况下,学习率应该是在左图的最低点(如下图所示)。在该例中为 0.001 到 0.01 之间。

以上看起来很有用。我该如何开始使用它?

目前,它被作为 fast.ai 深度学习库的一个函数来支持。由 Jeremy Howard 开发,是用来抽象 PyTorch 深度学习框架的一种方式,就像 Keras 是对 TensorFlow 框架的抽象。

AI 前线: fast.ai 深度学习库是 fast.ai 基于 PyTorch 的基础上创建的自有软件库,并且他们认为,这将有助于更加清晰地展示深度学习的概念,同时有助于实现最佳编码。采用 Apache 2.0 许可证,可免费使用。

只需输入以下命令,就可以在训练神经网络之前找到最佳学习率。

复制代码
# learn is an instance of Learner class or one of derived classes like ConvLearner
learn.lr_find()
learn.sched.plot_lr()

精益求精

在这个关键时刻,我们已经讨论了学习率的全部内容和它的重要性,以及我们如何在开始训练模型时系统地达到最佳的使用价值。

接下来,我们将讨论如何使用学习率来提高模型的性能。

一般看法

通常情况下,当一个人设定学习率并训练模型时,只有等待学习率随着时间的推移而降低,并且模型最终会收敛。

然而,随着梯度逐渐趋于稳定时,训练损失也变得难以改善。在 [3] 中,Dauphin 等人认为,最大限度地减少损失的难度来自于鞍点,而非局部极小值。

AI 前线:鞍点是梯度接近于 0 的点,在误差曲面中既不是最大值也不是最小值的平滑曲面,则一般结果表现为性能比较差;如果该驻点是局部极小值,那么表现为性能较好,但不是全局最优值。

复制代码
误差曲面中的鞍点。鞍点是函数的导数变为零但点不是所有轴上的局部极值的点。(图片来源:safaribooksonline)

那么我们该如何摆脱呢?

有几个选项我们可以考虑。一般来说,从 [1] 引用一句:

……而不是使用一个固定值的学习率,并随着时间的推移而降低,如果训练不会改善我们的损失,我们将根据一些循环函数 _f_ 来改变每次迭代的学习率。每个周期的迭代次数都是固定的。这种方法让学习率在合理的边界值之间循环变化。这有助于解决问题,因为如果我们被困在鞍点上,提高学习率可以更快速地穿越鞍点。

在 [2] 中,Leslie 提出了一种“Triangular”的方法,在每次迭代之后,学习率都会重新开始。

复制代码
Leslie N. Smith 提出的“Triangular”和“Triangular2”循环学习率的方法。在左边的图上,minmax lr 保持不变。在右边,每个周期之后的差异减半。

另一种同样受欢迎的方法是由 Loshchilov 和 Hutter 提出的热重启的随机梯度下降法(Stochastic Gradient Descent with Warm Restarts,SGDR)[6]。这种方法主要利用余弦函数作为循环函数,并在每个周期的最大值重新开始学习率。“热重启”一词源于这样的一个事实:当学习率重新开始的时候,并不是从头开始,而是来自模型在上一步收敛的参数开始 [7]。

AI 前线:热重启后的初始高学习率用于基本上将参数从它们先前收敛的最小值弹射到不同的损失表面。根据经验,热重启的随机梯度下降法需要的时间比学习率退火要少 2~4 倍,且能达到相当或更好的性能。

虽然有这种变化,下面的图表展示了它的一个实现,其中每个周期都被设置为同一时间周期。

复制代码
SGDR 图,学习率与迭代。

因此,我们现在有一种减少训练时间的方法,基本上就是周期性地在“山脉”周围跳跃(下图)。

复制代码
比较固定学习率和循环学习率(图片来源:ruder.io

除了节省时间外,研究还表明,使用这些方法往往可以提高分类准确性,而无需进行调优,而且可以在更少的迭代次数内完成。


迁移学习(Transfer Learning)中的学习率

在 fast.ai 课程中,在解决 AI 问题时,非常重视利用预先训练的模型。例如,在解决图像分类问题时,教授学生如何使用预先训练好的模型,如 VGG 或 Resnet50,并将其连接到想要预测的任何图像数据集。

总结如何在 fast.ai 中完成模型构建(注意该程序不要与 fast.ai 深度学习库混淆),下面是我们通常采取的几个步骤 [8]:

1. 启用数据增强,precompute=True。
2. 使用 lr_find()查找最高的学习率,在此情况下,损失仍在明显改善。
3. 训练最后一层从预计算激活 1~2 个轮数。
4. 在 cycle_len=1 的情况下训练最后一层数据增加(即 precompute=False)2~3 个轮数。
5. 解除所有层的冻结。
6. 将较早的层设置为比下一个较高层低 3~10 倍的学习率。
7. 再次使用lr_find()
8. 使用 cycle_mult=2 训练完整网络,直到过度拟合。

从上面的步骤中,我们注意到第 2 步、第 5 步和第 7 步关注了学习率。在这篇文章的前半部分,我们已经基本讨论了涵盖了上述步骤中的第 2 项——我们在这里讨论了如何在训练模型之前得出最佳学习率。

AI 前线:轮数,epoch,即对所有训练数据的一轮遍历。

在接下来的部分中,我们通过使用 SGDR 来了解如何通过重新开始学习速率来减少训练时间和提高准确性,以避免梯度接近于 0 的区域。

在最后一节中,我们将重点讨论差分学习,以及它是如何被用来在训练模型与预先训练的模型相结合时确定学习率的。

什么是差分学习?

这是一种在训练期间为网络中的不同层设置不同的学习率的方法。这与人们通常如何配置学习率相反,即在训练期间在整个网络中使用相同的速率。

这是我为什么喜欢 Twitter 的原因之一——可以直接从作者本人得到答案。

在写这篇文章的时候,Jeremy 和 Sebastian Ruder 发表了一篇论文,深入探讨了这个话题。所以我估计差分学习率现在有一个新的名字:判别式微调(discriminative fine-tuning)。 :)

AI 前线:判别式微调对较底层进行微调以调到一个相较于较高层较低的程度,从而保留通过语言建模所获得的的知识。它可以避免微调过程中产生严重的遗忘。

为了更清楚地说明这个概念,我们可以参考下图,其中一个预训练模型被分成 3 个组,每个组都配置了一个递增的学习率值。

复制代码
差分学习率的 CNN 样本。图片来自 [3]

这种配置方法背后的直觉是,最初的几层通常包含数据的非常细粒度的细节,如线条和边缘——我们通常不希望改变太多,并且保留它的信息。因此,没有太多的需要去大量改变它们的权重。

相比之下,在后面的层中,比如上面绿色的层——我们可以获得眼球或嘴巴或鼻子等数据的详细特征;我们可能不一定要保留它们。

与其他微调方法相比,它表现如何?

在 [9] 中,有人认为,对整个模型进行微调的代价太大,因为有些模型可能有 100 多个层。因此,人们通常做的是一次对模型进行微调。

但是,这就引入了顺序的要求,妨碍了并行性,并且需要多次通过数据集,从而导致对小数据集的过度拟合。

也已经证明 [9] 中引入的方法能够在不同的 NLP 分类任务中提高精度和降低错误率(如下图所示):

复制代码
取自 [9] 的结果

参考资料

[0] Understanding Learning Rates and How It Improves Performance in Deep Learning

[1] Improving the way we work with learning rate. [2] The Cyclical Learning Rate technique. [3] Transfer Learning using differential learning rates. [4] Leslie N. Smith. Cyclical Learning Rates for Training Neural Networks. [5] Estimating an Optimal Learning Rate for a Deep Neural Network [6] Stochastic Gradient Descent with Warm Restarts [7] Optimization for Deep Learning Highlights in 2017 [8] Lesson 1 Notebook, fast.ai Part 1 V2 [9] Fine-tuned Language Models for Text Classification

感谢陈利鑫对本文的审校。

2018-03-18 18:085871
用户头像

发布了 375 篇内容, 共 190.7 次阅读, 收获喜欢 946 次。

关注

评论

发布
暂无评论
发现更多内容

掌握 Playwright:元素操作技巧大揭秘

霍格沃兹测试开发学社

IM是什么意思?

BeeWorks

通义灵码企业版正式发布,满足企业私域知识检索、数据合规、统一管理等需求

阿里巴巴云原生

阿里云 云原生 通义灵码

Playwright安装与Python集成:探索跨浏览器测试的奇妙世界

霍格沃兹测试开发学社

通义灵码企业版正式发布,满足企业私域知识检索、数据合规、统一管理等需求

阿里云云效

阿里云 云原生 云效 通义灵码

playwright使用:启动浏览器与多种运行方式

霍格沃兹测试开发学社

Polygon市值机器人

开发丨飞机丨 @aivenli

神器!使用Python 轻松识别验证码

霍格沃兹测试开发学社

使用 Playwright 控制浏览器的启动、停止和等待

霍格沃兹测试开发学社

深入探究 Playwright:Frame 操作技巧

霍格沃兹测试开发学社

有了京东商品详情数据接口,数据采集UP,UP,UP

tbapi

京东商品详情数据接口

使用Lambda表达式和接口的简单Java 8 Predicate示例

码语者

数智制造:机器学习与人工智能的全方位渗透

不在线第一只蜗牛

人工智能 机器学习 数智制造

Pandabuy淘宝代购集运系统赢利点讲解

tbapi

淘宝代购系统 淘宝代购集运系统 Pandabuy 反向海淘

基于51单片机的车辆倒车雷达报警系统

芯动大师

系统 51单片机 倒车

以太ETH链市值机器人

开发丨飞机丨 @aivenli

AI 大模型应用开发实战营毕业总结

海神名

零代码平台助力中国石化江苏油田实现高效评价体系

明道云

使用 Playwright 进行元素定位

霍格沃兹测试开发学社

im(即时通讯)是什么?

BeeWorks

ETLCloud中如何执行Java Bean脚本

RestCloud

Java 脚本 ETL JavaBean

如何快速上手 AI 大模型应用开发?天翼云弹性云主机给你答案!

编程猫

10分钟了解Golang泛型

俞凡

golang

软件测试学习笔记丨黑盒测试-边界值

测试人

软件测试

谈谈分布式事务原理

快乐非自愿限量之名

分布式

IM 是什么?

BeeWorks

如何打破数据管理僵局,释放数据资产价值?[AMT企源案例]

AMT企源

数据库 数据资产 数据管理 主数据

WorkPlus im(即时通讯)集成平台助力政企数字化转型升级

BeeWorks

了解学习率及其如何提高深度学习的性能_语言 & 开发_刘志勇_InfoQ精选文章