码迷,mamicode.com
首页 > 其他好文 > 详细

Tensor的合并与分割

时间:2020-01-22 22:14:08      阅读:88      评论:0      收藏:0      [点我收藏+]

标签:shape   concat   行合并   tac   nes   ack   接口   spl   堆叠   

先来看一下有哪些接口用来进行张量的合并与分割:

tf.concat用来进行张量的拼接,tf.stack用来进行张量的堆叠,tf.split用来进行张量的分割,tf.unstack是tf.split的一种,也用来进行张量分割

1.tf.concat

参数axis代表将要合并的维度

# 假设a代表四个班的成绩(每班35人,8个科目),b代表2个班的成绩
a = tf.ones([4,35,8])
b = tf.ones([2,35,8])
# 使用concat进行合并得到6个班的成绩
c = tf.concat([a,b],axis=0)
# (6,35,8)
print(c.shape)

2.tf.stack(用于创建一个新的维度)

# 假设a代表A学校的四个班的成绩(每班35人,8个科目),b代表B学校四个班的成绩
a = tf.ones([4,35,8])
b = tf.ones([4,35,8])
# 使用stack进行合并得到6个班的成绩
c = tf.stack([a,b],axis=0)
# (2,4,35,8)
print(c.shape)

3.tf.unstack(对某维度进行等分)

# 假设a代表A学校的四个班的成绩(每班35人,8个科目),b代表B学校四个班的成绩
a = tf.ones([4,35,8])
b = tf.ones([4,35,8])
# 使用stack进行合并得到6个班的成绩
c = tf.stack([a,b],axis=0)
# (2,4,35,8)
print(c.shape)
aa,bb=tf.unstack(c,axis=0)
# (4,35,8)
print(aa.shape,bb.shape)
res=tf.unstack(c,axis=3)
# (2,4,35)
print(res[0].shape,res[7].shape)

4.tf.split(按比例打散)

# 假设a代表A学校的四个班的成绩(每班35人,8个科目),b代表B学校四个班的成绩
a = tf.ones([4,35,8])
b = tf.ones([4,35,8])
# 使用stack进行合并得到6个班的成绩
c = tf.stack([a,b],axis=0)
# (2,4,35,8)
print(c.shape)
res = tf.split(c,axis=3,num_or_size_splits=2)
# 2,(2,4,35,4)
print(len(res),res[0].shape,res[1].shape)
res = tf.split(c,axis=3,num_or_size_splits=[2,2,4])
# 3 (2,4,35,2) (2,4,35,2) (2,4,35,4)
print(len(res),res[0].shape,res[1].shape,res[2].shape)

Tensor的合并与分割

标签:shape   concat   行合并   tac   nes   ack   接口   spl   堆叠   

原文地址:https://www.cnblogs.com/zdm-code/p/12229527.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!