前言
深度学习在图像处理、语音识别、自然语言处理领域的应用取得了巨大成功,但是它通常在功能强大的服务器端进行运算。如果智能手机通过网络远程连接服务器,也可以利用深度学习技术,但这样可能会很慢,而且只有在设备处于良好的网络连接环境下才行,这就需要把深度学习模型迁移到智能终端。
由于智能终端 CPU 和内存资源有限,为了提高运算性能和内存利用率,需要对服务器端的模型进行量化处理并支持低精度算法。TensorFlow 版本增加了对 Android、iOS 和 Raspberry Pi 硬件平台的支持,允许它在这些设备上执行图像分类等操作。这样就可以创建在智能手机上工作并且不需要云端每时每刻都支持的机器学习模型,带来了新的 APP。
本文主要基于看花识名 APP 应用,讲解 TensorFlow 模型如何应用于 Android 系统;在服务器端训练 TensorFlow 模型,并把模型文件迁移到智能终端;TensorFlow Android 开发环境构建以及应用开发 API。
看花识名 APP
使用 AlexNet 模型、Flowers 数据以及 Android 平台构建了“看花识名”APP。TensorFlow 模型对五种类型的花数据进行训练。如下图所示:
Daisy:雏菊
(点击放大图像)
Dandelion:蒲公英
(点击放大图像)
Roses:玫瑰
(点击放大图像)
Sunflowers:向日葵
(点击放大图像)
Tulips:郁金香
(点击放大图像)
在服务器上把模型训练好后,把模型文件迁移到Android 平台,在手机上安装APP。使用效果如下图所示,界面上端显示的是模型识别的置信度,界面中间是要识别的花:
(点击放大图像)
TensorFlow 模型如何应用于看花识名 APP 中,主要包括以下几个关键步骤:模型选择和应用、模型文件转换以及 Android 开发。如下图所示:
(点击放大图像)
(点击放大图像)
模型训练及模型文件
本章采用AlexNet 模型对Flowers 数据进行训练。AlexNet 在2012 取得了ImageNet 最好成绩,top 5 准确率达到80.2%。这对于传统的机器学习分类算法而言,已经相当出色。模型结构如下:
(点击放大图像)
本文采用TensorFlow 官方Slim( https://github.com/tensorflow/models/tree/master/slim )AlexNet 模型进行训练。
- 首先下载 Flowers 数据,并转换为 TFRecord 格式: ```
DATA_DIR=/tmp/data/flowers
python download_and_convert_data.py --dataset_name=flowers
–dataset_dir="${DATA_DIR}"
- 执行模型训练,经过 36618 次迭代后,模型精度达到 85% ``` TRAIN_DIR=/tmp/data/train python train_image_classifier.py --train_dir=${TRAIN_DIR} --dataset_dir=${DATASET_DIR} --dataset_name=flowers --dataset_split_name=train --model_name=alexnet_v2 --preprocessing_name=vgg
- 生成 Inference Graph 的 PB 文件 ```
python export_inference_graph.py --alsologtostderr
–model_name=alexnet_v2 --dataset_name=flowers --dataset_dir=${DATASET_DIR}
–output_file=alexnet_v2_inf_graph.pb
- 结合 CheckPoint 文件和 Inference GraphPB 文件,生成 Freeze Graph 的 PB 文件 ``` python freeze_graph.py --input_graph=alexnet_v2_inf_graph.pb --input_checkpoint= ${TRAIN_DIR}/model.ckpt-36618 --input_binary=true --output_graph=frozen_alexnet_v2.pb --output_node_names=alexnet_v2/fc8/squeezed
- 对 Freeze Graph 的 PB 文件进行数据量化处理,减少模型文件的大小,生成的 quantized_alexnet_v2_graph.pb 为智能终端中应用的模型文件 ```
bazel-bin/tensorflow/tools/graph_transforms/transform_graph
–in_graph=frozen_alexnet_v2.pb --outputs=“alexnet_v2/fc8/squeezed”
–out_graph=quantized_alexnet_v2_graph.pb --transforms=‘add_default_attributes
strip_unused_nodes(type=float, shape=“1,224,224,3”) remove_nodes(op=Identity,
op=CheckNumerics) fold_constants(ignore_errors=true) fold_batch_norms
fold_old_batch_norms quantize_weights quantize_nodes
strip_unused_nodes sort_by_execution_order’
为了减少智能终端上模型文件的大小,TensorFlow 中常用的方法是对模型文件进行量化处理,本文对 AlexNet CheckPoint 文件进行 Freeze 和 Quantized 处理后的文件大小变化如下图所示: (点击放大图像) [![](https://static001.infoq.cn/resource/image/5d/df/5d593ed8d32b07c574bc72de1dd597df.jpg)](/mag4media/repositories/fs/articles//zh/resources/9.jpg) 量化操作的主要思想是在模型的 Inference 阶段采用等价的 8 位整数操作代替 32 位的浮点数操作,替换的操作包括:卷积操作、矩阵相乘、激活函数、池化操作等。量化节点的输入、输出为浮点数,但是内部运算会通过量化计算转换为 8 位整数(范围为 0 到 255)的运算,浮点数和 8 位量化整数的对应关系示例如下图所示: (点击放大图像) [![](https://static001.infoq.cn/resource/image/bf/3c/bfcd220108faeae76f17696d8003983c.jpg)](/mag4media/repositories/fs/articles//zh/resources/10.jpg) 量化 Relu 操作的基本思想如下图所示: (点击放大图像) [![](https://static001.infoq.cn/resource/image/6e/e9/6e497e3b56fa00e127d65a98634274e9.jpg)](/mag4media/repositories/fs/articles//zh/resources/11.jpg) ## TensorFlow Android 应用开发环境构建 在 Android 系统上使用 TensorFlow 模型做 Inference 依赖于两个文件 libtensorflow\_inference.so 和 libandroid\_tensorflow\_inference\_java.jar。这两个文件可以通过下载 TensorFlow 源代码后,采用 bazel 编译出来,如下所示: - 下载 TensorFlow 源代码 git clone --recurse-submodules <https://github.com/tensorflow/tensorflow.git> - [下载安装 Android NDK](https://developer.android.com/ndk/downloads/older_releases.html#ndk-12b-downloads) - [下载安装 Android SDK](https://developer.android.com/studio/command-line/sdkmanager.html) - 配置 tensorflow/WORKSPACE 中 android 开发工具路径 ``` android_sdk_repository(name = "androidsdk", api_level = 23, build_tools_version = "25.0.2", path = "/opt/android",) android_ndk_repository(name="androidndk", path="/opt/android/android-ndk-r12b", api_level=14) {1}
- 编译 libtensorflow_inference.so ```
bazel build -c opt //tensorflow/contrib/android:libtensorflow_inference.so
–crosstool_top=//external:android/crosstool --host_crosstool_top=
@bazel_tools//tools/cpp:toolchain --cpu=armeabi-v7a
- 编译 libandroid\_tensorflow\_inference\_java.jar `bazel build //tensorflow/contrib/android:android_tensorflow_inference_java` TensorFlow[提供了 Android 开发的示例框架](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android),下面基于 AlexNet 模型的看花识名 APP 做一些相应源码的修改,并编译生成 Android 的安装包: - 基于 AlexNet 模型,修改 Inference 的输入、输出的 Tensor 名称 ``` private static final String INPUT_NAME = "input"; private static final String OUTPUT_NAME = "alexnet_v2/fc8/squeezed";
- 放置 quantized_alexnet_v2_graph.pb 和对应的 labels.txt 文件到 assets 目录下,并修改 Android 文件路径 ```
private static final String MODEL_FILE = “file:///android_asset/quantized_alexnet_v2_graph.pb”;
private static final String LABEL_FILE = “file:///android_asset/labels.txt”;
- 编译生成安装包 `bazel build -c opt //tensorflow/examples/android:tensorflow_demo` - 拷贝 tensorflow\_demo.apk 到手机上,并执行安装,太阳花识别效果如下图所示: (点击放大图像) [![](https://static001.infoq.cn/resource/image/25/0e/25dd09b1e992248d565cfae7b4dc760e.jpg)](/mag4media/repositories/fs/articles//zh/resources/12.jpg) ## TensorFlow 移动端应用开发 API 在 Android 系统中执行 TensorFlow Inference 操作,需要调用 libandroid\_tensorflow\_inference\_java.jar 中的 JNI 接口,主要接口如下: - 构建 TensorFlow Inference 对象,构建该对象时候会加载 TensorFlow 动态链接库 libtensorflow\_inference.so 到系统中;参数 assetManager 为 android asset 管理器;参数 modelFilename 为 TensorFlow 模型文件在 android\_asset 中的路径。 ``` TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
- 向 TensorFlow 图中加载输入数据,本 App 中输入数据为摄像头截取到的图片;参数 inputName 为 TensorFlow Inference 中的输入数据 Tensor 的名称;参数 floatValues 为输入图片的像素数据,进行预处理后的浮点值;[1,inputSize,inputSize,3] 为裁剪后图片的大小,比如 1 张 224*224*3 的 RGB 图片。
inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
- 执行模型推理; outputNames 为 TensorFlow Inference 模型中要运算 Tensor 的名称,本 APP 中为分类的 Logist 值。
inferenceInterface.run(outputNames);
- 获取模型 Inference 的运算结果,其中 outputName 为 Tensor 名称,参数 outputs 存储 Tensor 的运算结果。本 APP 中,outputs 为计算得到的 Logist 浮点数组。
inferenceInterface.fetch(outputName, outputs);
总结
本文基于看花识名 APP,讲解了 TensorFlow 在 Android 智能终端中的应用技术。首先回顾了 AlexNet 模型结构,基于 AlexNet 的 slim 模型对 Flowers 数据进行训练;对训练后的 CheckPoint 数据,进行 Freeze 和 Quantized 处理,生成智能终端要用的 Inference 模型。然后介绍了 TensorFlow Android 应用开发环境的构建,编译生成 TensorFlow 在 Android 上的动态链接库以及 java 开发包;文章最后介绍了 Inference API 的使用方式。
参考文献
- http://www.tensorflow.org
- 深度学习利器: 分布式 TensorFlow 及实例分析
- 深度学习利器:TensorFlow 使用实战
- 深度学习利器:TensorFlow 系统架构与高性能程序设计
- 深度学习利器:TensorFlow 与深度卷积神经网络
- 深度学习利器:TensorFlow 与 NLP 模型
作者简介
武维(微信:3381209@qq.com):博士,系统架构师,主要从事大数据,深度学习,云计算等领域的研发工作。
感谢蔡芳芳对本文的审校。
给InfoQ 中文站投稿或者参与内容翻译工作,请邮件至 editors@cn.infoq.com 。也欢迎大家通过新浪微博( @InfoQ , @丁晓昀),微信(微信号: InfoQChina )关注我们。
评论