码迷,mamicode.com
首页 > 其他好文 > 详细

tensorflow2.0 squeeze出错

时间:2019-09-25 01:00:41      阅读:64      评论:0      收藏:0      [点我收藏+]

标签:图片   代码   flow   info   none   img   elf   删掉   put   

用tf.keras写了自定义层,但在调用自定义层的时候总是报错,找了好久才发现问题所在,所以记下此问题。

问题代码

u=tf.squeeze(tf.expand_dims(tf.expand_dims(inputs,axis=1),axis=3)@self.kernel,axis=3)

其中inputs的第一维为None,这里的代码为自定义的前向传播。我是想将得到的输出张量维度为1的维度删掉,因此调用了tf.squeeze方法,这时虽然没有报错但出现了问题。我分别打印了下面内容。

print(tf.expand_dims(tf.expand_dims(inputs,axis=1),axis=3).shape)
print(self.kernel.shape)
print((tf.expand_dims(tf.expand_dims(inputs,axis=1),axis=3)@self.kernel).shape)
print(tf.squeeze(tf.expand_dims(tf.expand_dims(inputs,axis=1))@self.kernel,axis=3))

技术图片

可以发现,当张量第一维为None的时候tf.squeeze使结果变为了0。我想要的结果是删除第三个输出的大小为1的维度,即得到下面的结果

技术图片

解决使用tf.squeeze的时候加上删除的维度。

tf.squeeze(tf.expand_dims(tf.expand_dims(inputs,axis=1),axis=3)@self.kernel,axis=3)

tensorflow2.0 squeeze出错

标签:图片   代码   flow   info   none   img   elf   删掉   put   

原文地址:https://www.cnblogs.com/lolybj/p/11581917.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!