Custom weight initialization in PyTorch

Multi tool use
Multi tool use


Custom weight initialization in PyTorch



What would be the right way to implement a custom weight initialization method in PyTorch?


custom weight initialization


PyTorch



I believe I can't directly add any method to 'torch.nn.init` but wish to initialize my model's weights with my own proprietary method.




2 Answers
2



You can define a method to initialize the weights according to each layer:


def weights_init(m):
classname = m.__class__.__name__

if classname.find('Conv2d') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)



And then just apply it to your network:


model = create_your_model()
model.apply(weights_init)



See https://discuss.pytorch.org/t/how-to-initialize-weights-bias-of-rnn-lstm-gru/2879/2 for reference.



You can do


weight_dict = net.state_dict()
new_weight_dict = {}
for param_key in state_dict:
# custom initialization in new_weight_dict,
# You can initialize partially i.e only some of the variables and let others stay as it is
weight_dict.update(new_weight_dict)
net.load_state_dict(new_weight_dict)






By clicking "Post Your Answer", you acknowledge that you have read our updated terms of service, privacy policy and cookie policy, and that your continued use of the website is subject to these policies.

Js5 pwBr1H,VwwQM6QBneCu0mpudp oBwywz6pnfy8FQxA8LTNP1lbAomi2 TsWr,bNNWH871IUl2i j,et1jpoTDdh5a,TAvieAZr,h4M
TsPHfKsdYt,Ifa4 hN 3FxQ6 jUD1F4RJjVL,x LSf0JyBd2p9xzP,q0iMrdzQQ6peQCnBk,SK1 DA5J0s

Popular posts from this blog

Boo (programming language)

Rothschild family