Tensorflow2.0 HDFS数据读取
1. 背景
实际工作中训练数据往往很多,很难全部download到本地使用pandas load使用。本文介绍从HDFS读取数据用于训练的方法。
ps: 代码基于tensorflow 2.0 实现
2. Code
不废话,直接上代码吧
2.1 From CSV
基本思路就是使用 TextLineDataset 读取 csv
并用 decode_csv对每行做处理
def get_file_list(path_pattern=[], root_path):
"""
生成hdfs file list
:param path_pattern:
:param root_path
:return:
"""
cmd = """
hadoop fs -ls -R {0}
""".format(root_path)
if len(path_pattern) == 0:
pattern = "|".join(["(" + str(p.replace('/', '\/')) + ")" for p in path_pattern])
else:
pattern = ""
# 筛选文件
def validate_path_pattern(path):
if pattern != "" and re.search(pattern, path) and '_SUCCESS' not in path:
return True
elif pattern == "" and '_SUCCESS' not in path:
return True
else:
return False
status, output = subprocess.getstatusoutput(cmd)
output = output.split('\n')
output = list(filter(validate_path_pattern, output))
file_list = list()
polluted = any(len(info.split()) != 8 for info in output)
if status == 0 and len(output) > 0 and not polluted:
file_list = [info.split()[-1] for info in output if info[0] == '-']
return file_list
def input_fn(files, batch_size=32, perform_shuffle=False, separator='\t', has_header=False):
"""
input_fn 用于tf.estimators
:param files:
:param batch_size:
:param perform_shuffle:
:param separator:
:param has_header: csv文件是否包含列名
:return:
"""
def get_columns(file):
cmd = """hadoop fs -cat {0} | head -1""".format(file)
status, output = subprocess.getstatusoutput(cmd)
return output.split("\n")[0].split(separator)
def map_fn(line):
defaults = []
for col in all_columns:
if col in CONTINUOUS_COLUMNS + ['label']:
defaults.append([0.0])
else:
defaults.append(['0'])
columns = tf.compat.v2.io.decode_csv(line, defaults, separator, use_quote_delim=False)
feature_map = dict()
for fea, col in zip(all_columns, columns):
if fea not in USE_COLUMNS:
continue
feature_map[fea] = col
labels = feature_map['label']
return feature_map, labels
if has_header:
all_columns = get_columns(files[0])
# 使用.skip() 跳过csv的第一行
dataset = tf.data.Dataset.from_tensor_slices(files)
dataset = dataset.flat_map(lambda filename: (
tf.data.TextLineDataset(filename).skip(1).map(map_fn)))
else:
all_columns = COLUMNS
dataset = tf.data.TextLineDataset(files).map(map_fn())
if perform_shuffle:
dataset = dataset.shuffle(512)
dataset = dataset.batch(batch_size)
return dataset
# 这里定义自己的estimator
# 这里注意 files 那么都要加上 hdfs://xxx/ 的前缀
files = get_file_list(...)
model = tf.estimator.LinearClassifier(...)
model.train(input_fn=lambda: input_fn(...))
python 代码与读取本地csv区别不大, 启动时需要把HDFS_PATH加入 classpath执行
export HADOOP_USER_NAME=xxxxx
export HADOOP_USER_PASSWORD=xxxxx
CLASSPATH=$($HADOOP_HDFS_HOME/bin/hdfs classpath --glob) python3 test_read_hdfs.py
2.2 From TFRecord
tfrecord数据文件是一种将图像数据和标签统一存储的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储等。
以一段读取libsvm格式数据代码为例:
def input_fn(files: list, feature_len, batch_size=32, perform_shuffle=False):
"""
input_fn
:param files:
:param feature_len:
:param batch_size:
:param perform_shuffle:
:return:
"""
def map_fn(record):
feature_description = {
"feature_index": tf.io.FixedLenFeature([feature_len], tf.float32),
"feature_value": tf.io.FixedLenFeature([feature_len], tf.float32),
"label": tf.io.FixedLenFeature([1], tf.int64)
}
parsed = tf.io.parse_single_example(record, feature_description)
return parsed["feature_index"], parsed["feature_value"], parsed["label"]
data = tf.data.TFRecordDataset(files).map(map_fn)
if perform_shuffle:
data = data.shuffle(512)
data = data.batch(batch_size)
return data
# 其余代码 同上 ...
版权声明:本文为liukanglucky原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。