张量的拼接

张量的拼接

通过 torch.cat() 函数,可以把若干个张量按照指定的维度拼接起来。

  • dim=0 表示沿着行索引(row index)的方向拼接,即竖直拼接
  • dim=1 表示沿着列索引(column index)的方向拼接,即水平拼接
x = torch.rand(5, 3)
y = torch.rand(2, 3)
z = torch.rand(5, 4)
out_1 = torch.cat([x, y], dim=0)  # 沿着行索引的方向拼接,两个张量的列数需一致
out_2 = torch.cat((x, z), dim=1)  # 沿着列索引的方向拼接,两个张量的行数需一致
print(x)
print(y)
print(z)
print(out_1)
print(out_2)
---------
tensor([[0.1910, 0.3205, 0.8193],
        [0.7118, 0.2205, 0.3091],
        [0.9017, 0.2311, 0.7343],
        [0.4640, 0.1793, 0.8087],
        [0.7302, 0.6814, 0.3770]])
tensor([[0.7730, 0.2388, 0.4679],
        [0.1334, 0.9005, 0.8307]])
tensor([[0.2889, 0.0683, 0.1202, 0.1321],
        [0.7592, 0.1536, 0.6919, 0.1438],
        [0.5249, 0.0094, 0.9866, 0.4137],
        [0.0555, 0.1812, 0.7475, 0.6379],
        [0.0140, 0.7772, 0.5915, 0.1765]])
tensor([[0.1910, 0.3205, 0.8193],
        [0.7118, 0.2205, 0.3091],
        [0.9017, 0.2311, 0.7343],
        [0.4640, 0.1793, 0.8087],
        [0.7302, 0.6814, 0.3770],
        [0.7730, 0.2388, 0.4679],
        [0.1334, 0.9005, 0.8307]])
tensor([[0.1910, 0.3205, 0.8193, 0.2889, 0.0683, 0.1202, 0.1321],
        [0.7118, 0.2205, 0.3091, 0.7592, 0.1536, 0.6919, 0.1438],
        [0.9017, 0.2311, 0.7343, 0.5249, 0.0094, 0.9866, 0.4137],
        [0.4640, 0.1793, 0.8087, 0.0555, 0.1812, 0.7475, 0.6379],
        [0.7302, 0.6814, 0.3770, 0.0140, 0.7772, 0.5915, 0.1765]])

# torch.concat() 与 torch.cat() 是一样的

torch.stack() 函数也可以进行张量的拼接,该函数沿着一个新的维度对输入的张量序列进行拼接,所有张量的 size 必须相同。
换句话说,就是把多个具有相同 size 的 2 维张量拼接成一个 3 维张量,把多个具有相同 size 的 3 维张量拼接成一个 4 维张量,以此类推,沿着新的维度进行堆叠。

参数:

  • inputs:待拼接的张量序列
    注:序列一般为 list 和 tuple
  • dim:新的维度,必须在 [0, tensor.dim()] 内
    注:通过 tensor.dim() 可以知道张量的维度
x = torch.tensor([[1, 2, 3, 4],
                 [5, 6, 7, 8],
                 [9, 10, 11, 12]])
y = torch.tensor([[10, 20, 30, 40],
                 [50, 60, 70, 80],
                 [90, 100, 110, 120]])
out_1 = torch.stack([x, y], dim=0)  # 在第 0 个维度上增加新维度 --> [2, 3, 4]
out_2 = torch.stack([x, y], dim=1)  # 在第 1 个维度上增加新维度 --> [3, 2, 4]
out_3 = torch.stack([x, y], dim=2)  # 在第 2 个维度上增加新维度 --> [3, 4, 2]
print(out_1, out_1.shape)
print(out_2, out_2.shape)
print(out_3, out_3.shape)
---------
tensor([[[  1,   2,   3,   4],
         [  5,   6,   7,   8],
         [  9,  10,  11,  12]],

        [[ 10,  20,  30,  40],
         [ 50,  60,  70,  80],
         [ 90, 100, 110, 120]]]) torch.Size([2, 3, 4])
tensor([[[  1,   2,   3,   4],
         [ 10,  20,  30,  40]],

        [[  5,   6,   7,   8],
         [ 50,  60,  70,  80]],

        [[  9,  10,  11,  12],
         [ 90, 100, 110, 120]]]) torch.Size([3, 2, 4])
tensor([[[  1,  10],
         [  2,  20],
         [  3,  30],
         [  4,  40]],

        [[  5,  50],
         [  6,  60],
         [  7,  70],
         [  8,  80]],

        [[  9,  90],
         [ 10, 100],
         [ 11, 110],
         [ 12, 120]]]) torch.Size([3, 4, 2])

版权声明:本文为weixin_48158964原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。