码迷,mamicode.com
首页 > 其他好文 > 详细

CIFAR10-网络训练技术

时间:2020-02-29 13:20:47      阅读:85      评论:0      收藏:0      [点我收藏+]

标签:optimizer   scalar   aries   def   ima   pix   variable   训练   pre   

1、数据增强

  1)随机裁剪

  在原始图片的每一边pad 4个 pixels,然后再裁切成32*32的图片

distorted_images = tf.image.resize_image_with_crop_or_pad(record_images, 
                                                          imageHeight+8, imageWidth+8)
distorted_images = tf.random_crop(distorted_images, size = [batch_size, imageHeight, imageHeight, 3])

  2)随机翻转、调节亮度和对比度、标准化

for i in xrange(len(distorted_images)):
    distorted_images[i] = tf.image.random_flip_left_right(distorted_images[i])
    distorted_images[i] = tf.image.random_brightness(distorted_images[i], max_delta=63)
    distorted_images[i] = tf.image.random_contrast(distorted_images[i], lower=0.2, upper=1.8)
    distorted_images[i] = tf.image.per_image_standardization(distorted_images[i])

 

2、学习率

  1)线性衰减

  2)指数衰减

  3)按区间衰减

global_step = tf.Variable(0, trainable=False)
boundaries = [10000, 15000, 20000, 25000]
values = [0.1, 0.05, 0.01, 0.005, 0.001]
learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)

 

3、weight decay

#Add the l2 weights to the loss
#Add weight decay to the loss.
l2_loss = weight_decay * tf.add_n(
# loss is computed using fp32 for numerical stability.
[tf.nn.l2_loss(tf.cast(v, tf.float32)) for v in tf.trainable_variables()])
tf.summary.scalar(‘l2_loss‘, l2_loss)
loss = cross_entropy_mean + l2_loss

 

4、优化器

#Define the optimizer
optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
 
#Relate to the batch normalization
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    opt_op = optimizer.minimize(loss, global_step)

 

CIFAR10-网络训练技术

标签:optimizer   scalar   aries   def   ima   pix   variable   训练   pre   

原文地址:https://www.cnblogs.com/wt-seu/p/12382130.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!