338 字
2 分钟
参数管理
5.2 参数管理
在确定了模型的“层”和“块”后,模型的基本框架就完成了。
下一步,寻找损失函数最小化的参数。
5.2.1 参数访问
torch.nn.Module.state_dict()
:访问Module类下的所有关于模型状态的信息(参数等)。
# 例子OrderedDict([('weight', tensor([[-0.3309, -0.1469, -0.1062, -0.2565, 0.0803, -0.2798, 0.2136, 0.1604]])), ('bias', tensor([-0.1101]))])
torch.nn.Linear.bias
, torch.nn.Linear.weight
来直接访问指定模型的参数。
print(type(net[2].bias))print(net[2].bias)print(net[2].bias.data)
net.named_parameters()
,以字典形式,输出所有参数,包括子数组的参数。key为参数名,value为参数本身。
对于嵌套的块,使用
named_parameters()
,会按照顺序诸葛输出每个子块的参数信息。
5.3 参数初始化
nn.init
下包括许多初始化方案,能够对参数进行初始化。
https://pytorch.org/docs/stable/nn.init.html#torch-nn-init
init.normal_()
init.zeros_()
init.constant_()
通常可以自定义一个初始化函数来实现指定类型模型的初始化.
def init_function(module): if type(module) == nn.Linear nn.init.normal_(module.weight, mean=0, std=0.01) # init的参数不会参与到计算梯度的过程。 nn.init.zeros_(module.bias)
module.apply()
,让模型使用指定的初始化函数来初始化参数。
net.apply(init_function)
延后初始化
在训练的过程中,深度学习框架无法直接判断网络的输入维度,因此会采用一种延后初始化的方法。
延后初始化:当数据第一次通过模型时,框架才会动态地推断每个层的大小。