码迷,mamicode.com
首页 > 编程语言 > 详细

Python实现K-means聚类算法

时间:2015-11-23 18:41:18      阅读:782      评论:0      收藏:0      [点我收藏+]

标签:

因为自己对python也有一定的了解,之前也用R做过一些数据分析,又恰好看到几篇文章介绍python实现算法的,觉得挺有意思,所以参考了一些书籍来自己实现一个K-means的聚类算法。《Python数据分析基础教程:NumPy学习指南(第2版)》和 《 Matplotlib手册》是做数据分析的挺不错的两个入门级教材,推荐给大家。

链接:http://pan.baidu.com/s/1FSheY 密码:ulsa

数据聚类是对于静态数据分析的一门技术,在许多领域内都被广泛地应用,包括机器学习数据挖掘模式识别图像分析、信息检索以及生物信息等。聚类是把相似的对象通过静态分类的方法分成不同的组别或者更多的子集,这样让在同一个子集中的成员对象都有相似的一些属性,常见的包括在坐标系中更加短的空间距离等。

一、K-means算法的简单介绍

K-means算法是很典型的基于距离的聚类算法,采用距离作为相似性的评价指标,即认为两个对象的距离越近,其相似度就越大。该算法认为簇是由距离靠近的对象组成的,因此把得到紧凑且独立的簇作为最终目标。

用一个简单的例子来说明K-means算法的原理。如下图所展示,将n个样本点随机的分布在二维图表中,从点的聚集和松散程度可以大致将点分为三个簇,分别用不同的颜色表示出来。

技术分享

K-menas算法首先选择K个初始质心,其中K是用户指定的参数,即所期望的簇的个数。这样做的前提是我们已经知道数据集中包含多少个簇,但很多情况下,我们并不知道数据的分布情况,所以只能假设有K个簇,然后做进一步的迭代求平方和来判断各聚类(簇)是否为最优。

K-means算法是一个基于距离的迭代算法,下面是K-means的代价函数:
技术分享

式中的J函数表示每个样本点到其质心的距离平方和,μc(i)表示第i个聚类的均值,只需对所有类所得到的误差平方求和,即可验证分为k个类时,各聚类是否是最优的。其实这里只需将J函数开方就得到了欧氏距离,更加的直观的想到J函数的值越小,各类内的样本越相似。

下面是K-means算法的求解过程:

1、随机选取 k个聚类质心点

2、重复下面过程直到收敛  {

      对于每一个样例 i,计算其应该属于的类:

技术分享

      对于每一个类 j,重新计算该类的质心:

技术分享

}

这其实就是一个多次迭代的过程:

1)在所有的样本点中随机抽取出k个聚类中心点

(2)遍历其他的观测点到聚类中心点的最近距离,得到这个点,将其加入到该聚类中

(3)计算每个聚类的平均值,并作为新的中心点

(4)重复(2)-(3),直到这k个中心点不再变化(收敛了)

 

下图展示了对n个样本点进行K-means聚类的效果,这里k取2。

技术分享

 

二、代码实现K-means算法

1、创建K个观测点作为聚类的中心点

2、计算质心到观测点的距离

3、将观测点分配到最近的簇

4、对每一个簇,计算簇中所有点的均值,并将均值作为质心

 

 下面的代码用到了Python中的两个库Numpy和Matplotlib

 1 from numpy import *  
 2 import time  
 3 import matplotlib.pyplot as plt  
 4   
 5   
 6 # 判读两点之间的距离  
 7 def euclDistance(x1, x2):  
 8     return sqrt(sum(power(x2 - x1, 2)))  
 9   
10 # 初始化质心与随机样本点 
11 def initCentroids(data, k):
12 len_line, dim = data.shape #读取矩阵的长度,返回一个(len_line,dim)的元组 13 shuzu = zeros((k, dim)) #创建k行dim列的数组 14 for i in range(k): 15 index = int(random.uniform(0, len_line)) #uniform()方法将随机生成下一个实数,它在[x,y]范围内。 16 shuzu[i, :] = data[index, :] 17 return shuzu 18 19 # k-means cluster 20 def kmeans(data, k): 21 len_line = data.shape[0] #读取数组第一行的长度 22 23 24 clusterAssment = mat(zeros((len_line, 2))) 25 clusterChanged = True 26 27 ## step 1: init centroids 28 centroids = initCentroids(data, k) 29 30 while clusterChanged: 31 clusterChanged = False 32 ## for each sample 33 for i in xrange(len_line): 34 minDist = 100000.0 35 minIndex = 0 36 ## for each centroid 37 ## step 2: find the centroid who is closest 38 for j in range(k): 39 distance = euclDistance(centroids[j, :], data[i, :]) 40 if distance < minDist: 41 minDist = distance 42 minIndex = j 43 44 ## step 3: update its cluster 45 if clusterAssment[i, 0] != minIndex: 46 clusterChanged = True 47 clusterAssment[i, :] = minIndex, minDist**2 48 49 ## step 4: update centroids 50 for j in range(k): 51 pointsInCluster = data[nonzero(clusterAssment[:, 0].A == j)[0]] 52 centroids[j, :] = mean(pointsInCluster, axis = 0) 53 54 print Congratulations, cluster complete! 55 return centroids, clusterAssment 56 57 # show your cluster only available with 2-D data 58 def showCluster(data, k, centroids, clusterAssment): 59 len_line, dim = data.shape 60 if dim != 2: 61 print "Sorry! I can not draw because the dimension of your data is not 2!" 62 return 1 63 64 mark = [or, ob, og, ok, ^r, +r, sr, dr, <r, pr] 65 if k > len(mark): 66 print "Sorry! Your k is too large! please contact Zouxy" 67 return 1 68 69 # draw all samples 70 for i in xrange(len_line): 71 markIndex = int(clusterAssment[i, 0]) 72 plt.plot(data[i, 0], data[i, 1], mark[markIndex]) 73 74 mark = [Dr, Db, Dg, Dk, ^b, +b, sb, db, <b, pb] 75 # draw the centroids 76 for i in range(k): 77 plt.plot(centroids[i, 0], centroids[i, 1], mark[i], markersize = 12) 78 79 plt.show()

三、测试结果

 下面是80个样本,分成4个类

    1.658985    4.285136  
    -3.453687   3.424321  
    4.838138    -1.151539  
    -5.379713   -3.362104  
    0.972564    2.924086  
    -3.567919   1.531611  
    0.450614    -3.302219  
    -3.487105   -1.724432  
    2.668759    1.594842  
    -3.156485   3.191137  
    3.165506    -3.999838  
    -2.786837   -3.099354  
    4.208187    2.984927  
    -2.123337   2.943366  
    0.704199    -0.479481  
    -0.392370   -3.963704  
    2.831667    1.574018  
    -0.790153   3.343144  
    2.943496    -3.357075  
    -3.195883   -2.283926  
    2.336445    2.875106  
    -1.786345   2.554248  
    2.190101    -1.906020  
    -3.403367   -2.778288  
    1.778124    3.880832  
    -1.688346   2.230267  
    2.592976    -2.054368  
    -4.007257   -3.207066  
    2.257734    3.387564  
    -2.679011   0.785119  
    0.939512    -4.023563  
    -3.674424   -2.261084  
    2.046259    2.735279  
    -3.189470   1.780269  
    4.372646    -0.822248  
    -2.579316   -3.497576  
    1.889034    5.190400  
    -0.798747   2.185588  
    2.836520    -2.658556  
    -3.837877   -3.253815  
    2.096701    3.886007  
    -2.709034   2.923887  
    3.367037    -3.184789  
    -2.121479   -4.232586  
    2.329546    3.179764  
    -3.284816   3.273099  
    3.091414    -3.815232  
    -3.762093   -2.432191  
    3.542056    2.778832  
    -1.736822   4.241041  
    2.127073    -2.983680  
    -4.323818   -3.938116  
    3.792121    5.135768  
    -4.786473   3.358547  
    2.624081    -3.260715  
    -4.009299   -2.978115  
    2.493525    1.963710  
    -2.513661   2.642162  
    1.864375    -3.176309  
    -3.171184   -3.572452  
    2.894220    2.489128  
    -2.562539   2.884438  
    3.491078    -3.947487  
    -2.565729   -2.012114  
    3.332948    3.983102  
    -1.616805   3.573188  
    2.280615    -2.559444  
    -2.651229   -3.103198  
    2.321395    3.154987  
    -1.685703   2.939697  
    3.031012    -3.620252  
    -4.599622   -2.185829  
    4.196223    1.126677  
    -2.133863   3.093686  
    4.668892    -2.562705  
    -2.793241   -2.149706  
    2.884105    3.043438  
    -2.967647   2.848696  
    4.479332    -1.764772  
    -4.905566   -2.911070  

 

 测试代码:

#################################################  
# kmeans: k-means cluster  
# Author : zouxy  
# Date   : 2013-12-25  
# HomePage : http://blog.csdn.net/zouxy09  
# Email  : zouxy09@qq.com  
#################################################  
  
from numpy import *  
import time  
import matplotlib.pyplot as plt  
  
## step 1: load data  
print "step 1: load data..."  
dataSet = []  
fileIn = open(E:/Python/Machine Learning in Action/testSet.txt)  
for line in fileIn.readlines():  
    lineArr = line.strip().split(\t)  
    dataSet.append([float(lineArr[0]), float(lineArr[1])])  
  
## step 2: clustering...  
print "step 2: clustering..."  
dataSet = mat(dataSet)  
k = 4  
centroids, clusterAssment = kmeans(dataSet, k)  
  
## step 3: show the result  
print "step 3: show the result..."  
showCluster(dataSet, k, centroids, clusterAssment) 

 

 运行的前后结果是:

技术分享

 不同的类用不同的颜色来表示,其中的大菱形是对应类的均值质心点。

Python实现K-means聚类算法

标签:

原文地址:http://www.cnblogs.com/xbkp/p/4988730.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!