《机器学习 (周志华)》习题7.3答案

编程实现拉普拉斯修正的朴素贝叶斯,西瓜3.0训练集,“测1”样本测试。

书上求得的标准差是除以(N-1)即np.std(x, ddof=1)得到的,用与numpy直接用std计算结果存在偏差。

不加拉普拉斯修正跑的数据,部分和书上不一致(P(蜷缩,是)和P(凹陷,是)),经检查是书中错误。

只能通过测试样例,对count为0的数据存在bug,待修改。。。

代码如下:

# coding: utf-8
import math
import numpy as np

file = open('西瓜数据集3.csv'.decode('utf-8'))
filedata = [line.strip('\n').split(',')[1:] for line in file]
idx1 = filedata[0].index('密度')
idx2 = filedata[0].index('含糖率')
for i in range(1, len(filedata)):
	filedata[i][idx1] = float(filedata[i][idx1])
	filedata[i][idx2] = float(filedata[i][idx2])
filedata = filedata[1:]

def fit(filedata, lapula_correct=True):
	diff_class = {i:set() for i in range(len(filedata[0]))}
	for raw in filedata:
		for j in range(len(raw)):
			diff_class[j].add(raw[j])
	count = {}
	for raw in filedata:
		for j in range(len(raw)):
			label = raw[-1]
			# discrete attribute 
			if type(raw[j]) is not float:
				tup = (raw[j], label)
				count[tup] = (count.get(tup, [0])[0] + 1, len(diff_class[j]))
			# continuous attribute
			else:
				tup = (j, label)
				if tup not in count:
					count[tup] = [raw[j]]
				else:
					count[tup].append(raw[j])

	prob = {}
	total_case = len(filedata)
	for i in count:
		if type(count[i]) is list:
			mean = np.mean(count[i])
			std = np.std(count[i])
			# std = np.std(count[i], ddof=1)
			prob[i] = (mean, std)
		else:
			x, c = i
			if lapula_correct:
				if x == c:
					prob[x] = float(count[i][0] + 1) / (total_case + count[i][1])
				else:
					prob[i] = float(count[i][0] + 1) / (count[(c, c)][0] + count[i][1])
			else:
				if x == c:
					prob[x] = float(count[i][0]) / total_case
				else:
					prob[i] = float(count[i][0]) / count[(c, c)][0]

	return prob

def predict(data, prob):
	label = ['是', '否']
	p1, p2 = prob[label[0]], prob[label[1]]
	val = [np.log(p1), np.log(p2)]
	# val = [p1, p2]
	for i in data:
		for j in range(2):
			if type(i) is float:
				idx = data.index(i)
				tup = (idx, label[j])
				mean, std = prob[tup]
				p = np.exp(-(i-mean)**2/(2*std**2))/(np.sqrt(2*np.pi)*std)	
			else:
				tup = (i, label[j])
				p = prob[tup]
			val[j] += np.log(p)
			# val[j] *= p
	return max(label, key=lambda x:val[label.index(x)])


prob = fit(filedata)
test_data = ['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 0.697, 0.460]
res = predict(test_data, prob)
print 'res', res



版权声明:本文为Wiking__acm原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。