CART(三)

四、剪枝

剪枝有预剪枝和后剪枝,预剪枝就是在树生成的过程中,加上一些限制条件使得树不会过度分裂,在上一节代码中,已经加上了预剪枝。

下面重点讲后剪枝。

后剪枝算法:

 输入:已经生成的树

 输出:剪枝后的树

 步骤:

 (1)如果存在任一子集是一棵树,则在该子集递归剪枝过程

 (2)计算将当前两个叶节点合并后的误差

 (3)计算不合并的误差

 (4)如果合并会降低误差,则将叶节点合并


def isTree(obj):
    return (type(obj).__name__=='dict')

def getMean(tree):
    if isTree(tree['right']):tree['right']=getMean(tree['right'])
    if isTree(tree['left']):tree['left']=getMean(tree['left'])
    return (tree['left']+tree['right'])/2.0

def prune(tree,testData):
    if shape(testData)[0]==0:return getMean(tree)
    if (isTree(tree['right']) or isTree(tree['left'])):
        lSet,rSet=binSplitDataSet(testData,tree['spInd'],tree['spVal'])
    if isTree(tree['left']):tree['left']=prune(tree['left'],lSet)
    if isTree(tree['right']):tree['right']=prune(tree['right'],rSet)
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet,rSet=binSplitDataSet(testData,tree['spInd'],tree['spVal'])
        errorNoMerge=sum(power(lSet[:,-1]-tree['left'],2))+sum(power(rSet[:,-1]-tree['right'],2))
        treeMean=(tree['left']+tree['right'])/2.0
        errorMerge=sum(power(testData[:,-1]-treeMean,2))
        if errorMerge<errorNoMerge:
            print "merging"
            return treeMean
        else:
            return tree
    else:
        return tree



版权声明:本文为xiaocong1990原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。