今晚这个代码愣是调试了好一会,弄的我是满脑子在飘"为啥是这样", 网上博客搜索一堆, 也没有解决我的问题,先看一下这个神奇的存在:
错误代码段
image_rotated = tf.py_func(_rotated, [image], [tf.float64])
报错:
tensorflow.python.framework.errors_impl.InvalidArgumentError: 0-th value returned by pyfunc_0 is float, but expects double
[[Node: cond/PyFunc = PyFunc[Tin=[DT_FLOAT], Tout=[DT_DOUBLE], token="pyfunc_0"](cond/Switch_1)]]
[[Node: IteratorGetNext_1 = IteratorGetNext[output_shapes=[<unknown>, [?]], output_types=[DT_FLOAT, DT_INT32],
_device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator_1)]]很明显, 我的bug定位在Node几点的PyFunc处,而且错误也能看的明白, 是Tin返回值的类型(float)错误, 与expect(double)不一致导致的,明白了这一点, 回头查看调用tf.py_func()函数的地方的输入, 愣是觉得没啥问题呀, 到底哪里出错了, 后来经过我一番琢磨和尝试,终于修改成功了, 将原来的代码修改为如下即可:
正确代码
image_rotated = tf.py_func(_rotated, [image], [tf.float32])真是坑人不浅啊啊啊......
下面附上详细的将图片文件转换为Dataset数据集的代码:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @ProjectName : 06_create_image_dataset.py
# @DateTime : 2019-11-30 13:32
# @Author : 皮皮虾
# 变化丰富的数据集会使模型的精度和泛化性能成倍的提升
# 一套成熟的代码,可以使开发数据集的工作简化很多
import os
import logging
import argparse
import numpy as np
import tensorflow as tf
from skimage import transform
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
def load_sample(src_path):
image_path_list = []
real_label_list = []
for _dir_ in os.listdir(src_path):
image_dirname_path = os.path.join(src_path, _dir_)
for image in os.listdir(image_dirname_path):
image_path = os.path.join(image_dirname_path, image)
image_path_list.append(image_path)
real_label_list.append(_dir_)
# 将原始的label进行转化,{"man": 0, "woman": 1}
map_label_list = []
for label in real_label_list:
if label == "man":
map_label_list.append(0)
else:
map_label_list.append(1)
return shuffle(np.asarray(image_path_list), np.asarray(map_label_list))
def _distorted_image(image, size, channal=1, shuffleflag=False, cropflag=False,
brightnessflag=False, contrastflag=False):
# 数据增强
# 随机左右翻转
distorted_image = tf.image.random_flip_left_right(image=image)
# 随机裁剪
if cropflag == True:
# 产生随机数
s = tf.random_uniform(shape=(1, 2), minval=int(size[0] * 0.8), maxval=size[0], dtype=tf.int32)
distorted_image = tf.random_crop(value=distorted_image, size=[s[0][0], s[0][0], channal])
# 随机变化亮度
distorted_image = tf.image.random_flip_up_down(image=distorted_image)
if brightnessflag == True:
distorted_image = tf.image.random_brightness(image=distorted_image, max_delta=10)
# 随机变化对比度
if contrastflag == True:
distorted_image = tf.image.random_contrast(image=distorted_image, lower=0.2, upper=1.8)
# 随机打乱
if shuffleflag == True:
# 沿着第0维度打乱数据
distorted_image = tf.random_shuffle(value=distorted_image)
return distorted_image
def _norm_image(image, size, channal=1, flattenflag=False):
# 归一化 压平
image_decoded = image / 255.0
if flattenflag == True:
image_decoded = tf.reshape(tensor=image_decoded, shape=[size[0]*size[1]*channal])
return image_decoded
# 在整个数据集的处理流程中,对图片的操作都是基于张量进行的, 因为第三方函数无法操作tensorflow
# 中的张量,所以需要对其进行额外的封装
def _random_rotated30(image, label):
# 定义函数实现图片随机旋转操作
def _rotated(image):
# 封装好的skimage模块,来进行图片旋转30度
shift_y, shift_x = np.array(list(image.shape)[:2], np.float32) / 2.
tf_rotate = transform.SimilarityTransform(rotation=np.deg2rad(30))
tf_shift = transform.SimilarityTransform(translation=[-shift_x, -shift_y])
tf_shift_inv = transform.SimilarityTransform(translation=[shift_x, shift_y])
# 兼容transform函数
image_rotated = transform.warp(image, (tf_shift + (tf_rotate + tf_shift_inv)).inverse)
return image_rotated
def _rotatedwrap():
# 下面一行代码的类型 tf.float64 --> tf.float32 调用第三方函数py_func
image_rotated = tf.py_func(_rotated, [image], [tf.float32])
return tf.cast(image_rotated, tf.float32)[0]
a = tf.random_uniform([1], 0, 2, tf.int32) # 实现随机功能
image_decoded = tf.cond(tf.equal(tf.constant(0), a[0]), lambda: image, _rotatedwrap)
return image_decoded, label
def dataset(directory, size, batchsize, random_rotated=False):
"""创建数据集"""
(filenames, labels) = load_sample(directory)
def _parseone(filename, label):
"""读取解析一个图片文件"""
image_string = tf.read_file(filename=filename)
image_decoded = tf.image.decode_image(contents=image_string)
# 对图片做扭曲变换
image_decoded.set_shape([None, None, None])
image_decoded = _distorted_image(image=image_decoded, size=size)
# 变化尺寸
image_decoded = tf.image.resize_images(images=image_decoded, size=size)
# 归一化
image_decoded = _norm_image(image=image_decoded, size=size)
# 类型转化
image_decoded = tf.cast(x=image_decoded, dtype=tf.float32)
# 将label转换为张量
label = tf.cast(x=tf.reshape(tensor=label, shape=[]), dtype=tf.int32)
return image_decoded, label
# 生成Dataset对象
dataset = tf.data.Dataset.from_tensor_slices(tensors=(filenames, labels))
# 转换为图片数据集
dataset = dataset.map(_parseone)
# print("dataset:", dataset)
if random_rotated == True:
dataset = dataset.map(_random_rotated30)
# 批次组合数据集
dataset = dataset.batch(batch_size=batchsize)
return dataset
def show_single_image(subplot, label, image):
plt.subplot(subplot)
plt.axis("off")
plt.imshow(image)
plt.title(label=label)
def show_batch_image(label, image, top):
plt.figure(figsize=(20, 10))
plt.axis("off")
top = min(top, 9)
for i in range(top):
show_single_image(subplot=100 + 10 * top + 1 + i, label=label[i], image=image[i])
plt.show()
def get_one(dataset):
# 生成一个迭代器
iterator = dataset.make_one_shot_iterator()
# 从迭代器中取出一个元素
one_element = iterator.get_next()
return one_element
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(filename)s - %(lineno)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_path",
default=r"",
type=str,
required=False,
help="image input path"
)
parser.add_argument(
"--batch_size",
default=10,
type=int,
required=False,
help="batch size"
)
parser.add_argument(
"--size",
default=[96, 96],
type=list,
required=False,
help="image size"
)
FLAGS, _ = parser.parse_known_args()
dataset_1 = dataset(directory=FLAGS.input_path, size=FLAGS.size, batchsize=FLAGS.batch_size)
dataset_2 = dataset(directory=FLAGS.input_path, size=FLAGS.size, batchsize=FLAGS.batch_size, random_rotated=True)
one_element1 = get_one(dataset=dataset_1)
one_element2 = get_one(dataset=dataset_2)
with tf.Session() as sess:
sess.run(tf.local_variables_initializer())
try:
for step in range(1):
# 原始数据集里面只有10张图片, 这里迭代一次就可以完成
value1 = sess.run(one_element1)
value2 = sess.run(one_element2)
# 显示图片
show_batch_image(label=value2[1], image=np.asarray(value2[0]*255, np.uint8), top=10)
except tf.errors.OutOfRangeError:
print("finish!!!")
版权声明:本文为qq_36400655原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。