pred = torch.max(a,1,keepdim=True)[1]
TypeError: torch.max received an invalid combination of arguments - got (torch.LongTensor, int, keepdim=bool), but expected one of:
* (torch.LongTensor source)
* (torch.LongTensor source, torch.LongTensor other)
* (torch.LongTensor source, int dim)
TypeError: torch.max received an invalid combination of arguments - got (torch.LongTensor, int, keepdim=bool), but expected one of:
* (torch.LongTensor source)
* (torch.LongTensor source, torch.LongTensor other)
* (torch.LongTensor source, int dim)
didn't match because some of the keywords were incorrect: keepdim
上面错误解决办法
把原先的代码
pred = output.data.max(1, keepdim=True)[1]改为
pred = torch.max(output.data,1)[1] 参考文章
版权声明:本文为zhuoyuezai原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。