QCon 演讲火热征集中,快来分享技术实践与洞见! 了解详情
写点什么

带你跨过神经网络训练常见的 37 个坑

  • 2017-09-17
  • 本文字数:4091 字

    阅读完需:约 13 分钟

神经网络已经持续训练了 12 个小时。它看起来很好:梯度在变化,损失也在下降。但是预测结果出来了:全部都是零值,全部都是背景,什么也检测不到。我质问我的计算机:“我做错了什么?”,它却无法回答。

如果你的模型正在输出垃圾(比如预测所有输出的平均值,或者它的精确度真的很低),那么你从哪里开始检查呢?

无法训练神经网络的原因有很多。在经历了许多次调试之后,我发现有一些检查是经常做的。这张列表汇总了我的经验以及最好的想法,希望对读者也有所帮助。

〇. 使用指南

许多事情都可能出错。但其中有些事情相比于其他方面更容易出问题。在出现问题时,我通常会做以下几件事情。

  1. 从已知的用于该数据类型的简单模型入手(例如 VGG 用于图像处理)。尽可能使用标准误差。
  2. 去掉所有的花哨的预处理程序,例如正则化和数据增强。
  3. 微调模型,仔细检查预处理,应该和原始模型的训练设置保持一致。
  4. 验证输入数据的正确性。
  5. 从较小的数据集开始(2-20 个样本)。在小数据集上过拟合之后再增加数据量。
  6. 慢慢加入之前忽略的项:增强或正则化、自定义损失函数,以及尝试更多的复杂模型。

如果上面的步骤还不能解决,可以开始一项一项的按以下列表进行检查。

Ⅰ. 数据集问题

1. 检查你的输入数据

检查馈送到网络的输入数据是否正确。例如,我不止一次混淆了图像的宽度和高度。有时,我错误地让输入数据全部为零,或者一遍遍地使用同一批数据。所以要打印或显示一些批次的输入和目标输出,并确保它们是正确的。

2. 尝试随机输入

尝试向网络传入随机数而不是真实数据,看看错误的产生方式是否相同。如果是,说明在某些时候你的网络把数据转化为了垃圾。试着逐层调试,并查看出错的地方。

3. 检查数据加载器

你的数据也许很好,但是把输入数据读取到网络的代码可能有问题,所以我们应该在进行其他操作之前打印出第一层的输入并进行检查。

4. 确保输入与输出相关联

检查少许输入样本是否有正确的标签,确保打乱输入样本同样也要打乱输出标签。

5. 输入与输出之间的关系是否太随机

相较于随机的部分(可以认为股票价格也是这种情况),输入与输出之间的非随机部分也许占得比重太小。也就是说输入与输出的关联度太低。没有统一的方法来检测它,因为这取决于数据的性质。

6. 数据集中是否有太多的噪声

我曾经遇到过这种情况,当我从一个食品网站抓取一个图像数据集时,错误标签太多以至于网络无法学习。手动检查一些输入样本并查看标签是否大致正确。例如这篇文章,由于在MNIST 数据集中使用了50% 损坏的标签,只得到了50% 的准确率。

7. 打乱数据集

如果你的数据集没有被随机打乱,并且有特定的序列(按标签排序),这可能给学习带来不利影响。打乱数据集可以避免这一问题。要确保输入和标签都被重新排列。

8. 减少类别失衡

是不是对于一张类别 B 的图像,有 1000 张类别 A 图像?如果是这种情况,那么你也许需要平衡损失函数或者尝试其他解决类别失衡的方法

9. 你有足够的训练实例吗?

如果你从头开始训练一个网络(不是调试),你很可能需要大量数据。对于图像分类,每个类别需要 1000 张图像甚至更多。

10. 确保一批数据不是单一标签

这可能发生在排好顺序的数据集中(即前 10000 个样本属于同一个分类)。可通过打乱数据集轻松修复这个问题。

11. 缩减训练批次大小

这篇文章指出巨大的批次会降低模型的泛化能力。

补充. 使用标准数据集(例如MNIST,cifar10)

测试新的网络结构,或者写了一段新代码时,首先要使用标准数据集,而不是你自己的数据。这是因为在这些数据集上已经有了许多参考结果,他们被证明是“可解的”。不会出现标签噪音、训练/ 测试分布差距、数据集太难等问题。

Ⅱ. 数据归一化/ 增强

12. 归一化特征

你的输入已经归一化到零均值和单位方差了吗?

13. 你是否应用了过量的数据增强?

数据增强有正则化效果。过量的数据增强,加上其它形式的正则化(权重 L2,dropout 操作,等等)可能会导致网络欠拟合。

14. 检查预训练模型的预处理过程

如果你正在使用一个已经预训练过的模型,确保你现在正在使用的归一化和预处理与之前训练模型的设置相同。例如,一个图像的像素是在 [0, 1],[-1, 1] 或 [0, 255] 的范围内吗?

15. 检查训练、验证、测试集的预处理

CS231n 指出了一个常见的陷阱:“任何预处理数据(例如数据均值)必须只在训练数据上进行计算,然后再应用到验证、测试数据中。例如,计算均值,然后在整个数据集的每个图像中都减去它,再把数据分发进训练、验证、测试集中,这是一个典型的错误。”

此外,要在每一个样本或批次(batch)中检查是否存在不同的预处理。

Ⅲ. 实现问题

16. 试着解决某一问题的更简单的版本

这将会有助于找到问题的根源究竟在哪里。例如,如果目标输出是一个物体类别和坐标,那就试着把预测结果仅限制在物体类别当中。

17. “碰巧”寻找正确的损失

还是来源于 CS231n 的技巧:用小参数进行初始化,不使用正则化。例如,如果我们有 10 个类别,“碰巧”就意味着我们将会在 10% 的时间里得到正确类别,Softmax 损失是正确类别的负 log 概率: -ln(0.1) = 2.302。然后,试着增加正则化的强度,这样应该会增加损失。

18. 检查你的损失函数

如果你实现的是你自己的损失函数,那么就要检查错误,并且添加单元测试。通常情况下,损失可能会有些不正确,并且略微损害网络的性能表现。

19. 核实损失输入

如果你正在使用的是框架提供的损失函数,那么要确保你传递给它的东西是它所期望的。例如,在 PyTorch 中,我会混淆 NLLLoss 和 CrossEntropyLoss,因为一个需要 softmax 输入,而另一个不需要。

20. 调整损失权重

如果你的损失由几个更小的损失函数组成,那么确保它们每一个的相应幅值都是正确的。这可能会涉及到测试损失权重的不同组合。

21. 监控其它指标

有时损失并不是衡量你的网络是否被正确训练的最佳预测器。如果可以的话,使用其它指标来帮助你,例如精度。

22. 测试任意的自定义层

你自己在网络中实现过任意层吗?检查并且复核以确保它们的运行符合你的预期。

23. 检查“冷冻”层或变量

检查你是否无意中阻止了一些层或变量的梯度更新,这些层或变量本来应该是可以学习的。

24. 扩大网络规模

可能你网络的表现力不足以捕捉目标函数。试着加入更多的层,或在全连层中增加更多的隐藏单元。

25. 检查隐维度误差

如果你的输入看上去像(k,H,W)= (64, 64, 64),那么很容易错过与错误维度相关的误差。给输入维度使用一些“奇怪”的数值(例如,每一个维度使用不同的质数),并且检查它们是如何通过网络传播的。

26. 探索梯度检查

如果你手动实现了梯度下降,梯度检查会确保你的反向传播能像预期一样工作。

更多信息: 1 2 3

Ⅳ. 训练问题

27. 一个真正小的数据集

过拟合数据的一个小子集,并确保它能正常工作。例如,仅使用 1 个 或 2 个实例训练,并查看你的网络是否能够区分它们。然后再训练每个分类的更多实例。

28. 检查权重初始化

如果不确定,请使用 Xavier He 初始化。同样,初始化也许会给你带来坏的局部最小值,因此尝试不同的初始化,看看是否有效。

29. 改变你的超参数

或许你正在使用一个很糟糕的超参数集。如果可行,尝试一下网格搜索

30. 减少正则化

太多的正则化会导致网络严重地欠拟合。减少正则化,比如 dropout、批归一、权重/偏差 L2 正则化等。在课程《编程人员的深度学习实战》中, Jeremy Howard 建议首先解决欠拟合问题。这意味着你充分地过拟合训练数据,并且只在那时处理过拟合。

31. 给它一些时间

也许你的网络需要更多的时间来训练,在它能做出有意义的预测之前。如果你的损失在稳步下降,那就再多训练一会儿。

32. 从训练模式转换为测试模式

一些框架有批归一化层、Dropout 层,而其他的层在训练和测试时表现并不同。转换到适当的模式有助于网络更好地预测。

33. 可视化训练

  • 监督每层的激活值、权重和更新。确保它们的大小匹配。例如,参数更新的大小幅度(权重和偏差)应该是 1-e3
  • 考虑可视化库,例如 Tensorboard Crayon 。紧要时你也可以打印权重、偏差或激活值。
  • 寻找平均值远大于 0 的层激活。尝试批归一化层或者 ELU 单元。
  • Deeplearning4j 指出了权重和偏差柱状图的期望值应该是什么样的: 对于权重,一段时间之后这些柱状图应该有一个近似高斯的(正态)分布。对于偏差,这些柱状图通常会从 0 开始,并经常以近似高斯(LSTM 是例外情况)结束。留意那些向正无穷或负无穷发散的参数。留意那些变得很大的偏差。这有可能发生在分类网络的输出层,如果类别的分布不均匀。
  • 检查层更新,它们应该呈高斯分布。

34. 尝试不同的优化器

优化器的选择不应当妨碍网络的训练,除非你选择了特别糟糕的超参数。但是,选择一个合适的优化器非常有助于在最短的时间内获得最多的训练结果。描述算法的论文应该指定了优化器,如果没有,我倾向于选择 Adam 或者带有动量的朴素 SGD。

关于梯度下降的优化器可以参考 Sebastian Ruder 的博文

35. 梯度爆炸、梯度消失

  • 检查隐藏层的更新情况,过大的值说明可能出现了梯度爆炸。这时,梯度截断(Gradient clipping)可能会有所帮助。
  • 检查隐藏层的激活值。 Deeplearning4j 中有一个很好的指导方针:“一个好的激活值标准差大约在 0.5 到 2.0 之间。明显超过这一范围可能就代表着激活值消失或爆炸。”

36. 增加、减少学习速率

低学习速率将会导致你的模型收敛很慢。高学习速率将会在开始阶段减少你的损失,但是可能会导致你很难找到一个好的解决方案。

试着把你当前的学习速率乘以 0.1 或 10 然后进行循环。

37. 克服 NaN

据我所知,在训练 RNNs 时得到 NaN(Non-a-Number,非数)是一个很大的问题。一些解决它的方法:

  • 减小学习速率,尤其是如果你在前 100 次迭代中就得到了 NaN。
  • NaNs 的出现可能是由于用零作了除数,或用零或负数作了自然对数。
  • Russell Stewart 在《如何处理 NaN》中分享了很多心得。
  • 尝试逐层评估你的网络,这样就会看见 NaN 到底出现在了哪里。

关于作者:Slav Ivanov 是保加利亚索菲亚的企业家和 ML 实践者。博客主页

查看英文原文: 37 Reasons why your Neural Network is not working


感谢薛命灯对本文的审校。

给InfoQ 中文站投稿或者参与内容翻译工作,请邮件至 editors@cn.infoq.com 。也欢迎大家通过新浪微博( @InfoQ @丁晓昀),微信(微信号: InfoQChina )关注我们。

2017-09-17 17:4018107
用户头像

发布了 52 篇内容, 共 30.4 次阅读, 收获喜欢 73 次。

关注

评论

发布
暂无评论
发现更多内容

用这4招优雅的实现Spring Boot 异步线程间数据传递

小小怪下士

Java spring 程序员 springboot

Redis高级数据结构Stream和HyperLogLog

做梦都在改BUG

Java redis stream HyperLogLog

5 步带你入门 GaussDB (DWS) 的 GDS 导入导出

华为云开发者联盟

数据库 华为云 企业号 2 月 PK 榜 华为云开发者联盟

宽表为什么横行?

王磊

OKR之剑·实战篇06:OKR致胜法宝-氛围&业绩双轮驱动(下)

vivo互联网技术

团队管理 OKR

实现一个简单的Database10(译文)

GreatSQL

sqlite myslq greatsql greatsql社区

智能汽车商业化、产业化演进及投资机会分析

不脱发的程序猿

汽车电子 智能汽车商业化 汽车行业投资机会分析

入门数据分析师的最强秘籍,都在这4本书里!

博文视点Broadview

飞桨框架v2.4 API新升级!全面支持稀疏计算、图学习、语音处理等任务

飞桨PaddlePaddle

paddle API 飞桨

面试官:你来谈一下Synchronized-轻量级锁

做梦都在改BUG

Java synchronized 轻量级锁

进击中的 Zebec 生态,Web2 与 Web3 世界的连接器

西柚子

Java Map操作解锁新姿势

派大星

2023-02-14:魔物了占领若干据点,这些据点被若干条道路相连接, roads[i] = [x, y] 表示编号 x、y 的两个据点通过一条道路连接。 现在勇者要将按照以下原则将这些据点逐一夺回:

福大大架构师每日一题

算法 rust 福大大

ChatGPT入门案例|商务智能对话客服(一)| 社区征文

TiAmo

AI ChatGPT

基于文心大模型套件ERNIEKit实现文本匹配算法,模块化方便应用落地

汀丶人工智能

自然语言处理 nlp 2月月更 2月日更 文本匹配算法

明晚 8 点直播!OpenCloudOS 中的海光国密算法分析

OpenCloudOS

Linux

模块六作业

程序员小张

「架构实战营」

Linux安装ElasticSearch

Geek_7ubdnf

Java elasticsearch

中国工商银行签约易观千帆,夯实数字基石,助力用户价值增长

易观分析

金融 银行

软件测试/测试开发 | Selenium多浏览器处理

测试人

软件测试 自动化测试 测试开发 selenium web测试

OneCode开源低代码引擎白皮书

codebee

低代码 开发工具 低代码平台 java UI

微服务 SpringBoot 整合 Redis GEO 实现附近商户功能

做梦都在改BUG

Java redis 微服务 Spring Boot

分布式事务解决方案

Java 分布式事务 事务

记一次SpringBoot启动优化实践

做梦都在改BUG

Java spring Spring Boot

Three.js 进阶之旅:物理效果-碰撞和声音 💥

dragonir

CSS JavaScript html 前端 three.js

从实战出发,聊聊缓存数据库一致性

做梦都在改BUG

Java 数据库 缓存 一致性

ArkUI新能力,助力应用开发更便捷

HarmonyOS开发者

HarmonyOS

怎样快速地迁移 MySQL 中的数据?

做梦都在改BUG

Java MySQL 数据库

面试官:分库分表,真的有必要吗?

做梦都在改BUG

Java 分库分表

软件测试/测试开发 | web自动化测试-执行 JavaScript 脚本

测试人

软件测试 自动化测试 测试开发 Web自动化测试 selenium

面试官:如果 MySQL 数据库中的数据丢失,有哪些补救的办法呢?

做梦都在改BUG

Java MySQL 数据库

带你跨过神经网络训练常见的37个坑_语言 & 开发_Slav Ivanov_InfoQ精选文章