码迷,mamicode.com
首页 > 其他好文 > 详细

机器学习-识别手写数字0-9

时间:2021-03-15 10:35:44      阅读:0      评论:0      收藏:0      [点我收藏+]

标签:tin   res   atp   top   print   model   sof   mes   rom   

 1 import os
 2 os.environ[TF_CPP_MIN_LOG_LEVEL]=2 # to hidden the messages from tensorflow
 3 from tensorflow import keras
 4 from  tensorflow.keras import layers
 5 import numpy as np
 6 import matplotlib.pyplot as plt
 7 # from tensorflow.keras.datasets import mnist
 8 # mnist is the handwriting number dataset 0-9
 9 import sys
10 
11 def load_mnist(path):
12     file=np.load(path)
13     x_train,y_train=file[x_train],file[y_train]
14     x_test,y_test=file[x_test],file[y_test]
15     file.close()
16     return (x_train, y_train), (x_test, y_test)
17 
18 
19 def check(N,pages):
20     idx=0
21     for page in range(pages):
22         for i in range(N):
23             for j in range(N):
24                 num = i*N+j
25                 plt.subplot(N,N,num+1)
26                 plt.imshow(x_train[num+idx],cmap=plt.get_cmap(gray))
27         idx+=N*N
28         plt.show()
29         # show the plot
30 
31 
32 path="C:/Users/77007/Desktop/python/pythonProject1/mnist.npz"
33 (x_train,y_train),(x_test, y_test)=load_mnist(path)
34 # print(x_train.shape) # 60000 张图片 pix 28*28
35 # print(y_train.shape) # 60000 个结果对应0-9其中一个
36 x_train=x_train.reshape(-1,784).astype("float32")/255.0
37 x_test=x_test.reshape(-1,784).astype("float32")/255.0
38 # print(x_train.shape)
39 # print(x_test.shape)
40 
41 # Sequential API (convenient, not flexible)
42 model=keras.Sequential(
43     [
44         keras.Input(shape=(28*28)), # for print the model
45         layers.Dense(512,activation=relu),
46         layers.Dense(256,activation=relu),
47         layers.Dense(10),
48     ]
49 )
50 # another definition method
51 ‘‘‘ 
52 model=keras.Sequential()
53 model.add(keras.Input(shape=784))
54 model.add(layers.Dense(512,activation=‘relu‘))
55 model.add(layers.Dense(256,activation=‘relu‘))
56 model.add(layers.Dense(10))
57 ‘‘‘
58 # print(model.summary())
59 # sys.exit()
60 
61 # Functional API (more flexible)
62 inputs=keras.Input(shape=(784))
63 x=layers.Dense(512,activation=relu,name=first_layer)(inputs)
64 x=layers.Dense(256,activation=relu,name=second_layer)(x)
65 outputs=layers.Dense(10,activation=softmax)(x)
66 model=keras.Model(inputs=inputs,outputs=outputs) # change the model (OR you can comment this line using the model upward)
67 # print(model.summary())
68 # for printing summary of our model, we can add name=‘xxx‘ feature of each layers
69 
70 model.compile(
71     loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
72     # Sequential API using True, Functional API using False
73     optimizer=keras.optimizers.Adam(lr=0.001),
74     # learing rate=0.001
75     metrics=["accuracy"],
76 )
77 
78 model.fit(x_train,y_train,batch_size=32,epochs=5,verbose=2)
79 model.evaluate(x_test,y_test,batch_size=32,verbose=2)

 

 1 Train on 60000 samples
 2 Epoch 1/5
 3 60000/60000 - 6s - loss: 0.1865 - accuracy: 0.9425
 4 Epoch 2/5
 5 60000/60000 - 5s - loss: 0.0800 - accuracy: 0.9749
 6 Epoch 3/5
 7 60000/60000 - 5s - loss: 0.0541 - accuracy: 0.9826
 8 Epoch 4/5
 9 60000/60000 - 5s - loss: 0.0395 - accuracy: 0.9872
10 Epoch 5/5
11 60000/60000 - 5s - loss: 0.0341 - accuracy: 0.9890
12 10000/1 - 0s - loss: 0.0392 - accuracy: 0.9792

 

机器学习-识别手写数字0-9

标签:tin   res   atp   top   print   model   sof   mes   rom   

原文地址:https://www.cnblogs.com/JasonCow/p/14524922.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!