博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
实现机器学习的循序渐进指南X——KMeans
阅读量:3523 次
发布时间:2019-05-20

本文共 5626 字,大约阅读时间需要 18 分钟。

目录


可访问 ,获取本系列完成文章列表。 

介绍

KMeans是一种简单的聚类算法,它计算样本和质心之间的距离以生成聚类。K是用户给出的簇数。在初始时,随机选择K个簇,在每次迭代时调整它们以获得最佳结果。质心是其相应聚类中样本的平均值,因此,该算法称为K“均值

KMeans模型

KMEANS

普通的KMeans算法非常简单。将K簇表示为,并将每个簇中的样本数表示为  。KMeans的损失函数是:

计算损失函数的导数:

然后,使导数等于0,我们可以得到:

即,质心是其相应聚类中的样本的手段。KMeans的代码如下所示:

def kmeans(self, train_data, k):    sample_num = len(train_data)    distances = np.zeros([sample_num, 2])                      # (index, distance)    centers = self.createCenter(train_data)    centers, distances = self.adjustCluster(centers, distances, train_data, self.k)    return centers, distances

adjustCluster()确定初始质心后调整质心的位置,旨在最小化损失函数。adjustCluster的代码是:

def adjustCluster(self, centers, distances, train_data, k):    sample_num = len(train_data)    flag = True  # If True, update cluster_center    while flag:        flag = False        d = np.zeros([sample_num, len(centers)])        for i in range(len(centers)):            # calculate the distance between each sample and each cluster center            d[:, i] = self.calculateDistance(train_data, centers[i])        # find the minimum distance between each sample and each cluster center        old_label = distances[:, 0].copy()        distances[:, 0] = np.argmin(d, axis=1)        distances[:, 1] = np.min(d, axis=1)        if np.sum(old_label - distances[:, 0]) != 0:            flag = True            # update cluster_center by calculating the mean of each cluster            for j in range(k):                current_cluster =                      train_data[distances[:, 0] == j]  # find the samples belong                                                        # to the j-th cluster center                if len(current_cluster) != 0:                    centers[j, :] = np.mean(current_cluster, axis=0)    return centers, distances

平分KMeans

因为KMeans可能得到局部优化结果,为了解决这个问题,我们引入了另一种称为平分KMeansKMeans算法。在平分KMeans时,所有样本在初始时被视为一个簇,并且簇被分成两部分。然后,选择分割簇的一部分一次又一次地平分,直到簇的数量为K。我们根据最小平方误差和SSE)对簇进行平分。将当前n群集表示为:

我们选择一个集群中的,并平分它分成两个部分使用正常的K均值。SSE是:

我们选择,其可以获得最小的SSE作为要被平分的集群,即:

重复上述过程,直到簇的数量为K

def biKmeans(self, train_data):    sample_num = len(train_data)    distances = np.zeros([sample_num, 2])         # (index, distance)    initial_center = np.mean(train_data, axis=0)  # initial cluster #shape (1, feature_dim)    centers = [initial_center]                    # cluster list    # clustering with the initial cluster center    distances[:, 1] = np.power(self.calculateDistance(train_data, initial_center), 2)    # generate cluster centers    while len(centers) < self.k:        # print(len(centers))        min_SSE  = np.inf        best_index = None        best_centers = None        best_distances = None        # find the best split        for j in range(len(centers)):            centerj_data = train_data[distances[:, 0] == j]   # find the samples belonging                                                              # to the j-th center            split_centers, split_distances = self.kmeans(centerj_data, 2)            split_SSE = np.sum(split_distances[:, 1]) ** 2    # calculate the distance                                                              # for after clustering            other_distances = distances[distances[:, 0] != j] # the samples don't belong                                                              # to j-th center            other_SSE = np.sum(other_distances[:, 1]) ** 2    # calculate the distance                                                              # don't belong to j-th center            # save the best split result            if (split_SSE + other_SSE) < min_SSE:                best_index = j                best_centers = split_centers                best_distances = split_distances                min_SSE = split_SSE + other_SSE        # save the spilt data        best_distances[best_distances[:, 0] == 1, 0] = len(centers)        best_distances[best_distances[:, 0] == 0, 0] = best_index        centers[best_index] = best_centers[0, :]        centers.append(best_centers[1, :])        distances[distances[:, 0] == best_index, :] = best_distances    centers = np.array(centers)   # transform form list to array    return centers, distances

KMEANS ++

因为初始质心对KMeans的性能有很大影响,为了解决这个问题,我们引入了另一种名为KMeans ++KMeans算法。将当前n群集表示为:

当我们选择第(n + 1)个质心时,样本离现有质心越远,它就越有可能被选为新的质心。首先,我们计算每个样本与现有质心之间的最小距离:

然后,计算每个样本被选为下一个质心的概率:

然后,我们使用轮盘选择来获得下一个质心。确定K群集后,运行adjustCluster()以调整结果。

def kmeansplusplus(self,train_data):    sample_num = len(train_data)    distances = np.zeros([sample_num, 2])       # (index, distance)    # randomly select a sample as the initial cluster    initial_center = train_data[np.random.randint(0, sample_num-1)]    centers = [initial_center]    while len(centers) < self.k:        d = np.zeros([sample_num, len(centers)])        for i in range(len(centers)):            # calculate the distance between each sample and each cluster center            d[:, i] = self.calculateDistance(train_data, centers[i])        # find the minimum distance between each sample and each cluster center        distances[:, 0] = np.argmin(d, axis=1)        distances[:, 1] = np.min(d, axis=1)        # Roulette Wheel Selection        prob = np.power(distances[:, 1], 2)/np.sum(np.power(distances[:, 1], 2))        index = self.rouletteWheelSelection(prob, sample_num)        new_center = train_data[index, :]        centers.append(new_center)    # adjust cluster    centers = np.array(centers)   # transform form list to array    centers, distances = self.adjustCluster(centers, distances, train_data, self.k)    return centers, distances

结论与分析

实际上,在确定如何选择参数'K'之后有必要调整簇,存在一些算法。最后,让我们比较三种聚类算法的性能。

可以发现KMeans ++具有最佳性能。

可以在找到本文中的相关代码和数据集。

有兴趣的小伙伴可以查看和。

 

原文地址:

转载地址:http://gizhj.baihongyu.com/

你可能感兴趣的文章
leetcode332. 重新安排行程
查看>>
为什么局域网网段不同不能通信?
查看>>
itchat微信助手,kaggle 电影数据集分析,基于内容的电影推荐
查看>>
认识和使用JWT
查看>>
通过springboot框架,自己动手实现oauth2.0授权码模式认证
查看>>
条件表达式于运算符的点点滴滴的积累
查看>>
最短路径最基本的三种算法【此后无良辰】
查看>>
class的点点滴滴的总结
查看>>
vector 的点点滴滴的总结
查看>>
测试用例
查看>>
自动化测试学习步骤
查看>>
自动化测试需要掌握的知识
查看>>
HTTP协议
查看>>
Python问题总结01
查看>>
Python小程序——冒泡排序
查看>>
cmd中输入net start mysql 提示:服务名无效或者MySQL正在启动 MySQL无法启动
查看>>
LeetCode 206反转链表 [javsScript]
查看>>
[LeetCode javaScript] 3. 无重复字符的最长子串
查看>>
[LeetCode javaScript] 6. Z字形变换
查看>>
[LeetCode javaScript]455. 分发饼干
查看>>