飞天发布时刻:2024年 Forrester 公有云平台Wave™评估报告解读 了解详情
写点什么

带你跨过神经网络训练常见的 37 个坑

  • 2017-09-17
  • 本文字数:4091 字

    阅读完需:约 13 分钟

神经网络已经持续训练了 12 个小时。它看起来很好:梯度在变化,损失也在下降。但是预测结果出来了:全部都是零值,全部都是背景,什么也检测不到。我质问我的计算机:“我做错了什么?”,它却无法回答。

如果你的模型正在输出垃圾(比如预测所有输出的平均值,或者它的精确度真的很低),那么你从哪里开始检查呢?

无法训练神经网络的原因有很多。在经历了许多次调试之后,我发现有一些检查是经常做的。这张列表汇总了我的经验以及最好的想法,希望对读者也有所帮助。

〇. 使用指南

许多事情都可能出错。但其中有些事情相比于其他方面更容易出问题。在出现问题时,我通常会做以下几件事情。

  1. 从已知的用于该数据类型的简单模型入手(例如 VGG 用于图像处理)。尽可能使用标准误差。
  2. 去掉所有的花哨的预处理程序,例如正则化和数据增强。
  3. 微调模型,仔细检查预处理,应该和原始模型的训练设置保持一致。
  4. 验证输入数据的正确性。
  5. 从较小的数据集开始(2-20 个样本)。在小数据集上过拟合之后再增加数据量。
  6. 慢慢加入之前忽略的项:增强或正则化、自定义损失函数,以及尝试更多的复杂模型。

如果上面的步骤还不能解决,可以开始一项一项的按以下列表进行检查。

Ⅰ. 数据集问题

1. 检查你的输入数据

检查馈送到网络的输入数据是否正确。例如,我不止一次混淆了图像的宽度和高度。有时,我错误地让输入数据全部为零,或者一遍遍地使用同一批数据。所以要打印或显示一些批次的输入和目标输出,并确保它们是正确的。

2. 尝试随机输入

尝试向网络传入随机数而不是真实数据,看看错误的产生方式是否相同。如果是,说明在某些时候你的网络把数据转化为了垃圾。试着逐层调试,并查看出错的地方。

3. 检查数据加载器

你的数据也许很好,但是把输入数据读取到网络的代码可能有问题,所以我们应该在进行其他操作之前打印出第一层的输入并进行检查。

4. 确保输入与输出相关联

检查少许输入样本是否有正确的标签,确保打乱输入样本同样也要打乱输出标签。

5. 输入与输出之间的关系是否太随机

相较于随机的部分(可以认为股票价格也是这种情况),输入与输出之间的非随机部分也许占得比重太小。也就是说输入与输出的关联度太低。没有统一的方法来检测它,因为这取决于数据的性质。

6. 数据集中是否有太多的噪声

我曾经遇到过这种情况,当我从一个食品网站抓取一个图像数据集时,错误标签太多以至于网络无法学习。手动检查一些输入样本并查看标签是否大致正确。例如这篇文章,由于在MNIST 数据集中使用了50% 损坏的标签,只得到了50% 的准确率。

7. 打乱数据集

如果你的数据集没有被随机打乱,并且有特定的序列(按标签排序),这可能给学习带来不利影响。打乱数据集可以避免这一问题。要确保输入和标签都被重新排列。

8. 减少类别失衡

是不是对于一张类别 B 的图像,有 1000 张类别 A 图像?如果是这种情况,那么你也许需要平衡损失函数或者尝试其他解决类别失衡的方法

9. 你有足够的训练实例吗?

如果你从头开始训练一个网络(不是调试),你很可能需要大量数据。对于图像分类,每个类别需要 1000 张图像甚至更多。

10. 确保一批数据不是单一标签

这可能发生在排好顺序的数据集中(即前 10000 个样本属于同一个分类)。可通过打乱数据集轻松修复这个问题。

11. 缩减训练批次大小

这篇文章指出巨大的批次会降低模型的泛化能力。

补充. 使用标准数据集(例如MNIST,cifar10)

测试新的网络结构,或者写了一段新代码时,首先要使用标准数据集,而不是你自己的数据。这是因为在这些数据集上已经有了许多参考结果,他们被证明是“可解的”。不会出现标签噪音、训练/ 测试分布差距、数据集太难等问题。

Ⅱ. 数据归一化/ 增强

12. 归一化特征

你的输入已经归一化到零均值和单位方差了吗?

13. 你是否应用了过量的数据增强?

数据增强有正则化效果。过量的数据增强,加上其它形式的正则化(权重 L2,dropout 操作,等等)可能会导致网络欠拟合。

14. 检查预训练模型的预处理过程

如果你正在使用一个已经预训练过的模型,确保你现在正在使用的归一化和预处理与之前训练模型的设置相同。例如,一个图像的像素是在 [0, 1],[-1, 1] 或 [0, 255] 的范围内吗?

15. 检查训练、验证、测试集的预处理

CS231n 指出了一个常见的陷阱:“任何预处理数据(例如数据均值)必须只在训练数据上进行计算,然后再应用到验证、测试数据中。例如,计算均值,然后在整个数据集的每个图像中都减去它,再把数据分发进训练、验证、测试集中,这是一个典型的错误。”

此外,要在每一个样本或批次(batch)中检查是否存在不同的预处理。

Ⅲ. 实现问题

16. 试着解决某一问题的更简单的版本

这将会有助于找到问题的根源究竟在哪里。例如,如果目标输出是一个物体类别和坐标,那就试着把预测结果仅限制在物体类别当中。

17. “碰巧”寻找正确的损失

还是来源于 CS231n 的技巧:用小参数进行初始化,不使用正则化。例如,如果我们有 10 个类别,“碰巧”就意味着我们将会在 10% 的时间里得到正确类别,Softmax 损失是正确类别的负 log 概率: -ln(0.1) = 2.302。然后,试着增加正则化的强度,这样应该会增加损失。

18. 检查你的损失函数

如果你实现的是你自己的损失函数,那么就要检查错误,并且添加单元测试。通常情况下,损失可能会有些不正确,并且略微损害网络的性能表现。

19. 核实损失输入

如果你正在使用的是框架提供的损失函数,那么要确保你传递给它的东西是它所期望的。例如,在 PyTorch 中,我会混淆 NLLLoss 和 CrossEntropyLoss,因为一个需要 softmax 输入,而另一个不需要。

20. 调整损失权重

如果你的损失由几个更小的损失函数组成,那么确保它们每一个的相应幅值都是正确的。这可能会涉及到测试损失权重的不同组合。

21. 监控其它指标

有时损失并不是衡量你的网络是否被正确训练的最佳预测器。如果可以的话,使用其它指标来帮助你,例如精度。

22. 测试任意的自定义层

你自己在网络中实现过任意层吗?检查并且复核以确保它们的运行符合你的预期。

23. 检查“冷冻”层或变量

检查你是否无意中阻止了一些层或变量的梯度更新,这些层或变量本来应该是可以学习的。

24. 扩大网络规模

可能你网络的表现力不足以捕捉目标函数。试着加入更多的层,或在全连层中增加更多的隐藏单元。

25. 检查隐维度误差

如果你的输入看上去像(k,H,W)= (64, 64, 64),那么很容易错过与错误维度相关的误差。给输入维度使用一些“奇怪”的数值(例如,每一个维度使用不同的质数),并且检查它们是如何通过网络传播的。

26. 探索梯度检查

如果你手动实现了梯度下降,梯度检查会确保你的反向传播能像预期一样工作。

更多信息: 1 2 3

Ⅳ. 训练问题

27. 一个真正小的数据集

过拟合数据的一个小子集,并确保它能正常工作。例如,仅使用 1 个 或 2 个实例训练,并查看你的网络是否能够区分它们。然后再训练每个分类的更多实例。

28. 检查权重初始化

如果不确定,请使用 Xavier He 初始化。同样,初始化也许会给你带来坏的局部最小值,因此尝试不同的初始化,看看是否有效。

29. 改变你的超参数

或许你正在使用一个很糟糕的超参数集。如果可行,尝试一下网格搜索

30. 减少正则化

太多的正则化会导致网络严重地欠拟合。减少正则化,比如 dropout、批归一、权重/偏差 L2 正则化等。在课程《编程人员的深度学习实战》中, Jeremy Howard 建议首先解决欠拟合问题。这意味着你充分地过拟合训练数据,并且只在那时处理过拟合。

31. 给它一些时间

也许你的网络需要更多的时间来训练,在它能做出有意义的预测之前。如果你的损失在稳步下降,那就再多训练一会儿。

32. 从训练模式转换为测试模式

一些框架有批归一化层、Dropout 层,而其他的层在训练和测试时表现并不同。转换到适当的模式有助于网络更好地预测。

33. 可视化训练

  • 监督每层的激活值、权重和更新。确保它们的大小匹配。例如,参数更新的大小幅度(权重和偏差)应该是 1-e3
  • 考虑可视化库,例如 Tensorboard Crayon 。紧要时你也可以打印权重、偏差或激活值。
  • 寻找平均值远大于 0 的层激活。尝试批归一化层或者 ELU 单元。
  • Deeplearning4j 指出了权重和偏差柱状图的期望值应该是什么样的: 对于权重,一段时间之后这些柱状图应该有一个近似高斯的(正态)分布。对于偏差,这些柱状图通常会从 0 开始,并经常以近似高斯(LSTM 是例外情况)结束。留意那些向正无穷或负无穷发散的参数。留意那些变得很大的偏差。这有可能发生在分类网络的输出层,如果类别的分布不均匀。
  • 检查层更新,它们应该呈高斯分布。

34. 尝试不同的优化器

优化器的选择不应当妨碍网络的训练,除非你选择了特别糟糕的超参数。但是,选择一个合适的优化器非常有助于在最短的时间内获得最多的训练结果。描述算法的论文应该指定了优化器,如果没有,我倾向于选择 Adam 或者带有动量的朴素 SGD。

关于梯度下降的优化器可以参考 Sebastian Ruder 的博文

35. 梯度爆炸、梯度消失

  • 检查隐藏层的更新情况,过大的值说明可能出现了梯度爆炸。这时,梯度截断(Gradient clipping)可能会有所帮助。
  • 检查隐藏层的激活值。 Deeplearning4j 中有一个很好的指导方针:“一个好的激活值标准差大约在 0.5 到 2.0 之间。明显超过这一范围可能就代表着激活值消失或爆炸。”

36. 增加、减少学习速率

低学习速率将会导致你的模型收敛很慢。高学习速率将会在开始阶段减少你的损失,但是可能会导致你很难找到一个好的解决方案。

试着把你当前的学习速率乘以 0.1 或 10 然后进行循环。

37. 克服 NaN

据我所知,在训练 RNNs 时得到 NaN(Non-a-Number,非数)是一个很大的问题。一些解决它的方法:

  • 减小学习速率,尤其是如果你在前 100 次迭代中就得到了 NaN。
  • NaNs 的出现可能是由于用零作了除数,或用零或负数作了自然对数。
  • Russell Stewart 在《如何处理 NaN》中分享了很多心得。
  • 尝试逐层评估你的网络,这样就会看见 NaN 到底出现在了哪里。

关于作者:Slav Ivanov 是保加利亚索菲亚的企业家和 ML 实践者。博客主页

查看英文原文: 37 Reasons why your Neural Network is not working


感谢薛命灯对本文的审校。

给InfoQ 中文站投稿或者参与内容翻译工作,请邮件至 editors@cn.infoq.com 。也欢迎大家通过新浪微博( @InfoQ @丁晓昀),微信(微信号: InfoQChina )关注我们。

2017-09-17 17:4018284
用户头像

发布了 52 篇内容, 共 30.9 次阅读, 收获喜欢 73 次。

关注

评论

发布
暂无评论
发现更多内容

比特币领涨,反转行情即将开启?市场双位数反弹与未来展望

区块链软件开发推广运营

dapp开发 区块链开发 链游开发 NFT开发 公链开发

Forrester Wave™报告:天翼云三项产品能力获评最高分!

天翼云开发者社区

云计算 公有云 云平台

【YashanDB数据库】Mybatis-plus分页框架识别不到Yashandb

YashanDB

yashandb 崖山数据库 崖山DB

AI自动化应用开发,让创意与效率并驾齐驱!

测吧(北京)科技有限公司

测试

京东小程序数据中心架构设计与最佳实践

京东科技开发者

实际上手体验maven面对冲突Jar包的加载规则

京东科技开发者

六个策略,打造网络安全宣传周峰值体验

我再BUG界嘎嘎乱杀

网络安全 信息安全 网络安全宣传周

如何成为网络安全架构师?

我再BUG界嘎嘎乱杀

黑客 网络安全 信息安全 架构师 网安

【YashanDB数据库】由于网络带宽不足导致的jdbc向yashandb插入数据慢

YashanDB

yashandb 崖山数据库 崖山DB

跨越边界:京东商品详情API的全球拓展之旅

代码忍者

拼多多API接口:通过商品ID获取拼多多商品详情数据接口

tbapi

拼多多商品详情接口 拼多多API 拼多多商品数据采集

全国高校软件测试开发教学师资培训会圆满落幕

测吧(北京)科技有限公司

测试

如何制作巡逻巡更二维码?扫码就能快速上报异常情况

草料二维码

设备巡检 草料二维码 二维码系统 巡逻巡更二维码 巡逻巡更

商品计划管理系统助力企业实现高效决策与资源配置

第七在线

AI入门之深度学习:基本概念篇

京东科技开发者

怎么用云手机进行TikTok矩阵运营

Ogcloud

云手机 海外云手机 tiktok云手机 云手机海外版 tiktok矩阵

如何构建高效的 CRUD 应用程序?

NocoBase

软件开发 crud crudapi

一招致胜!天翼云对象存储攻克数据存、管、用难题!

天翼云开发者社区

云计算 对象存储 云服务 天翼云

3 x 2 + 1 !安 全 能 力 权 威 认 可 !

天翼云开发者社区

云计算 安全 天翼云

AI 时代,网关更能打了?

阿里巴巴云原生

阿里云 云原生 网关

并发性能提升 4 倍!云帐房用 Serverless 轻松应对瞬时业务洪峰

阿里巴巴云原生

阿里云 Serverless 云原生

从理念到实践,解构HBlock降本增效黑科技!

天翼云开发者社区

数据库 云计算 存储 天翼云

【YashanDB数据库】PHP无法通过ODBC连接到数据库

YashanDB

yashandb 崖山数据库 崖山DB

苹果电脑防火墙Radio Silence for mac v3.2激活版 附安装教程

Rose

苹果电脑 mac防火墙 Radio Silence下载 Radio Silence破解版

课件ppt怎么做?3个在线网站轻松制作教学ppt!

职场工具箱

效率 职场 PPT 办公软件 AI生成PPT

云手机在海外社交媒体运营中的作用

Ogcloud

云手机 海外云手机 云手机海外版 海外社媒运营 海外社媒营销

最佳实践:解读GaussDB(DWS) 统计信息自动收集方案

不在线第一只蜗牛

Java 人工智能 GuassDB

带你跨过神经网络训练常见的37个坑_语言 & 开发_Slav Ivanov_InfoQ精选文章