当前位置: 首页 > 编程笔记 >

浅谈Pytorch中的torch.gather函数的含义

柴华灿
2023-03-14
本文向大家介绍浅谈Pytorch中的torch.gather函数的含义,包括了浅谈Pytorch中的torch.gather函数的含义的使用技巧和注意事项,需要的朋友参考一下

pytorch中的gather函数

pytorch比tensorflow更加编程友好,所以准备用pytorch试着做最近要做的一些实验。

立个flag开始学习pytorch,新开一个分类整理学习pytorch中的一些踩到的泥坑。

今天刚开始接触,读了一下documentation,写一个一开始每太搞懂的函数gather

b = torch.Tensor([[1,2,3],[4,5,6]])
print b
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print torch.gather(b, dim=1, index=index_1)
print torch.gather(b, dim=0, index=index_2)

观察它的输出结果:

 1 2 3
 4 5 6
[torch.FloatTensor of size 2x3]


 1 2
 6 4
[torch.FloatTensor of size 2x2]


 1 5 6
 1 2 3
[torch.FloatTensor of size 2x3]

这里是官方文档的解释

torch.gather(input, dim, index, out=None) → Tensor

 Gathers values along an axis specified by dim.

 For a 3-D tensor the output is specified by:

 out[i][j][k] = input[index[i][j][k]][j][k] # dim=0
 out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
 out[i][j][k] = input[i][j][index[i][j][k]] # dim=2

 Parameters: 

  input (Tensor) – The source tensor
  dim (int) – The axis along which to index
  index (LongTensor) – The indices of elements to gather
  out (Tensor, optional) – Destination tensor

 Example:

 >>> t = torch.Tensor([[1,2],[3,4]])
 >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
  1 1
  4 3
 [torch.FloatTensor of size 2x2]

可以看出,gather的作用是这样的,index实际上是索引,具体是行还是列的索引要看前面dim 的指定,比如对于我们的栗子,【1,2,3;4,5,6,】,指定dim=1,也就是横向,那么索引就是列号。index的大小就是输出的大小,所以比如index是【1,0;0,0】,那么看index第一行,1列指的是2, 0列指的是1,同理,第二行为4,4 。这样就输入为【2,1;4,4】,参考这样的解释看上面的输出结果,即可理解gather的含义。

gather在one-hot为输出的多分类问题中,可以把最大值坐标作为index传进去,然后提取到每一行的正确预测结果,这也是gather可能的一个作用。

以上这篇浅谈Pytorch中的torch.gather函数的含义就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持小牛知识库。

 类似资料:
  • 本文向大家介绍浅谈Shell中的函数,包括了浅谈Shell中的函数的使用技巧和注意事项,需要的朋友参考一下 函数可以让我们将一个复杂功能划分成若干模块,让程序结构更加清晰,代码重复利用率更高。像其他编程语言一样,Shell也支持函数。Shell函数必须先定义后使用。 1.Shell函数的定义格式 可以带function关键字使用function fun_name()来定义,也可以直接给出函数名fu

  • 本文向大家介绍浅谈python中的getattr函数 hasattr函数,包括了浅谈python中的getattr函数 hasattr函数的使用技巧和注意事项,需要的朋友参考一下 hasattr(object, name) 作用:判断对象object是否包含名为name的特性(hasattr是通过调用getattr(ojbect, name)是否抛出异常来实现的)。 示例: getattr(obj

  • 本文向大家介绍浅谈Python中函数的参数传递,包括了浅谈Python中函数的参数传递的使用技巧和注意事项,需要的朋友参考一下 1.普通的参数传递 2.参数个数可选,参数有默认值的传递 参数sep的缺省值是'_' 如果这个参数不给定值就会使用缺省值 如果给定 则使用给定的值 需要注意 如果一个参数是可选参数 那么它后面所有的参数都应该是可选的,另外 可选参数的顺序颠倒依然可以正确的给对应的参数赋值

  • 本文向大家介绍浅谈function(函数)中的动态参数,包括了浅谈function(函数)中的动态参数的使用技巧和注意事项,需要的朋友参考一下 我们可向函数传递动态参数,*args,**kwargs,首先我们来看*args,示例如下: 1.show(*args) 首先我们定义了一个函数,函数show(*args)里面的*args可以接收动态参数,这里我们接收一个元组形式的参数,我们可以向show(

  • 本文向大家介绍浅谈Mysql中类似于nvl()函数的ifnull()函数,包括了浅谈Mysql中类似于nvl()函数的ifnull()函数的使用技巧和注意事项,需要的朋友参考一下 IFNULL(expr1,expr2) 如果expr1不是NULL,IFNULL()返回expr1,否则它返回expr2。IFNULL()返回一个数字或字符串值,取决于它被使用的上下文环境。 如果expr1是TRUE(e

  • 本文向大家介绍浅谈pytorch torch.backends.cudnn设置作用,包括了浅谈pytorch torch.backends.cudnn设置作用的使用技巧和注意事项,需要的朋友参考一下 cuDNN使用非确定性算法,并且可以使用torch.backends.cudnn.enabled = False来进行禁用 如果设置为torch.backends.cudnn.enabled =Tru