张量的拼接
通过 torch.cat() 函数,可以把若干个张量按照指定的维度拼接起来。
- dim=0 表示沿着行索引(row index)的方向拼接,即竖直拼接
- dim=1 表示沿着列索引(column index)的方向拼接,即水平拼接
x = torch.rand(5, 3)
y = torch.rand(2, 3)
z = torch.rand(5, 4)
out_1 = torch.cat([x, y], dim=0) # 沿着行索引的方向拼接,两个张量的列数需一致
out_2 = torch.cat((x, z), dim=1) # 沿着列索引的方向拼接,两个张量的行数需一致
print(x)
print(y)
print(z)
print(out_1)
print(out_2)
---------
tensor([[0.1910, 0.3205, 0.8193],
[0.7118, 0.2205, 0.3091],
[0.9017, 0.2311, 0.7343],
[0.4640, 0.1793, 0.8087],
[0.7302, 0.6814, 0.3770]])
tensor([[0.7730, 0.2388, 0.4679],
[0.1334, 0.9005, 0.8307]])
tensor([[0.2889, 0.0683, 0.1202, 0.1321],
[0.7592, 0.1536, 0.6919, 0.1438],
[0.5249, 0.0094, 0.9866, 0.4137],
[0.0555, 0.1812, 0.7475, 0.6379],
[0.0140, 0.7772, 0.5915, 0.1765]])
tensor([[0.1910, 0.3205, 0.8193],
[0.7118, 0.2205, 0.3091],
[0.9017, 0.2311, 0.7343],
[0.4640, 0.1793, 0.8087],
[0.7302, 0.6814, 0.3770],
[0.7730, 0.2388, 0.4679],
[0.1334, 0.9005, 0.8307]])
tensor([[0.1910, 0.3205, 0.8193, 0.2889, 0.0683, 0.1202, 0.1321],
[0.7118, 0.2205, 0.3091, 0.7592, 0.1536, 0.6919, 0.1438],
[0.9017, 0.2311, 0.7343, 0.5249, 0.0094, 0.9866, 0.4137],
[0.4640, 0.1793, 0.8087, 0.0555, 0.1812, 0.7475, 0.6379],
[0.7302, 0.6814, 0.3770, 0.0140, 0.7772, 0.5915, 0.1765]])
# torch.concat() 与 torch.cat() 是一样的
torch.stack() 函数也可以进行张量的拼接,该函数沿着一个新的维度对输入的张量序列进行拼接,所有张量的 size 必须相同。
换句话说,就是把多个具有相同 size 的 2 维张量拼接成一个 3 维张量,把多个具有相同 size 的 3 维张量拼接成一个 4 维张量,以此类推,沿着新的维度进行堆叠。
参数:
- inputs:待拼接的张量序列
注:序列一般为 list 和 tuple - dim:新的维度,必须在 [0, tensor.dim()] 内
注:通过 tensor.dim() 可以知道张量的维度
x = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
y = torch.tensor([[10, 20, 30, 40],
[50, 60, 70, 80],
[90, 100, 110, 120]])
out_1 = torch.stack([x, y], dim=0) # 在第 0 个维度上增加新维度 --> [2, 3, 4]
out_2 = torch.stack([x, y], dim=1) # 在第 1 个维度上增加新维度 --> [3, 2, 4]
out_3 = torch.stack([x, y], dim=2) # 在第 2 个维度上增加新维度 --> [3, 4, 2]
print(out_1, out_1.shape)
print(out_2, out_2.shape)
print(out_3, out_3.shape)
---------
tensor([[[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]],
[[ 10, 20, 30, 40],
[ 50, 60, 70, 80],
[ 90, 100, 110, 120]]]) torch.Size([2, 3, 4])
tensor([[[ 1, 2, 3, 4],
[ 10, 20, 30, 40]],
[[ 5, 6, 7, 8],
[ 50, 60, 70, 80]],
[[ 9, 10, 11, 12],
[ 90, 100, 110, 120]]]) torch.Size([3, 2, 4])
tensor([[[ 1, 10],
[ 2, 20],
[ 3, 30],
[ 4, 40]],
[[ 5, 50],
[ 6, 60],
[ 7, 70],
[ 8, 80]],
[[ 9, 90],
[ 10, 100],
[ 11, 110],
[ 12, 120]]]) torch.Size([3, 4, 2])
版权声明:本文为weixin_48158964原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。