在第3 篇文章中,我们构建并训练了第一个神经网络,接下来可以处理一些更复杂的样本了。
最顶尖的深度学习模型通常都复杂到让人难以置信。其中可能包含数百层,就算用不了数周,往往也要数天时间来使用海量数据进行训练。这类模型的构建和优化需要大量经验。
好在这些模型的使用还是很简单的,通常只需要编写几行代码。本文将使用一个名为 Inception v3的预训练模型进行图片分类。
Inception v3
诞生于 2015 年 12 月的 Inception v3 是 GoogleNet 模型(曾赢得 2014 年度 ImageNet 挑战赛)的改进版。本文不准备深入介绍该模型的研究论文,不过打算强调一下论文的结论:相比当时最棒的模型,Inception v3 的准确度高出了15%–25%,同时计算的经济性方面低六倍,并且至少将参数的数量减少了五倍(例如使用该模型对内存的要求更低)。
简直就是神器!那么我们该如何使用?
MXNet model zoo
Model zoo 提供了一系列可直接使用的预训练模型,并且通常还会提供模型定义、模型参数(例如神经元权重),(也许还会提供)使用说明。
首先来下载定义和参数(你也许需要更改文件名)。第一个文件可以直接打开:其中包含了每一层的定义。第二个文件是一个二进制文件,请不要打开 ;)
$ wget http://data.dmlc.ml/models/imagenet/inception-bn/Inception-BN-symbol.json $ wget http://data.dmlc.ml/models/imagenet/inception-bn/Inception-BN-0126.params $ mv Inception-BN-0126.params Inception-BN-0000.params
该模型已通过 ImageNet 数据集进行了训练,因此我们还需要下载对应的图片分类清单(共有 1000 个分类)。
$ wget http://data.dmlc.ml/models/imagenet/synset.txt $ wc -l synset.txt 1000 synset.txt $ head -5 synset.txt n01440764 tench, Tinca tinca n01443537 goldfish, Carassius auratus n01484850 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias n01491361 tiger shark, Galeocerdo cuvieri n01494475 hammerhead, hammerhead shark
搞定,开始实战。
加载模型
我们需要:
-
加载处于保存状态的模型:MXNet 将其称之为检查点 (Checkpoint)。随后即可得到输入的 Symbol 和模型参数。 ```
import mxnet as mx
sym, arg_params, aux_params = mx.model.load_checkpoint(‘Inception-BN’, 0)
- 新建一个 Module 并为其指派输入 Symbol。我们还可以使用一个 Context 参数决定要在哪里运行该模型:默认值为 cpu(0),但也可改为 gpu(0) 以便通过 GPU 运行。 ``` mod = mx.mod.Module(symbol=sym)
- 将输入 Symbol 绑定至输入数据。将其称之为“数据”是因为在网络的输入层中就使用了这样的名称(可以从 JSON 文件的前几行代码中看到)。
- 将“数据”的形态 (Shape)定义为 1x3x224x224。别慌 ;),“224x224”是图片的分辨率,模型就是这样训练出来的。“3”是通道数量:红绿蓝(严格按照这样的顺序),“1”是批大小:我们将一次预测一张图片。
mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))])
-
设置模型参数。 ```
mod.set_params(arg_params, aux_params)
这样就可以了。只需要四行代码!随后可以放入一些数据看看会发生什么。嗯……先别急。 ## 准备数据 数据准备:从七十年代以来,这一直是个痛苦的过程……从关系型数据库到机器学习,再到深度学习,这方面没有任何改进。虽然乏味但很必要。开始吧。 还记得吗,这个模型需要通过四维 NDArray 来保存一张 224x224 分辨率图片的红、绿、蓝通道数据。我们将使用流行的 [OpenCV](http://www.opencv.org/) 库从输入图片中构建这样的 NDArray。如果还没安装 OpenCV,考虑到本例的要求,直接运行 pip install opencv-python 就够了 :)。 随后的步骤如下: - ** 读取 ** 图片:将返回一个 Numpy 数组,其形态为(图片高度, 图片宽度, 3),按顺序代表 **BGR**(蓝、绿、红)三个通道。 ``` img = cv2.imread(filename) {1}
-
将图片转换为 RGB。 ```
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- 将图片 ** 调整大小 ** 至 **224x224**。 ``` img = cv2.resize(img, (224, 224,))
-
重塑数组的形态,从(图片高度, 图片宽度, 3)重塑为(3, 图片高度, 图片宽度)。 ```
img = np.swapaxes(img, 0, 2)
img = np.swapaxes(img, 1, 2)
- 添加一个 ** 第四维度 ** 并构建 NDArray ``` img = img[np.newaxis, :] array = mx.nd.array(img) >>> print array.shape (1L, 3L, 224L, 224L)
晕了?一起用个例子看看吧。输入下列这张图片:
输入 448x336 的图片(来源:metaltraveller.com)
处理完毕后,该图会被缩小尺寸并拆分为 RGB 通道,存储在 array[0] 中(生成下文图片的代码可参阅这里)。
array[0][0]:224x224,红色通道
array 0 :224x224,绿色通道
array 0 :224x224,蓝色通道
如果批大小大于 1,那么可以通过 array 1 指定第二张图片,使用 array 2 指定第三张图片,以此类推。
无论这个过程是乏味还是有趣,接下来我们开始预测吧!
开始预测
你可能还记得第 3 篇文章中提到,Module 对象必须以批为单位向模型提供数据:最常见的做法是使用数据迭代器(因此我们使用了 NDArrayIter 对象)。
在这里我们想要预测一张图片,因此尽管可以使用数据迭代器,不过也没啥必要。但我们可以创建一个名为 Batch 的具名元组 (Named tuple), 它可以充当假的迭代器,在引用数据属性时返回输入的 NDArray。
from collections import namedtuple Batch = namedtuple('Batch', ['data'])
随后即可将这个“Batch”传递给模型开始预测。
mod.forward(Batch([array]))
这个模型会输出一个包含1000 个可能性的 NDArray,每个可能性对应一个分类。由于批大小等于 1,因此只需要一行代码。
prob = mod.get_outputs()[0].asnumpy() >>> prob.shape (1, 1000)
使用 squeeze() 将其转换为数组,随后使用 argsort() 创建第二个数组,其中保存了这些可能性按照降序排列的指数。
prob = np.squeeze(prob) >>> prob.shape (1000,) >> prob [ 4.14978594e-08 1.31608676e-05 2.51907986e-05 2.24045834e-05 2.30327873e-06 3.40798979e-05 7.41563645e-06 3.04062659e-08 etc. sortedprob = np.argsort(prob)[::-1] >> sortedprob.shape (1000,)
根据模型的计算,这张图片最可能的分类是#546,可能性为58%。
>> sortedprob [546 819 862 818 542 402 650 420 983 632 733 644 513 875 776 917 795 etc. >> prob[546] 0.58039135
这个分类叫什么名字呢?我们可以使用 synset.txt 文件构建分类清单,并找出 546 号的名称。
synsetfile = open('synset.txt', 'r') categorylist = [] for line in synsetfile: categorylist.append(line.rstrip()) >>> categorylist[546] 'n03272010 electric guitar'
可能性第二大的分类是什么?
>>> prob[819] 0.27168664 >>> categorylist[819] 'n04296562 stage
挺棒的,你说呢?
就是这样,我们已经了解了如何使用预训练的顶尖模型进行图片分类。而这一切只需要4 行代码……除此之外只要准备好数据就够了。
完整代码如下,请自行尝试并继续保持关注 ??
代码已发布至 GitHub: mxnet_example2.py
后续内容:
- 第 5 篇:进一步了解预训练模型(VGG16 和 ResNet-152)
- 第 6 篇:通过树莓派进行实时物体检测(并让它讲话!)
作者: Julien Simon ,阅读英文原文: An introduction to the MXNet API?—?part 4
感谢杜小芳对本文的审校。
给InfoQ 中文站投稿或者参与内容翻译工作,请邮件至 editors@cn.infoq.com 。也欢迎大家通过新浪微博( @InfoQ , @丁晓昀),微信(微信号: InfoQChina )关注我们。
评论