pytorch中名目繁多的乘法
pytorch中表达乘法的方法有很多,如torch.mul(),以下为笔记兼试错心得。既然选择了半小时入门玩法,基础总是要撞坑的!
1. 读作torch.mul()写作*的乘法
简单来说就是 tensor 元素逐个相乘,他还有一个洋气的名字叫哈达玛积。又学了一个憨憨叫法,又可以用专有名词吓坏小朋友了。
import numpy as np
import torch
a = range(12)
n = np.array(a).reshape([3, 4])
X = torch.from_numpy(n).float()
print(X * X)
print(X.mul(X)) # torch.mul(X, X)
结果:
tensor([[ 0., 1., 4., 9.],
[ 16., 25., 36., 49.],
[ 64., 81., 100., 121.]])
tensor([[ 0., 1., 4., 9.],
[ 16., 25., 36., 49.],
[ 64., 81., 100., 121.]])
2. 写作torch.matmul()写作@的乘法
我才是矩阵乘法!我还支持广播机制!不行了,我好骚啊,我得叉下腰!
import numpy as np
import torch
a = range(12)
n = np.array(a).reshape([3, 4])
X = torch.from_numpy(n).float()
print(X @ X.t())
print(X.matmul(X.t())) # torch.matmul(X, X.t())
结果:
tensor([[ 14., 38., 62.],
[ 38., 126., 214.],
[ 62., 214., 366.]])
tensor([[ 14., 38., 62.],
[ 38., 126., 214.],
[ 62., 214., 366.]])
广播
广播可好玩了,玩好了可以骚出天际,玩不好原地升天!
Y = torch.ones(4)
print(X @ Y)
Z = torch.ones(3)
print(Z @ X)
tensor([ 6., 22., 38.])
tensor([12., 15., 18., 21.])
X = X.reshape([3, 2, 2])
print(Z @ X)
H = torch.ones(2)
print(X @ H)
print(H @ X)
第一个报错,mat1 and mat2 shapes cannot be multiplied (6x2 and 3x1)
tensor([[ 1., 5.],
[ 9., 13.],
[17., 21.]])
tensor([[ 2., 4.],
[10., 12.],
[18., 20.]])
这告诉我们三维数据的存储形式是若干个二位矩阵,第一个维度的维数它仅仅代表着(我想要你快乐)有多少个这样的二维矩阵
那么三维乘二维i呢?
H = torch.ones([2,2])
print(H @ X)
print(X @ H)
这两个根据上面的分析显然没有问题对吧!
那么三维乘三维也类似。
特别的,当第一个维度是1时,可以拓展为对应的维度!
torch.mm()
单纯的二维矩阵乘法,不如matmul好用,不支持广播,很菜。
torch.bmm()
按batch进行torch.mm(),仅用与三个维度都相同的三维张量,没得广播,不够骚。
torch.dot()
向量点积,对应位置相乘再相加,如果是p*n(p个n维)的向量,torch.diag(x@y.t())
x = torch.randn([5,3])
print(torch.diag(x@x.t()))
tensor([1.2279, 3.7674, 3.7071, 1.1913, 2.3835])
分别为每个向量的点积
两个骚到飞起的乘法
a = torch.arange(60.).reshape(3, 4, 5)
b = torch.arange(24.).reshape(4, 3, 2)
print(torch.tensordot(a, b, dims=([1, 0], [0, 1]))) # 支持广播
print(torch.einsum("jik,ijl->kl", (x, y)))
dims=([1,0],[0,1]),表明对a的0,1两维转置,对b的0,1两维不变,使用得到的矩阵对应位置按位乘再相加得到的矩阵,而"ijk,jil->kl"表达的意思类似,但是进行转置的固定为后面的矩阵,同时失去了对广播的支持!
还有更好玩的用法
x = torch.range(1, 12).reshape([3,4])
y = torch.range(1, 8).reshape([4,2])
print(torch.einsum("nm,mk->nk", (x, y))) # torch.tensordot(x,y,dims=([1],[0]))
print(x@y)
发现是一样的,这就是矩阵乘法呀!
那么是不是可以有更多的骚操作呢?
x = torch.range(1, 12).reshape([3,2,2])
y = torch.range(1, 8).reshape([4,2])
print(torch.einsum("ijk,lk->ijl", (x, y))) # torch.tensordot(x,y,dims=([2],[1]))
这样如果我们分不清楚matmul的操作方法,就可以用这种方法糊弄过去!