Pytorch 常用函数 cheatsheet (不定期更新)

万俟渝
2023-12-01

Pytorch 常用函数 cheatsheet (不定期更新)

为啥我会做这么一个cheatsheet?

这原本是我自己复现代码时候做的笔记,原本只是想要keep it personal,在写代码的时候方便查阅。但是最近在读James Clear 的《Atomic Habits》深受启发,也开始希望自己平日里的一点点积累能够帮助到和我一样在学习之路上苦修的萌新们。

日后这个cheatsheet会随着我自己的学习进程不定期更新,我也会在适当时候给这些函数做一个归纳和分类,以便逐渐改善阅读体验。

pytorch 常用函数

torch.where(condition, x, y)
  • 如果满足condition,返回x中的元素,反之返回y中的元素
torch.clamp(x : torch.tensor, min, max)
  • 根据条件修改x中元素的值,如果小于min则改为min,大于max则改为max
torch.split(x : torch.tensor, split_size_or_sections : int or list, dim : int)
  • 这个函数的功能就是将一个tensor分块,分块后的各个子tensor保存在一个tuple里面返回。
  • 理解这个函数需要稍微花费一点功夫。这里面最关键的两个参数分别是:split_size_or_sections (暂且缩写成s) 和 dim。可以这么理解:
    • 当第二个参数为int时,该函数把输入x切分成dim维上大小为s的多个子tensor。举个栗子:
      >>> import torch
      >>> x = torch.rand(4,8,6)
      >>> y = torch.split(x,2,dim = 1)
      >>> for i in y:
      ...     print(y.shape)
      torch.Size([4, 2, 6])
      torch.Size([4, 2, 6])
      torch.Size([4, 2, 6])
      torch.Size([4, 2, 6])
      
      看见没,原来shape = [4,8,6] 的输入被切分成了4个shape = [4,2,6]的子tensor。
    • 当第二个参数为list时,该函数按照list每个元素指定的大小在dim维上分割输入数组。
      >>> z = torch.split(x,[1,2,5],dim=1)
      >>> for i in z:
      ...     print(i.shape)
      ...
      torch.Size([4, 1, 6])
      torch.Size([4, 2, 6])
      torch.Size([4, 5, 6])
      >>>
      
 类似资料: