码迷,mamicode.com
首页 > Windows程序 > 详细

tensorflow API _ 3 (tf.train.polynomial_decay)

时间:2018-04-24 11:17:42      阅读:1184      评论:0      收藏:0      [点我收藏+]

标签:training   ber   type   replica   方式   false   err   conf   rate   

学习率的三种调整方式:
固定的,指数的,多项式的

def _configure_learning_rate(num_samples_per_epoch, global_step):
"""Configures the learning rate.

Args:
num_samples_per_epoch: The number of samples in each epoch of training.
global_step: The global_step tensor.

Returns:
A `Tensor` representing the learning rate.

Raises:
ValueError: if
"""
decay_steps = int(num_samples_per_epoch / FLAGS.batch_size *
FLAGS.num_epochs_per_decay)
if FLAGS.sync_replicas:
decay_steps /= FLAGS.replicas_to_aggregate

if FLAGS.learning_rate_decay_type == ‘exponential‘:
return tf.train.exponential_decay(FLAGS.learning_rate,
global_step,
decay_steps,
FLAGS.learning_rate_decay_factor,
staircase=True,
name=‘exponential_decay_learning_rate‘)
elif FLAGS.learning_rate_decay_type == ‘fixed‘:
return tf.constant(FLAGS.learning_rate, name=‘fixed_learning_rate‘)
elif FLAGS.learning_rate_decay_type == ‘polynomial‘:
return tf.train.polynomial_decay(FLAGS.learning_rate,
global_step,
decay_steps,
FLAGS.end_learning_rate,
power=1.0,
cycle=False,
name=‘polynomial_decay_learning_rate‘)
else:
raise ValueError(‘learning_rate_decay_type [%s] was not recognized‘,
FLAGS.learning_rate_decay_type)

tensorflow API _ 3 (tf.train.polynomial_decay)

标签:training   ber   type   replica   方式   false   err   conf   rate   

原文地址:https://www.cnblogs.com/Libo-Master/p/8926136.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!