码迷,mamicode.com
首页 > 编程语言 > 详细

LSTM java 实现

时间:2016-11-11 17:20:57      阅读:310      评论:0      收藏:0      [点我收藏+]

标签:comm   ast   keyword   listen   开始   arguments   min   特性   article   

由于实验室事情缘故,需要将Python写的神经网络转成Java版本的,但是python中的numpy等啥包也不知道在Java里面对应的是什么工具,所以索性直接寻找一个现成可用的Java神经网络框架,于是就找到了JOONE,JOONE是一个神经网络的开源框架,使用的是BP算法进行迭代计算参数,使用起来比较方便也比较实用,下面介绍一下JOONE的一些使用方法。

 

JOONE需要使用一些外部的依赖包,这在官方网站上有,也可以在这里下载。将所需的包引入工程之后,就可以进行编码实现了。

 

首先看下完整的程序,这个是上面那个超链接给出的程序,应该是官方给出的一个示例吧,因为好多文章都用这个,这其实是神经网络训练一个异或计算器:

 

[java] view plain copy
 
  1. import org.joone.engine.*;  
  2. import org.joone.engine.learning.*;  
  3. import org.joone.io.*;  
  4. import org.joone.net.*;  
  5.   
  6.   
  7. /* 
  8.  *  
  9.  * JOONE实现 
  10.  *  
  11.  * */  
  12. public class XOR_using_NeuralNet implements NeuralNetListener  
  13. {  
  14.     private NeuralNet nnet = null;  
  15.     private MemoryInputSynapse inputSynapse, desiredOutputSynapse;  
  16.     LinearLayer input;  
  17.     SigmoidLayer hidden, output;  
  18.     boolean singleThreadMode = true;  
  19.   
  20.     // XOR input  
  21.     private double[][] inputArray = new double[][]  
  22.     {  
  23.     { 0.0, 0.0 },  
  24.     { 0.0, 1.0 },  
  25.     { 1.0, 0.0 },  
  26.     { 1.0, 1.0 } };  
  27.   
  28.     // XOR desired output  
  29.     private double[][] desiredOutputArray = new double[][]  
  30.     {  
  31.     { 0.0 },  
  32.     { 1.0 },  
  33.     { 1.0 },  
  34.     { 0.0 } };  
  35.   
  36.     /** 
  37.      * @param args 
  38.      *            the command line arguments 
  39.      */  
  40.     public static void main(String args[])  
  41.     {  
  42.         XOR_using_NeuralNet xor = new XOR_using_NeuralNet();  
  43.   
  44.         xor.initNeuralNet();  
  45.         xor.train();  
  46.         xor.interrogate();  
  47.     }  
  48.   
  49.     /** 
  50.      * Method declaration 
  51.      */  
  52.     public void train()  
  53.     {  
  54.   
  55.         // set the inputs  
  56.         inputSynapse.setInputArray(inputArray);  
  57.         inputSynapse.setAdvancedColumnSelector(" 1,2 ");  
  58.         // set the desired outputs  
  59.         desiredOutputSynapse.setInputArray(desiredOutputArray);  
  60.         desiredOutputSynapse.setAdvancedColumnSelector(" 1 ");  
  61.   
  62.         // get the monitor object to train or feed forward  
  63.         Monitor monitor = nnet.getMonitor();  
  64.   
  65.         // set the monitor parameters  
  66.         monitor.setLearningRate(0.8);  
  67.         monitor.setMomentum(0.3);  
  68.         monitor.setTrainingPatterns(inputArray.length);  
  69.         monitor.setTotCicles(5000);  
  70.         monitor.setLearning(true);  
  71.   
  72.         long initms = System.currentTimeMillis();  
  73.         // Run the network in single-thread, synchronized mode  
  74.         nnet.getMonitor().setSingleThreadMode(singleThreadMode);  
  75.         nnet.go(true);  
  76.         System.out.println(" Total time=  "  
  77.                 + (System.currentTimeMillis() - initms) + "  ms ");  
  78.     }  
  79.   
  80.     private void interrogate()  
  81.     {  
  82.   
  83.         double[][] inputArray = new double[][]  
  84.         {  
  85.         { 1.0, 1.0 } };  
  86.         // set the inputs  
  87.         inputSynapse.setInputArray(inputArray);  
  88.         inputSynapse.setAdvancedColumnSelector(" 1,2 ");  
  89.         Monitor monitor = nnet.getMonitor();  
  90.         monitor.setTrainingPatterns(4);  
  91.         monitor.setTotCicles(1);  
  92.         monitor.setLearning(false);  
  93.         MemoryOutputSynapse memOut = new MemoryOutputSynapse();  
  94.         // set the output synapse to write the output of the net  
  95.   
  96.         if (nnet != null)  
  97.         {  
  98.             nnet.addOutputSynapse(memOut);  
  99.             System.out.println(nnet.check());  
  100.             nnet.getMonitor().setSingleThreadMode(singleThreadMode);  
  101.             nnet.go();  
  102.   
  103.             for (int i = 0; i < 4; i++)  
  104.             {  
  105.                 double[] pattern = memOut.getNextPattern();  
  106.                 System.out.println(" Output pattern # " + (i + 1) + " = "  
  107.                         + pattern[0]);  
  108.             }  
  109.             System.out.println(" Interrogating Finished ");  
  110.         }  
  111.     }  
  112.   
  113.     /** 
  114.      * Method declaration 
  115.      */  
  116.     protected void initNeuralNet()  
  117.     {  
  118.   
  119.         // First create the three layers  
  120.         input = new LinearLayer();  
  121.         hidden = new SigmoidLayer();  
  122.         output = new SigmoidLayer();  
  123.   
  124.         // set the dimensions of the layers  
  125.         input.setRows(2);  
  126.         hidden.setRows(3);  
  127.         output.setRows(1);  
  128.   
  129.         input.setLayerName(" L.input ");  
  130.         hidden.setLayerName(" L.hidden ");  
  131.         output.setLayerName(" L.output ");  
  132.   
  133.         // Now create the two Synapses  
  134.         FullSynapse synapse_IH = new FullSynapse(); /* input -> hidden conn. */  
  135.         FullSynapse synapse_HO = new FullSynapse(); /* hidden -> output conn. */  
  136.   
  137.         // Connect the input layer whit the hidden layer  
  138.         input.addOutputSynapse(synapse_IH);  
  139.         hidden.addInputSynapse(synapse_IH);  
  140.   
  141.         // Connect the hidden layer whit the output layer  
  142.         hidden.addOutputSynapse(synapse_HO);  
  143.         output.addInputSynapse(synapse_HO);  
  144.   
  145.         // the input to the neural net  
  146.         inputSynapse = new MemoryInputSynapse();  
  147.   
  148.         input.addInputSynapse(inputSynapse);  
  149.   
  150.         // The Trainer and its desired output  
  151.         desiredOutputSynapse = new MemoryInputSynapse();  
  152.   
  153.         TeachingSynapse trainer = new TeachingSynapse();  
  154.   
  155.         trainer.setDesired(desiredOutputSynapse);  
  156.   
  157.         // Now we add this structure to a NeuralNet object  
  158.         nnet = new NeuralNet();  
  159.   
  160.         nnet.addLayer(input, NeuralNet.INPUT_LAYER);  
  161.         nnet.addLayer(hidden, NeuralNet.HIDDEN_LAYER);  
  162.         nnet.addLayer(output, NeuralNet.OUTPUT_LAYER);  
  163.         nnet.setTeacher(trainer);  
  164.         output.addOutputSynapse(trainer);  
  165.         nnet.addNeuralNetListener(this);  
  166.     }  
  167.   
  168.     public void cicleTerminated(NeuralNetEvent e)  
  169.     {  
  170.     }  
  171.   
  172.     public void errorChanged(NeuralNetEvent e)  
  173.     {  
  174.         Monitor mon = (Monitor) e.getSource();  
  175.         if (mon.getCurrentCicle() % 100 == 0)  
  176.             System.out.println(" Epoch:  "  
  177.                     + (mon.getTotCicles() - mon.getCurrentCicle()) + "  RMSE: "  
  178.                     + mon.getGlobalError());  
  179.     }  
  180.   
  181.     public void netStarted(NeuralNetEvent e)  
  182.     {  
  183.         Monitor mon = (Monitor) e.getSource();  
  184.         System.out.print(" Network started for  ");  
  185.         if (mon.isLearning())  
  186.             System.out.println(" training. ");  
  187.         else  
  188.             System.out.println(" interrogation. ");  
  189.     }  
  190.   
  191.     public void netStopped(NeuralNetEvent e)  
  192.     {  
  193.         Monitor mon = (Monitor) e.getSource();  
  194.         System.out.println(" Network stopped. Last RMSE= "  
  195.                 + mon.getGlobalError());  
  196.     }  
  197.   
  198.     public void netStoppedError(NeuralNetEvent e, String error)  
  199.     {  
  200.         System.out.println(" Network stopped due the following error:  "  
  201.                 + error);  
  202.     }  
  203.   
  204. }  

 

 

现在我会逐步解释上面的程序。

 【1】 从main方法开始说起,首先第一步新建一个对象:

[java] view plain copy
 
  1. XOR_using_NeuralNet xor = new XOR_using_NeuralNet();  

【2】然后初始化神经网络:

 

[java] view plain copy
 
  1. xor.initNeuralNet();  

初始化神经网络的方法中:

[java] view plain copy
 
  1. // First create the three layers  
  2.         input = new LinearLayer();  
  3.         hidden = new SigmoidLayer();  
  4.         output = new SigmoidLayer();  
  5.   
  6.         // set the dimensions of the layers  
  7.         input.setRows(2);  
  8.         hidden.setRows(3);  
  9.         output.setRows(1);  
  10.   
  11.         input.setLayerName(" L.input ");  
  12.         hidden.setLayerName(" L.hidden ");  
  13.         output.setLayerName(" L.output ");  

 

 

上面代码解释:

input=new LinearLayer()是新建一个输入层,因为神经网络的输入层并没有训练参数,所以使用的是线性层;

hidden = new SigmoidLayer();这里是新建一个隐含层,使用sigmoid函数作为激励函数,当然你也可以选择其他的激励函数,如softmax激励函数

output则是新建一个输出层

之后的三行代码是建立输入层、隐含层、输出层的神经元个数,这里表示输入层为2个神经元,隐含层是3个神经元,输出层是1个神经元

最后的三行代码是给每个输出层取一个名字。

[java] view plain copy
 
  1. // Now create the two Synapses  
  2.         FullSynapse synapse_IH = new FullSynapse(); /* input -> hidden conn. */  
  3.         FullSynapse synapse_HO = new FullSynapse(); /* hidden -> output conn. */  
  4.   
  5.         // Connect the input layer whit the hidden layer  
  6.         input.addOutputSynapse(synapse_IH);  
  7.         hidden.addInputSynapse(synapse_IH);  
  8.   
  9.         // Connect the hidden layer whit the output layer  
  10.         hidden.addOutputSynapse(synapse_HO);  
  11.         output.addInputSynapse(synapse_HO);  

 

上面代码解释:

 

上面代码的主要作用是将三个层连接起来,synapse_IH用来连接输入层和隐含层,synapse_HO用来连接隐含层和输出层

[java] view plain copy
 
  1. // the input to the neural net  
  2.         inputSynapse = new MemoryInputSynapse();  
  3.   
  4.         input.addInputSynapse(inputSynapse);  
  5.   
  6.         // The Trainer and its desired output  
  7.         desiredOutputSynapse = new MemoryInputSynapse();  
  8.   
  9.         TeachingSynapse trainer = new TeachingSynapse();  
  10.   
  11.         trainer.setDesired(desiredOutputSynapse);  

 

上面代码解释: 

 

上面的代码是在训练的时候指定输入层的数据和目的输出的数据,

 inputSynapse = new MemoryInputSynapse();这里指的是使用了从内存中输入数据的方法,指的是输入层输入数据,当然还有从文件输入的方法,这点在文章后面再谈。同理,desiredOutputSynapse = new MemoryInputSynapse();也是从内存中输入数据,指的是从输入层应该输出的数据

[java] view plain copy
 
  1. // Now we add this structure to a NeuralNet object  
  2.         nnet = new NeuralNet();  
  3.   
  4.         nnet.addLayer(input, NeuralNet.INPUT_LAYER);  
  5.         nnet.addLayer(hidden, NeuralNet.HIDDEN_LAYER);  
  6.         nnet.addLayer(output, NeuralNet.OUTPUT_LAYER);  
  7.         nnet.setTeacher(trainer);  
  8.         output.addOutputSynapse(trainer);  
  9.         nnet.addNeuralNetListener(this);  

上面代码解释:

 

这段代码指的是将之前初始化的构件连接成一个神经网络,NeuralNet是JOONE提供的类,主要是连接各个神经层,最后一个nnet.addNeuralNetListener(this);这个作用是对神经网络的训练过程进行监听,因为这个类实现了NeuralNetListener这个接口,这个接口有一些方法,可以实现观察神经网络训练过程,有助于参数调整。

【3】然后我们来看一下train这个方法:

[java] view plain copy
 
  1. inputSynapse.setInputArray(inputArray);  
  2.         inputSynapse.setAdvancedColumnSelector(" 1,2 ");  
  3.         // set the desired outputs  
  4.         desiredOutputSynapse.setInputArray(desiredOutputArray);  
  5.         desiredOutputSynapse.setAdvancedColumnSelector(" 1 ");  

 

上面代码解释:

 

inputSynapse.setInputArray(inputArray);这个方法是初始化输入层数据,也就是指定输入层数据的内容,inputArray是程序中给定的二维数组,这也就是为什么之前初始化神经网络的时候使用的是MemoryInputSynapse,表示从内存中读取数据

inputSynapse.setAdvancedColumnSelector(" 1,2 ");这个表示的是输入层数据使用的是inputArray的前两列数据。

desiredOutputSynapse这个也同理

[java] view plain copy
 
  1. Monitor monitor = nnet.getMonitor();  
  2.   
  3.         // set the monitor parameters  
  4.         monitor.setLearningRate(0.8);  
  5.         monitor.setMomentum(0.3);  
  6.         monitor.setTrainingPatterns(inputArray.length);  
  7.         monitor.setTotCicles(5000);  
  8.         <span style="line-height: 1.5;">monitor.setLearning(true);  

 上面代码解释:

这个monitor类也是JOONE框架提供的,主要是用来调节神经网络的参数,monitor.setLearningRate(0.8);是用来设置神经网络训练的步长参数,步长越大,神经网络梯度下降的速度越快,monitor.setTrainingPatterns(inputArray.length);这个是设置神经网络的输入层的训练数据大小size,这里使用的是数组的长度;monitor.setTotCicles(5000);这个指的是设置迭代数目;monitor.setLearning(true);这个true表示是在训练过程。

[java] view plain copy
 
  1. nnet.getMonitor().setSingleThreadMode(singleThreadMode);  
  2.         nnet.go(true);  

上面代码解释:

 

nnet.getMonitor().setSingleThreadMode(singleThreadMode);这个指的是是不是使用多线程,但是我不太清楚这里的多线程指的是什么意思

nnet.go(true)表示的是开始训练。

【4】最后来看一下interrogate方法

[java] view plain copy
 
  1. double[][] inputArray = new double[][]  
  2.         {  
  3.         { 1.0, 1.0 } };  
  4.         // set the inputs  
  5.         inputSynapse.setInputArray(inputArray);  
  6.         inputSynapse.setAdvancedColumnSelector(" 1,2 ");  
  7.         Monitor monitor = nnet.getMonitor();  
  8.         monitor.setTrainingPatterns(4);  
  9.         monitor.setTotCicles(1);  
  10.         monitor.setLearning(false);  
  11.         MemoryOutputSynapse memOut = new MemoryOutputSynapse();  
  12.         // set the output synapse to write the output of the net  
  13.   
  14.         if (nnet != null)  
  15.         {  
  16.             nnet.addOutputSynapse(memOut);  
  17.             System.out.println(nnet.check());  
  18.             nnet.getMonitor().setSingleThreadMode(singleThreadMode);  
  19.             nnet.go();  
  20.   
  21.             for (int i = 0; i < 4; i++)  
  22.             {  
  23.                 double[] pattern = memOut.getNextPattern();  
  24.                 System.out.println(" Output pattern # " + (i + 1) + " = "  
  25.                         + pattern[0]);  
  26.             }  
  27.             System.out.println(" Interrogating Finished ");  
  28.         }  

 

这个方法相当于测试方法,这里的inputArray是测试数据, 注意这里需要设置monitor.setLearning(false);,因为这不是训练过程,并不需要学习,monitor.setTrainingPatterns(4);这个是指测试的数量,4表示有4个测试数据(虽然这里只有一个)。这里还给nnet添加了一个输出层数据对象,这个对象mmOut是初始测试结果,注意到之前我们初始化神经网络的时候并没有给输出层指定数据对象,因为那个时候我们在训练,而且指定了trainer作为目的输出。

 

 

接下来就是输出结果数据了,pattern的个数和输出层的神经元个数一样大,这里输出层神经元的个数是1,所以pattern大小为1.

 

【5】我们看一下测试结果:

 

[java] view plain copy
 
  1. Output pattern # 1 = 0.018303527517809233  

 

 

表示输出结果为0.01,根据sigmoid函数特性,我们得到的输出是0,和预期结果一致。如果输出层神经元个数大于1,那么输出值将会有多个,因为输出层结果是0|1离散值,所以我们取输出最大的那个神经元的输出值取为1,其他为0

 

 

 

【6】最后我们来看一下神经网络训练过程中的一些监听函数:

cicleTerminated:每个循环结束后输出的信息

errorChanged:神经网络错误率变化时候输出的信息

netStarted:神经网络开始运行的时候输出的信息

netStopped:神经网络停止的时候输出的信息

 

【7】好了,JOONE基本上内容就是这些。还有一些额外东西需要说明:

 

1,从文件中读取数据构建神经网络

2.如何保存训练好的神经网络到文件夹中,只要测试的时候直接load到内存中就行,而不用每次都需要训练。

 

 

【8】先看第一个问题:

从文件中读取数据:

文件的格式:

0;0;0

1;0;1

1;1;0

0;1;1

 

中间使用分号隔开,使用方法如下,也就是把上文的MemoryInputSynapse换成FileInputSynapse即可。

[java] view plain copy
 
  1. fileInputSynapse = new FileInputSynapse();  
  2. input.addInputSynapse(fileInputSynapse);  
  3. fileDisireOutputSynapse = new FileInputSynapse();  
  4. TeachingSynapse trainer = new TeachingSynapse();  
  5. trainer.setDesired(fileDisireOutputSynapse);  

 我们看下文件是如何输出数据的:

[java] view plain copy
 
  1. private File inputFile = new File(Constants.TRAIN_WORD_VEC_PATH);  
  2. fileInputSynapse.setInputFile(inputFile);  
  3. fileInputSynapse.setFirstCol(2);//使用文件的第2列到第3列作为输出层输入  
  4. fileInputSynapse.setLastCol(3);  

 

[java] view plain copy
 
  1. fileDisireOutputSynapse.setInputFile(inputFile);  
  2. fileDisireOutputSynapse.setFirstCol(1);//使用文件的第1列作为输出数据  
  3. fileDisireOutputSynapse.setLastCol(1);  

 

 其余的代码和上文的是一样的。

 

 

【9】然后看第二个问题:

如何保存神经网络

其实很简单,直接序列化nnet对象就行了,然后读取该对象就是java的反序列化,这个就不多做介绍了,比较简单。但是需要说明的是,保存神经网络的时机一定是在神经网络训练完毕后,可以使用下面代码:

[java] view plain copy
 
    1. public void netStopped(NeuralNetEvent e) {  
    2.         Monitor mon = (Monitor) e.getSource();  
    3.         try {  
    4.             if (mon.isLearning()) {  
    5.                 saveModel(nnet); //序列化对象  
    6.             }  
    7.         } catch (IOException ee) {  
    8.             // TODO Auto-generated catch block  
    9.             ee.printStackTrace();  
    10.         }  

LSTM java 实现

标签:comm   ast   keyword   listen   开始   arguments   min   特性   article   

原文地址:http://www.cnblogs.com/bob-wzb/p/6054884.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!