写点什么

AWS 与微软合作发布 Gluon API 可快速构建机器学习模型

  • 2017-10-18
  • 本文字数:1422 字

    阅读完需:约 5 分钟

2017 年 10 月 12 日, AWS 与微软合作发布了 Gluon 开源项目,该项目旨在帮助开发者更加简单快速的构建机器学习模型,同时保留了较好的性能。

根据 Gluon 项目官方 Github 页面上的描述,Gluon API 支持任意一种深度学习框架,其相关规范已经在 Apache MXNet 项目中实施,开发者只需安装最新版本的 MXNet(master)即可体验。AWS 用户可以创建一个AWS Deep Learning AMI 进行体验。

该页面提供了一段简易使用说明,摘录如下:

本教程以一个两层神经网络的构建和训练为例,我们将它称呼为多层感知机(multilayer perceptron)。(本示范建议使用Python 3.3 或以上,并且使用 Jupyter notebook 来运行。详细教程可参考这个页面。)

首先,进行如下引用声明:

复制代码
import mxnet as mx
from mxnet import gluon, autograd, ndarray
import numpy as np

然后,使用gluon.data.DataLoader承载训练数据和测试数据。这个 DataLoader 是一个 iterator 对象类,非常适合处理规模较大的数据集。

复制代码
train_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=True, transform=lambda data, label: (data.astype(np.float32)/255, label)),
batch_size=32, shuffle=True)
test_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=False, transform=lambda data, label: (data.astype(np.float32)/255, label)),
batch_size=32, shuffle=False)

接下来,定义神经网络:

复制代码
# 先把模型做个初始化
net = gluon.nn.Sequential()
# 然后定义模型架构
with net.name_scope():
net.add(gluon.nn.Dense(128, activation="relu")) # 第一层设置 128 个节点
net.add(gluon.nn.Dense(64, activation="relu")) # 第二层设置 64 个节点
net.add(gluon.nn.Dense(10)) # 输出层

然后把模型的参数设置一下:

复制代码
# 先随机设置模型参数
# 数值从一个标准差为 0.05 正态分布曲线里面取
net.collect_params().initialize(mx.init.Normal(sigma=0.05))
# 使用 softmax cross entropy loss 算法
# 计算模型的预测能力
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
# 使用随机梯度下降算法 (sgd) 进行训练
# 并且将学习率的超参数设置为 .1
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': .1})

之后就可以开始跑训练了,一共分四个步骤。一、把数据放进去;二、在神经网络模型算出输出之后,比较其与实际结果的差距;三、用 Gluon 的autograd计算模型各参数对此差距的影响;四、用 Gluon 的trainer方法优化这些参数以降低差距。以下我们先让它跑 10 轮的训练:

复制代码
epochs = 10
for e in range(epochs):
for i, (data, label) in enumerate(train_data):
data = data.as_in_context(mx.cpu()).reshape((-1, 784))
label = label.as_in_context(mx.cpu())
with autograd.record(): # Start recording the derivatives
output = net(data) # the forward iteration
loss = softmax_cross_entropy(output, label)
loss.backward()
trainer.step(data.shape[0])
# Provide stats on the improvement of the model over each epoch
curr_loss = ndarray.mean(loss).asscalar()
print("Epoch {}. Current Loss: {}.".format(e, curr_loss))

若想了解更多 Gluon 说明与用法,可以查看 gluon.mxnet.io 这个网站。

2017-10-18 20:241751

评论

发布
暂无评论
发现更多内容

冲击“金九银十”的利器!《Java权威面试指南(阿里版)》人手一份吊打面试官轻轻松松!

Java 编程 IT 计算机 知识分享

【Takin应用日记】记一次TransmittableThreadLocal引起的业务异常

TakinTalks稳定性社区

高可用 性能压测 生产环境全链路压测 takin

fil挖矿官网有哪些?fil挖矿平台有哪些?

fil挖矿平台有哪些 fil挖矿官网有哪些

极客星球 | 应用开发的性能优化探索

MobTech袤博科技

性能

如何快速定位程序Core?

百度Geek说

Linux 后端

GitHub星标63K霸榜半月!阿里大牛的微服务分布式架构笔记已上线

Java 编程 IT 计算机 知识

架构训练营 - 模块四 - 作业

姑射仙人

架构训练营

滴滴架构师被迫离职后,只留下这份731页Java程序性能优化手册

Java 编程 架构 面试 调优

员工流动大难管理?织信低代码+人事管理系统轻松掌控员工档案信息

优秀

低代码

肺炎在家“闭关”,阿里竟发来视频面试,4面顺利拿下offer

公众号_愿天堂没有BUG

Java 编程 程序员 架构 面试

中国如何应对中美博弈?

石云升

学习 贸易战 8月日更

【共识专栏】HotStuff共识

趣链科技

区块链 共识机制 拜占庭容错 共识算法

小心这个陷阱:为什么总是你赔钱?

非著名程序员

认知提升 个人提升 投资理财 8月日更

博睿数据分布式手机真机监测+两大核心技术,轻松掌控短信服务质量与用户体验

博睿数据

如何在多云环境中建立信任

云计算

云服务器在市场变化下的技术突破,企业运维中的基础保障

九河云安全

一个完整的内网渗透是什么样子的

网络安全学海

网络安全 信息安全 网络 渗透测试 漏洞分析

深耕城市治理场景,百度智能云联合慧联无限推内涝智能检测预警

百度大脑

人工智能 洪水

通俗易懂的ReentrantLock,不懂你来砍我

程序猿阿星

AQS 公平锁 非公平锁 独占锁 ReentrantLock;

去中心化DeFi系统开发

Geek_23f0c3

智能合约 DeFi去中心化系统开发 DAPP智能合约交易系统开发

极客星球 | Android SDK架构设计之路

MobTech袤博科技

架构 sdk andiod

如何实现H.264的实时传输?

拍乐云Pano

Aosp 之 Property

Qunar技术沙龙

android API properties 字典树 内存映射

iOS 开发技术栈与进阶

iOSer

ios 面试 iOS 知识体系 iOS技术栈

TCP 四次挥手

W🌥

计算机网络 TCP/IP 8月日更

DCS_FunTester分布式压测框架更新(二)

FunTester

分布式 性能测试 测试框架 测试开发 FunTester

啃了三个月!靠着这份大厂Java面试全秘籍,成功入职京东,税前30K

Java 程序员 架构 面试 计算机

嗨!你有一封来自百度世界大会的“情书”,818等你开启

百度大脑

人工智能

微博SDK初始化问题 please init sdk before use it. Wb.install()

mengxn

微博sdk

远程办公一星期,竟等来了阿里新零售视频面(Java岗,已过2面)

公众号_愿天堂没有BUG

Java 编程 程序员 架构 面试

运维工程师核心工作是什么?用什么运维工具好?

行云管家

云计算 运维 IT运维

AWS与微软合作发布Gluon API 可快速构建机器学习模型_微软_sai_InfoQ精选文章