torch.cat()张量拼接

torch.cat((a,b),dim),两个向量在dim维度上拼接,要求被拼接的矩阵在另一个维度相等

import torch
A=torch.ones(2,3) #2x3的张量(矩阵)
B=2*torch.ones(4,3)#4x3的张量(矩阵)
C=torch.cat((A,B),0)#按维数0(行)拼接,要求两个矩阵另一个维度(列)相等
print(C)#
# tensor([[1., 1., 1.],
#         [1., 1., 1.],
#         [2., 2., 2.],
#         [2., 2., 2.],
#         [2., 2., 2.],
#         [2., 2., 2.]])
print(C.size())#torch.Size([6, 3])
A=torch.ones(4,1) #2x3的张量(矩阵)
B=2*torch.ones(4,3)#4x3的张量(矩阵)
d=torch.cat((A,B),1)#按维数1(列)拼接,要求连个矩阵另一个维度相等,即这里要求两个矩阵行相等
print('d',d)
#tensor([[1., 2., 2., 2.],
        # [1., 2., 2., 2.],
        # [1., 2., 2., 2.],
        # [1., 2., 2., 2.]])
print(d.size())#([4, 4])

 


版权声明:本文为weixin_38145317原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。