码迷,mamicode.com
首页 > 其他好文 > 详细

KNN——图像分类

时间:2018-04-05 16:28:20      阅读:362      评论:0      收藏:0      [点我收藏+]

标签:class   https   lan   tps   简单的   gre   byte   load   ast   

 

内容参考自:https://zhuanlan.zhihu.com/p/20894041?refer=intelligentunit

 

用像素点的rgb值来判断图片的分类准确率并不高,但是作为一个练习knn的题目,还是挺不错的。

 

1. CIFAR-10

CIFAR-10是一个图像分类数据集。数据集包含60000张32*32像素的小图片,每张图片都有一个类别标注(总共有10类),分成了50000张的训练集和10000张的测试集。

然后下载后得到的并不是实实在在的图片(不然60000张有点可怕...),而是序列化之后的,需要我们用代码来打开来获得图片的rgb值。

1 import pickle
2 
3 def unpickle(file):
4    with open(file, rb) as f:
5    dict = pickle.load(f, encoding=bytes)
6    return dict 

由此得到的是一个字典,有data和labels两个值。

data:

一个10000*3072的numpy数组,这个数组的每一行存储了32*32大小的彩色图像。前1024个数是red,然后分别是green,blue。

labels:
一个范围在0-9的含有10000个数的一维数组。第i个数就是第i个图像的类标。

 

2. 基于曼哈顿距离的1NN分类

先简单的直接找最接近的那个图片吧。

 1 #! /usr/bin/dev python
 2 # coding=utf-8
 3 import os
 4 import sys
 5 import pickle
 6 import numpy as np
 7 
 8 def load_data(file):
 9     with open(file, rb) as f:
10         datadict = pickle.load(f, encoding=latin1)
11         X = datadict[data]
12         Y = datadict[labels]
13         X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype(float)
14         Y = np.array(Y)
15         return X, Y
16 
17 def load_all(root):
18     xs = []
19     ys = []
20     for n in range(1, 2):
21         f = os.path.join(root, data_batch_%d %(n,))
22         X, Y = load_data(f)
23         xs.append(X)
24         ys.append(Y)
25     X_train = np.concatenate(xs)
26     Y_train = np.concatenate(ys)
27     del X, Y
28     X_test, Y_test = load_data(os.path.join(root, test_batch))
29     return X_train, Y_train, X_test, Y_test
30 
31 
32 def classTest(Xtr_rows, Xte_rows, Y_train):
33     count = 0
34     numTest = Xte_rows.shape[0]
35     result = np.zeros(numTest)   #构造一维向量的结果
36     for i in range(numTest):
37         distance = np.sum(np.abs(Xtr_rows - Xte_rows[i,:]), axis=1)
38         min_dis = np.argmin(distance)
39         result[i] = Y_train[min_dis]
40         print(%d:  %d %(count, result[i]))
41         count += 1
42     return result
43 
44 if __name__ == __main__:
45     X_train, Y_train, X_test, Y_test = load_all(D:\学习资料\机器学习\cifar-10-python\\)
46     Xtr_rows = X_train.reshape(X_train.shape[0], 32 * 32 * 3)
47     Xte_rows = X_test.reshape(X_test.shape[0], 32 * 32 * 3)
48     result = classTest(Xtr_rows, Xte_rows, Y_train)
49     print(accuracy: %f % (np.mean(result == Y_test)))

 

KNN——图像分类

标签:class   https   lan   tps   简单的   gre   byte   load   ast   

原文地址:https://www.cnblogs.com/zyb993963526/p/8722866.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!