Python实现LDA降维过程可视化

import random
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D 

class MyModel(object):

	def __init__(self):
		self.x1,self.x2,self.y1,self.y2,self.z1,self.z2 = self.get_data()

	def get_data(self,x1Max=2.2,y1Max=2.2,z1Max=2.2,\
					  x1Min=1,y1Min=1,z1Min=1,\
					  x2Max=3,y2Max=3,z2Max=3,\
					  x2Min=1.7,y2Min=1.7,z2Min=1.7,numData=30):

		x1,x2,y1,y2,z1,z2 = [],[],[],[],[],[]

		for i in range(numData):
			x1.append(random.uniform(x1Min,x1Max))
			x2.append(random.uniform(x2Min,x2Max))
			y1.append(random.uniform(y1Min,y1Max))
			y2.append(random.uniform(y2Min,y2Max))
			z1.append(random.uniform(z1Min,z1Max))
			z2.append(random.uniform(z2Min,z2Max))

		
		return x1,x2,y1,y2,z1,z2

	def get_centre(self):
		centre_1 = []
		centre_2 = []

		for item in [self.x1,self.y1,self.z1]:
			centre_1.append(sum(item)/len(item))

		for item in [self.x2,self.y2,self.z2]:
			centre_2.append(sum(item)/len(item))

		return centre_1,centre_2

	def get_LDA(self):
		centre_1,centre_2 = self.get_centre()
		lda = np.round(list(np.array(centre_1) - np.array(centre_2)),3)
		
		return lda

	def ues_LDA(self,vec):
		lda = np.array(self.get_LDA())
		vec = np.array(vec)

		return np.dot(lda,vec)


	def show_data_3D(self):
		ax = Axes3D(plt.figure())
		ax.scatter(self.x1,self.y1,self.z1,c='r')
		ax.scatter(self.x2,self.y2,self.z2,c='b')

		centre_1,centre_2 = self.get_centre()
		x=[centre_1[0],centre_2[0]]
		y=[centre_1[1],centre_2[1]]
		z=[centre_1[2],centre_2[2]]
		ax.scatter(x,y,z,c='pink')	
		ax.plot(x,y,z,c='yellow')

		plt.show()

	def show_data_1D(self):
		plt.figure()

		vec = []
		for i in range(len(self.x1)):
			vec.append(self.ues_LDA([self.x1[i],self.y1[i],self.z1[i]]))
		plt.scatter(vec,list(np.zeros(len(vec),int)),s=10,c='b')

		vec = []
		for i in range(len(self.x2)):
			vec.append(self.ues_LDA([self.x2[i],self.y2[i],self.z2[i]]))

		plt.scatter(vec,list(np.zeros(len(vec),int)),s=10,c='r')
		
		plt.show()

if __name__ == "__main__":
	myModel = MyModel()
	myModel.show_data_1D()
	myModel.show_data_3D()

版权声明:本文为weixin_43628432原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。