码迷,mamicode.com
首页 > 编程语言 > 详细

用java写bp神经网络(二)

时间:2014-11-27 16:03:45      阅读:191      评论:0      收藏:0      [点我收藏+]

标签:des   style   blog   io   ar   color   os   sp   java   

接上篇。

Net和Propagation具备后,我们就可以训练了。训练师要做的事情就是,怎么把一大批样本分成小批训练,然后把小批的结果合并成完整的结果(批量/增量);什么时候调用学习师根据训练的结果进行学习,然后改进网络的权重和状态;什么时候决定训练结束。

那么这两位老师儿长的什么样子,又是怎么做到的呢?

public interface Trainer {
    public void train(Net net,DataProvider provider);
}

public interface Learner {
    public void learn(Net net,TrainResult trainResult);
}

 所谓Trainer即是给定数据,对指定网络进行训练;所谓Learner即是给定训练结果,然后对指定网络进行权重调整。

下面给出这两个接口的简单实现。

Trainer

Trainer实现简单的批量训练功能,在给定的迭代次数后停止。代码示例如下。

public class CommonTrainer implements Trainer {
	int ecophs;
	Learner learner;
	List<Double> costs = new ArrayList<>();
	List<Double> accuracys = new ArrayList<>();
	int batchSize = 1;

	public CommonTrainer(int ecophs, Learner learner) {
		super();
		this.ecophs = ecophs;
		this.learner = learner == null ? new MomentAdaptLearner() : learner;
	}

	public CommonTrainer(int ecophs, Learner learner, int batchSize) {
		this(ecophs, learner);
		this.batchSize = batchSize;
	}

	public void trainOne(final Net net, DataProvider provider) {
		final Propagation propagation = new Propagation(net);
		DoubleMatrix input = provider.getInput();
		DoubleMatrix target = provider.getTarget();
		final int allLen = target.columns;
		final int[] nodesNum = net.getNodesNum();
		final int layersNum = net.getLayersNum();

		List<DoubleMatrix> inputBatches = this.getBatches(input);
		final List<DoubleMatrix> targetBatches = this.getBatches(target);

		final List<Integer> batchLen = MatrixUtil.getEndPosition(targetBatches);

		final BackwardResult backwardResult = new BackwardResult(net, allLen);

          // 分批并行训练
		Parallel.For(inputBatches, new Parallel.Operation<DoubleMatrix>() {
			@Override
			public void perform(int index, DoubleMatrix subInput) {
				ForwardResult subResult = propagation.forward(subInput);
				DoubleMatrix subTarget = targetBatches.get(index);
				BackwardResult backResult = propagation.backward(subTarget,
						subResult);

				DoubleMatrix cost = backwardResult.cost;
				DoubleMatrix accuracy = backwardResult.accuracy;
				DoubleMatrix inputDeltas = backwardResult.getInputDelta();

				int start = index == 0 ? 0 : batchLen.get(index - 1);
				int end = batchLen.get(index) - 1;
				int[] cIndexs = ArraysHelper.makeArray(start, end);

				cost.put(cIndexs, backResult.cost);

				if (accuracy != null) {
					accuracy.put(cIndexs, backResult.accuracy);
				}

				inputDeltas.put(ArraysHelper.makeArray(0, nodesNum[0] - 1),
						  cIndexs, backResult.getInputDelta());

				for (int i = 0; i < layersNum; i++) {
					DoubleMatrix gradients = backwardResult.gradients.get(i);
					DoubleMatrix biasGradients = backwardResult.biasGradients
							.get(i);

  					DoubleMatrix subGradients = backResult.gradients.get(i)
							.muli(backResult.cost.columns);
					DoubleMatrix subBiasGradients = backResult.biasGradients
							.get(i).muli(backResult.cost.columns);
					gradients.addi(subGradients);
					biasGradients.addi(subBiasGradients);
				}
			}
		});
         // 求均值
		for(DoubleMatrix gradient:backwardResult.gradients){
			gradient.divi(allLen);
		}
		for(DoubleMatrix gradient:backwardResult.biasGradients){
			gradient.divi(allLen);
		}
		
		// this.mergeBackwardResult(backResults, net, input.columns);
		TrainResult trainResult = new TrainResult(null, backwardResult);

		learner.learn(net, trainResult);

		Double cost = backwardResult.getMeanCost();
		Double accuracy = backwardResult.getMeanAccuracy();
		if (cost != null)
			costs.add(cost);
		if (accuracy != null)
			accuracys.add(accuracy);

  		System.out.println(cost);
		System.out.println(accuracy);
	}

	@Override
	public void train(Net net, DataProvider provider) {
		for (int i = 0; i < this.ecophs; i++) {
			this.trainOne(net, provider);
		}

	}
}

 

Learner

Learner是具体的调整算法,当梯度计算出来后,它负责对网络权重进行调整。调整算法的选择直接影响着网络收敛的快慢。本文的实现采用简单的动量-自适应学习率算法。

其迭代公式如下:

$$W(t+1)=W(t)+\Delta W(t)$$

$$\Delta W(t)=rate(t)(1-moment(t))G(t+1)+moment(t)\Delta W(t-1)$$

$$rate(t+1)=\begin{cases} rate(t)\times 1.05 & \mbox{if } cost(t)<cost(t-1)\\ rate(t)\times 0.7 & \mbox{else if } cost(t)<cost(t-1)\times 1.04\\ 0.01 & \mbox{else} \end{cases}$$

$$moment(t+1)=\begin{cases} 0.9 & \mbox{if } cost(t)<cost(t-1)\\ rate(t)\times 0.7 & \mbox{else if } cost(t)<cost(t-1)\times 1.04\\ 1-0.9 & \mbox{else} \end{cases}$$

示例代码如下:

public class MomentAdaptLearner implements Learner {

	Net net;
	double moment = 0.9;
	double lmd = 1.05;
	double preCost = 0;
	double eta = 0.01;
	double currentEta=eta;
	double currentMoment=moment;
	TrainResult preTrainResult;
	
	public MomentAdaptLearner(double moment, double eta) {
		super();
		this.moment = moment;
		this.eta = eta;
		this.currentEta=eta;
		this.currentMoment=moment;
	}

	@Override
	public void learn(Net net, TrainResult trainResult) {
		if (this.net == null)
			init(net);
		
		BackwardResult backwardResult = trainResult.backwardResult;
		BackwardResult preBackwardResult = preTrainResult.backwardResult;
		double cost=backwardResult.getMeanCost();
		this.modifyParameter(cost);
		System.out.println("current eta:"+this.currentEta);
		System.out.println("current moment:"+this.currentMoment);
		for (int j = 0; j < net.getLayersNum(); j++) {
			DoubleMatrix weight = net.getWeights().get(j);
			DoubleMatrix gradient = backwardResult.gradients.get(j);

			gradient = gradient.muli(currentEta * (1 - this.currentMoment)).addi(
					preBackwardResult.gradients.get(j).muli(this.currentMoment));
			preBackwardResult.gradients.set(j, gradient);

			weight.subi(gradient);
			
			DoubleMatrix b = net.getBs().get(j);
			DoubleMatrix bgradient = backwardResult.biasGradients.get(j);

			bgradient = bgradient.muli(currentEta * (1 - this.currentMoment)).addi(
					preBackwardResult.biasGradients.get(j).muli(this.currentMoment));
			preBackwardResult.biasGradients.set(j, bgradient);

			b.subi(bgradient);
		}

	}

	public void modifyParameter(double cost){
		if(cost<this.preCost){
			this.currentEta*=1.05;
			this.currentMoment=moment;
		}else if(cost<1.04*this.preCost){
			this.currentEta*=0.7;
			this.currentMoment*=0.7;
		}else{
			this.currentEta=eta;
			this.currentMoment=1-moment;
		}
		this.preCost=cost;
	}
	public void init(Net net) {
		this.net =  net;
		BackwardResult bResult = new BackwardResult();

		for (DoubleMatrix weight : net.getWeights()) {
			bResult.gradients.add(DoubleMatrix.zeros(weight.rows,
					weight.columns));
		}
		for (DoubleMatrix b : net.getBs()) {
			bResult.biasGradients.add(DoubleMatrix.zeros(b.rows, b.columns));
		}
		preTrainResult=new TrainResult(null,bResult);
	}


}

现在,一个简单的神经网路从生成到训练已经简单实现完毕。

用java写bp神经网络(二)

标签:des   style   blog   io   ar   color   os   sp   java   

原文地址:http://www.cnblogs.com/wuseguang/p/4126226.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!