标签:cal ack for ict mod git rac var result
1 import tensorflow as tf 2 import numpy as np 3 from tensorflow.keras.layers import Dense, SimpleRNN, Embedding 4 import matplotlib.pyplot as plt 5 import os 6 7 8 input_word = "abcdefghijklmnopqrstuvwxyz" 9 w_to_id = {‘a‘:0, ‘b‘:1, ‘c‘:2, ‘d‘:3, ‘e‘:4, 10 ‘f‘:5, ‘g‘:6, ‘h‘:7, ‘i‘:8, ‘j‘:9, 11 ‘k‘:10, ‘l‘:11, ‘m‘:12, ‘n‘:13, ‘o‘:14, ‘p‘:15, 12 ‘q‘:16, ‘r‘:17, ‘s‘:18, ‘t‘:19, 13 ‘u‘:20, ‘v‘:21, ‘w‘:22, ‘x‘:23, ‘y‘:24, ‘z‘:25} 14 15 training_set_scaled = [x for x in range(26)] 16 17 18 x_train = [] 19 y_train = [] 20 21 for i in range(4, 26): 22 x_train.append(training_set_scaled[i-4 : i]) 23 y_train.append(training_set_scaled[i]) 24 25 26 np.random.seed(7) 27 np.random.shuffle(x_train) 28 np.random.seed(7) 29 np.random.shuffle(y_train) 30 tf.random.set_seed(7) 31 32 33 # 使x_train符合Embedding输入要求:[送入样本数, 循环核时间展开步数] , 34 # 此处整个数据集送入所以送入,送入样本数为len(x_train);输入4个字母出结果,循环核时间展开步数为4。 35 x_train = np.reshape(x_train, (len(x_train), 4)) 36 y_train = np.array(y_train) 37 38 39 model = tf.keras.Sequential([ 40 Embedding(26, 2), 41 SimpleRNN(10), 42 Dense(26, activation=‘softmax‘) 43 ]) 44 45 model.compile(optimizer=tf.keras.optimizers.Adam(0.01), 46 loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), 47 metrics=[‘sparse_categorical_accuracy‘]) 48 49 checkpoint_save_path = "./checkpoint/rnn_embedding_epred1.ckpt" 50 51 if os.path.exists(checkpoint_save_path + ‘.index‘): 52 print(‘-----------load the model----------‘) 53 model.load_weights(checkpoint_save_path) 54 55 cp_callback = tf.keras.callbacks.ModelCheckpoint( 56 filepath=checkpoint_save_path, 57 save_weights_only=True, 58 save_best_only=True, 59 monitor=‘loss‘) 60 61 history = model.fit(x_train, y_train, batch_size=32, epochs=100, callbacks=[cp_callback]) 62 63 model.summary() 64 65 66 with open(‘./weights.txt‘, ‘w‘) as f: 67 for v in model.trainable_variables: 68 f.write(str(v.name) + ‘\n‘) 69 f.write(str(v.shape) + ‘\n‘) 70 f.write(str(v.numpy()) + ‘\n‘) 71 72 73 acc = history.history[‘sparse_categorical_accuracy‘] 74 loss = history.history[‘loss‘] 75 76 plt.subplot(1, 2, 1) 77 plt.plot(acc, label=‘Training Accuracy‘) 78 plt.title(‘Training Accuracy‘) 79 plt.legend() 80 81 plt.subplot(1, 2, 2) 82 plt.plot(loss, label=‘Training Loss‘) 83 plt.title(‘Training Loss‘) 84 plt.legend() 85 plt.show() 86 87 88 preNum = int(input("Input the number of test alphabet:")) 89 for i in range(preNum): 90 alphabet1 = input("input test alphabet:") 91 alphabet = [w_to_id[a] for a in alphabet1] 92 # 使alphabet符合Embedding输入要求:[送入样本数, 时间展开步数]。 93 # 此处验证效果送入了1个样本,送入样本数为1;输入4个字母出结果,循环核时间展开步数为4。 94 alphabet = np.reshape(alphabet, (1, 4)) 95 result = model.predict([alphabet]) 96 pred = tf.argmax(result, axis=1) 97 pred =int(pred) 98 tf.print(alphabet1 + ‘->‘ + input_word[pred])
标签:cal ack for ict mod git rac var result
原文地址:https://www.cnblogs.com/wbloger/p/12879790.html