_,predicted = torch.max(outputs.data,dim)

 dim=1时,按返回最大值所在索引

 dim=0时,按返回最大值所在索引

_,predicted = torch.max(outputs.data,dim):返回最大值所在索引

predicted = torch.max(outputs.data,dim):返回最大值


import torch
tensor = torch.tensor([[1.2,-3,-6.,-9.0,-4.56,5.678]])
_,predicted = torch.max(tensor,1)
print(predicted)

'''
返回最大值所在的索引
tensor([5])
'''


tensor = torch.tensor([[1.2,-3,-6.,-9.0,-4.56,5.678]])
predicted1 = torch.max(tensor,1)
print(predicted1)

'''
返回最大值和其所在索引
torch.return_types.max(
values=tensor([5.6780]),
indices=tensor([5]))
'''
import torch
#0按列返回
tensor = torch.tensor([[1.2,-3,-6.,-9.0,-4.56,5.678],[1,2,3,4,5,0]])
_,predicted = torch.max(tensor,0)
print(predicted)

'''
按列返回最大值所在的索引,此处只有两个分类结果,即0,1列
tensor([0, 1, 1, 1, 1, 0])
'''


tensor = torch.tensor([[1.2,-3,-6.,-9.0,-4.56,5.678],[1,2,3,4,5,0]])
predicted1 = torch.max(tensor,0)
print(predicted1)

'''
返回最大值和所在索引
torch.return_types.max(
values=tensor([1.2000, 2.
0000, 3.0000, 4.0000, 5.0000, 5.6780]),
indices=tensor([0, 1, 1, 1, 1, 0]))
'''

 


版权声明:本文为csweettea原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。