Custom weight initialization in PyTorch

Multi tool use
Custom weight initialization in PyTorch
What would be the right way to implement a custom weight initialization
method in PyTorch
?
custom weight initialization
PyTorch
I believe I can't directly add any method to 'torch.nn.init` but wish to initialize my model's weights with my own proprietary method.
2 Answers
2
You can define a method to initialize the weights according to each layer:
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
And then just apply it to your network:
model = create_your_model()
model.apply(weights_init)
See https://discuss.pytorch.org/t/how-to-initialize-weights-bias-of-rnn-lstm-gru/2879/2 for reference.
You can do
weight_dict = net.state_dict()
new_weight_dict = {}
for param_key in state_dict:
# custom initialization in new_weight_dict,
# You can initialize partially i.e only some of the variables and let others stay as it is
weight_dict.update(new_weight_dict)
net.load_state_dict(new_weight_dict)
By clicking "Post Your Answer", you acknowledge that you have read our updated terms of service, privacy policy and cookie policy, and that your continued use of the website is subject to these policies.