码迷,mamicode.com
首页 > 编程语言 > 详细

PyQt训练BP模型时,显示waiting动图(多线程)

时间:2018-06-16 23:35:03      阅读:274      评论:0      收藏:0      [点我收藏+]

标签:finish   技术   thread   ram   AC   cti   代码   else   initial   

1、实现效果

技术分享图片

2、相关代码

实现BP训练模型的线程类

 1 class WorkThread(QtCore.QThread):
 2     finish_trigger = QtCore.pyqtSignal()  # 关闭waiting_gif
 3     result_trigger = QtCore.pyqtSignal(pd.Series)  # 传递预测结果信号
 4     evaluate_trigger = QtCore.pyqtSignal(list)  # 传递正确率信号
 5 
 6     def __int__(self):
 7         super(WorkThread, self).__init__()
 8 
 9     def init(self, dataset, feature, label, info):
10         self.dataset = dataset
11         self.feature = feature
12         self.label = label
13         self.info = info
14 
15     # 可以认为,run()函数就是新的线程需要执行的代码
16     def run(self):
17         self.BP()
18 
19     def BP(self):
20         """
21         BP神经网络,返回标签的预测数据
22         :param parent:
23         :param dataset:
24         :param feature:
25         :param label:
26         :param info:
27         :return:
28         """
29         dataset = self.dataset
30         feature = self.feature
31         label = self.label
32         info = self.info
33 
34         input_dim = len(feature)
35         data_x = dataset[feature]  # 特征数据
36         data_y = dataset[label]  # 标签数据
37 
38         x_train, x_test, y_train, y_test = train_test_split(data_x, data_y, test_size=info[0][3])
39     
40         # **********************建立一个简单BP神经网络模型*********************************
41         self.model = Sequential()  # 声明一个顺序模型
42         count = len(info)
43         for i in range(1, count-1):
44             if i == 1:
45                 self.model.add(Dense(info[i][0], activation=info[i][1], input_dim=input_dim, kernel_initializer=info[i][2]))  # 输入层,Dense表示BP层
46             else:
47                 self.model.add(Dense(info[i][0], activation=info[i][1], kernel_initializer=info[i][2]))
48 
49         # 添加输出层
50         self.model.add(Dense(info[count-1][0], activation=info[count-1][1], kernel_initializer=info[count-1][2]))
51 
52         sgd = SGD(lr=info[0][0], decay=1e-6, momentum=0.9, nesterov=True)
53         self.model.compile(loss=binary_crossentropy,  optimizer=sgd,  metrics=[accuracy])  # 编译模型
54 
55         self.model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=info[0][1], batch_size=info[0][2])  # 训练模型1000次
56 
57         scores_train = self.model.evaluate(x_train, y_train, batch_size=10)
58         scores_test = self.model.evaluate(x_test, y_test, batch_size=10)
59         scores = self.model.evaluate(data_x, data_y, batch_size=10)
60 
61         self.finish_trigger.emit()         # 循环完毕后发出信号
62         list = [scores_train[1]*100, scores_test[1]*100, scores[1]*100]
63         self.evaluate_trigger.emit(list)
64         result = pd.Series(self.model.predict(data_x).T[0])
65         result.name = 预测(BP)
66         self.result_trigger.emit(result)
67         K.clear_session()  # 反复调用model 模型
68 
69     def save_model(self, save_dir):
70         self.model.save(save_dir)  # 保存模型

GUI显示代码(部分):

 1 class MainWindow(QtGui.QMainWindow):
 2     save_dir_signal = QtCore.pyqtSignal(str)  # 传递保存目录信号
 3 
 4 def show_evaluate_result(self, evaluate_result):
 5         help = QtGui.QMessageBox.information(self, 评价结果,
 6                                              "训练集正确率:  %.2f%%\n测试集正确率:  %.2f%%\n数据集正确率:  %.2f%%" %
 7                                              (evaluate_result[0], evaluate_result[1], evaluate_result[2]),
 8                                              QtGui.QMessageBox.Yes)
 9 
10         self.pop_save_dir()
11 
12     def pop_save_dir(self):
13         msg = QtGui.QMessageBox.information(self, 提示, 是否保存模型?, QtGui.QMessageBox.Yes | QtGui.QMessageBox.No)
14         if msg == QtGui.QMessageBox.Yes:
15                 save_dir = QtGui.QFileDialog.getSaveFileName(self, 选择保存目录, C:\\Users\\fuqia\\Desktop)
16 
17                 if save_dir != ‘‘:
18                     save_dir = save_dir + .model
19                     self.save_dir_signal.emit(save_dir)
20 
21     def show_bp_result(self, result):
22 
23         self.predict_data = result
24         TableWidgetDeal.add_predict_data(self.table, result)
25 
26     def waiting_label_close(self):
27         self.label.close()
28 
29     def show_waiting(self):
30         self.label = QtGui.QLabel(self)
31         self.label.setFixedSize(640, 480)  # 不加的话有问题???
32         self.label.setWindowFlags(QtCore.Qt.FramelessWindowHint)  # 无边框
33         self.label.setAttribute(QtCore.Qt.WA_TranslucentBackground)  # 背景透明
34 
35         screen = QtGui.QDesktopWidget().screenGeometry()
36         size = self.label.geometry()
37         # 如果是self.label.move((screen.width() - size.width()) / 2 , (screen.height() - size.height()) / 2)无法居中
38         self.label.move((screen.width() - size.width()) / 2 + 240, (screen.height() - size.height()) / 2)
39 
40         # 打开gif文件
41         movie = QtGui.QMovie("./Icon/waiting.gif")
42         # 设置cacheMode为CacheAll时表示gif无限循环,注意此时loopCount()返回-1
43         movie.setCacheMode(QtGui.QMovie.CacheAll)
44         # 播放速度
45         movie.setSpeed(100)
46         self.label.setMovie(movie)
47         # 开始播放,对应的是movie.start()
48         movie.start()
49         self.label.show()
50         q = QtCore.QEventLoop()
51         q.exec_()

 

1 w = WorkThread()
2 w.init(self.object.data_set, feature, label, self.bp_ui.bp_info)
3 w.start()
4 w.finish_trigger.connect(self.waiting_label_close)
5 w.result_trigger.connect(self.show_bp_result)
6 w.evaluate_trigger.connect(self.show_evaluate_result)
7 self.save_dir_signal.connect(w.save_model)
8 self.show_waiting()

 

PyQt训练BP模型时,显示waiting动图(多线程)

标签:finish   技术   thread   ram   AC   cti   代码   else   initial   

原文地址:https://www.cnblogs.com/fuqia/p/9191696.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!