码迷,mamicode.com
首页 > 其他好文 > 详细

Pytorch的scatter()函数用法

时间:2020-12-23 11:41:16      阅读:0      评论:0      收藏:0      [点我收藏+]

标签:开始   参考   article   code   函数   rgb   div   修改   width   

scatter(dim, index, src)的三个参数为:

(1)dim:沿着哪个维度进行索引

(2)index: 用来scatter的元素索引

(3)src: 用来scatter的源元素,可以使一个标量也可以是一个张量

注:带_表示在原张量上修改。

二维例子如下:

1 y = y.scatter(dim,index,src)
2  
3 y [ index[i][j] ] [j] = src[i][j] #if dim==0
4 y[i] [ index[i][j] ]  = src[i][j] #if dim==1

实例如下:

 1 x = torch.rand(2, 5)
 2 
 3 #tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945],
 4 #        [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]])
 5 
 6 y = torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
 7 
 8 #tensor([[0.1940, 0.5978, 0.0074, 0.4269, 0.5945],
 9 #        [0.0000, 0.3340, 0.0000, 0.0943, 0.0000],
10 #        [0.2078, 0.0000, 0.8184, 0.0000, 0.0266]])

说明:

需要根据index(即 torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])) 来查找src的元素(即x ),从而得到结果y。

一开始进行 self[index[0][0]][0],其中 index[0][0] 的值是0,所以执行 self[0][0]=x[0][0]=0.1940 ,self[index[i][j]][j]=src[i][j]
再比如self[index[1][0]][0],其中 index[1][0] 的值是2,所以执行 self[2][0]=x[1][0]=0.2078 

 

如何确定最终需要修改y中的哪些元素呢?

个人认为根据index中的值及其索引。因为index有10个元素,所以最终y中有10个元素会被修改,具体如下:

技术图片

 

参考:https://www.cnblogs.com/dogecheng/p/11938009.html

           https://blog.csdn.net/t20134297/article/details/105755817

Pytorch的scatter()函数用法

标签:开始   参考   article   code   函数   rgb   div   修改   width   

原文地址:https://www.cnblogs.com/vvzhang/p/14152210.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!