Tensorflow——TFRecord文件的制作及读取

Tensorflow——TFRecord文件的制作及读取

一、TFRecord文件的制作

TFRecord是TensorFlow提供的统一存储数据的一种格式。在训练一个深度学习的模型之前,往往需要先制作数据集。这里主要介绍在TensorFlow框架下,将图像数据库及其对应的标签制作成TFRecord文件,便于图像数据的输入。直接贴个样例代码讲解吧。

import tensorflow as tf
import os
from PIL import Image

#———————————————————————————TFRecord文件制作———————————————————————————————
image_train_path = '' # label路径,label是txt文件
label_train_path = ''
image_test_path = '' # 图像文件路径
label_test_path = ''
data_path = './data' # .tfrecords 文件保存路径

tfRecord_train = './data/train.tfrecords' # tfrecord文件名
tfRecord_test = './data/test.tfrecords'


def write_tfRecord(tfRecordName, image_path, label_path):
    writer = tf.python_io.TFRecordWriter(tfRecordName)
    num_pic = 0  #计数器
    f = open(label_path,'r')  #读label '.txt'文件:每行由图片名与标签组成
    contents = f.readlines()  #获取整个文件
    f.close()
    for content in contents:  # 遍历每行的内容
        value = content.split(' ')  # 空格分隔每行的图片名与标签,分割后组成列表value
        img = Image.open(image_path + value[0])  # 图片文件的路径
        img_raw = img.tobytes()  # 将图片转换成二进制数据
        labels = float(value[1])
        # width = int(value[2]) # 标签不止一个数据的情况
        # height = int(value[3])
        
		# TFRecord文件的固定格式
        example = tf.train.Example(features=tf.train.Features(feature={
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),#放二进制图片
            'label': tf.train.Feature(float_list=tf.train.FloatList(value=[labels])),#放图片标签,需要注意数据类型
            # 'height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
            # 'width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
        })) #封装
        writer.write(example.SerializeToString()) #序列化example
        num_pic += 1
    print('the number of picture:',num_pic) #打印进度提示
    writer.close()
    print('write tfrecord successful')


def generate_tfRecord():
    isExists = os.path.exists(data_path) #判断保存路径
    if not isExists:
        os.makedirs(data_path)
        print('The directory was created successfully.')
    else:
        print('directory already exists')
    write_tfRecord(tfRecord_train, image_train_path, label_train_path) 
    write_tfRecord(tfRecord_test, image_test_path, label_test_path)


def main():
    generate_tfRecord()


if __name__ == '__main__':
    main()

二、TFRecord文件的读取

TensorFlow还提供了一套数据处理框架——数据集(Dataset)。用队列读取TFReord文件也是可以的,但会出现警告,TensorFlow推荐使用数据集作为输入数据的首选框架。接着介绍利用Dataset读取文件,首先需要解析TFRecord格式的文件,在利用Dataset读取数据。下面的代码中列出解析函数,以及直接读取数据的例程,为了便于放入整体的代码中,我另外写了一个读取数据的函数,仅供参考。

import tensorflow as tf
# 获取TFRecord文件
train_files = tf.train.match_filenames_once("./data/train.tfrecords*")
test_files = tf.train.match_filenames_once("./data/test.tfrecords*")


#——————————————————————————— 解析TFRecord文件————————————————————————————————
def parser(record):
    features = tf.parse_single_example(
        record,
        features={
            'label': tf.FixedLenFeature([], tf.float32),
            'img_raw': tf.FixedLenFeature([], tf.string),
            # 'height':tf.FixedLenFeature([],tf.int64),
            # 'width':tf.FixedLenFeature([],tf.int64),
        })
	
	# 恢复数据类型及shape
    image = tf.decode_raw(features['img_raw'], tf.uint8)
    label = features['label']
    image_size = 256 # 根据数据库中图像大小定义
    # height = tf.cast(features['height'], tf.int64)
    # width = tf.cast(features['width'], tf.int64)  # 转换数据类型
    image_norm = tf.cast(image, tf.float32) * (1. / 255)
    decoded_image = tf.reshape(image_norm, [image_size, image_size, 3])
    decoded_label = tf.reshape(label,[1])

    return decoded_image, decoded_label

# ——————————————————————————————读取数据例程————————————————————————————————
shuffle_buffer = 100       # 定义随机打乱顺序是buffer的大小
batch_size = 10

# 定义读取训练数据的数据集
dataset = tf.data.TFRecordDataset(train_files)
dataset = dataset.map(parser) # 解析数据

dataset = dataset.shuffle(shuffle_buffer).batch(batch_size)

NUM_EPOCHS = 10         # 数据集重复次数
dataset = dataset.repeat(NUM_EPOCHS) # 训练多个EPOCH

# 定义数据迭代器
iterator = dataset.make_initializable_iterator()
image_batch, label_batch = iterator.get_next()

with tf.Session() as sess:
    sess.run((tf.global_variables_initializer(),
              tf.local_variables_initializer()))
    sess.run(iterator.initializer) # 初始化迭代器
    img_val, labels = sess.run([image_batch, label_batch])
    print(img_val, labels)

# —————————————————————————————读取文件函数——————————————————————————————
def get_batch_data(file, buffer_size, batch_size, epoch_num):
    
    dataSet = tf.data.TFRecordDataset(file)
    dataSet = dataSet.map(parser)
    dataSet = dataSet.shuffle(buffer_size)
    dataSet = dataSet.batch(batch_size)
    dataSet = dataSet.repeat(epoch_num)

    iterator = dataSet.make_initializable_iterator()
    next_batch = iterator.get_next()

    with tf.Session() as sess:
        sess.run(tf.local_variables_initializer()) # 初始化局部变量
        sess.run(iterator.initializer) # 初始化迭代器
    image_batch, label_batch = next_batch
    return image_batch, label_batch
    

版权声明:本文为weixin_43424744原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。