学习笔记30-Top1和Top5定义与代码复现

定义

Top-1: Accuracy是指排名第一的类别与实际结果相符的准确率,就是你预测的label取最后概率向量里面最大的那一个作为预测结果,如过你的预测结果中概率最大的那个分类正确,则预测正确。否则预测错误。
Top-5: Accuracy是指排名前五的类别包含实际结果的准确率,就是最后概率向量最大的前五名中,只要出现了正确概率即为预测正确。否则预测错误。
TOP-5正确率=(所有测试图片中正确标签包含在前五个分类概率中的个数)除以(总的测试图片数)
TOP-5错误率=(所有测试图片中正确标签不在前五个概率中的个数)除以(总的测试图片数)
注意: 我们平时说的top1就是准确率,Accuracy和F1-Score这些是判断分类模型总体的标准。

代码复现

输入是模型输出(batch_size×num_of_class),目标label(num_of_class向量),元组(分别向求top几)

acc.py

import torch
def accu(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

# torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
# input:输入张量
# k:指定返回的前几位的值
# dim:排序的维度
# largest:返回最大值
# sorted:返回值是否排序
# out:可选输出张量

train.py

#计算Top1
                pred1_train, pred2_train = accu(outputs, lables, topk=(1, ))
                train_top1.update(pred1_train[0], val_images.size(0))
                #train_top2.update(pred2_train[0], val_images.size(0))
                t_top1 = train_top1.avg
                #t_top2 = train_top2.avg
#打印结果
print('[epoch %d] train_loss: %.3f  test_loss: %.3f val_accuracy: %.3f top1: %.4f' %
              (epoch + 1, running_loss / train_steps, testing_loss / test_steps , val_accurate, t_top1))

class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += float(val) * n
        self.count += n
        self.avg = self.sum / self.count

版权声明:本文为LZL2020LZL原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。