BUG[Pytorch]:RuntimeError: Expected object of scalar type Int but got scalar type Float for argument

RuntimeError: Expected object of scalar type Int but got scalar type Float for argument #3 ‘mat1’…

torch.float64
Traceback (most recent call last):
  File "G:/Graph-master/train.py", line 173, in <module>
    main()
  File "G:/Graph-master/train.py", line 81, in main
    metrics = engine.train(trainx, trainy[:, 0, :, :])  # trainy.shape=(64,207,12)
  File "G:\Graph-master\engine.py", line 38, in train
    output = self.model(input)
  File "F:\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "G:\Graph-master\model.py", line 181, in forward
    x = self.start_conv(x)
  File "F:\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "F:\Anaconda3\lib\site-packages\torch\nn\modules\conv.py", line 349, in forward
    return self._conv_forward(input, self.weight)
  File "F:\Anaconda3\lib\site-packages\torch\nn\modules\conv.py", line 346, in _conv_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #3 'mat1' in call to _th_addmm_

该错误是在给模型传数据、准备训练时产生的,我检查了所有的数据来源,其数据类型都是numpy.float64,根据错误的显示,需要传入double类型的参数,emmm,改成double还是同样的错误,百思不得其解。
解决:将传入的数据后面加上.float(),就好了~

x = self.start_conv(x.float())

float()可将数据转为float32()

import torch
a = torch.ones([2, 4], dtype=torch.float64)
"""
>>> a
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.]], dtype=torch.float64)
"""
a.float().dtype # torch.float32 

官网的类型解释

最终,将float64转为float32解决了问题,但是和报错信息有冲突


版权声明:本文为qq_33866063原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。