目录
前言
本篇博客从实际代码出发,介绍tensorflow里面TFRecord格式的数据。Tensorflow官网可以参考:演练:读取和写入图像数据和TFRecordDataset。Tensorflow读取数据,官网给出了三种方法:
- 供给数据:在Tensorflow程序运行的每一步,让python代码来供给数据
- 从文件中读取数据:在tensorflow图的起始,让一个输入管线从文件中读取数据
- 预加载数据:在tensorflow图中定义常量或变量来保存所有数据
Tensorflow提供一种统一格式来存储数据,这个格式就是TFRecords。
TFRecord
TFRecord数据类型
tf.Example类就是一种将数据表示为{string:value}形式的message类型,Tensorflow经常使用tf.Example来写入,读取TFRecord的数据。通常情况下,tf.Example中可以使用以下几种数据格式:
- tf.train.BytesList:可以使用类型包括string和byte
- tf.train.FloatList:可以使用类型包括float和double
- tf.train.Int64List:可以使用类型包括enum,bool,int32,uint32,int64
以目标检测数据集构建TFRecord为例,代码如下:
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
'image/width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
'image/filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[
annotation['filename'].encode('utf8')])),
'image/source_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[
annotation['filename'].encode('utf8')])),
'image/key/sha256': tf.train.Feature(bytes_list=tf.train.BytesList(value=[key.encode('utf8')])),
'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'image/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=['jpeg'.encode('utf8')])),
'image/object/bbox/xmin': tf.train.Feature(float_list=tf.train.FloatList(value=xmin)),
'image/object/bbox/xmax': tf.train.Feature(float_list=tf.train.FloatList(value=xmax)),
'image/object/bbox/ymin': tf.train.Feature(float_list=tf.train.FloatList(value=ymin)),
'image/object/bbox/ymax': tf.train.Feature(float_list=tf.train.FloatList(value=ymax)),
'image/object/class/text': tf.train.Feature(bytes_list=tf.train.BytesList(value=classes_text)),
'image/object/class/label': tf.train.Feature(int64_list=tf.train.Int64List(value=classes)),
'image/object/difficult': tf.train.Feature(int64_list=tf.train.Int64List(value=difficult_obj)),
'image/object/truncated': tf.train.Feature(int64_list=tf.train.Int64List(value=truncated)),
'image/object/view': tf.train.Feature(bytes_list=tf.train.BytesList(value=views)),
}))
从上面的代码可以得知该TFRecord里面记录的是图像的高度、宽度、文件名、目标框、分类目标等信息。
TFRecord写入
以目标检测VOC2012数据集为例,该数据集的文件结构如下图所示:

Annotation文件夹保存的是目标标注的xml文件,ImageSets里面有一个Main文件夹,Main文件夹里面包含训练集、测试集以及验证集的txt文件。JPEGImages包含了图像数据。以下是程序代码实现过程:
def main(_argv):
# 获取类别数
class_map = {name: idx for idx, name in enumerate(
open(FLAGS.classes).read().splitlines())}
logging.info("Class mapping loaded: %s", class_map)
# 声明TFRecord的对象
writer = tf.io.TFRecordWriter(FLAGS.output_file)
image_list = open(os.path.join(
FLAGS.data_dir, 'ImageSets', 'Main', '%s.txt' % FLAGS.split)).read().splitlines()
logging.info("Image list loaded: %d", len(image_list))
for name in tqdm.tqdm(image_list):
annotation_xml = os.path.join(
FLAGS.data_dir, 'Annotations', name + '.xml')
try:
# 解析xml文件
annotation_xml = lxml.etree.fromstring(open(annotation_xml).read().encode('utf-8'))
# annotation_xml = lxml.etree.fromstring(open(annotation_xml).read())
annotation = parse_xml(annotation_xml)['annotation']
#获取TFRecord对象
tf_example = build_example(annotation, class_map)
#序列化TFRecord对象并写入文件
writer.write(tf_example.SerializeToString())
except Exception as e:
print(annotation_xml)
raise e
writer.close()
logging.info("Done")
现在我们来看下解析xml文件的代码:
def parse_xml(xml):
if not len(xml):
return {xml.tag: xml.text}
result = {}
for child in xml:
child_result = parse_xml(child)
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result:
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}
构建tfrecord对象
def build_example(annotation, class_map):
img_path = os.path.join(
FLAGS.data_dir, 'JPEGImages', annotation['filename'])
print(img_path)
img_raw = open(img_path, 'rb').read()
key = hashlib.sha256(img_raw).hexdigest()
width = int(annotation['size']['width'])
height = int(annotation['size']['height'])
xmin = []
ymin = []
xmax = []
ymax = []
classes = []
classes_text = []
truncated = []
views = []
difficult_obj = []
if 'object' in annotation:
for obj in annotation['object']:
difficult = bool(int(obj['difficult']))
difficult_obj.append(int(difficult))
xmin.append(float(obj['bndbox']['xmin']) / width)
ymin.append(float(obj['bndbox']['ymin']) / height)
xmax.append(float(obj['bndbox']['xmax']) / width)
ymax.append(float(obj['bndbox']['ymax']) / height)
classes_text.append(obj['name'].encode('utf8'))
classes.append(class_map[obj['name']])
truncated.append(int(obj['truncated']))
views.append(obj['pose'].encode('utf8'))
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
'image/width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
'image/filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[
annotation['filename'].encode('utf8')])),
'image/source_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[
annotation['filename'].encode('utf8')])),
'image/key/sha256': tf.train.Feature(bytes_list=tf.train.BytesList(value=[key.encode('utf8')])),
'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'image/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=['jpeg'.encode('utf8')])),
'image/object/bbox/xmin': tf.train.Feature(float_list=tf.train.FloatList(value=xmin)),
'image/object/bbox/xmax': tf.train.Feature(float_list=tf.train.FloatList(value=xmax)),
'image/object/bbox/ymin': tf.train.Feature(float_list=tf.train.FloatList(value=ymin)),
'image/object/bbox/ymax': tf.train.Feature(float_list=tf.train.FloatList(value=ymax)),
'image/object/class/text': tf.train.Feature(bytes_list=tf.train.BytesList(value=classes_text)),
'image/object/class/label': tf.train.Feature(int64_list=tf.train.Int64List(value=classes)),
'image/object/difficult': tf.train.Feature(int64_list=tf.train.Int64List(value=difficult_obj)),
'image/object/truncated': tf.train.Feature(int64_list=tf.train.Int64List(value=truncated)),
'image/object/view': tf.train.Feature(bytes_list=tf.train.BytesList(value=views)),
}))
return example
TFRecord读取
上面完成了TFRecord写入之后我们来看下读取部分。读取分为训练前可视化数据集以及训练的加载。
数据可视化
解析tfrecord函数:
def parse_tfrecord(tfrecord, class_table, size):
x = tf.io.parse_single_example(tfrecord, IMAGE_FEATURE_MAP)
x_train = tf.image.decode_jpeg(x['image/encoded'], channels=3)
x_train = tf.image.resize(x_train, (size, size))
class_text = tf.sparse.to_dense(
x['image/object/class/text'], default_value='')
labels = tf.cast(class_table.lookup(class_text), tf.float32)
y_train = tf.stack([tf.sparse.to_dense(x['image/object/bbox/xmin']),
tf.sparse.to_dense(x['image/object/bbox/ymin']),
tf.sparse.to_dense(x['image/object/bbox/xmax']),
tf.sparse.to_dense(x['image/object/bbox/ymax']),
labels], axis=1)
paddings = [[0, FLAGS.yolo_max_boxes - tf.shape(y_train)[0]], [0, 0]]
y_train = tf.pad(y_train, paddings)
return x_train, y_train
visualize_dataset
def main(_argv):
class_names = [c.strip() for c in open(FLAGS.classes).readlines()]
logging.info('classes loaded')
dataset = load_tfrecord_dataset(FLAGS.dataset, FLAGS.classes, FLAGS.size)
dataset = dataset.shuffle(512)
for image, labels in dataset.take(4):
boxes = []
scores = []
classes = []
for x1, y1, x2, y2, label in labels:
if x1 == 0 and x2 == 0:
continue
boxes.append((x1, y1, x2, y2))
scores.append(1)
classes.append(label)
nums = [len(boxes)]
boxes = [boxes]
scores = [scores]
classes = [classes]
logging.info('labels:')
for i in range(nums[0]):
logging.info('\t{}, {}, {}'.format(class_names[int(classes[0][i])],
np.array(scores[0][i]),
np.array(boxes[0][i])))
img = cv2.cvtColor(image.numpy(), cv2.COLOR_RGB2BGR)
img = draw_outputs(img, (boxes, scores, classes, nums), class_names)
cv2.imwrite(FLAGS.output, img)
logging.info('output saved to: {}'.format(FLAGS.output))
上面那段代码里面有两个比较重要的函数:load_tfrecord_dataset以及draw_outputs。
def load_tfrecord_dataset(file_pattern, class_file, size=416):
LINE_NUMBER = -1 # TODO: use tf.lookup.TextFileIndex.LINE_NUMBER
'''
tf.lookup.StaticHashTable:建立类别与数字的关联关系
keys_tensor = tf.constant([1, 2])
vals_tensor = tf.constant([3, 4])
input_tensor = tf.constant([1, 5])
table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1)
print(table.lookup(input_tensor))
output:tf.Tensor([ 3 -1], shape=(2,), dtype=int32)
tf.lookup.TextFileInitializer:Table initializers from a text file.
'''
class_table = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer(
class_file, tf.string, 0, tf.int64, LINE_NUMBER, delimiter="\n"), -1)
files = tf.data.Dataset.list_files(file_pattern)
dataset = files.flat_map(tf.data.TFRecordDataset)
return dataset.map(lambda x: parse_tfrecord(x, class_table, size))
在上面的代码中,可以使用tf.data.TFRecordDataset读取TFRecord的数据,整体代码如下所示:
def load_tfrecord_dataset(file_pattern, class_file, size=(416, 416)):
LINE_NUMBER = -1 # TODO: use tf.lookup.TextFileIndex.LINE_NUMBER
'''
tf.lookup.StaticHashTable: 建立类别与数字的关联关系
keys_tensor = tf.constant([1, 2])
vals_tensor = tf.constant([3, 4])
input_tensor = tf.constant([1, 5])
table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1)
print(table.lookup(input_tensor))
output:tf.Tensor([ 3 -1], shape=(2,), dtype=int32)
tf.lookup.TextFileInitializer: Table initializers from a text file.
'''
class_table = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer(
class_file, tf.string, 0, tf.int64, LINE_NUMBER, delimiter="\n"), -1)
# files = tf.data.Dataset.list_files(file_pattern)
# dataset = files.flat_map(tf.data.TFRecordDataset)
dataset = tf.data.TFRecordDataset(file_pattern)
# debug
# for ds in dataset:
# parse_tfrecord(ds, class_table, size)
return dataset.map(lambda x: parse_tfrecord(x, class_table, size))
目标检测画出groundtrue
def draw_outputs(img, outputs, class_names):
boxes, objectness, classes, nums = outputs
boxes, objectness, classes, nums = boxes[0], objectness[0], classes[0], nums[0]
wh = np.flip(img.shape[0:2])
for i in range(nums):
x1y1 = tuple((np.array(boxes[i][0:2]) * wh).astype(np.int32))
x2y2 = tuple((np.array(boxes[i][2:4]) * wh).astype(np.int32))
# 检查box的问题
max_x = img.shape[0]
max_y = img.shape[1]
x1y1 = checkBox(x1y1, max_x, max_y)
x2y2 = checkBox(x2y2, max_x, max_y)
img = cv2.rectangle(img, x1y1, x2y2, (255, 0, 0), 2)
img = cv2.putText(img, '{} {:.4f}'.format(
class_names[int(classes[i])], objectness[i]),
x1y1, cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, (0, 0, 255), 2)
return img
训练读取数据
train_dataset = dataset.load_tfrecord_dataset(FLAGS.dataset, FLAGS.classes, FLAGS.size)
train_dataset = train_dataset.shuffle(buffer_size=512)
train_dataset = train_dataset.batch(FLAGS.batch_size)
train_dataset = train_dataset.map(lambda x, y: (
dataset.transform_images(x, FLAGS.size),
dataset.transform_targets(y, anchors, anchor_masks, FLAGS.size)))
train_dataset = train_dataset.prefetch(
buffer_size=tf.data.experimental.AUTOTUNE)
整个过程中用到tf.data.Dataset读取tfrecord数据,在新的API中,只需要简单的使用 tf.data.TFRecordDataset 就能够轻松的读取数据,高效便捷。 强烈建议使用 tf.data 来完成数据的读取操作,值得一提的是,旧的数据读取接口在最新的TensorFlow API 中已经完全舍弃了。在训练读取数据的代码中dataset.load_tfrecord_dataset同上的load_tfrecord_dataset里面有部分代码我觉得有点多余:
files = tf.data.Dataset.list_files(file_pattern)
dataset = files.flat_map(tf.data.TFRecordDataset)
return dataset.map(lambda x: parse_tfrecord(x, class_table, size))
还不如直接用tf.data.TFRecordDataset读取TFRcord数据,然后再进行map操作,代码如下:
def load_tfrecord_dataset(file_pattern, class_file, size=416):
LINE_NUMBER = -1 # TODO: use tf.lookup.TextFileIndex.LINE_NUMBER
'''
tf.lookup.StaticHashTable:建立类别与数字的关联关系
keys_tensor = tf.constant([1, 2])
vals_tensor = tf.constant([3, 4])
input_tensor = tf.constant([1, 5])
table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1)
print(table.lookup(input_tensor))
output:tf.Tensor([ 3 -1], shape=(2,), dtype=int32)
tf.lookup.TextFileInitializer:Table initializers from a text file.
'''
class_table = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer(
class_file, tf.string, 0, tf.int64, LINE_NUMBER, delimiter="\n"), -1)
# files = tf.data.Dataset.list_files(file_pattern)
# dataset = files.flat_map(tf.data.TFRecordDataset)
dataset = tf.data.TFRecordDataset(file_pattern)
return dataset.map(lambda x: parse_tfrecord(x, class_table, size))
喂入模型
在加载TFRecord数据之后,需要对模型进行预处理,包括shuffle,batch,resize等transform_images操作,相对应的有transform_targets操作,以及prefetch操作。接下来一个个介绍这些操作。
shuffle
Randomly shuffles the elements of this dataset.
This dataset fills a buffer with buffer_size elements, then randomly samples elements from this buffer, replacing the selected elements with new elements. For perfect shuffling, a buffer size greater than or equal to the full size of the dataset is required.
For instance, if your dataset contains 10,000 elements but buffer_size is set to 1,000, then shuffle will initially select a random element from only the first 1,000 elements in the buffer. Once an element is selected, its space in the buffer is replaced by the next (i.e. 1,001-st) element, maintaining the 1,000 element buffer.
reshuffle_each_iteration controls whether the shuffle order should be different for each epoch. In TF 1.X, the idiomatic way to create epochs was through the repeat transformation:
train_dataset = train_dataset.shuffle(buffer_size=512)
transform_images和transform_targets
这两个方法就是具体看网络需要怎么预处理图像而进行自定,如需要resize到某个大小的图像、归一化等。transform_targets即将对应的标签进行预处理。
prefetch
Creates a Dataset that prefetches elements from this dataset.
Most dataset input pipelines should end with a call to prefetch. This allows later elements to be prepared while the current element is being processed. This often improves latency and throughput, at the cost of using additional memory to store prefetched elements.
batch
batch用法比较简单,就是将数据集分成不同的批次,其API如下:
batch(
batch_size, drop_remainder=False, num_parallel_calls=None, deterministic=None,
name=None
)
里面有个重要参数num_parallel_calls,该参数表示异步并行计算的批数。类似后面将会降到的prefetch。
除了上面提及的方法,还有很多关于tf.data.dataset的方法,如apply、cache等,详情可以参考:
tf.data.dataset用法介绍。