用Tensorboard查看训练过程
背景
训练过程中记录summary文件,使用Tensorboard来查看。
记录
如下示例中,summary会记录到tensorboard/lm
目录下。
saver = tf.train.Saver()
with tf.Session() as sess:
merged = tf.summary.merge_all()
sess.run(tf.global_variables_initializer())
writer = tf.summary.FileWriter('tensorboard/lm', tf.get_default_graph())
for k in range(epochs):
total_loss = 0
batch_num = len(input_num) // batch_size
batch = lm_util.get_batch(input_num, label_num, batch_size)
for i in range(batch_num):
# print("epoch {0}, [{1}/{2}]".format(k, i, batch_num))
input_batch, label_batch = next(batch)
feed = {g.x: input_batch, g.y: label_batch}
cost, _ = sess.run([g.mean_loss, g.train_op], feed_dict=feed)
total_loss += cost
if (k * batch_num + i) % 10 == 0:
rs = sess.run(merged, feed_dict=feed)
writer.add_summary(rs, k * batch_num + i)
print('epochs', k+1, ': average loss = ', total_loss/batch_num)
查看
在机器上,执行 tensorboard --logdir=.
,然后用浏览器打开 http://localhost:6006即可。