码迷,mamicode.com
首页 > 其他好文 > 详细

Hopfield神经网络实现污染字体的识别

时间:2016-06-05 23:12:01      阅读:364      评论:0      收藏:0      [点我收藏+]

标签:

这个网络的内部使用的是hebb学习规则

 

贴上两段代码:

 

  

package geym.nn.hopfiled;

import java.util.Arrays;

import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.nnet.Hopfield;
import org.neuroph.nnet.comp.neuron.InputOutputNeuron;
import org.neuroph.nnet.learning.HopfieldLearning;
import org.neuroph.util.NeuronProperties;
import org.neuroph.util.TransferFunctionType;

/**
 * 识别0 1 2 使用hopfield 全连接结构
 * @author Administrator
 *
 */
public class HopfieldSample2 {

    public static double[] format(double[] data){
        for(int i=0;i<data.length;i++){
            if(data[i]==0)data[i]=-1;
        }
        return data;
    }
    
    public static void main(String args[]) {
        NeuronProperties neuronProperties = new NeuronProperties();
        neuronProperties.setProperty("neuronType", InputOutputNeuron.class);
        neuronProperties.setProperty("bias", new Double(0.0D));
        neuronProperties.setProperty("transferFunction", TransferFunctionType.STEP);
        neuronProperties.setProperty("transferFunction.yHigh", new Double(1.0D));
        neuronProperties.setProperty("transferFunction.yLow", new Double(-1.0D));

        // create training set (H and T letter in 3x3 grid)
        DataSet trainingSet = new DataSet(30);
        trainingSet.addRow(new DataSetRow(format(new double[] { 
                0,1,1,1,1,0,
                1,0,0,0,0,1,
                1,0,0,0,0,1,
                1,0,0,0,0,1,
                0,1,1,1,1,0}))); //0
        
        trainingSet.addRow(new DataSetRow(format(new double[] { 
                0,0,0,0,0,0,
                1,0,0,0,0,0,
                1,1,1,1,1,1,
                0,0,0,0,0,0,
                0,0,0,0,0,0}))); //1
        
        trainingSet.addRow(new DataSetRow(format(new double[] { 
                1,0,0,0,0,0,
                1,0,0,1,1,1,
                1,0,0,1,0,1,
                1,0,0,1,0,1,
                0,1,1,0,0,1}))); //2
        
        

        // create hopfield network
        Hopfield myHopfield = new Hopfield(30, neuronProperties);
        myHopfield.setLearningRule(new StandHopfieldLearning());
        // learn the training set
        myHopfield.learn(trainingSet);

        // test hopfield network
        System.out.println("Testing network");

        // add one more ‘incomplete‘ H pattern for testing - it will be
        // recognized as H
        // DataSetRow h=new DataSetRow(new double[] { 1, 0, 0, 1, 0, 1, 1, 0, 1
        // });
        // DataSetRow h=new DataSetRow(new double[] { 1, 0, 0, 1, 0, 1, 1, 0, 1
        // });
        DataSetRow h = new DataSetRow(format(new double[] { 
                1,0,0,0,0,0,
                1,0,0,1,1,1,
                1,0,0,1,0,1,
                1,0,0,1,0,0,
                0,1,1,0,0,1})); // 2 bad
        trainingSet.addRow(h); 


        myHopfield.setInput(h.getInput());

        double[] networkOutput = null;
        double[] preNetworkOutput = null;
        while (true) {
            myHopfield.calculate();
            networkOutput = myHopfield.getOutput();
            if (preNetworkOutput == null) {
                preNetworkOutput = networkOutput;
                continue;
            }
            if (Arrays.equals(networkOutput, preNetworkOutput)) {
                break;
            }
            preNetworkOutput = networkOutput;
        }

        System.out.print("Input: " + Arrays.toString(h.getInput()));
        System.out.println(" Output: " + Arrays.toString(networkOutput));
    
        System.out.println(Arrays.equals(format(new double[] { 
                1,0,0,0,0,0,
                1,0,0,1,1,1,
                1,0,0,1,0,1,
                1,0,0,1,0,1,
                0,1,1,0,0,1}), networkOutput));
    }

}

 

 

下面就是StandHopfieldLearning类的实现,里面标红的地方就是hebb学习规则,权重为输入和输出的乘积:

 

  

package com.cgjr.com.hopfield;

import org.neuroph.core.Connection;
import org.neuroph.core.Layer;
import org.neuroph.core.Neuron;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.learning.LearningRule;

/**
 * Learning algorithm for the Hopfield neural network.
 * 
 * @author Zoran Sevarac <sevarac@gmail.com>
 */
public class StandHopfieldLearning extends LearningRule {
    
    /**
     * The class fingerprint that is set to indicate serialization
     * compatibility with a previous version of the class.
     */    
    private static final long serialVersionUID = 1L;

    /**
     * Creates new HopfieldLearning
     */
    public StandHopfieldLearning() {
        super();
    }


    /**
     * Calculates weights for the hopfield net to learn the specified training
     * set
     * 
     * @param trainingSet
     *            training set to learn
     */
    public void learn(DataSet trainingSet) {
        int M = trainingSet.size();
        int N = neuralNetwork.getLayerAt(0).getNeuronsCount();
        Layer hopfieldLayer = neuralNetwork.getLayerAt(0);

        for (int i = 0; i < N; i++) {
            for (int j = 0; j < N; j++) {
                if (j == i)
                    continue;
                Neuron ni = hopfieldLayer.getNeuronAt(i);
                Neuron nj = hopfieldLayer.getNeuronAt(j);
                Connection cij = nj.getConnectionFrom(ni);
                Connection cji = ni.getConnectionFrom(nj);
                
                double wij=0;
                for(int k = 0;k < M;k++){
                    DataSetRow row=trainingSet.getRowAt(k);
                    double[] inputs=row.getInput();
                    wij+=inputs[i]*inputs[j];//Hebb学习规则
                }
                cij.getWeight().setValue(wij);
                cji.getWeight().setValue(wij);
            }// j
        } // i

    }

}

 

Hopfield神经网络实现污染字体的识别

标签:

原文地址:http://www.cnblogs.com/beigongfengchen/p/5562032.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!