xml 没有root 节点怎么解析list_数据集与XML文件一起增强——功能实现

最近在对数据集进行增广训练时,发现仅仅对数据集中的JPEGImages文件进行增强时,需要将对应的xml文件Annotation一起增强。代码如下:

# -*- coding:utf8 -*-

from PIL import Image #python中的图像处理库PIL来实现不同图像格式的转换。
import numpy as np
# Element类型是一种灵活的容器对象,用于在内存中存储结构化数据。
import xml.etree.ElementTree as ET #
import os
import shutil
#该函数返回文件名字
def standardize(filename):
    filename = str(filename)
    c = len(filename)
    filename = '0'*(6-c)+filename
    return filename
#该函数返回xml中的图像根结点、图像的长和宽以及左上角和右下角的顶点坐标
def xml_parse(xml_path):
    width,height,label = 0, 0, 0
    tree = ET.parse(xml_path)#ET.parse:直接解析XML文件并获得根节点
    gtboxes, filename = [], ''#初始化变量为列表、字符串
    for child_root in tree.getroot():#tree.getroot():获得根节点

        if child_root.tag == 'size':#tag,即标签,用于标识该元素表示哪种数据
            for son_item in child_root:
                if son_item.tag == 'height':
                    height = int(son_item.text)#text,文本字符串,可以用来存储一些数据
                if son_item.tag == 'width':
                    width = int(son_item.text)

        if child_root.tag == 'object':##xml中的第二个标签
            sbox = []
            for son_root in child_root:
                if son_root.tag == 'bndbox':
                    for son_item in son_root:
                        ##list.append() 方法用于在列表末尾添加新的对象。该方法无返回值,但是会修改原来的列表。
                        sbox.append(int(son_item.text))
            gtboxes.append(sbox)

    return tree, height, width, np.array(gtboxes)#array()函数生成矩阵时数据只能为列表形式

def rotate(img, xml, degree, save_img_path, save_xml_path, filename, format='.png'):

    image = np.array(img)
    tree, height, width, gtboxes = xml_parse(xml)#%调用xml_parse函数
    gtboxes_copy = []
    # x0,y0,x1,y1

    if degree == 90:
        height, width = width, height
        image = np.rot90(image, 3)
        for boxes in gtboxes:
            gtboxes_copy.append([width - boxes[3], boxes[0], width - boxes[1], boxes[2]])
    elif degree == 180:
        image = np.rot90(image, 2)
        for boxes in gtboxes:
            gtboxes_copy.append([width - boxes[2], height - boxes[3], width - boxes[0], height - boxes[1]])
    elif degree == 270:
        height, width = width, height
        image = np.rot90(image)
        for boxes in gtboxes:
            gtboxes_copy.append([boxes[1], height - boxes[2], boxes[3], height - boxes[0]])

    for child_root in tree.getroot():#获得根结点
        if child_root.tag == 'filename':
            child_root.text = filename
        if child_root.tag == 'size':
            for son_item in child_root:
                if son_item.tag == 'height':
                    son_item.text = str(height)
                if son_item.tag == 'width':
                    son_item.text = str(width)

        if child_root.tag == 'object':
            sbox = gtboxes_copy[0]
            for son_root in child_root:
                if son_root.tag == 'bndbox':
                    for idx,son_item in enumerate(son_root):
                        son_item.text = str(sbox[idx])
            del gtboxes_copy[0]
    image = Image.fromarray(image)#使用PIL打开图片,并将其分离为RGB三个通道
    imgfilename = os.path.join(save_img_path, filename+format)

    image.save(imgfilename)
    xmlfilename = os.path.join(save_xml_path, filename+'.xml')
    tree.write(xmlfilename)

def flip(img, xml, type, save_img_path, save_xml_path, filename, format='.png'):

    tree, height, width, gtboxes = xml_parse(xml)
    gtboxes_copy = []
    # x0,y0,x1,y1

    if type == 'Up_Bottom':
        img = img.transpose(Image.FLIP_TOP_BOTTOM)
        for boxes in gtboxes:
            gtboxes_copy.append([boxes[0], height - boxes[3], boxes[2], height - boxes[1]])

    elif type == 'Left_Right':
        img = img.transpose(Image.FLIP_LEFT_RIGHT)
        for boxes in gtboxes:
            gtboxes_copy.append([width - boxes[2], boxes[1], width - boxes[0], boxes[3]])

    for child_root in tree.getroot():
        if child_root.tag == 'filename':
            child_root.text = filename
        if child_root.tag == 'size':
            for son_item in child_root:
                if son_item.tag == 'height':
                    son_item.text = str(height)
                if son_item.tag == 'width':
                    son_item.text = str(width)

        if child_root.tag == 'object':
            sbox = gtboxes_copy[0]
            for son_root in child_root:
                if son_root.tag == 'bndbox':
                    for idx,son_item in enumerate(son_root):
                        son_item.text = str(sbox[idx])
            del gtboxes_copy[0]
    imgfilename = os.path.join(save_img_path, filename+format)
    img.save(imgfilename)
    xmlfilename = os.path.join(save_xml_path, filename+'.xml')
    tree.write(xmlfilename)

if __name__=='__main__':

#注意:此时路径前两行为linux系统路径格式,后面两行是windows路径格式。
   # img_path, xml_path = '/media/数据备份/data_new/JPEGImages/', '/media/xhh/数据备份/data_new/Annotations/'
    #save_img_path, save_xml_path, = '/media/数据备份/AUG_Data/JPEGImages/', '/media/数据备份/AUG_Data/Annotations/'
    img_path, xml_path = 'F:oiltank-oriJPEGImages', 'F:oiltank-oriAnnotation'
    save_img_path, save_xml_path, = 'F:oiltank-orioiltank_dataAugJPEGImages', 'F:oiltank-orioiltank_dataAugAnnotation'
    #print(save_img_path)
    counter = 1

    for file in os.listdir(img_path):
        filename = file.split('.')[0]
        format = '.'+file.split('.')[-1]
        shutil.copy(os.path.join(img_path, file), os.path.join(save_img_path, standardize(counter)+format))

        tree, height, width, gtboxes = xml_parse(os.path.join(xml_path, filename+'.xml'))

        for child_root in tree.getroot():
            if child_root.tag == 'filename':
                child_root.text = standardize(counter)

        tree.write(os.path.join(save_xml_path, standardize(counter)+'.xml'))

        image = Image.open(os.path.join(img_path, file))
        xml = os.path.join(xml_path, filename+'.xml')

        rotate(image, xml, 90, save_img_path, save_xml_path, standardize(counter+1), format)
        rotate(image, xml, 180, save_img_path, save_xml_path, standardize(counter+2), format)
        rotate(image, xml, 270, save_img_path, save_xml_path, standardize(counter+3), format)

        flip(image, xml, 'Up_Bottom', save_img_path, save_xml_path, standardize(counter+4), format)
        flip(image, xml, 'Left_Right', save_img_path, save_xml_path, standardize(counter+5), format)

        counter += 6