pytorch中可以改变维度的操作有torch.reshape,torch.view,torch.repeat,torch.expand,torch.permute,torch.transpose
torch.reshape与torch.view
import torch
#torch.view和torch.reshape是一样的,torch.reshape/torch.view与input共用基础数据,改变其中一个,另一个也跟着改变
t = torch.rand(4,4)
print(t)
b = t.view(2,8) #按行优先对原tensor进行重排
print(b)
print(t)
c = t.reshape(2,8)
print(c==b)
print(t.storage().data_ptr()==b.storage().data_ptr())
print(t.storage().data_ptr()==c.storage().data_ptr())
print(t.is_contiguous(),b.is_contiguous(),c.is_contiguous())
tensor([[0.9270, 0.6583, 0.7644, 0.2253],
[0.2529, 0.6430, 0.2869, 0.6918],
[0.9237, 0.3050, 0.9217, 0.7309],
[0.5239, 0.6444, 0.9402, 0.9589]])
tensor([[0.9270, 0.6583, 0.7644, 0.2253, 0.2529, 0.6430, 0.2869, 0.6918],
[0.9237, 0.3050, 0.9217, 0.7309, 0.5239, 0.6444, 0.9402, 0.9589]])
tensor([[0.9270, 0.6583, 0.7644, 0.2253],
[0.2529, 0.6430, 0.2869, 0.6918],
[0.9237, 0.3050, 0.9217, 0.7309],
[0.5239, 0.6444, 0.9402, 0.9589]])
tensor([[True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True]])
True
True
True True True
t[3,3]=1.0
print('t:',t)
print('b:',b)
print('c:',c)
b[0,4]=2.0
print('t:',t)
print('b:',b)
print('c:',c)
c[1,0]=3.0
print('t:',t)
print('b:',b)
print('c:',c)
t: tensor([[0.9270, 0.6583, 0.7644, 0.2253],
[0.2529, 0.6430, 0.2869, 0.6918],
[0.9237, 0.3050, 0.9217, 0.7309],
[0.5239, 0.6444, 0.9402, 1.0000]])
b: tensor([[0.9270, 0.6583, 0.7644, 0.2253, 0.2529, 0.6430, 0.2869, 0.6918],
[0.9237, 0.3050, 0.9217, 0.7309, 0.5239, 0.6444, 0.9402, 1.0000]])
c: tensor([[0.9270, 0.6583, 0.7644, 0.2253, 0.2529, 0.6430, 0.2869, 0.6918],
[0.9237, 0.3050, 0.9217, 0.7309, 0.5239, 0.6444, 0.9402, 1.0000]])
t: tensor([[0.9270, 0.6583, 0.7644, 0.2253],
[2.0000, 0.6430, 0.2869, 0.6918],
[0.9237, 0.3050, 0.9217, 0.7309],
[0.5239, 0.6444, 0.9402, 1.0000]])
b: tensor([[0.9270, 0.6583, 0.7644, 0.2253, 2.0000, 0.6430, 0.2869, 0.6918],
[0.9237, 0.3050, 0.9217, 0.7309, 0.5239, 0.6444, 0.9402, 1.0000]])
c: tensor([[0.9270, 0.6583, 0.7644, 0.2253, 2.0000, 0.6430, 0.2869, 0.6918],
[0.9237, 0.3050, 0.9217, 0.7309, 0.5239, 0.6444, 0.9402, 1.0000]])
t: tensor([[0.9270, 0.6583, 0.7644, 0.2253],
[2.0000, 0.6430, 0.2869, 0.6918],
[3.0000, 0.3050, 0.9217, 0.7309],
[0.5239, 0.6444, 0.9402, 1.0000]])
b: tensor([[0.9270, 0.6583, 0.7644, 0.2253, 2.0000, 0.6430, 0.2869, 0.6918],
[3.0000, 0.3050, 0.9217, 0.7309, 0.5239, 0.6444, 0.9402, 1.0000]])
c: tensor([[0.9270, 0.6583, 0.7644, 0.2253, 2.0000, 0.6430, 0.2869, 0.6918],
[3.0000, 0.3050, 0.9217, 0.7309, 0.5239, 0.6444, 0.9402, 1.0000]])
torch.expand()和torch.repeat()
import torch
#torch.expand()和torch.repeat()可以用来扩展tensor某维的尺寸
#torch.expand()不会分配新的内存,新的tensor只是原tensor的一个view
#torch.repeat()沿着特定的维度重复这个tensor,它会拷贝原tensor的数据形成新的tensor
x = torch.tensor([[1],[2],[3]])
print(x.shape,x)
c = x.expand(3,4)
d = x.expand(-1,3) #-1表示当前维度表示不变,类似于x.expand(3,3)
print(c.shape,c)
print(d.shape,d)
#改变x,c,d中任意一个值,另两个值也会跟着改变
x[1,0]=5
print(x,c,d)
c[1,3]=6
print(x,c,d)
torch.Size([3, 1]) tensor([[1],
[2],
[3]])
torch.Size([3, 4]) tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]])
torch.Size([3, 3]) tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
tensor([[1],
[5],
[3]]) tensor([[1, 1, 1, 1],
[5, 5, 5, 5],
[3, 3, 3, 3]]) tensor([[1, 1, 1],
[5, 5, 5],
[3, 3, 3]])
tensor([[1],
[6],
[3]]) tensor([[1, 1, 1, 1],
[6, 6, 6, 6],
[3, 3, 3, 3]]) tensor([[1, 1, 1],
[6, 6, 6],
[3, 3, 3]])
x = torch.tensor([[1],[2],[3]])
print(x.shape,x)
e = x.repeat(3,4)#根据给定的size,每个维度扩展size对应数
print(e.shape,e)
f = x.repeat(2,3,4)
print(f.shape,f)
#改变x,c,d中任意一个值,另两个值不会跟着改变
x[1,0]=5
print(x,e,f)
e[1,3]=6
print(x,e,f)
torch.Size([3, 1]) tensor([[1],
[2],
[3]])
torch.Size([9, 4]) tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]])
torch.Size([2, 9, 4]) tensor([[[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]],
[[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]]])
tensor([[1],
[5],
[3]]) tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]]) tensor([[[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]],
[[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]]])
tensor([[1],
[5],
[3]]) tensor([[1, 1, 1, 1],
[2, 2, 2, 6],
[3, 3, 3, 3],
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]]) tensor([[[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]],
[[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]]])
torch.permute与torch.transpose
#torch.permute,torch.transpose
#torch.transpose(input, dim0, dim1) 交换给定的dim0和dim1,返回原tensor的view,改变其中一个值,也会导致另一个值改变
x = torch.randn(2,3)
c = torch.transpose(x,0,1)
print(x)
print(c)
print(c.is_contiguous())
c[0,1]=2.0
print(x)
print(c)
tensor([[-0.6876, -0.4264, 0.3218],
[-0.2164, -2.4182, 2.0601]])
tensor([[-0.6876, -0.2164],
[-0.4264, -2.4182],
[ 0.3218, 2.0601]])
False
tensor([[-0.6876, -0.4264, 0.3218],
[ 2.0000, -2.4182, 2.0601]])
tensor([[-0.6876, 2.0000],
[-0.4264, -2.4182],
[ 0.3218, 2.0601]])
#permute()可以一次操作多个维度,且每次操作必须传入所有维度,另外没有torch.permute,只有x.permute(),返回原tensor的view,改变其中一个值,也会导致另一个值改变
x = torch.randn(1,2,3)
print(x)
c = x.permute(0,2,1)
print(c)
print(c.is_contiguous())
c[0,1,0]=3.0
print(x)
print(c)
tensor([[[-1.3573, 2.5450, -1.4865],
[ 0.9934, 0.4449, -0.1048]]])
tensor([[[-1.3573, 0.9934],
[ 2.5450, 0.4449],
[-1.4865, -0.1048]]])
False
tensor([[[-1.3573, 3.0000, -1.4865],
[ 0.9934, 0.4449, -0.1048]]])
tensor([[[-1.3573, 0.9934],
[ 3.0000, 0.4449],
[-1.4865, -0.1048]]])
#在使用了transpose()或者permute(),会导致tensor的内存不连续,如果这个时候view()就会报错
# '''报错
# RuntimeError: view size is not compatible with input tensor's size and stride.......
# '''
#需要先contiguous()强制连续再view()。但是对于reshape()不需要tensor连续,x.reshape()等价于x.contiguous().view()
x=torch.randn(3,4)
x=x.transpose(0,1) #x.transpose_(0,1)等价于x=x.transpose(0,1)类似于x+=1和x=x+1
print(x.shape)
#x.view(3,4) #会报错
x = x.contiguous().view(3,4) #等价于x=x.reshape(3,4)
print(x.shape)
torch.Size([4, 3])
torch.Size([3, 4])
版权声明:本文为whitesilence原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。