mm-bmm-matmul的区别

徐鸿文
2023-12-01

mm-bmm-matmul的区别

三者都是用于计算torch的计算方法!

torch.mm

用于计算2D矩阵tensor的叉积。(注意必须是2D的tensor才能用于mm计算)

x = torch.tensor([[1,2,3]])
y = torch.tensor([[1,2,3,4],
                  [5,6,7,8],
                  [9,10,11,12]])
z = torch.mm(x, y)
print(z)
result:
tensor([[38, 44, 50, 56]])

如果使用:

x = torch.tensor([1,2,3])
y = torch.tensor([[1,2,3,4],
                  [5,6,7,8],
                  [9,10,11,12]])
z = torch.mm(x, y)

x是1D的,y是2D的,会报错。必须使得两者都是2D的。

torch.bmm

用于计算3D矩阵的tensor的叉积。(注意必须是3D的tensor才能用于bmm计算)

第一维度必须相同,因为第一维度是batch维度!

x = torch.tensor([[[1,2,3]]])
y = torch.tensor([[[1,2,3,4],
                  [5,6,7,8],
                  [9,10,11,12]]])
z = torch.bmm(x, y)
print(z)
result:
tensor([[[38, 44, 50, 56]]])

torch.matmul

用于计算点积或叉积。

规则如下:

1D * 1D

1D * 1D为点积!

x = torch.tensor([1,1,1])
y = torch.tensor([2,2,2])
z = torch.matmul(x,y)
print(z)
reslut:
tensor(6)

1D * 2D

1D * 2D为叉积!

x = torch.tensor([1,2,3])
y = torch.tensor([[1,2,3,4],
                  [5,6,7,8],
                  [9,10,11,12]])
z = torch.matmul(x, y)
print(z)
reslut:
tensor([38, 44, 50, 56])

2D * 1D

2D * 1D为:2D行与1D进行点积,最后将结果平铺为1D!

x = torch.tensor([1,2,3])
y = torch.tensor([[1,2,3],
                  [5,6,7],
                  [9,10,11],
                  [1,1,1]])
z = torch.matmul(y, x)
print(z)
reslut:
tensor([14, 38, 62,  6])

1D * 多D

多D的第一维度是batch,进行的是批量计算,最后两维与1D进行叉积!

import torch
x = torch.randn(2, 3, 4)
y = torch.randn(3)
print(torch.matmul(y, x),'\n',torch.matmul(y, x).size()) #1D*3D
 
output:
tensor([[-0.9747, -0.6660, -1.1704, -1.0522],
        [ 0.0901, -1.5353,  1.5601, -0.0252]]) 
 torch.Size([2, 4])

多D * 1D

多D的第一维度是batch,进行的是批量计算,最后两维与1D进行点积!

import torch
x = torch.randn(2, 3, 4)
y = torch.randn(4)
 
print(torch.matmul(x, y),'\n',torch.matmul(x, y).size()) # 3D*1D
 
output:
tensor([[ 0.6217, -0.1259, -0.2377],
        [ 0.6874,  0.0733,  0.1793]]) 
 torch.Size([2, 3])

3D * 3D

与bmm的计算相同!

import torch
x = torch.randn(2,2,4)
y = torch.randn(2,4,5)
 
print(torch.matmul(x, y).size(),'\n',torch.bmm(x, y).size())
print(torch.equal(torch.matmul(x,y),torch.bmm(x,y)))
 
output:
torch.Size([2, 2, 5]) 
 torch.Size([2, 2, 5])
True
 类似资料: