用Tensorboard查看训练过程

  |   0 评论   |   0 浏览

背景

训练过程中记录summary文件,使用Tensorboard来查看。

记录

如下示例中,summary会记录到tensorboard/lm目录下。

saver = tf.train.Saver()
with tf.Session() as sess:
    merged = tf.summary.merge_all()
    sess.run(tf.global_variables_initializer())
    writer = tf.summary.FileWriter('tensorboard/lm', tf.get_default_graph())
    for k in range(epochs):
        total_loss = 0
        batch_num = len(input_num) // batch_size
        batch = lm_util.get_batch(input_num, label_num, batch_size)
        for i in range(batch_num):
            # print("epoch {0}, [{1}/{2}]".format(k, i, batch_num))
            input_batch, label_batch = next(batch)
            feed = {g.x: input_batch, g.y: label_batch}
            cost, _ = sess.run([g.mean_loss, g.train_op], feed_dict=feed)
            total_loss += cost
            if (k * batch_num + i) % 10 == 0:
                rs = sess.run(merged, feed_dict=feed)
                writer.add_summary(rs, k * batch_num + i)
        print('epochs', k+1, ': average loss = ', total_loss/batch_num)

查看

在机器上,执行 tensorboard --logdir=.,然后用浏览器打开 http://localhost:6006即可。

参考