标签:des style blog io ar color os sp java
接上篇。
Net和Propagation具备后,我们就可以训练了。训练师要做的事情就是,怎么把一大批样本分成小批训练,然后把小批的结果合并成完整的结果(批量/增量);什么时候调用学习师根据训练的结果进行学习,然后改进网络的权重和状态;什么时候决定训练结束。
那么这两位老师儿长的什么样子,又是怎么做到的呢?
public interface Trainer { public void train(Net net,DataProvider provider); } public interface Learner { public void learn(Net net,TrainResult trainResult); }
所谓Trainer即是给定数据,对指定网络进行训练;所谓Learner即是给定训练结果,然后对指定网络进行权重调整。
下面给出这两个接口的简单实现。
Trainer
Trainer实现简单的批量训练功能,在给定的迭代次数后停止。代码示例如下。
public class CommonTrainer implements Trainer { int ecophs; Learner learner; List<Double> costs = new ArrayList<>(); List<Double> accuracys = new ArrayList<>(); int batchSize = 1; public CommonTrainer(int ecophs, Learner learner) { super(); this.ecophs = ecophs; this.learner = learner == null ? new MomentAdaptLearner() : learner; } public CommonTrainer(int ecophs, Learner learner, int batchSize) { this(ecophs, learner); this.batchSize = batchSize; } public void trainOne(final Net net, DataProvider provider) { final Propagation propagation = new Propagation(net); DoubleMatrix input = provider.getInput(); DoubleMatrix target = provider.getTarget(); final int allLen = target.columns; final int[] nodesNum = net.getNodesNum(); final int layersNum = net.getLayersNum(); List<DoubleMatrix> inputBatches = this.getBatches(input); final List<DoubleMatrix> targetBatches = this.getBatches(target); final List<Integer> batchLen = MatrixUtil.getEndPosition(targetBatches); final BackwardResult backwardResult = new BackwardResult(net, allLen); // 分批并行训练 Parallel.For(inputBatches, new Parallel.Operation<DoubleMatrix>() { @Override public void perform(int index, DoubleMatrix subInput) { ForwardResult subResult = propagation.forward(subInput); DoubleMatrix subTarget = targetBatches.get(index); BackwardResult backResult = propagation.backward(subTarget, subResult); DoubleMatrix cost = backwardResult.cost; DoubleMatrix accuracy = backwardResult.accuracy; DoubleMatrix inputDeltas = backwardResult.getInputDelta(); int start = index == 0 ? 0 : batchLen.get(index - 1); int end = batchLen.get(index) - 1; int[] cIndexs = ArraysHelper.makeArray(start, end); cost.put(cIndexs, backResult.cost); if (accuracy != null) { accuracy.put(cIndexs, backResult.accuracy); } inputDeltas.put(ArraysHelper.makeArray(0, nodesNum[0] - 1), cIndexs, backResult.getInputDelta()); for (int i = 0; i < layersNum; i++) { DoubleMatrix gradients = backwardResult.gradients.get(i); DoubleMatrix biasGradients = backwardResult.biasGradients .get(i); DoubleMatrix subGradients = backResult.gradients.get(i) .muli(backResult.cost.columns); DoubleMatrix subBiasGradients = backResult.biasGradients .get(i).muli(backResult.cost.columns); gradients.addi(subGradients); biasGradients.addi(subBiasGradients); } } }); // 求均值 for(DoubleMatrix gradient:backwardResult.gradients){ gradient.divi(allLen); } for(DoubleMatrix gradient:backwardResult.biasGradients){ gradient.divi(allLen); } // this.mergeBackwardResult(backResults, net, input.columns); TrainResult trainResult = new TrainResult(null, backwardResult); learner.learn(net, trainResult); Double cost = backwardResult.getMeanCost(); Double accuracy = backwardResult.getMeanAccuracy(); if (cost != null) costs.add(cost); if (accuracy != null) accuracys.add(accuracy); System.out.println(cost); System.out.println(accuracy); } @Override public void train(Net net, DataProvider provider) { for (int i = 0; i < this.ecophs; i++) { this.trainOne(net, provider); } } }
Learner
Learner是具体的调整算法,当梯度计算出来后,它负责对网络权重进行调整。调整算法的选择直接影响着网络收敛的快慢。本文的实现采用简单的动量-自适应学习率算法。
其迭代公式如下:
$$W(t+1)=W(t)+\Delta W(t)$$
$$\Delta W(t)=rate(t)(1-moment(t))G(t+1)+moment(t)\Delta W(t-1)$$
$$rate(t+1)=\begin{cases} rate(t)\times 1.05 & \mbox{if } cost(t)<cost(t-1)\\ rate(t)\times 0.7 & \mbox{else if } cost(t)<cost(t-1)\times 1.04\\ 0.01 & \mbox{else} \end{cases}$$
$$moment(t+1)=\begin{cases} 0.9 & \mbox{if } cost(t)<cost(t-1)\\ rate(t)\times 0.7 & \mbox{else if } cost(t)<cost(t-1)\times 1.04\\ 1-0.9 & \mbox{else} \end{cases}$$
示例代码如下:
public class MomentAdaptLearner implements Learner { Net net; double moment = 0.9; double lmd = 1.05; double preCost = 0; double eta = 0.01; double currentEta=eta; double currentMoment=moment; TrainResult preTrainResult; public MomentAdaptLearner(double moment, double eta) { super(); this.moment = moment; this.eta = eta; this.currentEta=eta; this.currentMoment=moment; } @Override public void learn(Net net, TrainResult trainResult) { if (this.net == null) init(net); BackwardResult backwardResult = trainResult.backwardResult; BackwardResult preBackwardResult = preTrainResult.backwardResult; double cost=backwardResult.getMeanCost(); this.modifyParameter(cost); System.out.println("current eta:"+this.currentEta); System.out.println("current moment:"+this.currentMoment); for (int j = 0; j < net.getLayersNum(); j++) { DoubleMatrix weight = net.getWeights().get(j); DoubleMatrix gradient = backwardResult.gradients.get(j); gradient = gradient.muli(currentEta * (1 - this.currentMoment)).addi( preBackwardResult.gradients.get(j).muli(this.currentMoment)); preBackwardResult.gradients.set(j, gradient); weight.subi(gradient); DoubleMatrix b = net.getBs().get(j); DoubleMatrix bgradient = backwardResult.biasGradients.get(j); bgradient = bgradient.muli(currentEta * (1 - this.currentMoment)).addi( preBackwardResult.biasGradients.get(j).muli(this.currentMoment)); preBackwardResult.biasGradients.set(j, bgradient); b.subi(bgradient); } } public void modifyParameter(double cost){ if(cost<this.preCost){ this.currentEta*=1.05; this.currentMoment=moment; }else if(cost<1.04*this.preCost){ this.currentEta*=0.7; this.currentMoment*=0.7; }else{ this.currentEta=eta; this.currentMoment=1-moment; } this.preCost=cost; } public void init(Net net) { this.net = net; BackwardResult bResult = new BackwardResult(); for (DoubleMatrix weight : net.getWeights()) { bResult.gradients.add(DoubleMatrix.zeros(weight.rows, weight.columns)); } for (DoubleMatrix b : net.getBs()) { bResult.biasGradients.add(DoubleMatrix.zeros(b.rows, b.columns)); } preTrainResult=new TrainResult(null,bResult); } }
现在,一个简单的神经网路从生成到训练已经简单实现完毕。
标签:des style blog io ar color os sp java
原文地址:http://www.cnblogs.com/wuseguang/p/4126226.html