码迷,mamicode.com
首页 > 其他好文 > 详细

TensorFlow(八) TensorFlow图像识别(KNN)

时间:2018-06-29 11:10:38      阅读:178      评论:0      收藏:0      [点我收藏+]

标签:numpy   分享   dimens   type   dict   run   rom   ima   ase   

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn import  datasets
import random
from PIL import Image

from tensorflow.examples.tutorials.mnist import  input_data

sess=tf.Session()
mnist= input_data.read_data_sets("MNIST_data/",one_hot=True)
#本例包含10个类别
train_size=1000
test_size=102
rand_train_indices=np.random.choice(len(mnist.train.images),train_size,replace=False)

rand_test_indices=np.random.choice(len(mnist.train.images),test_size,replace=False)

x_vals_train=mnist.train.images[rand_train_indices]
x_vals_test=mnist.train.images[rand_test_indices]
y_vals_train=mnist.train.labels[rand_train_indices]
y_vals_test=mnist.train.labels[rand_test_indices]

k=4
batch_size=6
x_data_train=tf.placeholder(shape=[None,784],dtype=tf.float32)
x_data_test=tf.placeholder(shape=[None,784],dtype=tf.float32)
y_target_train=tf.placeholder(shape=[None,10],dtype=tf.float32)
y_target_test=tf.placeholder(shape=[None,10],dtype=tf.float32)

#L1距离 shape=(6, 1000)   sub.shape=(1000,784) - (6,1,10)=(6,1000,784)
distance=tf.reduce_sum(tf.abs(tf.subtract(x_data_train,tf.expand_dims(x_data_test,1))),reduction_indices=2)

#top K (6, 4)
top_k_xvals,top_k_indices=tf.nn.top_k(tf.negative(distance),k=k)
#(6, 4, 10)  =   gather((1000,10),(6,4)  )
prediction_indices=tf.gather(y_target_train,top_k_indices)
#shape=(6, 10)
count_of_prediction=tf.reduce_sum(prediction_indices,reduction_indices=1)
#预测模型 shape=(6,)
prediction=tf.arg_max(count_of_prediction,dimension=1)

num_loop=int(np.ceil(len(x_vals_test)/batch_size))
test_output=[]
actual_vals=[]
for i in range(num_loop):
    min_index=i*batch_size
    max_index=min((i+1)*batch_size,len(x_vals_test))
    #获取数据
    x_batch=x_vals_test[min_index:max_index]
    y_batch = y_vals_test[min_index:max_index]
    predictions=sess.run(prediction,feed_dict={x_data_test:x_batch,x_data_train:x_vals_train,y_target_test:y_batch,y_target_train:y_vals_train})
    test_output.extend(predictions)
    actual_vals.extend(np.argmax(y_batch,axis=1))

#精确度预测
accuracy=sum( 1./test_size for i in range(test_size) if test_output[i]==actual_vals[i])
print("Accuarcy: "+str(accuracy))

actuals=np.argmax(y_batch,axis=1)
for i in range(len(actuals)):
    plt.subplot(2,3,i+1)
    plt.imshow(np.reshape(x_batch[i],[28,28]),cmap="Greys_r")
    plt.title(Actual: +str(actuals[i])+ Pred:+str(predictions[i]),fontsize=10)
    frame=plt.gca()
    frame.axes.get_xaxis().set_visible(False)
    frame.axes.get_yaxis().set_visible(False)

plt.show()

技术分享图片

 

TensorFlow(八) TensorFlow图像识别(KNN)

标签:numpy   分享   dimens   type   dict   run   rom   ima   ase   

原文地址:https://www.cnblogs.com/x0216u/p/9241759.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!