常用tensor值处理变换(tensor组成的list转多维tensor;多维tensor数组数据变成二维)

常见一些python包对数据输入格式有要求,要求为数组narray,但有的输入需要是tensor格式.

1.tensor 转数组格式,在之前文章有描述过:tensor转narray
2.当多个tensor值通过.append()拼接成list格式时,改系列数据展示的则是list属性,此时需要继续变为tensor值之后再进行相关操作,可进行如下操作:final_output = torch.stack(output_list)
利用stack()函数;另外也可以使用torch.cat()函数拼接,但是需要指定维度。
3.将多维数据变成二维再进行操作:

def flatten(tensor):
    """Flattens a given tensor such that the channel axis is first.
    The shapes are transformed as follows:
       (N, C, D, H, W) -> (C, N * D * H * W)
    """
    # number of channels
    C = tensor.size(1)
    # new axis order
    axis_order = (1, 0) + tuple(range(2, tensor.dim()))
    # Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
    transposed = tensor.permute(axis_order)
    # Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
    return transposed.contiguous().view(C, -1)

然后调用上述函数即可:

input=flatten(final_output)
4.多维tensor数组之间的相乘(*)默认是对应位置相乘,得到新的相同维度的tensor数组


版权声明:本文为ZXQ_JAY原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。