net.apply(fn: Callable[[ForwardRef(‘Module’)], NoneType]) -> ~T
model.apply(fn)
会递归地将函数 fn
应用到父模块的每个子模块以及model这个父模块自身。通常用于初始化模型的参数。fn (:class:`Module` -> None)
->将应用于每个子模块的函数Module: self
来看以下示例:
@torch.no_grad()
def init_weights(m):
print(m)
if type(m) == nn.Linear:
m.weight.fill_(1.0)
print(m.weight)
net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)
在这个网络示例中,模块 net
有两个子模块,均为 Linear(2,4)
。函数首先对这两个子模块调用 init_weights
函数,然后再对 net
模块进行同样的操作。
会打印以下信息:
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[ 1., 1.],
[ 1., 1.]])
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[ 1., 1.],
[ 1., 1.]])
Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
)
Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
)
可以看见的使,不仅仅分别打印两个Linear,还加上父模块自身自己父模块的返回Sequential。
或者下面这个例子:
net = nn.Sequential(nn.Flatten(),
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10))
def init_weights(m):
print(m)
net.apply(init_weights)
会打印以下信息:
Flatten(start_dim=1, end_dim=-1)
Linear(in_features=784, out_features=256, bias=True)
ReLU()
Linear(in_features=256, out_features=10, bias=True)
Sequential(
(0): Flatten(start_dim=1, end_dim=-1)
(1): Linear(in_features=784, out_features=256, bias=True)
(2): ReLU()
(3): Linear(in_features=256, out_features=10, bias=True)
)
如果我们想对某些特定的子模块submodule做一些针对性的处理,该怎么做呢?我们可以加入type(m) == nn.Linear:
这类判断语句,从而对特定子模块m进行处理:
net = nn.Sequential(nn.Flatten(),
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10))
def init_weights(m):
if type(m) == nn.Linear:
nn.init.normal_(m.weight, std=0.01)
net.apply(init_weights)