torch.gather的用法

torch.gather()函数用于收集数据。有两种用法,假如有一个tensor p, 则有:

torch.gather(p, dim = 1, index = p_i)

或者

p.gather(dim=1, index=p_i)

这里的dim是指在哪个维度上搜集值,1表示在行上搜集值,0表示在列上搜集值。

这里举例:

p = torch.tensor([[1, 2, 3], [4, 5, 6]])

# 根据行搜集数据组成新的矩阵
p_gather_index_dim1 = torch.tensor([[1, 2, 1, 2, 1], [2, 1, 2, 1, 2]])  # 形状为2*n
p1 = p.gather(dim=1, index=p_gather_index_dim1)
print(p1)

得到结果:

p为:

 p1为:

通过以上的例子说明几点:

1、传入.gather()方法的数据必须也是tensor,并且维度(ndim)与p相同;

2、p1可以无限制的从p中拿数据,但是有规则,当dim=1时,为行上取元素,因此p_gather_index_dim1的形状为(2,n),即必须也是2行,n表示第二个维度可以无穷多数。

3、这里的index = torch.tensor([1,2,1,2,1], [2,1,2,1,2])表示从p的第一行选取索引为1, 2, 1, 2, 1的元素(即2, 3, 2, 3, 2)组成p1的第一行,而从p的第二行选取索引为2,1,2,1,2的元素(即6, 5, 6, 5, 6)作为p2的第二行。

同理,当从dim = 0即从列取元素时,传入gather的index形状为(n, 3),即可以组成无穷多行,代码如下:

# 根据列搜集数据组成新的矩阵
p_gather_index_dim2 = torch.tensor([[0, 1, 0], [1, 0, 1], [1, 1, 0], [0, 0, 1]])  # 形状为n*3
p2 = p.gather(dim=0, index=p_gather_index_dim2)
print(p2)

 结果为:

p2:

全部代码如下:

import torch

p = torch.tensor([[1, 2, 3], [4, 5, 6]])  # 形状为2*3

# 根据行搜集数据组成新的矩阵
p_gather_index_dim1 = torch.tensor([[1, 2, 1, 2, 1], [2, 1, 2, 1, 2]])  # 形状为2*n
p1 = p.gather(dim=1, index=p_gather_index_dim1)
print(p)
print(p1)

# 根据列搜集数据组成新的矩阵
p_gather_index_dim2 = torch.tensor([[0, 1, 0], [1, 0, 1], [1, 1, 0], [0, 0, 1]])  # 形状为n*3
p2 = p.gather(dim=0, index=p_gather_index_dim2)
print(p2)

最后需要说明的是,对于本例(p的形状为2*3),当dim = 1时,p_gather_index_dim1的形状可以是1*n即只取第一行的数据,但是不能是3*n,因为p只有2行数据;而当dim = 0时,p_gather_index_dim2的形状还可以是n*1,即只按第一列取值,n*2,即只按前两列取值。但不能是n*4,因为p只有3列数据。


版权声明:本文为mengdeer_Q原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。