码迷,mamicode.com
首页 > 其他好文 > 详细

动手学深度学习第一课:从上手到多类分类-Autograd

时间:2019-09-01 13:03:57      阅读:121      评论:0      收藏:0      [点我收藏+]

标签:导入   命令式   lse   war   介绍   梯度   梯度下降   控制   根据   

使用autograd来自动求导

在机器学习中,我们通常使用梯度下降来更新模型参数从而求解。损失函数关于模型参数的梯度指向一个可以降低损失函数值的方向,我们不断地沿着梯度的方向更新模型从而最小化损失函数。虽然梯度计算比较直观,但对于复杂的模型,例如多达数十层的神经网络,手动计算梯度非常困难。

为此MXNet提供autograd包来自动化求导过程。虽然大部分的深度学习框架要求编译计算图来自动求导,mxnet.autograd可以对正常的命令式程序进行求导,它每次在后端实时创建计算图从而可以立即得到梯度的计算方法。

下面让我们一步步介绍这个包。我们先导入autograd。

import mxnet.ndarray as nd
import mxnet.autograd as ag

为变量附上梯度

假设假设我们想对函数f = 2 * (x ** 2)求关于x的导数。我们先创建变量x,并赋初值。

x = nd.array([[1, 2], [3, 4]])
x
[[ 1.  2.]
 [ 3.  4.]]
<NDArray 2x2 @cpu(0)>

当进行求导的时候,我们需要一个地方来存x的导数,这个可以通过NDArray的方法attach_grad()来要求系统申请对应的空间。

x.attach_grad()

下面定义f。默认条件下,MXNet不会自动记录和构建用于求导的计算图,我们需要使用autograd里的record()函数来显式的要求MXNet记录我们需要求导的程序。

with ag.record():
    y = x * 2
    z = y * x
z
[[  2.   8.]
 [ 18.  32.]]
<NDArray 2x2 @cpu(0)>

接下来我们可以通过z.backward()来进行求导。如果z不是一个标量,那么z.backward()等价于nd.sum(z).backward()。

z.backward()
x.grad
[[  4.   8.]
 [ 12.  16.]]
<NDArray 2x2 @cpu(0)>

现在我们来看求出来的导数是不是正确的。注意到y = x * 2和z = x * y,所以z等价于2 * x * x。那么它的导数就是dz/dx = 4 * x。

x.grad == 4 * x
[[ 1.  1.]
 [ 1.  1.]]
<NDArray 2x2 @cpu(0)>
x.grad == 3 * x
[[ 0.  0.]
 [ 0.  0.]]
<NDArray 2x2 @cpu(0)>

对控制流求导

命令式的编程的一个便利之处是几乎可以对任意的可导程序进行求导,即使里面包含了Python的控制流。考虑下面程序,里面包含控制流for和if,但循环迭代的次数和判断语句的执行都是取决于输入的值。不同的输入会导致这个程序的执行不一样。(对于计算图框架来说,这个对应于动态图,就是图的结构会根据输入数据不同而改变)。

def f(a):
    b = a * 2
    while nd.norm(b).asscalar() < 1000:
        b = b * 2
    if nd.sum(b).asscalar() > 0:
        c = b
    else:
        c = 100 * b
    return c

我们可以跟之前一样使用record记录和backward求导。

a = nd.random_normal(shape=3)
a.attach_grad()
with ag.record():
    c = f(a)
c.backward()
a.grad
[ 51200.  51200.  51200.]
<NDArray 3 @cpu(0)>

注意到给定输入a,其输出f(a) = xa,x的值取决于输入a。所以有df/da = x,我们可以很简单地评估自动求导的导数:

c/a
[ 51200.  51200.  51200.]
<NDArray 3 @cpu(0)>
a.grad == c/a
[ 1.  1.  1.]
<NDArray 3 @cpu(0)>

动手学深度学习第一课:从上手到多类分类-Autograd

标签:导入   命令式   lse   war   介绍   梯度   梯度下降   控制   根据   

原文地址:https://www.cnblogs.com/KisInfinite/p/11441836.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!