码迷,mamicode.com
首页 > 编程语言 > 详细

BP算法 python实现

时间:2015-03-04 12:47:23      阅读:560      评论:0      收藏:0      [点我收藏+]

标签:

0-9数字识别,NMIST数据的识别。

具体代码包括NMIST见附件中。

参考资料是TOM的机器学习BP那一章。

 

技术分享
# coding:utf-8
# 没考虑大小端
import struct
import numpy

def loadImages(filename):
    try:
        f = open(filename,rb)
    except Exception as instance:
        print type(instance)
        exit()
        
    allImage = []
    bins = f.read()
    index = 0
    magicNum,imageNum,rowNum,colNum = struct.unpack_from(>IIII,bins,index)
    index = index + struct.calcsize(>IIII)
    
    assert 2051 == magicNum,dataset damaged | little endian
    
    for ct in xrange(imageNum):
        allImage.append(struct.unpack_from(>784B,bins,index))
        index = index + struct.calcsize(>784B)
    
    return numpy.array(allImage,dtype=float32)

def loadLabels(filename):
    try:
        f = open(filename,rb)
    except Exception as instance:
        print type(instance)
        exit()
        
    allLabels = []
    bins = f.read()
    index = 0
    magicNum,labelNum = struct.unpack_from(>II,bins,index)
    index = index + struct.calcsize(>II)
    
    assert 2049 == magicNum,dataset damaged | little endian
    
    for ct in xrange(labelNum):
        allLabels.append(struct.unpack_from(B,bins,index))
        index = index + struct.calcsize(B)
    return numpy.array(allLabels,dtype=float32)
        
    
    

if __main__ == __name__:
    images = loadImages(t10k-images.idx3-ubyte)
    labels = loadLabels(t10k-labels.idx1-ubyte)
    import matplotlib.pyplot as plt
    for x in range(3):
        plt.figure()
        shown = images[x].reshape(28,28)
        # shown 28*28 numpy matrix
        plt.imshow(shown,cmap=gray)
        plt.title(str(labels[x]))
        plt.show()
     
     
     
     
     
     
     
     
     
     
读取NMIST
# -*- coding: utf-8 -*-
import dataLoad
import numpy as np
import sys
import warnings

def bp(trainSet,eta=0.01,nin=None,nhid=None,nout=None,iterNum = 10):
    ‘‘‘[(instance,label),((784,1)array,(10,1)array)……]‘‘‘
    Wkh = (np.random.rand(nout,nhid)-0.5) / 10.0
    Whi = (np.random.rand(nhid,nin )-0.5) / 10.0
    iteration = 0
    er = 1
    ‘‘‘iteration‘‘‘
    while iteration < iterNum and er > 0.04:
        print iteration=,iteration
        er = testAnn((Whi,Wkh))
        iteration += 1
        for (x,label) in trainSet:
            # 最大最小归一化
            x = (x - x.min())/x.max()-x.min()
            #前向
            neth = np.dot(Whi,x) # nhid*nin nin*1 -> nhid*1
            oh = sigmoid(neth)   # nhid*1
            netk = np.dot(Wkh,oh) #nout*nhid nhid*1 -> nout*1
            ok = sigmoid(netk) # nout*1
            #求误差
            dk = ok*(1-ok)*(label-ok) 
            dh = oh*(1-oh)*np.dot(Wkh.T,dk) #(nhid,1) = (nout,nhid).T * (nout,1)
            #更新权值矩阵
            Wkh = Wkh + eta * np.dot(dk,oh.T)  #nout*nhid + nout*1*1*hid
            Whi = Whi + eta * np.dot(dh,x.T)  #nhid*nin + nhid*1*1*nin
    print iteration over 
    return Wkh,Whi

def testAnn(model):
    err = 0   
    for i in range(len(testLabels)):
        res = fit(model,testImages[i].reshape(784,1))
        if i > 9990:
            print \t ,int(testLabels[i][0]),was recognized as,res[1]
        if testLabels[i][0] != res[1]:
            err += 1
    errorRate = float(err)/float(len(testLabels))
    print error rate,errorRate,\n
    return errorRate

def fit(model,Image):
    Whi,Wkh = model
    ok = list(sigmoid(sigmoid(Image.T.dot(Whi.T)).dot(Wkh.T))[0])
    return ok,ok.index(max(ok))

def sigmoid(y):
    # 讨厌的溢出警告
    warnings.filterwarnings("ignore")
    return 1/(1+np.exp(-y))    

if __main__ == __name__:
    np.random.seed(207)
    trainImages = dataLoad.loadImages(train-images.idx3-ubyte)
    trainLabels = dataLoad.loadLabels(train-labels.idx1-ubyte) 
    testImages = dataLoad.loadImages(t10k-images.idx3-ubyte)
    testLabels = dataLoad.loadLabels(t10k-labels.idx1-ubyte) 
    dataSet = []
    for i in range(len(trainLabels)):
        tmp = np.zeros((10,1),dtype=float32)
        tmp[int(trainLabels[i]),0] = 1
        dataSet.append((trainImages[i].reshape(784,1),tmp))
    bp(trainSet=dataSet,eta=0.05,nin=784,nhid=20,nout=10,iterNum=20)
  

 

BP算法 python实现

标签:

原文地址:http://www.cnblogs.com/myli-aslp/p/4312850.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!