在这篇 tutorial 中,我将主要介绍如何 freeze 一个训练好的 tensorflow 模型并部署成 webserver,webserver 使用的是 python flask 框架(其它框架也可以)另外训练 tensorflow 模型时,数据输入方式选择 placeholder 加载方式,并且给 tensor variable “name” 赋值,后边 freeze 模型时会用到。
1x = tf.placeholder(tf.float32, shape=[None, img_size,img_size,num_channels], name='x')
2
3y_true = tf.placeholder(tf.float32, shape=[None, num_classes], name='y_true')
复制代码
freeze model
freeze tensorflow model 是什么,为什么要 freeze,让我们先看下经典的 Alexnet network 结构:
1conv1 layer: (11*11)*3*96 (weights) + 96 (biases) = 34944
2conv2 layer: (5*5)*96*256 (weights)+ 256 (biases) = 614656
3conv3 layer: (3*3)*256*384 (weights) + 384 (biases) = 885120
4conv4 layer: (3*3)*384*384 (weights) + 384 (biases) = 1327488
5conv5 layer: (3*3)*384*256 (weights) + 256 (biases) = 884992
6fc1 layer: (6*6)*256*4096 (weights) + 4096 (biases) = 37752832
7fc2 layer: 4096*4096 (weights) + 4096 (biases) = 16781312
8fc3 layer: 4096*1000 (weights) + 1000 (biases) = 4097000
复制代码
网络所要计算的参数量已经超过了 6000 万,除此之外在网络训练时梯度反向传播的过程中还有相同数量的梯度值需要计算。tensorflow 训练的模型文件包含所有这些参数,但当你部署这些模型时是不需要这些梯度值的。freeze 就是把所需要的 tensorflow graph、weight 等参数保存到一个文件中的过程。
tensorflow model 的参数包含在下面四类文件中:
1)model-ckpt.meta
This contains the complete graph. This contains a serialized MetaGraphDef protocol buffer, it contains the graphDef that describles the data-flow , annotations for variables, input pipelines and other relevant information.
2)model-ckpt.data-0000-of-00001
This contains all the values of variables (weights, biases, placeholders, gradients, hyper-parameters etc.)
3)model-ckpt.index
It is an immutable table(tensorflow::table::Table),Each key is a name of a tensor and it is value is a serialized BundleEntryProto.Each BundleEntryProto describles the metadata of a Tensor.
4)checkpoint
All the checkpoint information.
总的来说当我们想把模型部署到 webserber 的时候,我们就要去除一些不必要的 meta-data, gradients and unnecessary training variables 以及 encapsulate 压缩剩余的参数到一个文件中,这个压缩后的单个文件(.pb extension)被叫做 “frozen graph def”。
freeze graph 代码如下:
1import tensorflow as tf
2from tensorflow.python.framework import graph_util
3import os,sys
4import argparse
5# 选择保存模型inference 时需要的tensor variable
6parser = argparse.ArgumentParser()
7parser.add_argument(
8 '--meta',
9 required=True,
10 type=str,
11 help='input model checkpoint meta data file (.meta)'
12 )
13parser.add_argument(
14 '--prefix',
15 required=True,
16 type=str,
17 help='input model data prefix')
18FLAGS, unparsed = parser.parse_known_args()
19# 确定你想从网络中保存哪个output, 大多数时候你只会用到预测节点,这里我们只保存预测节点
20output_node_names = "y_pred"
21#加载保存graph 的.meta 文件并在会话中恢复weights
22#saver = tf.train.import_meta_graph('model.ckpt-74928.meta', clear_devices=True)
23saver = tf.train.import_meta_graph(FLAGS.meta, clear_devices=True)
24
25# 把graph 转换为 graph_def
26graph = tf.get_default_graph()
27input_graph_def = graph.as_graph_def()
28sess = tf.Session()
29# 使用 graph_util中的函数 convert_variables_to_constants 保存graph_def 以及网络中的ends
30#saver.restore(sess, "./model.ckpt-74928")
31saver.restore(sess, FLAGS.prefix)
32output_graph_def = graph_util.convert_variables_to_constants(
33 sess, # The session is used to retrieve the weights
34 input_graph_def, # The graph_def is used to retrieve the nodes
35 output_node_names.split(",") # The output node names are used to select the usefull nodes
36)
37output_graph="estate_model.pb"
38# 最后序列化并把output graph 写入.pb 文件
39with tf.gfile.GFile(output_graph, "wb") as f:
40 f.write(output_graph_def.SerializeToString())
41sess.close()
42# 最后模型从600多M 减小到 200M
43# 加载使用frozen 后的模型
44# 创建graph并加载权重使之保存到内存中(否则每次request 都会重新加载权重)
45def load_graph(trained_model):
46 """
47 method 1: load graph as default graph.
48 #Unpersists graph from file as default graph.
49 with tf.gfile.GFile(trained_model, 'rb') as f:
50 graph_def = tf.GraphDef()
51 graph_def.ParseFromString(f.read())
52 tf.import_graph_def(graph_def, name='')
53 """
54 #load graph
55 with tf.gfile.GFile(trained_model, "rb") as f:
56 graph_def = tf.GraphDef()
57 graph_def.ParseFromString(f.read())
58 tf.import_graph_def(graph_def, name='')
59 with tf.Graph().as_default() as graph:
60 tf.import_graph_def(
61 graph_def,
62 input_map=None,
63 return_elements=None,
64 name="")
65 return graph
复制代码
最后加载保存的.pb 文件
1app = Flask(__name__)
2FLAGS, unparsed = parser.parse_known_args()
3g1 = load_graph(FLAGS.graph1)
4session1 = tf.Session(graph=g1, config=config)
5@app.route('/image_classification', methods=['POST'])
6def parse_request():
7...
8...
9app.run(host="10.200.0.174", port=int("16888"), debug=True, use_reloader=False)
复制代码
问题与解决
以下是在部署模型中遇到的一些坑:
1.模型 inference 耗时严重
运行 freeze 后的模型发现单张图片的 inference 时间消耗达到了几秒钟,经过定位发现是每次 inference 时 tensorflow 会把所有的参数从内存加载到 GPU 显存中,本质上,Tensorflow 在每次启用 run_graph 时,将所有计算加载至内存中,如果你试着在 GPU 上执行推断时会明显发现这一现象,你会看到 GPU 内存随着 tensorflow 在 GPU 上加载和卸载模型参数而升降。
解决方案:
去掉 with tf.Session() as sess 构造,向 run_graph 添加 sess 变量,这样处理后图模型的参数只会在 webserver 第一次启动时从内存加载到 GPU 显存消耗一段时间,之后每次 inference 模型参数都是在 GPU 显存中。
2.tensorflow 内存泄漏以及耗时不断增加的问题
问题代码:
1with tf.Graph().as_default():
2 # build graph
3 preprocessed_image = tf.placeholder(tf.float32, shape=(image_size,image_size,3), name="preprocessed_images")
4 processed_image = tf.expand_dims(preprocessed_image, 0)
5 # execute graph
6 with tf.Session() as sess:
7 image_string_tmp = tf.gfile.FastGFile(line, 'rb').read()
8 # 严禁把tf.image.decode_image() operate 写在此处
9 image_decode_tmp = tf.image.decode_image(testImage_string_tmp, channels=3)
10 preprocessed_image_tmp = inception_preprocessing.preprocess_image(image_decode_tmp, image_size, image_size, is_training=False)
11 preprocessed_image_tmp_val = sess.run([preprocessed_image_tmp])
12 np_probabilities = sess.run(probabilities,{"preprocessed_image:0":preprocessed_image_tmp_val[0]})
复制代码
通过使用 time.time() 和 resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 记录每一步骤的耗时以及内存占用情况。
从日志从发现是 tf.image.decode_image 造成的内存泄露以及耗时不断增加的问题。
解决方案
1def preprocess(img_name, height, width,
2 central_fraction=0.875, scope=None):
3 """
4 :param image: preprocess image name
5 :param height:
6 :param width:
7 :param central_fraction: fraction of the image to crop
8 :param scope: scope for name_scope
9 :return: 3-D float Tensor of prepared image.
10 """
11 image_raw_data = tf.gfile.FastGFile(img_name, 'r').read()
12 file_extension = img_name.rsplit('.', 1)[1]
13 logging.info("file_extension: %s", file_extension)
14 if file_extension == 'jpg' or file_extension == 'jpeg':
15 image_raw_data = tf.image.decode_jpeg(image_raw_data)
16 elif file_extension == 'png':
17 image_raw_data = tf.image.decode_png(image_raw_data)
18 image_raw_data = tf.image.encode_jpeg(image_raw_data)
19 image_raw_data = tf.image.decode_jpeg(image_raw_data)
20
21
22def run_graph1(filename, sess):
23 # build graph
24 with sess.graph.as_default():
25 image_width = 256
26 image_height = 256
27 num_channels = 3
28 start_load_graph = time.time()
29 y_pred = sess.graph.get_tensor_by_name("y_pred:0")
30 ## Let's feed the images to the input placeholders
31 x = sess.graph.get_tensor_by_name("x:0")
32 # y_true = graph.get_tensor_by_name("y_true:0")
33 # y_test_images = np.zeros((1, 2))
34 #sess = tf.Session(graph=graph, config=config)
35 load_graph_elapsed = time.time() - start_load_graph
36 logging.info("load_graph_elapsed: %f:", load_graph_elapsed)
37
38 # compute preprocess image time
39 start_process = time.time()
40 images = preprocess(os.path.join(UPLOAD_FOLDER, filename), image_height, image_width)
41 process_elapsed = time.time() - start_process
42 logging.info("process_elapsed: %f:", process_elapsed)
43 # execute graph
44 image = images.eval(session=sess)
45 x_batch = image.reshape(1, image_height, image_width, num_channels)
46 feed_data_time = time.time()
47 ### Creating the feed_dict that is required to be fed to calculate y_pred
48 feed_dict_testing = {x: x_batch}
复制代码
虽然 tf.image.decode_image 仅仅是对图片进行解码(把图片字符转换成 tensor,可能存在为 tensor 分配内存的操作),在使用 tensorflow 的过程中把涉及 tensor 的相关操作放在构建图中。
3.加载多个模型
在 Tensorflow 中,所有操作对象都包装到相应的 Session 中的,所以想要使用不同的模型就需要将这些模型加载到不同的 Session 中并在使用的时候申明是哪个 Session,从而避免由于 Session 和想使用的模型不匹配导致的错误。而使用多个 graph,就需要为每个 graph 使用不同的 Session,但是每个 graph 也可以在多个 Session 中使用,这个时候就需要在每个 Session 使用的时候明确申明使用的 graph。
需要注意的是由于有多个 graph,所以 sess.graph 与 tf.get_default_value 的值是不相等的,因此在进入 sess 的时候必须 sess.graph.as_default()明确申明 sess.graph 为当前默认 graph,否则就会报错。
1def run_graph1(filename, sess):
2 with sess.graph.as_default():
3 image_width = 256
4 image_height = 256
5 num_channels = 3
6 start_load_graph = time.time()
7 y_pred = sess.graph.get_tensor_by_name("y_pred:0")
8 ## Let's feed the images to the input placeholders
9 x = sess.graph.get_tensor_by_name("x:0")
10 load_graph_elapsed = time.time() - start_load_graph
11 logging.info("load_graph_elapsed: %f:", load_graph_elapsed)
12 # compute preprocess image time
13 start_process = time.time()
14 images = preprocess(os.path.join(UPLOAD_FOLDER, filename), image_height, image_width)
15 process_elapsed = time.time() - start_process
16 logging.info("process_elapsed: %f:", process_elapsed)
17 image = images.eval(session=sess)
18 x_batch = image.reshape(1, image_height, image_width, num_channels)
19 feed_data_time = time.time()
20 ### Creating the feed_dict that is required to be fed to calculate y_pred
21 feed_dict_testing = {x: x_batch}
22 feed_data_elapsed = time.time() - feed_data_time
23 logging.info("feed_data_time:", feed_data_elapsed)
24 start_compute_time = time.time()
25 result = sess.run(y_pred, feed_dict=feed_dict_testing)
26 compute_elapsed_time = time.time() - start_compute_time
27 logging.info("compute_elapsed_time: %f:", compute_elapsed_time)
28 return result
29
30
31def run_graph2(filename, sess):
32 with sess.graph.as_default():
33 image_width = 256
34 image_height = 256
35 num_channels = 3
36 start_load_graph = time.time()
37 y_pred = sess.graph.get_tensor_by_name("y_pred:0")
38 ## Let's feed the images to the input placeholders
39 x = sess.graph.get_tensor_by_name("x:0")
40 load_graph_elapsed = time.time() - start_load_graph
41 logging.info("load_graph_elapsed: %f:", load_graph_elapsed)
42 # compute preprocess image time
43 start_process = time.time()
44 images = preprocess(os.path.join(UPLOAD_FOLDER, filename), image_height, image_width)
45 process_elapsed = time.time() - start_process
46 logging.info("process_elapsed: %f:", process_elapsed)
47 image = images.eval(session=sess)
48 x_batch = image.reshape(1, image_height, image_width, num_channels)
49 feed_data_time = time.time()
50 ### Creating the feed_dict that is required to be fed to calculate y_pred
51 feed_dict_testing = {x: x_batch}
52 feed_data_elapsed = time.time() - feed_data_time
53 logging.info("feed_data_time:", feed_data_elapsed)
54 start_compute_time = time.time()
55 result = sess.run(y_pred, feed_dict=feed_dict_testing)
56 compute_elapsed_time = time.time() - start_compute_time
57 logging.info("compute_elapsed_time: %f:", compute_elapsed_time)
58 return result
59
60# web框架加载时就把不同的图赋值给不同的session
61app = Flask(__name__)
62FLAGS, unparsed = parser.parse_known_args()
63g1 = load_graph(FLAGS.graph1)
64g2 = load_graph(FLAGS.graph2)
65session1 = tf.Session(graph=g1, config=config)
66session2 = tf.Session(graph=g2, config=config)
复制代码
作者介绍:
作者爱贝(企业代号名),目前负责贝壳找房图像处理方向的相关工作。
本文转载自公众号贝壳产品技术(ID:gh_9afeb423f390)。
原文链接:
https://mp.weixin.qq.com/s/0cmEmIC_CEgiC_o4JwTHnw
评论