ID3算法生成决策树并绘制的代码

1. tree.py是ID3算法生成决策树的代码

2. treePlotter.py是将决策树绘制出来的代码

# tree.py
from math import log
import treePlotter


def calc_shannon_ent(dataset):
    num_entries = len(dataset)
    label_counts = {}

    # 为所有可能分类创建字典
    for feat_vec in dataset:
        current_label = feat_vec[-1]
        if current_label not in label_counts.keys():
            label_counts[current_label] = 0
        label_counts[current_label] += 1

    shannon_ent = 0.0
    for key in label_counts:
        prob = float(label_counts[key]) / num_entries
        shannon_ent -= prob * log(prob, 2)
    return shannon_ent


def create_dataset():
    dataset = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataset, labels


# 按照给定的特征划分数据集
def split_dataset(dataset, axis, value):
    ret_dataset = []
    for feat_vec in dataset:
        # 选出dataset中axis特征=value的样本
        if feat_vec[axis] == value:
            # 把axis这个特征去掉
            reduced_feat_vec = feat_vec[:axis]
            reduced_feat_vec.extend(feat_vec[axis + 1:])
            # 加入到需要返回的数组里面
            ret_dataset.append(reduced_feat_vec)
    return ret_dataset


# 选择最好的数据集划分方式
def choose_best_feature_to_split(dataset):
    """
    遍历整个数据集,循环计算香农熵和split_dataset()函数,找到最好的特征划分方式
    :param dataset:
    :return:
    """
    num_features = len(dataset[0]) - 1
    base_entropy = calc_shannon_ent(dataset)
    best_info_gain = 0.0
    best_feature = -1

    for i in range(num_features):
        # 创建唯一的分类标签列表
        feat_list = [example[i] for example in dataset]
        unique_vals = set(feat_list)
        new_entropy = 0.0
        for value in unique_vals:
            sub_dataset = split_dataset(dataset, i, value)
            prob = len(sub_dataset) / float(len(dataset))
            new_entropy += prob * calc_shannon_ent(sub_dataset)
        info_gain = base_entropy - new_entropy
        if info_gain > best_info_gain:
            best_info_gain = info_gain
            best_feature = i
    return best_feature


# 多数表决,决定叶子节点的分类
def majority_cnt(class_list):
    class_count = {}
    for vote in class_list:
        if vote not in class_count.keys():
            class_count[vote] = 0
        class_count[vote] += 1
    sorted_class_count = sorted(class_count.items(), key=lambda kv: kv[1], reverse=True)
    return sorted_class_count[0][0]


# 创建决策树
def create_tree(dataset, labels):
    class_list = [example[-1] for example in dataset]
    # 如果所有的样本都是同一个标签,不用继续分,叶子节点就是这个标签
    if class_list.count(class_list[0]) == len(class_list):
        return class_list[0]
    # 如果没有特征可以用,则多数表决,决定叶子的标签
    if len(dataset[0]) == 1:
        return majority_cnt(class_list)
    best_feat = choose_best_feature_to_split(dataset)
    best_feat_label = labels[best_feat]
    my_tree = {best_feat_label: {}}
    del(labels[best_feat])
    feat_values = [example[best_feat] for example in dataset]
    unique_vals = set(feat_values)
    for value in unique_vals:
        sub_labels = labels[:]
        my_tree[best_feat_label][value] = create_tree(split_dataset(dataset, best_feat, value), sub_labels)
    return my_tree


# 使用决策树的分类函数
def classify(input_tree, feat_labels, test_vec):
    first_str = list(input_tree.keys())[0]
    second_dict = input_tree[first_str]
    feat_index = feat_labels.index(first_str)
    for key in second_dict.keys():
        if test_vec[feat_index] == key:
            if type(second_dict[key]).__name__ == 'dict':
                class_label = classify(second_dict[key], feat_labels, test_vec)
            else:
                class_label = second_dict[key]
    return class_label


# 使用pickle模块存储决策树
def store_tree(input_tree, filename):
    import pickle
    fw = open(filename, 'wb')
    pickle.dump(input_tree, fw)
    fw.close()


# 从文件中取出决策树
def grab_tree(filename):
    import pickle
    fr = open(filename, 'rb')
    return pickle.load(fr)


def test():
    # 测试香农熵
    my_dataset, labels = create_dataset()
    print(calc_shannon_ent(my_dataset))
    my_dataset[0][-1] = 'maybe'
    print(calc_shannon_ent(my_dataset))


def test2():
    # 测试分割数据集
    my_dat, labels = create_dataset()
    print(split_dataset(my_dat, 0, 1))
    print(split_dataset(my_dat, 0, 0))


def test3():
    my_dat, labels = create_dataset()
    print(choose_best_feature_to_split(my_dat))
    print(my_dat)


def test4():
    my_dat, labels = create_dataset()
    my_tree = create_tree(my_dat, labels)
    print(my_tree)


def test5():
    my_dat, labels = create_dataset()
    my_tree = treePlotter.retrieve_tree(0)
    print(classify(my_tree, labels, [1, 0]))
    print(classify(my_tree, labels, [1, 1]))


def test6():
    # my_dat, labels = create_dataset()
    my_tree = treePlotter.retrieve_tree(0)
    store_tree(my_tree, 'classifierStorage.txt')
    tree = grab_tree('classifierStorage.txt')
    print(tree)


# 使用决策树预测隐形眼镜类型
def test7():
    fr = open('./lenses.txt')
    lenses = [inst.strip().split('\t') for inst in fr.readlines()]
    lenses_labels = ['age', 'prescript', 'astigmatic', 'tearRate']
    lenses_tree = create_tree(lenses, lenses_labels)
    print(lenses_tree)

    treePlotter.create_plot(lenses_tree)


if __name__ == '__main__':
    test7()

# treePlotter.py
import matplotlib.pyplot as plt

# 使用文本注释绘制树节点
decision_node = dict(boxstyle='sawtooth', fc='0.8')
leaf_node = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle='<-')


def plot_node(node_txt, center_pt, parent_pt, node_type):
    # create_plot.ax1.annotate(node_txt, xy=parent_pt, xycoords='axes fraction', xytext=center_pt, textcoords='axes fraction', va='center', ha='center', bbox=node_type, arrowprops=arrow_args)
    create_plot.ax1.annotate(node_txt, xy=parent_pt, xycoords='axes fraction',
                            xytext=center_pt, textcoords='axes fraction',
                            va="center", ha="center", bbox=node_type, arrowprops=arrow_args)


def create_plot():
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    create_plot.ax1 = plt.subplot(111, frameon=False)
    plot_node(U'决策节点', (0.5, 0.1), (0.1, 0.5), decision_node)
    plot_node(U'叶节点', (0.8, 0.1), (0.3, 0.8), leaf_node)
    plt.show()


# 获取叶节点的数目和树的层数
def get_num_leafs(my_tree):
    num_leafs = 0
    first_str = list(my_tree.keys())[0]
    second_dict = my_tree[first_str]
    for key in second_dict.keys():
        if type(second_dict[key]).__name__ == 'dict':
            num_leafs += get_num_leafs(second_dict[key])
        else:
            num_leafs += 1
    return num_leafs


def get_tree_depth(my_tree):
    max_depth = 0
    first_str = list(my_tree.keys())[0]
    second_dict = my_tree[first_str]
    for key in second_dict.keys():
        if type(second_dict[key]).__name__ == 'dict':
            this_depth = 1 + get_tree_depth(second_dict[key])
        else:
            this_depth = 1
        if this_depth > max_depth:
            max_depth = this_depth
    return max_depth


def retrieve_tree(i):
    """
    输出预先存储的树信息,避免每次测试代码时都要从数据中创建树的麻烦
    :param i:
    :return:
    """
    list_of_trees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                     {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]

    return list_of_trees[i]


def test():
    # retrieve_tree(1)
    my_tree = retrieve_tree(0)
    print(get_num_leafs(my_tree))
    print(get_tree_depth(my_tree))


def plot_mid_text(cntr_pt, parent_pt, txt_string):
    x_mid = (parent_pt[0] - cntr_pt[0]) / 2.0 + cntr_pt[0]
    y_mid = (parent_pt[1] - cntr_pt[1]) / 2.0 + cntr_pt[1]
    create_plot.ax1.text(x_mid, y_mid, txt_string)


def plot_tree(my_tree, parent_pt, node_txt):
    num_leafs = get_num_leafs(my_tree)
    depth = get_tree_depth(my_tree)
    first_str = list(my_tree.keys())[0]
    cntr_pt = (plot_tree.x_off + (1.0 + float(num_leafs)) / 2.0 /plot_tree.total_w, plot_tree.y_off)

    plot_mid_text(cntr_pt, parent_pt, node_txt)
    plot_node(first_str, cntr_pt, parent_pt, decision_node)
    second_dict = my_tree[first_str]
    plot_tree.y_off = plot_tree.y_off - 1.0 / plot_tree.total_d
    for key in second_dict.keys():
        if type(second_dict[key]).__name__ == 'dict':
            plot_tree(second_dict[key], cntr_pt, str(key))
        else:
            plot_tree.x_off = plot_tree.x_off + 1.0 / plot_tree.total_w
            plot_node(second_dict[key], (plot_tree.x_off, plot_tree.y_off), cntr_pt, leaf_node)
            plot_mid_text((plot_tree.x_off, plot_tree.y_off), cntr_pt, str(key))
    plot_tree.y_off = plot_tree.y_off + 1.0 / plot_tree.total_d


def create_plot(in_tree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    create_plot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plot_tree.total_w = float(get_num_leafs(in_tree))
    plot_tree.total_d = float(get_tree_depth(in_tree))
    plot_tree.x_off = -0.5 / plot_tree.total_w
    plot_tree.y_off = 1.0
    plot_tree(in_tree, (0.5, 1.0), '')
    plt.show()


def test2():
    my_tree = retrieve_tree(0)
    create_plot(my_tree)
    my_tree['no surfacing'][3] = 'maybe'
    create_plot(my_tree)


if __name__ == '__main__':
    test2()


版权声明:本文为Atticus_zhang原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。