functorch
functorch copied to clipboard
Different gradients for HyperNet training
TLDR: Is there a way to optimize model created by combine_state_for_ensemble using torch.backward()?
Hi, I am using combine_state_for_ensemble for HyperNet training.
fmodel, fparams, fbuffers = combine_state_for_ensemble([HyperMLP() for i in range(K)])
[p.requires_grad_() for p in fparams];
weights_and_biases = vmap(fmodel)(fparams, fbuffers, z.expand(self.K,-1,-1)) #in which it parallizes over K
After I create the weights_and_biases, I put them into right shapes ws_and_bs and use as parameters of another ensemble.
fmodel, fparams, fbuffers = combine_state_for_ensemble([SimpleMLP() for i in range(K)])
outputs = vmap(fmodel)(ws_and_bs, fbuffers, inputs)
This approach generates exactly the same outputs if I use loops instead of vmap. However, (somehow) their gradients are different.
loss = compute_loss(outputs)
loss.backward()
Do you have any idea why?
Update: It seems like ws_and_bs does not holding any gradient even though it is requires_grad.
Update2: It seems like I can forward by using stateless model with my generated weights but I cannot backprop from them using loss.backward(). Is there any trick that I can use?
Hey @bkoyuncu,
Do you have a longer script that we can use to reproduce the problem?
From what you are saying, it sounds like you want ws_and_bs to get gradients. You can do this by detaching them and creating new leaf tensors in the autograd graph:
def create_new_leaf(x):
return x.detach().requires_grad_()
torch.utils._pytree.tree_map(create_new_leaf, ws_and_bs)
or by using Tensor.retain_grad.
Thank you so much for the suggestion @zou3519, I will check this and get back to you!