def tfrecord_pipeline(cls, tfrecord_file, batch_size, prebatch,
epochs, shuffle=True):
''获取tfrecord配置文件''
# tfrecord file should be a text file with absolute path of tfrecords
if not os.path.isfile(tfrecord_file):
raise ValueError('{} should be a text file'.format(tfrecord_file))
with open(tfrecord_file) as f:
record_files = [path.strip() for path in f]
''parser函数是自定义的数据解码函数''
def parser(record):
feature_map = {
CATEGORICAL_FEATURE_NAME: tf.FixedLenFeature(
[prebatch * tf.app.flags.FLAGS.sparse_features], tf.int64),
'label': tf.FixedLenFeature([prebatch], tf.int64),
'numerical': tf.FixedLenFeature(
[prebatch * tf.app.flags.FLAGS.dense_features], tf.float32),
}
features = tf.parse_single_example(record, features=feature_map)
return features
dataset = tf.data.TFRecordDataset(filenames=record_files)
if shuffle:
dataset = dataset.shuffle(buffer_size=128)
dataset = dataset.repeat(epochs).map(parser, num_parallel_calls=4) \
.batch(batch_size) \
.prefetch(buffer_size=512)
return dataset.make_initializable_iterator()
版权声明:本文为OliverChrist原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。