NestedTensor是DETR代码中用的一个数据结构
包括tensor和mask两个成员,tensor就是输入的图像。mask跟tensor同高宽但是单通道。
tensors:获取整个batch里面最大的w,h,用0 padding补齐(右,下padding)。
mask:宽高与图像对应,除padding位置为true外,其他位置都为false。最后用的时候会取反,就是补全的地方是0,图像填充的地方用1,make sence。mask是用在Transformer中的
要想将tensors和masks分开用
x.decompose()
源码
class NestedTensor(object): def __init__(self, tensors, mask: Optional[Tensor]): self.tensors = tensors self.mask = mask if mask == 'auto': self.mask = torch.zeros_like(tensors).to(tensors.device) if self.mask.dim() == 3: self.mask = self.mask.sum(0).to(bool) elif self.mask.dim() == 4: self.mask = self.mask.sum(1).to(bool) else: raise ValueError("tensors dim must be 3 or 4 but {}({})".format(self.tensors.dim(), self.tensors.shape)) def imgsize(self): res = [] for i in range(self.tensors.shape[0]): mask = self.mask[i] maxH = (~mask).sum(0).max() maxW = (~mask).sum(1).max() res.append(torch.Tensor([maxH, maxW])) return res def to(self, device): # type: (Device) -> NestedTensor # noqa cast_tensor = self.tensors.to(device) mask = self.mask if mask is not None: assert mask is not None cast_mask = mask.to(device) else: cast_mask = None return NestedTensor(cast_tensor, cast_mask) def to_img_list_single(self, tensor, mask): assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format(tensor.dim()) maxH = (~mask).sum(0).max() maxW = (~mask).sum(1).max() img = tensor[:, :maxH, :maxW] return img def to_img_list(self): """remove the padding and convert to img list Returns: [type]: [description] """ if self.tensors.dim() == 3: return self.to_img_list_single(self.tensors, self.mask) else: res = [] for i in range(self.tensors.shape[0]): tensor_i = self.tensors[i] mask_i = self.mask[i] res.append(self.to_img_list_single(tensor_i, mask_i)) return res @property def device(self): return self.tensors.device def decompose(self): return self.tensors, self.mask def __repr__(self): return str(self.tensors) @property def shape(self): return { 'tensors.shape': self.tensors.shape, 'mask.shape': self.mask.shape }
版权声明:本文为hxxjxw原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。