functorch
functorch copied to clipboard
How to update the original model parameters after calling make_functional?
As per the title, I find that updating the tensors pointed by the params
returned by make_functional
does not update the real parameters in the original model.
Is there a way to do this? I find that it would be extremely useful to implement optimization algorithms in a way that is more similar to their mathematical description.
To provide more context I add an example script of what standard Gradient Descent should look like in this way:
import torch
from torch import nn
from functorch import make_functional
learning_rate = 0.1
def optstep(params, jacobians):
with torch.no_grad():
for i, param in enumerate(params):
param.add_(jacobians[i], alpha=-learning_rate)
if __name__ == '__main__':
model = nn.Linear(3, 5)
x, targets = torch.randn(2, 3), torch.randn(2, 5)
criterion = nn.MSELoss()
print("INITIAL LOSS:", criterion(model(x), targets).item())
# Render the model functional and compute the jacobian
func_model, params = make_functional(model)
def f(*params):
out = func_model(params, x)
return criterion(out, targets)
jacobian = torch.autograd.functional.jacobian(f, params)
# Ideally would train on the current input
optstep(params, jacobian)
# Now compute the new loss
print("NEW LOSS:", criterion(model(x), targets).item())
Executing the script shows that the parameters are not updated since the loss doesn't change
INITIAL LOSS: 1.2894147634506226
NEW LOSS: 1.2894147634506226
After looking a bit in the source code I've found functorch._src.make_functional.extract_weights
and load_weights
which allow me to do exactly what I wanted to do.
Maybe those methods can be exposed and documented to allow the suggested use case?
Couldn't you do
def optstep(model, jacobians):
with torch.no_grad():
for i, param in enumerate(model.parameters()):
param.add_(jacobians[i], alpha=-learning_rate)
?
(Also, you might want to try functorch.jacrev instead of torch.autograd.functional.jacobian -- it may be faster)
Is model.parameters() guaranteed to return parameters in the same order of make_functional?
If this is the case then I can surely do this, however I would like to ask that it is documented as proper behaviour on which one can rely on.
Thank you very much
Is model.parameters() guaranteed to return parameters in the same order of make_functional?
Yes
If this is the case then I can surely do this, however I would like to ask that it is documented as proper behaviour on which one can rely on.
Yes, we should document this
Thank you very much again for all this work. I think the issue can be closed as soon as the behaviour is documented.
@trenta3 out of curiosity, what are you using make_functional
for? Are you using any of the other functorch APIs?
I'm currently using make_functional as well as other functorch APIs, in particular jvp and jacrev to easily write more complex optimizers that need to consider also second order information of a neural network, which is unmanageable to do in pytorch. Earlier this year I had the need to compute eigenvectors of the linearizations of some neural networks, and the ability to obtain gradients for each example separately was crucial.
If I must say it, a thing that I miss is the ability to "lazily" compute parts of the hessian, like extracting its diagonal, without using the full memory (and compute) requirement to calculate the whole hessian. More generally the ability for a pytorch user to manipulate "lazy tensors" (i.e. a thunk of computation depending on some data, but which is not eagerly executed) would be extremely useful to compute the diagonal of the hessian, as well as a lot of computations on kernel methods (like pyKeops does), but I sincerly don't know how much this can be made efficient.
Hi! Thanks a lot for building this awesome functorch!
I have the same issue. I'm using fmodel, params, buffers = combine_state_for_ensemble(models)
to stack models and optimizing the params in a training loop. After this, I wish to update each origin model's state_dict(). I can't find a nice way to achieve this. Actually what I am doing is
with torch.no_grad():
for idx, model in enumerate(models):
for i, param in enumerate(model.parameters()):
param.set_(params[i][idx])
Hope I can get a nicer way to achieve this with a good tutorial. Thanks!
@kxhit thank you for your feedback. Could you give a little more context about why you want to update each original model's state_dict?
@zou3519 Hi, thanks for your quick reply.
In my case, I'm training many tiny networks and need to use the up-to-date network's weights every a few steps. So I need to assign batch weights back to the original models frequently.