码迷,mamicode.com
首页 > 其他好文 > 详细

tensorflow 测量工具,与自定义训练

时间:2020-05-05 17:36:29      阅读:113      评论:0      收藏:0      [点我收藏+]

标签:gdi   test   lock   variable   rac   tor   enumerate   red   写入   

 

# 新建测量器
m = tf.keras.metrics.Accuracy()
# 写入测量器
m.update_state([0,1,1],[0,1,2])
# 读取统计信息
m.result() # 准确率为0.66
# 清除
m.reset_states()
acc_meter = tf.keras.metrics.Accuracy()
loss_meter = tf.keras.metrics.Mean() # 求平均loss
op = tf.keras.optimizers.Adam(0.01)
import datetime
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = "logs/"+current_time
summary_writer = tf.summary.create_file_writer(logdir)
for epoch in range(10):
   for step,(x,y) in enumerate(train_data):
       with tf.GradientTape() as tape:
           loss = tf.losses.categorical_crossentropy(y,model(x))
           loss_meter.update_state(loss) # 准确率
       grads = tape.gradient(loss,model.train_variables) # 求梯度
       op.apply_gradients(zip(grads,model.train_variables)) # 更新梯度 w = w - delta
       
       with summary_writer.as_default()
           tf.summary.scalar(name="loss",data=loss_meter.result().numpy(),step=xxxx)
       print(epoch,step,loss,loss_meter.result().numpy())   # numpy() 将tensor转化为变量
       loss_meter.reset_states()
   
   for step,(x,y) in enumerate(test_data):
       out = model(x)
       pred = tf.argmax(out,axis=-1)
       pred = tf.cast(pred,dtype=tf.int32)
       y = tf.cast(tf.argmax(y,axis=-1),dtype=tf.int32)
       acc_meter.update_state(y,pred)
   with summary_writer.as_default()
       tf.summary.scalar(name="acc",data=acc_meter.result().numpy(),step=xxxx)    
   print(epoch,acc_meter.result().numpy())
   acc_meter.reset_states()

 

tensorflow 测量工具,与自定义训练

标签:gdi   test   lock   variable   rac   tor   enumerate   red   写入   

原文地址:https://www.cnblogs.com/Dean0731/p/12831518.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!