dim
表示维度
x = torch.randn(2, 3, 3)
print(x)
print(x.size())
print(x.dim())
输出:
tensor([[[-1.6943, -2.1487, 1.2332],
[-0.2261, -0.1596, 1.5513],
[ 2.0383, -0.6982, -2.1481]],
[[ 0.4201, -2.7373, 0.2424],
[-1.1152, 1.3682, -1.8322],
[ 0.1957, -0.2920, 0.1845]]])
torch.Size([2, 3, 3])
3
这样看着不是很清晰,但如果将[]
格式化:
[
[
[-1.6943, -2.1487, 1.2332],
[-0.2261, -0.1596, 1.5513],
[ 2.0383, -0.6982, -2.1481]
],
[
[ 0.4201, -2.7373, 0.2424],
[-1.1152, 1.3682, -1.8322],
[ 0.1957, -0.2920, 0.1845]
]
]
(2, 3, 3)
就很明显了, 是从矩阵的外部到内部x.dim() = 3
意味着x
有三个维度, dim = (0, 1, 2)
,
0
对应着x.size()
中的(2
, 3, 3)1
对应着x.size()
中的(2, 3
, 3)2
对应着x.size()
中的(2, 3, 3
)当dim = 0
时, 指的是 x(3, 3)
也就是:
x = torch.randn(2, 3, 3)
print(x)
for i in x:
print(i)
print(i.size())
输出:
tensor([[[-1.4251, -0.8321, 1.0230],
[ 0.2008, 0.5929, -0.7696],
[-0.3721, -1.0837, -0.6642]],
[[-0.5337, 0.7808, 0.4419],
[-0.4683, 0.3847, 0.0747],
[ 1.0156, -0.4933, 1.5340]]])
tensor(
[
[-1.4251, -0.8321, 1.0230],
[ 0.2008, 0.5929, -0.7696],
[-0.3721, -1.0837, -0.6642]
]
)
torch.Size([3, 3])
tensor(
[
[-0.5337, 0.7808, 0.4419],
[-0.4683, 0.3847, 0.0747],
[ 1.0156, -0.4933, 1.5340]
]
)
torch.Size([3, 3])
所以说当dim=0
时, 相当于去除x
中的dim = 0
的维度
import torch
x = torch.randn(2, 3, 3)
print(x)
print('='*50, end='\n\n')
for i in x:
print(i)
print(i.size())
print('='*50, end='\n\n')
print(x.size())
print(x.dim())
print('='*50, end='\n\n')
y = torch.argmax(x, dim=0)
print(y)
print(y.size())
输出:
tensor(
[
[
[-1.3918, 0.0620, -0.4111],
[ 1.9623, -1.3399, -0.4673],
[-0.0185, -1.9024, 0.1340]
],
[
[ 0.7135, -0.5290, -0.7656],
[ 0.2642, 0.5956, -0.0718],
[-0.7465, -0.8098, -0.0874]
]
]
)
==================================================
tensor([[-1.3918, 0.0620, -0.4111],
[ 1.9623, -1.3399, -0.4673],
[-0.0185, -1.9024, 0.1340]])
torch.Size([3, 3])
tensor([[ 0.7135, -0.5290, -0.7656],
[ 0.2642, 0.5956, -0.0718],
[-0.7465, -0.8098, -0.0874]])
torch.Size([3, 3])
==================================================
torch.Size([2, 3, 3])
3
==================================================
tensor([[1, 0, 0],
[0, 1, 1],
[0, 1, 0]])
torch.Size([3, 3])
分析一下 y[0] = [1, 0, 0]
, 为什么呢?
有两种想法:
[-1.3918, 0.0620, -0.4111]
与 [ 0.7135, -0.5290, -0.7656]
1
0
0
x[i]
中的每一列, 得到的是2x3
的输出, 例如 x[0]
: [-1.3918, 0.0620, -0.4111],
[ 1.9623, -1.3399, -0.4673],
[-0.0185, -1.9024, 0.1340]
比较每一列, 经过torch.argmax
得到的是 [1, 0, 2]
如果按照去掉dim = 0
的部分, x'
:
[
[-1.3918, 0.0620, -0.4111],
[ 1.9623, -1.3399, -0.4673],
[-0.0185, -1.9024, 0.1340]
],
[
[ 0.7135, -0.5290, -0.7656],
[ 0.2642, 0.5956, -0.0718],
[-0.7465, -0.8098, -0.0874]
]
也就是两个size = (3, 3)
的tensor
, 这为什么不是第二种情况就比较合理了
因为比较的是两个tensor
, 而第二种情况是分别在一个tensor
内的比较, 再将两个tensor
的比较结果合并
2
, 3, 3) -> (3
, 3), 得到的结果的size
是去掉指定dim
的size
如果只有两个维度, 或许会好理解一些:
import torch
x = torch.randn(2,3)
print(x)
y = torch.argmax(x, dim=0)
print(y)
print(y.size())
输出:
tensor(
[
[ 0.0251, -0.3640, 0.1965],
[ 0.6902, 0.9846, 0.2035]
]
)
tensor([1, 1, 1])
torch.Size([3])
去掉dim = 0
, 比较的就是 [ 0.0251, -0.3640, 0.1965]
和 [ 0.6902, 0.9846, 0.2035]
dim = (2
, 3) -> dim(3
)
这时候再回来看上面3个维度的例子:
[
[-1.3918, 0.0620, -0.4111],
[ 1.9623, -1.3399, -0.4673],
[-0.0185, -1.9024, 0.1340]
],
[
[ 0.7135, -0.5290, -0.7656],
[ 0.2642, 0.5956, -0.0718],
[-0.7465, -0.8098, -0.0874]
]
比较两者时相当于在下面的tensor
做torch.argmax()
[
[-1.3918, 0.0620, -0.4111],
[ 0.7135, -0.5290, -0.7656]
]