坐下搬运工,记录几个重要的函数:
- Returns the maximum value of each row of the
inputTensor in the givendimensiondim. - The second return value is the index location of eachmaximum value found (argmax).
1.torch. max ( input, dim, keepdim=False, out=None) -> (Tensor, LongTensor )例如:
>> a = torch.randn(4, 4) >> a 0.0692 0.3142 1.2513 -0.5428 0.9288 0.8552 -0.2073 0.6409 1.0695 -0.0101 -2.4507 -1.2230 0.7426 -0.7666 0.4862 -0.6628 torch.FloatTensor of size 4x4] >>> torch.max(a, 1) ( 1.2513 0.9288 1.0695 0.7426 [torch.FloatTensor of size 4] , 2 0 0 0 [torch.LongTensor of size 4] ) 一般分类的话 都是使用location 或者index 作为类别 _, predicted = torch.max(outputs.data, 1)
版权声明:本文为weixin_37541676原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。