码迷,mamicode.com
首页 > 其他好文 > 详细

Tensorflow2.0之Minist手写数字识别运行通过源码

时间:2021-06-15 17:35:48      阅读:0      评论:0      收藏:0      [点我收藏+]

标签:amp   loading   转换   val   test   +=   sign   lib   range   

 
 

 

from matplotlib import pyplot as mp
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets ,layers ,optimizers


def preprocess(x,y):
    x=tf.cast(x,dtype=tf.float32)/255#归一化,使0-255之间的灰度值,归一化为0-1
    x=tf.reshape(x,[-1,28*28])
    y=tf.cast(y,dtype=tf.int32)
    y=tf.one_hot(y,depth=10)
    return x,y

(x,y),(x_test,y_test)=datasets.mnist.load_data()
print(x:,x.shape,y:,y.shape,x test:,x_test.shape,y test:,y_test)

batchsz=512
train_db=tf.data.Dataset.from_tensor_slices((x,y))#数据加载后转换成dataset对象
train_db=train_db.shuffle(1000)#缓冲区
train_db=train_db.batch(batchsz)#批训练,一次并行计算的样本个数
train_db=train_db.map(preprocess)#map是调用函数的意思
train_db=train_db.repeat(20)
test_db=tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db=test_db.shuffle(1000).batch(batchsz).map(preprocess)
x,y=next(iter(train_db))
print(train sample:,x.shape,y.shape)

#主函数
def main():
    lr=1e-2
    accs,losses=[],[]
    w1,b1=tf.Variable(tf.random.normal([784,256],stddev=0.1)),          tf.Variable(tf.zeros([256]))#random.normal是初始化
    w2, b2 = tf.Variable(tf.random.normal([256, 128], stddev=0.1)),              tf.Variable(tf.zeros([128]))
    w3, b3 = tf.Variable(tf.random.normal([128, 10], stddev=0.1)),              tf.Variable(tf.zeros([10]))

    for step,(x,y)in enumerate(train_db):
        x=tf.reshape(x,(-1,784))
        with tf.GradientTape() as tape:
            h1=x@w1+b1
            h1=tf.nn.relu(h1)
            h2=h1@w2+b2
            h2=tf.nn.relu(h2)
            out=h2@w3+b3
            #损失函数
            loss=tf.square(y-out)
            loss=tf.reduce_mean(loss)#tensor指定轴方向上的平均值
        #计算梯度
        grads=tape.gradient(loss,[w1,b1,w2,b2,w3,b3])
        #参数更新
        for p,g in zip([w1,b1,w2,b2,w3,b3],grads):
            p.assign_sub(lr*g)

        if step%100==0:
            print(step,loss:,float(loss))
            losses.append(float(loss))#列表尾部追加数据
        if step%100==0:
            total,total_correct=0,0
            for x,y in test_db:
                #layer1
                h1 = x @ w1 + b1
                h1 = tf.nn.relu(h1)
                #layer2
                h2 = h1 @ w2 + b2
                h2 = tf.nn.relu(h2)
                #output
                out = h2 @ w3 + b3
                pred=tf.argmax(out,axis=1)#寻找具有最大评分的数
                y=tf.argmax(y,axis=1)
                correct=tf.equal(pred,y)
                total_correct+=tf.reduce_sum(tf.cast(correct,dtype=tf.int32)).numpy()
                total+=x.shape[0]
            print(step,Evaluate Acc:,total_correct/total)
            accs.append(total_correct/total)
    mp.figure()
    x = [i * 80 for i in range(len(losses))]
    mp.plot(x, losses, color=C0, marker=s, label=train)
    mp.ylabel(MSE)
    mp.xlabel(Step)
    mp.legend()
    mp.savefig(train.svg)

    mp.figure()
    mp.plot(x, accs, color=C1, marker=s, label=test)
    mp.ylabel(Acc)
    mp.xlabel(Step)
    mp.legend()
    mp.savefig(test.svg)
    mp.show()

if __name__ == __main__:
        main()

 技术图片

技术图片

Tensorflow2.0之Minist手写数字识别运行通过源码

标签:amp   loading   转换   val   test   +=   sign   lib   range   

原文地址:https://www.cnblogs.com/sunshine-66/p/14882787.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!