torch.argmax的一些补充

torch.argmax是不会向后传梯度,但是被选中的部分还是可以传梯度的

import torch

s=torch.rand(1,3,6,6,requires_grad=True)
d=torch.rand(1,3,6,6,requires_grad=True)
p=torch.argmax(s,dim=1).unsqueeze(1)
q=torch.gather(d,dim=1,index=p)
q=q.sum()
loss=(q-1)*(q-1)
loss.backward()
print(s.grad)
print(d.grad)

output:

None
tensor([[[[ 0.0000,  0.0000,  0.0000, 38.2169,  0.0000,  0.0000],
          [38.2169,  0.0000,  0.0000, 38.2169, 38.2169,  0.0000],
          [ 0.0000, 38.2169,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000, 38.2169,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000, 38.2169,  0.0000, 38.2169,  0.0000,  0.0000],
          [38.2169,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

         [[38.2169,  0.0000, 38.2169,  0.0000, 38.2169,  0.0000],
          [ 0.0000, 38.2169,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, 38.2169, 38.2169,  0.0000, 38.2169],
          [38.2169,  0.0000,  0.0000, 38.2169, 38.2169, 38.2169],
          [ 0.0000,  0.0000, 38.2169,  0.0000,  0.0000, 38.2169],
          [ 0.0000,  0.0000, 38.2169,  0.0000,  0.0000,  0.0000]],

         [[ 0.0000, 38.2169,  0.0000,  0.0000,  0.0000, 38.2169],
          [ 0.0000,  0.0000, 38.2169,  0.0000,  0.0000, 38.2169],
          [38.2169,  0.0000,  0.0000,  0.0000, 38.2169,  0.0000],
          [ 0.0000,  0.0000, 38.2169,  0.0000,  0.0000,  0.0000],
          [38.2169,  0.0000,  0.0000,  0.0000, 38.2169,  0.0000],
          [ 0.0000, 38.2169,  0.0000, 38.2169, 38.2169, 38.2169]]]])

 


版权声明:本文为qq_39861441原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。