写点什么

Java 机器学习工具箱:Amazon Deep Java Library

作者:Xinyu Liu, Frank Liu 等

  • 2020-06-17
  • 本文字数:4023 字

    阅读完需:约 13 分钟

Java机器学习工具箱:Amazon Deep Java Library

本文要点


  • 目前还没有用 Java 开发机器学习应用程序的标准

  • JSR 381 的提出就是为了填补这项空白

  • Amazon 的 Deep Java Library(DJL)是这个新标准的其中一种实现

  • VisRec 是 JSR 381 的一部分,用于图像的视觉识别

  • DJL 包含许多预训练的模型


近年来,人们对机器学习的兴趣稳步增长。具体来说,现在,企业在各种各样的场景中使用机器学习进行图像识别。它在汽车工业医疗保健安全零售仓库、农场和农业的自动化产品跟踪食品识别,甚至通过手机摄像头进行实时翻译等方面都有应用。借助机器学习和视觉识别,机器可以从 MRI 和 CT 扫描结果中发现癌症COVID-19


如今,这些解决方案主要是用 Python 开发的,使用了开源和专有的 ML 工具包,每个工具包都有自己的 API。尽管Java在企业中很流行,但是 Java 中没有任何标准是针对机器学习应用程序开发的。JSR-381的提出就是为了填补这项空白,它为 Java 应用程序开发人员提供了一套标准的、灵活的、Java 友好的、面向视觉识别(VisRec)应用程序(如图像分类和对象检测)的 API。JSR-381 有几个依赖于 TensorFlow、MXNet 和 DeepNetts 等机器学习平台的实现。其中一个实现是基于Deep Java Library(DJL)的,这是一个由 Amazon 开发的开源库,用于使用 Java 构建机器学习应用。DJL 通过绑定必要的图像处理例程,提供了流行机器学习框架(如TensorFlowMXNetPyTorch)的钩子,对于 JSR-381 的用户来说,这是一个灵活而简单的选项。


在本文中,我们将演示 Java 开发人员如何使用 JSR-381 VisRec API 在不到 10 行代码内利用 DJL 的预训练模型实现图像分类或对象检测。我们还通过两个例子演示了用户如何在 10 分钟内使用预先训练好的机器学习模型。让我们开始吧!

使用预训练的模型识别手写数字

识别手写数字是一个有用的应用,也是视觉识别的一个“hello world”示例。对人类来说,识别手写数字似乎很容易。得益于我们大脑中视觉和模式匹配子系统的处理能力和协作,我们通常可以从潦草的手写文件中正确地识别出数字。然而,由于可能存在许多变化,这个看似简单的任务对于机器来说是难以置信的复杂。这是机器学习,特别是视觉识别的一个很好的用例。JSR 381 库中有一个很好的示例,使用 JSR-381 VisRec API 正确地识别出了手写数字。这个示例将手写数字与MNIST手写数字数据集进行比较,后者是一个包含超过 6 万幅图像的公开数据库。预测图像所代表的内容称为图像分类。我们的示例查看一副新图像,并确定它具体是哪个数字的概率。


对于这项任务,VisRec API 提供了一个ImageClassifier接口,可以使用泛型参数具体化为输入图像的特定 Java 类。它还提供了一个classie()方法,该方法执行图像分类并返回所有可能的图像类别与概率的Map。根据 VisRec API 的约定,每个模型都提供一个静态的builder()方法,它返回一个对应的builder对象,并允许开发者配置所有相关的设置,例如imageHeightimageWidth


在我们的手写数字示例中,要定义一个图像分类器,就需要使用inputClass(BufferedImage.class) 配置输入句柄。你可以通过它指定使用哪个类来表示图像。你可以使用imageHeight(28)imageWidth(28) 将图像尺寸调整到 28x28,模型最初训练时就用的这个大小。


分类器对象构建完成后,将输入图像输入到分类器以识别图像。


File input = new File("../jsr381/src/test/resources/0.png");// 使用mlp文件夹里的预训练模型Path modelPath = Paths.get("../jsr381/src/test/resources/mlp");ImageClassifier<BufferedImage> classifier =       NeuralNetImageClassifier.builder()           // 输入时一个图像文件,应该作为BufferImage进行处理           .inputClass(BufferedImage.class)           // 图像尺寸应该调整到28 x 28           .imageHeight(28)           .imageWidth(28)           .importModel(modelPath)           .build();// 执行推断并获取分类结果Map<String, Float> result = classifier.classify(input);// 打印结果for (Map.Entry<String, Float> entry : result.entrySet()) {   System.out.println(entry.getKey() + ": " + entry.getValue());}
复制代码


执行上述代码会产生以下输出:


0: 0.99976332: 6.915607E-55: 2.7744078E-56: 6.1097984E-59: 3.8322916E-5
复制代码


对于图像中的数字,该模型识别出五种可能的选项,以及每个选项的概率。分类器以 99.98%的压倒性概率正确地预测了数字 0。


推而广之,如果需要检测出同一副图像中的多个不同的对象该怎么办?

使用预训练的单帧检测器(SSD)模型识别物体

单帧检测器(SSD)是一种利用一个深度神经网络从图像中检测物体的机制。本例使用预先训练好的 SSD 模型识别图像中的对象。对象检测是一项比较具有挑战性的视觉识别任务。除了对图像中的对象进行分类外,对象检测还可以识别图像中对象的位置。它还可以在关注对象周围绘制一个边框并添加一个类别(文本)标签。


SSD 机制是机器学习领域的一项最新进展,它检测对象的速度非常快,与此同时,还能保持与需要更大计算量的模型相媲美的准确性。要了解关于 SSD 模型的更多信息,可以阅读博文“理解SSD MultiBox——深度学习中的实时对象检测”以及《深入机器学习》这本书里的这个练习


使用 DJL 的 JSR-381 实现,用户可以访问预先训练好的、开箱即用的 SSD 模型实现。DJL 使用ModelZoo来简化模型部署。下面的代码块使用ModelZoo.loadModel()加载一个预先训练好的模型,实例化一个对象检测器类,并将这个模型应用到一副示例图像上。


// 定义一个满足用户需求的模型查找标准Criteria<BufferedImage, DetectedObjects> criteria =        Criteria.builder()                .setTypes(BufferedImage.class, DetectedObjects.class)                // 查找一个对象检测模型                .optApplication(Application.CV.OBJECT_DETECTION)                .build();// 加载模型,创建一个SimpleObjectDectector对象try (ZooModel<BufferedImage, DetectedObjects> model = ModelZoo.loadModel(criteria)) {   // SimpleObjectDetector是一个负责检测对象的高级JSR-381 API   SimpleObjectDetector objectDetector = new SimpleObjectDetector(model);   // 加载图像   BufferedImage input =       BufferedImageUtils.fromUrl(           "https://djl-ai.s3.amazonaws.com/resources/images/dog_bike_car.jpg");   // 检测对象   Map<String, List<BoundingBox>> result = objectDetector.detectObject(input);   for (List<BoundingBox> boundingBoxes : result.values()) {       for (BoundingBox boundingBox : boundingBoxes) {           System.out.println(boundingBox.toString());       }   }}
复制代码


下面是一副可供我们使用的新图像。



在这幅图像上运行代码将产生如下结果:


BoundingBox{id=0, x=124.0, y=119.0, width=456.45093, height=338.8393, label=bicycle, score=0.9538524}BoundingBox{id=0, x=469.0, y=78.0, width=225.19464, height=92.147675, label=car, score=0.99991035}BoundingBox{id=0, x=128.0, y=201.0, width=210.51933, height=341.7647, label=dog, score=0.9375212}
复制代码


如果你希望给从图像上检测到的每个对象添加边框,只需几行代码即可。要了解更多信息,请参见完整的GitHub示例。该模型对三个关注对象(自行车、汽车和狗)进行分类,在每个对象周围画一个边框,并提供一个由概率反映的置信度。



值得注意的是,预训练模型的检测精度取决于用于训练模型的图像。模型的精度可以通过再训练来提高,也可以使用一组更能代表最终应用程序的图像开发一个自定义的模型。然而,这种方法非常耗时,并且需要使用大量的训练数据。对于许多 ML 应用程序,使用预先训练好的模型建立基线通常是值得的。这可以节省大量收集、准备数据和从头训练模型的时间。

未来展望

在这篇文章中,我们仅仅了解了使用 JSR-381 API 的 DJL 实现可以做些什么。你可以使用 ModelZoo 中预先训练好的模型库探索和实现更多的模型,或者引入自己的模型。


感兴趣的读者可以检出DJL,这是一个由 Amazon 的 Java 开发人员为 Java 社区构建的开源库。我们试图简化 Java 中机器学习的开发和部署。欢迎加入我们!


DJL 有很多用例,你可以开发一个客服问答应用程序,实现瑜伽姿势的姿态估计,或者训练你自己的模型来检测后院的入侵者。我们的Spring Boot入门套件还简化了 ML 与 Spring Boot 应用程序的集成。读者可以通过我们的介绍性博客网站示例库了解更多关于 DJL 的信息。请访问我们的Github库,在我们的 Slack频道与我们合作。



参考资料




作者简介:


Frank Liu AWS AI软件工程师。他专注于为软件工程师和科学家打造创新型深度学习工具。在业余时间,他喜欢与朋友和家人一起徒步旅行。


Xinyu Liu AWS AI软件开发经理。他热衷于机器学习和大规模分布式系统。


Frank Greco 是 Crossroads Technologies 公司的创始人和首席执行官。他是高级技术顾问和企业架构师,致力于为开发人员提供云计算和 AI/ML 工具。他是一名 Java 冠军程序员、NYJavaSIG的主席,并在欧洲举办了企业机器学习国际会议。他业余时间喜欢弹吉他。


Zoran Sevarac Deep Netts的 CEO。他致力于为 Java 开发人员构建用户友好的深度学习工具,并创建 AI Java 标准。他是贝尔格莱德大学的教授和 Java 冠军程序员。他业余时间喜欢弹吉他。


Balaji Kamakoti AWS AI高级产品经理。他致力于让开发者更容易使用深度学习产品。在业余时间,他喜欢打网球,弹萨罗德琴(一种无琴格的弦乐器)。


特别感谢JCPJSR-381团队的宝贵贡献:Kevin Berendsen、Sandhya Kapoor、Werner Keil、Constantin Drabo、Ankara Parida、Melissa Mckay、Buddha Jyoti Prasad、Shreya Gupta、Amit Nagesh、Heather VanCura 和 Harold Ogle。


原文链接:


Machine Learning in Java With Amazon Deep Java Library


2020-06-17 10:392389

评论

发布
暂无评论
发现更多内容

VS Code 如何设置大小写转换快捷键

AlwaysBeta

vscode

一口气搞懂【Linux内存管理】,就靠这60张图、59个问题了

奔着腾讯去

内存泄露 内存管理 Linux Kenel 内存映射 内存池

瞰见 | 初创1个月就融到3亿美金,ClickHouse 你凭什么?

OpenTEKr

狄安瞰源

费用节省 50%,函数计算 FC 助力分众传媒降本增效

阿里巴巴云原生

阿里云 云原生 合作 函数计算FC 分众传媒

盘点 2021|自己一个人扛起了公司的半边天

liuzhen007

技术人生 盘点2021 盘点 2021

Git基础 |打tag

xcbeyond

git 28天写作 tag 12月日更

【架构实战营】模块三:命题作业

wgl

「架构实战营」

一个cpp协程库的前世今生(一)缘起

SkyFire

协程 cpp cocpp

「架构实战营」模块三《如何保证设计出合理的架构》作业

DaiChen

作业 模块三 「架构实战营」

如何设计贴合业务的高性能高可用中间件系统

天天向上

架构实战营

外包学生管理系统架构设计文档

李晓笛

「架构实战营」

【架构实战营】模块三:知识点总结

wgl

「架构实战营」

详细架构设计文档

Anlumina

#架构实战营

从人工到智能!百度AI开发者大会分论坛,探寻国球乒乓背后的AI之路

百度大脑

人工智能

百度智能云发布零碳园区解决方案,助力实现双碳目标

百度大脑

人工智能

外包学生管理系统详细设计文档

糖糖学编程

架构实战营

瞰见 | 开源,会不会变成开源创业的焦油坑?

OpenTEKr

狄安瞰源

第三模块学习总结

Anlumina

#架构实战营

学习总结 2021.12.30

mj4ever

学习笔记

百度飞桨EasyDL桌面版正式上线,没网也能训练AI!

百度大脑

人工智能

架构实战营模块三作业

lchx08

「架构实战营」

ALC北京发起人 姜宁:通过开放与协作,我们可以实现一个人想都不敢想的事情 I OpenTEKr 大话开源 Vol.6

OpenTEKr

大话开源

Apache 海豚调度 PMC 郭炜:开源,不是天才的甜点,而是执着者的盛宴 I OpenTEKr 大话开源 Vol.7

OpenTEKr

大话开源

引领人工智能技术自立自强 百度吴甜获评“首都最美巾帼奋斗者”

百度大脑

人工智能「

阿里巴巴超大规模 Kubernetes 基础设施运维体系揭秘

阿里巴巴云原生

阿里云 Serverless Kubernetes 云原生 ASI

第三周学习总结

糖糖学编程

架构实战营

Golang中文件的基本操作

liuzhen007

Go 28天写作 Go 语言 12月日更

Java 数据持久化系列之池化技术

程序员历小冰

MySQL 持久化 28天写作 池化技术 12月日更

元宇宙100讲-0x011

hackstoic

元宇宙

深入理解一下Python中的面向对象编程

宇宙之一粟

Python 面向对象 12月日更

架构模块三作业

holdzhu

「架构实战营」

Java机器学习工具箱:Amazon Deep Java Library_AI&大模型_InfoQ精选文章