Pytorch-2:张量关于shape的操作

  • flatten
  • reshape
  • squeeze explained

shape

关于shape:

t = torch.tensor([
		[1,1,1,1],
		[2,2,2,2],
		[3,3,3,3]
])

#get shape 
t.size()				->		torch.Size([3, 4])
t.shape			->		torch.Size([3, 4])

#get rank 
len(t.shape)		->		2

#get number of elements 获取元素数
torch.tensor(t.shape).prod()	->		tensor(12)
t.numel()									->		12

reshape()

t.reshape()

# keep rank 保持维数的reshape
t.reshape(12, 1)
t.reshape(4, 3)

t.squeeze()t.unsqueeze()
t.squeeze():移除长度为1的维度。
t.unsqueeze():增加一个长度为1的维度。

# change rank 改变维数的reshape
print(t.reshape(1,12).shape)													->	torch.Size([1, 12])

print(t.reshape(1,12).squeeze().shape)									->	torch.Size([12])

print(t.reshape(1,12).squeeze().unsqueeze(dim=0).shape)	->	torch.Size([1, 12])

flatten()

Conv卷积层->FC全连接,中间需要使用flatten(),将每张图片转为一维。

def flatten(t):
	t = t.reshape(1, -1)		#->	torch.Size([1, 12])
	t = t.squeeze()				#->	torch.Size([12])
	return t

一个batch:

t = torch.stack((t1, t2, t3))		#->	torch.Size([3, 4, 4])
t = t.reshape(3,1,4,4)				#->	torch.Size([3, 1, 4, 4])

由于图片通常以batch形式转为tensor传入FC,所以每个batch对应的tensor尺寸为[batch_size, H x W],为了将输入的[batch_size, n_chanels, H, W]转为[batch_size, H x W x C]:

t.flatten(start_dim=1)				#->	torch.Size([3, 16])

版权声明:本文为Ztomepic原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。