标签:amp loading 转换 val test += sign lib range
from matplotlib import pyplot as mp import tensorflow as tf from tensorflow import keras from tensorflow.keras import datasets ,layers ,optimizers def preprocess(x,y): x=tf.cast(x,dtype=tf.float32)/255#归一化,使0-255之间的灰度值,归一化为0-1 x=tf.reshape(x,[-1,28*28]) y=tf.cast(y,dtype=tf.int32) y=tf.one_hot(y,depth=10) return x,y (x,y),(x_test,y_test)=datasets.mnist.load_data() print(‘x:‘,x.shape,‘y:‘,y.shape,‘x test:‘,x_test.shape,‘y test:‘,y_test) batchsz=512 train_db=tf.data.Dataset.from_tensor_slices((x,y))#数据加载后转换成dataset对象 train_db=train_db.shuffle(1000)#缓冲区 train_db=train_db.batch(batchsz)#批训练,一次并行计算的样本个数 train_db=train_db.map(preprocess)#map是调用函数的意思 train_db=train_db.repeat(20) test_db=tf.data.Dataset.from_tensor_slices((x_test,y_test)) test_db=test_db.shuffle(1000).batch(batchsz).map(preprocess) x,y=next(iter(train_db)) print(‘train sample:‘,x.shape,y.shape) #主函数 def main(): lr=1e-2 accs,losses=[],[] w1,b1=tf.Variable(tf.random.normal([784,256],stddev=0.1)), tf.Variable(tf.zeros([256]))#random.normal是初始化 w2, b2 = tf.Variable(tf.random.normal([256, 128], stddev=0.1)), tf.Variable(tf.zeros([128])) w3, b3 = tf.Variable(tf.random.normal([128, 10], stddev=0.1)), tf.Variable(tf.zeros([10])) for step,(x,y)in enumerate(train_db): x=tf.reshape(x,(-1,784)) with tf.GradientTape() as tape: h1=x@w1+b1 h1=tf.nn.relu(h1) h2=h1@w2+b2 h2=tf.nn.relu(h2) out=h2@w3+b3 #损失函数 loss=tf.square(y-out) loss=tf.reduce_mean(loss)#tensor指定轴方向上的平均值 #计算梯度 grads=tape.gradient(loss,[w1,b1,w2,b2,w3,b3]) #参数更新 for p,g in zip([w1,b1,w2,b2,w3,b3],grads): p.assign_sub(lr*g) if step%100==0: print(step,‘loss:‘,float(loss)) losses.append(float(loss))#列表尾部追加数据 if step%100==0: total,total_correct=0,0 for x,y in test_db: #layer1 h1 = x @ w1 + b1 h1 = tf.nn.relu(h1) #layer2 h2 = h1 @ w2 + b2 h2 = tf.nn.relu(h2) #output out = h2 @ w3 + b3 pred=tf.argmax(out,axis=1)#寻找具有最大评分的数 y=tf.argmax(y,axis=1) correct=tf.equal(pred,y) total_correct+=tf.reduce_sum(tf.cast(correct,dtype=tf.int32)).numpy() total+=x.shape[0] print(step,‘Evaluate Acc:‘,total_correct/total) accs.append(total_correct/total) mp.figure() x = [i * 80 for i in range(len(losses))] mp.plot(x, losses, color=‘C0‘, marker=‘s‘, label=‘train‘) mp.ylabel(‘MSE‘) mp.xlabel(‘Step‘) mp.legend() mp.savefig(‘train.svg‘) mp.figure() mp.plot(x, accs, color=‘C1‘, marker=‘s‘, label=‘test‘) mp.ylabel(‘Acc‘) mp.xlabel(‘Step‘) mp.legend() mp.savefig(‘test.svg‘) mp.show() if __name__ == ‘__main__‘: main()
Tensorflow2.0之Minist手写数字识别运行通过源码
标签:amp loading 转换 val test += sign lib range
原文地址:https://www.cnblogs.com/sunshine-66/p/14882787.html