网上介绍K-近邻算法的例子很多,其Python实现版本基本都是来自于机器学习的入门书籍《机器学习实战》,虽然K-近邻算法本身很简单,但很多初学者对其Python版本的源代码理解不够,所以本文将对其源代码进行分析。
什么是K-近邻算法?
简单的说,K-近邻算法采用不同特征值之间的距离方法进行分类。所以它是一个分类算法。
优点:无数据输入假定,对异常值不敏感
缺点:复杂度高
好了,直接先上代码,等会在分析:(这份代码来自《机器学习实战》)
def classify0(inx, dataset, lables, k): dataSetSize = dataset.shape[0] diffMat = tile(inx, (dataSetSize, 1)) - dataset sqDiffMat = diffMat**2 sqDistance = sqDiffMat.sum(axis=1) distances = sqDistance**0.5 sortedDistances = distances.argsort() classCount={} for i in range(k): label = lables[sortedDistances[i]] classCount[label] = classCount.get(label, 0) + 1 sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0]
该函数的原理是:
存在一个样本数据集合,也称为训练集,在样本集中每个数据都存在标签。在我们输入没有标签的新数据后,将新数据的每个特征与样本集中对应的特征进行比较,然后提取最相似(最近邻)的分类标签。一般我们只选样本数据集中前K 个最相似的数据。最后,出现次数最多的分类就是新数据的分类。
classify0函数的参数意义如下:
inx : 是输入没有标签的新数据,表示为一个向量。
dataset: 是样本集。表示为向量数组。
labels:对应样本集的标签。
k:即所选的前K。
用于产生数据样本的简单函数:
def create_dataset(): group = array([[1.0, 1.1], [1.0, 1.1], [0, 0], [0, 0.1]]) labels = ['A', 'A', 'B', 'B'] return group, labels
注意,array是numpy里面的。我们需要实现import进来。
from numpy import * import operator
我们在调用时,
group,labels = create_dataset() result = classify0([0,0], group, labels, 3) print result
知道了这些,初学者应该对实际代码还是很陌生。不急,正文开始了!
源码分析
dataSetSize = dataset.shape[0]
shape是array的属性,它描述了一个数组的“形状”,也就是它的维度。比如,
In [2]: dataset = array([[1.0, 1.1], [1.0, 1.1], [0, 0], [0, 0.1]]) In [3]: print dataset.shape (4, 2)
diffMat = tile(inx, (dataSetSize, 1)) - dataset
我们看看tile(inx, (4, 1))的结果,
In [5]: tile(x, (4, 1)) Out[5]: array([[0, 0], [0, 0], [0, 0], [0, 0]])
为证实上面的结论,
In [6]: tile(x,(4,2)) Out[6]: array([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])
In [7]: tile(x,(2,2)) Out[7]: array([[0, 0, 0, 0], [0, 0, 0, 0]])
得到tile后,减去dataset。这类似一个矩阵的减法,结果仍是一个 4 * 2的数组。
In [8]: tile(x, (4, 1)) - dataset Out[8]: array([[-1. , -1.1], [-1. , -1.1], [ 0. , 0. ], [ 0. , -0.1]])
我们看看求和的方法,
sqDiffMat.sum(axis=1)
In [14]: sqDiffmat Out[14]: array([[ 1. , 1.21], [ 1. , 1.21], [ 0. , 0. ], [ 0. , 0.01]])
求和的结果是对行求和,是一个N*1的数组。
如果要对列求和,
sqlDiffMat.sum(axis=0)
classCount是一个字典,key是标签,value是该标签出现的次数。
这样,算法的一些具体代码细节就清楚了。
原文地址:http://blog.csdn.net/chenloveit/article/details/39969245