写点什么

Deep Java Library (DJL) 简介:与引擎无关的 Java 深度学习框架

  • 2020-01-13
  • 本文字数:8821 字

    阅读完需:约 29 分钟

Deep Java Library (DJL) 简介:与引擎无关的Java深度学习框架

本文要点

  • 开发人员可以使用 Java 和他们喜欢的 IDE 来构建、训练和部署机器学习(ML)和深度学习(DL)模型

  • DJL 简化了深度学习(DL)框架的使用,目前支持 Apache MXNet

  • DJL 的开源对于工具包及其用户来说都是互惠互利的

  • DJL 是引擎无关的,这意味着开发人员只需编写一次代码就可以在任何引擎上运行

  • 在尝试使用 DJL 之前,Java 开发人员应该了解 ML 生命周期和常用的 ML 术语


亚马逊(Amazon)的 DJL(Deep Java Library )是一个深度学习工具包,使用它可在 Java 中原生地进行机器学习(ML)和深度学习(DL)模型开发,从而简化深度学习框架的使用。DJL 是在 2019 年 re:Invent 大会上开源的工具包,它提供了一组高级 API 来训练、测试和运行在线推理(inference)。Java 开发人员可以开发自己的模型,也可以在他们的 Java 代码中使用数据科学家用 Python 开发的预先训练的模型。


DJL 秉承了 Java 的座右铭,“编写一次,到处运行(WORA)”,因为它是引擎和深度学习框架无关的。开发人员只需编写一次就可在任何引擎上运行。DJL 目前提供了一个 Apache MXNet 的实现,这是一个可以简化深度神经网络开发的 ML 引擎。DJL API 使用 JNA(Java Native Access)来调用相应的 Apache MXNet 操作。DJL 编排管理基础设施,基于硬件配置来提供自动的 CPU/GPU 检测,以确保良好的运行效果。


DJL API 通过抽象常用的功能来开发模型,这使 Java 开发人员能够利用现有的知识,从而可以轻松地过渡到 ML。为了了解 DJL 的实际效果,我们开发一个“鞋”的分类模型作为一个简单的示例。

机器学习生命周期

我们建立“鞋”分类模型遵循了机器学习的生命周期。ML 生命周期与传统的软件开发生命周期有所不同,它包含六个具体的步骤:


  1. 获取数据

  2. 清洗并准备数据

  3. 生成模型

  4. 评估模型

  5. 部署模型

  6. 从模型中获得预测(或推理)


生命周期的最终结果是一个可以查询并返回答案(或预测)的机器学习模型。



模型只是数据中趋势和模式的数学表示。好的数据才是所有 ML 项目的基础。


在步骤 1 中,从可靠的来源中获取数据。在步骤 2 中,数据被清洗、转换并以机器可以学习的格式存储。清洗和转换过程通常是机器学习生命周期中最耗时的部分。DJL 提供了利用翻译器(translator)来对图像进行预处理的能力,这能为开发人员简化清洗和转换过程。翻译器可以执行一些图像任务,比如,可以根据预设参数调整图像的大小或将图像从彩色图转换为灰度图。


刚刚过渡向机器学习的开发人员常常会低估清洗和转换数据所需的时间,因此翻译器是快速启动该过程的好方法。步骤 3,在训练过程中,一个机器学习算法会对数据进行多遍(或多代)处理,不断研究它们,以试图学习到不同类型的“鞋”。训练过程中发现的与“鞋”相关的趋势和模式会被存储在模型中。当需要评估模型以确定其在识别“鞋”方面的能力时,第 4 步会作为训练的一部分;如果发现了错误,则予以纠正。在步骤 5 中,将模型部署到生产环境中。模型投入生产后,步骤 6 允许其他系统使用该模型。


通常,可以在代码中动态地加载模型,或者通过基于 REST 的 HTTPS 端点访问模型。

数据

“鞋”分类模型是一个多级分类计算机视觉(CV)模型,它使用有监督学习进行训练,可以将“鞋”分为四类:靴子(boots)、凉鞋(sandals)、鞋子(shoes)或拖鞋(slippers)。有监督学习必须包含已经标记了我们想要预测的目标(或答案)的数据;这就是机器学习的方式。


“鞋”分类模型的数据源是德克萨斯大学奥斯汀分校(The University of Texas at Austin)提供的 UTZappos50k 数据集(dataset),它可免费用于学术和非商业用途。下面这个“鞋子”数据集包含了从 Zappos.com 收集的 50025 张带标签的目录图像。



“鞋”数据保存在本地,并使用 DJL 的 ImageFolder 数据集对其进行加载,该数据集可以从本地文件夹中检索图像。


// 识别训练数据的位置String trainingDatasetRoot = "src/test/resources/imagefolder/train";
// 识别验证数据的位置String validateDatasetRoot = "src/test/resources/imagefolder/validate";
// 创建训练数据 ImageFolder 数据集ImageFolder trainingDataset = initDataset(trainingDatasetRoot);
//创建验证数据 ImageFolder 数据集ImageFolder validateDataset = initDataset(validateDatasetRoot);
复制代码


在本地构造数据时,我并没有深入到 UTZappos50k 数据集所标识的最细粒度的分类等级,比如到脚踝的、膝盖等高的、到达小腿中部的、过膝的等靴子的最细粒度等级的分类标签。我的本地数据使用的是最高等级的分类,仅包括靴子、凉鞋、鞋子和拖鞋等四类。



在 DJL 术语中,数据集只用于保存训练数据。有些数据集的实现可用于下载数据(基于我们提供的 URL)、提取数据、以及自动地将数据分为训练集和验证集。


自动分离是一个特别有用的特性,因为不使用相同的数据来训练和验证模型这一点是至关重要的。该模型所使用的训练数据集用于查找“鞋”数据中的趋势和模式。验证数据集通过提供对“鞋”分类模型精度无偏差的估计来检验模型的效果。


如果用训练的数据验证模型,则会降低我们对模型分类鞋子能力的信心,因为模型是用它已经看到的数据进行测试的。在现实世界中,老师也不会使用和学习指南上完全相同的题目来测试学生,因为这不能衡量一个学生的真实知识或对资料的理解;当然,同样的概念也适用于机器学习模型。

训练

现在我们已经将“鞋”数据分为训练集和验证集,下面我们将使用神经网络来训练(或生成)模型。


public final class Training extends AbstractTraining {
. . .
@Override protected void train(Arguments arguments) throws IOException {
// 识别训练数据的位置 String trainingDatasetRoot = "src/test/resources/imagefolder/train";
// 识别验证数据的位置 String validateDatasetRoot = "src/test/resources/imagefolder/validate";
//创建训练数据 ImageFolder 数据集 ImageFolder trainingDataset = initDataset(trainingDatasetRoot);
//创建验证数据 ImageFolder 数据集 ImageFolder validateDataset = initDataset(validateDatasetRoot);
. . . try (Model model = Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH)) { TrainingConfig config = setupTrainingConfig(loss);
try (Trainer trainer = model.newTrainer(config)) { trainer.setMetrics(metrics);
trainer.setTrainingListener(this);
Shape inputShape = new Shape(1, 3, NEW_HEIGHT, NEW_WIDTH);
// 根据相应输入的形状初始化训练器 trainer.initialize(inputShape);
//在数据中查找模式 fit(trainer, trainingDataset, validateDataset, "build/logs/training");
//设置模型属性 model.setProperty("Epoch", String.valueOf(EPOCHS)); model.setProperty("Accuracy", String.format("%.2f", getValidationAccuracy()));
// 训练完成后保存模型,为后面的推理做准备 //模型保存为 shoeclassifier-0000.params model.save(Paths.get(modelParamsPath), modelParamsName); } } }
}
复制代码


第一步是通过调用 Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH) 来获取模型实例。深度学习是机器学习的一种形式,它使用神经网络来训练模型。神经网络是以人脑中的神经元来进行建模的;神经元是可以将信息(或数据)传递给其他细胞的细胞。


ResNet-50 是一种常用于图像分类的神经网络,50 表示从初始输入数据和最终预测之间有 50 个学习层(或神经元)。getModel() 方法用于创建一个空模型,构造一个 ResNet-50 神经网络,并将神经网络设置到该模型中。


public class Models {   public static ai.djl.Model getModel(int numOfOutput, int height, int width) {       //创建一个空模型的新实例       ai.djl.Model model = ai.djl.Model.newInstance();
//是构建神经网络所需的可组合单元;可以像像乐高积木一样将它们连结在一起, //形成一个复杂的网络 Block resNet50 = //构建网络 new ResNetV1.Builder() .setImageShape(new Shape(3, height, width)) .setNumLayers(50) .setOutSize(numOfOutput) .build();
//将神经网络设置到模型中 model.setBlock(resNet50); return model; }}
复制代码


下一步是通过调用 model.newTrainer(config) 方法来设置和配置训练器。通过调用 setupTrainingConfig(loss) 方法来初始化配置对象,该方法通过设置训练的配置(或超参)来决定如何训练网络。


接下来的步骤使我们可以通过设置以下内容来向 Trainer 中添加功能:


  • 使用 trainer.setMetrics(metrics) 来设置 Metrics

  • 使用 trainer.setTrainingListener(this) 来设置训练监听器

  • 使用 trainer.initialize(inputShape) 来设置合适的输入形状


Metrics 在训练期间收集并报告关键绩效指标(KPI),该 KPI 可用于分析和监控训练的效果和稳定性。下一步是通过调用 fit(trainer, trainingDataset, validateDataset, “build/logs/training”) 方法来启动训练过程,该方法将迭代训练数据并存储在模型中找到的模式。训练结束时,使用 model.save(Paths.get(modelParamsPath) 方法将一个表现良好的、经过验证的模型工件及属性保存在本地。


训练过程中报告的度量指标如下所示。注意,随着每代(epoch)(或每遍(pass))的递增,模型的精度都会提高;第 9 代(epoch)的最终训练精度为 90%。


推理

现在我们已经生成了模型,它可以用于对我们不知道类型(或目标)的新数据执行推理(或预测)。


private Classifications predict() throws IOException, ModelException, TranslateException  {   //在训练期间保存到模型的位置   String modelParamsPath = "build/logs";
//训练时设置的模型名称 String modelParamsName = "shoeclassifier";
//需要分类的图像路径 String imageFilePath = "src/test/resources/slippers.jpg";
//从路径加载图像文件 BufferedImage img = BufferedImageUtils.fromFile(Paths.get(imageFilePath));
//持有每个标签的概率分数 Classifications predictResult;
try (Model model = Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH)) { //加载模型 model.load(Paths.get(modelParamsPath), modelParamsName);
//定义用于预处理和后置处理的翻译器 Translator<BufferedImage, Classifications> translator = new MyTranslator();
//使用预测器运行推理 try (Predictor<BufferedImage, Classifications> predictor = model.newPredictor(translator)) { predictResult = predictor.predict(img); } }
return predictResult;}
复制代码


在设置了模型和要分类的图像的必要路径之后,使用 Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH) 方法获取一个空模型实例,并使用 model.load(Paths.get(modelParamsPath), modelParamsName) 方法对其进行初始化。它将会加载上一步训练的模型。


接下来,使用 model.newPredictor(translator) 方法初始化一个带有指定的 Translator 的 Predictor。在 DJL 术语中,Translator 提供了模型预处理和置后处理的能力。例如,对于 CV 模型,需要将图像重塑为灰度图;Translator 是可以做到的。Predictor 使我们可以利用 predictor.predict(img) 方法来对加载的 Model 进行推理,并传入图像进行分类。


这个示例展示的是单个的预测,但是 DJL 也支持批量预测。推理存储在 predictResult 中,predictResult 包含了每个标签的概率估计。


推理(每张图片)及其对应的概率得分如下所示。






(表格对应的图片如上所示)


图像概率得分
如图1[信息] - [                 分类: “0”, 概率: 0.98985                 分类: “1”, 概率: 0.00225                 分类: “2”, 概率: 0.00224                 分类: “3”, 概率: 0.00564             ] 分类0 代表靴子,概率得分为 98.98%
图2[信息] - [                分类: “0”, 概率: 0.02111                分类: “1”, 概率: 0.76524                分类: “2”, 概率: 0.01159                分类: “3”, 概率: 0.20204           ] 分类1 代表凉鞋,概率得分为 o76.52%
图3[信息] - [                分类: “0”, 概率: 0.05523                分类: “1”, 概率: 0.01417                分类: “2”, 概率: 0.87900                分类: “3”, 概率: 0.05158               ] 分类2 代表鞋子,概率得分为 87.90%
图4[信息] - [                 分类: “0”, 概率: 0.00003                 分类: “1”, 概率: 0.01133                分类: “2”, 概率: 0.00179                 分类: “3”, 概率: 0.98682               ] 分类3 代表拖鞋,概率得分为of 98.68%.


DJL 提供了与其他 Java 库一样的原生 Java 开发体验和功能。设计这些 API 是为了指导开发人员能够用最佳实践来完成深度学习任务。在开始使用 DJL 之前,需要对 ML 生命周期有一个很好的理解。如果您是 ML 初学者,请先阅读这篇概述或 InfoQ 的系列文章《软件开发人员机器学习入门》。在理解了生命周期和常见的 ML 术语之后,开发人员就可以快速地掌握 DJL 的 API 了。


亚马逊已经开源了 DJL,有关该工具包的更多详细信息可以在 DJL 网站Java 库 API 规范(Java Library API Specification) 页面上找到。您也可以回顾下“鞋”分类模型的代码,以进一步探索该示例。

作者介绍

Kesha Williams 是一位屡获殊荣的软件工程师、机器学习实践者和 A Cloud Guru 的技术讲师,拥有 24 年的经验。在大学任教期间,她曾培训并指导了数千名来自美国、欧洲和亚洲的 Java 软件工程师。她经常带领创新团队验证新兴技术,并在全球各地的会议上分享她的经验教训。作为 TED 的 Spotlight Presentation Academy 的获得者,她在 TED 舞台上做过机器学习的演讲。此外,她在人工智能领域的开创性工作为她赢得了亚马逊的 Alexa Champion 和 AWS Machine Learning Hero 的殊荣。在业余时间,她通过在线社交专业网络平台 Colors of STEM 指导女性科技从业者。


原文链接:


Getting to Know Deep Java Library (DJL)


2020-01-13 09:157176

评论

发布
暂无评论
发现更多内容

EMQX企业版正式入驻华为云云商城,成为华为云联营联运合作伙伴

EMQ映云科技

物联网 IoT 华为云 云端 企业号 1 月 PK 榜

全景剖析阿里云容器网络数据链路(二):Terway EN

阿里巴巴云原生

阿里云 容器 云原生

Hive查询语句

mm

漏洞优先级排序的六大关键因素

SEAL安全

安全 漏洞 企业号 1 月 PK 榜 优先级排序

PolarDB for PostgreSQL 14 开源实战训练营免费报名中!

阿里云数据库开源

数据库 阿里云 开源 postgre PolarDB for PostgreSQL

如何定义算法?10分钟带你弄懂算法的基本概念

九章云极DataCanvas

机器学习 机器学习算法

送给SQL开发者的一份新年礼物!一款100%自主研发的纯Web化SQL开发工具——SQL Studio 1.0正式发布

雨果

sql 数据库管理工具 SQL开发工具

让开源和标准成为云原生的确定性力量

阿里巴巴云原生

阿里云 开源 云原生

不懂任务调度系统,快来看这篇

华为云开发者联盟

后端 开发 华为云 企业号 1 月 PK 榜

Apache Spark + 海豚调度:PB 级数据调度挑战,教你如何构建高效离线工作流

白鲸开源

海豚调度 Apache Spark 大数据 开源

软件测试/测试开发 | 跨平台设备管理方案 Selenium Grid

测试人

软件测试 自动化测试 测试开发 selenium Grid

Spring Boot 3.0横空出世,快来看看是不是该升级了

程序那些事

Java spring 程序那些事 spring boot3

KubeVela 获得 2022 “开源新锐”和“开发者最喜爱”双料年度项目

阿里巴巴云原生

阿里云 开源 云原生

4个因素会影响LED显示屏的安全防火问题

Dylan

LED显示屏 全彩LED显示屏 led显示屏厂家

聊聊Cookie、Session、Token 背后的故事

华为云开发者联盟

前端 华为云 企业号 1 月 PK 榜

智能流程机器人助你“聚划算”

华为云开发者联盟

人工智能 机器人 华为云 企业号 1 月 PK 榜

使用服务网格提升应用和网络安全

HummerCloud

服务网格 云原生安全

从一个Demo说起Dubbo3

宋小生

dubbo RPC Dubbo3

市面上数一数二的双机热备系统当属Skybility HA!

行云管家

高可用 厂商 双机热备 双机热备系统

一站式云原生体验|龙蜥云原生ACNS + Rainbond

北京好雨科技有限公司

Kubernetes 云原生

DNS 代理?Pipy:这我也可以

Flomesh

Pipy 可编程代理 流量管理

Payso×OceanBase:云上拓新,开启云数据库的智能托管

OceanBase 数据库

数据库 oceanbase

【Redis 技术探索】「数据迁移实战」手把手教你如何实现在线 + 离线模式进行迁移 Redis 数据实战指南(scan模式迁移)

码界西柚

redis 数据同步 1月日更 RedisShake

使用无代码构建移动应用程序

间隔

基于单机最高能效270亿参数GPT模型的文本生成与理解

阿里云大数据AI技术

自然语言处理 机器学习 GPT 企业号 1 月 PK 榜

软件测试 | 测试开发| 跨平台设备管理方案Selenium Grid

测吧(北京)科技有限公司

成功上岸字节全靠这份Redis技术笔记,深入浅出值得一看

小小怪下士

Java redis 程序员 面试 字节

年度重磅!《2022华为开发者宝典》免费下载

华为云开发者联盟

开源 华为云 鲲鹏 昇腾 企业号 1 月 PK 榜

解决Redis缓存穿透/击穿/雪崩以及数据一致性的方案

风铃架构日知录

Java redis 缓存穿透 缓存雪崩 数据一致性

桌面云是什么?有什么优势?桌面云是云桌面吗?

行云管家

云计算 桌面云 云桌面

不会还有人不知道吧?BOM上的器件也能在PCB上快速定位啦!(内附高效手焊攻略)

华秋PCB

工艺 PCB PCB设计 焊接 PCB工艺

Deep Java Library (DJL) 简介:与引擎无关的Java深度学习框架_AI&大模型_Kesha Williams_InfoQ精选文章