Tensorflow Dataset数据集制作专题【二】— 将图片文件制作成Dataset数据集

今晚这个代码愣是调试了好一会,弄的我是满脑子在飘"为啥是这样", 网上博客搜索一堆, 也没有解决我的问题,先看一下这个神奇的存在:

错误代码段
image_rotated = tf.py_func(_rotated, [image], [tf.float64])

报错:
tensorflow.python.framework.errors_impl.InvalidArgumentError: 0-th value returned by pyfunc_0 is float, but expects double
	 [[Node: cond/PyFunc = PyFunc[Tin=[DT_FLOAT], Tout=[DT_DOUBLE], token="pyfunc_0"](cond/Switch_1)]]
	 [[Node: IteratorGetNext_1 = IteratorGetNext[output_shapes=[<unknown>, [?]], output_types=[DT_FLOAT, DT_INT32], 
	 _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator_1)]]

很明显, 我的bug定位在Node几点的PyFunc处,而且错误也能看的明白, 是Tin返回值的类型(float)错误, 与expect(double)不一致导致的,明白了这一点, 回头查看调用tf.py_func()函数的地方的输入, 愣是觉得没啥问题呀, 到底哪里出错了, 后来经过我一番琢磨和尝试,终于修改成功了, 将原来的代码修改为如下即可:

正确代码
image_rotated = tf.py_func(_rotated, [image], [tf.float32])

真是坑人不浅啊啊啊......

下面附上详细的将图片文件转换为Dataset数据集的代码:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @ProjectName : 06_create_image_dataset.py
# @DateTime :  2019-11-30 13:32
# @Author : 皮皮虾

# 变化丰富的数据集会使模型的精度和泛化性能成倍的提升
# 一套成熟的代码,可以使开发数据集的工作简化很多


import os
import logging
import argparse
import numpy as np
import tensorflow as tf
from skimage import transform
import matplotlib.pyplot as plt
from sklearn.utils import shuffle


def load_sample(src_path):
    image_path_list = []
    real_label_list = []
    for _dir_ in os.listdir(src_path):
        image_dirname_path = os.path.join(src_path, _dir_)
        for image in os.listdir(image_dirname_path):
            image_path = os.path.join(image_dirname_path, image)
            image_path_list.append(image_path)
            real_label_list.append(_dir_)
    # 将原始的label进行转化,{"man": 0, "woman": 1}
    map_label_list = []
    for label in real_label_list:
        if label == "man":
            map_label_list.append(0)
        else:
            map_label_list.append(1)
    return shuffle(np.asarray(image_path_list), np.asarray(map_label_list))

def _distorted_image(image, size, channal=1, shuffleflag=False, cropflag=False,
                     brightnessflag=False, contrastflag=False):
    # 数据增强
    # 随机左右翻转
    distorted_image = tf.image.random_flip_left_right(image=image)
    # 随机裁剪
    if cropflag == True:
        # 产生随机数
        s = tf.random_uniform(shape=(1, 2), minval=int(size[0] * 0.8), maxval=size[0], dtype=tf.int32)
        distorted_image = tf.random_crop(value=distorted_image, size=[s[0][0], s[0][0], channal])

    # 随机变化亮度
    distorted_image = tf.image.random_flip_up_down(image=distorted_image)
    if brightnessflag == True:
        distorted_image = tf.image.random_brightness(image=distorted_image, max_delta=10)

    # 随机变化对比度
    if contrastflag == True:
        distorted_image = tf.image.random_contrast(image=distorted_image, lower=0.2, upper=1.8)

    # 随机打乱
    if shuffleflag == True:
        # 沿着第0维度打乱数据
        distorted_image = tf.random_shuffle(value=distorted_image)

    return distorted_image


def _norm_image(image, size, channal=1, flattenflag=False):
    # 归一化 压平
    image_decoded = image / 255.0
    if flattenflag == True:
        image_decoded = tf.reshape(tensor=image_decoded, shape=[size[0]*size[1]*channal])

    return image_decoded

# 在整个数据集的处理流程中,对图片的操作都是基于张量进行的, 因为第三方函数无法操作tensorflow
# 中的张量,所以需要对其进行额外的封装

def _random_rotated30(image, label):  
    # 定义函数实现图片随机旋转操作

    def _rotated(image):  
        # 封装好的skimage模块,来进行图片旋转30度
        shift_y, shift_x = np.array(list(image.shape)[:2], np.float32) / 2.
        tf_rotate = transform.SimilarityTransform(rotation=np.deg2rad(30))
        tf_shift = transform.SimilarityTransform(translation=[-shift_x, -shift_y])
        tf_shift_inv = transform.SimilarityTransform(translation=[shift_x, shift_y])
        # 兼容transform函数
        image_rotated = transform.warp(image, (tf_shift + (tf_rotate + tf_shift_inv)).inverse)
        return image_rotated

    def _rotatedwrap():
        # 下面一行代码的类型 tf.float64 --> tf.float32  调用第三方函数py_func
        image_rotated = tf.py_func(_rotated, [image], [tf.float32])  
        return tf.cast(image_rotated, tf.float32)[0]

    a = tf.random_uniform([1], 0, 2, tf.int32)  # 实现随机功能
    image_decoded = tf.cond(tf.equal(tf.constant(0), a[0]), lambda: image, _rotatedwrap)

    return image_decoded, label


def dataset(directory, size, batchsize, random_rotated=False):
    """创建数据集"""
    (filenames, labels) = load_sample(directory)

    def _parseone(filename, label):
        """读取解析一个图片文件"""
        image_string = tf.read_file(filename=filename)
        image_decoded = tf.image.decode_image(contents=image_string)
        # 对图片做扭曲变换
        image_decoded.set_shape([None, None, None])
        image_decoded = _distorted_image(image=image_decoded, size=size)
        # 变化尺寸
        image_decoded = tf.image.resize_images(images=image_decoded, size=size)
        # 归一化
        image_decoded = _norm_image(image=image_decoded, size=size)
        # 类型转化
        image_decoded = tf.cast(x=image_decoded, dtype=tf.float32)
        # 将label转换为张量
        label = tf.cast(x=tf.reshape(tensor=label, shape=[]), dtype=tf.int32)
        return image_decoded, label

    # 生成Dataset对象
    dataset = tf.data.Dataset.from_tensor_slices(tensors=(filenames, labels))
    # 转换为图片数据集
    dataset = dataset.map(_parseone)
    # print("dataset:", dataset)
    if random_rotated == True:
        dataset = dataset.map(_random_rotated30)
    # 批次组合数据集
    dataset = dataset.batch(batch_size=batchsize)

    return dataset


def show_single_image(subplot, label, image):
    plt.subplot(subplot)
    plt.axis("off")
    plt.imshow(image)
    plt.title(label=label)


def show_batch_image(label, image, top):
    plt.figure(figsize=(20, 10))
    plt.axis("off")
    top = min(top, 9)
    for i in range(top):
        show_single_image(subplot=100 + 10 * top + 1 + i, label=label[i], image=image[i])
    plt.show()


def get_one(dataset):
    # 生成一个迭代器
    iterator = dataset.make_one_shot_iterator()
    # 从迭代器中取出一个元素
    one_element = iterator.get_next()
    return one_element


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(filename)s - %(lineno)s - %(levelname)s - %(message)s')
    logger = logging.getLogger(__name__)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--input_path",
        default=r"",
        type=str,
        required=False,
        help="image input path"
    )
    parser.add_argument(
        "--batch_size",
        default=10,
        type=int,
        required=False,
        help="batch size"
    )
    parser.add_argument(
        "--size",
        default=[96, 96],
        type=list,
        required=False,
        help="image size"
    )
    FLAGS, _ = parser.parse_known_args()

    dataset_1 = dataset(directory=FLAGS.input_path, size=FLAGS.size, batchsize=FLAGS.batch_size)
    dataset_2 = dataset(directory=FLAGS.input_path, size=FLAGS.size, batchsize=FLAGS.batch_size, random_rotated=True)

    one_element1 = get_one(dataset=dataset_1)
    one_element2 = get_one(dataset=dataset_2)

    with tf.Session() as sess:
        sess.run(tf.local_variables_initializer())
        try:
            for step in range(1):
                # 原始数据集里面只有10张图片, 这里迭代一次就可以完成
                value1 = sess.run(one_element1)
                value2 = sess.run(one_element2)
                # 显示图片
                show_batch_image(label=value2[1], image=np.asarray(value2[0]*255, np.uint8), top=10)
        except tf.errors.OutOfRangeError:
            print("finish!!!")



 


版权声明:本文为qq_36400655原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。