写点什么

MXNet API 入门 —第 3 篇

  • 2017-07-13
  • 本文字数:4703 字

    阅读完需:约 15 分钟

第2 篇文章中,我们介绍了如何使用Symbols 定义计算中使用的Graph,并处理存储在NDArray(在第1 篇文章中有介绍)中的数据。

本文将介绍如何使用Symbol 和NDArray 准备所需数据并构建神经网络。随后将使用 Module API 训练该网络并预测结果。

定义数据集

我们(设想中的)数据集包含1000 个数据样本

  • 每个样本有100 个特征
  • 每个特征体现为一个介于 0 和 1 之间的浮点值
  • 样本被分为10 个类别,我们将使用神经网络预测特定样本的恰当类别,
  • 我们将使用 800 个样本进行训练,使用 200 个样本进行验证
  • 训练和验证过程的批大小为 10。
复制代码
import mxnet as mx
import numpy as np
import logging
logging.basicConfig(level=logging.INFO)
sample_count = 1000
train_count = 800
valid_count = sample_count - train_count
feature_count = 100
category_count = 10
batch=10

生成数据集

我们将通过均匀分布的方式生成这 1000 个样本,将其存储在一个名为“X”的 NDArray 中:1000 行,100 列

复制代码
X = mx.nd.uniform(low=0, high=1, shape=(sample_count,feature_count))
>>> X.shape
(1000L, 100L)
>>> X.asnumpy()
array([[ 0.70029777, 0.28444085, 0.46263582, ..., 0.73365158,
0.99670047, 0.5961988 ],
[ 0.34659418, 0.82824177, 0.72929877, ..., 0.56012964,
0.32261589, 0.35627609],
[ 0.10939316, 0.02995235, 0.97597599, ..., 0.20194994,
0.9266268 , 0.25102937],
...,
[ 0.69691515, 0.52568913, 0.21130568, ..., 0.42498392,
0.80869114, 0.23635457],
[ 0.3562004 , 0.5794751 , 0.38135922, ..., 0.6336484 ,
0.26392782, 0.30010447],
[ 0.40369365, 0.89351988, 0.88817406, ..., 0.13799617,
0.40905532, 0.05180593]], dtype=float32)

这 1000 个样本的类别用介于 0-9 的整数来代表,类别是随机生成的,存储在一个名为“Y”的 NDArray 中。

复制代码
Y = mx.nd.empty((sample_count,))
for i in range(0,sample_count-1):
Y[i] = np.random.randint(0,category_count)
>>> Y.shape
(1000L,)
>>> Y[0:10].asnumpy()
array([ 3., 3., 1., 9., 4., 7., 3., 5., 2., 2.], dtype=float32)

拆分数据集

随后我们将针对训练验证两个用途对数据集进行80/20拆分。为此需要使用 NDArray.crop 函数。在这里,数据集是完全随机的,因此可以使用前 80% 的数据进行训练,用后 20% 的数据进行验证。实际运用中,我们可能需要首先搅乱数据集,这样才能避免按顺序生成的数据可能造成的偏差。

复制代码
X_train = mx.nd.crop(X, begin=(0,0), end=(train_count,feature_count-1))
X_valid = mx.nd.crop(X, begin=(train_count,0), end=(sample_count,feature_count-1))
Y_train = Y[0:train_count]
Y_valid = Y[train_count:sample_count]

至此数据已经准备完毕!

构建网络

这个网络其实很简单,一起看看其中的每一层:

  • 输入层是由一个名为“Data”的 Symbol 代表的,随后会绑定至实际的输入数据。 ```

    data = mx.sym.Variable(‘data’)

复制代码
- fc1 是 ** 第一个隐藏层 **,通过 **64 个相互连接的神经元 ** 构建而来,输入层的每个特征都会连接至所有的 64 个神经元。如你所见,我们使用了高级的 Symbol.FullyConnected 函数,相比手工建立每个连接,这种做法更方便一些! ```
fc1 = mx.sym.FullyConnected(data, name='fc1', num_hidden=64)
  • fc1 的每个输出会进入到一个激活函数 (Activation function) 。在这里我们将使用一个线性整流单元 (Rectified linear unit) ,即“Relu”。之前承诺过尽量少讲理论知识,因此可以这样理解:激活函数将用于决定是否要“启动”某个神经元,例如其输入是否由足够有意义,可以预测出正确的结果。 ```

    relu1 = mx.sym.Activation(fc1, name=‘relu1’, act_type=“relu”)

复制代码
- fc2 是 ** 第二个隐藏层 **,由 **10 个相互连接的神经元 ** 构建而来,可映射至我们的 **10 个分类 **。每个神经元可输出一个任意标度 (Arbitrary scale) 的浮点值。10 个值中最大的那个代表了数据样本 ** 最有可能的类别 **。 ```
fc2 = mx.sym.FullyConnected(relu1, name='fc2', num_hidden=category_count)
  • 输出层会将 Softmax 函数应用给来自 fc2 层的 10 个值:这些值会被转换为 10 个介于 0 和 1 之间的值,所有值的总和为 1。每个值代表预测出的每个类别的可能性,其中最大的值代表最有可能的类别。 ```

    out = mx.sym.SoftmaxOutput(fc2, name=‘softmax’)
    mod = mx.mod.Module(out)

复制代码
## 构建数据迭代器
在第 1 篇文章中,我们了解到神经网络并不会一次只训练一个样本,因为这样做从性能的角度来看效率太低。因此我们会使用 ** 批 **,即 ** 一批固定数量的样本 **
为了给神经网络提供这样的“批”,我们需要使用 NDArrayIter 函数构建一个 ** 迭代器 **。其参数包括 ** 训练数据 **、分类(MXNet 将其称之为 ** 标签 (Label)**),以及 ** 批大小 **
如你所见,我们可以对整个数据集进行迭代,同时对 10 个样本和 10 个标签执行该操作。随后即可调用 reset() 函数将迭代器恢复为初始状态。

train_iter = mx.io.NDArrayIter(data=X_train,label=Y_train,batch_size=batch)

for batch in train_iter:
… print batch.data
… print batch.label

[<NDArray 10x99 @cpu(0)>]
[<NDArray 10 @cpu(0)>]
[<NDArray 10x99 @cpu(0)>]
[<NDArray 10 @cpu(0)>]
[<NDArray 10x99 @cpu(0)>]
[<NDArray 10 @cpu(0)>]

train_iter.reset()

复制代码
网络已经准备完成,开始训练吧!
## 训练模型
首先将输入 Symbol\*\* 绑定\*\* 至实际的数据集(样本和标签),这时候就会用到迭代器。

mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)

复制代码
随后对网络中的神经元权重进行 ** 初始化 **。这个步骤非常重要:使用“恰当”的技术对齐进行初始化可以帮助网络 ** 更快速地 ** 学习。此时可用的技术很多,Xavier 初始化器(名称源自该技术的发明人 Xavier Glorot?—?[PDF](http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf))就是其中之一。

Allowed, but not efficient

mod.init_params()

Much better

mod.init_params(initializer=mx.init.Xavier(magnitude=2.))

复制代码
接着需要定义 ** 优化 ** 参数:
- 我们将使用 [随机坡降法 (Stochastic Gradient Descent)](https://en.wikipedia.org/wiki/Stochastic_gradient_descent) 算法(又名 SGD),该算法在机器学习和深度学习领域有着广泛的应用。
- 我们会将 ** 学习速率 ** 设置为 0.1,这是 SGD 算法一个非常普遍的设置。

mod.init_optimizer(optimizer=‘sgd’, optimizer_params=((‘learning_rate’, 0.1), ))

复制代码
最后,终于可以开始训练网络了!我们会执行 50 个 ** 回合 (Epoch)** 的训练,也就是说,整个数据集需要在这个网络中(以 10 个样本为一批)运行 50 次。

mod.fit(train_iter, num_epoch=50)
INFO:root:Epoch[0] Train-accuracy=0.097500
INFO:root:Epoch[0] Time cost=0.085
INFO:root:Epoch[1] Train-accuracy=0.122500
INFO:root:Epoch[1] Time cost=0.074
INFO:root:Epoch[2] Train-accuracy=0.153750
INFO:root:Epoch[2] Time cost=0.087
INFO:root:Epoch[3] Train-accuracy=0.162500
INFO:root:Epoch[3] Time cost=0.082
INFO:root:Epoch[4] Train-accuracy=0.192500
INFO:root:Epoch[4] Time cost=0.094
INFO:root:Epoch[5] Train-accuracy=0.210000
INFO:root:Epoch[5] Time cost=0.108
INFO:root:Epoch[6] Train-accuracy=0.222500
INFO:root:Epoch[6] Time cost=0.104
INFO:root:Epoch[7] Train-accuracy=0.243750
INFO:root:Epoch[7] Time cost=0.110
INFO:root:Epoch[8] Train-accuracy=0.263750
INFO:root:Epoch[8] Time cost=0.101
INFO:root:Epoch[9] Train-accuracy=0.286250
INFO:root:Epoch[9] Time cost=0.097
INFO:root:Epoch[10] Train-accuracy=0.306250
INFO:root:Epoch[10] Time cost=0.100

INFO:root:Epoch[20] Train-accuracy=0.507500

INFO:root:Epoch[30] Train-accuracy=0.718750

INFO:root:Epoch[40] Train-accuracy=0.923750

INFO:root:Epoch[50] Train-accuracy=0.998750
INFO:root:Epoch[50] Time cost=0.077

复制代码
如你所见,训练的准确度有了飞速提升,50 个回合后已经接近 **99% 以上 **。似乎我们的网络已经从训练数据集中学成了。非常惊人!
但针对验证数据集执行的效果如何呢?
## 验证模型
随后将新的数据样本放入网络,例如剩下的那 20%** 尚未 ** 在训练中使用过的数据。
首先构建一个迭代器,这一次将使用 ** 验证 ** 样本和标签。

pred_iter = mx.io.NDArrayIter(data=X_valid,label=Y_valid, batch_size=batch)

复制代码
随后要使用 Module.iter\_predict() 函数,借此让样本在网络中运行。这样做的同时,还需要对 ** 预测的标签 **** 实际标签 ** 进行对比。我们需要追踪比分并显示 ** 验证准确度 **,即,网络针对验证数据集的执行效果到底如何。

pred_count = valid_count
correct_preds = total_correct_preds = 0
for preds, i_batch, batch in mod.iter_predict(pred_iter):
label = batch.label[0].asnumpy().astype(int)
pred_label = preds[0].asnumpy().argmax(axis=1)
correct_preds = np.sum(pred_label==label)
total_correct_preds = total_correct_preds + correct_preds
print(‘Validation accuracy: %2.2f’ % (1.0*total_correct_preds/pred_count))

复制代码
这个过程中发生了不少事 :)
iter\_predict() 返回了:
{1}
- i\_batch:批编号。
- batch:一个 NDArray 数组。这里它其实保存了一个 NDArray,其中存储了当前批的内容。我们将用它找出当前批中 10 个数据样本的标签,随后将其存储在名为 Label 的 Numpy array 中(10 个元素)。
- preds:也是一个 NDArray 数组。这里它保存了一个 NDArray,其中存储了当前批预测出的标签:对于每个样本,我们提供了 ** 所有 10 个分类预测出的可能性 **(10x10 矩阵)。因此我们将使用 argmax() 找出最高值的 ** 指数 **,即 ** 最可能的分类 **。所以 pred\_label 实际上是一个 10 元素数组,其中保存了当前批中每个数据样本预测出的分类。
{1}
随后我们需要使用 Numpy.sum() 将 label 和 pred\_label 中相等值的数量进行对比。
最后需要计算并显示验证准确度。
> 验证准确度:0.09
什么?只有 9%?** 真是太悲催了 **!如果你希望证明我们的数据集真的是随机的,那么你有证据了!
底线在于,我们确实可以通过训练神经网络学习 ** 任何东西 **,但如果数据本身是 ** 无意义的 **(例如我们本例中使用的数据),那么就什么都预测不出来。** 种瓜得瓜,种豆得豆 **
如果你已经读到这里,我猜你是真心希望看到本例的完整代码 ;) 请花些时间用你自己的数据进行验证,这才是学习的最佳方法。
代码已发布至 GitHub:[mxnet\_example1.py](https://gist.github.com/juliensimon/7cfef0423b0183e891774a289e156b49#file-mxnet_example1-py)。
## 后续内容:
- 第 4 篇:使用预训练模型进行图片分类(Inception v3)
- 第 5 篇:进一步了解预训练模型(VGG16 和 ResNet-152)
- 第 6 篇:通过树莓派进行实时物体检测(并让它讲话!)
** 作者 **:[Julien Simon](https://medium.com/@julsimon),** 阅读英文原文 **:[An introduction to the MXNet API?—?part 3](https://medium.com/@julsimon/an-introduction-to-the-mxnet-api-part-3-1803112ba3a8)
- - - - - -
感谢 [杜小芳](http://www.infoq.com/cn/author/%E6%9D%9C%E5%B0%8F%E8%8A%B3) 对本文的审校。
给 InfoQ 中文站投稿或者参与内容翻译工作,请邮件至 [editors@cn.infoq.com](mailto:editors@cn.infoq.com)。也欢迎大家通过新浪微博([@InfoQ](http://www.weibo.com/infoqchina),[@丁晓昀](http://weibo.com/u/1451714913)),微信(微信号:[InfoQChina](http://www.geekbang.org/ivtw))关注我们。
2017-07-13 17:394012
用户头像

发布了 283 篇内容, 共 107.7 次阅读, 收获喜欢 62 次。

关注

评论

发布
暂无评论
发现更多内容

恒源云(GPUSHARE)_[SimCSE]:对比学习,只需要 Dropout?

恒源云

深度学习

大凉山的新衣,产业AI的未来

脑极体

如何把 MySQL 备份验证性能提升 10 倍

Juicedata

MySQL 数据库 云存储 数据备份

模块三作业

cqyanbo

[架构实战营] 模块八作业

张祥

架构实战营

模块八

侠客行

「架构实战营」

你只会用 split?试试 StringTokenizer,性能可以快 4 倍!!

CRMEB

百度智能云开物秀出全年成绩,发布和升级五大新产品

百度大脑

人工智能 百度

2021年末28天写作营总结

mtfelix

28天写作

1.6(下周四)直播 | 观测云实践学堂03期:K8S太复杂,可观测实践一筹莫展?全新K8S实践干货直播间等你!

观测云

直播

NFG定期赚币专场在虎符开启 APY高达1200%

区块链前沿News

Hoo虎符 虎符交易所

怎么借助Camtasia制作回忆录

淋雨

Camtasia 录屏 luping

网络安全审计之CMS代码审计

网络安全学海

黑客 网络安全 信息安全 渗透测试 代码审计

Presto 在字节跳动的内部实践与优化(优化篇)

字节跳动数据平台

大数据 字节跳动 presto

在字节,大规模埋点数据治理这么做!

字节跳动数据平台

大数据 字节跳动 埋点 流量 埋点治理

物业资产管理系统解决方案

低代码小观

低代码 企业管理 资产管理 CRM CRM系统

Spring框架基础知识(02)

海拥(haiyong.site)

28天写作 12月日更

Presto 在字节跳动的内部实践与优化(实践篇)

字节跳动数据平台

大数据 字节跳动 presto

LabVIEW图像灰度分析与变换(基础篇—4)

不脱发的程序猿

机器视觉 图像处理 LabVIEW 图像灰度分析与变换

直播整理 | TDengine 技术内幕分享:兼容 OpenTSDB

TDengine

数据库 tdengine OpenTSDB

Apache APISIX Dashboard 未授权访问漏洞公告(CVE-2021-45232)

API7.ai 技术团队

漏洞修复 CVE Apache APISIX

2021,用「创新」重新定义ToB

ToB行业头条

面试被问spring ioc,这样说让面试官眼前一亮(1)

公众号:程序猿成神之路

spring 5

数字中国建设再提速,智慧金融发展如何跑出“加速度”?

百度大脑

人工智能 数字化 智能化

Java 数据持久化系列之JDBC

程序员历小冰

数据库 持久化 28天写作 12月日更

2022 年第一场云原生技术实践营开启报名

阿里巴巴云原生

阿里云 云原生 线下活动 布道师 实践营

httprouter源码刨析

王博

云原生 Serverless Database 使用体验

阿里巴巴云原生

阿里云 Serverless 云原生 弹性 表格存储

【安全漏洞】利用CodeQL分析并挖掘Log4j漏洞

H

网络安全 信息安全 漏洞

28天写作总结

wood

28天写作

一套架构框架如何满足流批数据质量监控

字节跳动数据平台

大数据 字节跳动 数据质量

MXNet API入门 —第3篇_语言 & 开发_Julien Simon_InfoQ精选文章