attention


代码实现:这里的数据是二维的 假设是三维的,一般在seq2seq中会有 ,而且第一维度的 也是1
所以不影响!
a=torch.rand(1,256)
a=a.repeat(10,1)
print(a.size())
b=torch.rand(10,256)
weights = torch.tanh(attn(torch.cat([a, b],1)))
print(weights.size())
v=nn.Linear(10,1)
attention = F.softmax(v(weights),dim=0)
print(attention)
版权声明:本文为qq_38735017原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。