解决办法 pred = torch.max(a,1,keepdim=True)[1]

pred = torch.max(a,1,keepdim=True)[1]
TypeError: torch.max received an invalid combination of arguments - got (torch.LongTensor, int, keepdim=bool), but expected one of:
 * (torch.LongTensor source)
 * (torch.LongTensor source, torch.LongTensor other)
 * (torch.LongTensor source, int dim)

      didn't match because some of the keywords were incorrect: keepdim


上面错误解决办法

把原先的代码

 pred = output.data.max(1, keepdim=True)[1]

改为

 pred = torch.max(output.data,1)[1] 


参考文章

torch.max


版权声明:本文为zhuoyuezai原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。