2天时间,聊今年最热的 Agent、上下文工程、AI 产品创新等话题。2025 年最后一场~ 了解详情
写点什么

MXNet API 入门 —第 3 篇

  • 2017-07-13
  • 本文字数:4703 字

    阅读完需:约 15 分钟

第2 篇文章中,我们介绍了如何使用Symbols 定义计算中使用的Graph,并处理存储在NDArray(在第1 篇文章中有介绍)中的数据。

本文将介绍如何使用Symbol 和NDArray 准备所需数据并构建神经网络。随后将使用 Module API 训练该网络并预测结果。

定义数据集

我们(设想中的)数据集包含1000 个数据样本

  • 每个样本有100 个特征
  • 每个特征体现为一个介于 0 和 1 之间的浮点值
  • 样本被分为10 个类别,我们将使用神经网络预测特定样本的恰当类别,
  • 我们将使用 800 个样本进行训练,使用 200 个样本进行验证
  • 训练和验证过程的批大小为 10。
复制代码
import mxnet as mx
import numpy as np
import logging
logging.basicConfig(level=logging.INFO)
sample_count = 1000
train_count = 800
valid_count = sample_count - train_count
feature_count = 100
category_count = 10
batch=10

生成数据集

我们将通过均匀分布的方式生成这 1000 个样本,将其存储在一个名为“X”的 NDArray 中:1000 行,100 列

复制代码
X = mx.nd.uniform(low=0, high=1, shape=(sample_count,feature_count))
>>> X.shape
(1000L, 100L)
>>> X.asnumpy()
array([[ 0.70029777, 0.28444085, 0.46263582, ..., 0.73365158,
0.99670047, 0.5961988 ],
[ 0.34659418, 0.82824177, 0.72929877, ..., 0.56012964,
0.32261589, 0.35627609],
[ 0.10939316, 0.02995235, 0.97597599, ..., 0.20194994,
0.9266268 , 0.25102937],
...,
[ 0.69691515, 0.52568913, 0.21130568, ..., 0.42498392,
0.80869114, 0.23635457],
[ 0.3562004 , 0.5794751 , 0.38135922, ..., 0.6336484 ,
0.26392782, 0.30010447],
[ 0.40369365, 0.89351988, 0.88817406, ..., 0.13799617,
0.40905532, 0.05180593]], dtype=float32)

这 1000 个样本的类别用介于 0-9 的整数来代表,类别是随机生成的,存储在一个名为“Y”的 NDArray 中。

复制代码
Y = mx.nd.empty((sample_count,))
for i in range(0,sample_count-1):
Y[i] = np.random.randint(0,category_count)
>>> Y.shape
(1000L,)
>>> Y[0:10].asnumpy()
array([ 3., 3., 1., 9., 4., 7., 3., 5., 2., 2.], dtype=float32)

拆分数据集

随后我们将针对训练验证两个用途对数据集进行80/20拆分。为此需要使用 NDArray.crop 函数。在这里,数据集是完全随机的,因此可以使用前 80% 的数据进行训练,用后 20% 的数据进行验证。实际运用中,我们可能需要首先搅乱数据集,这样才能避免按顺序生成的数据可能造成的偏差。

复制代码
X_train = mx.nd.crop(X, begin=(0,0), end=(train_count,feature_count-1))
X_valid = mx.nd.crop(X, begin=(train_count,0), end=(sample_count,feature_count-1))
Y_train = Y[0:train_count]
Y_valid = Y[train_count:sample_count]

至此数据已经准备完毕!

构建网络

这个网络其实很简单,一起看看其中的每一层:

  • 输入层是由一个名为“Data”的 Symbol 代表的,随后会绑定至实际的输入数据。 ```

    data = mx.sym.Variable(‘data’)

复制代码
- fc1 是 ** 第一个隐藏层 **,通过 **64 个相互连接的神经元 ** 构建而来,输入层的每个特征都会连接至所有的 64 个神经元。如你所见,我们使用了高级的 Symbol.FullyConnected 函数,相比手工建立每个连接,这种做法更方便一些! ```
fc1 = mx.sym.FullyConnected(data, name='fc1', num_hidden=64)
  • fc1 的每个输出会进入到一个激活函数 (Activation function) 。在这里我们将使用一个线性整流单元 (Rectified linear unit) ,即“Relu”。之前承诺过尽量少讲理论知识,因此可以这样理解:激活函数将用于决定是否要“启动”某个神经元,例如其输入是否由足够有意义,可以预测出正确的结果。 ```

    relu1 = mx.sym.Activation(fc1, name=‘relu1’, act_type=“relu”)

复制代码
- fc2 是 ** 第二个隐藏层 **,由 **10 个相互连接的神经元 ** 构建而来,可映射至我们的 **10 个分类 **。每个神经元可输出一个任意标度 (Arbitrary scale) 的浮点值。10 个值中最大的那个代表了数据样本 ** 最有可能的类别 **。 ```
fc2 = mx.sym.FullyConnected(relu1, name='fc2', num_hidden=category_count)
  • 输出层会将 Softmax 函数应用给来自 fc2 层的 10 个值:这些值会被转换为 10 个介于 0 和 1 之间的值,所有值的总和为 1。每个值代表预测出的每个类别的可能性,其中最大的值代表最有可能的类别。 ```

    out = mx.sym.SoftmaxOutput(fc2, name=‘softmax’)
    mod = mx.mod.Module(out)

复制代码
## 构建数据迭代器
在第 1 篇文章中,我们了解到神经网络并不会一次只训练一个样本,因为这样做从性能的角度来看效率太低。因此我们会使用 ** 批 **,即 ** 一批固定数量的样本 **
为了给神经网络提供这样的“批”,我们需要使用 NDArrayIter 函数构建一个 ** 迭代器 **。其参数包括 ** 训练数据 **、分类(MXNet 将其称之为 ** 标签 (Label)**),以及 ** 批大小 **
如你所见,我们可以对整个数据集进行迭代,同时对 10 个样本和 10 个标签执行该操作。随后即可调用 reset() 函数将迭代器恢复为初始状态。

train_iter = mx.io.NDArrayIter(data=X_train,label=Y_train,batch_size=batch)

for batch in train_iter:
… print batch.data
… print batch.label

[<NDArray 10x99 @cpu(0)>]
[<NDArray 10 @cpu(0)>]
[<NDArray 10x99 @cpu(0)>]
[<NDArray 10 @cpu(0)>]
[<NDArray 10x99 @cpu(0)>]
[<NDArray 10 @cpu(0)>]

train_iter.reset()

复制代码
网络已经准备完成,开始训练吧!
## 训练模型
首先将输入 Symbol\*\* 绑定\*\* 至实际的数据集(样本和标签),这时候就会用到迭代器。

mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)

复制代码
随后对网络中的神经元权重进行 ** 初始化 **。这个步骤非常重要:使用“恰当”的技术对齐进行初始化可以帮助网络 ** 更快速地 ** 学习。此时可用的技术很多,Xavier 初始化器(名称源自该技术的发明人 Xavier Glorot?—?[PDF](http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf))就是其中之一。

Allowed, but not efficient

mod.init_params()

Much better

mod.init_params(initializer=mx.init.Xavier(magnitude=2.))

复制代码
接着需要定义 ** 优化 ** 参数:
- 我们将使用 [随机坡降法 (Stochastic Gradient Descent)](https://en.wikipedia.org/wiki/Stochastic_gradient_descent) 算法(又名 SGD),该算法在机器学习和深度学习领域有着广泛的应用。
- 我们会将 ** 学习速率 ** 设置为 0.1,这是 SGD 算法一个非常普遍的设置。

mod.init_optimizer(optimizer=‘sgd’, optimizer_params=((‘learning_rate’, 0.1), ))

复制代码
最后,终于可以开始训练网络了!我们会执行 50 个 ** 回合 (Epoch)** 的训练,也就是说,整个数据集需要在这个网络中(以 10 个样本为一批)运行 50 次。

mod.fit(train_iter, num_epoch=50)
INFO:root:Epoch[0] Train-accuracy=0.097500
INFO:root:Epoch[0] Time cost=0.085
INFO:root:Epoch[1] Train-accuracy=0.122500
INFO:root:Epoch[1] Time cost=0.074
INFO:root:Epoch[2] Train-accuracy=0.153750
INFO:root:Epoch[2] Time cost=0.087
INFO:root:Epoch[3] Train-accuracy=0.162500
INFO:root:Epoch[3] Time cost=0.082
INFO:root:Epoch[4] Train-accuracy=0.192500
INFO:root:Epoch[4] Time cost=0.094
INFO:root:Epoch[5] Train-accuracy=0.210000
INFO:root:Epoch[5] Time cost=0.108
INFO:root:Epoch[6] Train-accuracy=0.222500
INFO:root:Epoch[6] Time cost=0.104
INFO:root:Epoch[7] Train-accuracy=0.243750
INFO:root:Epoch[7] Time cost=0.110
INFO:root:Epoch[8] Train-accuracy=0.263750
INFO:root:Epoch[8] Time cost=0.101
INFO:root:Epoch[9] Train-accuracy=0.286250
INFO:root:Epoch[9] Time cost=0.097
INFO:root:Epoch[10] Train-accuracy=0.306250
INFO:root:Epoch[10] Time cost=0.100

INFO:root:Epoch[20] Train-accuracy=0.507500

INFO:root:Epoch[30] Train-accuracy=0.718750

INFO:root:Epoch[40] Train-accuracy=0.923750

INFO:root:Epoch[50] Train-accuracy=0.998750
INFO:root:Epoch[50] Time cost=0.077

复制代码
如你所见,训练的准确度有了飞速提升,50 个回合后已经接近 **99% 以上 **。似乎我们的网络已经从训练数据集中学成了。非常惊人!
但针对验证数据集执行的效果如何呢?
## 验证模型
随后将新的数据样本放入网络,例如剩下的那 20%** 尚未 ** 在训练中使用过的数据。
首先构建一个迭代器,这一次将使用 ** 验证 ** 样本和标签。

pred_iter = mx.io.NDArrayIter(data=X_valid,label=Y_valid, batch_size=batch)

复制代码
随后要使用 Module.iter\_predict() 函数,借此让样本在网络中运行。这样做的同时,还需要对 ** 预测的标签 **** 实际标签 ** 进行对比。我们需要追踪比分并显示 ** 验证准确度 **,即,网络针对验证数据集的执行效果到底如何。

pred_count = valid_count
correct_preds = total_correct_preds = 0
for preds, i_batch, batch in mod.iter_predict(pred_iter):
label = batch.label[0].asnumpy().astype(int)
pred_label = preds[0].asnumpy().argmax(axis=1)
correct_preds = np.sum(pred_label==label)
total_correct_preds = total_correct_preds + correct_preds
print(‘Validation accuracy: %2.2f’ % (1.0*total_correct_preds/pred_count))

复制代码
这个过程中发生了不少事 :)
iter\_predict() 返回了:
{1}
- i\_batch:批编号。
- batch:一个 NDArray 数组。这里它其实保存了一个 NDArray,其中存储了当前批的内容。我们将用它找出当前批中 10 个数据样本的标签,随后将其存储在名为 Label 的 Numpy array 中(10 个元素)。
- preds:也是一个 NDArray 数组。这里它保存了一个 NDArray,其中存储了当前批预测出的标签:对于每个样本,我们提供了 ** 所有 10 个分类预测出的可能性 **(10x10 矩阵)。因此我们将使用 argmax() 找出最高值的 ** 指数 **,即 ** 最可能的分类 **。所以 pred\_label 实际上是一个 10 元素数组,其中保存了当前批中每个数据样本预测出的分类。
{1}
随后我们需要使用 Numpy.sum() 将 label 和 pred\_label 中相等值的数量进行对比。
最后需要计算并显示验证准确度。
> 验证准确度:0.09
什么?只有 9%?** 真是太悲催了 **!如果你希望证明我们的数据集真的是随机的,那么你有证据了!
底线在于,我们确实可以通过训练神经网络学习 ** 任何东西 **,但如果数据本身是 ** 无意义的 **(例如我们本例中使用的数据),那么就什么都预测不出来。** 种瓜得瓜,种豆得豆 **
如果你已经读到这里,我猜你是真心希望看到本例的完整代码 ;) 请花些时间用你自己的数据进行验证,这才是学习的最佳方法。
代码已发布至 GitHub:[mxnet\_example1.py](https://gist.github.com/juliensimon/7cfef0423b0183e891774a289e156b49#file-mxnet_example1-py)。
## 后续内容:
- 第 4 篇:使用预训练模型进行图片分类(Inception v3)
- 第 5 篇:进一步了解预训练模型(VGG16 和 ResNet-152)
- 第 6 篇:通过树莓派进行实时物体检测(并让它讲话!)
** 作者 **:[Julien Simon](https://medium.com/@julsimon),** 阅读英文原文 **:[An introduction to the MXNet API?—?part 3](https://medium.com/@julsimon/an-introduction-to-the-mxnet-api-part-3-1803112ba3a8)
- - - - - -
感谢 [杜小芳](http://www.infoq.com/cn/author/%E6%9D%9C%E5%B0%8F%E8%8A%B3) 对本文的审校。
给 InfoQ 中文站投稿或者参与内容翻译工作,请邮件至 [editors@cn.infoq.com](mailto:editors@cn.infoq.com)。也欢迎大家通过新浪微博([@InfoQ](http://www.weibo.com/infoqchina),[@丁晓昀](http://weibo.com/u/1451714913)),微信(微信号:[InfoQChina](http://www.geekbang.org/ivtw))关注我们。
2017-07-13 17:394532
用户头像

发布了 283 篇内容, 共 123.0 次阅读, 收获喜欢 63 次。

关注

评论

发布
暂无评论
发现更多内容

成功入职阿里,薪资翻倍~ 感谢这份顶级版,linux教程入门教程PDF

Java 程序员 后端

我上高中的弟弟都能看懂的Docker学习教程,你看看讲的怎么样

Java 程序员 后端

手把手讲解-一个复杂动效的自定义绘制3,最全153道Spring全家桶面试题

Java 程序员 后端

怎样成为全栈工程师(Full Stack Developer)(1),已拿offer

Java 程序员 后端

意犹未尽的一篇Nginx原理详解,面试官看了都忍不住点赞

Java 程序员 后端

懵逼!阿里一面就被虐了,幸获内推华为技术四面,kafka高性能原理

Java 程序员 后端

意犹未尽的一篇Nginx原理详解,面试官看了都忍不住点赞(1)

Java 程序员 后端

成功拿到大厂offer的我熬夜整理了这份Java高频面试题(含答案)

Java 程序员 后端

我是全网最硬核的Java中间件领域作者,CSDN最值得关注的博主,大家同意吗

Java 程序员 后端

手撕ArrayList底层,透彻分析源码,mysql索引优化面试题

Java 程序员 后端

总结历年各大厂面试官传授的面试经验+阿里P8级架构师整理的Java高频核心知识点

Java 程序员 后端

恕我直言,我怀疑你们并不会用 Java 枚举,java分布式架构面试题

Java 程序员 后端

想进阿里、京东?这些多线程并发的技术要点你需要知道,Java程序员怎么优雅迈过30K+这道坎

Java 程序员 后端

怎么用Redis分布式锁才能确保万无一失?,15个经典面试问题及答案

Java 程序员 后端

总是说spring难学?看完这些spring的注解及其解释,对你来说就是So-easy!

Java 程序员 后端

扫盲帖:聊聊微服务与分布式系统,Java校招面试指南

Java 程序员 后端

我,48岁,上海外企高管,9次Java面试经验总结

Java 程序员 后端

怎样成为全栈工程师(Full Stack Developer),sqlproformysql使用教程

Java 程序员 后端

我凭借这1000道java高频真题,顺利拿下京东、饿了么,java高级开发面试总结

Java 程序员 后端

我在北京已经几年了,Java百度网盘

Java 程序员 后端

手把手教你应用三种工厂模式在SpringIOC中创建对象实例【案例详解】

Java 程序员 后端

技术分享成就现在的我:中间件兴趣圈荣获CSDN2020博客之星亚军

Java 程序员 后端

快速鸟瞰并发编程,-呕心沥血整理的架构技术【2】,分层展示的架构图

Java 程序员 后端

成为架构师之前,你一定要懂的-CAP-定理,Java程序员必备书籍

Java 程序员 后端

我出息了,给 JDK 上报了一个 BUG,mongodb入门到精通

Java 程序员 后端

快醒醒吧!互联网大厂面试必问的JVM底层原理,你还搞不清楚

Java 程序员 后端

怎么可能?面试会被Spring难住?Spring框架从入门到精通

Java 程序员 后端

我用思维导图整理好了Java并发基础知识,还学不会就没救了!

Java 程序员 后端

我丢,去面试初级Java开发岗位,被问到泛型,mysql索引原理面试题

Java 程序员 后端

我画了19张图,彻底帮你搞定Redis,mybatisgenerator教程

Java 程序员 后端

我见过最详细的Redis解析:不懂Redis为什么高性能?如何做高可用

Java 程序员 后端

MXNet API入门 —第3篇_语言 & 开发_Julien Simon_InfoQ精选文章