Tensorflow的checkpoint文件转pb模型

  |   0 评论   |   0 浏览

背景

使用 tf.train.saver()保存模型时会产生多个文件,会把计算图的结构和图上参数取值分成了不同的文件存储。

保存代码

import tensorflow as tf
# 声明两个变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
    sess.run(init_op)
    print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比
    print("v2:", sess.run(v2))
    saver_path = saver.save(sess, "save/model.ckpt")  # 将模型保存到save/model.ckpt文件
    print("Model saved in file:", saver_path)

如:

checkpoint  lm.data-00000-of-00001  lm.index  lm.meta

其中:

  • checkpoint:保存了一个目录下所有的模型文件列表;
  • *.meta:保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构,该文件可以被 tf.train.import_meta_graph 加载到当前默认的图来使用。
  • *.data:保存模型中每个变量的取值

模型导出

我们需要将TensorFlow的模型导出为单个文件(同时包含模型结构的定义与权重),来在别的环境中部署。

用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而用tf.train.Saver().save()导出的文件graph_def与权重是分离的。

由于graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。

ensoFlow为我们提供了convert_variables_to_constants()方法,该方法可以固化模型结构,将计算图中的变量取值以常量的形式保存,而且保存的模型可以移植到别的平台。

初体验

查看output结点

from tensorflow.python import pywrap_tensorflow
checkpoint_path = 'result/model/final/lm'
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    print("tensor_name: ", key)

CKPT转PB格式

import tensorflow as tf

from tensorflow.python.tools import freeze_graph
import os

def freeze(input_checkpoint, output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB模型保存路径
    :return:
    '''
    # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
    # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径

    # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
    output_node_names = "outputs"
    saver = tf.train.import_meta_graph(
        input_checkpoint + '.meta', clear_devices=True)

    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)  # 恢复图并得到数据

        output_graph_def = tf.graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=sess.graph_def, # 等于:sess.graph_def
            output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开

        with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
            f.write(output_graph_def.SerializeToString()) #序列化输出
        print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点


if __name__ == '__main__':
    # 输入ckpt模型路径
    input_checkpoint = 'result/model/final/lm'

    # 输出pb模型的路径
    out_pb_path = "lm-model.pb"

    # 调用freeze_graph将ckpt转为pb
    freeze(input_checkpoint, out_pb_path)
    if os.path.exists(out_pb_path):
        print("success")
    else:
        print("failed")

参考