标签:plot pyplot check time ESS repr sha sum class
1 !pip install tushare 2 import tushare as ts 3 import numpy as np 4 import tensorflow as tf 5 from tensorflow.keras.layers import Dropout, Dense, LSTM 6 import matplotlib.pyplot as plt 7 import os 8 import pandas as pd 9 from sklearn.preprocessing import MinMaxScaler 10 from sklearn.metrics import mean_squared_error, mean_absolute_error 11 import math 12 13 14 df1 = ts.get_k_data(‘600519‘, ktype=‘D‘, start=‘2004-01-01‘, end=‘2020-05-12‘) 15 16 datapath1 = "./SH600519.csv" 17 df1.to_csv(datapath1) 18 19 20 maotai = pd.read_csv("./SH600519.csv") 21 22 maotai.head() 23 24 25 maotai.tail() 26 27 28 training_set = maotai.iloc[0:3000, 2:3].values 29 test_set = maotai.iloc[3000:, 2:3].values 30 31 #归一化 32 sc = MinMaxScaler(feature_range = (0, 1)) 33 training_set_scaled = sc.fit_transform(training_set) 34 test_set = sc.transform(test_set) 35 36 training_set_scaled.shape 37 38 test_set.shape 39 40 x_train = [] 41 y_train = [] 42 43 x_test = [] 44 y_test = [] 45 46 47 for i in range(60, len(training_set_scaled)): 48 x_train.append(training_set_scaled[i - 60:i, 0]) 49 y_train.append(training_set_scaled[i, 0]) 50 51 np.random.seed(7) 52 np.random.shuffle(x_train) 53 np.random.seed(7) 54 np.random.shuffle(y_train) 55 tf.random.set_seed(7) 56 57 58 x_train, y_train = np.array(x_train), np.array(y_train) 59 60 x_train.shape 61 y_train.shape 62 63 64 x_train = np.reshape(x_train, (x_train.shape[0], 60, 1)) 65 for i in range(60, len(test_set)): 66 x_test.append(test_set[i-60:i, 0]) 67 y_test.append(test_set[i, 0]) 68 69 x_test, y_test = np.array(x_test), np.array(y_test) 70 x_test = np.reshape(x_test, (x_test.shape[0], 60, 1)) 71 72 73 model = tf.keras.Sequential([ 74 LSTM(80, return_sequences=True), 75 Dropout(0.2), 76 LSTM(100), 77 Dropout(0.2), 78 Dense(1) 79 ]) 80 81 model.compile(optimizer=tf.keras.optimizers.Adam(0.0001), 82 loss=‘mean_squared_error‘) 83 84 checkpoint_save_path = "./checkpoint/LSTM_stock.ckpt" 85 86 if os.path.exists(checkpoint_save_path + ‘.index‘): 87 print(‘-------------load the model-------------‘) 88 model.load_weights(checkpoint_save_path) 89 90 cp_callback = tf.keras.callbacks.ModelCheckpoint( 91 filepath=checkpoint_save_path, 92 save_weights_only=True, 93 save_best_only=True, 94 monitor=‘val_loss‘) 95 96 history = model.fit(x_train, y_train, batch_size=64, epochs=24, 97 validation_data=(x_test, y_test), validation_freq=1, callbacks=[cp_callback]) 98 99 model.summary() 100 101 102 103 with open("./weights.txt", "w") as f: 104 for v in model.trainable_variables: 105 f.write(str(v.name) + ‘\n‘) 106 f.write(str(v.shape) + ‘\n‘) 107 f.write(str(v.numpy()) + ‘\n‘) 108 109 110 loss = history.history[‘loss‘] 111 val_loss = history.history[‘val_loss‘] 112 113 plt.plot(loss, label=‘Training Loss‘) 114 plt.plot(val_loss, label=‘Validation Loss‘) 115 plt.title(‘Training and Validation Loss‘) 116 plt.legend() 117 plt.show() 118 119 120 predicted_stock_price = model.predict(x_test) 121 predicted_stock_price = sc.inverse_transform(predicted_stock_price) 122 real_stock_price = sc.inverse_transform(test_set[60:]) 123 124 plt.plot(real_stock_price, color=‘red‘, label=‘real_stock_price‘) 125 plt.plot(predicted_stock_price, color=‘blue‘, label=‘predicted_stock_price‘) 126 plt.title(‘Maotai Stock Price Prediction‘) 127 plt.xlabel(‘Time‘) 128 plt.ylabel(‘Maotai Stock Price‘) 129 plt.legend() 130 plt.show() 131 132 133 mse = mean_squared_error(predicted_stock_price, real_stock_price) 134 rmse = math.sqrt(mean_squared_error(predicted_stock_price, real_stock_price)) 135 mae = mean_absolute_error(predicted_stock_price, real_stock_price) 136 print(‘均方误差: %.6f‘%mse) 137 print(‘均方根误差: %.6f‘%rmse) 138 print(‘平均绝对误差: %.6f‘%mae)
标签:plot pyplot check time ESS repr sha sum class
原文地址:https://www.cnblogs.com/wbloger/p/12885496.html