Tensorflow的checkpoint文件转pb模型
背景
使用 tf.train.saver()保存模型时会产生多个文件,会把计算图的结构和图上参数取值分成了不同的文件存储。
保存代码
import tensorflow as tf
# 声明两个变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
sess.run(init_op)
print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比
print("v2:", sess.run(v2))
saver_path = saver.save(sess, "save/model.ckpt") # 将模型保存到save/model.ckpt文件
print("Model saved in file:", saver_path)
如:
checkpoint lm.data-00000-of-00001 lm.index lm.meta
其中:
checkpoint
:保存了一个目录下所有的模型文件列表;*.meta
:保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构,该文件可以被 tf.train.import_meta_graph 加载到当前默认的图来使用。*.data
:保存模型中每个变量的取值
模型导出
我们需要将TensorFlow的模型导出为单个文件(同时包含模型结构的定义与权重),来在别的环境中部署。
用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而用tf.train.Saver().save()导出的文件graph_def与权重是分离的。
由于graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。
ensoFlow为我们提供了convert_variables_to_constants()方法,该方法可以固化模型结构,将计算图中的变量取值以常量的形式保存,而且保存的模型可以移植到别的平台。
初体验
查看output结点
from tensorflow.python import pywrap_tensorflow
checkpoint_path = 'result/model/final/lm'
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
CKPT转PB格式
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
import os
def freeze(input_checkpoint, output_graph):
'''
:param input_checkpoint:
:param output_graph: PB模型保存路径
:return:
'''
# checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
# input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
# 指定输出的节点名称,该节点名称必须是原模型中存在的节点
output_node_names = "outputs"
saver = tf.train.import_meta_graph(
input_checkpoint + '.meta', clear_devices=True)
with tf.Session() as sess:
saver.restore(sess, input_checkpoint) # 恢复图并得到数据
output_graph_def = tf.graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess=sess,
input_graph_def=sess.graph_def, # 等于:sess.graph_def
output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
f.write(output_graph_def.SerializeToString()) #序列化输出
print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
if __name__ == '__main__':
# 输入ckpt模型路径
input_checkpoint = 'result/model/final/lm'
# 输出pb模型的路径
out_pb_path = "lm-model.pb"
# 调用freeze_graph将ckpt转为pb
freeze(input_checkpoint, out_pb_path)
if os.path.exists(out_pb_path):
print("success")
else:
print("failed")