torch.tensor().reshape在求梯度问题

x = torch.tensor([[2,3]],dtype=torch.float,requires_grad=True)

......

x.grad

与x = torch.tensor([2,3],dtype=torch.float,requires_grad=True).reshape((1,2))

......

x.grad

是不一样的,前者可以正常求出梯度,后者的x经过reshape就不是叶子节点了,所以平时要注意需要求梯度的变量的创立格式。(注意第一个tensor是两个中括号,这样x刚创立就是2维的)

如果仍然想用reshape可以这样:

x = torch.tensor([2,3],dtype=torch.float,requires_grad=True)

x1 = x.reshape((1,2))

......

x.grad


版权声明:本文为mahui6144原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。