码迷,mamicode.com
首页 > 其他好文 > 详细

用cnn构建多层神经网络来识别mnist中的图片

时间:2017-10-28 11:16:54      阅读:303      评论:0      收藏:0      [点我收藏+]

标签:with   extra   bat   roo   ror   深度   code   stdout   ide   

mnist.py

import tensorflow as tf
import numpy as np
import argparse
import sys
import urllib
import os
import gzip

SOURCE_URL = rhttp://yann.lecun.com/exdb/mnist/
TRAIN_SIZE = 55000
TEST_SIZE = 10000
VALIDATE_SIZE = 5000
IMAGE_SIZE = 28
NUMBER_CHANNEL = 1
TRAIN_DATA = train-images-idx3-ubyte.gz
TRAIN_LABELS = train-labels-idx1-ubyte.gz
TEST_DATA = t10k-images-idx3-ubyte.gz
TEST_LABELS = t10k-labels-idx1-ubyte.gz


class Mnist:
  def __init__(self, FLAGS):
    self.FLAGS = FLAGS
    self.start = 0
    self.train_size = TRAIN_SIZE
    self.validate_size = VALIDATE_SIZE
    self.test_size = TEST_SIZE
    
    self._maybe_download(TRAIN_DATA)
    self.train_data = self._extract_images(TRAIN_DATA)
    
    self.validate_data, self.train_data = self._get_validate_data(self.train_data, self.validate_size)

    self._maybe_download(TRAIN_LABELS)
    self.train_labels = self._extract_labels(TRAIN_LABELS)
    self.train_labels = self._one_hot(self.train_labels)
    
    self.validate_labels, self.train_labels= self._get_validate_data(self.train_labels, self.validate_size)
    
    self._maybe_download(TEST_DATA)
    self.test_data = self._extract_images(TEST_DATA)
    
    self._maybe_download(TEST_LABELS)
    self.test_labels = self._extract_labels(TEST_LABELS)
    self.test_labels = self._one_hot(self.test_labels)
    
    if self.num_images != self.num_labels:
      raise Error(number of images and number of labels don\‘t match)
    
    self.train_data, self.train_labels = self._shuffle(self.train_data, self.train_labels) 
    print(Done preparing data)

  def _shuffle(self, images, labels):
    perm = np.arange(self.train_size)
    np.random.shuffle(perm)
    return images[perm], labels[perm] 
    
  def _read32(self, bufstream):
    dt = np.dtype(>i4)
    buf = bufstream.read(4)
    data = np.frombuffer(buf, dtype = dt)[0]
    data.astype(np.int32)
    return data
    
  def _maybe_download(self, filename):
    filepath = os.path.join(self.FLAGS.data_dir, filename)
    if not os.path.isdir(self.FLAGS.data_dir):
      os.path.mkdir(self.FLAGS.data_dir)
    if not os.path.isfile(filepath):
        #def _progress():           
      urllib.request.urlopen(SOURCE_URL, filepath)
        
  def _extract_images(self, filename):
    filepath = os.path.join(self.FLAGS.data_dir, filename)
    with gzip.open(filepath) as bufstream:
      magic = self._read32(bufstream)
      print(magic)
      if not magic == 2051:
        raise ValueError("2051 error")
      num_data = self._read32(bufstream)
      self.num_images = num_data
      rows = self._read32(bufstream)
      self.image_size = rows
      cols = self._read32(bufstream)
      print(num_data %d rows %d cols %d% (num_data, rows, cols))
      buf = bufstream.read(num_data*rows*cols*NUMBER_CHANNEL) 
      data = np.frombuffer(buf, dtype = np.uint8)
      data.astype(np.float32)
      data = data.reshape(num_data, rows*cols*NUMBER_CHANNEL)     
      return data

  def _extract_labels(self, filename):
    filepath = os.path.join(self.FLAGS.data_dir, filename)
    with gzip.open(filepath) as bufstream: 
      magic = self._read32(bufstream)
      if magic != 2049:
        raise ValueError(2049 error)
      num_labels = self._read32(bufstream)
      self.num_labels = num_labels
      print(num_labels %d%num_labels)
      buf = bufstream.read(num_labels*1)
      labels = np.frombuffer(buf, dtype = np.uint8)
      labels.astype(np.int32)       
      return labels
        
  def _get_validate_data(self, data, validate_size):
    if validate_size > data.shape[0]:
      raise Error(validate size out of bound)
    validate_data = data[:validate_size,...]
    data_left = data[validate_size:,...]
    return validate_data, data_left 
    
  def _one_hot(self, labels):
    num_labels = labels.shape[0]
    one_hot_labels = np.zeros([num_labels,10],dtype = np.float32)    
    one_hot_labels[[range(num_labels)], labels] = 1.0
    return one_hot_labels

  def get_batch(self, batch_size):              
    end = self.start +batch_size
    tmp_start = self.start
   
    if end >= self.train_size:
      self.start = end%self.train_size
      tmp_data = np.vstack((self.train_data, self.train_data[:batch_size,...]))
      tmp_labels = np.vstack((self.train_labels, self.train_labels[:batch_size,...]))
      return tmp_data[tmp_start:end,...],tmp_labels[tmp_start:end,...]
    else:
      self.start = end
      return self.train_data[tmp_start:end,...], self.train_labels[tmp_start:end,...]
  

mnist_cnn.py

import tensorflow as tf
import numpy as np
import mnist
import argparse
import sys
import os
from tensorflow.examples.tutorials.mnist import input_data

FLAGS = tf.flags.FLAGS

def _weight(name, shape):
  return tf.Variable(tf.truncated_normal(stddev = 0.1, shape = shape, name = name, dtype = tf.float32))

def _bias(name, shape):
  return tf.Variable(tf.constant(0.1, shape = shape, name = name, dtype= tf.float32))   

def _conv(x, w):
  return tf.nn.conv2d(x, w, strides = [1,1,1,1], padding= ‘SAME‘)
 
def _max_pool(x):
  return tf.nn.max_pool(x, ksize = [1,2,2,1], strides = [1,2,2,1], padding = ‘SAME‘)

def _rm_prefiles(paths):
  for path in paths:
    for root, dirs, names in os.walk(path):
      for filename in names:
        os.remove(os.path.join(root, filename))
      

def cnn_structure(x):
  x_images =tf.reshape(x,[-1,28,28,1])
 
  with tf.name_scope(‘conv1‘):
    w_cov1 = _weight(‘w_cov1‘, [5,5,1,32])
    with tf.name_scope(‘w_cov1‘):
      tf.summary.scalar(‘w_cov1_mean‘, tf.reduce_mean(w_cov1))
      tf.summary.scalar(‘w_cov1_dev‘, tf.reduce_mean(tf.square(w_cov1- tf.reduce_mean(w_cov1))))
      tf.summary.histogram(‘w_cov1_hist‘, w_cov1)
    b_cov1 = _bias(‘b_cov1‘, [32])
    h_cov1 = tf.nn.relu(_conv(x_images, w_cov1) + b_cov1 )
    
  h_pool1 = _max_pool(h_cov1)
    
  with tf.name_scope(‘conv2‘):
    w_cov2 = _weight(‘w_cov2‘, [5,5,32,64])
    b_cov2 = _bias(‘b_cov2‘, [64])
    h_cov2 = tf.nn.relu(_conv(h_pool1, w_cov2) + b_cov2)
    
  h_pool2 = _max_pool(h_cov2)
  h_pool2_flat = tf.reshape(h_pool2, [-1, 3136])
 
  with tf.name_scope(‘fc1‘):
    w_fc1 = _weight(‘w_fc1‘, [7*7*64, 1024])
    b_fc1 = _bias(‘b_fc1‘, [1024])
    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)
 
  with tf.name_scope(‘dropout‘):
    keep_prob = tf.placeholder(dtype = tf.float32)
    h_dropout = tf.nn.dropout(h_fc1, keep_prob)
 
  with tf.name_scope(‘fc2‘):
    w_fc2 = _weight(‘w_fc2‘, [1024, 10])
    b_fc2 = _bias(‘b_fc2‘, [10])
    y = tf.matmul(h_dropout, w_fc2) + b_fc2
 
  saver = tf.train.Saver({‘w_cov1‘:w_cov1, ‘b_cov1‘:b_cov1, ‘w_cov2‘:w_cov2, ‘b_cov2‘:b_cov2})
 
  return y, keep_prob, saver
 
def main(_):
 
  mnistdata = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
  mymnist = mnist.Mnist(FLAGS)
  x = tf.placeholder(shape = [None, 784], dtype = tf.float32)
  y_ = tf.placeholder(shape = [None, 10], dtype = tf.float32)
 
  y, keep_prob, saver = cnn_structure(x)
 
  with tf.name_scope(‘train_section‘):  
    with tf.name_scope(‘Cross_entropy‘):
      cross_entropy = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits = y, labels = y_))
      tf.summary.scalar(‘cross_entropy‘, cross_entropy)

    with tf.name_scope(‘Adam_Optimizer‘):
      optimizer = tf.train.AdamOptimizer(learning_rate = 0.0001)  
      train_steps = optimizer.minimize(cross_entropy)
 
  with tf.name_scope(‘Test_section‘):
    with tf.name_scope(‘Accuracy‘):
      prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))  
      accuracy = tf.reduce_mean(tf.cast(prediction, tf.float32))
    
    
  merged = tf.summary.merge_all()
 
 
  _rm_prefiles([FLAGS.summary_dir, FLAGS.model_dir])
 
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    train_writer = tf.summary.FileWriter(FLAGS.summary_dir, sess.graph)
 
 
    for i in range(FLAGS.train_steps):
      if i%20 ==0:
        j = int(i/int(FLAGS.train_steps/60))
        sys.stdout.write(‘\r%4.2f%% ‘%(float(i)/float(FLAGS.train_steps)*100))
        sys.stdout.write(‘->‘+‘#‘*j+(60-j)*‘ ‘+‘<-‘)
        sys.stdout.flush()
      #batch = mnistdata.train.next_batch(50)
      #feed_dict = {x:batch[0], y_:batch[1],keep_prob:0.5}
      batch_xs, batch_ys = mymnist.get_batch(FLAGS.batch_size)
      feed_dict = {x: batch_xs, y_: batch_ys, keep_prob: 0.5}
      
      _, summaries = sess.run([train_steps,merged], feed_dict = feed_dict)
      train_writer.add_summary(summaries, i)
      ‘‘‘
      if i%100 ==0:
        feed_dict= {x:mymnist.test_data, y_:mymnist.test_labels, keep_prob : 1.0}
        #feed_dict = {x: mnistdata.test.images, y_: mnistdata.test.labels, keep_prob: 1.0}
        print(‘%5.2f%%‘%(sess.run(accuracy,feed_dict = feed_dict)*100))
      ‘‘‘
    saver.save(sess, os.path.join(FLAGS.model_dir, ‘model.ckpt‘))
    feed_dict= {x:mymnist.test_data, y_:mymnist.test_labels, keep_prob : 1.0}
    #feed_dict = {x: mnistdata.test.images, y_: mnistdata.test.labels, keep_prob: 1.0}  
    print(‘%5.2f%%‘%(sess.run(accuracy,feed_dict = feed_dict)*100))
    
 



if __name__ ==‘__main__‘:
    parser = argparse.ArgumentParser()
    parser.add_argument("-d", "--data_dir", type = str, default = r‘W:\workspace\tensorflow\my\data‘,
    help = "Directory for storing input data")
    parser.add_argument(‘-m‘, "--model_dir", type = str, default = r‘W:\workspace\tensorflow\mnistexpert\model‘, help = ‘Directory for storing model‘)
    parser.add_argument(‘-s‘,‘--summary_dir‘, type = str, default = r‘W:\workspace\tensorflow\mnistexpert\summary‘, help = ‘Directory for storing summary data‘)
    parser.add_argument(‘-b‘, ‘--batch_size‘, type = int, default = 50, help = ‘Size of batch‘)
    parser.add_argument(‘-t‘, ‘--train_steps‘, type = int, default = 20000, help = ‘Number of steps to train the model‘)
    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(main = main, argv = [sys.argv[0]]+ unparsed)



在这次实现中将深度神经网络框架独立了出来写进了一个函数,从而使得代码可读性大大提高

用cnn构建多层神经网络来识别mnist中的图片

标签:with   extra   bat   roo   ror   深度   code   stdout   ide   

原文地址:http://www.cnblogs.com/Yorkisme/p/7745822.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!