functorch icon indicating copy to clipboard operation
functorch copied to clipboard

How to update the original model parameters after calling make_functional?

Open trenta3 opened this issue 3 years ago • 10 comments

As per the title, I find that updating the tensors pointed by the params returned by make_functional does not update the real parameters in the original model. Is there a way to do this? I find that it would be extremely useful to implement optimization algorithms in a way that is more similar to their mathematical description.

To provide more context I add an example script of what standard Gradient Descent should look like in this way:

import torch
from torch import nn
from functorch import make_functional

learning_rate = 0.1

def optstep(params, jacobians):
    with torch.no_grad():
        for i, param in enumerate(params):
            param.add_(jacobians[i], alpha=-learning_rate)

if __name__ == '__main__':
    model = nn.Linear(3, 5)
    x, targets = torch.randn(2, 3), torch.randn(2, 5)
    criterion = nn.MSELoss()

    print("INITIAL LOSS:", criterion(model(x), targets).item())
    # Render the model functional and compute the jacobian                                                           
    func_model, params = make_functional(model)
    def f(*params):
        out = func_model(params, x)
        return criterion(out, targets)
    jacobian = torch.autograd.functional.jacobian(f, params)

    # Ideally would train on the current input                                                                       
    optstep(params, jacobian)
    # Now compute the new loss                                                                                       
    print("NEW LOSS:", criterion(model(x), targets).item())

Executing the script shows that the parameters are not updated since the loss doesn't change

INITIAL LOSS: 1.2894147634506226
NEW LOSS: 1.2894147634506226

trenta3 avatar Nov 19 '21 08:11 trenta3

After looking a bit in the source code I've found functorch._src.make_functional.extract_weights and load_weights which allow me to do exactly what I wanted to do. Maybe those methods can be exposed and documented to allow the suggested use case?

trenta3 avatar Nov 19 '21 09:11 trenta3

Couldn't you do

def optstep(model, jacobians):
    with torch.no_grad():
        for i, param in enumerate(model.parameters()):
            param.add_(jacobians[i], alpha=-learning_rate)

?

(Also, you might want to try functorch.jacrev instead of torch.autograd.functional.jacobian -- it may be faster)

zou3519 avatar Nov 19 '21 22:11 zou3519

Is model.parameters() guaranteed to return parameters in the same order of make_functional?

If this is the case then I can surely do this, however I would like to ask that it is documented as proper behaviour on which one can rely on.

Thank you very much

trenta3 avatar Nov 20 '21 10:11 trenta3

Is model.parameters() guaranteed to return parameters in the same order of make_functional?

Yes

If this is the case then I can surely do this, however I would like to ask that it is documented as proper behaviour on which one can rely on.

Yes, we should document this

zou3519 avatar Nov 22 '21 02:11 zou3519

Thank you very much again for all this work. I think the issue can be closed as soon as the behaviour is documented.

trenta3 avatar Nov 22 '21 08:11 trenta3

@trenta3 out of curiosity, what are you using make_functional for? Are you using any of the other functorch APIs?

zou3519 avatar Nov 29 '21 20:11 zou3519

I'm currently using make_functional as well as other functorch APIs, in particular jvp and jacrev to easily write more complex optimizers that need to consider also second order information of a neural network, which is unmanageable to do in pytorch. Earlier this year I had the need to compute eigenvectors of the linearizations of some neural networks, and the ability to obtain gradients for each example separately was crucial.

If I must say it, a thing that I miss is the ability to "lazily" compute parts of the hessian, like extracting its diagonal, without using the full memory (and compute) requirement to calculate the whole hessian. More generally the ability for a pytorch user to manipulate "lazy tensors" (i.e. a thunk of computation depending on some data, but which is not eagerly executed) would be extremely useful to compute the diagonal of the hessian, as well as a lot of computations on kernel methods (like pyKeops does), but I sincerly don't know how much this can be made efficient.

trenta3 avatar Nov 30 '21 20:11 trenta3

Hi! Thanks a lot for building this awesome functorch!

I have the same issue. I'm using fmodel, params, buffers = combine_state_for_ensemble(models) to stack models and optimizing the params in a training loop. After this, I wish to update each origin model's state_dict(). I can't find a nice way to achieve this. Actually what I am doing is

with torch.no_grad():
    for idx, model in enumerate(models):
        for i, param in enumerate(model.parameters()):
            param.set_(params[i][idx])

Hope I can get a nicer way to achieve this with a good tutorial. Thanks!

kxhit avatar Apr 13 '22 15:04 kxhit

@kxhit thank you for your feedback. Could you give a little more context about why you want to update each original model's state_dict?

zou3519 avatar Apr 13 '22 21:04 zou3519

@zou3519 Hi, thanks for your quick reply.

In my case, I'm training many tiny networks and need to use the up-to-date network's weights every a few steps. So I need to assign batch weights back to the original models frequently.

kxhit avatar Apr 13 '22 22:04 kxhit