码迷,mamicode.com
首页 > 其他好文 > 详细

pytorch中的scatter_()函数

时间:2020-03-21 13:15:53      阅读:338      评论:0      收藏:0      [点我收藏+]

标签:proc   代码   上进   ros   就是   att   ESS   下标   code   

最近在学习pytorch函数时需要做独热码,然后遇到了scatter_()函数,不太明白意思,现在懂了记录一下以免以后忘记。

这个函数是用一个src的源张量或者标量以及索引来修改另一个张量。这个函数主要有三个参数scatter_(dim,index,src)

dim:沿着哪个维度来进行索引(一会儿举个例子就明白了)

index:用来进行索引的张量

src:源张量或者标量

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

这个是官网给出的例子,但是一般在做独热码的时候通常是采用二维张量所以应该是这样

#dim=0
self[index[x][y]][y]=src[x][y]  

#dim=1
self[x][index[x][y]]=src[x][y]

这个是什么意思呢。首先请看下面的程序,程序是我瞎编的,想试试的话可以自己编数据哈

import torch
x=torch.rand(3,5)
print(x)
print(-------------------)
y=torch.zeros(3,5)
print(y)
print(-------------------)
inx=torch.tensor([[0,4,3,1,2],[3,2,1,4,3]])
output_y=y.scatter_(dim=1,index=inx,src=x)
print(output_y)

下面是运行的结果

tensor([[0.1380, 0.6030, 0.2396, 0.0066, 0.7116],
        [0.5755, 0.2856, 0.4862, 0.2132, 0.2475],
        [0.5145, 0.4753, 0.2736, 0.2623, 0.8532]])
-------------------
tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])
-------------------
tensor([[0.1380, 0.0066, 0.7116, 0.2396, 0.6030],
        [0.0000, 0.4862, 0.2856, 0.2475, 0.2132],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])

Process finished with exit code 0

那么是什么意思呢,举个例子,这里我要强调一下,index即这个程序中的inx里面的每个数值,不能超过该dim的张量的最大下标,不然的话就会越界,找不到src中的源数据。因为是在dim=1上进行索引,所以采用第二个式子。

我们在索引表中找到index[1][3]=4,那么此时x=1,y=3,即output_y[1][index[1][3]]=src[1][3],即output_y[1][4]=src[1][3]。即x[1][3]。以此类推就可以得到其他的值。

src不仅仅可以是张量,也可以是标量,下面这段代码是模仿怎么生成独热码

import torch
x=torch.zeros(4,8)
label=torch.tensor([[1],[5],[7],[6]])
one_hot=x.scatter_(1,label,1)
print(one_hot)

其中x的第一个参数代表的是batch_size,第二个参数代表的是classnum,而label有batch_size行只有一列,就是将x每一行的label值指向的位置置成1,这就是独热码。当然其他位置都是0啦,下面看一下结果吧。

tensor([[0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1., 0.]])

Process finished with exit code 0

好啦,这就是scatter_()函数的用法。

ps:本来坚持不下去了快,但是把scatter弄清楚了发现还有一点动力学下去,加油吧。

 

pytorch中的scatter_()函数

标签:proc   代码   上进   ros   就是   att   ESS   下标   code   

原文地址:https://www.cnblogs.com/daremosiranaihana/p/12538512.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!