参考文章:
http://www.csuldw.com/2016/02/25/2016-02-25-machine-learning-MNIST-dataset/
import numpy as np import struct import matplotlib.pyplot as plt import os filename = ‘data_AI/MNIST/train-images.idx3-ubyte‘ binfile = open(filename , ‘rb‘) buf = binfile.read() index = 0 magic, numImages , numRows , numColumns = struct.unpack_from(‘>IIII‘ , buf , index) index += struct.calcsize(‘IIII‘ ) images = [] for i in range(numImages): imgVal = struct.unpack_from(‘>784B‘, buf, index) index += struct.calcsize(‘>784B‘) imgVal = list(imgVal) for j in range(len(imgVal)): if imgVal[j] > 1: imgVal[j] = 1 images.append(imgVal) arrX = np.array(images) # 读取标签 binFile = open(‘data_AI/MNIST/train-labels.idx1-ubyte‘,‘rb‘) buf = binFile.read() binFile.close() index = 0 magic, numItems= struct.unpack_from(‘>II‘, buf,index) index += struct.calcsize(‘>II‘) labels = [] for x in range(numItems): im = struct.unpack_from(‘>1B‘,buf,index) index += struct.calcsize(‘>1B‘) labels.append(im[0]) arrY = np.array(labels) print(np.shape(arrY)) # print(np.shape(trainX)) #以下内容是将图像保存到本地文件中 path_trainset = "data_AI/MNIST/imgs_train" path_testset = "data_AI/MNIST/imgs_test" if not os.path.exists(path_trainset): os.mkdir(path_trainset) if not os.path.exists(path_testset): os.mkdir(path_testset) for i in range(1): img = np.array(arrX[i]) print(img) img = img.reshape(28,28) outfile = str(i) + "_" + str(arrY[i]) + ".png" # outfile = str(i)+".png" plt.figure() plt.imshow(img, cmap = ‘binary‘) #将图像黑白显示 plt.savefig(path_trainset + "/" + outfile) print("save"+str(i)+"张")