本程序采用CART算法实现分类决策树,具体理论部分参照李航统计学习方法决策树一章,实践操作参照机器学习实战决策树生成部分。
from numpy import *
def creatDataSet():
dataset=[['青年','否','否','一般','否'],
['青年','否','否','好','否'],
['青年','是','否','好','是'],
['青年','是','是','一般','是'],
['青年','否','否','一般','否'],
['中年','否','否','一般','否'],
['中年','否','否','好','否'],
['中年','是','是','好','是'],
['中年','否','是','非常好','是'],
['中年','否','是','非常好','是'],
['老年','否','是','非常好','是'],
['老年','否','是','好','是'],
['老年','是','否','好','是'],
['老年','是','否','非常好','是'],
['老年','否','否','一般','否']]
label=['年龄','有工作','有自己的房子','信贷情况']
return dataset,label
# 计算某一个数据集的基尼指数
def gini(dataset):
numdataset=len(dataset)
classcount={}
for data in dataset:
curdata=data[-1]
if curdata not in classcount.keys():
classcount[curdata]=0
classcount[curdata]+=1
gini=1.0
for key in classcount:
p_1=float(classcount[key])/numdataset
p_2=-(p_1**2)
gini+=p_2
return gini
# 划分数据集
def splitdata(dataset,axis,value):
retDataset1=[]
retDataset2=[]
for featVec in dataset:
# 将数据集中指定一特征等于某一指定值的划分为一类
if featVec[axis] == value:
retFetVec=featVec[:axis]
retFetVec.extend(featVec[axis+1:])
retDataset1.append(retFetVec)
# 将数据集中指定一特征不等于某一指定值的划分为一类
elif featVec[axis] != value:
retFetVec=featVec[:axis]
retFetVec.extend(featVec[axis+1:])
retDataset2.append(retFetVec)
return retDataset1,retDataset2
def choosebestsplit(dataSet):
numFeature=len(dataSet[0])-1
bestgini=inf
bestvalue=-1
bestfeature=-1
for i in range(numFeature):
features=[example[i] for example in dataSet]
uniqueVals = list(set(features))
for value in uniqueVals:
splitset1,splitset2=splitdata(dataSet,i,value)
Infogini=gini(splitset1)*(len(splitset1)/len(dataSet))+gini(splitset2)*(len(splitset2)/len(dataSet))
if Infogini < bestgini:
bestgini=Infogini
bestfeature=i
bestvalue=value
return bestfeature,bestvalue
def majority(classList):
classcount={}
for vote in classList:
if vote not in classcount.keys():
classcount[vote]=0
classcount[vote]+=1
return classcount
# 创建决策树
def creatTree(dataSet,labels):
classList=[example[-1] for example in dataSet]
if classList.count(classList[0])==len(classList):
return classList[0]
if len(dataSet[0])==1:
return majority(classList)
bestfeature,bestvalue=choosebestsplit(dataSet)
bestlabel=labels[bestfeature]
myTree={bestlabel:{}}
del(labels[bestfeature])
data1, data2 = splitdata(dataSet, bestfeature, bestvalue)
featvalue = [example[bestfeature] for example in dataSet]
uniquevalue=list(set(featvalue))
for value in uniquevalue:
if value==bestvalue:
myTree[bestlabel][value]=creatTree(data1,labels)
else:
myTree[bestlabel][value]=creatTree(data2,labels)
return myTree
if __name__ == "__main__":
dataSet,labels=creatDataSet()
myTree=creatTree(dataSet,labels)
print(myTree)
版权声明:本文为weixin_40230767原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。