标签:strip orm split line tor 重复 sort 测试 %s
机器学习实战之K-近邻算法:
KNN算法,就是在已知数据集中,计算出离输入的需要预测的点最接近的K个点,然后通过这最近的K个点中哪种分类所占比最高,该预测点就是哪一种分类。
from numpy import * import operator import matplotlib import matplotlib.pyplot as plt import os def createDataSet(): group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]]) labels = [‘A‘,‘A‘,‘B‘,‘B‘] return group,labels def classify0(inX,dataSet,labels,k): # 获得数据集的个数:有多少个数据 dataSetSize = dataSet.shape[0] # tile:可以用来重复数据 # tile(inX,(dataSetSize,1)) 就是让inX这个数据重复dataSetSize遍 每次都单独一行 diffMat = tile(inX,(dataSetSize,1))-dataSet # 计算距离 sqDiffMat = diffMat**2 sqDiatance = sqDiffMat.sum(axis=1) distances = sqDiatance**0.5 # argsort():对数组进行排序 并且返回排序的下标 默认是从小到大 sortedDisIndicies = distances.argsort() classCount ={} for i in range(k): votelLabel =labels[sortedDisIndicies[i]] classCount[votelLabel] = classCount.get(votelLabel,0)+1 sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True) return sortedClassCount[0][0] def file2mat(filename): # 将文件转化成矩阵 # 打开文件 fr = open(filename) # 读取文件 获取文件行数:也就是数据集的个数 arrayOfLines =fr.readlines() numberOfLines = len(arrayOfLines) # 因为测试数据中有三列数据,一列标签 所以后面是3 这里应该改成输入进参数 有更好的复用性 returnMat = zeros((numberOfLines,3)) classLabelVector=[] index = 0 for line in arrayOfLines: # 通过分隔符来判断有几列 并记录对应数值 line = line.strip() listFromLine = line.split(‘\t‘) returnMat[index,:] = listFromLine[0:3] classLabelVector.append(int(listFromLine[-1])) index +=1 return returnMat,classLabelVector def autoNorm(dataSet): # 归一化处理 minVals = dataSet.min(0) maxVals = dataSet.max(0) ranges = maxVals-minVals normDataSet = zeros(shape(dataSet)) m = dataSet.shape[0] normDataSet = dataSet - tile(minVals,(m,1)) normDataSet = normDataSet/tile(ranges,(m,1)) return normDataSet,ranges,minVals def datingTest(): hoRatio = 0.1 datingDataMat,datingLabels = file2mat("datingTestSet2.txt") normMat,ranges,minVals = autoNorm(datingDataMat) m =normMat.shape[0] numTestVec = int(m*hoRatio) errorCount = 0.0 for i in range(numTestVec): classidierResult = classify0(normMat[i,:],normMat[numTestVec:m,:],datingLabels[numTestVec:m],3) print("the classifier came back with : %d , the real answer is %d" %(classidierResult,datingLabels[i])) if(classidierResult != datingLabels[i]): errorCount+=1.0 print("the total error rate is : %f" %(errorCount/float(numTestVec))) def img2Mat(filename): returnVect = zeros((1,1024)) fr = open(filename) for i in range(32): lineStr = fr.readline() for j in range(32): returnVect[0,32*i+j] = int (lineStr[j]) return returnVect def handWritingTest(): hwLabels = [] trainingFileList = os.listdir(‘trainingDigits‘) m = len(trainingFileList) trainingMat = zeros((m,1024)) for i in range(m): fileNameStr = trainingFileList[i] fileStr = fileNameStr.split(‘.‘)[0] classNumStr = int(fileStr.split(‘_‘)[0]) hwLabels.append(classNumStr) trainingMat[i,:]= img2Mat(‘trainingDigits/%s‘ %fileNameStr) testFileList = os.listdir(‘testDigits‘) errorCount = 0.0 mTest = len(testFileList) for i in range(mTest): fileNameStr = testFileList[i] fileStr = fileNameStr.split(‘.‘)[0] classNumStr = int(fileStr.split(‘_‘)[0]) vectorUnderTest = img2Mat(‘testDigits/%s‘ %fileNameStr) classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3) print("the classifier came back with : %d , the real answer is %d" %(classifierResult,classNumStr)) if(classifierResult != classNumStr ): errorCount+=1.0 print("\n the error rate is : %f" %(errorCount/float(mTest)))
对应的代码和注解
标签:strip orm split line tor 重复 sort 测试 %s
原文地址:https://www.cnblogs.com/xiaoxineryi/p/12332404.html