写点什么

AWS 与微软合作发布 Gluon API 可快速构建机器学习模型

  • 2017-10-18
  • 本文字数:1422 字

    阅读完需:约 5 分钟

2017 年 10 月 12 日, AWS 与微软合作发布了 Gluon 开源项目,该项目旨在帮助开发者更加简单快速的构建机器学习模型,同时保留了较好的性能。

根据 Gluon 项目官方 Github 页面上的描述,Gluon API 支持任意一种深度学习框架,其相关规范已经在 Apache MXNet 项目中实施,开发者只需安装最新版本的 MXNet(master)即可体验。AWS 用户可以创建一个AWS Deep Learning AMI 进行体验。

该页面提供了一段简易使用说明,摘录如下:

本教程以一个两层神经网络的构建和训练为例,我们将它称呼为多层感知机(multilayer perceptron)。(本示范建议使用Python 3.3 或以上,并且使用 Jupyter notebook 来运行。详细教程可参考这个页面。)

首先,进行如下引用声明:

复制代码
import mxnet as mx
from mxnet import gluon, autograd, ndarray
import numpy as np

然后,使用gluon.data.DataLoader承载训练数据和测试数据。这个 DataLoader 是一个 iterator 对象类,非常适合处理规模较大的数据集。

复制代码
train_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=True, transform=lambda data, label: (data.astype(np.float32)/255, label)),
batch_size=32, shuffle=True)
test_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=False, transform=lambda data, label: (data.astype(np.float32)/255, label)),
batch_size=32, shuffle=False)

接下来,定义神经网络:

复制代码
# 先把模型做个初始化
net = gluon.nn.Sequential()
# 然后定义模型架构
with net.name_scope():
net.add(gluon.nn.Dense(128, activation="relu")) # 第一层设置 128 个节点
net.add(gluon.nn.Dense(64, activation="relu")) # 第二层设置 64 个节点
net.add(gluon.nn.Dense(10)) # 输出层

然后把模型的参数设置一下:

复制代码
# 先随机设置模型参数
# 数值从一个标准差为 0.05 正态分布曲线里面取
net.collect_params().initialize(mx.init.Normal(sigma=0.05))
# 使用 softmax cross entropy loss 算法
# 计算模型的预测能力
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
# 使用随机梯度下降算法 (sgd) 进行训练
# 并且将学习率的超参数设置为 .1
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': .1})

之后就可以开始跑训练了,一共分四个步骤。一、把数据放进去;二、在神经网络模型算出输出之后,比较其与实际结果的差距;三、用 Gluon 的autograd计算模型各参数对此差距的影响;四、用 Gluon 的trainer方法优化这些参数以降低差距。以下我们先让它跑 10 轮的训练:

复制代码
epochs = 10
for e in range(epochs):
for i, (data, label) in enumerate(train_data):
data = data.as_in_context(mx.cpu()).reshape((-1, 784))
label = label.as_in_context(mx.cpu())
with autograd.record(): # Start recording the derivatives
output = net(data) # the forward iteration
loss = softmax_cross_entropy(output, label)
loss.backward()
trainer.step(data.shape[0])
# Provide stats on the improvement of the model over each epoch
curr_loss = ndarray.mean(loss).asscalar()
print("Epoch {}. Current Loss: {}.".format(e, curr_loss))

若想了解更多 Gluon 说明与用法,可以查看 gluon.mxnet.io 这个网站。

2017-10-18 20:241927

评论

发布
暂无评论
发现更多内容

2024南京国际智能机器人展览会

AIOTE智博会

机器人展 智能机器人展

拓展AI边界:去中心化人工智能的应用场景和主要项目盘点

区块链软件开发推广运营

dapp开发 区块链开发 链游开发 NFT开发 公链开发

一文熟悉PolarDB-PG 分区表核心特性

阿里云数据库开源

数据库 阿里云 polarDB PolarDB-PG

软件测试学习笔记丨Allure2 失败重试功能应用场景

测试人

软件测试

Web3 游戏周报(3.17-3.23)

Footprint Analytics

Web3 游戏

不给灰暗留下死角:华为应用市场的安全之光

脑极体

应用

宁德时代与特斯拉合作;钟睒睒连续四次中国首富丨 RTE 开发者日报 Vol.171

声网

深圳站回顾|隐语最新功能、隐私计算硬核技术、数据要素实践干货全记录(附演讲视频)

隐语SecretFlow

Netflix微服务经验教训

俞凡

微服务 最佳实践 netflix 大厂实践

AI时代来临我们要如何面对?

小魏写代码

青亦学爬虫:根据淘宝天猫商品链接封装淘宝天猫商品详情数据接口

tbapi

淘宝API接口 淘宝商品详情接口 天猫商品详情接口 淘宝数据爬虫 天猫数据爬虫

Solana链狙击机器人:交易者的新宠

开发丨飞机丨 @aivenli

网心科技入选2023中国ToB行业影响力价值榜

网心科技

从数据存储的演迁,看芯赛云分布式存储应用

科技热闻

SpringBoot如何优雅的进行参数校验

不在线第一只蜗牛

Java 后端 springboot

什么样的商品管理系统可以驱动品牌增长?

第七在线

iOS开发优势解析,费用探究以及软件开发详解

华为云亮相KubeCon EU 2024,以持续开源创新开启智能时代

华为云开发者联盟

开源 开发 华为云 华为云开发者联盟

智达方通全面预算管理系统,为企业带来更可靠的交付

智达方通

全面预算管理 全面预算管理系统

容器镜像加速指南:探索 Kubernetes 缓存最佳实践

不在线第一只蜗牛

Kubernetes 容器化 集群

深入解析以太坊Dencun升级:提升网络性能与安全的关键举措

区块链软件开发推广运营

dapp开发 区块链开发 链游开发 NFT开发 公链开发

那位拿了多个Offer的大佬分享了最新Go面经

王中阳Go

Go 后端 Go 面试题 面经 后端 大厂

u-blox 面向多个大众应用市场推出最新 Wi-Fi 6 模块NORA-W4

科技之家

Databend 开源周报第 137 期

Databend

利用Python和数据获取技术实现智能旅游情报系统

阿Q说代码

Python 后端 数据获取

AWS与微软合作发布Gluon API 可快速构建机器学习模型_微软_sai_InfoQ精选文章