写点什么

AWS 与微软合作发布 Gluon API 可快速构建机器学习模型

  • 2017-10-18
  • 本文字数:1422 字

    阅读完需:约 5 分钟

2017 年 10 月 12 日, AWS 与微软合作发布了 Gluon 开源项目,该项目旨在帮助开发者更加简单快速的构建机器学习模型,同时保留了较好的性能。

根据 Gluon 项目官方 Github 页面上的描述,Gluon API 支持任意一种深度学习框架,其相关规范已经在 Apache MXNet 项目中实施,开发者只需安装最新版本的 MXNet(master)即可体验。AWS 用户可以创建一个AWS Deep Learning AMI 进行体验。

该页面提供了一段简易使用说明,摘录如下:

本教程以一个两层神经网络的构建和训练为例,我们将它称呼为多层感知机(multilayer perceptron)。(本示范建议使用Python 3.3 或以上,并且使用 Jupyter notebook 来运行。详细教程可参考这个页面。)

首先,进行如下引用声明:

复制代码
import mxnet as mx
from mxnet import gluon, autograd, ndarray
import numpy as np

然后,使用gluon.data.DataLoader承载训练数据和测试数据。这个 DataLoader 是一个 iterator 对象类,非常适合处理规模较大的数据集。

复制代码
train_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=True, transform=lambda data, label: (data.astype(np.float32)/255, label)),
batch_size=32, shuffle=True)
test_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=False, transform=lambda data, label: (data.astype(np.float32)/255, label)),
batch_size=32, shuffle=False)

接下来,定义神经网络:

复制代码
# 先把模型做个初始化
net = gluon.nn.Sequential()
# 然后定义模型架构
with net.name_scope():
net.add(gluon.nn.Dense(128, activation="relu")) # 第一层设置 128 个节点
net.add(gluon.nn.Dense(64, activation="relu")) # 第二层设置 64 个节点
net.add(gluon.nn.Dense(10)) # 输出层

然后把模型的参数设置一下:

复制代码
# 先随机设置模型参数
# 数值从一个标准差为 0.05 正态分布曲线里面取
net.collect_params().initialize(mx.init.Normal(sigma=0.05))
# 使用 softmax cross entropy loss 算法
# 计算模型的预测能力
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
# 使用随机梯度下降算法 (sgd) 进行训练
# 并且将学习率的超参数设置为 .1
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': .1})

之后就可以开始跑训练了,一共分四个步骤。一、把数据放进去;二、在神经网络模型算出输出之后,比较其与实际结果的差距;三、用 Gluon 的autograd计算模型各参数对此差距的影响;四、用 Gluon 的trainer方法优化这些参数以降低差距。以下我们先让它跑 10 轮的训练:

复制代码
epochs = 10
for e in range(epochs):
for i, (data, label) in enumerate(train_data):
data = data.as_in_context(mx.cpu()).reshape((-1, 784))
label = label.as_in_context(mx.cpu())
with autograd.record(): # Start recording the derivatives
output = net(data) # the forward iteration
loss = softmax_cross_entropy(output, label)
loss.backward()
trainer.step(data.shape[0])
# Provide stats on the improvement of the model over each epoch
curr_loss = ndarray.mean(loss).asscalar()
print("Epoch {}. Current Loss: {}.".format(e, curr_loss))

若想了解更多 Gluon 说明与用法,可以查看 gluon.mxnet.io 这个网站。

2017-10-18 20:242200

评论

发布
暂无评论
发现更多内容

protocol buffer没那么难,不信你看这篇

程序那些事

Java protobuf 程序那些事

MySQL 系列教程之(十二)扩展了解 MySQL 的存储过程,视图,触发器

若尘

MySQL 数据库 8月日更

在java程序中使用protobuf

程序那些事

Java protobuf 程序那些事

基于Mybatis-plus实现多租户架构

码农参上

多租户 8月日更 Mybatis-Plus

神策 Android 全埋点插件介绍

神策技术社区

程序员 数据分析 埋点

神策分析 Web JS SDK 功能介绍

神策技术社区

程序员 代码 埋点

保护亿万数据安全,Spring有“声明式事务”绝招

华为云开发者联盟

spring 数据安全 事务管理

前端、后端、测试、研发经理必备技能-ApiPost接口管理工具

CodeNongXiaoW

大前端 测试 后端 接口工具

拿捏!隔离级别、幻读、Gap Lock、Next-Key Lock

艾小仙

MySQL sql 面试 大前端

LeetCode刷题07-简单 整数翻转

ベ布小禅

8月日更

从 FFmpeg 性能加速到端云一体媒体系统优化

阿里云CloudImagine

开源 ffmpeg 视频处理 视频流 视频云

神策分析 iOS SDK 代码埋点解析 | 数据采集

神策技术社区

程序员 数据 代码 埋点

国产数据库的挑战与机遇

晨山资本

数据库 大数据 云原生 超融合

神策分析 Android SDK 网络模块解析

神策技术社区

程序员 代码 神策数据

图文并茂的聊聊ReentrantReadWriteLock的位运算

程序猿阿星

ReentrantReadWriteLock 位运算

LeetCode题解:28. 实现 strStr(),暴力法,JavaScript,详细注释

Lee Chen

算法 大前端 LeetCode

vivo商城计价中心 - 从容应对复杂场景价格计算

vivo互联网技术

Java 架构 后端 促销系统

容器监控薅光了头发?这篇你再也不能错过!

观测云

json Docker 云计算 Linux 容器

2021 年 8 月国产数据库排行榜:秋日胜春朝

墨天轮

数据库 TiDB oceanbase 国产数据库 达梦

架构实战营 模块六作业

孫影

架构实战营 #架构实战营

书单 | 无所不能的Python,从技术到办公,总有一款适合你!

博文视点Broadview

手把手教你写 Gradle 插件 | 数据采集

神策技术社区

程序员 埋点 数据化 神策数据

架构实战营模块一作业

michael

架构实战营

带你认识MRS CDL架构

华为云开发者联盟

数据库 大数据 FusionInsight MRS MRS CDL 实时同步

原来一条select语句在MySQL是这样执行的《死磕MySQL系列 一》

咔咔

MySQL 数据库

SphereEx CEO 张亮:数据库上云是大势所趋|初心·问

SphereEx

数据库 开源

架構實戰營 - 畢業設計

Frank Yang

架构实战营

FL Studio基本功能介绍

懒得勤快

神策分析 iOS SDK 全埋点解析之启动与退出

神策技术社区

ios 代码 埋点 神策数据

揭秘环境管理 Noah 的技术实现

Qunar技术沙龙

测试 Dev QA 环境 资源池

支持 10 亿日流量的基础设施:当 Apahce APISIX 遇上腾讯

API7.ai 技术团队

案例 API网关 APISIX Meetup 腾讯游戏

AWS与微软合作发布Gluon API 可快速构建机器学习模型_微软_sai_InfoQ精选文章