pytorch pdist

写在前面

最近看代码,发现别人的代码里用到了一个神奇的操作,torch.pdist()。查阅许久之后,对于他们的描述都不是很明白,遂结合描述,自行测试,结果记录于此,便于理解。

一、文档描述

关于torch.pdist()官方文档如下:
pdist官方文档

计算输入中每对行向量之间的p范数距离。这与torch.norm(input[:,None]-input, dim=2, p=p])的上三角部分相同,不包括对角线。如果行上是连续的,这个函数将很快。(Computes the p-norm distance between every pair of row vectors in the input. This is identical to the upper triangular portion, excluding the diagonal, of torch.norm(input[:, None] - input, dim=2, p=p). This function will be faster if the rows are contiguous.)

其实理解之后,对于他的描述才会感觉认同。但是不理解的时候,也看不太懂他的描述。

二、代码测试

我看到的代码如下:

    torch.pdist(x, p=2)	# 其中x为二维矩阵

因此,为了更好的理解torch.pdist(),我需要去建立一个简单的二维矩阵,然后根据torch.pdist()的原理,手写出其计算过程。(PS:之所以建立简单的二维矩阵,就是为了更好理解)

import torch
import numpy as np
_x = np.asarray([[1,2,3],[4,5,6],[7,8,9]])
# print(_x)
# array([[1, 2, 3],
#       [4, 5, 6],
#       [7, 8, 9]])
# 官方的pdist
x = torch.Tensor(_x)
res0 = torch.pdist(x, p=2)
# print(res0)
# tensor([ 5.1962, 10.3923,  5.1962])
# 官方的解释
res1 = torch.norm(x[:,None]-x,dim=2,p=2)
# print(res0)
# tensor([[ 0.0000,  5.1962, 10.3923],
#        [ 5.1962,  0.0000,  5.1962],
#        [10.3923,  5.1962,  0.0000]])
# 取上三角部分,剔除掉对角线,就是[5.1962, 10.3923, 5.1962],但是这又需要看懂torch.norm(x[:,None]-x,dim=2,p=2)是什么意思

我的理解

文档里讲到,他是算两行之间的p norm,p是参数,容易知道2-范数的公式:
x = ∣ x 1 − x 2 ∣ 2 + ∣ y 1 − y 2 ∣ 2 + ∣ z 1 − z 2 ∣ 2 x = \sqrt{|x_1 - x_2|^2 + |y_1 - y_2|^2 + |z_1 - z_2|^2}x=x1x22+y1y22+z1z22
看到公式,计算下两行之间的2-范数,就知道结果了:
5.1962 ≊ ∣ 1 − 4 ∣ 2 + ∣ 2 − 5 ∣ 2 + ∣ 3 − 6 ∣ 2 5.1962 \approxeq \sqrt{|1 - 4|^2 + |2 - 5|^2 + |3 - 6|^2}5.1962142+252+362
10.3923 ≊ ∣ 1 − 7 ∣ 2 + ∣ 2 − 8 ∣ 2 + ∣ 3 − 9 ∣ 2 10.3923\approxeq \sqrt{|1-7|^2 + |2-8|^2 + |3-9|^2}10.3923172+282+392
于是,就很明了了。


版权声明:本文为yyhaohaoxuexi原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。