李航 统计学习方法 例5.4 使用CART算法生成分类决策树

本程序采用CART算法实现分类决策树,具体理论部分参照李航统计学习方法决策树一章,实践操作参照机器学习实战决策树生成部分。

from numpy import *

def creatDataSet():
    dataset=[['青年','否','否','一般','否'],
             ['青年','否','否','好','否'],
             ['青年','是','否','好','是'],
             ['青年','是','是','一般','是'],
             ['青年','否','否','一般','否'],
             ['中年','否','否','一般','否'],
             ['中年','否','否','好','否'],
             ['中年','是','是','好','是'],
             ['中年','否','是','非常好','是'],
             ['中年','否','是','非常好','是'],
             ['老年','否','是','非常好','是'],
             ['老年','否','是','好','是'],
             ['老年','是','否','好','是'],
             ['老年','是','否','非常好','是'],
             ['老年','否','否','一般','否']]
    label=['年龄','有工作','有自己的房子','信贷情况']
    return dataset,label


# 计算某一个数据集的基尼指数
def gini(dataset):
    numdataset=len(dataset)
    classcount={}
    for data in dataset:
        curdata=data[-1]
        if curdata not in classcount.keys():
            classcount[curdata]=0
        classcount[curdata]+=1
    gini=1.0
    for key in classcount:
        p_1=float(classcount[key])/numdataset
        p_2=-(p_1**2)
        gini+=p_2
    return gini

# 划分数据集
def splitdata(dataset,axis,value):
    retDataset1=[]
    retDataset2=[]
    for featVec in dataset:
        # 将数据集中指定一特征等于某一指定值的划分为一类
        if featVec[axis] == value:
            retFetVec=featVec[:axis]
            retFetVec.extend(featVec[axis+1:])
            retDataset1.append(retFetVec)
        # 将数据集中指定一特征不等于某一指定值的划分为一类
        elif featVec[axis] != value:
            retFetVec=featVec[:axis]
            retFetVec.extend(featVec[axis+1:])
            retDataset2.append(retFetVec)
    return retDataset1,retDataset2


def choosebestsplit(dataSet):
    numFeature=len(dataSet[0])-1
    bestgini=inf
    bestvalue=-1
    bestfeature=-1
    for i in range(numFeature):
        features=[example[i] for example in dataSet]
        uniqueVals = list(set(features))
        for value in uniqueVals:
            splitset1,splitset2=splitdata(dataSet,i,value)
            Infogini=gini(splitset1)*(len(splitset1)/len(dataSet))+gini(splitset2)*(len(splitset2)/len(dataSet))
            if Infogini < bestgini:
                bestgini=Infogini
                bestfeature=i
                bestvalue=value
    return bestfeature,bestvalue

def majority(classList):
    classcount={}
    for vote in classList:
        if vote not in classcount.keys():
            classcount[vote]=0
        classcount[vote]+=1
    return classcount

# 创建决策树
def creatTree(dataSet,labels):
    classList=[example[-1] for example in dataSet]
    if classList.count(classList[0])==len(classList):
        return classList[0]
    if len(dataSet[0])==1:
        return majority(classList)
    bestfeature,bestvalue=choosebestsplit(dataSet)
    bestlabel=labels[bestfeature]
    myTree={bestlabel:{}}
    del(labels[bestfeature])
    data1, data2 = splitdata(dataSet, bestfeature, bestvalue)
    featvalue = [example[bestfeature] for example in dataSet]
    uniquevalue=list(set(featvalue))
    for value in uniquevalue:
        if value==bestvalue:
            myTree[bestlabel][value]=creatTree(data1,labels)
        else:
            myTree[bestlabel][value]=creatTree(data2,labels)
    return myTree


if __name__ == "__main__":
    dataSet,labels=creatDataSet()
    myTree=creatTree(dataSet,labels)
    print(myTree)



版权声明:本文为weixin_40230767原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。