functorch icon indicating copy to clipboard operation
functorch copied to clipboard

Different gradients for HyperNet training

Open bkoyuncu opened this issue 3 years ago • 2 comments

TLDR: Is there a way to optimize model created by combine_state_for_ensemble using torch.backward()?

Hi, I am using combine_state_for_ensemble for HyperNet training.

fmodel, fparams, fbuffers = combine_state_for_ensemble([HyperMLP() for i in range(K)])
[p.requires_grad_() for p in fparams];
weights_and_biases = vmap(fmodel)(fparams, fbuffers, z.expand(self.K,-1,-1)) #in which it parallizes over K

After I create the weights_and_biases, I put them into right shapes ws_and_bs and use as parameters of another ensemble.

fmodel, fparams, fbuffers = combine_state_for_ensemble([SimpleMLP() for i in range(K)])        
outputs = vmap(fmodel)(ws_and_bs, fbuffers, inputs)

This approach generates exactly the same outputs if I use loops instead of vmap. However, (somehow) their gradients are different.

loss = compute_loss(outputs)
loss.backward()

Do you have any idea why?

Update: It seems like ws_and_bs does not holding any gradient even though it is requires_grad.

Update2: It seems like I can forward by using stateless model with my generated weights but I cannot backprop from them using loss.backward(). Is there any trick that I can use?

bkoyuncu avatar Nov 30 '22 21:11 bkoyuncu

Hey @bkoyuncu,

Do you have a longer script that we can use to reproduce the problem?

From what you are saying, it sounds like you want ws_and_bs to get gradients. You can do this by detaching them and creating new leaf tensors in the autograd graph:

def create_new_leaf(x):
  return x.detach().requires_grad_()

torch.utils._pytree.tree_map(create_new_leaf, ws_and_bs)

or by using Tensor.retain_grad.

zou3519 avatar Dec 02 '22 15:12 zou3519

Thank you so much for the suggestion @zou3519, I will check this and get back to you!

bkoyuncu avatar Dec 03 '22 13:12 bkoyuncu