tensorflow之TFRcords文件读取

TFRcords文件读取与储存

(1)TFRecords分析、存取:

TFRecords是Tensorflow设计的一种内置文件格式,是一种二进制文件,
它能更好的利用内存,更方便复制和移动

为了将二进制数据和标签(训练的类别标签)数据存储在同一个文件中

文件格式:*.tfrecords

写入文件内容:Example协议块

(2)TFRecords存储API

1、建立TFRecord存储器
tf.python_io.TFRecordWriter(path)
写入tfrecords文件
path: TFRecords文件的路径
return:写文件

method
write(record):向文件中写入一个字符串记录

close():关闭文件写入器

:字符串为一个序列化的Example,Example.SerializeToString()

(3)TFRecords读取方法API:

同文件阅读器流程,中间需要解析过程

解析TFRecords的example协议内存块
tf.parse_single_example(serialized,features=None,name=None)
解析一个单一的Example原型
serialized:标量字符串Tensor,一个序列化的Example
features:dict字典数据,键为读取的名字,值为FixedLenFeature
return:一个键值对组成的字典,键为读取的名字


tf.FixedLenFeature(shape,dtype)
shape:输入数据的形状,一般不指定,为空列表
dtype:输入数据类型,与存储进文件的类型要一致
类型只能是float32,int64,string

(4)CIFAR-10批处理结果存入tfrecords流程:

1、构造存储器

2、构造每一个样本的Example

3、写入序列化的Example

(5)读取tfrecords流程:

1、构造TFRecords阅读器

2、解析Example

3、转换格式,bytes解码

(6)代码实现:

   def write_ro_tfrecords(self, image_batch, label_batch):
        """
        将图片的特征值和目标值存进tfrecords
        :param image_batch: 10张图片的特征值
        :param label_batch: 10张图片的目标值
        :return: None
        """
        # 1、建立TFRecord存储器
        writer = tf.python_io.TFRecordWriter(FLAGS.cifar_tfrecords)

        # 2、循环将所有样本写入文件,每张图片样本都要构造example协议
        for i in range(10):
            # 取出第i个图片数据的特征值和目标值
            image = image_batch[i].eval().tostring()

            label = int(label_batch[i].eval()[0])

            # 构造一个样本的example
            example =  tf.train.Example(features=tf.train.Features(feature={
                "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
            }))

            # 写入单独的样本
            writer.write(example.SerializeToString())

        # 关闭
        writer.close()
        return None

    def read_from_tfrecords(self):

        # 1、构造文件队列
        file_queue = tf.train.string_input_producer([FLAGS.cifar_tfrecords])

        # 2、构造文件阅读器,读取内容example,value=一个样本的序列化example
        reader = tf.TFRecordReader()

        key, value = reader.read(file_queue)

        # 3、解析example
        features = tf.parse_single_example(value, features={
            "image": tf.FixedLenFeature([], tf.string),
            "label": tf.FixedLenFeature([], tf.int64),
        })

        # 4、解码内容, 如果读取的内容格式是string需要解码, 如果是int64,float32不需要解码
        image = tf.decode_raw(features["image"], tf.uint8)

        # 固定图片的形状,方便与批处理
        image_reshape = tf.reshape(image, [self.height, self.width, self.channel])

        label = tf.cast(features["label"], tf.int32)

        print(image_reshape, label)

        # 进行批处理
        image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)

        return image_batch, label_batch


if __name__ == "__main__":
    # 1、找到文件,放入列表   路径+名字  ->列表当中
    file_name = os.listdir(FLAGS.cifar_dir)

    filelist = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if file[-3:] == "bin"]

    # print(file_name)
    cf = CifarRead(filelist)

    # image_batch, label_batch = cf.read_and_decode()

    image_batch, label_batch = cf.read_from_tfrecords()

    # 开启会话运行结果
    with tf.Session() as sess:
        # 定义一个线程协调器
        coord = tf.train.Coordinator()

        # 开启读文件的线程
        threads = tf.train.start_queue_runners(sess, coord=coord)

        # 存进tfrecords文件
        # print("开始存储")
        #
        # cf.write_ro_tfrecords(image_batch, label_batch)
        #
        # print("结束存储")

        # 打印读取的内容
        print(sess.run([image_batch, label_batch]))

        # 回收子线程
        coord.request_stop()

        coord.join(threads)


版权声明:本文为XST1520203418原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。