码迷,mamicode.com
首页 > 编程语言 > 详细

KNN算法

时间:2020-02-19 18:49:49      阅读:66      评论:0      收藏:0      [点我收藏+]

标签:strip   orm   split   line   tor   重复   sort   测试   %s   

机器学习实战之K-近邻算法:

  KNN算法,就是在已知数据集中,计算出离输入的需要预测的点最接近的K个点,然后通过这最近的K个点中哪种分类所占比最高,该预测点就是哪一种分类。

from numpy import *
import operator
import matplotlib
import matplotlib.pyplot as plt
import os
def createDataSet():
    group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
    labels = [A,A,B,B]
    return group,labels

def classify0(inX,dataSet,labels,k):
    # 获得数据集的个数:有多少个数据
    dataSetSize  = dataSet.shape[0]
    # tile:可以用来重复数据
    # tile(inX,(dataSetSize,1)) 就是让inX这个数据重复dataSetSize遍 每次都单独一行
    diffMat = tile(inX,(dataSetSize,1))-dataSet
    # 计算距离
    sqDiffMat = diffMat**2
    sqDiatance = sqDiffMat.sum(axis=1)
    distances  = sqDiatance**0.5
    # argsort():对数组进行排序 并且返回排序的下标 默认是从小到大
    sortedDisIndicies = distances.argsort()
    classCount ={}
    for i in range(k):
        votelLabel  =labels[sortedDisIndicies[i]]
        classCount[votelLabel] = classCount.get(votelLabel,0)+1
    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]

def file2mat(filename):
    # 将文件转化成矩阵
    # 打开文件
    fr = open(filename)
    # 读取文件 获取文件行数:也就是数据集的个数
    arrayOfLines =fr.readlines()
    numberOfLines = len(arrayOfLines)
    # 因为测试数据中有三列数据,一列标签 所以后面是3 这里应该改成输入进参数 有更好的复用性
    returnMat = zeros((numberOfLines,3))
    classLabelVector=[]
    index = 0
    for line in arrayOfLines:
        # 通过分隔符来判断有几列 并记录对应数值
        line = line.strip()
        listFromLine = line.split(\t)
        returnMat[index,:] = listFromLine[0:3]
        classLabelVector.append(int(listFromLine[-1]))
        index +=1
    return returnMat,classLabelVector

def autoNorm(dataSet):
    # 归一化处理
    minVals = dataSet.min(0)
    maxVals = dataSet.max(0)
    ranges = maxVals-minVals
    normDataSet = zeros(shape(dataSet))
    m = dataSet.shape[0]
    normDataSet = dataSet - tile(minVals,(m,1))
    normDataSet = normDataSet/tile(ranges,(m,1))
    return normDataSet,ranges,minVals

def datingTest():
    hoRatio = 0.1
    datingDataMat,datingLabels = file2mat("datingTestSet2.txt")
    normMat,ranges,minVals = autoNorm(datingDataMat)
    m  =normMat.shape[0]
    numTestVec = int(m*hoRatio)
    errorCount = 0.0
    for i in range(numTestVec):
        classidierResult = classify0(normMat[i,:],normMat[numTestVec:m,:],datingLabels[numTestVec:m],3)
        print("the classifier came back with : %d , the real answer is %d" %(classidierResult,datingLabels[i]))
        if(classidierResult != datingLabels[i]):
            errorCount+=1.0

    print("the total error rate is : %f" %(errorCount/float(numTestVec)))


def img2Mat(filename):
    returnVect = zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0,32*i+j] = int (lineStr[j])
    return returnVect


def handWritingTest():
    hwLabels = []
    trainingFileList = os.listdir(trainingDigits)
    m = len(trainingFileList)
    trainingMat = zeros((m,1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split(.)[0]
        classNumStr = int(fileStr.split(_)[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:]= img2Mat(trainingDigits/%s %fileNameStr)
    testFileList = os.listdir(testDigits)
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split(.)[0]
        classNumStr = int(fileStr.split(_)[0])
        vectorUnderTest = img2Mat(testDigits/%s %fileNameStr)
        classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3)

        print("the classifier came back with : %d , the real answer is %d" %(classifierResult,classNumStr))
        if(classifierResult != classNumStr ):
            errorCount+=1.0
    print("\n the error rate is : %f" %(errorCount/float(mTest)))

 

    对应的代码和注解

 

KNN算法

标签:strip   orm   split   line   tor   重复   sort   测试   %s   

原文地址:https://www.cnblogs.com/xiaoxineryi/p/12332404.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!