码迷,mamicode.com
首页 > 其他好文 > 详细

用LSTM分类 MNIST

时间:2017-09-03 11:16:00      阅读:320      评论:0      收藏:0      [点我收藏+]

标签:学习   input   网络模型   线图   技术   add   size   算法   util   

 

    LSTM是RNN的一种算法, 在序列分类中比较有用。常用于语音识别,文字处理(NLP)等领域。 

等同于VGG等CNN模型在在图像识别领域的位置。  本篇文章是叙述LSTM 在MNIST 手写图中的使用。

用来给初步学习RNN的一个范例,便于学习和理解LSTM .

    先把工作流程图贴一下

技术分享

 

代码片段

   数据准备

def makedata():
    img_rows, img_cols = 28, 28

    mnist = fetch_mldata("MNIST original")
    # rescale the data, use the traditional train/test split
    X_1D, y_int = mnist.data / 255., mnist.target
    y = np_utils.to_categorical(y_int, num_classes=10)

    X = X_1D.reshape(X_1D.shape[0], img_rows, img_cols )

    input_shape = (img_rows, img_cols, 1)
    x_train, x_test = X[:60000], X[60000:]
    y_train, y_test = y[:60000], y[60000:]

    return X, y
    pass

下载 MNIST数据, 进行归一化  mnist.data / 255, 把数据[7000,784 ] 转成[ 70000,28,28] 

 

构建模型:

def buildlstm():

    import numpy as np

    data_dim = 28
    timesteps = 28
    num_classes = 10

    # expected input data shape: (batch_size, timesteps, data_dim)
    model = Sequential()
    model.add(LSTM(32, return_sequences=True,   input_shape=(timesteps, data_dim+14)))   
    model.add(LSTM(32, return_sequences=True))  
    model.add(LSTM(32))  
    model.add(Dense(10, activation=softmax))

    model.compile(loss=categorical_crossentropy,
                  optimizer=rmsprop,
                  metrics=[accuracy])
    print model.summary()
    return  model
    pass

基础参数: data_dim, timesteps, num_classes   分别为 28,28, 10
网络层级 :    LSTM ----》LSTM ----》LSTM ----》Dense
注意点: input_shape=(timesteps, data_dim+14))   此处 应该为  data_dim , data_dim+14是我做第二个试验使用。
网络理解: RNN是用前一部分数据对当前数据的影响,并共同作用于最后结果。 用基础的深度神经网络(只有Dense层),是把MNIST一个图形,
提取成784个像素数据,把784个数据扔给神经网络,784个数据是同等的概念。 训练出权重来确定最终的分类值。   

RNN 之于MNIST, 是把MNIST 分成 28x28 数据。可以理解为用一个激光扫描一个图片,扫成28个(行)数据, 每行为28个像素。 站在时间序列
的角度,其实图片没有序列概念。但是我们可以这样理解, 每一行于下一行是有位置关系的,不能进行顺序变化。 比如一个手写 “7”字, 如果把28行
的上下行顺序打乱, 那么7 上面的一横就可能在中间位置,也可能在下面的位置。  这样,最终的结果就不应该是 7 .  
所以MNIST 的 28x28可以理解为 有时序关系的数据。 

训练预测:

def runTrain(model, x_train, x_test, y_train, y_test):
    model.fit(x_train, y_train,  batch_size= nbatch_size, epochs= nEpoches)
    score = model.evaluate(x_test, y_test, batch_size=nbatch_size)
    print evaluate score:, score
    pass

这部分应该没什么好说的

主程序:

def test():

    X,y = makedata2()
    x_train, x_test = X[:60000], X[60000:]
    y_train, y_test = y[:60000], y[60000:]
    model = buildlstm()
    runTrain(model, x_train, x_test, y_train, y_test )
    pass


运行结果

结构:
Layer (type)                 Output Shape              Param #
=================================================================
lstm_1 (LSTM)                (None, 28, 32)            7808
_________________________________________________________________
lstm_2 (LSTM)                (None, 28, 32)            8320
_________________________________________________________________
lstm_3 (LSTM)                (None, 32)                8320
_________________________________________________________________
dense_1 (Dense)              (None, 10)                330
=================================================================
Total params: 24,778
Trainable params: 24,778
Non-trainable params: 0
_________________________________________________________________


结果:
base    lstm for mnist
acc : 98.56%

结果2:
把数据最后增加 50%  的 0 , (dim X 0.5)
acc : 98.39%
结果基本上 与原数据一致

 

该实验证明两个结论:
1.  LSTM可用于图形识别
2.  在数据中 每行28个基础像素后面 + 14 个空白(0)的元素,不影分类识别。 


写在最后:  本实验的目的是为了理解RNN(LSTM),  只有理解了才能很好的使用。 本文章的目的是为记录和分享。
再说下 RNN在其它领域的应用。  比如在语音识别领域,一个音谱,识别成一个单词(词语),可以理解成一个
竖向扫描的MNIST ,   一个股票的K线图,也可以理解一个竖向扫描的MNIST。  还有其它领域,可以归纳递推。 
入门之后, 如何在自己的领域,再深入(构建复杂模型,优化数据的处理),提高网络模型的识别准确,那需要
见仁见智的。 

代码文件链接:

源码下载

 
有对 金融程序化 和 深度学习结合有兴趣的可以加群 , 个人群: 杭州程序化交易群  375129936

用LSTM分类 MNIST

标签:学习   input   网络模型   线图   技术   add   size   算法   util   

原文地址:http://www.cnblogs.com/xiaoxuebiye/p/7468732.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!