torchopt icon indicating copy to clipboard operation
torchopt copied to clipboard

[POC] functorch integration

Open vmoens opened this issue 2 years ago • 5 comments

This is a proof of concept of integrating functorch in MAML.

vmoens avatar Apr 13 '22 14:04 vmoens

I see your point @Benjamin-eecs I see little advantage of the current API though since the users still need to save params in a state_dict and pass it back to their model afterwards. Hence I don't really see how it hides away the magic of meta learning algos. Can you elaborate on what the advantage is in your opinion?

Are you aware that functorch is now part of torch core? As such you don't need an extra dependency.

vmoens avatar Aug 11 '22 14:08 vmoens

@vmoens Hi vincent, have you checked our low-level API? You can find information in our README and also our doc https://torchopt.readthedocs.io/en/latest/api/api.html#functional-optimizers. The extract_state api is mainly for the pytorch-like API (which is the high-level API). For instance, when I am conducting iterative multi-task training for maml, i need to reset my neural network parameter to the initial one. Thus we need such api.

waterhorse1 avatar Aug 12 '22 01:08 waterhorse1

Thanks for this @waterhorse1 I'm familiar with how the low and high level APIs work. I'm just trying to point out that from a user perspective, the following two code snippets require the same mental effort, the same low level understanding of what a meta-learning algorithm is and does and the same amount of coding (the number of lines of code is identical).

Example 1 (current)

        policy_state_dict = torchopt.extract_state_dict(policy)
        for idx in range(TASK_NUM):
            for _ in range(inner_iters):
                pre_trajs = sample_traj(env, tasks[idx], policy)
                inner_loss = a2c_loss(pre_trajs, policy, value_coef=0.5)
                inner_opt.step(inner_loss)
            post_trajs = sample_traj(env, tasks[idx], policy)
            outer_loss = a2c_loss(post_trajs, policy, value_coef=0.5)
            outer_loss.backward()
            torchopt.recover_state_dict(policy, policy_state_dict)

Example 2 (functorch)

        fpolicy, policy_params = functorch.make_functional(policy)
        for idx in range(TASK_NUM):
            policy_params_new = policy_params
            for _ in range(inner_iters):
                pre_trajs = sample_traj(env, tasks[idx], fpolicy, policy_params_new)
                inner_loss = a2c_loss(pre_trajs, fpolicy, policy_params_new, value_coef=0.5)
                policy_params_new = inner_opt.step(inner_loss, policy_params_new)
            post_trajs = sample_traj(env, tasks[idx], fpolicy, policy_params_new)
            outer_loss = a2c_loss(post_trajs, fpolicy, policy_params_new, value_coef=0.5)
            outer_loss.backward()
#            not necessary since policy parameters need not to be reset
#            torchopt.recover_state_dict(policy, policy_state_dict)

It is my personal opinion that functorch offers more clarity on what is happening: with the current API, trying to make "as if" everything was like another pytorch optimisers may push users to overlook the extract_state_dict and recover_state_dict, or worse, use them where they should not. I do not think using functorch makes it less clear, on the contrary. From an OOP perspective, the current API shows to the user twice the same object (e.g. the policy), once with a set of regular parameters, once with tensors that are not parameters anymore but the result of some optimization. I personally find it confusing, as I am looking at the same object, but with different content, and it is not apparent that the objects that once were attributes of it aren't anymore:

        policy_state_dict = torchopt.extract_state_dict(policy) << HERE POLICY HAS nn.Parameters ATTRIBUTES
        for idx in range(TASK_NUM):
            for _ in range(inner_iters):
                pre_trajs = sample_traj(env, tasks[idx], policy)
                inner_loss = a2c_loss(pre_trajs, policy, value_coef=0.5)
                inner_opt.step(inner_loss) << AFTER THIS POLICY HASN'T nn.Parameters ANYMORE
            post_trajs = sample_traj(env, tasks[idx], policy)
            outer_loss = a2c_loss(post_trajs, policy, value_coef=0.5)
            outer_loss.backward()
            torchopt.recover_state_dict(policy, policy_state_dict)<< AFTER POLICY HAS nn.Parameters ATTRIBUTES AGAIN

Basically, we're showing to the user one thing that is not one thing but 2, and this may lead users to expect behaviours that are not going to work in practice (e.g. what should parameters() return? In "normal" pytorch this iterator will always return the very same list of items).

vmoens avatar Aug 12 '22 08:08 vmoens

@vmoens Hi vincent, thanks for your advice. We have a dicussion about what you mention and here is what we get:

For the functional high-level api, we can easily do the thing in your snippest by building a warpper to warp some low-level api (including torch.autograd.grad, optimizer.update and apply_updates). Bo is working on that in FuncOptimizer.

For the OOP api, it's impossible to make the parameters within the inner-loop process still nn.Parameters because they are non-leaf nodes. We can offer an alternative solution to that by warpping the tensor to become nn.Parameters. It's still a tensor but you can treat it as nn.Parameters.

waterhorse1 avatar Aug 12 '22 16:08 waterhorse1

Codecov Report

Merging #6 (fa2a38c) into main (d8f90cc) will increase coverage by 0.46%. The diff coverage is 86.20%.

@@            Coverage Diff             @@
##             main       #6      +/-   ##
==========================================
+ Coverage   70.16%   70.63%   +0.46%     
==========================================
  Files          31       33       +2     
  Lines        1391     1420      +29     
==========================================
+ Hits          976     1003      +27     
- Misses        415      417       +2     
Flag Coverage Δ
unittests 70.63% <86.20%> (+0.46%) :arrow_up:

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
torchopt/_src/alias.py 77.67% <ø> (+0.89%) :arrow_up:
torchopt/_src/optimizer/meta/base.py 31.42% <0.00%> (-1.91%) :arrow_down:
torchopt/_src/optimizer/base.py 84.61% <50.00%> (-1.39%) :arrow_down:
torchopt/_src/optimizer/func/base.py 95.45% <95.45%> (ø)
torchopt/__init__.py 100.00% <100.00%> (ø)
torchopt/_src/optimizer/__init__.py 100.00% <100.00%> (ø)
torchopt/_src/optimizer/func/__init__.py 100.00% <100.00%> (ø)
torchopt/_src/transform.py 81.01% <0.00%> (+0.31%) :arrow_up:

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

codecov-commenter avatar Aug 14 '22 09:08 codecov-commenter

@vmoens Hi vincent, thanks for your advice. We have a dicussion about what you mention and here is what we get:

For the functional high-level api, we can easily do the thing in your snippest by building a warpper to warp some low-level api (including torch.autograd.grad, optimizer.update and apply_updates). Bo is working on that in FuncOptimizer.

For the OOP api, it's impossible to make the parameters within the inner-loop process still nn.Parameters because they are non-leaf nodes. We can offer an alternative solution to that by warpping the tensor to become nn.Parameters. It's still a tensor but you can treat it as nn.Parameters.

Sure this is what I had in mind! Obviously we can't work with non-leaf nn.Parameters :-) Great work guys, I love it

vmoens avatar Sep 09 '22 20:09 vmoens