码迷,mamicode.com
首页 > 其他好文 > 详细

tensorflow-底层梯度(2)

时间:2018-12-14 21:16:09      阅读:183      评论:0      收藏:0      [点我收藏+]

标签:stop   影响   import   ted   env   方法   usr   4.0   pre   

#!/usr/bin/env python2 # -*- coding: utf-8 -*- """ Created on Mon Aug 27 11:16:32 2018 @author: myhaspl """ import tensorflow as tf x = tf.constant(1.) a = 12*x b = 2*a c = 3*b g1 = tf.gradients([a + b + c], [a, b, c]) g2 = tf.gradients([a + b + c], [a, b, c],stop_gradients=[a, b, c]) sess=tf.Session() with sess: print sess.run(g1) print sess.run(g2)

[9.0, 4.0, 1.0]
[1.0, 1.0, 1.0]
与全导数g1 = tf.gradients([a + b + c], [a, b, c])相比,偏导数g2 = tf.gradients([a + b + c], [a, b, c],stop_gradients=[a, b, c])
的值是[1.0,1.0 , 1.0],而全导数g1 = tf.gradients([a + b + c], [a, b, c])考虑了a对b和c的影响,并求值为[9.0, 4.0, 1.0],例如:

(a+b+c)‘a=9,其中:

b=2*a

c=3b=32*a=6a

相比于在图构造期间使用的tf.stop_gradient。stop_gradients提供了已经构造完图之后停止梯度的方法,当将这两种方法结合起来时,反向传播在tf.stop_gradient节点和stop_gradients节点处都停止,无论首先遇到哪个节点。

tensorflow-底层梯度(2)

标签:stop   影响   import   ted   env   方法   usr   4.0   pre   

原文地址:http://blog.51cto.com/13959448/2330673

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!