以下是本人根据Pytorch学习过程中总结出的经验,如果有错误,请指正。
在Pytorch建立神经元网络模型的时候,经常用到forward方法,表示在建立模型后,进行神经元网络的前向传播。说的直白点,forward就是专门用来计算给定输入,得到神经元网络输出的方法。
在代码实现中,也是用def forward
来写forward前向传播的方法,我原来以为这是一种约定熟成的名字,也可以换成任意一个自己喜欢的名字。
但是看的多了之后发现并非如此:Pytorch对于forward方法赋予了一些特殊“功能”
(这里不禁再吐槽,一些看起来挺厉害的Pytorch“大神”,居然不知道这个。。。只能草草解释一下:“就是这样的。。。”)
我最开始发现forward()的与众不同之处就是在此,首先举个例子:
import torch.nn as nn
class test(nn.Module):
def __init__(self, input):
super(test,self).__init__()
self.input = input
def forward(self,x):
return self.input * x
T = test(8)
print(T(6))
# print(T.forward(6))
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:/Users/Lenovo/Desktop/DL/pythonProject/tt.py
48
Process finished with exit code 0
可以发现,T(6)是可以输出的!而且不用指定,默认了调用forward方法
。当然如果非要写上.forward()这也是可以正常运行的,和不写是一样的。
如果不调用Pytorch(正常的Python语法规则),这样肯定会报错的
# import torch.nn as nn #不再调用torch
class test():
def __init__(self, input):
self.input = input
def forward(self,x):
return self.input * x
T = test(8)
print(T.forward(6))
print("************************")
print(T(6))
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:/Users/Lenovo/Desktop/DL/pythonProject/tt.py
48
************************
Traceback (most recent call last):
File "C:\Users\Lenovo\Desktop\DL\pythonProject\tt.py", line 77, in <module>
print(T(6))
TypeError: 'test' object is not callable
Process finished with exit code 1
这里会报:‘test’ object is not callable
因为class不能被直接调用,不知道你想调用哪个方法。
如果在class中再增加一个方法:
import torch.nn as nn
class test(nn.Module):
def __init__(self, input):
super(test,self).__init__()
self.input = input
def byten(self):
return self.input * 10
def forward(self,x):
return self.input * x
T = test(8)
print(T(6))
print(T.byten())
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:/Users/Lenovo/Desktop/DL/pythonProject/tt.py
48
80
Process finished with exit code 0
可以见到,在class中有多个method的时候,如果不指定method,forward是会被优先执行的。
在Pytorch中,forward方法是一个特殊的方法,被专门用来进行前向传播。