Tensorflow2.0 HDFS数据读取

Tensorflow2.0 HDFS数据读取

1. 背景

实际工作中训练数据往往很多,很难全部download到本地使用pandas load使用。本文介绍从HDFS读取数据用于训练的方法。
ps: 代码基于tensorflow 2.0 实现

2. Code

不废话,直接上代码吧

2.1 From CSV

基本思路就是使用 TextLineDataset 读取 csv
并用 decode_csv对每行做处理

def get_file_list(path_pattern=[], root_path):
    """
    生成hdfs file list
    :param path_pattern:
    :param root_path
    :return:
    """
    cmd = """
        hadoop fs -ls -R {0}
    """.format(root_path)
    if len(path_pattern) == 0:
        pattern = "|".join(["(" + str(p.replace('/', '\/')) + ")" for p in path_pattern])
    else:
        pattern = ""
	
	# 筛选文件
    def validate_path_pattern(path):
        if pattern != "" and re.search(pattern, path) and '_SUCCESS' not in path:
            return True
        elif pattern == "" and '_SUCCESS' not in path:
            return True
        else:
            return False

    status, output = subprocess.getstatusoutput(cmd)
    output = output.split('\n')
    output = list(filter(validate_path_pattern, output))
    file_list = list()
    polluted = any(len(info.split()) != 8 for info in output)
    if status == 0 and len(output) > 0 and not polluted:
        file_list = [info.split()[-1] for info in output if info[0] == '-']
    return file_list

def input_fn(files, batch_size=32, perform_shuffle=False, separator='\t', has_header=False):
    """
    input_fn 用于tf.estimators
    :param files:
    :param batch_size:
    :param perform_shuffle:
    :param separator:
    :param has_header: csv文件是否包含列名
    :return:
    """
    def get_columns(file):
        cmd = """hadoop fs -cat {0} | head -1""".format(file)
        status, output = subprocess.getstatusoutput(cmd)
        return output.split("\n")[0].split(separator)

    def map_fn(line):
        defaults = []
        for col in all_columns:
            if col in CONTINUOUS_COLUMNS + ['label']:
                defaults.append([0.0])
            else:
                defaults.append(['0'])
        columns = tf.compat.v2.io.decode_csv(line, defaults, separator, use_quote_delim=False)

        feature_map = dict()

        for fea, col in zip(all_columns, columns):
            if fea not in USE_COLUMNS:
                continue
            feature_map[fea] = col
        labels = feature_map['label']

        return feature_map, labels

    if has_header:
        all_columns = get_columns(files[0])
        # 使用.skip() 跳过csv的第一行
        dataset = tf.data.Dataset.from_tensor_slices(files)
        dataset = dataset.flat_map(lambda filename: (
            tf.data.TextLineDataset(filename).skip(1).map(map_fn)))
    else:
        all_columns = COLUMNS
        dataset = tf.data.TextLineDataset(files).map(map_fn())

    if perform_shuffle:
        dataset = dataset.shuffle(512)
    dataset = dataset.batch(batch_size)
    return dataset

# 这里定义自己的estimator
# 这里注意 files 那么都要加上 hdfs://xxx/ 的前缀
files = get_file_list(...)
model = tf.estimator.LinearClassifier(...)
model.train(input_fn=lambda: input_fn(...))

python 代码与读取本地csv区别不大, 启动时需要把HDFS_PATH加入 classpath执行

export HADOOP_USER_NAME=xxxxx
export HADOOP_USER_PASSWORD=xxxxx
CLASSPATH=$($HADOOP_HDFS_HOME/bin/hdfs classpath --glob) python3 test_read_hdfs.py 

2.2 From TFRecord

tfrecord数据文件是一种将图像数据和标签统一存储的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储等。

以一段读取libsvm格式数据代码为例:

def input_fn(files: list, feature_len, batch_size=32, perform_shuffle=False):
    """
    input_fn
    :param files:
    :param feature_len:
    :param batch_size:
    :param perform_shuffle:
    :return:
    """

    def map_fn(record):
        feature_description = {
            "feature_index": tf.io.FixedLenFeature([feature_len], tf.float32),
            "feature_value": tf.io.FixedLenFeature([feature_len], tf.float32),
            "label": tf.io.FixedLenFeature([1], tf.int64)
        }
        parsed = tf.io.parse_single_example(record, feature_description)
        return parsed["feature_index"], parsed["feature_value"], parsed["label"]

    data = tf.data.TFRecordDataset(files).map(map_fn)
    if perform_shuffle:
        data = data.shuffle(512)
    data = data.batch(batch_size)
    return data
 
 # 其余代码 同上 ...

版权声明:本文为liukanglucky原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。