当前位置: 首页 > 工具软件 > Axis Mundi > 使用案例 >

python中numpy的axis和torch的dim

吕皓
2023-12-01

举一个例子:

import torch
A = torch.rand((3,4))
print(A)
#tensor([[0.3602, 0.2583, 0.1758, 0.3575],
#        [0.9582, 0.2092, 0.6829, 0.8663],
#        [0.3922, 0.1360, 0.3733, 0.3477]])
z = A.sum(dim=1, keepdim=True)
print(z)
#tensor([[1.1518],
#        [2.7166],
#        [1.2492]])
y = A.sum(dim=1)
print(y)
#tensor([1.1518, 2.7166, 1.2492])

我们看keepdim=True的情形(此时最清楚,没有做置换强行改变行列下标),dim=1就相当于sum是在dim=1上做,于是dim=1的下标就没有了,只剩下dim=0的下标。即,z的元素下标为 [0,0],[1,0],[2,0],dim=1的下标全部为0。

从这个角度看,不用画图,只考虑运算。另外,也可以认为sum(dim=1),就是 ∑ d i m = 1 \sum\limits_{dim=1} dim=1

 类似资料: