TFRcords文件读取与储存
(1)TFRecords分析、存取:
TFRecords是Tensorflow设计的一种内置文件格式,是一种二进制文件,
它能更好的利用内存,更方便复制和移动
为了将二进制数据和标签(训练的类别标签)数据存储在同一个文件中
文件格式:*.tfrecords
写入文件内容:Example协议块
(2)TFRecords存储API
1、建立TFRecord存储器
tf.python_io.TFRecordWriter(path)
写入tfrecords文件
path: TFRecords文件的路径
return:写文件
method
write(record):向文件中写入一个字符串记录
close():关闭文件写入器
注:字符串为一个序列化的Example,Example.SerializeToString()
(3)TFRecords读取方法API:
同文件阅读器流程,中间需要解析过程
解析TFRecords的example协议内存块
tf.parse_single_example(serialized,features=None,name=None)
解析一个单一的Example原型
serialized:标量字符串Tensor,一个序列化的Example
features:dict字典数据,键为读取的名字,值为FixedLenFeature
return:一个键值对组成的字典,键为读取的名字
tf.FixedLenFeature(shape,dtype)
shape:输入数据的形状,一般不指定,为空列表
dtype:输入数据类型,与存储进文件的类型要一致
类型只能是float32,int64,string
(4)CIFAR-10批处理结果存入tfrecords流程:
1、构造存储器
2、构造每一个样本的Example
3、写入序列化的Example
(5)读取tfrecords流程:
1、构造TFRecords阅读器
2、解析Example
3、转换格式,bytes解码
(6)代码实现:
def write_ro_tfrecords(self, image_batch, label_batch):
"""
将图片的特征值和目标值存进tfrecords
:param image_batch: 10张图片的特征值
:param label_batch: 10张图片的目标值
:return: None
"""
# 1、建立TFRecord存储器
writer = tf.python_io.TFRecordWriter(FLAGS.cifar_tfrecords)
# 2、循环将所有样本写入文件,每张图片样本都要构造example协议
for i in range(10):
# 取出第i个图片数据的特征值和目标值
image = image_batch[i].eval().tostring()
label = int(label_batch[i].eval()[0])
# 构造一个样本的example
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
}))
# 写入单独的样本
writer.write(example.SerializeToString())
# 关闭
writer.close()
return None
def read_from_tfrecords(self):
# 1、构造文件队列
file_queue = tf.train.string_input_producer([FLAGS.cifar_tfrecords])
# 2、构造文件阅读器,读取内容example,value=一个样本的序列化example
reader = tf.TFRecordReader()
key, value = reader.read(file_queue)
# 3、解析example
features = tf.parse_single_example(value, features={
"image": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64),
})
# 4、解码内容, 如果读取的内容格式是string需要解码, 如果是int64,float32不需要解码
image = tf.decode_raw(features["image"], tf.uint8)
# 固定图片的形状,方便与批处理
image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
label = tf.cast(features["label"], tf.int32)
print(image_reshape, label)
# 进行批处理
image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
return image_batch, label_batch
if __name__ == "__main__":
# 1、找到文件,放入列表 路径+名字 ->列表当中
file_name = os.listdir(FLAGS.cifar_dir)
filelist = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if file[-3:] == "bin"]
# print(file_name)
cf = CifarRead(filelist)
# image_batch, label_batch = cf.read_and_decode()
image_batch, label_batch = cf.read_from_tfrecords()
# 开启会话运行结果
with tf.Session() as sess:
# 定义一个线程协调器
coord = tf.train.Coordinator()
# 开启读文件的线程
threads = tf.train.start_queue_runners(sess, coord=coord)
# 存进tfrecords文件
# print("开始存储")
#
# cf.write_ro_tfrecords(image_batch, label_batch)
#
# print("结束存储")
# 打印读取的内容
print(sess.run([image_batch, label_batch]))
# 回收子线程
coord.request_stop()
coord.join(threads)
版权声明:本文为XST1520203418原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。