常见报错:RuntimeError: expected scalar type Long but found Float

RuntimeError: expected scalar type Long but found Float

这是一个非常常见的报错,我已经遇到过这个报错很多次了,但是之前没有仔细研究过,今天好好好看了看,终于找到了原因。
首先把导致报错的代码写出来:

import torch
import torch.nn as nn

v = torch.tensor([0])
m = nn.Linear(1, 10)
m(v)

短短的几行代码,就是初始化了一个值为0的v、一个网络m,运行后爆出了一大堆错:

Traceback (most recent call last):
  File "D:\ProgramData\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 3418, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-183-2ddaa24c9bb3>", line 1, in <module>
    m(v)
  File "D:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "D:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\linear.py", line 93, in forward
    return F.linear(input, self.weight, self.bias)
  File "D:\ProgramData\Anaconda3\lib\site-packages\torch\nn\functional.py", line 1692, in linear
    output = input.matmul(weight.t())
RuntimeError: expected scalar type Long but found Float

注意到导致报错的代码: output = input.matmul(weight.t())
因为input也就是我们的v是torch.long类型的而weight是torch.float类型
所以在做矩阵乘法的时候这两种类型的不一致导致了报错
解决方法就是把v的dtype显示地设置成torch.float代码就成功运行了:

import torch
import torch.nn as nn
# dtype=torch.float必不可少
v = torch.tensor([0], dtype=torch.float)
m = nn.Linear(1, 10)
m(v)
Out[11]: 
tensor([-0.0628, -0.2544,  0.1313, -0.9293, -0.1259, -0.3151,  0.0729, -0.3097,
         0.8988,  0.1230], grad_fn=<AddBackward0>)

版权声明:本文为rambo_csdn_123原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。