介绍:
- flatten()是对多维数据的降维函数
- flatten(),默认缺省参数为0
- 适用:numpy对象,即数组array或者矩阵MAT,普通的list列表不可以
- 出于:flatten是numpy.ndarray.flatten的一个函数
- 详细:python里的flatten(dim)表示,从第dim个维度开始展开,将后面的维度转化为一维.也就是说,只保留dim之前的维度,其他维度的数据全都挤在dim这一维。
举例:
import torch
a = torch.rand(2,3,4) # 0,1,2维
print(a.size())
# torch.Size([2, 3, 4])
flatten()函数的默认缺省参数为0,即可以理解为将所有维度数相乘,恢复原始数据
b = a.flatten(0) # 2*3*4
print(b.size())
# torch.Size([24])
保留第0维的维度,其他维度上的数字转成一维
b = a.flatten(1) # 2,3*4
print(b.size())
# torch.Size([2, 12])
保留第0,1维的维度,其他维度上的数字转成一维
b = a.flatten(2) # 2,3,4(除了第0,1维,其他维只有4)
print(b.size())
# torch.Size([2, 3, 4])
为深入理解,定义一个5维的数据。保留第0,1,2维的维度,将其他维度上的数字转成一维。可以直接想到,5维的数据,即包含0,1,2,3,4维,保留三个维度的数,则5维数据变成了4维。若保留0,1维,则返回一个3维的数据。
import torch
a = torch.rand(2,3,4,5,6) # 2,3,4,5*6
b = a.flatten(3)
b.size()
# torch.Size([2, 3, 4, 30])
版权声明:本文为qq_43083762原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。