码迷,mamicode.com
首页 > 编程语言 > 详细

k-近邻分类的Python实现

时间:2014-05-26 18:25:13      阅读:196      评论:0      收藏:0      [点我收藏+]

标签:style   c   class   blog   code   java   

参见《机器学习实战》

bubuko.com,布布扣
 1 # -*- coding:cp936 -*-
 2 #===============================================================================
 3 # 设计KNN最近邻分类器:
 4 #     找出每个元素在数据集中的最近邻的K个数据,统计这K个数据所属的类,所属类最多的那个类就是该元素所属的类
 5 #===============================================================================
 6 import numpy as np
 7 
 8 def loadHaiLunData(f_name):
 9     with open(f_name) as fHandle:
10         fLines = fHandle.readlines()
11         dataLines = len(fLines)
12         label  = []
13         dataSetMat = np.zeros((dataLines,3))
14         for i in range(dataLines):
15             lineList = fLines[i].strip().split(\t)
16             dataSetMat[i,:] = lineList[0:3]
17             label.append(int(lineList[-1]))
18         return dataSetMat,label
19     
20 
21 def dataNorm(dataSet):
22     numOfEle = dataSet.shape[0]
23     minEle = dataSet.min(0)
24     maxEle = dataSet.max(0)
25     normedData = (dataSet-np.tile(minEle,(numOfEle,1)))/np.tile(maxEle-minEle,(numOfEle,1))
26     return normedData
27 
28 def classifyKnn(inX, dataSet, label, k):
29     #===========================================================================
30     # inX:输入向量
31     # dataSet:保存数据特征的数组,每一行为若干个特征的参数,与label对应
32     # label:表明当前这个数据集中的每一个元素属于哪一类
33     # k:设定最近邻的个数
34     #===========================================================================
35     
36     #首先对数据集进行归一化
37 #     dataSet = dataNorm(dataSet)
38     numOfEle = dataSet.shape[0]
39     index = 0
40     diffDistance = dataSet - np.tile(inX, (numOfEle,1))
41     diffDistance = diffDistance**2
42     squareDistance = diffDistance.sum(1)
43 #     squareDistance = squareDistance**0.5
44     knnIndex = squareDistance.argsort()
45     #统计最近的k个近邻的label,看哪个label类别最多就可将该训练元素判为对应类
46     staticDict = {}
47     for i in range(k):
48         staticDict[label[knnIndex[i]]]=staticDict.get(label[knnIndex[i]],0)+1
49     itemList = staticDict.items()
50     argmax = np.argmax(itemList, axis = 0)
51     return itemList[argmax[1]][0]
52     
53 def testHaiLunClassify(k = 3, hRatio = 0.5):
54     dataSet,label = loadHaiLunData(datingTestSet2.txt)
55 #     hRatio = 0.5
56     totalNum = dataSet.shape[0]
57     testNum = int(totalNum*hRatio)
58     dataNormed = dataNorm(dataSet)
59     errorClass = 0
60     for i in range(testNum):
61         classRes = classifyKnn(dataNormed[i,:], dataNormed[testNum:,:], label[testNum:], k)
62         if classRes != label[i]:
63             errorClass += 1
64 #             print "classify error, No. %d should be label %d but got %d"%(i, label[i],classRes)
65     errorRate = errorClass/float(testNum)
66 #     print
67 #     print "Error rate: %f"%(errorRate)
68     return errorRate
69 
70 if __name__ == __main__:
71     errorList = []
72     kRange = range(1,50,1)
73     for k in kRange:
74         errorList.append(testHaiLunClassify(k))
75     print errorList
76     import matplotlib.pyplot as  plt
77     fig = plt.figure(1)
78 #     ax  = fig.add_subplot(111)
79     plt.plot(kRange, errorList,rs-)
80     plt.show()
81     
bubuko.com,布布扣

 

k-近邻分类的Python实现,布布扣,bubuko.com

k-近邻分类的Python实现

标签:style   c   class   blog   code   java   

原文地址:http://www.cnblogs.com/mmhx/p/3752521.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!