码迷,mamicode.com
首页 > 其他好文 > 详细

[MNIST数据集]输入图像的预处理

时间:2019-01-06 01:08:22      阅读:295      评论:0      收藏:0      [点我收藏+]

标签:state   else   int   oat   print   pat   二维   car   部分   

因为MNIST数据是28*28的黑底白字图像,而且输入时要将其拉直,也就是可以看成1*784的二维张量(张量的值在0~1之间),所以我们要对图片进行预处理操作,是图片能被网络识别。

以下是代码部分

import tensorflow as tf
import numpy as np
from PIL import Image
import backward as bw
import forward as fw

def restore(testPicArr):
    with tf.Graph().as_default() as g:
        x = tf.placeholder(tf.float32, [None, fw.INPUT_NODES])
        y_ = tf.placeholder(tf.float32, [None, fw.OUTPUT_NODES])
        y = fw.get_y(x, None)
        preValue = tf.arg_max(y, 1)
        
        ema = tf.train.ExponentialMovingAverage(bw.MOVING_ARVERAGE_DECAY)
        ema_restore = ema.variables_to_restore()
        saver = tf.train.Saver(ema_restore)
        
        with tf.Session() as sess:
            tf.logging.set_verbosity(tf.logging.WARN)#降低警告等级
            ckpt = tf.train.get_checkpoint_state("./model/")
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                
                preValue = sess.run(preValue, feed_dict = {x: testPicArr})
                return preValue
            else:
                print("NO!!!")
                return -1
    
def pre_pic(picName):
    img = Image.open(picName)
    reIm = img.resize((28, 28), Image.ANTIALIAS)
    im_arr = np.array(reIm.convert(L))#变为灰度图
    threshold = 50#阈值,将图片二值化操作
    for i in range(28):
        for j in range(28):
            im_arr[i][j] = 255 - im_arr[i][j]#进行反色处理
            if(im_arr[i][j] < threshold):
                im_arr[i][j] = 0
            else: im_arr[i][j] = 255
    
    nm_arr = im_arr.reshape([1,784])
    nm_arr = nm_arr.astype(np.float32)#类型转换
    img_ready = np.multiply(nm_arr, 1.0/255.0)#把值变为0~1之间的数值
    
    return img_ready

def app():
    testNum = input("Input the number of test pictutre:")
    for i in range(int(testNum)):
        testPic = input("the path of test picture:")
        testPicArr = pre_pic(testPic)
        preValue = restore(testPicArr)
        print("The prediction number is :" , preValue)
        
def main():
    app()
    
if __name__ == __main__:
    main()
            

 

[MNIST数据集]输入图像的预处理

标签:state   else   int   oat   print   pat   二维   car   部分   

原文地址:https://www.cnblogs.com/1by1/p/10226900.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!