TensorFlow中读取图像数据的三种方式(转)

附加一个链接关于DatasetAPI:https://zhuanlan.zhihu.com/p/30751039

本文面对三种常常遇到的情况,总结三种读取数据的方式,分别用于处理单张图片、大量图片,和TFRecorder读取方式。并且还补充了功能相近的tf函数。

处理单张图片

我们训练完模型之后,常常要用图片测试,有的时候,我们并不需要对很多图像做测试,可能就是几张甚至一张。这种情况下没有必要用队列机制。


 
  1. import tensorflow as tf
  2. import matplotlib.pyplot as plt
  3. def read_image(file_name):
  4. img = tf.read_file(filename=file_name) #默认读取格式为uint8
  5. print( "img 的类型是",type(img));
  6. img = tf.image.decode_jpeg(img,channels= 0) # channels 为1得到的是灰度图,为0则按照图片格式来读
  7. return img
  8. def main( ):
  9. with tf.device( "/cpu:0"):
  10. img_path= './1.jpg'
  11. img=read_image(img_path)
  12. with tf.Session() as sess:
  13. image_numpy=sess.run(img)
  14. print(image_numpy)
  15. print(image_numpy.dtype)
  16. print(image_numpy.shape)
  17. plt.imshow(image_numpy)
  18. plt.show()
  19. if __name__== "__main__":
  20. main()

img_path是文件所在地址包括文件名称,地址用相对地址或者绝对地址都行

输出结果为:


 
  1. img 的类型是 <class 'tensorflow.python.framework.ops.Tensor'>
  2. [[[ 196 219 209]
  3. [ 196 219 209]
  4. [ 196 219 209]
  5. ...
  6. [[ 71 106 42]
  7. [ 59 89 39]
  8. [ 34 63 19]
  9. ...
  10. [ 21 52 46]
  11. [ 15 45 43]
  12. [ 22 50 53]]]
  13. uint8
  14. ( 675, 1200, 3)

和tf.read_file用法相似的函数还有tf.gfile.FastGFile  tf.gfile.GFile,只是要指定读取方式是'r' 还是'rb' 。

需要读取大量图像用于训练

这种情况就需要使用Tensorflow队列机制。首先是获得每张图片的路径,把他们都放进一个list里面,然后用string_input_producer创建队列,再用tf.WholeFileReader读取。具体请看下例:


 
  1. def get_image_batch(data_file,batch_size):
  2. data_names=[os.path.join(data_file,k) for k in os.listdir(data_file)]
  3. #这个num_epochs函数在整个Graph是local Variable,所以在sess.run全局变量的时候也要加上局部变量。 filenames_queue=tf.train.string_input_producer(data_names,num_epochs=50,shuffle=True,capacity=512)
  4. reader=tf.WholeFileReader()
  5. _,img_bytes=reader.read(filenames_queue)
  6. image=tf.image.decode_png(img_bytes,channels= 1) #读取的是什么格式,就decode什么格式
  7. #解码成单通道的,并且获得的结果的shape是[?, ?,1],也就是Graph不知道图像的大小,需要set_shape
  8. image.set_shape([ 180, 180, 1]) #set到原本已知图像的大小。或者直接通过tf.image.resize_images
  9. image=tf.image.convert_image_dtype(image,tf.float32)
  10. #预处理 下面的一句代码可以换成自己想使用的预处理方式
  11. #image=tf.divide(image,255.0)
  12. return tf.train.batch([image],batch_size)

这里的date_file是指文件夹所在的路径,不包括文件名。第一句是遍历指定目录下的文件名称,存放到一个list中。当然这个做法有很多种方法,比如glob.glob,或者tf.train.match_filename_once

全部代码如下:


 
  1. import tensorflow as tf
  2. import os
  3. def read_image(data_file,batch_size):
  4. data_names=[os.path.join(data_file,k) for k in os.listdir(data_file)]
  5. filenames_queue=tf.train.string_input_producer(data_names,num_epochs= 5,shuffle= True,capacity= 30)
  6. reader=tf.WholeFileReader()
  7. _,img_bytes=reader.read(filenames_queue)
  8. image=tf.image.decode_jpeg(img_bytes,channels= 1)
  9. image=tf.image.resize_images(image,( 180, 180))
  10. image=tf.image.convert_image_dtype(image,tf.float32)
  11. return tf.train.batch([image],batch_size)
  12. def main( ):
  13. img_path= r'F:\dataSet\WIDER\WIDER_train\images\6--Funeral' #本地的一个数据集目录,有足够的图像
  14. img=read_image(img_path,batch_size= 10)
  15. image=img[ 0] #取出每个batch的第一个数据
  16. print(image)
  17. init=[tf.global_variables_initializer(),tf.local_variables_initializer()]
  18. with tf.Session() as sess:
  19. sess.run(init)
  20. coord = tf.train.Coordinator()
  21. threads = tf.train.start_queue_runners(sess=sess,coord=coord)
  22. try:
  23. while not coord.should_stop():
  24. print(image.shape)
  25. except tf.errors.OutOfRangeError:
  26. print( 'read done')
  27. finally:
  28. coord.request_stop()
  29. coord.join(threads)
  30. if __name__== "__main__":
  31. main()

这段代码可以说写的很是规整了。注意到init里面有对local变量的初始化,并且因为用到了队列,当然要告诉电脑什么时候队列开始, tf.train.Coordinator 和 tf.train.start_queue_runners 就是两个管理队列的类,用法如程序所示。

输出如下:


 
  1. (180, 180, 1)
  2. (180, 180, 1)
  3. (180, 180, 1)
  4. (180, 180, 1)
  5. (180, 180, 1)

与 tf.train.string_input_producer相似的函数是 tf.train.slice_input_producer。 tf.train.slice_input_producer和tf.train.string_input_producer的第一个参数形式不一样。等有时间再做一个二者比较的博客

 

对TFRecorder解码获得图像数据

其实这块和上一种方式差不多的,更重要的是怎么生成TFRecorder文件,这一部分我会补充到另一篇博客上。

仍然使用 tf.train.string_input_producer。


 
  1. import tensorflow as tf
  2. import matplotlib.pyplot as plt
  3. import os
  4. import cv2
  5. import numpy as np
  6. import glob
  7. def read_image(data_file,batch_size):
  8. files_path=glob.glob(data_file)
  9. queue=tf.train.string_input_producer(files_path,num_epochs= None)
  10. reader = tf.TFRecordReader()
  11. print(queue)
  12. _, serialized_example = reader.read(queue)
  13. features = tf.parse_single_example(
  14. serialized_example,
  15. features={
  16. 'image_raw': tf.FixedLenFeature([], tf.string),
  17. 'label_raw': tf.FixedLenFeature([], tf.string),
  18. })
  19. image = tf.decode_raw(features[ 'image_raw'], tf.uint8)
  20. image = tf.cast(image, tf.float32)
  21. image.set_shape(( 12* 12* 3))
  22. label = tf.decode_raw(features[ 'label_raw'], tf.float32)
  23. label.set_shape(( 2))
  24. # 预处理部分省略,大家可以自己根据需要添加
  25. return tf.train.batch([image,label],batch_size=batch_size,num_threads= 4,capacity= 5*batch_size)
  26. def main( ):
  27. img_path= r'F:\python\MTCNN_by_myself\prepare_data\pnet*.tfrecords' #本地的几个tf文件
  28. img,label=read_image(img_path,batch_size= 10)
  29. image=img[ 0]
  30. init=[tf.global_variables_initializer(),tf.local_variables_initializer()]
  31. with tf.Session() as sess:
  32. sess.run(init)
  33. coord = tf.train.Coordinator()
  34. threads = tf.train.start_queue_runners(sess=sess,coord=coord)
  35. try:
  36. while not coord.should_stop():
  37. print(image.shape)
  38. except tf.errors.OutOfRangeError:
  39. print( 'read done')
  40. finally:
  41. coord.request_stop()
  42. coord.join(threads)
  43. if __name__== "__main__":
  44. main()

在read_image函数中,先使用glob函数获得了存放tfrecord文件的列表,然后根据TFRecord文件是如何存的就如何parse,再set_shape

这里有必要提醒下parse的方式。我们看到这里用的是tf.decode_raw ,因为做TFRecord是将图像数据string化了,数据是串行的,丢失了空间结果。从features中取出image和label的数据,这时就要用 tf.decode_raw  解码,得到的结果当然也是串行的了,所以set_shape 成一个串行的,再reshape。这种方式是取决于你的编码TFRecord方式的。

再举一种例子:


 
  1. reader=tf.TFRecordReader()
  2. _,serialized_example=reader.read(file_name_queue)
  3. features = tf.parse_single_example(serialized_example, features={
  4. 'data': tf.FixedLenFeature([ 256, 256], tf. float32),
  5. 'label': tf.FixedLenFeature([], tf. int64),
  6. 'id': tf.FixedLenFeature([], tf. int64)
  7. })
  8. img = features[ 'data']
  9. label =features[ 'label']
  10. id = features[ 'id']

这个时候就不需要任何解码了。因为做TFRecord的方式就是直接把图像数据append进去了。


版权声明:本文为monk1992原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。