# -*- coding:utf-8 -*-
# kmeans简单实现
"""
kmeans是经典的无监督聚类方法。步骤可以分为以下几步:
1.确定聚类的数目k
2.随机初始化k个聚类中心
3.根据准则(一般计算欧式距离)将数据分配到对应的聚类中心
4.更新每个类别的聚类中心(均值)
5.重复步骤2-步骤4若干次
"""
import numpy as np
import random
from collections import defaultdict
import matplotlib.pyplot as plt
class KMeans():
def __init__(self, data, k, max_iter=5):
# 初始化变量
self._data = data # 数据
self._k = k # 簇
self._max_iter = max_iter # 最大迭代次数
self._example_num = data.shape[0] # 有多少个数据
self._centroids = None # 设置质心
self._cluster_data_indices = None # 具体的每个簇
def _random_init_centroid(self):
# 随机初始化质心,保证质心在数据点之中
# sample函数,随机选择一个点做初始化质心
random_centroid_indicss = random.sample(
range(3, self._example_num), self._k)
centroids = self._data[random_centroid_indicss]
self._centroids = centroids
@staticmethod
def get_closeset_centroid(data, centroids):
# 得到最近的质心
distance = np.sum(np.power((data - centroids), 2), 1)
indices = np.argmin(distance)
return indices
def _assign_data_closest_centroids(self):
# 创建空列表,用于存储每个簇
ddict = defaultdict(list)
for row in range(0, self._example_num):
# 遍历分出每个簇
closest_centroid = self.get_closeset_centroid(
self._data[row, :], self._centroids)
ddict[closest_centroid].append(row)
self._cluster_data_indices = ddict
def _update_centroids(self):
for i in range(self._k):
# 分别得到每个簇的索引
data_indices = self._cluster_data_indices[i]
# 分别得到每个簇新的质心
self._centroids[i] = np.mean(self._data[data_indices, :], axis=0)
def _plot_cluster(self):
plt.figure()
for i in range(0, self._k):
# 分别得到每个簇的索引
data_indices = self._cluster_data_indices[i]
# 画出每个簇的散点图
plt.scatter(self._data[data_indices, 0], self._data[data_indices, 1], s=10)
# 保存图
plt.savefig('means{}.png'.format(self._index))
def fit(self):
# 调整函数
# 初始化质心
self._random_init_centroid()
for i in range(self._max_iter):
self._index = 1
self._assign_data_closest_centroids()
self._update_centroids()
self._plot_cluster()
if __name__ == '__main__':
arra = np.random.uniform(1, 6, (200, 2))
arrb = np.random.uniform(4, 10, (200, 2))
arrc = np.random.uniform(9, 15, (200, 2))
arr = np.concatenate((arra, arrb, arrc), axis=0)
kmens = KMeans(arr, 3)
kmens.fit()
print("done")
版权声明:本文为k2325原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。