标签:load rand join groups ups required ble 含义 清空
本节讲述Pytorch中torch.optim优化器包,学习率、参数Momentum动量的含义,以及常用的几类优化器。【Latex公式采用在线编码器】 优化器概念:管理并更新模型所选中的网络参数,使得模型输出更加接近真实标签。 |
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # 第一项也可自定义参数,用list封装
# 后面介绍的基本方法,都是利用optimizer.方法
? 所有的optim中的优化器都继承Optimizer父类,即:
class Optimizer(object):
def __init__(self, params, defaults):
torch._C._log_api_usage_once("python.optimizer")
self.defaults = defaults # 1 保存优化器本身的参数,例如
if isinstance(params, torch.Tensor):
raise TypeError("params argument given to the optimizer should be "
"an iterable of Tensors or dicts, but got " +
torch.typename(params))
self.state = defaultdict(dict) #2
self.param_groups = [] #3
param_groups = list(params)
if len(param_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
if not isinstance(param_groups[0], dict):
param_groups = [{‘params‘: param_groups}]
for param_group in param_groups:
self.add_param_group(param_group) # 调用add_param_group函数,将default优化器本身参数,送入param_groups中
? 由上式代码注释#可知,重要参数如下:
self.defaults:优化器本身参数,如学习率、动量等等
self.state:参数缓存,如动量缓存
self.param_groups:管理的参数组,注意这里是list(dict)形式,即列表中字典。
? ? 例如:<class ‘list‘>: [{‘params‘: [网络参数], ‘lr‘: 0.1, ‘momentum‘: 0, ‘dampening‘: 0, ‘weight_decay‘: 0, ‘nesterov‘: False}]
? 注意:这里模型中的参数(如W)与param_groups中保存的W,地址相同。
? 清空所管理的网络参数的梯度
class Optimizer(object):
def zero_grad(self):
r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
for group in self.param_groups: # 对于self.param_groups的list中字典key=‘params‘对应的value
for p in group[‘params‘]:
if p.grad is not None:
p.grad.detach_() # 脱离原来的计算图,被计算机捕捉到
p.grad.zero_()
? 执行一步更新,根据对应的梯度下降策略。
? 添加参数组,经常用于finetune,又例如设置两部分参数,e.g. 网络分为:特征提取层+全连接分类层,设置两组优化参数。
class Optimizer(object):
def add_param_group(self, param_group):
"""
Arguments:
param_group (dict): Specifies what Tensors should be optimized along with group
specific optimization options.
"""
params = param_group[‘params‘]
param_set = set()
for group in self.param_groups:
param_set.update(set(group[‘params‘]))
self.param_groups.append(param_group)
? 同一个优化器,添加新的优化参数:
weight = torch.randn((2, 2), requires_grad=True)
optimizer = optim.SGD([weight], lr=0.1)
print(‘添加之后未添加之前:{}‘.format(optimizer.param_groups))
‘‘‘
添加之后未添加之前:[{‘params‘: [tensor([[ 0.4523, 0.2895],
[-0.4283, 1.0688]], requires_grad=True)], ‘lr‘: 0.1, ‘momentum‘: 0, ‘dampening‘: 0, ‘weight_decay‘: 0, ‘nesterov‘: False}]
‘‘‘
w2 = torch.randn((3, 3), requires_grad=True)
optimizer.add_param_group({"params": w2, ‘lr‘: 0.0001})
print("添加之后{}".format(optimizer.param_groups))
‘‘‘
添加之后[{‘params‘: [tensor([[ 0.4523, 0.2895],
[-0.4283, 1.0688]], requires_grad=True)], ‘lr‘: 0.1, ‘momentum‘: 0, ‘dampening‘: 0, ‘weight_decay‘: 0, ‘nesterov‘: False}, {‘params‘: [tensor([[-1.0346, 1.2396, -1.4738],
[ 0.8029, -1.1723, 0.0783],
[ 0.7809, 0.4156, 0.3127]], requires_grad=True)], ‘lr‘: 0.0001, ‘momentum‘: 0, ‘dampening‘: 0, ‘weight_decay‘: 0, ‘nesterov‘: False}]
‘‘‘
? 可以看到添加之后,optimizer.param_groups list中含有两个字典,一个字典是之前的参数,另一个字典是新添加的一系列优化器参数
? 获取当前优化器的一系列信息参数。由代码可知,返回的是字典,两个key:‘state‘和‘param_groups‘
class Optimizer(object):
def state_dict(self):
...
...
return {
‘state‘: packed_state,
‘param_groups‘: param_groups,
}
? self.state:参数缓存,如动量缓存,当网络没有经过optimizer.step(),即没有根据loss.backward()得到的梯度去更新网络参数时,state为空:
print(optimizer.state_dict())
‘‘‘
{‘state‘: {}, ‘param_groups‘: [{‘lr‘: 0.1, ‘momentum‘: 0, ‘dampening‘: 0, ‘weight_decay‘: 0, ‘nesterov‘: False, ‘params‘: [140306069859856]}]}
‘‘‘
当更新之后,‘state‘将保存‘params‘中value的地址以及{‘momentun_buffer‘:tensor()}动量缓存,用于后续断点恢复。
? 加载保存的状态信息字典
‘‘‘保存优化器状态信息‘‘‘
torch.save(optimizer.state_dict(), os.path.join(address, "name.pkl"))
‘‘‘加载优化器状态信息‘‘‘
state_dict = torch.load(os.path.join(address, "name.pkl"))
optimizer.load_state_dict(state_dict)
? 学习率可以看作是对梯度的缩小因子,用来控制梯度更新的步伐:
lr不能过大(易loss激增);
lr不能过小(收敛较慢);
当设置lr适当小时,如0.01,此时可通过增加网络训练时间,进行弥补;
? 结合当前梯度与上一时刻更新的信息,来更新当前梯度信息。Momentum 梯度下降法 可追溯到指数加权平均:
其中 \(\theta _{t}\) 为当前时刻的参数,因为 \(\beta < 1\) ,从上述公式可知,距离当前t时刻越远的时刻参数,权重越小,对t时刻影响越小。
可以看到, 当 \(Momentum\) 太大时,由于受到前面时刻梯度线性影响,会有一定的震荡。
optimizer = optim.SGD(params, lr=required, momentum=0, dampening=0, weight_decay=0, nesterov=False)
? 下次来补充啦??!
标签:load rand join groups ups required ble 含义 清空
原文地址:https://www.cnblogs.com/zpc1001/p/13195928.html