pytorch

坐下搬运工,记录几个重要的函数:


1.torch. max ( input, dim, keepdim=False, out=None) -> (Tensor, LongTensor )
Returns the maximum value of each row of the input Tensor in the givendimension dim.
The second return value is the index location of eachmaximum value found (argmax).


例如:

>> a = torch.randn(4, 4)
>> a

0.0692  0.3142  1.2513 -0.5428
0.9288  0.8552 -0.2073  0.6409
1.0695 -0.0101 -2.4507 -1.2230
0.7426 -0.7666  0.4862 -0.6628
torch.FloatTensor of size 4x4]

>>> torch.max(a, 1)
(
 1.2513
 0.9288
 1.0695
 0.7426
[torch.FloatTensor of size 4]
,
 2
 0
 0
 0
[torch.LongTensor of size 4]
)


一般分类的话 都是使用location 或者index 作为类别

_, predicted = torch.max(outputs.data, 1)








 


版权声明:本文为weixin_37541676原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。