Python:flatten()函数用法

介绍:

  • flatten()是对多维数据的降维函数
  • flatten(),默认缺省参数为0
  • 适用:numpy对象,即数组array或者矩阵MAT,普通的list列表不可以
  • 出于:flatten是numpy.ndarray.flatten的一个函数
  • 详细:python里的flatten(dim)表示,从第dim个维度开始展开,将后面的维度转化为一维.也就是说,只保留dim之前的维度,其他维度的数据全都挤在dim这一维。

举例:

import torch
a = torch.rand(2,3,4) # 0,1,2维
print(a.size())

# torch.Size([2, 3, 4])

flatten()函数的默认缺省参数为0,即可以理解为将所有维度数相乘,恢复原始数据

b = a.flatten(0)  # 2*3*4
print(b.size()) 

# torch.Size([24])

保留第0维的维度,其他维度上的数字转成一维

b = a.flatten(1) # 2,3*4
print(b.size())

# torch.Size([2, 12])

保留第0,1维的维度,其他维度上的数字转成一维

b = a.flatten(2) # 2,3,4(除了第0,1维,其他维只有4)
print(b.size())

# torch.Size([2, 3, 4])

为深入理解,定义一个5维的数据。保留第0,1,2维的维度,将其他维度上的数字转成一维。可以直接想到,5维的数据,即包含0,1,2,3,4维,保留三个维度的数,则5维数据变成了4维。若保留0,1维,则返回一个3维的数据。

import torch
a = torch.rand(2,3,4,5,6)  # 2,3,4,5*6
b = a.flatten(3)
b.size()

# torch.Size([2, 3, 4, 30])

参考:
皮皮宽:python:flatten()参数详解

Mingsheng Zhang:flatten()函数用法


版权声明:本文为qq_43083762原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。