码迷,mamicode.com
首页 > 其他好文 > 详细

鸢尾花等表格数据简单分类器(模型可以替换)

时间:2018-10-04 18:58:34      阅读:152      评论:0      收藏:0      [点我收藏+]

标签:loss   分数   shape   好的   metrics   ras   read   res   orm   

Keras 2.2.4

Keras-Applications 1.0.6

Keras-Preprocessing 1.0.5

tensorflow          1.11.0

numpy               1.15.2

pandas              0.23.4

scikit-learn        0.20.0

测试成功

 1 # -*- coding: utf-8 -*-
 2 import numpy
 3 import pandas
 4 from keras.layers.core import Dense, Dropout, Activation
 5 from keras.models import Sequential
 6 from keras.utils import np_utils
 7 from keras.utils import plot_model
 8 from sklearn import utils
 9 from sklearn.model_selection import StratifiedShuffleSplit
10 from sklearn.preprocessing import LabelEncoder
11 
12 
13 def load_data():
14     ‘‘‘
15     获取数据
16     :return x_train, y_train, x_test, y_test, encoder:
17     ‘‘‘
18     # 载入数据
19     data_frame = pandas.read_csv("iris.csv", header=None)
20     data_set = data_frame.values
21     # 取所有行,从第0列到第4列(不包含第4列)
22     x_data = data_set[:, 0:4].astype(float)
23     # 取所有行,第4列
24     y_data = data_set[:, 4]
25     # 标签编码
26     encoder = LabelEncoder()
27     # 将字符串编译成0,1,2,3分类
28     # encoder.classes_以npy可以保存加载编码规则(np.save(‘encoder.npy‘,encoder.classes_),encoder.classes_=np.load(‘encoder.npy‘))
29     encoded_transform_y = encoder.fit_transform(y_data)
30     # 编译好的0,1,2,3 One_Hot
31     y_data = np_utils.to_categorical(encoded_transform_y)
32     # 打乱数据集
33     x_data, y_data = utils.shuffle(x_data, y_data)
34     # 切分数据集
35     train_idx, test_idx = next(iter(
36         StratifiedShuffleSplit(n_splits=1, test_size=0.2,
37                                random_state=0).split(x_data, y_data)))
38     x_train = x_data[train_idx]
39     y_train = y_data[train_idx]
40     x_test = x_data[test_idx]
41     y_test = y_data[test_idx]
42     return x_train, y_train, x_test, y_test, encoder
43 
44 
45 def compile_model():
46     # 模型
47     _model = Sequential()
48     _model.add(Dense(10, input_shape=(4,)))
49     _model.add(Activation(tanh))
50     _model.add(Dropout(0.2))
51     _model.add(Dense(3))
52     _model.add(Activation(softmax))
53     _model.compile(
54         loss="categorical_crossentropy",
55         optimizer=adam,
56         metrics=[accuracy])
57     # 生成模型图片
58     plot_model(_model, to_file=model.png, show_shapes=True)
59     return _model
60 
61 
62 def train_model(_model, _x_train, _y_train, _x_test, _y_test):
63     # 训练
64     history = _model.fit(_x_train, _y_train, epochs=100, batch_size=12,
65                          verbose=1, validation_data=[_x_test, _y_test])
66     # 测试训练集
67     score = _model.evaluate(_x_test, _y_test, verbose=1)
68     print(Test score:, score[0])
69     print(Test accuracy:, score[1])
70 
71 
72 def test(_model, _encoder, _x_test):
73     # 校验,返回标签
74     result = _model.predict(_x_test)
75     result = numpy.argmax(result, axis=1)
76     result = _encoder.inverse_transform(result)
77     print(result)
78 
79 
80 if __name__ == __main__:
81     x_train, y_train, x_test, y_test, encoder = load_data()
82     model = compile_model()
83     train_model(model, x_train, y_train, x_test, y_test)
84     test(model, encoder, x_test)

 

鸢尾花等表格数据简单分类器(模型可以替换)

标签:loss   分数   shape   好的   metrics   ras   read   res   orm   

原文地址:https://www.cnblogs.com/yzpopulation/p/9742742.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!