报名参加CloudWeGo黑客松,奖金直推双丰收! 了解详情
写点什么

MXNet API 入门 —第 3 篇

  • 2017-07-13
  • 本文字数:4703 字

    阅读完需:约 15 分钟

第2 篇文章中,我们介绍了如何使用Symbols 定义计算中使用的Graph,并处理存储在NDArray(在第1 篇文章中有介绍)中的数据。

本文将介绍如何使用Symbol 和NDArray 准备所需数据并构建神经网络。随后将使用 Module API 训练该网络并预测结果。

定义数据集

我们(设想中的)数据集包含1000 个数据样本

  • 每个样本有100 个特征
  • 每个特征体现为一个介于 0 和 1 之间的浮点值
  • 样本被分为10 个类别,我们将使用神经网络预测特定样本的恰当类别,
  • 我们将使用 800 个样本进行训练,使用 200 个样本进行验证
  • 训练和验证过程的批大小为 10。
复制代码
import mxnet as mx
import numpy as np
import logging
logging.basicConfig(level=logging.INFO)
sample_count = 1000
train_count = 800
valid_count = sample_count - train_count
feature_count = 100
category_count = 10
batch=10

生成数据集

我们将通过均匀分布的方式生成这 1000 个样本,将其存储在一个名为“X”的 NDArray 中:1000 行,100 列

复制代码
X = mx.nd.uniform(low=0, high=1, shape=(sample_count,feature_count))
>>> X.shape
(1000L, 100L)
>>> X.asnumpy()
array([[ 0.70029777, 0.28444085, 0.46263582, ..., 0.73365158,
0.99670047, 0.5961988 ],
[ 0.34659418, 0.82824177, 0.72929877, ..., 0.56012964,
0.32261589, 0.35627609],
[ 0.10939316, 0.02995235, 0.97597599, ..., 0.20194994,
0.9266268 , 0.25102937],
...,
[ 0.69691515, 0.52568913, 0.21130568, ..., 0.42498392,
0.80869114, 0.23635457],
[ 0.3562004 , 0.5794751 , 0.38135922, ..., 0.6336484 ,
0.26392782, 0.30010447],
[ 0.40369365, 0.89351988, 0.88817406, ..., 0.13799617,
0.40905532, 0.05180593]], dtype=float32)

这 1000 个样本的类别用介于 0-9 的整数来代表,类别是随机生成的,存储在一个名为“Y”的 NDArray 中。

复制代码
Y = mx.nd.empty((sample_count,))
for i in range(0,sample_count-1):
Y[i] = np.random.randint(0,category_count)
>>> Y.shape
(1000L,)
>>> Y[0:10].asnumpy()
array([ 3., 3., 1., 9., 4., 7., 3., 5., 2., 2.], dtype=float32)

拆分数据集

随后我们将针对训练验证两个用途对数据集进行80/20拆分。为此需要使用 NDArray.crop 函数。在这里,数据集是完全随机的,因此可以使用前 80% 的数据进行训练,用后 20% 的数据进行验证。实际运用中,我们可能需要首先搅乱数据集,这样才能避免按顺序生成的数据可能造成的偏差。

复制代码
X_train = mx.nd.crop(X, begin=(0,0), end=(train_count,feature_count-1))
X_valid = mx.nd.crop(X, begin=(train_count,0), end=(sample_count,feature_count-1))
Y_train = Y[0:train_count]
Y_valid = Y[train_count:sample_count]

至此数据已经准备完毕!

构建网络

这个网络其实很简单,一起看看其中的每一层:

  • 输入层是由一个名为“Data”的 Symbol 代表的,随后会绑定至实际的输入数据。 ```

    data = mx.sym.Variable(‘data’)

复制代码
- fc1 是 ** 第一个隐藏层 **,通过 **64 个相互连接的神经元 ** 构建而来,输入层的每个特征都会连接至所有的 64 个神经元。如你所见,我们使用了高级的 Symbol.FullyConnected 函数,相比手工建立每个连接,这种做法更方便一些! ```
fc1 = mx.sym.FullyConnected(data, name='fc1', num_hidden=64)
  • fc1 的每个输出会进入到一个激活函数 (Activation function) 。在这里我们将使用一个线性整流单元 (Rectified linear unit) ,即“Relu”。之前承诺过尽量少讲理论知识,因此可以这样理解:激活函数将用于决定是否要“启动”某个神经元,例如其输入是否由足够有意义,可以预测出正确的结果。 ```

    relu1 = mx.sym.Activation(fc1, name=‘relu1’, act_type=“relu”)

复制代码
- fc2 是 ** 第二个隐藏层 **,由 **10 个相互连接的神经元 ** 构建而来,可映射至我们的 **10 个分类 **。每个神经元可输出一个任意标度 (Arbitrary scale) 的浮点值。10 个值中最大的那个代表了数据样本 ** 最有可能的类别 **。 ```
fc2 = mx.sym.FullyConnected(relu1, name='fc2', num_hidden=category_count)
  • 输出层会将 Softmax 函数应用给来自 fc2 层的 10 个值:这些值会被转换为 10 个介于 0 和 1 之间的值,所有值的总和为 1。每个值代表预测出的每个类别的可能性,其中最大的值代表最有可能的类别。 ```

    out = mx.sym.SoftmaxOutput(fc2, name=‘softmax’)
    mod = mx.mod.Module(out)

复制代码
## 构建数据迭代器
在第 1 篇文章中,我们了解到神经网络并不会一次只训练一个样本,因为这样做从性能的角度来看效率太低。因此我们会使用 ** 批 **,即 ** 一批固定数量的样本 **
为了给神经网络提供这样的“批”,我们需要使用 NDArrayIter 函数构建一个 ** 迭代器 **。其参数包括 ** 训练数据 **、分类(MXNet 将其称之为 ** 标签 (Label)**),以及 ** 批大小 **
如你所见,我们可以对整个数据集进行迭代,同时对 10 个样本和 10 个标签执行该操作。随后即可调用 reset() 函数将迭代器恢复为初始状态。

train_iter = mx.io.NDArrayIter(data=X_train,label=Y_train,batch_size=batch)

for batch in train_iter:
… print batch.data
… print batch.label

[<NDArray 10x99 @cpu(0)>]
[<NDArray 10 @cpu(0)>]
[<NDArray 10x99 @cpu(0)>]
[<NDArray 10 @cpu(0)>]
[<NDArray 10x99 @cpu(0)>]
[<NDArray 10 @cpu(0)>]

train_iter.reset()

复制代码
网络已经准备完成,开始训练吧!
## 训练模型
首先将输入 Symbol\*\* 绑定\*\* 至实际的数据集(样本和标签),这时候就会用到迭代器。

mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)

复制代码
随后对网络中的神经元权重进行 ** 初始化 **。这个步骤非常重要:使用“恰当”的技术对齐进行初始化可以帮助网络 ** 更快速地 ** 学习。此时可用的技术很多,Xavier 初始化器(名称源自该技术的发明人 Xavier Glorot?—?[PDF](http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf))就是其中之一。

Allowed, but not efficient

mod.init_params()

Much better

mod.init_params(initializer=mx.init.Xavier(magnitude=2.))

复制代码
接着需要定义 ** 优化 ** 参数:
- 我们将使用 [随机坡降法 (Stochastic Gradient Descent)](https://en.wikipedia.org/wiki/Stochastic_gradient_descent) 算法(又名 SGD),该算法在机器学习和深度学习领域有着广泛的应用。
- 我们会将 ** 学习速率 ** 设置为 0.1,这是 SGD 算法一个非常普遍的设置。

mod.init_optimizer(optimizer=‘sgd’, optimizer_params=((‘learning_rate’, 0.1), ))

复制代码
最后,终于可以开始训练网络了!我们会执行 50 个 ** 回合 (Epoch)** 的训练,也就是说,整个数据集需要在这个网络中(以 10 个样本为一批)运行 50 次。

mod.fit(train_iter, num_epoch=50)
INFO:root:Epoch[0] Train-accuracy=0.097500
INFO:root:Epoch[0] Time cost=0.085
INFO:root:Epoch[1] Train-accuracy=0.122500
INFO:root:Epoch[1] Time cost=0.074
INFO:root:Epoch[2] Train-accuracy=0.153750
INFO:root:Epoch[2] Time cost=0.087
INFO:root:Epoch[3] Train-accuracy=0.162500
INFO:root:Epoch[3] Time cost=0.082
INFO:root:Epoch[4] Train-accuracy=0.192500
INFO:root:Epoch[4] Time cost=0.094
INFO:root:Epoch[5] Train-accuracy=0.210000
INFO:root:Epoch[5] Time cost=0.108
INFO:root:Epoch[6] Train-accuracy=0.222500
INFO:root:Epoch[6] Time cost=0.104
INFO:root:Epoch[7] Train-accuracy=0.243750
INFO:root:Epoch[7] Time cost=0.110
INFO:root:Epoch[8] Train-accuracy=0.263750
INFO:root:Epoch[8] Time cost=0.101
INFO:root:Epoch[9] Train-accuracy=0.286250
INFO:root:Epoch[9] Time cost=0.097
INFO:root:Epoch[10] Train-accuracy=0.306250
INFO:root:Epoch[10] Time cost=0.100

INFO:root:Epoch[20] Train-accuracy=0.507500

INFO:root:Epoch[30] Train-accuracy=0.718750

INFO:root:Epoch[40] Train-accuracy=0.923750

INFO:root:Epoch[50] Train-accuracy=0.998750
INFO:root:Epoch[50] Time cost=0.077

复制代码
如你所见,训练的准确度有了飞速提升,50 个回合后已经接近 **99% 以上 **。似乎我们的网络已经从训练数据集中学成了。非常惊人!
但针对验证数据集执行的效果如何呢?
## 验证模型
随后将新的数据样本放入网络,例如剩下的那 20%** 尚未 ** 在训练中使用过的数据。
首先构建一个迭代器,这一次将使用 ** 验证 ** 样本和标签。

pred_iter = mx.io.NDArrayIter(data=X_valid,label=Y_valid, batch_size=batch)

复制代码
随后要使用 Module.iter\_predict() 函数,借此让样本在网络中运行。这样做的同时,还需要对 ** 预测的标签 **** 实际标签 ** 进行对比。我们需要追踪比分并显示 ** 验证准确度 **,即,网络针对验证数据集的执行效果到底如何。

pred_count = valid_count
correct_preds = total_correct_preds = 0
for preds, i_batch, batch in mod.iter_predict(pred_iter):
label = batch.label[0].asnumpy().astype(int)
pred_label = preds[0].asnumpy().argmax(axis=1)
correct_preds = np.sum(pred_label==label)
total_correct_preds = total_correct_preds + correct_preds
print(‘Validation accuracy: %2.2f’ % (1.0*total_correct_preds/pred_count))

复制代码
这个过程中发生了不少事 :)
iter\_predict() 返回了:
{1}
- i\_batch:批编号。
- batch:一个 NDArray 数组。这里它其实保存了一个 NDArray,其中存储了当前批的内容。我们将用它找出当前批中 10 个数据样本的标签,随后将其存储在名为 Label 的 Numpy array 中(10 个元素)。
- preds:也是一个 NDArray 数组。这里它保存了一个 NDArray,其中存储了当前批预测出的标签:对于每个样本,我们提供了 ** 所有 10 个分类预测出的可能性 **(10x10 矩阵)。因此我们将使用 argmax() 找出最高值的 ** 指数 **,即 ** 最可能的分类 **。所以 pred\_label 实际上是一个 10 元素数组,其中保存了当前批中每个数据样本预测出的分类。
{1}
随后我们需要使用 Numpy.sum() 将 label 和 pred\_label 中相等值的数量进行对比。
最后需要计算并显示验证准确度。
> 验证准确度:0.09
什么?只有 9%?** 真是太悲催了 **!如果你希望证明我们的数据集真的是随机的,那么你有证据了!
底线在于,我们确实可以通过训练神经网络学习 ** 任何东西 **,但如果数据本身是 ** 无意义的 **(例如我们本例中使用的数据),那么就什么都预测不出来。** 种瓜得瓜,种豆得豆 **
如果你已经读到这里,我猜你是真心希望看到本例的完整代码 ;) 请花些时间用你自己的数据进行验证,这才是学习的最佳方法。
代码已发布至 GitHub:[mxnet\_example1.py](https://gist.github.com/juliensimon/7cfef0423b0183e891774a289e156b49#file-mxnet_example1-py)。
## 后续内容:
- 第 4 篇:使用预训练模型进行图片分类(Inception v3)
- 第 5 篇:进一步了解预训练模型(VGG16 和 ResNet-152)
- 第 6 篇:通过树莓派进行实时物体检测(并让它讲话!)
** 作者 **:[Julien Simon](https://medium.com/@julsimon),** 阅读英文原文 **:[An introduction to the MXNet API?—?part 3](https://medium.com/@julsimon/an-introduction-to-the-mxnet-api-part-3-1803112ba3a8)
- - - - - -
感谢 [杜小芳](http://www.infoq.com/cn/author/%E6%9D%9C%E5%B0%8F%E8%8A%B3) 对本文的审校。
给 InfoQ 中文站投稿或者参与内容翻译工作,请邮件至 [editors@cn.infoq.com](mailto:editors@cn.infoq.com)。也欢迎大家通过新浪微博([@InfoQ](http://www.weibo.com/infoqchina),[@丁晓昀](http://weibo.com/u/1451714913)),微信(微信号:[InfoQChina](http://www.geekbang.org/ivtw))关注我们。
2017-07-13 17:394168
用户头像

发布了 283 篇内容, 共 112.3 次阅读, 收获喜欢 62 次。

关注

评论

发布
暂无评论
发现更多内容

JavaScript 中数组 sort() 方法的基本使用

编程三昧

JavaScript 大前端 数组 排序 js

分布式认知工业互联网如何赋能工业企业数字化转型?

CECBC

Python——输入输出:加减乘除四则运算的程序

在即

6月日更

5分钟速读之Rust权威指南(十九)

wzx

rust 生命周期

高性能 JavaScriptの七 -- 编程实践小技巧

空城机

JavaScript 大前端 6月日更

《原则》(八)

Changing Lin

6月日更

项目管理与项目集管理、项目组合管理的区别?

万事ONES

项目管理 项目 PMO ONES

公司:离职就是一场危机管理

石云升

创业 职场经验 6月日更

云原生推动全云开发与实践

阿里巴巴云原生

云原生

spring-beans 注册 Beans(四)BeanDefinition

梦倚栏杆

Java--JVM运行流程

是老郭啊

Java JVM JVM原理

MySQL基础之六:连接查询

打工人!

myslq 6月日更

阿里云边缘容器服务、申通 IoT 云边端架构入选 2021 云边协同发展阶段性领先成果

阿里巴巴云原生

云原生

【布道API】浅谈API设计风格

devpoint

Rest API 6月日更

这些书都学完,绝对是编程界的大佬

看山

Java 程序员 6月日更

学妹问,学网站开发还是打 ACM?

程序员鱼皮

Java 程序员 算法 大前端 ACM

操作系统内核是什么?Linux内核又是什么?读完这篇文章,我终于知道了

奔着腾讯去

c++ 操作系统 内存管理 Linux内核 进程管理

异构内存及其在机器学习系统的应用与优化

白玉兰开源

人工智能 机器学习 解决方案 第四范式 傲腾

Kubernetes手记(5)- 配置清单使用

雪雷

k8s 6月日更

你愿意被管理么?

escray

学习 极客时间 朱赟的技术管理课 6月日更

【Vue2.x 源码学习】第八篇 - 数组的深层劫持

Brave

源码 vue2 6月日更

缓存穿透、缓存雪崩、缓存击穿问题与优化方案

Skysper

软件研发团队如何做好项目进度管理?

万事ONES

项目管理 研发管理 需求 ONES

做通才还是专才,你会怎么选?

架构精进之路

认知提升 6月日更

【21-1】21 连更第一篇

耳东@Erdong

6月日更

不管是三胎还是App!指望“拉新”太难了,还是要靠老用户!

APP开发

Mybatis 二级缓存简单示例

Java mybatis

数字化转型背景下的测试转型

BY林子

敏捷测试 测试转型

浅谈Java中的TCP超时

Hoswey_洪树伟

Java、

加快技术应用规模化 建设世界先进水平区块链产业生态

CECBC

区块链+金融:当前区块链应用场景中最具活力的领域

CECBC

MXNet API入门 —第3篇_语言 & 开发_Julien Simon_InfoQ精选文章