Pytorch 扩展单张图片维度@Elaine
训练数据一般都是(b,c,h,w),在测试时只输入一张图片,所以需要扩展维度,下面是扩展维度方法
import cv2
import torch
image = cv2.imread(img_path)
image = torch.tensor(image)
print(image.size())
img = image.view(1, *image.size())
print(img.size())
#output:
#torch.Size([h, w, c])
#torch.Size([1, h, w, c])
或
import cv2
import numpy as np
image = cv2.imread(img_path)
print(image.shape)
img = image[np.newaxis, :, :, :]
print(img.shape)
# output:
# (h, w, c)
# (1, h, w, c)
或
import cv2
import torch
image = cv2.imread(img_path)
image = torch.tensor(image)
print(image.size())
img = image.unsqueeze(dim=0)
print(img.size())
img = img.squeeze(dim=0)
print(img.size())
# output:
# torch.Size([(h, w, c)])
# torch.Size([1, h, w, c])
tensor.unsqueeze(dim):扩展维度,dim指定扩展哪个维度。
tensor.squeeze(dim):去除dim指定的且size为1的维度,维度大于1时,squeeze()不起作用,不指定dim时,去除所有size为1的维度。