问题描述
使用pytorch的函数 torch.nn.CrossEntropyLoss()计算Loss时报错:
RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed
报错原因
直观上看,函数要求目标分类数大于等于0并且小于等于输入的类别。所以一般而言,都是网络中输出的种类数和标签中设置的种类数量不同造成的。
解决方案
针对于不同原因,主要从两方面考虑解决。
方向一:模型输出与分类数不一致
- 看一下模型的输出尺寸与分类数差异是否明显,核查代码是否存在错误。
- 如果没有错误,只是映射维度不对,可以考虑在模型的最后一层加一层FC层,将输出尺寸映射到分类大小。
方向二:标签的设置不是从0开始
- 如果模型的输出尺寸与分类数大小相同,看一下标签的设定是否是从0开始的。
- 如果标签是从1开始设置的,重新设置标签。这里存在的坑是:在使用CrossEntropyLoss()这个函数进行验证时,标签必须从0开始设置,否则便会报错。
版权声明:本文为m0_37369043原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。

