标签:
0-9数字识别,NMIST数据的识别。
具体代码包括NMIST见附件中。
参考资料是TOM的机器学习BP那一章。
# coding:utf-8 # 没考虑大小端 import struct import numpy def loadImages(filename): try: f = open(filename,‘rb‘) except Exception as instance: print type(instance) exit() allImage = [] bins = f.read() index = 0 magicNum,imageNum,rowNum,colNum = struct.unpack_from(‘>IIII‘,bins,index) index = index + struct.calcsize(‘>IIII‘) assert 2051 == magicNum,‘dataset damaged | little endian‘ for ct in xrange(imageNum): allImage.append(struct.unpack_from(‘>784B‘,bins,index)) index = index + struct.calcsize(‘>784B‘) return numpy.array(allImage,dtype=‘float32‘) def loadLabels(filename): try: f = open(filename,‘rb‘) except Exception as instance: print type(instance) exit() allLabels = [] bins = f.read() index = 0 magicNum,labelNum = struct.unpack_from(‘>II‘,bins,index) index = index + struct.calcsize(‘>II‘) assert 2049 == magicNum,‘dataset damaged | little endian‘ for ct in xrange(labelNum): allLabels.append(struct.unpack_from(‘B‘,bins,index)) index = index + struct.calcsize(‘B‘) return numpy.array(allLabels,dtype=‘float32‘) if ‘__main__‘ == __name__: images = loadImages(‘t10k-images.idx3-ubyte‘) labels = loadLabels(‘t10k-labels.idx1-ubyte‘) import matplotlib.pyplot as plt for x in range(3): plt.figure() shown = images[x].reshape(28,28) # shown 28*28 numpy matrix plt.imshow(shown,cmap=‘gray‘) plt.title(str(labels[x])) plt.show()
# -*- coding: utf-8 -*- import dataLoad import numpy as np import sys import warnings def bp(trainSet,eta=0.01,nin=None,nhid=None,nout=None,iterNum = 10): ‘‘‘[(instance,label),((784,1)array,(10,1)array)……]‘‘‘ Wkh = (np.random.rand(nout,nhid)-0.5) / 10.0 Whi = (np.random.rand(nhid,nin )-0.5) / 10.0 iteration = 0 er = 1 ‘‘‘iteration‘‘‘ while iteration < iterNum and er > 0.04: print ‘iteration=‘,iteration er = testAnn((Whi,Wkh)) iteration += 1 for (x,label) in trainSet: # 最大最小归一化 x = (x - x.min())/x.max()-x.min() #前向 neth = np.dot(Whi,x) # nhid*nin nin*1 -> nhid*1 oh = sigmoid(neth) # nhid*1 netk = np.dot(Wkh,oh) #nout*nhid nhid*1 -> nout*1 ok = sigmoid(netk) # nout*1 #求误差 dk = ok*(1-ok)*(label-ok) dh = oh*(1-oh)*np.dot(Wkh.T,dk) #(nhid,1) = (nout,nhid).T * (nout,1) #更新权值矩阵 Wkh = Wkh + eta * np.dot(dk,oh.T) #nout*nhid + nout*1*1*hid Whi = Whi + eta * np.dot(dh,x.T) #nhid*nin + nhid*1*1*nin print ‘iteration over‘ return Wkh,Whi def testAnn(model): err = 0 for i in range(len(testLabels)): res = fit(model,testImages[i].reshape(784,1)) if i > 9990: print ‘\t ‘,int(testLabels[i][0]),‘was recognized as‘,res[1] if testLabels[i][0] != res[1]: err += 1 errorRate = float(err)/float(len(testLabels)) print ‘error rate‘,errorRate,‘\n‘ return errorRate def fit(model,Image): Whi,Wkh = model ok = list(sigmoid(sigmoid(Image.T.dot(Whi.T)).dot(Wkh.T))[0]) return ok,ok.index(max(ok)) def sigmoid(y): # 讨厌的溢出警告 warnings.filterwarnings("ignore") return 1/(1+np.exp(-y)) if ‘__main__‘ == __name__: np.random.seed(207) trainImages = dataLoad.loadImages(‘train-images.idx3-ubyte‘) trainLabels = dataLoad.loadLabels(‘train-labels.idx1-ubyte‘) testImages = dataLoad.loadImages(‘t10k-images.idx3-ubyte‘) testLabels = dataLoad.loadLabels(‘t10k-labels.idx1-ubyte‘) dataSet = [] for i in range(len(trainLabels)): tmp = np.zeros((10,1),dtype=‘float32‘) tmp[int(trainLabels[i]),0] = 1 dataSet.append((trainImages[i].reshape(784,1),tmp)) bp(trainSet=dataSet,eta=0.05,nin=784,nhid=20,nout=10,iterNum=20)
标签:
原文地址:http://www.cnblogs.com/myli-aslp/p/4312850.html