3、Spark MLlib Deep Learning Convolution Neural Network(深度学习-卷积神经网络)3.3
第三章Convolution Neural Network (卷积神经网络)
3实例
3.1 測试数据
依照上例数据,或者新建图片识别数据。
3.2 CNN实例
???//2 測试数据
???Logger.getRootLogger.setLevel(Level.WARN)
???valdata_path="/user/tmp/deeplearn/train_d.txt"
???valexamples=sc.textFile(data_path).cache()
???valtrain_d1=examples.map{ line =>
?????valf1 = line.split("\t")
?????valf =f1.map(f => f.toDouble)
?????valy =f.slice(0,10)
?????valx =f.slice(10,f.length)
?????(newBDM(1,y.length, y), (new BDM(1,x.length, x)).reshape(28,28) / 255.0)
???}
???valtrain_d=train_d1.map(f=> (f._1, f._2))
?
???//3 设置训练參数。建立模型
???// opts:迭代步长,迭代次数,交叉验证比例
???valopts= Array(100.0,1.0,0.0)
???train_d.cache
???valnumExamples=train_d.count()
???println(s"numExamples = $numExamples.")
???valCNNmodel=newCNN().
?????setMapsize(new BDM(1,2, Array(28.0,28.0))).
?????setTypes(Array("i", "c","s","c","s")).
?????setLayer(5).
?????setOnum(10).
?????setOutputmaps(Array(0.0, 6.0,0.0,12.0,0.0)).
?????setKernelsize(Array(0.0, 5.0,0.0,5.0,0.0)).
?????setScale(Array(0.0, 0.0,2.0,0.0,2.0)).
?????setAlpha(1.0).
?????setBatchsize(50.0).
?????setNumepochs(1.0).
?????CNNtrain(train_d,opts)
?
???//4 模型測试
???valCNNforecast=CNNmodel.predict(train_d)
???valCNNerror=CNNmodel.Loss(CNNforecast)
???println(s"NNerror = $CNNerror.")
???valprintf1=CNNforecast.map(f=> (f.label.data(0), f.predict_label.data(0))).take(200)
???println("预測结果——实际值:预測值:误差")
???for(i <-0 until printf1.length)
?????println(printf1(i)._1 +"\t" +printf1(i)._2 +"\t" + (printf1(i)._2 -printf1(i)._1))???val numExamples = train_d.count()
???println(s"numExamples = $numExamples.")
???println(mynn._2)
???for(i <-0 to mynn._1.length -1) {
?????print(mynn._1(i) +"\t")
???}
???println()
???println("mynn_W1")
???valtmpw1=mynn._3(0)
???for(i <-0 to tmpw1.rows -1) {
?????for(j <-0 to tmpw1.cols -1) {
??????? print(tmpw1(i,j) + "\t")
?????}
?????println()
???}
???valNNmodel=newNeuralNet().
?????setSize(mynn._1).
?????setLayer(mynn._2).
?????setActivation_function("sigm").
?????setOutput_function("sigm").
?????setInitW(mynn._3).
?????NNtrain(train_d,nnopts)
?
???//5 NN模型測试
???valNNforecast=NNmodel.predict(train_d)
???valNNerror=NNmodel.Loss(NNforecast)
???println(s"NNerror = $NNerror.")
???valprintf1=NNforecast.map(f=> (f.label.data(0), f.predict_label.data(0))).take(200)
???println("预測结果——实际值:预測值:误差")
???for(i <-0 until printf1.length)
?????println(printf1(i)._1 +"\t" +printf1(i)._2 +"\t" + (printf1(i)._2 -printf1(i)._1))?
转载请注明出处:
?
?
?