写点什么

Analytics Zoo:在 Apache Spark 上实现分布式 Tensorflow 和 BigDL 管道的统一分析和 AI 平台

  • 2018-12-22
  • 本文字数:5824 字

    阅读完需:约 19 分钟

Analytics Zoo:在Apache Spark上实现分布式Tensorflow和BigDL管道的统一分析和AI平台
  • 如今,将深度学习应用于大数据管道往往需要手工“拼接”许多独立的组件(如 TensorFlow、Apache Spark、Apache HDFS 等),这个过程可能非常复杂,而且容易出错。

  • Analytics Zoo提供了一个在 Apache Spark 上实现分布式 TensorFlow、Keras 和 BigDL 管道的统一分析和 AI 平台,简化了这个过程。

  • 它将 Spark、TensorFlow、Keras 和 BigDL 程序无缝地合并到一个集成管道中,可以透明地扩展到大型 Apache Hadoop/Spark 集群,用于分布式训练或推理。

  • 早期用户(如世界银行CrayTalrooBaosight美的/库卡等)已经基于 Analytics Zoo 构建了分析+AI 应用程序,它可以应用于范围广泛的工作负载(包括基于迁移学习的图像分类、用于短时降水预测的 sequence-to-sequence 预测、用于推荐工作的神经协同过滤、无监督时序异常检测等等)。

  • 本文提供了几个具体的教程,介绍如何使用 Analytics Zoo 在 Apache Spark 上实现分布式 TensorFlow 管道,以及在实际的用例中使用 Analytics Zoo 实现端到端的文本分类管道。


人工智能应用程序的不断进步将深度学习带到新一代数据分析开发的前沿。特别是,我们看到越来越多的组织需要将深度学习技术(如计算机视觉、自然语言处理、生成对抗神经网络等)应用到他们的大数据平台和管道。如今,这常常需要手工“拼接”许多独立的组件(例如 Apache Spark、TensorFlow、Caffe、Apache Hadoop 分布式文件系统 HDFS、Apache Storm/Kafka 等),这可能是一个复杂且容易出错的过程。


在英特尔,我们一直与开源社区用户以及京东UCSFMastercard等合作伙伴和客户广泛合作,在 Apache Spark 上构建深度学习(DL)和 AI 应用程序。为了简化端到端的开发和部署,我们开发了Analytics Zoo,这是一个统一的分析+ AI 平台,它将 Spark、TensorFlow、Keras 和 BigDL 程序无缝地合并到一个集成管道中,可以透明地扩展到大型 Apache Hadoop/Spark 集群,用于分布式训练或推理。


早期用户(如世界银行CrayTalrooBaosight美的/库卡等)已经基于 Analytics Zoo 构建了分析+AI 应用程序,它可以应用于范围广泛的工作负载,其中包括基于迁移学习的图像分类、用于短时降水预测的 sequence-to-sequence 预测、用于推荐工作的神经协同过滤、无监督时序异常检测等等。


本文提供了几个具体的教程,介绍如何使用 Analytics Zoo 在 Apache Spark 上实现分布式 TensorFlow 管道,以及在实际的用例中使用 Analytics Zoo 实现端到端的文本分类管道。

Apache Spark 上的分布式 TensorFlow

使用 Analytics Zoo,用户可以方便地使用 Spark 和 TensorFlow 在大型集群上构建端到端的深度学习管道,如下所述。

使用 PySpark 进行数据预处理和分析

举例来说,要使用分布式的方式处理对象检测管道的训练数据,可以使用 PySpark 简单地把原始图像数据读入一个 RDD(弹性分布式数据集),这是一个跨集群分区的不可变记录集合,然后运用一些转换解码图像,并提取边界框和类标签,如下所示。


train_rdd = sc.parallelize(examples_list)  .map(lambda x: read_image_and_label(x))  .map(lambda image: decode_to_ndarrays(image))
复制代码


结果 RDD (train_rdd)中的每条记录都包含一个 NumPy ndrray 列表(即图像、边界框、类和检测到的框的数量),然后可以直接在 Analytics Zoo 上用于 TensorFlow 模型的分布式训练;这是通过从结果 RDD 创建TFDataset来完成的(如下所示)。


dataset = TFDataset.from_rdd(train_rdd,             names=["images", "bbox", "classes", "num_detections"],             shapes=[[300, 300, 3],[None, 4], [None], [1)]],             types=[tf.float32, tf.float32, tf.int32, tf.int32],             batch_size=BATCH_SIZE,             hard_code_batch_size=True)
复制代码

使用 TensorFlow 开发深度学习模型

在 Analytics Zoo 中,TFDataset 表示一个分布式元素集,其中每个元素包含一个或多个 Tensorflow Tensor 对象。然后,我们可以直接使用这些 Tensor(作为输入)来构建 Tensorflow 模型;例如,我们可以使用 Tensorflow Object Detection API 构建一个 SSDLite+MobileNet V2 模型(如下所示)。


# 使用tensorflow对象检测api来构造模型# https://github.com/tensorflow/models/tree/master/research/object_detectionfrom object_detection.builders import model_builder
images, bbox, classes, num_detections = dataset.tensors
detection_model = model_builder.build(model_config, is_training=True)resized_images, true_image_shapes = detection_model.preprocess(images)detection_model.provide_groundtruth(bbox, classes)prediction_dict = detection_model.predict(resized_images, true_image_shapes)losses = detection_model.loss(prediction_dict, true_image_shapes)total_loss = tf.add_n(losses.values())
复制代码

在 Spark 和 BigDL 上进行分布式训练/推理

在构造好模型之后,我们可以直接在 Spark 上(利用 BigDL 框架)以分布式的方式训练模型。例如,在下面的代码片段中,我们应用迁移学习技术来训练一个在 MS COCO 数据集上预训练过的 Tensoflow 模型。


with tf.Session() as sess:    init_from_checkpoint(sess, CHECKPOINT_PATH)    optimizer = TFOptimizer(total_loss, RMSprop(LR), sess)    optimizer.optimize(end_trigger=MaxEpoch(20))    save_to_new_checkpoint(sess, NEW_CHEKCPOINT_PATH)
复制代码


在后台,从磁盘读取输入数据并进行预处理,利用 PySpark 生成 Tensorflow Tensor 的 RDD;然后,在 BigDL 和 Spark(如BigDL技术报告所述)上以分布式的方式对 Tensorflow 模型进行训练。整个训练管道可以自动从单个节点扩展到基于 Xeon 的大规模 Hadoop/Spark 集群(无需修改代码或手动配置)。


另外,模型训练好以后,我们可以使用 PySpark、TensorFlow 和 BigDL(类似于上面的训练管道)在 Analytics Zoo 上执行大规模的分布式评估/推断。或者,我们也可以使用 Analytics Zoo 提供的 POJO 风格的服务API来部署低延迟的在线服务(例如,Web 服务、Apache Storm、Apache Flink 等)模型,如下所示。


AbstractInferenceModel model = new AbstractInferenceModel(){};model.loadTF(modelPath, 0, 0, false);List<List<JTensor>> output = model.predict(inputs);
复制代码


下图显示了 Analytics Zoo 中 Apache Spark 管道上的分布式 TensorFlow 的整个工作流(包括训练、评估/推断和在线服务)。


端到端分析和 AI 管道

Analytics Zoo 还为用户提供了丰富的端到端管道分析和 AI 支持,包括:


  • 易于使用的抽象,如 Spark Dataframe 和 ML 管道支持、迁移学习支持、Keras 风格的 API、POJO 风格的模型服务 API 等等;

  • 面向图象、文本和 3D 图象的常见特征工程操作*;*

  • 内置的深度学习模型,如文本分类、推荐、对象检测、图象分类等;

  • 参考用例,如时间序列异常检测、欺诈检测、图像相似性搜索等。


使用这些高级管道支持,用户可以在几行代码中轻松构建复杂的数据分析和深度学习应用程序,如下所述。

使用 NNImageReader 将图象加载到 Spark DataFrames 中

from zoo.common.nncontext import *
from zoo.pipeline.nnframes import *
sc = init_nncontext()
imageDF = NNImageReader.readImages(image_path, sc)
复制代码

使用 DataFrames 转换处理加载的数据

getName = udf(lambda row: ...)
getLabel = udf(lambda name: ...)
df = imageDF.withColumn("name", getName(col("image"))) \
.withColumn("label", getLabel(col('name')))
复制代码

使用内置的特征工程操作处理图像

from zoo.feature.imageimport *transformer = ChainedPreprocessing(        [RowToImageFeature(), ImageChannelNormalize(123.0, 117.0, 104.0),         ImageMatToTensor(), ImageFeatureToTensor()])
复制代码


使用迁移学习 API 加载已有的 Caffe 模型,删除最后几层,冻结开始几层,追加几个新层(使用 Keras 风格的 API)


from zoo.pipeline.api.netimport *full_model = Net.load_caffe(def_path, model_path)# 删除pool5之后的层Remove layers after pool5model = full_model.new_graph(outputs=["pool5"])# 冻结从输入到res4f 之间的层,包括res4fmodel.freeze_up_to(["res4f"])# 追加几个层image = Input(name="input", shape=(3, 224, 224))resnet = model.to_keras()(image)resnet50 = Flatten()(resnet)
logits = Dense(2)(flatten)
newModel = Model(inputs, logits)
复制代码

使用 Spark ML 管道训练模型

estimater = NNEstimater(newModel, CrossEntropyCriterion(), transformer) \                .setLearningRate(0.003).setBatchSize(40).setMaxEpoch(2) \                .setFeaturesCol("image").setCachingSample(False)nnModel = estimater.fit(df)
复制代码

基于 Analytics Zoo 的真实 AI 案例

如上所述,有许多早期用户已经在 Analytics Zoo 上构建了真实的应用程序。在本节中,我们将更详细地描述如何在Microsoft Azure的 Analytics Zoo 上使用 NLP 技术构建端到端的文本分类管道。

文本分类概述

文本分类是一种常见的自然语言处理任务,其目的是将输入文本语料库分类为一个或多个类别。例如,垃圾邮件检测将电子邮件的内容分为垃圾邮件或非垃圾邮件类别。


一般来说,文本分类模型的训练包括以下步骤:收集和准备训练数据集及验证数据集、数据清理和预处理、训练模型、验证和评估模型、优化模型(包括但不限于添加数据、调整超参数、调整模型)。


Analytics Zoo 中有几个预定义的文本分类器可以开箱即用,即 CNN、LSTM、GRU。我们选择从 CNN 开始。我们在下面的文本中使用 Python API 来说明训练过程。


from zoo.models.textclassificationimport TextClassifiertext_classifier = TextClassifier(class_num, embedding_file, \
sequence_length=500, encoder="cnn", encoder_output_dim=256)
复制代码


在上面的 API 中,class_num 是这个问题中的类别数量,embedding_fileis 是预训练词向量文件的路径(目前只支持 Glove ),sequence_length 是每个文本记录中包含的单词数,encoder 是词编码器的类型(可以是 CNN、LSTM 或 GRU),encoder_output_dim 是这个编码器的输出。该模型接收词索引序列作为输入,输出标签。

数据收集和预处理

训练数据集中的每个记录包含两个字段,一个是 dialogue 和一个是 label。我们收集了数千条这样的记录,并通过手动和半自动的方式收集标签。然后,我们对原始文本进行数据清理,去掉无意义的标记和混淆的部分,并将它们转换为文本 RDD,每个记录的格式为一个(文本、标签)对。接下来,我们对文本 RDD 进行预处理,并输出我们的模型可以接收的正确形式。请确保你的数据清洗和处理对训练和预测都是一样的!


(How to get invoice …, 1)
(Can you send invoice to me…,1)
(Remote service connection failure…,2)
(How to buy…, 3)


上面是数据清理只有的文本 RDD 示例(每条记录是一个文本-标签对)。


1.数据读取


我们可以使用 Analytics Zoo 提供的 TextSet 以分布式方式读取文本数据,如下所示。


from zoo.feature.text import TextSetfrom zoo.common.nncontext import init_nncontext
sc = init_nncontext("Text Classification")text_set = TextSet.read(data_path, sc)
复制代码


2.分词


然后我们将句子分解为单词,将每个输入文本转换为一个标记(单词)数组,并对标记进行规范化(例如,删除未知字符并转换为小写)。


text_set = text_set.tokenize()   \                   .normalize()
复制代码


3.序列对齐


不同的文本可能会生成不同大小的标记数组。但是,文本分类模型要求输入的所有记录大小固定。因此,我们必须将标记数组对齐到相同的大小(在 parametersequence_lengthin 文本分类器中指定)。如果标记数组的大小大于所需的大小,则从开头或结尾对单词进行删除;否则,我们将无意义的单词填充到数组的末尾(例如“##”)。


text_set= text_set.shape_sequence(sequence_length)
复制代码


4.词索引


标记数组大小对齐后,需要将每个标记(单词)转换为索引,可用来查找其词向量(在文本分类器模型中)。在单词转换为索引的过程中,我们还通过删除文本中出现频率最高 N 个单词来移除停用词(即经常出现在文本中但无助于语义理解的单词,如“the”、“of”等)。


text_set= text_set.word2idx(remove_topN=10, max_words_num)
复制代码


5.转换成样本


经过以上步骤,每个文本都变成一个有形状的张量(sequence_length, 1),然后我们从每个记录构造一个 BigDL样本,以生成的张量作为特征,以标签整数作为标签。


text_set = text_set.generate_sample()
复制代码

模型训练、测试、评估和优化

在以相同的方式准备好训练数据集(train_rdd)和验证数据集(val_rdd)之后,我们实例化一个新的 TextClassifier 模型(text_classifier),然后创建一个优化器以分布式方式训练模型。我们使用稀疏分类交叉熵(Sparse Categorical Cross Entropy)作为损失函数。


train_set, val_set= text_set.random_split( \    [training_split, 1 - training_split])
model.compile(optimizer=Adagrad(learningrate=float(options.learning_rate), \ learningrate_decay=0.001), loss="sparse_categorical_crossentropy", \ metrics=['accuracy'])
model.fit(train_set,batch_size=int(options.batch_size), \ nb_epoch=max_epoch, validation_data=val_set)
复制代码


训练时可调整的参数包括轮数、批大小、学习率等。你可以指定输出指标的验证选项,如在训练过程中设置准确度验证,以检测过拟合或欠拟合。


如果在验证数据集上获得的结果不好,我们就必须优化模型。通常,这是指重复调整超参数/数据/模型、训练和验证的过程,直到结果足够好为止。通过调整学习速率、添加新数据和扩充停用词字典,我们的准确率得到了显著提高。


有关 Analytics Zoo 中的文本处理分类支持的更多细节,请参阅这些文档。


作者简介Jason Dai 是英特尔大数据技术高级首席工程师兼首席技术官,负责领导全球工程团队(包括硅谷和上海)开发先进的大数据分析和机器学习技术。他是 Apache Spark 的提交者、PMC 成员、Apache MXNet 导师、Strata Data Conference 北京站的联合主席以及 BigDL(Apache Spark 的分布式深度学习框架)的创建者。


查看英文原文:Analytics Zoo: Unified Analytics + AI Platform for Distributed Tensorflow, and BigDL on Apache Spark


2018-12-22 11:162017
用户头像

发布了 1008 篇内容, 共 387.8 次阅读, 收获喜欢 344 次。

关注

评论 1 条评论

发布
暂无评论
发现更多内容

技术分享| 基于RTM 实现的呼叫邀请如何添加推送功能?

anyRTC开发者

音视频 IM 实时消息 呼叫邀请 推送

《TiDB跨版本升级》 --流程概述

TiDB 社区干货传送门

迁移 实践案例 版本升级 管理与运维 安装 & 部署

为什么我要迁移SpringBoot到函数计算

Serverless Devs

MASA Framework 获取配置信息的方法

MASA技术团队

.net MASA Framewrok MASA

LeetCode-66. 加一(java)

bug菌

Leet Code 每日一题 9月月更

LeetCode-58. 最后一个单词的长度(java)

bug菌

Leet Code 每日一题 9月月更

计算机网络体概念

StackOverflow

编程 计算机网络 9月月更

主流定时任务解决方案全横评

Serverless Devs

spring Linux

iptables与firewalld防火墙是怎么样工作的呢?

阿柠xn

防火墙 Linux Kenel 运维‘ 9月月更

转:工业软件上云很难吗?可以微创呀!

小江

工业软件云化

中国移动NZONE 50 Pro 5G手机正式开售

Geek_2d6073

Apache DolphinScheduler PMC:开源不一定也要九死一生

Apache DolphinScheduler

海豚调度 开源社区 Apache DolphinScheduler 开源文化 #开源

千锋锋友学盟分享会:程序员百万年薪进阶指

千锋IT教育

软件测试 | 测试开发 | 持续交付-Blue Ocean 应用

测吧(北京)科技有限公司

软件测试 | 测试开发 | 跨平台API对接(Java)

测吧(北京)科技有限公司

jenkins、

软件测试 | 测试开发 | 测试左移之Sonarqube scanner使用

测吧(北京)科技有限公司

SonarQube

软件测试 | 测试开发 | 黑盒测试方法论—边界值

测吧(北京)科技有限公司

边界测试

基于函数计算自定义运行时快速部署一个 springboot 项目

Serverless Devs

「龙蜥开发者说」征稿啦!

OpenAnolis小助手

开源 征文 获奖 龙蜥开发者说 龙蜥技术

购买小间距LED显示屏前需要了解什么?

Dylan

LED显示屏 led显示屏厂家

软件测试 | 测试开发 | 构建测试平台与对应的组织架构需要哪些能力?

测吧(北京)科技有限公司

测试

PCTP考试学习笔记之二:TiDB 数据库 schema 设计

TiDB 社区干货传送门

集群管理 管理与运维 数据库架构设计

软件测试 | 测试开发 | 这些常用测试平台,你们公司在用的是哪些呢?

测吧(北京)科技有限公司

测试

参加了个算法比赛,真是一言难尽啊

捉虫大师

Go 算法 map 比赛 9月月更

20个既简单又实用的JavaScript小技巧

千锋IT教育

51单片机定时器原理及相关器件

孤衫

C语言 单片机 9月月更

设计模式的艺术 第十七章命令设计模式练习(开发一个基于Windows平台的公告板系统。该系统提供了一个主菜单(Menu),主菜单包含一些菜单项,Menu类可以增加菜单项。菜单项主要方法是click(),每个菜单项包含一个抽象命令类)

代廉洁

设计模式的艺术

dbt-tidb 1.2.0 尝鲜

TiDB 社区干货传送门

新版本/特性解读

软件测试 | 测试开发 | 一文带你了解K8S 容器编排(下)

测吧(北京)科技有限公司

测试

「工作小记」不同内容相似结构?按个开关试试

叶一一

JavaScript 前端 React Hooks 9月月更

软件测试 | 测试开发 | 测试开发基础 mvn test | 利用 Maven Surefire Plugin 做测试用例基础执行管理

测吧(北京)科技有限公司

maven

Analytics Zoo:在Apache Spark上实现分布式Tensorflow和BigDL管道的统一分析和AI平台 _AI&大模型_Jason Dai_InfoQ精选文章