码迷,mamicode.com
首页 > 其他好文 > 详细

第六讲 循环神经网络--Embedding--4pred1

时间:2020-05-13 00:36:35      阅读:152      评论:0      收藏:0      [点我收藏+]

标签:cal   ack   for   ict   mod   git   rac   var   result   

 1 import tensorflow as tf
 2 import numpy as np
 3 from tensorflow.keras.layers import Dense, SimpleRNN, Embedding
 4 import matplotlib.pyplot as plt
 5 import os
 6 
 7 
 8 input_word = "abcdefghijklmnopqrstuvwxyz"
 9 w_to_id = {a:0, b:1, c:2, d:3, e:4,
10            f:5, g:6, h:7, i:8, j:9,
11            k:10, l:11, m:12, n:13, o:14, p:15,
12            q:16, r:17, s:18, t:19, 
13            u:20, v:21, w:22, x:23, y:24, z:25}
14 
15 training_set_scaled = [x for x in range(26)]
16 
17 
18 x_train = []
19 y_train = []
20 
21 for i in range(4, 26):
22   x_train.append(training_set_scaled[i-4 : i])
23   y_train.append(training_set_scaled[i])
24 
25 
26 np.random.seed(7)
27 np.random.shuffle(x_train)
28 np.random.seed(7)
29 np.random.shuffle(y_train)
30 tf.random.set_seed(7)
31 
32 
33 # 使x_train符合Embedding输入要求:[送入样本数, 循环核时间展开步数] ,
34 # 此处整个数据集送入所以送入,送入样本数为len(x_train);输入4个字母出结果,循环核时间展开步数为4。
35 x_train = np.reshape(x_train, (len(x_train), 4))
36 y_train = np.array(y_train)
37 
38 
39 model = tf.keras.Sequential([
40         Embedding(26, 2),
41         SimpleRNN(10),
42         Dense(26, activation=softmax)
43 ])
44 
45 model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
46               loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
47               metrics=[sparse_categorical_accuracy])
48 
49 checkpoint_save_path = "./checkpoint/rnn_embedding_epred1.ckpt"
50 
51 if os.path.exists(checkpoint_save_path + .index):
52   print(-----------load the model----------)
53   model.load_weights(checkpoint_save_path)
54 
55 cp_callback = tf.keras.callbacks.ModelCheckpoint(
56     filepath=checkpoint_save_path,
57     save_weights_only=True,
58     save_best_only=True,
59     monitor=loss)
60 
61 history = model.fit(x_train, y_train, batch_size=32, epochs=100, callbacks=[cp_callback])
62 
63 model.summary()
64 
65 
66 with open(./weights.txt, w) as f:
67   for v in model.trainable_variables:
68     f.write(str(v.name) + \n)
69     f.write(str(v.shape) + \n)
70     f.write(str(v.numpy()) + \n)
71 
72 
73 acc = history.history[sparse_categorical_accuracy]
74 loss = history.history[loss]
75 
76 plt.subplot(1, 2, 1)
77 plt.plot(acc, label=Training Accuracy)
78 plt.title(Training Accuracy)
79 plt.legend()
80 
81 plt.subplot(1, 2, 2)
82 plt.plot(loss, label=Training Loss)
83 plt.title(Training Loss)
84 plt.legend()
85 plt.show()
86 
87 
88 preNum = int(input("Input the number of test alphabet:"))
89 for i in range(preNum):
90   alphabet1 = input("input test alphabet:")
91   alphabet = [w_to_id[a] for a in alphabet1]
92   # 使alphabet符合Embedding输入要求:[送入样本数, 时间展开步数]。
93   # 此处验证效果送入了1个样本,送入样本数为1;输入4个字母出结果,循环核时间展开步数为4。
94   alphabet = np.reshape(alphabet, (1, 4))
95   result = model.predict([alphabet])
96   pred = tf.argmax(result, axis=1)
97   pred =int(pred)
98   tf.print(alphabet1 + -> + input_word[pred])

 

第六讲 循环神经网络--Embedding--4pred1

标签:cal   ack   for   ict   mod   git   rac   var   result   

原文地址:https://www.cnblogs.com/wbloger/p/12879790.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!