码迷,mamicode.com
首页 > 其他好文 > 详细

Tensorflow Mnist数据集

时间:2019-07-01 18:37:09      阅读:132      评论:0      收藏:0      [点我收藏+]

标签:print   ima   rand   rom   bat   inpu   example   相关   class   

Tensorflow自带的Mnist数据集相关情况

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
#数据会自动在线下载,第一次较慢,第二次之后就好了
mnist = input_data.read_data_sets(‘data/‘,one_hot=True)
print(type(mnist))
print(mnist.train.num_examples)#55000
print(mnist.test.num_examples)#10000

img_train = mnist.train.images
label_train = mnist.train.labels

img_test = mnist.test.images
label_test = mnist.test.labels

print(type(img_train))#<class ‘numpy.ndarray‘>
print(type(label_train))#<class ‘numpy.ndarray‘>
print(type(img_test))#<class ‘numpy.ndarray‘>
print(type(label_test))#<class ‘numpy.ndarray‘>
print(img_train.shape)#(55000, 784) 28*28的图片
print(label_train.shape)#(55000, 10)
print(img_test.shape)#(10000, 784)
print(label_test.shape)#(10000, 10) #one hot coding便于取最大概率

num_sample = 5
rand_idx = np.random.randint(img_train.shape[0], size=num_sample)

for i in rand_idx:
    cur_img = np.reshape(img_train[i, :],(28,28))
    cur_label = np.argmax(label_train[i,:])
    plt.matshow(cur_img, cmap = plt.get_cmap(‘gray‘))
    print(str(i) + "训练数据的标签是" + str(cur_label))
    # plt.show()

#取batch数据
batch_size = 100
batch_x, batch_y = mnist.train.next_batch(batch_size)
print(type(batch_x))#<class ‘numpy.ndarray‘>
print(type(batch_y))#<class ‘numpy.ndarray‘>
print(batch_x.shape)#(100, 784)
print(batch_y.shape)#(100, 10)

Tensorflow Mnist数据集

标签:print   ima   rand   rom   bat   inpu   example   相关   class   

原文地址:https://blog.51cto.com/5669384/2415956

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!