当前位置: 首页 > 工具软件 > Dim > 使用案例 >

Pytorch中dim的理解

霍襦宗
2023-12-01

dim的定义

dim 表示维度

x = torch.randn(2, 3, 3)

print(x)
print(x.size())
print(x.dim())

输出:

tensor([[[-1.6943, -2.1487,  1.2332],
         [-0.2261, -0.1596,  1.5513],
         [ 2.0383, -0.6982, -2.1481]],

        [[ 0.4201, -2.7373,  0.2424],
         [-1.1152,  1.3682, -1.8322],
         [ 0.1957, -0.2920,  0.1845]]])
torch.Size([2, 3, 3])
3

这样看着不是很清晰,但如果将[]格式化:


[
    [
        [-1.6943, -2.1487,  1.2332],

        [-0.2261, -0.1596,  1.5513],

        [ 2.0383, -0.6982, -2.1481]
    ],

    [
        [ 0.4201, -2.7373,  0.2424],

        [-1.1152,  1.3682, -1.8322],

        [ 0.1957, -0.2920,  0.1845]
    ]
]

  • 维度(2, 3, 3)就很明显了, 是从矩阵的外部到内部
  • x.dim() = 3意味着x有三个维度, dim = (0, 1, 2),
    • 0对应着x.size()中的(2, 3, 3)
    • 1对应着x.size()中的(2, 3, 3)
    • 2对应着x.size()中的(2, 3, 3)

dim的理解

dim = 0时, 指的是 x(3, 3)
也就是:

x = torch.randn(2, 3, 3)
print(x)

for i in x:
    print(i)
    print(i.size())

输出:

tensor([[[-1.4251, -0.8321,  1.0230],
         [ 0.2008,  0.5929, -0.7696],
         [-0.3721, -1.0837, -0.6642]],

        [[-0.5337,  0.7808,  0.4419],
         [-0.4683,  0.3847,  0.0747],
         [ 1.0156, -0.4933,  1.5340]]])


tensor(
    [
        [-1.4251, -0.8321,  1.0230],
        [ 0.2008,  0.5929, -0.7696],
        [-0.3721, -1.0837, -0.6642]
    ]
)
torch.Size([3, 3])

tensor(
    [
        [-0.5337,  0.7808,  0.4419],
        [-0.4683,  0.3847,  0.0747],
        [ 1.0156, -0.4933,  1.5340]
    ]
)
torch.Size([3, 3])

所以说当dim=0时, 相当于去除x中的dim = 0的维度

验证

  • torch.argmax(tensor)
    返回tensor中值最大的数的下标, 比较的是同型张量
    Example:
    >>> x = torch.tensor([1, 5, 8, 4, 6])
    >>> torch.argmax(x)
    tensor(2)
import torch

x = torch.randn(2, 3, 3)

print(x)

print('='*50, end='\n\n')
for i in x:
    print(i)
    print(i.size())

print('='*50, end='\n\n')

print(x.size())
print(x.dim())

print('='*50, end='\n\n')

y = torch.argmax(x, dim=0)

print(y)
print(y.size())

输出:

tensor(
    [
        [
            [-1.3918,  0.0620, -0.4111],
            [ 1.9623, -1.3399, -0.4673],
            [-0.0185, -1.9024,  0.1340]
        ],

        [
            [ 0.7135, -0.5290, -0.7656],
            [ 0.2642,  0.5956, -0.0718],
            [-0.7465, -0.8098, -0.0874]
        ]
    ]
)
==================================================

tensor([[-1.3918,  0.0620, -0.4111],
        [ 1.9623, -1.3399, -0.4673],
        [-0.0185, -1.9024,  0.1340]])
torch.Size([3, 3])

tensor([[ 0.7135, -0.5290, -0.7656],
        [ 0.2642,  0.5956, -0.0718],
        [-0.7465, -0.8098, -0.0874]])
torch.Size([3, 3])
==================================================

torch.Size([2, 3, 3])
3
==================================================

tensor([[1, 0, 0],
        [0, 1, 1],
        [0, 1, 0]])
torch.Size([3, 3])
  • 分析一下 y[0] = [1, 0, 0], 为什么呢?
    有两种想法:

    1. 它比较的是 [-1.3918, 0.0620, -0.4111][ 0.7135, -0.5290, -0.7656]
      其中:
      [-1.3918, 0.7135], 0.7135比较大, 所以返回 1
      [0.0620, -0.5290], 0.0620比较大, 所以返回 0
      [-0.4111, -0.7656], -0.4111比较大, 所以返回 0
    2. 如果比较的是x[i]中的每一列, 得到的是2x3的输出, 例如 x[0]:
        [-1.3918,  0.0620, -0.4111],
        [ 1.9623, -1.3399, -0.4673],
        [-0.0185, -1.9024,  0.1340]
    

    比较每一列, 经过torch.argmax得到的是 [1, 0, 2]

  • 如果按照去掉dim = 0的部分, x':

    [
        [-1.3918,  0.0620, -0.4111],
        [ 1.9623, -1.3399, -0.4673],
        [-0.0185, -1.9024,  0.1340]
    ],
    
    [
        [ 0.7135, -0.5290, -0.7656],
        [ 0.2642,  0.5956, -0.0718],
        [-0.7465, -0.8098, -0.0874]
    ]
    

    也就是两个size = (3, 3)tensor, 这为什么不是第二种情况就比较合理了
    因为比较的是两个tensor, 而第二种情况是分别在一个tensor内的比较, 再将两个tensor的比较结果合并

    • 总结: 比较的是去掉指定维度后的第一个维度, 比如这里的:(2, 3, 3) -> (3, 3), 得到的结果的size是去掉指定dimsize
  • 如果只有两个维度, 或许会好理解一些:

    import torch
    
    x = torch.randn(2,3)
    
    print(x)
    
    y = torch.argmax(x, dim=0)
    
    print(y)
    print(y.size())
    

    输出:

    tensor(
        [
            [ 0.0251, -0.3640,  0.1965],
            [ 0.6902,  0.9846,  0.2035]
        ]
    )
    
    tensor([1, 1, 1])
    torch.Size([3])
    

    去掉dim = 0, 比较的就是 [ 0.0251, -0.3640, 0.1965][ 0.6902, 0.9846, 0.2035]
    dim = (2, 3) -> dim(3)

  • 这时候再回来看上面3个维度的例子:

    [                                           
        [-1.3918,  0.0620, -0.4111],
        [ 1.9623, -1.3399, -0.4673],
        [-0.0185, -1.9024,  0.1340]
    ],
    [
        [ 0.7135, -0.5290, -0.7656],
        [ 0.2642,  0.5956, -0.0718],
        [-0.7465, -0.8098, -0.0874]
    ]
    

    比较两者时相当于在下面的tensortorch.argmax()

    [
        [-1.3918,  0.0620, -0.4111],
        [ 0.7135, -0.5290, -0.7656]
    ]
    
 类似资料: