GBDT回归树代码实现(统计学习方法例子)

给出提升回归树算法的具体步骤:

2.2 举个栗子(统计学习方法P149)

代码实现:

import numpy as np
import math
'''
Input: x,y,splint_point_array(该特征的所有切分点)
Output:minloss(所有切分点中最小损失),最优切分点,残值数组(下一次在此基础进行拟合)
'''
def findBestSplitPoint(x_array,y_array,split_point_array):
    minloss = np.inf
    best_split_point = 0
    best_c1 = 0
    best_c2 = 0
    best_R1 = 0
    best_R2 = 0
    #遍历切分点
    for c in split_point_array:
        loss = 0
         # math.floor() 向下取整,划分R1/R2区域
        R1 = y_array[:x_array.index(math.floor(c))+1]
        R2 = y_array[x_array.index(math.floor(c))+1:]

        c1 = np.mean(R1)
        c2 = np.mean(R2)

        for y in y_array:
            if y in R1:
                loss+=(y - c1)**2
            else:
                loss+=(y - c2)**2

        if loss < minloss:
            minloss = loss
            best_split_point = c
            best_c1 = c1
            best_c2 = c2
            best_R1 = R1
            best_R2 = R2
    #残值结果,用于下一次拟合
    residual_array = np.hstack((best_R1-best_c1, best_R2-best_c2))

    return minloss,best_split_point,list(residual_array)

x = [1,2,3,4,5,6,7,8,9,10]
y = [5.56,5.70,5.91,6.40,6.80,7.05,8.9,8.70,9.00,9.05]
#切分点
split_point_array = [1.5,2.5,3.5,4.5,6.5,7.5,8.5,9.5]
print(findBestSplitPoint(x,y,split_point_array))

参考链接:https://www.cnblogs.com/ModifyRong/p/7744987.html


版权声明:本文为forestForQuietLive原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。