当前位置: 首页 > 编程笔记 >

浅析PyTorch中nn.Linear的使用

寇开畅
2023-03-14
本文向大家介绍浅析PyTorch中nn.Linear的使用,包括了浅析PyTorch中nn.Linear的使用的使用技巧和注意事项,需要的朋友参考一下

查看源码

Linear 的初始化部分:

class Linear(Module):
 ...
 __constants__ = ['bias']
 
 def __init__(self, in_features, out_features, bias=True):
   super(Linear, self).__init__()
   self.in_features = in_features
   self.out_features = out_features
   self.weight = Parameter(torch.Tensor(out_features, in_features))
   if bias:
     self.bias = Parameter(torch.Tensor(out_features))
   else:
     self.register_parameter('bias', None)
   self.reset_parameters()
 ...
 

需要实现的内容:

计算步骤:

@weak_script_method
  def forward(self, input):
    return F.linear(input, self.weight, self.bias)

返回的是:input * weight + bias

对于 weight

weight: the learnable weights of the module of shape
  :math:`(\text{out\_features}, \text{in\_features})`. The values are
  initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
  :math:`k = \frac{1}{\text{in\_features}}`

对于 bias

bias:  the learnable bias of the module of shape :math:`(\text{out\_features})`.
    If :attr:`bias` is ``True``, the values are initialized from
    :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
    :math:`k = \frac{1}{\text{in\_features}}`

实例展示

举个例子:

>>> import torch
>>> nn1 = torch.nn.Linear(100, 50)
>>> input1 = torch.randn(140, 100)
>>> output1 = nn1(input1)
>>> output1.size()
torch.Size([140, 50])
 

张量的大小由 140 x 100 变成了 140 x 50

执行的操作是:

[140,100]×[100,50]=[140,50]

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持小牛知识库。

 类似资料:
  • 本文向大家介绍浅析Node.js 中 Stream API 的使用,包括了浅析Node.js 中 Stream API 的使用的使用技巧和注意事项,需要的朋友参考一下 本文由浅入深给大家介绍node.js stream api,具体详情请看下文吧。 基本介绍 在 Node.js 中,读取文件的方式有两种,一种是用 fs.readFile ,另外一种是利用 fs.createReadStream 来

  • 本文向大家介绍浅谈Pytorch中的torch.gather函数的含义,包括了浅谈Pytorch中的torch.gather函数的含义的使用技巧和注意事项,需要的朋友参考一下 pytorch中的gather函数 pytorch比tensorflow更加编程友好,所以准备用pytorch试着做最近要做的一些实验。 立个flag开始学习pytorch,新开一个分类整理学习pytorch中的一些踩到的泥

  • 本文向大家介绍ThinkPHP中U方法的使用浅析,包括了ThinkPHP中U方法的使用浅析的使用技巧和注意事项,需要的朋友参考一下 thinkPHP中U方法的定义规则如下(方括号内参数根据实际应用决定): U('[项目://][路由@][分组名-模块/]操作? 参数1=值1[&参数N=值N]') 或者用数组的方式传入参数: U('[项目://][路由@][分组名-模块/]操作',array('参数

  • 本文向大家介绍浅谈pytorch torch.backends.cudnn设置作用,包括了浅谈pytorch torch.backends.cudnn设置作用的使用技巧和注意事项,需要的朋友参考一下 cuDNN使用非确定性算法,并且可以使用torch.backends.cudnn.enabled = False来进行禁用 如果设置为torch.backends.cudnn.enabled =Tru

  • 本文向大家介绍Python yield 使用浅析,包括了Python yield 使用浅析的使用技巧和注意事项,需要的朋友参考一下 初学 Python 的开发者经常会发现很多 Python 函数中用到了 yield 关键字,然而,带有 yield 的函数执行流程却和普通函数不一样,yield 到底用来做什么,为什么要设计 yield ?本文将由浅入深地讲解 yield 的概念和用法,帮助读者体会

  • 本文向大家介绍Python中optparse模块使用浅析,包括了Python中optparse模块使用浅析的使用技巧和注意事项,需要的朋友参考一下 最近遇到一个问题,是指定参数来运行某个特定的进程,这很类似Linux中一些命令的参数了,比如ls -a,为什么加上-a选项会响应。optparse模块实现的也是类似的功能,它是为脚本传递命令参数。 使用此模块前,首先需要导入模块中的类OptionPar