public class DailyData {
//开盘价
private double openPrice;
//收盘价
private double closeprice;
//最高价
private double maxPrice;
//最低价
private double minPrice;
//成交量
private double turnover;
//成交额
private double volume;
public double getTurnover() {
return turnover;
}
public double getVolume() {
return volume;
}
public DailyData(){
}
public double getOpenPrice() {
return openPrice;
}
public double getCloseprice() {
return closeprice;
}
public double getMaxPrice() {
return maxPrice;
}
public double getMinPrice() {
return minPrice;
}
public void setOpenPrice(double openPrice) {
this.openPrice = openPrice;
}
public void setCloseprice(double closeprice) {
this.closeprice = closeprice;
}
public void setMaxPrice(double maxPrice) {
this.maxPrice = maxPrice;
}
public void setMinPrice(double minPrice) {
this.minPrice = minPrice;
}
public void setTurnover(double turnover) {
this.turnover = turnover;
}
public void setVolume(double volume) {
this.volume = volume;
}
@Override
public String toString(){
StringBuilder builder = new StringBuilder();
builder.append("开盘价="+this.openPrice+", ");
builder.append("收盘价="+this.closeprice+", ");
builder.append("最高价="+this.maxPrice+", ");
builder.append("最低价="+this.minPrice+", ");
builder.append("成交量="+this.turnover+", ");
builder.append("成交额="+this.volume);
return builder.toString();
}
}
public class StockDataIterator implements DataSetIterator {
/**
*
*/
private static final long serialVersionUID = 1L;
private static final int VECTOR_SIZE = 6;
//每批次的训练数据组数
private int batchNum;
//每组训练数据长度(DailyData的个数)
private int exampleLength;
//数据集
private List<DailyData> dataList;
//存放剩余数据组的index信息
private List<Integer> dataRecord;
private double[] maxNum;
/**
* 构造方法
* */
public StockDataIterator(){
dataRecord = new ArrayList<>();
}
/**
* 加载数据并初始化
* */
public boolean loadData(String fileName, int batchNum, int exampleLength){
this.batchNum = batchNum;
this.exampleLength = exampleLength;
maxNum = new double[6];
//加载文件中的股票数据
try {
readDataFromFile(fileName);
}catch (Exception e){
e.printStackTrace();
return false;
}
//重置训练批次列表
resetDataRecord();
return true;
}
/**
* 重置训练批次列表
* */
private void resetDataRecord(){
dataRecord.clear();
int total = dataList.size()/exampleLength+1;
for( int i=0; i<total; i++ ){
dataRecord.add(i * exampleLength);
}
}
/**
* 从文件中读取股票数据
* */
public List<DailyData> readDataFromFile(String fileName) throws IOException{
dataList = new ArrayList<>();
BufferedReader in = new BufferedReader(new InputStreamReader(StockDataIterator.class.getResourceAsStream(fileName) ,"UTF-8"));
String line = in.readLine();
for(int i=0;i<maxNum.length;i++){
maxNum[i] = 0;
}
System.out.println("读取数据..");
while(line!=null){
String[] strArr = line.split(",");
if(strArr.length>=7) {
DailyData data = new DailyData();
//获得最大值信息,用于归一化
double[] nums = new double[6];
for(int j=0;j<6;j++){
nums[j] = Double.valueOf(strArr[j+2]);
if( nums[j]>maxNum[j] ){
maxNum[j] = nums[j];
}
}
//构造data对象
data.setOpenPrice(Double.valueOf(nums[0]));
data.setCloseprice(Double.valueOf(nums[1]));
data.setMaxPrice(Double.valueOf(nums[2]));
data.setMinPrice(Double.valueOf(nums[3]));
data.setTurnover(Double.valueOf(nums[4]));
data.setVolume(Double.valueOf(nums[5]));
dataList.add(data);
}
line = in.readLine();
}
in.close();
System.out.println("反转list...");
Collections.reverse(dataList);
return dataList;
}
public double[] getMaxArr(){
return this.maxNum;
}
public void reset(){
resetDataRecord();
}
public boolean hasNext(){
return dataRecord.size() > 0;
}
public DataSet next(){
return next(batchNum);
}
/**
* 获得接下来一次的训练数据集
* */
public DataSet next(int num){
if( dataRecord.size() <= 0 ) {
throw new NoSuchElementException();
}
int actualBatchSize = Math.min(num, dataRecord.size());
int actualLength = Math.min(exampleLength,dataList.size()-dataRecord.get(0)-1);
INDArray input = Nd4j.create(new int[]{actualBatchSize,VECTOR_SIZE,actualLength}, ‘f‘);
INDArray label = Nd4j.create(new int[]{actualBatchSize,1,actualLength}, ‘f‘);
DailyData nextData = null,curData = null;
//获取每批次的训练数据和标签数据
for(int i=0;i<actualBatchSize;i++){
int index = dataRecord.remove(0);
int endIndex = Math.min(index+exampleLength,dataList.size()-1);
curData = dataList.get(index);
for(int j=index;j<endIndex;j++){
//获取数据信息
nextData = dataList.get(j+1);
//构造训练向量
int c = endIndex-j-1;
input.putScalar(new int[]{i, 0, c}, curData.getOpenPrice()/maxNum[0]);
input.putScalar(new int[]{i, 1, c}, curData.getCloseprice()/maxNum[1]);
input.putScalar(new int[]{i, 2, c}, curData.getMaxPrice()/maxNum[2]);
input.putScalar(new int[]{i, 3, c}, curData.getMinPrice()/maxNum[3]);
input.putScalar(new int[]{i, 4, c}, curData.getTurnover()/maxNum[4]);
input.putScalar(new int[]{i, 5, c}, curData.getVolume()/maxNum[5]);
//构造label向量
label.putScalar(new int[]{i, 0, c}, nextData.getCloseprice()/maxNum[1]);
curData = nextData;
}
if(dataRecord.size()<=0) {
break;
}
}
return new DataSet(input, label);
}
public int batch() {
return batchNum;
}
public int cursor() {
return totalExamples() - dataRecord.size();
}
public int numExamples() {
return totalExamples();
}
public void setPreProcessor(DataSetPreProcessor preProcessor) {
throw new UnsupportedOperationException("Not implemented");
}
public int totalExamples() {
return (dataList.size()) / exampleLength;
}
public int inputColumns() {
return dataList.size();
}
public int totalOutcomes() {
return 1;
}
@Override
public List<String> getLabels() {
throw new UnsupportedOperationException("Not implemented");
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
@Override
public boolean resetSupported() {
// TODO Auto-generated method stub
return false;
}
@Override
public boolean asyncSupported() {
// TODO Auto-generated method stub
return false;
}
@Override
public DataSetPreProcessor getPreProcessor() {
// TODO Auto-generated method stub
return null;
}
}
public class Dtest {
private static final int IN_NUM = 6;
private static final int OUT_NUM = 1;
private static final int Epochs = 1;
private static final int lstmLayer1Size = 50;
private static final int lstmLayer2Size = 100;
public static MultiLayerNetwork getNetModel(int nIn,int nOut){
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.seed(12345)
.l2(0.001)
.updater(Updater.RMSPROP)
.list()
.layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(lstmLayer1Size)
.activation(Activation.TANH).build())
.layer(1, new GravesLSTM.Builder().nIn(lstmLayer1Size).nOut(lstmLayer2Size)
.activation(Activation.TANH).build())
.layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
.nIn(lstmLayer2Size).nOut(nOut).build())
.pretrain(false).backprop(true)
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(1));
return net;
}
public static void train(MultiLayerNetwork net,StockDataIterator iterator){
//迭代训练
for(int i=0;i<Epochs;i++) {
DataSet dataSet = null;
while (iterator.hasNext()) {
dataSet = iterator.next();
net.fit(dataSet);
}
iterator.reset();
System.out.println();
System.out.println("=================>完成第"+i+"次完整训练");
INDArray initArray = getInitArray(iterator);
System.out.println("预测结果:");
for(int j=0;j<20;j++) {
INDArray output = net.rnnTimeStep(initArray);
System.out.print(output.getDouble(0)*iterator.getMaxArr()[1]+" ");
}
System.out.println();
net.rnnClearPreviousState();
}
}
private static INDArray getInitArray(StockDataIterator iter){
double[] maxNums = iter.getMaxArr();
INDArray initArray = Nd4j.zeros(1, 6, 1);
initArray.putScalar(new int[]{0,0,0}, 3433.85/maxNums[0]);
initArray.putScalar(new int[]{0,1,0}, 3445.41/maxNums[1]);
initArray.putScalar(new int[]{0,2,0}, 3327.81/maxNums[2]);
initArray.putScalar(new int[]{0,3,0}, 3470.37/maxNums[3]);
initArray.putScalar(new int[]{0,4,0}, 304197903.0/maxNums[4]);
initArray.putScalar(new int[]{0,5,0}, 3.8750365e+11/maxNums[5]);
return initArray;
}
public static void main(String[] args) {
String inputFile = "sz399905.csv";
int batchSize = 1;
int exampleLength = 30;
//初始化深度神经网络
StockDataIterator iterator = new StockDataIterator();
iterator.loadData(inputFile,batchSize,exampleLength);
MultiLayerNetwork net = getNetModel(IN_NUM,OUT_NUM);
train(net, iterator);
}
}
数据格式如下:
sz399905 2015/12/11 7320.16 7290.7 7253.84 7347.36 72132287 1.12E+11 -0.008096367
sz399905 2015/12/10 7374.35 7350.21 7332.98 7437.71 78990424 1.30E+11 -0.003262696
sz399905 2015/12/9 7369.11 7374.27 7322.87 7431.04 83299991 1.32E+11 -0.004034229
sz399905 2015/12/8 7555.46 7404.14 7398.56 7555.46 94938823 1.47E+11 -0.026056828
sz399905 2015/12/7 7526.22 7602.23 7476.19 7602.77 92881296 1.47E+11 0.012055908
sz399905 2015/12/4 7533.61 7511.67 7464.28 7600.34 101362535 1.55E+11 -0.007772264
sz399905 2015/12/3 7413.22 7570.51 7412.65 7571.45 95329412 1.43E+11 0.022232394
sz399905 2015/12/2 7423.5 7405.86 7201.66 7444.22 102647475 1.50E+11 -0.005115571
sz399905 2015/12/1 7403.94 7443.94 7358.37 7519.94 113008679 1.73E+11 0.004797257
sz399905 2015/11/30 7388.28 7408.4 7035.55 7467.47 129234023 1.97E+11 0.004376285
sz399905 2015/11/27 7839.31 7376.12 7317.65 7852 152970489 2.34E+11 -0.063240404
sz399905 2015/11/26 7962.17 7874.08 7859.63 7974.73 140404615 2.29E+11 -0.006096653
sz399905 2015/11/25 7803.29 7922.38 7795.16 7925.54 124435501 2.07E+11 0.015885106
sz399905 2015/11/24 7739.09 7798.5 7635.78 7799.01 110258558 1.69E+11 0.0070143
参考文章:
https://blog.csdn.net/a398942089/article/details/52294082
原文地址:http://blog.51cto.com/12597095/2119576