写在前面
最近看代码,发现别人的代码里用到了一个神奇的操作,torch.pdist()
。查阅许久之后,对于他们的描述都不是很明白,遂结合描述,自行测试,结果记录于此,便于理解。
一、文档描述
关于torch.pdist()
的官方文档如下:
计算输入中每对行向量之间的p范数距离。这与
torch.norm(input[:,None]-input, dim=2, p=p])
的上三角部分相同,不包括对角线。如果行上是连续的,这个函数将很快。(Computes the p-norm distance between every pair of row vectors in the input. This is identical to the upper triangular portion, excluding the diagonal, of torch.norm(input[:, None] - input, dim=2, p=p). This function will be faster if the rows are contiguous.)
其实理解之后,对于他的描述才会感觉认同。但是不理解的时候,也看不太懂他的描述。
二、代码测试
我看到的代码如下:
torch.pdist(x, p=2) # 其中x为二维矩阵
因此,为了更好的理解torch.pdist()
,我需要去建立一个简单的二维矩阵,然后根据torch.pdist()
的原理,手写出其计算过程。(PS:之所以建立简单的二维矩阵,就是为了更好理解)
import torch
import numpy as np
_x = np.asarray([[1,2,3],[4,5,6],[7,8,9]])
# print(_x)
# array([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
# 官方的pdist
x = torch.Tensor(_x)
res0 = torch.pdist(x, p=2)
# print(res0)
# tensor([ 5.1962, 10.3923, 5.1962])
# 官方的解释
res1 = torch.norm(x[:,None]-x,dim=2,p=2)
# print(res0)
# tensor([[ 0.0000, 5.1962, 10.3923],
# [ 5.1962, 0.0000, 5.1962],
# [10.3923, 5.1962, 0.0000]])
# 取上三角部分,剔除掉对角线,就是[5.1962, 10.3923, 5.1962],但是这又需要看懂torch.norm(x[:,None]-x,dim=2,p=2)是什么意思
我的理解
文档里讲到,他是算两行之间的p norm,p是参数,容易知道2-范数的公式:
x = ∣ x 1 − x 2 ∣ 2 + ∣ y 1 − y 2 ∣ 2 + ∣ z 1 − z 2 ∣ 2 x = \sqrt{|x_1 - x_2|^2 + |y_1 - y_2|^2 + |z_1 - z_2|^2}x=∣x1−x2∣2+∣y1−y2∣2+∣z1−z2∣2
看到公式,计算下两行之间的2-范数,就知道结果了:
5.1962 ≊ ∣ 1 − 4 ∣ 2 + ∣ 2 − 5 ∣ 2 + ∣ 3 − 6 ∣ 2 5.1962 \approxeq \sqrt{|1 - 4|^2 + |2 - 5|^2 + |3 - 6|^2}5.1962≊∣1−4∣2+∣2−5∣2+∣3−6∣2
10.3923 ≊ ∣ 1 − 7 ∣ 2 + ∣ 2 − 8 ∣ 2 + ∣ 3 − 9 ∣ 2 10.3923\approxeq \sqrt{|1-7|^2 + |2-8|^2 + |3-9|^2}10.3923≊∣1−7∣2+∣2−8∣2+∣3−9∣2
于是,就很明了了。