写点什么

MXNet API 入门 —第 5 篇

2017 年 7 月 19 日

第4 篇中,我们介绍了如何轻松地使用预训练版Inception v3 模型进行物体识别。本文将使用另外两个著名的卷积神经网络(CNN)模型(VGG19 和ResNet-152),并将其与Inception v3 的效果进行对比。

CNN 的架构(来源:Nvidia)

VGG16

诞生于 2014 年的 VGG16 是一种16 层模型(研究论文),该模型以低至 7.4%的物体识别错误率赢得了 2014 年度 ImageNet 挑战赛

ResNet-152

诞生于 2015 年的 ResNet-152 是一种152 层模型(研究论文),在物体识别方面,该模型以破纪录的 3.57%错误率赢得了 2015 年度 ImageNet 挑战赛。这一准确率甚至超过了普通的人工识别,人工识别的错误率通常为 5% 左右。

下载模型

首先再次访问 Model zoo 。与 Inception v3 类似,我们需要下载模型的定义和参数。这三个模型都针对相同的分类进行了训练,因此可以继续使用之前的 synset.txt 文件。

复制代码
$ wget http://data.dmlc.ml/models/imagenet/vgg/vgg16-symbol.json
$ wget http://data.dmlc.ml/models/imagenet/vgg/vgg16-0000.params
$ wget http://data.dmlc.ml/models/imagenet/resnet/152-layers/resnet-152-symbol.json
$ wget http://data.dmlc.ml/models/imagenet/resnet/152-layers/resnet-152-0000.params

加载模型

这三个模型都使用 ImageNet 数据集进行过训练,该数据集中图片通常为 224x224 像素大小。由于数据形态和分类完全相同,因此也可以继续重用之前的代码

这次只需要更改模型的名称 :) 我们可以为 loadModel() 和 init() 函数添加一个参数。

复制代码
def loadModel(modelname):
sym, arg_params, aux_params = mx.model.load_checkpoint(modelname, 0)
mod = mx.mod.Module(symbol=sym)
mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))])
mod.set_params(arg_params, aux_params)
return mod
def init(modelname):
model = loadModel(modelname)
cats = loadCategories()
return model, cats

预测结果对比

我们可以通过同一批图片对三个模型的结果进行对比。

复制代码
*** VGG16
[(0.58786136, 'n03272010 electric guitar'), (0.29260877, 'n04296562 stage'),
(0.013744719, 'n04487394 trombone'), (0.013494448, 'n04141076 sax, saxophone'),
(0.00988709, 'n02231487 walking stick, walkingstick, stick insect')]

可能性最高的两个分类判断结果让人满意,但后续三个结果错得离谱。似乎竖立放置的麦克风支架干扰了模型的识别。

复制代码
*** ResNet-152
[(0.91063803, 'n04296562 stage'), (0.039011702, 'n03272010 electric guitar'),
(0.031426914, 'n03759954 microphone, mike'), (0.011822623,
'n04286575 spotlight, spot'), (0.0020199812, 'n02676566 acoustic guitar')]

排名第一的预测非常准确,但接下来的四个预测风马牛不相及。

复制代码
*** Inception v3
[(0.58039135, 'n03272010 electric guitar'), (0.27168664, 'n04296562 stage'),
(0.090769522, 'n04456115 torch'), (0.023762707, 'n04286575 spotlight, spot'),
(0.0081428187, 'n03250847 drumstick')]

前两个分类的结果与 VGG16 极为类似,另外三个结果良莠不齐。

再换张图片试试看。

复制代码
*** VGG16
[(0.96909302, 'n04536866 violin, fiddle'), (0.026661994, 'n02992211 cello, violoncello'),
(0.0017284016, 'n02879718 bow'), (0.00056815811, 'n04517823 vacuum, vacuum cleaner'),
(0.00024804732, 'n04090263 rifle')]
*** ResNet-152
[(0.96826887, 'n04536866 violin, fiddle'), (0.028052919, 'n02992211 cello, violoncello'),
(0.0008367821, 'n02676566 acoustic guitar'), (0.00070532493, 'n02787622 banjo'),
(0.00039021231, 'n02879718 bow')]
*** Inception v3
[(0.82023674, 'n04536866 violin, fiddle'), (0.15483995, 'n02992211 cello, violoncello'),
(0.0044540241, 'n02676566 acoustic guitar'), (0.0020963412, 'n02879718 bow'),
(0.0015099624, 'n03447721 gong, tam-tam')]

三个模型排名第一的预测分数都很高。不过也可以理解,毕竟小提琴的外形在神经网络看来还是很有特点的。

很明显,单凭少数例子还不能得出准确的结论。如果要挑选预训练模型,那么绝对要先准备好训练数据集,使用自己的数据进行测试并酌情决定!

技术指标对比

不同模型的详细评测结果可参阅现有研究论文,例如这一篇。对开发者来说,也许更需要考虑另外两个重要因素:

  • 模型需要多少内存
  • 模型的预测能有多快

为了回答第一个问题,可以通过参数文件的大小进行一个粗略的猜测:

  • VGG16:528MB(约 1.4 亿个参数)
  • ResNet-152:230MB(约 6000 万个参数)
  • Inception v3:43MB(约 2500 万个参数)

可见,目前的趋势是使用参数较少的更深度的网络 **。这样做有两方面收益:训练速度更快(因为网络需要学习的参数更少),并且可以 ** 降低内存使用 *。

第二个问题更复杂一些,并且取决于很多因素,例如批大小。所以我们对预测调用进行计时然后再次运行看看。

复制代码
t1 = time.time()
model.forward(Batch([array]))
t2 = time.time()
t = 1000*(t2-t1)
print("Predicted in %2.2f microseconds" % t)

结果就是这样(结果对多次调用取平均值获得)。

复制代码
*** VGG16
0.30 微秒完成预测
*** ResNet-152
0.90 微秒完成预测
*** Inception v3
0.40 微秒完成预测

总结来说(请自行套用标准免责声明):

  • 所有三个网络中(目前)ResNet-152 准确率最高,但速度也慢了 2–3 倍。
  • VGG16 速度最快(莫非因为层数少?)但内存用量最高并且准确率最差。
  • Inception v3 几乎最快,同时在准确率和内存使用方面较为平均。这也使得该模型成为某些条件受限的环境中最佳的选择。最后一篇文章中,还将进一步讨论这个问题 :)。

代码已发布至 GitHub: mxnet_example3.py , 请自行尝试。

后续内容:

  • 第 6 篇:通过树莓派进行实时物体检测(并让它讲话!)

作者 Julien Simon 阅读英文原文 An introduction to the MXNet API?—?part 5


感谢杜小芳对本文的审校。

给InfoQ 中文站投稿或者参与内容翻译工作,请邮件至 editors@cn.infoq.com 。也欢迎大家通过新浪微博( @InfoQ @丁晓昀),微信(微信号: InfoQChina )关注我们。

2017 年 7 月 19 日 17:031734
用户头像

发布了 283 篇内容, 共 84.6 次阅读, 收获喜欢 34 次。

关注

评论

发布
暂无评论
发现更多内容

区块链数字钱包系统开发方案,区块链钱包APP源码

13530558032

DataPipeline CPO 陈雷:实时数据融合之法,稳定高容错

DataPipeline

数据融合

快进收藏吃灰!字节跳动大佬用最通俗方法讲明白了红黑树算法

小Q

Java 学习 架构 面试 算法

6. 自定义容器类型元素验证,类级别验证(多字段联合验证)

YourBatman

Hibernate-Validator Bean Validation 多字段联合验证

Scrum指南这么改,我看要完蛋!

华为云开发者社区

Scrum 敏捷 改版

微信官方将打击恶意营销号:自媒体不可过度消费粉丝

石头IT视角

万字图文 | 聊一聊 ReentrantLock 和 AQS 那点事(看完不会你找我)

源码兴趣圈

架构 AQS ReentrantLock JUC CLH

架构师训练营第九周作业

_

极客大学架构师训练营 第九周作业

接口测试学习之json

测试人生路

json 接口测试

阿里达摩院副院长亲自所写Java架构29大核心知识体系+大厂面试真题+微服务

Java架构追梦

Java 学习 阿里巴巴 架构 面试

Springboot过滤器和拦截器详解及使用场景

996小迁

Java 编程 架构 面试 springboot

前嗅教你大数据——史上最全代理IP服务商对比

前嗅大数据

大数据 数据采集 动态代理 静态代理 代理IP

号外!5G+X联创营华为云官网上线,5G 创业春天来了!

华为云开发者社区

华为 程序员 AI 5G

强化学习入门必看之强化学习导识

Alocasia

人工智能 学习

区块链数字货币钱包开发,去中心化钱包搭建app

WX13823153201

11月阿里Spring全家桶+MQ微服务架构笔记:源码+实战

小Q

Java 学习 程序员 面试 微服务

DataPipeline CPO 陈雷:实时数据融合之法,便捷可管理

DataPipeline

数据融合

AI技术在音乐类产品中的应用场景

HIFIVE嗨翻屋

人工智能 AI 音乐 音乐制作

公众号高频被调整,它不是企业生产文章的机器

Linkflow

客户数据平台 CDP 私域流量

媲美物理机,裸金属云主机如何轻松应对11.11大促

京东智联云开发者

云计算 服务器 云主机 裸金属容器

企业工作流设计原则及多项目整合开发注意事项

Marilyn

敏捷开发 工作流 企业开发

合约跟单源码案例,合约跟单模式开发

13530558032

面试官问:如何排除GC引起的CPU飙高?我脱口而出5个步骤

田维常

cpu飙满

DataPipeline CTO 陈肃:构建批流一体数据融合平台的一致性语义保证

DataPipeline

数据融合

UNISKIN COO Kevin|营销数字化:数据沉淀和数据系统化运营一定要趁早!

Linkflow

营销数字化 客户数据平台 CDP

DataPipeline 王睿:业务异常实时自动化检测 — 基于人工智能的系统实战

DataPipeline

大数据

【JDD京智大咖说】AI 未来,路在何方?NLP、CV 技术的探索与展望

京东智联云开发者

人工智能 CV nlp

数字货币交易所开发有哪些模式?区块链交易平台

13530558032

区块链社交即时通许系统开发,区块链社交app开发价格

13530558032

《JAVA多线程设计模式》.pdf

田维常

多线程

DataPipeline CPO 陈雷:实时数据融合之道,博观约取,价值驱动

DataPipeline

数据融合

InfoQ 极客传媒开发者生态共创计划线上发布会

InfoQ 极客传媒开发者生态共创计划线上发布会

MXNet API入门 —第5篇-InfoQ