MMdetection绘制mAP-自用

一、代码描述

  • 通过log文件提取mAP关键词
  • 提取响应mAP列表
  • 支持多图合一aparser可选参数

二、代码

import matplotlib.pyplot as plt
import os
import numpy as np
import argparse
import warnings

# 获得当前目录下的所有log,并返回info字典,方便命名
def get_logs(path):
    logs = os.listdir(path)
    info_dict = dict()
    for log in logs:
        dict_name = log.split('.')[0]
        if dict_name == '1280_RefineHead_detach':
            dict_name = '1280_RD'
        elif dict_name == '1280_RefineHead_nodetach':
            dict_name = '1280_RND'
        elif dict_name == '2048_RefineHead_detach':
            dict_name = '2048_RD'
        else:
            dict_name = '2048_RND'
        info_dict[dict_name] = log
    return info_dict

# 读取某一个log,并返回其中的bbox_mAP
def read_log(path):
    bbox_mAP = []
    with open(path,'r') as f:
        logs_data = f.readlines()
    for line in logs_data:
        ind = line.find('bbox_mAP:')
        if ind != -1:
            map50_95 = line.split(': ')[1].split(',')[0]
            bbox_mAP.append(np.float32(map50_95))
    return bbox_mAP

# 从info_dict获取所有log的mAP
def get_bbox_mAPs(main_path,info_dict):
    detach_1280_mAP = []
    nodetach_1280_mAP = []
    detach_2048_mAP = []
    nodetach_2048_mAP = []
    for key,value in info_dict.items():
        log_path = os.path.join(main_path,value)
        bbox_mAP = read_log(log_path)
        if key == '1280_RD':
            detach_1280_mAP = bbox_mAP
        elif key == '1280_RND':
            nodetach_1280_mAP = bbox_mAP
        elif key == '2048_RD':
            detach_2048_mAP = bbox_mAP
        else:
            nodetach_2048_mAP = bbox_mAP
    return detach_1280_mAP,nodetach_1280_mAP,detach_2048_mAP,nodetach_2048_mAP

# 多合一模式
def MultiInOnePlot(args,info_dict,y1,y2,y3,y4):
    def plot_multi(list,labels,color_map,marker,index='fusion'):
        # 获取mAP最大值对应的坐标
        coordinates = []
        for y in list:
            coordinates.append([np.argmax(y),np.max(y)])
        plt.figure()
        for i,y in enumerate(list):
            plt.plot(np.arange(len(y)),y,f'-{color_map[i]}{marker[i]}',markersize=3,label=labels[i])
            plt.axvline(x=coordinates[i][0],color=f'{color_map[i]}',linestyle='--')
            plt.annotate(text='({},{:.2})'.format(coordinates[i][0],coordinates[i][1]),xy=coordinates[i])
        plt.xlabel('epoch',fontsize=15)
        plt.ylabel('mAP',fontsize=15)
        # plt.title('Multi-resolution and multi-strategy comparison',fontsize=15)
        if args.multi2one:
            pass
        else:
            plt.ylim((0.13, 0.43))
        plt.legend(loc='best',fontsize=10)
        plt.savefig(f'Multi_in_one_{index}',dpi=300)
        plt.show()
    # 获取绘图标签
    labels = []
    for label in info_dict.keys():
        labels.append(label)
    if args.multi2one:
        y_list = [[y1, y2],[y3,y4]]
        color_map = [['r','c'],['m', 'g']]
        marker = [['s','d'],['^','v']]
        labels = [[labels[0],labels[1]],[labels[2],labels[3]]]
        for index in range(2):
            plot_multi(y_list[index],labels[index],color_map[index],marker[index],str(index))
    else:
        y_list = [y1,y2,y3,y4]
        color_map = ['r', 'c', 'm', 'g']
        marker = ['s', 'd', '^', 'v']
        plot_multi(y_list, labels, color_map,marker)

# 单图模式
def PlotOneByOne(info_dict,y1,y2,y3,y4):
    # 获取四条线的最大坐标位置
    list = [y1,y2,y3,y4]
    coordinates = []
    for y in list:
        coordinates.append([np.argmax(y), np.max(y)])
    # 设定画图要素
    labels = []
    for label in info_dict.keys():
        labels.append(label)
    color_map = ['r', 'c', 'm', 'g']
    marker = ['s', 'd', '^', 'v']
    plt.figure()
    for i in range(len(labels)):
        plt.plot(np.arange(len(list[i])), list[i], f'-{color_map[i]}{marker[i]}', markersize=4, label=labels[i])
        plt.axvline(x=coordinates[i][0], color='gray', linestyle='--')
        plt.annotate(text='({},{:.2})'.format(coordinates[i][0], coordinates[i][1]), xy=coordinates[i])
        plt.xlabel('epoch', fontsize=15),plt.ylabel('mAP', fontsize=15)
        plt.legend(loc='best',fontsize=13)
        plt.savefig(f'Single_{labels[i]}',dpi=300)
        plt.show()

# 设置可选参数并解析参数
def parse_args():
    parser = argparse.ArgumentParser(description='Plot descriptions for mmdetlogs')
    parser.add_argument('main_path',default='log_data',type=str)
    parser.add_argument('--multi_in_one',action='store_true',help='whether plot in one figure')
    parser.add_argument('--multi2one', action='store_true', help='whether plot in one figure')
    if parser is None:
        warnings.warn('No valid os_path,please set the main_path.'
                      'Whatsmore pay attention to Type.Type is str')
    pars = parser.parse_args()
    return pars

def main():
    args = parse_args()
    main_path = args.main_path
    info_dict = get_logs(main_path)
    detach_1280_mAP,nodetach_1280_mAP,detach_2048_mAP,nodetach_2048_mAP = \
        get_bbox_mAPs(main_path, info_dict)
    print(len(detach_1280_mAP),len(nodetach_1280_mAP),len(detach_2048_mAP),len(nodetach_2048_mAP))
    if args.multi_in_one:
        MultiInOnePlot(args,info_dict,detach_1280_mAP,nodetach_1280_mAP,detach_2048_mAP,nodetach_2048_mAP)
    else:
        PlotOneByOne(info_dict,detach_1280_mAP,nodetach_1280_mAP,detach_2048_mAP,nodetach_2048_mAP)

if __name__ == '__main__':
    main()

三 效果

在这里插入图片描述
在这里插入图片描述


版权声明:本文为q7672345原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。