【Pytorch】torch.bmm()方法使用

邹华皓
2023-12-01

官方文档地址:https://pytorch.org/docs/stable/generated/torch.bmm.html?highlight=bmm#torch.bmm

形式:torch.bmm(input, mat2, *, out=None) → Tensor

作用:矩阵的批量相乘,支持TensorFloat32数据的操作。

要求:input 和 mat2 必须是 3-D 张量,每个张量都包含相同数量的矩阵。如果input的维度是 ( b × n × m ) (b\times n\times m) (b×n×m),mat2维度是 ( b × m × p ) (b\times m \times p) (b×m×p),那么返回的结果out就是: ( b × n × p ) (b\times n \times p) (b×n×p),那么有:
o u t i = i n p u t i @ m a t 2 i out_i = input_i @ mat2_i outi=inputi@mat2i

注:该操作不支持广播。支持广播的矩阵相乘:torch.matmul()

使用案例:

input = torch.randn(10, 3, 4)
mat2 = torch.randn(10, 4, 5)
res = torch.bmm(input, mat2)
res.size()  # torch.Size([10, 3, 5])
 类似资料: