写点什么

为什么卷积神经网络优于传统机器学习算法?

  • 2020-09-25
  • 本文字数:2944 字

    阅读完需:约 10 分钟

为什么卷积神经网络优于传统机器学习算法?

本文最初发表于 Towards Data Science 博客,经原作者 Rade Nježić 授权,InfoQ 中文站翻译并分享。


近十年来,随着深度学习的发现,图像分类领域经历了复兴。传统的机器学习方法已被更新的、更强大的深度学习算法所取代,例如卷积神经网络。然而,要真正理解并欣赏深度学习,我们必须知道为什么其他方法失败了,而深度学习却成功了。在本文中,我将试图通过对 Fashion MNIST 数据集应用不同的分类算法来回答其中一些问题。

数据集信息

Fashion MNIST 是由 Zalando Fashion 研究室于 2017 年 8 月推出的。随着 MNIST 变得过于容易和过度使用,Fashion MNIST 的目标是成为测试机器学习算法的新基准。MNIST 是由手写的数字组成的,而 Fashion MNIST 是由 10 个不同服装对象的图像组成的。每个图像都具有以下属性:


  • 它的大小为 28 × 28 像素。

  • 相应地旋转并以灰度表示,整数值范围为 0~255。

  • 空白空间用黑色表示,值为 0。


在数据集中,我们区分以下服装对象:


  • T 恤/上衣

  • 长裤

  • 套衫

  • 连衣裙

  • 大衣

  • 凉鞋

  • 衬衫

  • 运动鞋

  • 手提包

  • 短靴

探索性数据分析

由于数据集作为 Keras 库的一部分可用,并且图像已经经过处理,因此我们无需进行太多的预处理。我们所做的唯一更改是将图像从二维数组转换为一维数组,因为这样可以使它们更容易处理。


该数据集由 70000 幅图像组成,其中 60000 幅为训练集,10000 幅为测试集。与原始 MNIST 数据集一样,项目也是分布均匀的(每个训练集 6000 幅,测试集 1000 幅)。



不同服装项目的图片示例


然而,一张图片仍然有 784 个维度,因此我们转向主成分分析(Principal component analysis,PCA),看看哪些像素是最重要的。我们设定了累积方差的 80% 的传统基准,而该图告诉我们,只有大约 25 个主成分(占主成分总数的 3%)才有可能实现这一目标。然而,这并不奇怪,因为,我们可以在上面的图片中看到,在每幅图像中都有很多共享的未使用空间,并且不同类别的服装有不同的图像部分是黑色的。后者可能与以下事实有关:大约 70% 的累积方差仅由 8 个主成分解释。



可解释的累积方差百分比


我们将在逻辑回归、随机森林和支持向量机中应用主成分。


图像分类问题只是分类问题的一个小子集。最常用的图像分类方法是深度学习算法,卷积神经网络就是其中之一。其余采用的方法将是一小部分常用分类方法的集合。由于分类标签是均匀分布的,没有错误分类的惩罚,故我们将使用正确率来评估算法。

卷积神经网络

我们采用的第一种方法是卷积神经网络。由于图像是灰度图像,因此我们仅应用一个通道。我们选择了以下架构:


  • 两个卷积层,具有 32 和 64 个过滤器(Filter),3 × 3 内核大小,以及 ReLU 激活函数。

  • 选择轮询层对大小为 2 × 2 的块进行操作,并从中选择最大的元素。

  • 两组密集层,其中第一层选择 128 个特征,具有 ReLU 或 Softmax 激活函数。


这种架构并没有什么特别之处。事实上,它是我们可以用于卷积神经网络的最简单的架构之一。这向我们展示了这类方法的真正威力:使用基准结构可以获得很好的结果。


对于损失函数,我们选择了分类交叉熵(Categorical cross-entropy)。为避免出现过拟合,我们从训练集中选择了 9400 幅图像作为参数的验证集。我们使用了新颖的优化器 ADAM,它比标准的梯度下降法有所改进,并且对每个参数使用不同的学习率,批大小为 64。该模型训练了 50 个轮数。我们在下图中给出了正确率和损失值。




我们可以看到,该算法训练 15 个轮数后出现收敛,但并没有过度训练,所以我们对其进行了测试。得到的测试正确率为 89%,这是所有方法中得到的最好的结果!


在继续其他方法之前,让我们先解释一下卷积层是做什么的。直观的解释是,第一层捕获直线,第二层捕获曲线。在这两个层上,我们都应用了最大池化(Max pooling),即在内核中选择最大的值,将服装部分与空白空间分开。如此一来,我们就抓住了数据的代表性。在其他方面,神经网络自己进行特征选择。经过最后一个池化层,我们就得到了一个人工神经网络。因为我们要处理的是分类问题,因此最后一层使用 Softmax 激活来获得类别概率。当类别概率遵循一定的分布时,交叉熵表示与网络首选分布的距离。

多分类逻辑回归

由于像素值是分类变量,因此我们可以应用多分类逻辑回归(Multinomial Logistic Regression)。我们以一对余的方式应用它,训练 10 个二元逻辑回归分类器,我们将使用这些分类器来选择项目。为避免过度训练,我们使用了 L2 正则化。结果我们得到了 80% 的正确率,比卷积神经网络的正确率低 9%。但是,我们必须考虑到,该算法适用于灰度图像,这些图像是居中且正常旋转的,有很多空白的地方,因此它可能对更为复杂的图像不起作用。

近邻算法和质心算法

我们使用了两种不同的最近距离算法:


  • K-最近邻算法

  • 近质心近邻算法


近质心近邻算法找出每个类别的元素的平均值,并将测试元素分配给最近质心指定的类别。这两种算法都是相对于 L1 和 L2 距离实现的。K-最近邻算法的正确率为 85%,近质心近邻算法的正确率为 67%。这些结果是在 K=12 时得到的。K-最近邻算法的高正确率告诉我们,属于同一类别的图像往往占据相似的位置,也有类似的像素密度。虽然最近邻算法获得了很好的结果,但它们的表现仍然比卷积神经网络要差,因为它们不在每个特定特征邻近工作;而质心算法失败了,因为它们并不能区分相似的对象(例如套衫与 T 恤/上衣)。

随机森林

为了选择最佳的参数进行估计,我们使用平方根(Bagging)和特征的全部数量、Gini 和熵准则,以及树的最大深度为 5 和 6 进行网格搜索。网格搜索建议使用熵准则的特征平方根数(均适用于分类任务)。然而,得到的正确率仅为 77%,这意味着随机森林方法也不是什么特别好的方法。它失败的原因是主成分并不能代表图像可以具有的矩形分区,而随机森林就是在该矩形分区上操作的。同样的道理也适用于全尺寸图像,因为树会太深,可解释性将会失去。

支持向量机

我们使用径向核和多项式核的支持向量机。径向核的正确率为 77%,而多项式核的正确率仅为 46%。虽然图像分类并不是它们的强项,但对于其他二值分类任务仍然非常有用。它们最大的警告是,它们需要特征选择,而这会降低正确率;但如果没有特征选择,它们的计算成本可能会很高。此外,它们以一对余的方式应用多类分类,这使得更难有效地创建分离超平面(Separating hyperplane),因此在处理非二元分类任务时失去价值。

总结

我们在本文中,针对一个图像分类问题,采用了多种分类方法。我们已经解释了为什么卷积神经网络是我们可以采用的最好的方法,而其他方法为什么会失败。卷积神经网络是最为实用且通常最为正确的方法的一些原因如下:


  • 它们可以通过层传递学习,保存推理并在后续层上创建新的推理。

  • 在使用该算法之前,无需进行特征提取,而是在训练过程中完成的。

  • 它可以识别重要的特征。


但是,它们也有自己的警告。目前已知它们在旋转和缩放方式不同的图像上会失效,但这里的情况并非如此,因为数据已经经过预处理。而且,尽管其他方法在这个数据集上未能提供良好的结果,但它们仍然可以用于其他与图像处理相关的任务(如锐化、平滑等等)。


代码:


https://github.com/radenjezic153/Stat_ML/blob/master/project.ipynb


原文链接:


https://towardsdatascience.com/image-classification-with-fashion-mnist-why-convolutional-neural-networks-outperform-traditional-df531e0533c2


2020-09-25 08:102549
用户头像
刘燕 InfoQ高级技术编辑

发布了 1123 篇内容, 共 605.5 次阅读, 收获喜欢 1982 次。

关注

评论

发布
暂无评论
发现更多内容

如何在业务开发中使用适配器模式?

Karmada v1.5发布:多调度组助力成本优化

华为云开发者联盟

云原生 后端 华为云 华为云开发者联盟 企业号 4 月 PK 榜

再聊 MySQL 聚簇索引

江南一点雨

Java MySQL

尚能饭否|技术越来越新,我对老朋友jQuery还是一如既往热爱

浅羽技术

jquery 前端 Web 框架 三周年连更

通过小程序容器技术让App实现灰度发布

没有用户名丶

“分割一切”大模型SAM、超轻量PP-MobileSeg、工业质检工具、全景分割方案,PaddleSeg全新版本等你来体验!

飞桨PaddlePaddle

计算机视觉 飞桨 图像分割

DAYU200关闭自动息屏的几种方式

坚果

OpenHarmony 三周年连更

Spring Boot整合多数据源实践

Java Spring Boot

全栈开发实战|Spring Boot文件上传与下载

TiAmo

Spring Boot 三周年连更 Apache Commons 文件上传下载

DeepSpeed Chat: 一键式RLHF训练,让你的类ChatGPT千亿大模型提速省钱15倍

汀丶人工智能

人工智能 自然语言处理 深度学习 ChatGPT

Java线程中的wait、notify和notifyAll解析

共饮一杯无

Java 多线程 三周年连更

基于阿里云物联网平台设计的实时图传系统_采用MQTT协议传输图像

DS小龙哥

三周年连更

MySQL8.0 优化器介绍(三)

GreatSQL

MySQL greatsql greatsql社区

企业微信接入系列-自建应用

六月的雨在InfoQ

企业微信 应用配置 三周年连更 自建应用

软件架构生态化-多角色交付的探索实践

京东科技开发者

架构 架构师 交付能力 企业号 4 月 PK 榜

CentOS7 离线安装 Zabbix5.0

A-刘晨阳

Linux zabbix 三周年连更 离线安装

华为云网站安全解决方案:中小型企业的云上云下安全守护专家

YG科技

跨平台应用开发进阶(五十)uni-app ios web-view嵌套H5项目白屏问题分析及解决

No Silver Bullet

uni-app ios 跨平台应用开发 三周年连更 web-view

《社区人员管理》实战案例设计&个人案例分享

京东科技开发者

架构 测试 编码 在线设计平台 企业号 4 月 PK 榜

Intents ,快速完成任务的最强辅助

鼎道智联

AI

MobPush Android SDK厂商通道申请指南

MobTech袤博科技

实习生疑问:为什么要在需要排序的字段上加索引呢?

架构精进之路

MySQL 数据库 索引 三周年连更

《设计模式之禅》Proxy_Pattern--代理模式

浅辄

设计模式 代理模式 三周年连更

Java枚举和注解

timerring

Java 三周年连更

SAP Emarsys 的前后台技术栈

汪子熙

SaaS Cloud SAP 思爱普 三周年连更

一篇神文就把java多线程,锁,JMM,JUC和高并发设计模式讲明白了

Java 多线程 高并发

Springboot如何手动连接库并获取指定表结构|超级详细,建议收藏

bug菌

springboot 三周年连更

Typescript-类型检测和变量的定义

格斗家不爱在外太空沉思

typescript 三周年连更

华为云网站安全解决方案助力客户——构建风险全面可控的网站安全架构

YG科技

使用 docker manifest 构建跨平台镜像

江湖十年

Docker Desktop docker image docker build Docker 镜像

跨平台图像浏览器:XnViewMP 中文激活版

真大的脸盆

Mac Mac 软件 图像查看 图像浏览

为什么卷积神经网络优于传统机器学习算法?_AI&大模型_Rade Nježić_InfoQ精选文章