迫于生计,下午学习了一下如何加载训练好的TensorFlow模型并进行预测推理.
1 安装TensorFlow 为什么要单独开一个板块说明如此简单的事呢?
当然是因为这件事只是看上去很简单而已.
在Python3环境下(我没试过Python2),如果直接pip install tensorflow
,会发现自己安装的是TensorFlow 2.x,但TF1与TF2毫无兼容性可言,所以,在大多数教程都是基于TensorFlow 1.x的情况下,还是用TF1比较靠谱.
如果安装了TF2,也可以用import tensorflow.compat.v1 as tf
来解决这个问题,但这样的方法实在是太不优雅了.
但是again,如果直接pip install tensorflow==1.x
会遇到Error.
以tensorflow 1.12.0为例,一个可行的安装方式是:
1 python3 -m pip install --upgrade https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.12.0-py3-none-any.whl
我的环境:MaxOS Big Sur 11.4 & Python 3.8.8
2 加载模型 TF的模型保存有很多种形式,这里只描述如何加载saved_model形式的模型,因为这种模式下方便实际生产中的部署.
1 2 3 4 5 6 saved_model . ├── saved_model.pb └── variables ├── variables.data-00000-of-00001 └── variables.index
加载模型主要用到的函数(官方文档 )
1 2 3 tf.saved_model.loader.load( sess, tags, export_dir, import_scope=None , **saver_kwargs )
其中tags是save_model时指定的,但如果不知道tags,仍然有办法查看.
在tensorflow的安装文件夹下,有一个叫saved_model_cli.py
的文件,可以用以下方法查看:
1 2 3 4 > python ${path_to_saved_model_cli.py} show --dir ${path_to_saved_model.pb} The given SavedModel contains the following tag-sets: serve
3 预测推理 TensorFlow的运作方式大致是,构建graph(可以理解为网络结构),其中,graph由很多node(可以理解为网络的不同部件)构成,然后开session(可以理解为一个作业)来运行这个graph.
预测推理主要用到的函数(官方文档 )
1 sess.run(fetches, feed_dict)
fetches表示希望得到这些node的输出.
feed_dict表示提供给graph运行时需要的数据.
获得node列表:
1 2 3 4 with tf.Session(graph=tf.Graph()) as sess: tf.saved_model.loader.load(sess, tag_name, folder_to_saved_model) node_name = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node] print (node_name)
获得grpah需要的数据:
1 2 3 4 with tf.Session(graph=tf.Graph()) as sess: tf.saved_model.loader.load(sess, tag_name, folder_to_saved_model) placeholders = [placeholder for op in tf.get_default_graph().get_operations() if op.type =='Placeholder' for placeholder in op.values()] print (placeholders)
查看模型的输入输出格式:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 > python ${path_to_saved_model_cli.py} show --dir ${path_to_saved_model.pb} --all MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs: signature_def['serving_default']: The given SavedModel SignatureDef contains the following input(s): inputs['input_ids'] tensor_info: dtype: DT_INT32 shape: (-1, -1) name: input_ids:0 inputs['input_mask'] tensor_info: dtype: DT_INT32 shape: (-1, -1) name: input_mask:0 inputs['segment_ids'] tensor_info: dtype: DT_INT32 shape: (-1, -1) name: segment_ids:0 The given SavedModel SignatureDef contains the following output(s): outputs['logits'] tensor_info: dtype: DT_FLOAT shape: (-1, 2) name: app/ez_dense/BiasAdd:0 outputs['predictions'] tensor_info: dtype: DT_INT32 shape: (-1) name: ArgMax:0 outputs['probabilities'] tensor_info: dtype: DT_FLOAT shape: (-1, 2) name: Softmax:0 Method name is: tensorflow/serving/predict
4 完整代码 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 class TFPredictor (object ): def __init__ (self, tags:list , saved_model_path:str , inputs:dict , outputs_keys:list ) -> dict : self.saved_model_path = saved_model_path self.inputs_keys = [key for key in inputs.keys()] self.inputs_values = [value for value in inputs.values()] self.outputs_keys = [key for key in outputs_keys] self.tags = tags def forward (self ): with tf.Session(graph=tf.Graph()) as sess: tf.saved_model.loader.load(sess, self.tags, self.saved_model_path) feed_dict = dict () for key, value in zip (self.inputs_keys, self.inputs_values): placeholder = sess.graph.get_tensor_by_name(key) feed_dict[placeholder] = value operations = [sess.graph.get_tensor_by_name(key) for key in self.outputs_keys] outputs_values = sess.run(operations, feed_dict) outputs = dict () for key, value in zip (self.outputs_keys, outputs_values): outputs[key] = value return outputs
输入:
变量名
含义
tags
列表,load图时需要的tag名
saved_model_path
字符串,保存模型的路径
inputs
字典,输入的变量名与数据
outputs_keys
列表,输出的方法名
示例:
1 2 3 4 5 6 7 8 9 10 11 12 MAX_LEN = 30 tags = ['serve' ] saved_model_path = './model' inputs = { 'input_ids:0' : np.zeros((1 , MAX_LEN)), 'input_mask:0' : np.zeros((1 , MAX_LEN)), 'segment_ids:0' : np.zeros((1 , MAX_LEN)) } outputs_keys = ['app/ez_dense/BiasAdd:0' , 'ArgMax:0' , 'Softmax:0' ] predictor = TFPredictor(tags, saved_model_path, inputs, outputs_keys) result = predictor.forward()