加载TensorFlow模型与预测

迫于生计,下午学习了一下如何加载训练好的TensorFlow模型并进行预测推理.

1 安装TensorFlow

为什么要单独开一个板块说明如此简单的事呢?

当然是因为这件事只是看上去很简单而已.

在Python3环境下(我没试过Python2),如果直接pip install tensorflow,会发现自己安装的是TensorFlow 2.x,但TF1与TF2毫无兼容性可言,所以,在大多数教程都是基于TensorFlow 1.x的情况下,还是用TF1比较靠谱.

如果安装了TF2,也可以用import tensorflow.compat.v1 as tf来解决这个问题,但这样的方法实在是太不优雅了.

但是again,如果直接pip install tensorflow==1.x会遇到Error.

以tensorflow 1.12.0为例,一个可行的安装方式是:

1
python3 -m pip install --upgrade https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.12.0-py3-none-any.whl

我的环境:MaxOS Big Sur 11.4 & Python 3.8.8

2 加载模型

TF的模型保存有很多种形式,这里只描述如何加载saved_model形式的模型,因为这种模式下方便实际生产中的部署.

1
2
3
4
5
6
saved_model
.
├── saved_model.pb
└── variables
   ├── variables.data-00000-of-00001
   └── variables.index

加载模型主要用到的函数(官方文档)

1
2
3
tf.saved_model.loader.load(
sess, tags, export_dir, import_scope=None, **saver_kwargs
)

其中tags是save_model时指定的,但如果不知道tags,仍然有办法查看.

在tensorflow的安装文件夹下,有一个叫saved_model_cli.py的文件,可以用以下方法查看:

1
2
3
4
> python ${path_to_saved_model_cli.py} show --dir ${path_to_saved_model.pb}

The given SavedModel contains the following tag-sets:
serve

3 预测推理

TensorFlow的运作方式大致是,构建graph(可以理解为网络结构),其中,graph由很多node(可以理解为网络的不同部件)构成,然后开session(可以理解为一个作业)来运行这个graph.

预测推理主要用到的函数(官方文档)

1
sess.run(fetches, feed_dict)

fetches表示希望得到这些node的输出.

feed_dict表示提供给graph运行时需要的数据.

获得node列表:

1
2
3
4
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, tag_name, folder_to_saved_model)
node_name = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
print(node_name)

获得grpah需要的数据:

1
2
3
4
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, tag_name, folder_to_saved_model)
placeholders = [placeholder for op in tf.get_default_graph().get_operations() if op.type=='Placeholder' for placeholder in op.values()]
print(placeholders)

查看模型的输入输出格式:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
> python ${path_to_saved_model_cli.py} show --dir ${path_to_saved_model.pb} --all

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['input_ids'] tensor_info:
dtype: DT_INT32
shape: (-1, -1)
name: input_ids:0
inputs['input_mask'] tensor_info:
dtype: DT_INT32
shape: (-1, -1)
name: input_mask:0
inputs['segment_ids'] tensor_info:
dtype: DT_INT32
shape: (-1, -1)
name: segment_ids:0
The given SavedModel SignatureDef contains the following output(s):
outputs['logits'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 2)
name: app/ez_dense/BiasAdd:0
outputs['predictions'] tensor_info:
dtype: DT_INT32
shape: (-1)
name: ArgMax:0
outputs['probabilities'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 2)
name: Softmax:0
Method name is: tensorflow/serving/predict

4 完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class TFPredictor(object):
def __init__(self, tags:list, saved_model_path:str, inputs:dict, outputs_keys:list) -> dict:
self.saved_model_path = saved_model_path
self.inputs_keys = [key for key in inputs.keys()]
self.inputs_values = [value for value in inputs.values()]
self.outputs_keys = [key for key in outputs_keys]
self.tags = tags

def forward(self):
with tf.Session(graph=tf.Graph()) as sess:
# 加载模型
tf.saved_model.loader.load(sess, self.tags, self.saved_model_path)
# 创建feed_dict
feed_dict = dict()
for key, value in zip(self.inputs_keys, self.inputs_values):
placeholder = sess.graph.get_tensor_by_name(key)
feed_dict[placeholder] = value
# 获取输出operations
operations = [sess.graph.get_tensor_by_name(key) for key in self.outputs_keys]
# 预测推理
outputs_values = sess.run(operations, feed_dict)
# 输出
outputs = dict()
for key, value in zip(self.outputs_keys, outputs_values):
outputs[key] = value
return outputs

输入:

变量名 含义
tags 列表,load图时需要的tag名
saved_model_path 字符串,保存模型的路径
inputs 字典,输入的变量名与数据
outputs_keys 列表,输出的方法名

示例:

1
2
3
4
5
6
7
8
9
10
11
12
MAX_LEN = 30
tags = ['serve']
saved_model_path = './model'
inputs = {
'input_ids:0': np.zeros((1, MAX_LEN)),
'input_mask:0': np.zeros((1, MAX_LEN)),
'segment_ids:0': np.zeros((1, MAX_LEN))
}
outputs_keys = ['app/ez_dense/BiasAdd:0', 'ArgMax:0', 'Softmax:0']

predictor = TFPredictor(tags, saved_model_path, inputs, outputs_keys)
result = predictor.forward()