python中TFRecord的读写
背景
TFRecord是一种数据封装格式,做AI必备技能。
使用TFRocord
存储数据的好处:
- 为了更加方便的建图,原来使用
placeholder
的话,还要每次feed_dict
一下,使用TFRecord
+Dataset
的时候直接就把数据读入操作当成一个图中的节点,就不用每次都feed
了。 - 可以方便的和
Estimator
进行对接。
初体验
创建
先从0创建一个简单的tfrecord文件。文件中的每1个样本由3个feature:标量、向量和矩阵组成,样本共重复3次。
#encoding=utf-8
import tensorflow as tf
import numpy as np
writer = tf.python_io.TFRecordWriter('data.tfrecord')
# 标量
scalars = np.array([31,32,33],dtype=np.int64)
# 向量
vectors = np.array([[0.1,0.1,0.1],
[0.2,0.2,0.2],
[0.3,0.3,0.3]],dtype=np.float32)
# 矩阵
matrices = np.array([np.array((vectors[0],vectors[0])),
np.array((vectors[1],vectors[1])),
np.array((vectors[2],vectors[2]))],dtype=np.float32)
# #张量
# img=mpimg.imread('/home/leiwei/桌面/flower.jpeg')
# tensors = np.array([img,img,img])
# print(scalars)
# print('*'*50)
# print(vectors)
# print('*'*50)
# print(matrices)
# print('*'*50)
# print(tensors)
# 这里我们将会写3个样本,每个样本里有4个feature:标量,向量,矩阵,张量
for i in range(3):
# 创建字典
features={}
# 写入标量,类型Int64,由于是标量,所以"value=[scalars[i]]" 变成list
features['scalar'] = tf.train.Feature(int64_list=tf.train.Int64List(value=[scalars[i]]))
# 写入向量,类型float,本身就是list,所以"value=vectors[i]"没有中括号
features['vector'] = tf.train.Feature(float_list = tf.train.FloatList(value=vectors[i]))
# 写入矩阵,类型float,本身是矩阵,一种方法是将矩阵flatten成list
features['matrix'] = tf.train.Feature(float_list = tf.train.FloatList(value=matrices[i].reshape(-1)))
# 然而矩阵的形状信息(2,3)会丢失,需要存储形状信息,随后可转回原形状
features['matrix_shape'] = tf.train.Feature(int64_list = tf.train.Int64List(value=matrices[i].shape))
# # 写入张量,类型float,本身是三维张量,另一种方法是转变成字符类型存储,随后再转回原类型
# features['tensor'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[tensors[i].tostring()]))
# # 存储丢失的形状信息(806,806,3)
# features['tensor_shape'] = tf.train.Feature(int64_list = tf.train.Int64List(value=tensors[i].shape))
#将存有所有feature的字典送入tf.train.Features中
tf_features = tf.train.Features(feature= features)
# 再将其变成一个样本example
tf_example = tf.train.Example(features = tf_features)
# 序列化该样本
tf_serialized = tf_example.SerializeToString()
# 写入一个序列化的样本
writer.write(tf_serialized)
# 由于上面有循环3次,所以到此我们已经写了3个样本
# 关闭文件
writer.close()
读取解析
直接读取
读取TFRecord格式
# encoding=utf-8
import tensorflow as tf
import numpy as np
def _parse_record(example_photo):
features = {
'scalar': tf.FixedLenFeature(1, tf.int64),
'vector': tf.FixedLenFeature(3, tf.float32),
'matrix': tf.FixedLenFeature(6, tf.float32),
'matrix_shape': tf.FixedLenFeature(2, tf.int64),
}
parsed_features = tf.parse_single_example(example_photo, features=features)
return parsed_features
if __name__ == '__main__':
dataset = tf.data.TFRecordDataset('data.tfrecord')
# dataset = dataset.repeat(100)
batch_num = 1
dataset = dataset.map(_parse_record).batch(batch_num)
iterator = dataset.make_one_shot_iterator()
sess = tf.Session()
flag = True
for i in range(3):
try:
features = sess.run(iterator.get_next())
print(features)
except:
print("get next Error: i={0}".format(i))
结果
{'matrix_shape': array([[2, 3]]), 'scalar': array([[31]]), 'vector': array([[0.1, 0.1, 0.1]], dtype=float32), 'matrix': array([[0.1, 0.1, 0.1, 0.1, 0.1, 0.1]], dtype=float32)}
{'matrix_shape': array([[2, 3]]), 'scalar': array([[32]]), 'vector': array([[0.2, 0.2, 0.2]], dtype=float32), 'matrix': array([[0.2, 0.2, 0.2, 0.2, 0.2, 0.2]], dtype=float32)}
{'matrix_shape': array([[2, 3]]), 'scalar': array([[33]]), 'vector': array([[0.3, 0.3, 0.3]], dtype=float32), 'matrix': array([[0.3, 0.3, 0.3, 0.3, 0.3, 0.3]], dtype=float32)}
不定长读取
解析方式有两种:
-
定长特征解析:tf.FixedLenFeature(shape, dtype, default_value)
- shape:可当reshape来用,如vector的shape从(3,)改动成了(1,3)。
- 注:如果写入的feature使用了.tostring() 其shape就是()
- dtype:必须是tf.float32, tf.int64, tf.string中的一种。
- default_value:feature值缺失时所指定的值。
-
不定长特征解析:tf.VarLenFeature(dtype)
-
注:可以不明确指定shape,但得到的tensor是SparseTensor。
这里改变上一节代码中的其中一行:'matrix': tf.VarLenFeature(tf.float32),
将matrix设置为可变长的。
# encoding=utf-8
import tensorflow as tf
import numpy as np
def _parse_record(example_photo):
features = {
'scalar': tf.FixedLenFeature(1, tf.int64),
'vector': tf.FixedLenFeature(3, tf.float32),
'matrix': tf.VarLenFeature(tf.float32),
'matrix_shape': tf.FixedLenFeature(2, tf.int64),
}
parsed_features = tf.parse_single_example(example_photo, features=features)
return parsed_features
if __name__ == '__main__':
dataset = tf.data.TFRecordDataset('data.tfrecord')
# dataset = dataset.repeat(100)
batch_num = 1
dataset = dataset.map(_parse_record).batch(batch_num)
iterator = dataset.make_one_shot_iterator()
sess = tf.Session()
flag = True
for i in range(3):
try:
features = sess.run(iterator.get_next())
print(features)
except:
print("get next Error: i={0}".format(i))
结果
{'matrix_shape': array([[2, 3]]), 'scalar': array([[31]]), 'vector': array([[0.1, 0.1, 0.1]], dtype=float32), 'matrix': SparseTensorValue(indices=array([[0, 0],
[0, 1],
[0, 2],
[0, 3],
[0, 4],
[0, 5]]), values=array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1], dtype=float32), dense_shape=array([1, 6]))}
{'matrix_shape': array([[2, 3]]), 'scalar': array([[32]]), 'vector': array([[0.2, 0.2, 0.2]], dtype=float32), 'matrix': SparseTensorValue(indices=array([[0, 0],
[0, 1],
[0, 2],
[0, 3],
[0, 4],
[0, 5]]), values=array([0.2, 0.2, 0.2, 0.2, 0.2, 0.2], dtype=float32), dense_shape=array([1, 6]))}
{'matrix_shape': array([[2, 3]]), 'scalar': array([[33]]), 'vector': array([[0.3, 0.3, 0.3]], dtype=float32), 'matrix': SparseTensorValue(indices=array([[0, 0],
[0, 1],
[0, 2],
[0, 3],
[0, 4],
[0, 5]]), values=array([0.3, 0.3, 0.3, 0.3, 0.3, 0.3], dtype=float32), dense_shape=array([1, 6]))}
格式转换
不定长读取中,可以使用格式转换方法。
转变特征
# 解码字符
parsed_example['tensor'] = tf.decode_raw(parsed_example['tensor'], tf.uint8)
# 稀疏表示 转为 密集表示
parsed_example['matrix'] = tf.sparse_tensor_to_dense(parsed_example['matrix'])
改变形状
# 转变matrix形状
parsed_example['matrix'] = tf.reshape(matrix, parsed_features['matrix_shape'])
# 转变tensor形状
parsed_example['tensor'] = tf.reshape(parsed_example['tensor'], parsed_example['tensor_shape'])
可初始化读取
下面示例中,可以感知到原始数据读取循环的次数。
# encoding=utf-8
import tensorflow as tf
import numpy as np
def _parse_record(example_photo):
features = {
'scalar': tf.FixedLenFeature(1, tf.int64),
'vector': tf.FixedLenFeature(3, tf.float32),
'matrix': tf.VarLenFeature(tf.float32),
'matrix_shape': tf.FixedLenFeature(2, tf.int64),
}
parsed_features = tf.parse_single_example(example_photo, features=features)
return parsed_features['scalar']
if __name__ == '__main__':
dataset = tf.data.TFRecordDataset('data.tfrecord')
# dataset = dataset.shuffle(3)
dataset = dataset.repeat(3)
batch_size = 2
dataset = dataset.map(_parse_record).batch(batch_size)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
sess = tf.Session()
NUM_EPOCHS = 10
for epoch in range(NUM_EPOCHS):
sess.run(iterator.initializer)
print('Starting epoch %d / %d' % (epoch + 1, NUM_EPOCHS))
while True:
try:
features = sess.run(next_element)
print(features)
except Exception as e:
print("get next Error: epoch={0}".format(epoch))
break
重复读取
重复读取3次:dataset = dataset.repeat(3)
num为空表示无限重复下去
不设置则表示只重复一次
shuffle
shuffle:
dataset.shuffle(buffer_size,
seed=None,
reshuffle_each_iteration=True)
buffer_size=1时,其实就是不shuffle。
batch_size
batch_size,即一个batch中读取几个样本。
参数 drop_remainder = False,即最后一个batch如果样本数不足的话,是否丢弃。
batch_size = 2
dataset = dataset.map(_parse_record).batch(batch_size, drop_remainder=False)
iterator
iterator有多种实现。用可重新初始化迭代器初始化训练数据(为了可以有训练完一个数据集的信号),用单次迭代器配合无限重复次数使用验证集(我的程序会运行一次训练集的同时运行测试集判断有没有过拟合)。
单次
单次迭代器是最简单的迭代器形式,仅支持对数据集进行一次迭代,不需要显式初始化。
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
for i in range(100):
value = sess.run(next_element)
assert i == value
可初始化
max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
value = sess.run(next_element)
assert i == value
# Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(100):
value = sess.run(next_element)
assert i == value
可重新初始化
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
training_dataset.output_shapes)
可馈送
iterator = tf.data.Iterator.from_string_handle(
handle, training_dataset.output_types, training_dataset.output_shapes)
迭代结果
在 dataset.map
中的函数处理中,返回结果可以根据需求来自定义。
异常处理
数据读取完成后,读不到数据时,会报异常:OutOfRangeError。