torch.gather()函数用于收集数据。有两种用法,假如有一个tensor p, 则有:
torch.gather(p, dim = 1, index = p_i)
或者
p.gather(dim=1, index=p_i)
这里的dim是指在哪个维度上搜集值,1表示在行上搜集值,0表示在列上搜集值。
这里举例:
p = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 根据行搜集数据组成新的矩阵
p_gather_index_dim1 = torch.tensor([[1, 2, 1, 2, 1], [2, 1, 2, 1, 2]]) # 形状为2*n
p1 = p.gather(dim=1, index=p_gather_index_dim1)
print(p1)得到结果:
p为:

p1为:

通过以上的例子说明几点:
1、传入.gather()方法的数据必须也是tensor,并且维度(ndim)与p相同;
2、p1可以无限制的从p中拿数据,但是有规则,当dim=1时,为行上取元素,因此p_gather_index_dim1的形状为(2,n),即必须也是2行,n表示第二个维度可以无穷多数。
3、这里的index = torch.tensor([1,2,1,2,1], [2,1,2,1,2])表示从p的第一行选取索引为1, 2, 1, 2, 1的元素(即2, 3, 2, 3, 2)组成p1的第一行,而从p的第二行选取索引为2,1,2,1,2的元素(即6, 5, 6, 5, 6)作为p2的第二行。
同理,当从dim = 0即从列取元素时,传入gather的index形状为(n, 3),即可以组成无穷多行,代码如下:
# 根据列搜集数据组成新的矩阵
p_gather_index_dim2 = torch.tensor([[0, 1, 0], [1, 0, 1], [1, 1, 0], [0, 0, 1]]) # 形状为n*3
p2 = p.gather(dim=0, index=p_gather_index_dim2)
print(p2)结果为:
p2:

全部代码如下:
import torch
p = torch.tensor([[1, 2, 3], [4, 5, 6]]) # 形状为2*3
# 根据行搜集数据组成新的矩阵
p_gather_index_dim1 = torch.tensor([[1, 2, 1, 2, 1], [2, 1, 2, 1, 2]]) # 形状为2*n
p1 = p.gather(dim=1, index=p_gather_index_dim1)
print(p)
print(p1)
# 根据列搜集数据组成新的矩阵
p_gather_index_dim2 = torch.tensor([[0, 1, 0], [1, 0, 1], [1, 1, 0], [0, 0, 1]]) # 形状为n*3
p2 = p.gather(dim=0, index=p_gather_index_dim2)
print(p2)
最后需要说明的是,对于本例(p的形状为2*3),当dim = 1时,p_gather_index_dim1的形状可以是1*n即只取第一行的数据,但是不能是3*n,因为p只有2行数据;而当dim = 0时,p_gather_index_dim2的形状还可以是n*1,即只按第一列取值,n*2,即只按前两列取值。但不能是n*4,因为p只有3列数据。
版权声明:本文为mengdeer_Q原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。