当前位置: 首页 > 工具软件 > arg > 使用案例 >

torch.argmax()函数

文华美
2023-12-01

argmax函数:torch.argmax(input, dim=None, keepdim=False) 返回指定维度最大值的序号,dim给定的定义是:the demention to reduce,就是把dim这个维度,变成这个维度的最大值的index。

1)dim表示不同维度。特别的在dim=0表示二维矩阵中的列,dim=1在二维矩阵中的行。广泛的来说,我们不管一个矩阵是几维的,比如一个矩阵维度如下:(d0,d1,…,dn−1) ,那么dim=0就表示对应到d0 也就是第一个维度,dim=1表示对应到也就是第二个维度,以此类推。

2)知道dim的值是什么意思还不行,还要知道函数中这个dim给出来会发生什么。

例子一:二维数组

import torch

x = torch.randn(2, 4)
print(x)
'''
tensor([[ 1.2864, -0.5955,  1.5042,  0.5398],
        [-1.2048,  0.5106, -2.0288,  1.4782]])
'''

# y0表示矩阵dim=0维度上(每一列)张量最大值的索引
y0 = torch.argmax(x, dim=0)
print(y0)
'''
tensor([0, 1, 0, 1])
'''

# y1表示矩阵dim=1维度上(每一行)张量最大值的索引
y1 = torch.argmax(x, dim=1)
print(y1)
'''
tensor([2, 3])
'''

例子二:三维数组

x = torch.randn(2, 4, 5)
print(x)
'''
tensor([[[-1.2204, -0.6428, -0.2278,  0.5589,  1.1589],
         [ 0.4235,  1.9663,  0.5055, -1.3472,  1.3523],
         [ 1.4220,  0.7886, -1.0821,  0.6268, -0.9465],
         [-0.3950,  1.3275,  0.3369,  1.0224, -0.9944]],

        [[ 0.6024, -0.2604, -0.8631,  0.8113, -0.3140],
         [ 0.3487, -0.1941, -0.3955, -0.1719, -1.3734],
         [ 0.2467, -0.4268, -1.3428,  0.7346,  1.0932],
         [-0.5799,  0.0976, -1.9403, -0.2643,  0.7657]]])
'''

# dim=0,将第一个维度消除,也就是将两个[4*5]矩阵只保留一个,因此要在上下两个[3*4]的矩阵分别在对应位置上比较
y0 = torch.argmax(x, dim=0)
print(y0)
'''
tensor([[1, 1, 0, 1, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 1, 1],
        [0, 0, 0, 0, 1]])
'''

# dim=1,将第二个维度消除,也就是将四个[2*5]矩阵只保留一个
y1 = torch.argmax(x, dim=1)
print(y1)
'''
tensor([[2, 1, 1, 3, 1],
        [0, 3, 1, 0, 2]])
'''

y2 = torch.argmax(x, dim=2)
print(y2)
'''
tensor([[4, 1, 0, 1],
        [3, 0, 4, 4]])
'''

 类似资料: