码迷,mamicode.com
首页 > 其他好文 > 详细

NDArray自动求导

时间:2017-11-04 16:27:02      阅读:134      评论:0      收藏:0      [点我收藏+]

标签:es2017   结果   port   logs   cpu   log   导数   技术   http   

NDArray可以很方便的求解导数,比如下面的例子:(代码主要参考自https://zh.gluon.ai/chapter_crashcourse/autograd.html

技术分享

 用代码实现如下:

 1 import mxnet.ndarray as nd
 2 import mxnet.autograd as ag
 3 x = nd.array([[1,2],[3,4]])
 4 print(x)
 5 x.attach_grad() #附加导数存放的空间
 6 with ag.record():
 7     y = 2*x**2
 8 y.backward() #求导
 9 z = x.grad #将导数结果(也是一个矩阵)赋值给z
10 print(z) #打印结果
[[ 1.  2.]
 [ 3.  4.]]
<NDArray 2x2 @cpu(0)>

[[  4.   8.]
 [ 12.  16.]]
<NDArray 2x2 @cpu(0)>

 

对控制流求导

NDArray还能对诸如if的控制分支进行求导,比如下面这段代码:

1 def f(a):
2     if nd.sum(a).asscalar()<15: #如果矩阵a的元数和<15
3         b = a*2 #则所有元素*2
4     else:
5         b = a 
6     return b

数学公式等价于:

技术分享

这样就转换成本文最开头示例一样,变成单一函数求导,显然导数值就是x前的常数项,验证一下:

import mxnet.ndarray as nd
import mxnet.autograd as ag

def f(a):
    if nd.sum(a).asscalar()<15: #如果矩阵a的元数和<15
        b = a*2 #则所有元素平方
    else:
        b = a 
    return b

#注:1+2+3+4<15,所以进入b=a*2的分支
x = nd.array([[1,2],[3,4]])
print("x1=")
print(x)
x.attach_grad()
with ag.record():
    y = f(x)
print("y1=")
print(y)
y.backward() #dy/dx = y/x 即:2
print("x1.grad=")
print(x.grad)


x = x*2
print("x2=")
print(x)
x.attach_grad()
with ag.record():
    y = f(x)
print("y2=")
print(y)
y.backward()
print("x2.grad=")
print(x.grad)
x1=

[[ 1.  2.]
 [ 3.  4.]]
<NDArray 2x2 @cpu(0)>
y1=

[[ 2.  4.]
 [ 6.  8.]]
<NDArray 2x2 @cpu(0)>
x1.grad=

[[ 2.  2.]
 [ 2.  2.]]
<NDArray 2x2 @cpu(0)>
x2=

[[ 2.  4.]
 [ 6.  8.]]
<NDArray 2x2 @cpu(0)>
y2=

[[ 2.  4.]
 [ 6.  8.]]
<NDArray 2x2 @cpu(0)>
x2.grad=

[[ 1.  1.]
 [ 1.  1.]]
<NDArray 2x2 @cpu(0)>

 

NDArray自动求导

标签:es2017   结果   port   logs   cpu   log   导数   技术   http   

原文地址:http://www.cnblogs.com/yjmyzz/p/7783286.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!