码迷,mamicode.com
首页 > 其他好文 > 详细

Tensorflow - tf.split使用

时间:2019-07-31 19:04:34      阅读:401      评论:0      收藏:0      [点我收藏+]

标签:image   var   one   net   imp   数值   size   none   shape   

XDeepFM的CIN中第一层实现需要使两个二维矩阵相乘得到一个三维张量,于是来复习下split函数(需要用到):
首先看下函数原理:

tf.split(
    value,
    num_or_size_splits,
    axis=0,
    num=None,
    name=split
)

这个函数是用来切割张量的:输入切割的张量和参数,返回切割的结果。
value传入的就是需要切割的张量,axis的数值代表切割哪个维度。
这个函数有两种切割的方式:

以三个维度的张量为例,比如说一个20 * 30 * 40的张量my_tensor,就如同一个长20厘米宽30厘米高40厘米的蛋糕,每立方厘米都是一个分量。

有两种切割方式:
1. 如果num_or_size_splits传入的是一个整数,这个整数代表这个张量最后会被切成几个小张量。此时,传入axis的数值就代表切割哪个维度(从0开始计数)。调用tf.split(my_tensor, 2,0)返回两个10 * 30 * 40的小张量。
2. 如果num_or_size_splits传入的是一个向量,那么向量有几个分量就分成几份,切割的维度还是由axis决定。比如调用tf.split(my_tensor, [10, 5, 25], 2),则返回三个张量分别大小为 20 * 30 * 10、20 * 30 * 5、20 * 30 * 25。很显然,传入的这个向量各个分量加和必须等于axis所指示原张量维度的大小 (10 + 5 + 25 = 40)。

一个实例:

import tensorflow as tf
import numpy as np

arr1 = tf.convert_to_tensor(np.arange(1,25).reshape(2,4,3),dtype=tf.int32)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    split_arr1 = tf.split(arr1,[1,1,1],2) # 切割成2个2*4*1的张量
   print(sess.run(split_arr1)

可以看到原来的2*4*3的张量被切割为了3个2*4*1的张量

技术图片

 

Reference:

https://blog.csdn.net/SangrealLilith/article/details/80272346

Tensorflow - tf.split使用

标签:image   var   one   net   imp   数值   size   none   shape   

原文地址:https://www.cnblogs.com/Jesee/p/11277868.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!