给出提升回归树算法的具体步骤:
2.2 举个栗子(统计学习方法P149)
代码实现:
import numpy as np
import math
'''
Input: x,y,splint_point_array(该特征的所有切分点)
Output:minloss(所有切分点中最小损失),最优切分点,残值数组(下一次在此基础进行拟合)
'''
def findBestSplitPoint(x_array,y_array,split_point_array):
minloss = np.inf
best_split_point = 0
best_c1 = 0
best_c2 = 0
best_R1 = 0
best_R2 = 0
#遍历切分点
for c in split_point_array:
loss = 0
# math.floor() 向下取整,划分R1/R2区域
R1 = y_array[:x_array.index(math.floor(c))+1]
R2 = y_array[x_array.index(math.floor(c))+1:]
c1 = np.mean(R1)
c2 = np.mean(R2)
for y in y_array:
if y in R1:
loss+=(y - c1)**2
else:
loss+=(y - c2)**2
if loss < minloss:
minloss = loss
best_split_point = c
best_c1 = c1
best_c2 = c2
best_R1 = R1
best_R2 = R2
#残值结果,用于下一次拟合
residual_array = np.hstack((best_R1-best_c1, best_R2-best_c2))
return minloss,best_split_point,list(residual_array)
x = [1,2,3,4,5,6,7,8,9,10]
y = [5.56,5.70,5.91,6.40,6.80,7.05,8.9,8.70,9.00,9.05]
#切分点
split_point_array = [1.5,2.5,3.5,4.5,6.5,7.5,8.5,9.5]
print(findBestSplitPoint(x,y,split_point_array))版权声明:本文为forestForQuietLive原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。