时隔16年Jeff Barr重返10.23-25 QCon上海站,带你看透AI如何重塑软件开发! 了解详情
写点什么

带你跨过神经网络训练常见的 37 个坑

  • 2017-09-17
  • 本文字数:4091 字

    阅读完需:约 13 分钟

神经网络已经持续训练了 12 个小时。它看起来很好:梯度在变化,损失也在下降。但是预测结果出来了:全部都是零值,全部都是背景,什么也检测不到。我质问我的计算机:“我做错了什么?”,它却无法回答。

如果你的模型正在输出垃圾(比如预测所有输出的平均值,或者它的精确度真的很低),那么你从哪里开始检查呢?

无法训练神经网络的原因有很多。在经历了许多次调试之后,我发现有一些检查是经常做的。这张列表汇总了我的经验以及最好的想法,希望对读者也有所帮助。

〇. 使用指南

许多事情都可能出错。但其中有些事情相比于其他方面更容易出问题。在出现问题时,我通常会做以下几件事情。

  1. 从已知的用于该数据类型的简单模型入手(例如 VGG 用于图像处理)。尽可能使用标准误差。
  2. 去掉所有的花哨的预处理程序,例如正则化和数据增强。
  3. 微调模型,仔细检查预处理,应该和原始模型的训练设置保持一致。
  4. 验证输入数据的正确性。
  5. 从较小的数据集开始(2-20 个样本)。在小数据集上过拟合之后再增加数据量。
  6. 慢慢加入之前忽略的项:增强或正则化、自定义损失函数,以及尝试更多的复杂模型。

如果上面的步骤还不能解决,可以开始一项一项的按以下列表进行检查。

Ⅰ. 数据集问题

1. 检查你的输入数据

检查馈送到网络的输入数据是否正确。例如,我不止一次混淆了图像的宽度和高度。有时,我错误地让输入数据全部为零,或者一遍遍地使用同一批数据。所以要打印或显示一些批次的输入和目标输出,并确保它们是正确的。

2. 尝试随机输入

尝试向网络传入随机数而不是真实数据,看看错误的产生方式是否相同。如果是,说明在某些时候你的网络把数据转化为了垃圾。试着逐层调试,并查看出错的地方。

3. 检查数据加载器

你的数据也许很好,但是把输入数据读取到网络的代码可能有问题,所以我们应该在进行其他操作之前打印出第一层的输入并进行检查。

4. 确保输入与输出相关联

检查少许输入样本是否有正确的标签,确保打乱输入样本同样也要打乱输出标签。

5. 输入与输出之间的关系是否太随机

相较于随机的部分(可以认为股票价格也是这种情况),输入与输出之间的非随机部分也许占得比重太小。也就是说输入与输出的关联度太低。没有统一的方法来检测它,因为这取决于数据的性质。

6. 数据集中是否有太多的噪声

我曾经遇到过这种情况,当我从一个食品网站抓取一个图像数据集时,错误标签太多以至于网络无法学习。手动检查一些输入样本并查看标签是否大致正确。例如这篇文章,由于在MNIST 数据集中使用了50% 损坏的标签,只得到了50% 的准确率。

7. 打乱数据集

如果你的数据集没有被随机打乱,并且有特定的序列(按标签排序),这可能给学习带来不利影响。打乱数据集可以避免这一问题。要确保输入和标签都被重新排列。

8. 减少类别失衡

是不是对于一张类别 B 的图像,有 1000 张类别 A 图像?如果是这种情况,那么你也许需要平衡损失函数或者尝试其他解决类别失衡的方法

9. 你有足够的训练实例吗?

如果你从头开始训练一个网络(不是调试),你很可能需要大量数据。对于图像分类,每个类别需要 1000 张图像甚至更多。

10. 确保一批数据不是单一标签

这可能发生在排好顺序的数据集中(即前 10000 个样本属于同一个分类)。可通过打乱数据集轻松修复这个问题。

11. 缩减训练批次大小

这篇文章指出巨大的批次会降低模型的泛化能力。

补充. 使用标准数据集(例如MNIST,cifar10)

测试新的网络结构,或者写了一段新代码时,首先要使用标准数据集,而不是你自己的数据。这是因为在这些数据集上已经有了许多参考结果,他们被证明是“可解的”。不会出现标签噪音、训练/ 测试分布差距、数据集太难等问题。

Ⅱ. 数据归一化/ 增强

12. 归一化特征

你的输入已经归一化到零均值和单位方差了吗?

13. 你是否应用了过量的数据增强?

数据增强有正则化效果。过量的数据增强,加上其它形式的正则化(权重 L2,dropout 操作,等等)可能会导致网络欠拟合。

14. 检查预训练模型的预处理过程

如果你正在使用一个已经预训练过的模型,确保你现在正在使用的归一化和预处理与之前训练模型的设置相同。例如,一个图像的像素是在 [0, 1],[-1, 1] 或 [0, 255] 的范围内吗?

15. 检查训练、验证、测试集的预处理

CS231n 指出了一个常见的陷阱:“任何预处理数据(例如数据均值)必须只在训练数据上进行计算,然后再应用到验证、测试数据中。例如,计算均值,然后在整个数据集的每个图像中都减去它,再把数据分发进训练、验证、测试集中,这是一个典型的错误。”

此外,要在每一个样本或批次(batch)中检查是否存在不同的预处理。

Ⅲ. 实现问题

16. 试着解决某一问题的更简单的版本

这将会有助于找到问题的根源究竟在哪里。例如,如果目标输出是一个物体类别和坐标,那就试着把预测结果仅限制在物体类别当中。

17. “碰巧”寻找正确的损失

还是来源于 CS231n 的技巧:用小参数进行初始化,不使用正则化。例如,如果我们有 10 个类别,“碰巧”就意味着我们将会在 10% 的时间里得到正确类别,Softmax 损失是正确类别的负 log 概率: -ln(0.1) = 2.302。然后,试着增加正则化的强度,这样应该会增加损失。

18. 检查你的损失函数

如果你实现的是你自己的损失函数,那么就要检查错误,并且添加单元测试。通常情况下,损失可能会有些不正确,并且略微损害网络的性能表现。

19. 核实损失输入

如果你正在使用的是框架提供的损失函数,那么要确保你传递给它的东西是它所期望的。例如,在 PyTorch 中,我会混淆 NLLLoss 和 CrossEntropyLoss,因为一个需要 softmax 输入,而另一个不需要。

20. 调整损失权重

如果你的损失由几个更小的损失函数组成,那么确保它们每一个的相应幅值都是正确的。这可能会涉及到测试损失权重的不同组合。

21. 监控其它指标

有时损失并不是衡量你的网络是否被正确训练的最佳预测器。如果可以的话,使用其它指标来帮助你,例如精度。

22. 测试任意的自定义层

你自己在网络中实现过任意层吗?检查并且复核以确保它们的运行符合你的预期。

23. 检查“冷冻”层或变量

检查你是否无意中阻止了一些层或变量的梯度更新,这些层或变量本来应该是可以学习的。

24. 扩大网络规模

可能你网络的表现力不足以捕捉目标函数。试着加入更多的层,或在全连层中增加更多的隐藏单元。

25. 检查隐维度误差

如果你的输入看上去像(k,H,W)= (64, 64, 64),那么很容易错过与错误维度相关的误差。给输入维度使用一些“奇怪”的数值(例如,每一个维度使用不同的质数),并且检查它们是如何通过网络传播的。

26. 探索梯度检查

如果你手动实现了梯度下降,梯度检查会确保你的反向传播能像预期一样工作。

更多信息: 1 2 3

Ⅳ. 训练问题

27. 一个真正小的数据集

过拟合数据的一个小子集,并确保它能正常工作。例如,仅使用 1 个 或 2 个实例训练,并查看你的网络是否能够区分它们。然后再训练每个分类的更多实例。

28. 检查权重初始化

如果不确定,请使用 Xavier He 初始化。同样,初始化也许会给你带来坏的局部最小值,因此尝试不同的初始化,看看是否有效。

29. 改变你的超参数

或许你正在使用一个很糟糕的超参数集。如果可行,尝试一下网格搜索

30. 减少正则化

太多的正则化会导致网络严重地欠拟合。减少正则化,比如 dropout、批归一、权重/偏差 L2 正则化等。在课程《编程人员的深度学习实战》中, Jeremy Howard 建议首先解决欠拟合问题。这意味着你充分地过拟合训练数据,并且只在那时处理过拟合。

31. 给它一些时间

也许你的网络需要更多的时间来训练,在它能做出有意义的预测之前。如果你的损失在稳步下降,那就再多训练一会儿。

32. 从训练模式转换为测试模式

一些框架有批归一化层、Dropout 层,而其他的层在训练和测试时表现并不同。转换到适当的模式有助于网络更好地预测。

33. 可视化训练

  • 监督每层的激活值、权重和更新。确保它们的大小匹配。例如,参数更新的大小幅度(权重和偏差)应该是 1-e3
  • 考虑可视化库,例如 Tensorboard Crayon 。紧要时你也可以打印权重、偏差或激活值。
  • 寻找平均值远大于 0 的层激活。尝试批归一化层或者 ELU 单元。
  • Deeplearning4j 指出了权重和偏差柱状图的期望值应该是什么样的: 对于权重,一段时间之后这些柱状图应该有一个近似高斯的(正态)分布。对于偏差,这些柱状图通常会从 0 开始,并经常以近似高斯(LSTM 是例外情况)结束。留意那些向正无穷或负无穷发散的参数。留意那些变得很大的偏差。这有可能发生在分类网络的输出层,如果类别的分布不均匀。
  • 检查层更新,它们应该呈高斯分布。

34. 尝试不同的优化器

优化器的选择不应当妨碍网络的训练,除非你选择了特别糟糕的超参数。但是,选择一个合适的优化器非常有助于在最短的时间内获得最多的训练结果。描述算法的论文应该指定了优化器,如果没有,我倾向于选择 Adam 或者带有动量的朴素 SGD。

关于梯度下降的优化器可以参考 Sebastian Ruder 的博文

35. 梯度爆炸、梯度消失

  • 检查隐藏层的更新情况,过大的值说明可能出现了梯度爆炸。这时,梯度截断(Gradient clipping)可能会有所帮助。
  • 检查隐藏层的激活值。 Deeplearning4j 中有一个很好的指导方针:“一个好的激活值标准差大约在 0.5 到 2.0 之间。明显超过这一范围可能就代表着激活值消失或爆炸。”

36. 增加、减少学习速率

低学习速率将会导致你的模型收敛很慢。高学习速率将会在开始阶段减少你的损失,但是可能会导致你很难找到一个好的解决方案。

试着把你当前的学习速率乘以 0.1 或 10 然后进行循环。

37. 克服 NaN

据我所知,在训练 RNNs 时得到 NaN(Non-a-Number,非数)是一个很大的问题。一些解决它的方法:

  • 减小学习速率,尤其是如果你在前 100 次迭代中就得到了 NaN。
  • NaNs 的出现可能是由于用零作了除数,或用零或负数作了自然对数。
  • Russell Stewart 在《如何处理 NaN》中分享了很多心得。
  • 尝试逐层评估你的网络,这样就会看见 NaN 到底出现在了哪里。

关于作者:Slav Ivanov 是保加利亚索菲亚的企业家和 ML 实践者。博客主页

查看英文原文: 37 Reasons why your Neural Network is not working


感谢薛命灯对本文的审校。

给InfoQ 中文站投稿或者参与内容翻译工作,请邮件至 editors@cn.infoq.com 。也欢迎大家通过新浪微博( @InfoQ @丁晓昀),微信(微信号: InfoQChina )关注我们。

2017-09-17 17:4018759
用户头像

发布了 52 篇内容, 共 32.7 次阅读, 收获喜欢 73 次。

关注

评论

发布
暂无评论
发现更多内容

SpringBootApplication注解

梦倚栏杆

你愿意被管理么?

escray

学习 极客时间 朱赟的技术管理课 6月日更

Mybatis 二级缓存简单示例

Java mybatis

【21-1】21 连更第一篇

耳东@Erdong

6月日更

读深入ES6记[二]

蛋先生DX

ES6 6月日更

数字化转型背景下的测试转型

BY林子

敏捷测试 测试转型

MySQL基础之六:连接查询

打工人!

myslq 6月日更

浅谈Java中的TCP超时

Hoswey_洪树伟

Java、

高性能 JavaScriptの七 -- 编程实践小技巧

空城机

JavaScript 大前端 6月日更

缓存穿透、缓存雪崩、缓存击穿问题与优化方案

Skysper

Python——字符串转换与处理

在即

6月日更

做通才还是专才,你会怎么选?

架构精进之路

认知提升 6月日更

5分钟速读之Rust权威指南(十九)

wzx

rust 生命周期

阿里云边缘容器服务、申通 IoT 云边端架构入选 2021 云边协同发展阶段性领先成果

阿里巴巴云原生

云原生

【Vue2.x 源码学习】第八篇 - 数组的深层劫持

Brave

源码 vue2 6月日更

Packer 自动化镜像 Windows 安装过程

HoneyMoose

【Flutter 专题】109 图解自定义 ACERadio 单选框

阿策小和尚

Flutter 小菜 0 基础学习 Flutter Android 小菜鸟 6月日更

云原生推动全云开发与实践

阿里巴巴云原生

云原生

Kubernetes手记(5)- 配置清单使用

雪雷

k8s 6月日更

【布道API】浅谈API设计风格

devpoint

Rest API 6月日更

密码学系列之:生日攻击

程序那些事

加密解密 密码学 程序那些事

异构内存及其在机器学习系统的应用与优化

白玉兰开源

人工智能 机器学习 解决方案 第四范式 傲腾

spring-beans 注册 Beans(四)BeanDefinition

梦倚栏杆

Locust完成gRPC协议的性能测试

陈磊@Criss

Java--JVM运行流程

是老郭啊

Java JVM JVM原理

公司:离职就是一场危机管理

石云升

创业 职场经验 6月日更

操作系统内核是什么?Linux内核又是什么?读完这篇文章,我终于知道了

奔着腾讯去

c++ 操作系统 内存管理 Linux内核 进程管理

递归全排列问题(两种方法 Java实现)

若尘

数据结构 递归 6月日更

Python——输入输出:加减乘除四则运算的程序

在即

6月日更

这些书都学完,绝对是编程界的大佬

看山

Java 程序员 6月日更

当人工智能遇上视频直播——基于Agora Web SDK实现目标识别

dajyaretakuya

深度学习 音视频 WebRTC 声网 TensorFlow.js

带你跨过神经网络训练常见的37个坑_语言 & 开发_Slav Ivanov_InfoQ精选文章