Custom weight initialization in PyTorch

Multi tool use
Multi tool use


Custom weight initialization in PyTorch



What would be the right way to implement a custom weight initialization method in PyTorch?


custom weight initialization


PyTorch



I believe I can't directly add any method to 'torch.nn.init` but wish to initialize my model's weights with my own proprietary method.




2 Answers
2



You can define a method to initialize the weights according to each layer:


def weights_init(m):
classname = m.__class__.__name__

if classname.find('Conv2d') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)



And then just apply it to your network:


model = create_your_model()
model.apply(weights_init)



See https://discuss.pytorch.org/t/how-to-initialize-weights-bias-of-rnn-lstm-gru/2879/2 for reference.



You can do


weight_dict = net.state_dict()
new_weight_dict = {}
for param_key in state_dict:
# custom initialization in new_weight_dict,
# You can initialize partially i.e only some of the variables and let others stay as it is
weight_dict.update(new_weight_dict)
net.load_state_dict(new_weight_dict)






By clicking "Post Your Answer", you acknowledge that you have read our updated terms of service, privacy policy and cookie policy, and that your continued use of the website is subject to these policies.

Vf57WJqfbSXAlBB a,X rLRg,z,aZkkwnx5Zm muQqVnHPBO7F1mVNlEEyXKV X zE,vpx,6A996eTOliCWueC a V
by,eQtksFq 8SlCXlc0d,nGOOGGUh,GOLOEedE,IQ

Popular posts from this blog

Rothschild family

Cinema of Italy