码迷,mamicode.com
首页 > 编程语言 > 详细

k近邻算法

时间:2020-02-01 00:48:24      阅读:87      评论:0      收藏:0      [点我收藏+]

标签:介绍   int   spl   不同   getter   pen   items   pl2   span   

介绍

k近邻算法(KNN)属于监督学习的分类算法,通过测量不同特征值之间的距离进行分类,算法过程如下

  • 计算数据点与已知数据集中每个点的距离
  • 对距离从小到大进行排序
  • 选取前k个距离值
  • 确定前k个距离值所在类别的出现的概率
  • 将前k个点出现频率最高的类别作为当前数据的预测分类

主要代码如下

def classfiy(inData, dataSet, labels, k):
    dataSize = dataSet.shape[0]  # 得到数组的行维度,即数据的个数
    # 先通过tile将输入的数据扩展为与dataSet相同维度的数组,再通过距离公式计算距离
    distance = (((tile(inData, (dataSize, 1)) - dataSet) ** 2).sum(axis=1)) ** 0.5
    sortIndex = distance.argsort()  # 返回数组值从小到大的索引值
    classCount = {}
    for i in range(k):  # 只对前k个计数
        headLabel = labels[sortIndex[i]]
        classCount[headLabel] = classCount.get(headLabel, 0) + 1  # 统计前k个中出现标签的次数
    # 对字典按照第二个值(即出现的次数)进行排序,用reverse指定从大到小排
    sortCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortCount[0][0]  # 返回第一个的标签

其中距离计算,通过公式,如\((x_{1},y_{1})(x_{2},y_{2})\)两点的距离d为\(d=\sqrt{(x_{1}-x_{2})^2+(y_{1}-y_{2})^2}\)

用KNN识别数字图片中的数字

只是个玩具程序

收集数据

每个数字准备了10张图片,分别存在digit中的以各个数字命名的文件夹下
技术图片
又为每个数据准备了5张图片,以同样的规则存在digit2的各个文件夹下
技术图片

准备数据

缩放图像
采用了pillow中的resize函数,同一将图像缩放为50*50
newImg = img.resize((50, 50))
二值化图像
开始想直接通过convet(‘1‘)直接将图像二值化,但出现了很多噪音
所以通过以下程序将图像二值化。其中230为设定的阀值,多次尝试,发现230效果较好

    for i in range(rows):
        for j in range(cols):
            if (imgArray[i, j] <= 230):
                imgArray[i, j] = 0
            else:
                imgArray[i, j] = 255

转化为一维向量
将读取的处理后的图片的像素值转化为一维向量

测试

通过读取测试集中的数据,进行预测,和实际的类别比对,查看正确率

程序

from PIL import Image
from numpy import *
import os
import operator

#缩放为相同大小
def toSame(img):
    newImg = img.resize((50, 50))
    return newImg

#二值化处理
def toBinarry(img):
    imgArray = array(img)
    rows, cols = imgArray.shape
    for i in range(rows):
        for j in range(cols):
            if (imgArray[i, j] <= 230):
                imgArray[i, j] = 0
            else:
                imgArray[i, j] = 255
    return imgArray

#读取每个文件夹下的每张图片
def readImage(filePath):
    dataList = []
    labels = []
    for i in range(10):
        imagePath = filePath + '/' + str(i)
        files = os.listdir(imagePath)
        for j in files:
            labels.append(j.split('_')[0])#因为每张图片采用‘数字_第几张的命名方式’,所以通过下横线分割,取得前面的作为图片的分类标签
            img = Image.open(imagePath + '/' + j).convert('L')#先灰度化处理
            imgArray = toBinarry(toSame(img))
            dataList.append(imgArray.ravel())#转变为一维后加入列表
    dataSet = array(dataList)
    return dataSet, labels

#分类算法
def classfiy(inData, dataSet, labels, k):
    dataSize = dataSet.shape[0]  # 得到数组的行维度,即数据的个数
    # 先通过tile将输入的数据扩展为与dataSet相同维度的数组,再通过距离公式计算距离
    distance = (((tile(inData, (dataSize, 1)) - dataSet) ** 2).sum(axis=1)) ** 0.5
    sortIndex = distance.argsort()  # 返回数组值从小到大的索引值
    classCount = {}
    for i in range(k):  # 只对前k个计数
        headLabel = labels[sortIndex[i]]
        classCount[headLabel] = classCount.get(headLabel, 0) + 1  # 统计前k个中出现标签的次数
    # 对字典按照第二个值(即出现的次数)进行排序,用reverse指定从大到小排
    sortCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortCount[0][0]  # 返回第一个的标签

# 进行测试
dataSet, labels = readImage('./digit')
dataSet2, labels2 = readImage('./digit2')
n = 0
for i in range(len(dataSet2)):
    predict = classfiy(dataSet2[i], dataSet, labels, 10)
    print(predict + ' ' + labels2[i])
    if (predict == labels2[i]):
        n = n + 1
# 查看准确率
print(n / len(dataSet2))

运行结果

技术图片
发现准确率只有0.62

总结

  • 准确率如此低,可能是数据不足,也可能对图像处理不好。在二值化时,效果其实并不完美。也可能需要对图像进行一些裁剪。在二值化时,本程序也只适合一些浅色底子的数字图片
  • 采用不同的k,预测的效果也是不同,也需要找到一个最佳的k

其它

  • 在处理数据时,通常用到的归一化
def toNormal(dataSet):
    # 归一化
    min = dataSet.min(0)
    max = dataSet.max(0)
    # 公式normal=(x-min)/(max-min)
    normalArray = (dataSet - tile(min, (dataSet.shape[0], 1))) / tile(max - min, (dataSet.shape[0], 1))
    return normalArray
def toClear(imgArray):
    rows, cols = imgArray.shape
    for y in range(1, cols - 1):
        for x in range(1, rows - 1):
            count = 0
            if imgArray[x, y - 1] == 255:  # 上
                count = count + 1
            if imgArray[x, y + 1] == 255:  # 下
                count = count + 1
            if imgArray[x - 1, y] == 255:  # 左
                count = count + 1
            if imgArray[x + 1, y] == 255:  # 右
                count = count + 1
            if imgArray[x - 1, y - 1] == 255:  # 左上
                count = count + 1
            if imgArray[x - 1, y + 1] == 255:  # 左下
                count = count + 1
            if imgArray[x + 1, y - 1] == 255:  # 右上
                count = count + 1
            if imgArray[x + 1, y + 1] == 255:  # 右下
                count = count + 1
            if count > 4:
                imgArray[x, y] = 255
    return imgArray

k近邻算法

标签:介绍   int   spl   不同   getter   pen   items   pl2   span   

原文地址:https://www.cnblogs.com/Qi-Lin/p/12247163.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!