官方文档地址:https://pytorch.org/docs/stable/generated/torch.bmm.html?highlight=bmm#torch.bmm
形式:torch.bmm(input, mat2, *, out=None) → Tensor
作用:矩阵的批量相乘,支持TensorFloat32数据的操作。
要求:input 和 mat2 必须是 3-D 张量,每个张量都包含相同数量的矩阵。如果input的维度是
(
b
×
n
×
m
)
(b\times n\times m)
(b×n×m),mat2维度是
(
b
×
m
×
p
)
(b\times m \times p)
(b×m×p),那么返回的结果out就是:
(
b
×
n
×
p
)
(b\times n \times p)
(b×n×p),那么有:
o
u
t
i
=
i
n
p
u
t
i
@
m
a
t
2
i
out_i = input_i @ mat2_i
outi=inputi@mat2i
注:该操作不支持广播。支持广播的矩阵相乘:torch.matmul()
使用案例:
input = torch.randn(10, 3, 4)
mat2 = torch.randn(10, 4, 5)
res = torch.bmm(input, mat2)
res.size() # torch.Size([10, 3, 5])