码迷,mamicode.com
首页 > 编程语言 > 详细

KNN算法的感受 1

时间:2016-04-10 01:01:13      阅读:279      评论:0      收藏:0      [点我收藏+]

标签:

本来预计的打算是一天一个十大挖掘算法,然而由于同时要兼顾数据结构面试的事情,所以 很难办到,但至少在回家前要把数据挖掘十大算法看完,过个好年,在course上学习老吴的课程还是帮了我很大的忙,虽然浪费了时间,但是也无形中帮助我 很多,所以说还是很值得的,今天就总结KNN算法的一部分,这部分老吴的课程中没有太多涉及到,所以我又重新关注了一下,下面是我的总结,希望能对大家有 所帮组。

     介绍环镜:python2.7  IDLE  Pycharm5.0.3

     操作系统:windows

    第一步:因为没有numpy,所以要安装numpy,详情见另一篇安装numpy的博客,这里不再多说.

    第二步:贴代码:

 1     from numpy import *  
 2     import operator  
 3     from os import listdir
 5   def classify0(inX, dataSet, labels, k):
 6         dataSetSize = dataSet.shape[0]  
 7         diffMat = tile(inX, (dataSetSize,1)) - dataSet  
 8         sqDiffMat = diffMat**2  
 9         sqDistances = sqDiffMat.sum(axis=1)  
10         distances = sqDistances**0.5  
11         sortedDistIndicies = distances.argsort()       
12         classCount={}            
13         for i in range(k):  
14             voteIlabel = labels[sortedDistIndicies[i]]  
15             classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1  
16         sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)  
17         return sortedClassCount[0][0]  
18       
19   def createDataSet():
20         group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])  
21         labels = [A,A,B,B]  
22         return group, labels  
23       
24    def file2matrix(filename):  
25         fr = open(filename)  
26         numberOfLines = len(fr.readlines())         #get the number of lines in the file  
27         returnMat = zeros((numberOfLines,3))        #prepare matrix to return  
28         classLabelVector = []                       #prepare labels return     
29         fr = open(filename)  
30         index = 0  
31         for line in fr.readlines():  
32             line = line.strip()  
33             listFromLine = line.split(\t)  
34             returnMat[index,:] = listFromLine[0:3]  
35             classLabelVector.append(int(listFromLine[-1]))  
36             index += 1  
37         return returnMat,classLabelVector  
38           
39     def autoNorm(dataSet):
40         minVals = dataSet.min(0)  
41         maxVals = dataSet.max(0)  
42         ranges = maxVals - minVals  
43         normDataSet = zeros(shape(dataSet))  
44         m = dataSet.shape[0]  
45         normDataSet = dataSet - tile(minVals, (m,1))  
46         normDataSet = normDataSet/tile(ranges, (m,1))   #element wise divide  
47         return normDataSet, ranges, minVals  
48          
49     def datingClassTest():
50         hoRatio = 0.50      #hold out 10%  
51         datingDataMat,datingLabels = file2matrix(datingTestSet2.txt)       #load data setfrom file  
52         normMat, ranges, minVals = autoNorm(datingDataMat)  
53         m = normMat.shape[0]  
54         numTestVecs = int(m*hoRatio)  
55         errorCount = 0.0  
56         for i in range(numTestVecs):  
57             classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)  
58             print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])  
59             if (classifierResult != datingLabels[i]): errorCount += 1.0  
60         print "the total error rate is: %f" % (errorCount/float(numTestVecs))  
61         print errorCount  
62           
63     def img2vector(filename): 
64         returnVect = zeros((1,1024))  
65         fr = open(filename)  
66         for i in range(32):  
67             lineStr = fr.readline()  
68             for j in range(32):  
69                 returnVect[0,32*i+j] = int(lineStr[j])  
70         return returnVect  
71       
72     def handwritingClassTest():  
73         hwLabels = []  
74         trainingFileList = listdir(trainingDigits)           #load the training set  
75         m = len(trainingFileList)  
76         trainingMat = zeros((m,1024))  
77         for i in range(m):  
78             fileNameStr = trainingFileList[i]  
79             fileStr = fileNameStr.split(.)[0]     #take off .txt  
80             classNumStr = int(fileStr.split(_)[0])  
81             hwLabels.append(classNumStr)  
82             trainingMat[i,:] = img2vector(trainingDigits/%s % fileNameStr)  
83         testFileList = listdir(testDigits)        #iterate through the test set  
84         errorCount = 0.0  
85         mTest = len(testFileList)  
86         for i in range(mTest):  
87             fileNameStr = testFileList[i]  
88             fileStr = fileNameStr.split(.)[0]     #take off .txt  
89             classNumStr = int(fileStr.split(_)[0])  
90             vectorUnderTest = img2vector(testDigits/%s % fileNameStr)  
91             classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)  
92             print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)  
93             if (classifierResult != classNumStr): errorCount += 1.0  
94         print "\nthe total number of errors is: %d" % errorCount  
95         print "\nthe total error rate is: %f" % (errorCount/float(mTest))  

  第三步:通过命令行交互

      (1):先将上述代码保存为kNN.py

      (2):再在IDLE下的run菜单下run一下,将其生成python模块

      (3): import  kNN(因为上一步已经生成knn模块)
      (4): kNN.classify0([0,0],group,labels,3) (讨论[0,0]点属于哪一个类)

   注:其中【0,0】可以随意换

即【】内的坐标就是我们要判断的点的坐标:

>>> kNN.classify0([0,0],group,labels,3)
‘B‘
>>> kNN.classify0([0,1],group,labels,3)
‘B‘
>>> kNN.classify0([0.6,0.6],group,labels,3)
‘A‘

 

KNN算法的感受 1

标签:

原文地址:http://www.cnblogs.com/hellochennan/p/5373123.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!