码迷,mamicode.com
首页 > 编程语言 > 详细

python用K近邻(KNN)算法分类MNIST数据集和Fashion MNIST数据集

时间:2018-07-28 20:34:04      阅读:586      评论:0      收藏:0      [点我收藏+]

标签:getter   col   err   array   属性   orm   分析   简单   [1]   

一、KNN算法的介绍

  K最近邻(k-Nearest Neighbor,KNN)分类算法是最简单的机器学习算法之一,理论上比较成熟。KNN算法首先将待分类样本表达成和训练样本一致的特征向量;然后根据距离计算待测试样本和每个训练样本的距离,选择距离最小的K个样本作为近邻样本;最后根据K个近邻样本判断待分类样本的类别。KNN算法的正确选取是分类正确的关键因素之一,而近邻样本是通过计算测试样本与每个训练集样本的距离来选定的,故定义合适的距离是KNN正确分类的前提。

本文中在上述研究的基础上,将特征属性值对类别判断的重要性视为同样重要,将样本距离重新定义为任意两样本间像素点间的相关距离,并且距离计算使用的是距离。

二、算法原理

  k-近邻算法(KNN),其工作原理是存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系。输入没有标签的新数据后,将新数据的每个特征与样本集中数据数据对应的特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。一般来说,我们只选择样本数据集中前k个最相似的数据,这就是k-近邻算法中k的出处,通常k是不大于20的整数。最后,选择k个最相似数据中出现次数最多的分类,作为新数据的分类。

  收集和准备数据,这里使用的是mnist数据集和fashion mnist数据集,输入样本数据和结构化的输出结果,可以调整k的值,然后运行k-近邻算法判断输入数据分别属于哪个分类,最后计算错误率和准确率。

KNN算法(k邻近算法分类算法),就是k个最近的邻居的,说的是每个样本都可以用它最接近的k个邻居来代表,核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。KNN算法不仅可以用于分类,还可以用于回归。通过找出一个样本的k个最近邻居,将这些邻居的属性的平均值赋给该样本,就可以得到该样本的属性。在KNN中,通过计算对象间距离来作为各个对象之间的非相似性指标,避免了对象之间的匹配问题,在这里距离使用的是欧氏距离。

详细实现:将mnist数据集和fashion mnist数据集包括训练集和验证集导入到工程文件中,接着计算验证集和训练集的距离,并从小到达排序得到距离最近的k个邻居,并通过投票得到所属类别最高的类别,并判断该验证集的图片属于该类别,接着讲该类别的标签和验证集的标签进行比对,如果相符合则是正确的,如果不相符合,则是属于出错,最后输出计算出的错误率和准确率。

三、数据集介绍
  MNIST数据集,训练集60000张图片和标签;测试集有10000张图片和标签。读取28*28图片以后,要将每张图片转换为1*784的向量。
四、KNN算法实现和结果分析
代码实现:
from numpy import *
import operator
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from os import listdir
from mpl_toolkits.mplot3d import Axes3D
import struct

#读取图片
def read_image(file_name):
#先用二进制方式把文件都读进来
file_handle=open(file_name,"rb") #以二进制打开文档
file_content=file_handle.read() #读取到缓冲区中

offset=0
head = struct.unpack_from(‘>IIII‘, file_content, offset) # 取前4个整数,返回一个元组
offset += struct.calcsize(‘>IIII‘)
imgNum = head[1] #图片数
rows = head[2] #宽度
cols = head[3] #高度
# print(imgNum)
# print(rows)
# print(cols)

#测试读取一个图片是否读取成功
#im = struct.unpack_from(‘>784B‘, file_content, offset)
#offset += struct.calcsize(‘>784B‘)

images=np.empty((imgNum , 784))#empty,是它所常见的数组内的所有元素均为空,没有实际意义,它是创建数组最快的方法
image_size=rows*cols#单个图片的大小
fmt=‘>‘ + str(image_size) + ‘B‘#单个图片的format

for i in range(imgNum):
images[i] = np.array(struct.unpack_from(fmt, file_content, offset))
# images[i] = np.array(struct.unpack_from(fmt, file_content, offset)).reshape((rows, cols))
offset += struct.calcsize(fmt)
return images

‘‘‘bits = imgNum * rows * cols # data一共有60000*28*28个像素值
bitsString = ‘>‘ + str(bits) + ‘B‘ # fmt格式:‘>47040000B‘
imgs = struct.unpack_from(bitsString, file_content, offset) # 取data数据,返回一个元组
imgs_array=np.array(imgs).reshape((imgNum,rows*cols)) #最后将读取的数据reshape成 【图片数,图片像素】二维数组
return imgs_array‘‘‘

#读取标签
def read_label(file_name):
file_handle = open(file_name, "rb") # 以二进制打开文档
file_content = file_handle.read() # 读取到缓冲区中

head = struct.unpack_from(‘>II‘, file_content, 0) # 取前2个整数,返回一个元组
offset = struct.calcsize(‘>II‘)

labelNum = head[1] # label数
# print(labelNum)
bitsString = ‘>‘ + str(labelNum) + ‘B‘ # fmt格式:‘>47040000B‘
label = struct.unpack_from(bitsString, file_content, offset) # 取data数据,返回一个元组
return np.array(label)

#KNN算法
def KNN(test_data, dataSet, labels, k):
dataSetSize = dataSet.shape[0]#dataSet.shape[0]表示的是读取矩阵第一维度的长度,代表行数
# distance1 = tile(test_data, (dataSetSize,1)) - dataSet#欧氏距离计算开始
# print("dataSetSize:")
# print(dataSetSize)
distance1 = tile(test_data, (dataSetSize)).reshape((60000,784))-dataSet#tile函数在行上重复dataSetSizec次,在列上重复1次
# print("distance1.shape")
# print(distance1.shape)
distance2 = distance1**2 #每个元素平方
distance3 = distance2.sum(axis=1)#矩阵每行相加
distances4 = distance3**0.5#欧氏距离计算结束
# print(distances4[53843])
# print(distances4[38620])
# print(distances4[16186])
sortedDistIndicies = distances4.argsort() #返回从小到大排序的索引
classCount=np.zeros((10), np.int32)#10是代表10个类别
for i in range(k): #统计前k个数据类的数量
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] += 1
max = 0
id = 0
print(classCount.shape[0])
# print(classCount.shape[1])

for i in range(classCount.shape[0]):
if classCount[i] >= max:
max = classCount[i]
id = i
print(id)

# sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)#从大到小按类别数目排序
return id

def test_KNN():
# 文件获取
#mnist数据集
# train_image = "F:\mnist\\train-images-idx3-ubyte"
# test_image = "F:\mnist\\t10k-images-idx3-ubyte"
# train_label = "F:\mnist\\train-labels-idx1-ubyte"
# test_label = "F:\mnist\\t10k-labels-idx1-ubyte"
#fashion mnist数据集
train_image = "train-images-idx3-ubyte"
test_image = "t10k-images-idx3-ubyte"
train_label = "train-labels-idx1-ubyte"
test_label = "t10k-labels-idx1-ubyte"
# 读取数据
train_x = read_image(train_image) # train_dataSet
test_x = read_image(test_image) # test_dataSet
train_y = read_label(train_label) # train_label
test_y = read_label(test_label) # test_label

# print(train_x.shape)
# print(test_x.shape)
# print(train_y.shape)
# print(test_y.shape)
# plt.imshow(train_x[0])
# plt.show()

testRatio = 1 # 取数据集的前0.1为测试数据,这个参数比重可以改变
train_row = train_x.shape[0] # 数据集的行数,即数据集的总的样本数
test_row=test_x.shape[0]
testNum = int(test_row * testRatio)
errorCount = 0 # 判断错误的个数
for i in range(testNum):
result = KNN(test_x[i], train_x, train_y, 30)
# print(‘返回的结果是: %s, 真实结果是: %s‘ % (result, train_y[i]))

print(result, test_y[i])
if result != test_y[i]:
errorCount += 1.0# 如果mnist验证集的标签和本身标签不一样,则出错
error_rate = errorCount / float(testNum) # 计算出错率
acc = 1.0 - error_rate
print(errorCount)
print("\nthe total number of errors is: %d" % errorCount)
print("\nthe total error rate is: %f" % (error_rate))
print("\nthe total accuracy rate is: %f" % (acc))

if __name__ == "__main__":
test_KNN()#test()函数中调用了读取数据集的函数,并调用分类函数对数据集进行分类,最后对分类情况进行计算
结果分析:

 

输入:mnist数据集或者fashion mnist数据集

输出:出错率和准确率

Mnist数据集:

取k=30,验证集是50个的时候,准确率是1;

取k=30,验证集是500个的时候,准确率是0.98;

取k=30,验证集是10000个的时候,准确率是0.84。

Fashion Mnist数据集

K=30,验证集是10000的时候,一共的出错个数是1666,准确率是0.8334。

本文中的数据集采用KNN算法得到了较高的准确率,但是本文中考虑特征属性值对类别判断的重要性一样,改进算法时应该考虑特征属性值对类别判断的重要性不同,两样本间属性的相关距离可以用来度量属性值对类别的重要性,相关距离熵越小,两样本的相似程度越大,类可信度越大;此外本文中应该对不同取值的k进行分别的试验,得到使准确率较高的k,同时在实验多个k的时候,可以采用多线程进行跑实验,缩短时间。



 

python用K近邻(KNN)算法分类MNIST数据集和Fashion MNIST数据集

标签:getter   col   err   array   属性   orm   分析   简单   [1]   

原文地址:https://www.cnblogs.com/BlueBlue-Sky/p/9383120.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!