- flatten
- reshape
- squeeze explained
shape
关于shape:
t = torch.tensor([
[1,1,1,1],
[2,2,2,2],
[3,3,3,3]
])
#get shape
t.size() -> torch.Size([3, 4])
t.shape -> torch.Size([3, 4])
#get rank
len(t.shape) -> 2
#get number of elements 获取元素数
torch.tensor(t.shape).prod() -> tensor(12)
t.numel() -> 12
reshape()
t.reshape()
# keep rank 保持维数的reshape
t.reshape(12, 1)
t.reshape(4, 3)
t.squeeze()和t.unsqueeze()
t.squeeze():移除长度为1的维度。
t.unsqueeze():增加一个长度为1的维度。
# change rank 改变维数的reshape
print(t.reshape(1,12).shape) -> torch.Size([1, 12])
print(t.reshape(1,12).squeeze().shape) -> torch.Size([12])
print(t.reshape(1,12).squeeze().unsqueeze(dim=0).shape) -> torch.Size([1, 12])
flatten()
Conv卷积层->FC全连接,中间需要使用flatten(),将每张图片转为一维。
def flatten(t):
t = t.reshape(1, -1) #-> torch.Size([1, 12])
t = t.squeeze() #-> torch.Size([12])
return t
一个batch:
t = torch.stack((t1, t2, t3)) #-> torch.Size([3, 4, 4])
t = t.reshape(3,1,4,4) #-> torch.Size([3, 1, 4, 4])
由于图片通常以batch形式转为tensor传入FC,所以每个batch对应的tensor尺寸为[batch_size, H x W],为了将输入的[batch_size, n_chanels, H, W]转为[batch_size, H x W x C]:
t.flatten(start_dim=1) #-> torch.Size([3, 16])
版权声明:本文为Ztomepic原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。