st_gcn网络定义中使用了torch.nn.BatchNorm1d,在运行网络时出现RuntimeError: Tensor for argument #2 'weight' is on CPU, but expected it to be on GPU (while checking arguments for cudnn_batch_norm)
论坛中的解释是因为在forward函数中直接定义并使用了一个函数,但是源代码中是在__init__函数中定义,forward函数中再使用的
github上好像说是batchnorm1d的通病
最后在__init__函数中将定义的函数转到GPU上得到解决
self.data_bn = nn.BatchNorm1d(in_channels * A.size(1))
self.data_bn.cuda()
版权声明:本文为YoJayC原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。