原项目地址: https://github.com/wizyoung/YOLOv3_TensorFlow
这个项目训练加运行自己的数据集,可以参考这篇博客
但是这个项目中测试部分,只能处理一张图片,并不能批量处理一批图片。个人对代码修改一下,可以批量处理某一个文件夹下的所有图片。
首先,需要将要测试的图片保存到一个指定的文件夹中,我保存到了./test_picture_myself文件下。
.txt文本中的数字存的是图片的名字,要把包含这些名字的图片保存到另一个文件夹中。可以参考: https://blog.csdn.net/weixin_43384257/article/details/98374743
然后批量处理这个文件夹下的图片:
# coding: utf-8
# python test_single_image.py ./insulator/VOC2018/JPEGImages/000336.jpg 这张图片有缺陷
#000342 000388
#######################可以批量处理某一个图片文件夹下的所有图片./test_picture_myself#####################################
from __future__ import division, print_function
import tensorflow as tf
import numpy as np
import argparse
import cv2
from utils.misc_utils import parse_anchors, read_class_names
from utils.nms_utils import gpu_nms
from utils.plot_utils import get_color_table, plot_one_box
from utils.data_aug import letterbox_resize
import glob
from model import yolov3
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
parser = argparse.ArgumentParser(description="YOLO-V3 test single image test procedure.")
parser.add_argument("--input_image", type=str,default="./insulator/VOC2018/JPEGImages/000007.jpg",
help="The path of the input image.")
#如果用上面那句话就会默认一个位置,然后运行脚本文件时就不用非得输入某张图片的文字了
#但是注意必须是通过脚本文件运行 python test_single_image.py
# parser.add_argument("input_image", type=str,
# help="The path of the input image.")
parser.add_argument("--anchor_path", type=str, default="./data/yolo_anchors.txt",
help="The path of the anchor txt file.")
parser.add_argument("--new_size", nargs='*', type=int, default=[416, 416],
help="Resize the input image with `new_size`, size format: [width, height]")
parser.add_argument("--letterbox_resize", type=lambda x: (str(x).lower() == 'true'), default=True,
help="Whether to use the letterbox resize.")
parser.add_argument("--class_name_path", type=str, default="./data/voc_names.txt",
help="The path of the class names.")
parser.add_argument("--restore_path", type=str, default="./checkpoint_5/model-epoch_290_step_12512_loss_0.1421_lr_0.0002",
help="The path of the weights to restore.")
args = parser.parse_args()
args.anchors = parse_anchors(args.anchor_path)
args.classes = read_class_names(args.class_name_path)
args.num_class = len(args.classes)
color_table = get_color_table(args.num_class)
with tf.Session() as sess:
input_data = tf.placeholder(tf.float32, [1, args.new_size[1], args.new_size[0], 3], name='input_data')
#input_data="/home/dlut/网络/YOLOv3_TensorFlow-JYZ_6/insulator/VOC2018/JPEGImages/000336.jpg "
yolo_model = yolov3(args.num_class, args.anchors)
with tf.variable_scope('yolov3'):
pred_feature_maps = yolo_model.forward(input_data, False)
pred_boxes, pred_confs, pred_probs = yolo_model.predict(pred_feature_maps)
pred_scores = pred_confs * pred_probs
boxes, scores, labels = gpu_nms(pred_boxes, pred_scores, args.num_class, max_boxes=200, score_thresh=0.3, nms_thresh=0.3) #0.45 nms_thresh越小重叠的框越少
saver = tf.train.Saver()
saver.restore(sess, args.restore_path)
# for line in open("/home/dlut/网络/YOLOv3_TensorFlow-JYZ_6/insulator/VOC2018/ImageSets/Main/test_change.txt"):
for jpgfile in glob.glob(r'./test_picture_myself/*.jpg'):
args.input_image = jpgfile
img_ori = cv2.imread(args.input_image)
if args.letterbox_resize:
img, resize_ratio, dw, dh = letterbox_resize(img_ori, args.new_size[0], args.new_size[1])
else:
height_ori, width_ori = img_ori.shape[:2]
img = cv2.resize(img_ori, tuple(args.new_size))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.asarray(img, np.float32)
img = img[np.newaxis, :] / 255.
boxes_, scores_, labels_ = sess.run([boxes, scores, labels], feed_dict={input_data: img})
if args.letterbox_resize:
boxes_[:, [0, 2]] = (boxes_[:, [0, 2]] - dw) / resize_ratio
boxes_[:, [1, 3]] = (boxes_[:, [1, 3]] - dh) / resize_ratio
else:
boxes_[:, [0, 2]] *= (width_ori/float(args.new_size[0]))
boxes_[:, [1, 3]] *= (height_ori/float(args.new_size[1]))
print("box coords:")
print(boxes_)
print('*' * 30)
print("scores:")
print(scores_)
print('*' * 30)
print("labels:")
print(labels_)
for i in range(len(boxes_)):
x0, y0, x1, y1 = boxes_[i]
plot_one_box(img_ori, [x0, y0, x1, y1], label=args.classes[labels_[i]] + ', {:.2f}%'.format(scores_[i] * 100), color=color_table[labels_[i]])
#cv2.imshow('Detection result', img_ori) #显示图片,可以手动不显示
#cv2.imwrite('detection_result.jpg', img_ori)#原来的代码
#cv2.imwrite(args.input_image , img_ori) # ./checkpoint_5/Result #第一修改,图片保存到args.input_image 的路径下去了
cv2.imencode('.jpg', img_ori)[1].tofile('./checkpoint_5/Result/'+args.input_image[-10:-4]+'.jpg') #正确方法
#cv2.waitKey(0) #按键时图片显示桌面退出
最终将处理好的图片保存到了./checkpoint_5/Result/文件中,并以图片的名字命名。
结果:
版权声明:本文为weixin_43384257原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。