Yolo v3中anchor的k-means聚类代码分析

Yolo v3中的9个anchor是通过k-means聚类<自己数据集>的ground truth得到的。
本篇文章是我对聚类代码的分析

下边的代码是我根据github上一个作者的代码得到的,加了一丢丢自己的理解在注释里
这个代码可以直接用,没什么问题,只需把路径改为自己数据集标注的路径即可

原地址在这里:
https://github.com/PaulChongPeng/darknet/blob/master/tools/k_means_yolo.py

# BuSiNiao  2020/2/12.
# coding=utf-8
# k-means ++ for YOLOv2 anchors
# 通过k-means ++ 算法获取YOLOv2需要的anchors的尺寸
import numpy as np
# 定义Box类,描述bounding box的坐标


class Box:
    def __init__(self, x, y, w, h):
        self.x = x
        self.y = y
        self.w = w
        self.h = h


# 计算两个box在某个轴上的重叠部分
# x1是box1的中心在该轴上的坐标
# len1是box1在该轴上的长度
# x2是box2的中心在该轴上的坐标
# len2是box2在该轴上的长度
# 返回值是该轴上重叠的长度


def overlap(x1, len1, x2, len2):
    len1_half = len1 / 2
    len2_half = len2 / 2
    left = max(x1 - len1_half, x2 - len2_half)
    right = min(x1 + len1_half, x2 + len2_half)
    return right - left


# 计算box a 和box b 的交集面积
# a和b都是Box类型实例
# 返回值area是box a 和box b 的交集面积


def box_intersection(a, b):
    w = overlap(a.x, a.w, b.x, b.w)
    h = overlap(a.y, a.h, b.y, b.h)
    if w < 0 or h < 0:
        return 0
    area = w * h
    return area


# 计算 box a 和 box b 的并集面积
# a和b都是Box类型实例
# 返回值u是box a 和box b 的并集面积


def box_union(a, b):
    i = box_intersection(a, b)
    u = a.w * a.h + b.w * b.h - i
    return u


# 计算 box a 和 box b 的 iou
# a和b都是Box类型实例
# 返回值是box a 和box b 的iou


def box_iou(a, b):
    return box_intersection(a, b) / box_union(a, b)


# 使用k-means ++ 初始化 centroids,减少随机初始化的centroids对最终结果的影响   centroids:最开始给的框
# boxes是所有bounding boxes的Box对象列表     就是所有的真正的标签中的正确的框的合集 数组
# n_anchors是k-means的k值      就是需要n个anchor_box
# 返回值centroids 是初始化的n_anchors个centroid     就是说这个函数的返回值就是初始化的几个anchor boxes
# k means++  尽可能的使初始的类的距离远远远

def init_centroids(boxes, n_anchors):
    centroids = []
    boxes_num = len(boxes)   # 是一个整数
    centroid_index = np.random.choice(boxes_num, 1)
    # choice(a, size=None, replace=True, p=None) 从a中选取size个数,r为正表示可以重复,p表示概率,默认为均分
    centroids.append(boxes[centroid_index])   # append() 方法用于在列表末尾添加新的对象
    print(centroids[0].w, centroids[0].h)
    for centroid_index in range(0, n_anchors-1):
        sum_distance = 0
        distance_thresh = 0
        distance_list = []
        cur_sum = 0
        for box in boxes:
            min_distance = 1
            for centroid_i, centroid in enumerate(centroids):
                distance = (1 - box_iou(box, centroid))
                if distance < min_distance:
                    min_distance = distance
            sum_distance += min_distance
            distance_list.append(min_distance)
        distance_thresh = sum_distance*np.random.random()
        for i in range(0,boxes_num):
            cur_sum += distance_list[i]
            if cur_sum > distance_thresh:
                centroids.append(boxes[i])
                print(boxes[i].w, boxes[i].h)
                break
    return centroids


# 进行 k-means 计算新的centroids
# boxes是所有bounding boxes的Box对象列表
# n_anchors是k-means的k值
# centroids是所有簇的中心
# 返回值new_centroids 是计算出的新簇中心
# 返回值groups是n_anchors个簇包含的boxes的列表    就是属于每一个中心类别的boxes的集合
# 返回值loss是所有box距离所属的最近的centroid的距离的和


def do_kmeans(n_anchors, boxes, centroids):
    loss = 0
    groups = []
    new_centroids = []
    for i in range(n_anchors):          # 创建了一个有n个元素的空的groups和一个空的new_centroids
        groups.append([])
        new_centroids.append(Box(0, 0, 0, 0))
    for box in boxes:
        min_distance = 1   # 完全重合时,距离是0;完全不重合时,距离是1(此时为距离最远)
        group_index = 0
        for centroid_index, centroid in enumerate(centroids):   # 给第i个“box”分类
            distance = (1 - box_iou(box, centroid))
            if distance < min_distance:    # 当距离小于1时表示两个box有重合
                min_distance = distance
                group_index = centroid_index   # 找到box所属的类,然后把类索引放到
        groups[group_index].append(box)    # 找到box所属的类,然后把这个box放到group中
        loss += min_distance     # 每分类一个box,就叠加一次loss
        new_centroids[group_index].w += box.w
        # 每分类一次box,就将他的宽度叠加到这个类别的box的宽度上,最后得到一个属于此类别的宽度的总和,以便求均值
        new_centroids[group_index].h += box.h
        # 每分类一次box,就将他的高度叠加到这个类别的box的高度上,最后得到一个属于此类别的高度的总和,以便求均值
    for i in range(n_anchors):    # 以均值形成新的类别box
        new_centroids[i].w /= len(groups[i])
        new_centroids[i].h /= len(groups[i])
    return new_centroids, groups, loss


# 计算给定bounding boxes的n_anchors数量的centroids
# label_path是训练集列表文件地址
# n_anchors 是anchors的数量
# loss_convergence是允许的loss的最小变化值     阈值:结束迭代的判断标准
# grid_size * grid_size 是栅格数量    网格是几成几的网格?
# iterations_num是最大迭代次数
# plus = 1时启用k means ++ 初始化centroids


def compute_centroids(label_path, n_anchors, loss_convergence, grid_size, iterations_num, plus):
    boxes = []
    # label_files = []
    # f = open(label_path)
    # for line in f:
    #     label_path = line.rstrip().replace('images', 'labels')  # rstrip 对字符串的只有尾部删除指定字符,默认为空格
    #     label_path = label_path.replace('JPEGImages', 'labels')
    #     label_path = label_path.replace('.jpg', '.txt')
    #     label_path = label_path.replace('.JPEG', '.txt')
    #     label_files.append(label_path)
    # f.close()
    # for label_file in label_files:       # label_files中是一组一组的标签地址
    #     f = open(label_file)
    #     for line in f:
    #         temp = line.strip().split(" ")
    #         # strip 对字符串的头部尾部都删除指定字符,默认为空格;spilt将一组字符串切片,默认以空格切片,返回字符串列表
    #         if len(temp) > 1:
    #             boxes.append(Box(0, 0, float(temp[2]), float(temp[3])))    # 所有的box中心都在原点
    # 这一步之后就得到了boxes

    # 得到我的boxes
    f = open(label_path)
    for line in f:
        temp = line.strip().split(" ")
        if len(temp) > 1:
            boxes.append(Box(0, 0, float(temp[2]), float(temp[3])))  # 所有的box中心都在原点
    f.close()
    if plus:
        centroids = init_centroids(boxes, n_anchors)
    else:
        centroid_indices = np.random.choice(len(boxes), n_anchors)       # 在len长度的数字上随机抽取n个数字,返回的是数组
        centroids = []
        # print(centroid_indices)
        for centroid_index in centroid_indices:
            centroids.append(boxes[centroid_index])
        # print('******init******', centroids)
    # iterate k-means
    centroids, groups, old_loss = do_kmeans(n_anchors, boxes, centroids)
    iterations = 1
    i = 0
    while True:
        i = i+1
        centroids, groups, loss = do_kmeans(n_anchors, boxes, centroids)
        iterations = iterations + 1
        print("number:", i, "************loss = %f" % loss)
        if abs(old_loss - loss) < loss_convergence or iterations > iterations_num:
            break
        old_loss = loss
        for centroid in centroids:
            print(centroid.w * grid_size, centroid.h * grid_size)   #
    # print result
    for centroid in centroids:
        print("k-means result:\n")
        print(centroid.w * grid_size, centroid.h * grid_size)


# label_path = "/raid/pengchong_data/Data/Lists/paul_train.txt"
label_path = "G:/iSAID annotation/bbox.txt"
n_anchors = 5
loss_convergence = 1e-6
grid_size = 1
iterations_num = 100
plus = 0
print('**************Start**************')
compute_centroids(label_path, n_anchors, loss_convergence, grid_size, iterations_num, plus)
print('*********Already Finish*********')