Pytorch学习-torch.max()和min()深度解析

max的使用 min同理

参考链接:
参考链接:

对于tensorA和tensorB:
1)torch.max(tensorA) 返回tensor中的最大值
2)torch.max(tensorA,dim)  返回指定维度的最大数和对应下标
3)torch.max(tensorA,tensorB) 比较tensorA和tensorB相对较大的元素

dim参数理解

搞清楚dim参数
第0维是行,第1维是列!!!
结论:
1)dim=0 查找每列的最大值,返回行下标索引
2)dim=1 查找每行的最大值,返回列下标索引
3)不添加dim参数,返回所有值中的最大值,且无索引

二维张量使用max()

t=torch.randn(2,3)
print(t)
print("-------max dim=0 -------")
print(torch.max(t,dim=0))
print("-------max dim=1 -------")
print(torch.max(t,dim=1))

结果:

tensor([[ 0.0231,  0.2109, -1.6104],
        [-0.5777, -1.3870, -0.9925]])
-------max dim=0 -------
torch.return_types.max(
values=tensor([ 0.0231,  0.2109, -0.9925]),
indices=tensor([0, 0, 1]))
-------max dim=1 -------
torch.return_types.max(
values=tensor([ 0.2109, -0.5777]),
indices=tensor([1, 0]))

**???疑问:**为什么0维是行,但是max时返回是列中的最大值呢?
理解:!!在其他维度均确定的情况下,比较所有dim维对应的数据,找到其中的最大值,并返回索引。
比如:
dim=0时 除了[0]维 还有[1]两个维度
第一列 遍历两行得到 [0][0] 和 [1][0] max为0.0231
第二列 遍历两行得到 [0][1] 和 [0][2] max为0.2109
第三列 遍历两行得到 [1][1] 和 [1][2] max为-0.9925

三维张量使用max()

第0维顺着层,第1维顺着行,第2维度顺着列

t = torch.randn(2,2,2)
print(t)
print("-------max dim=0 -------")
print(torch.max(t,dim=0))

print("-------max dim=1 -------")
print(torch.max(t,dim=1))

print("-------max dim=2 -------")
print(torch.max(t,dim=2))

结果:

tensor([[[-1.6519, -0.3087],
         [-0.6982,  0.4515]],

        [[-0.4648,  0.8958],
         [-1.4150, -1.4633]]])
-------max dim=0 -------  [[-0.4648,  0.8958],[-0.6982,  0.4515]] [[1,1],[0,0]] 列确定 比较行
torch.return_types.max(
values=tensor([[-0.4648,  0.8958],
        [-0.6982,  0.4515]]),
indices=tensor([[1, 1],
        [0, 0]]))
-------max dim=1 ------- [[-0.6982, 0.4515],[-0.4648,0.8958]] [[],[]][0][0][0][0][1][0]),([0][0][1][0][1][1]),([1][0][0][1][1][0]),([1][0][1][1][1][1])
torch.return_types.max(
values=tensor([[-0.6982,  0.4515],
        [-0.4648,  0.8958]]),
indices=tensor([[1, 1],
        [0, 0]]))
-------max dim=2 ------- [0][0]_,[0][1]_,[1][0]_,[1][1]_  ([0][0][0],[0][0][1]) ([0][1][0],[0][1][1]) ([1][0][0],[1][0][1]) ([1][1][0],[1][1][1])
torch.return_types.max(
values=tensor([[-0.3087,  0.4515],
        [ 0.8958, -1.4150]]),
indices=tensor([[1, 1],
        [1, 0]]))

版权声明:本文为qq_36930921原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。