torchopt
torchopt copied to clipboard
[POC] functorch integration
This is a proof of concept of integrating functorch in MAML.
I see your point @Benjamin-eecs I see little advantage of the current API though since the users still need to save params in a state_dict and pass it back to their model afterwards. Hence I don't really see how it hides away the magic of meta learning algos. Can you elaborate on what the advantage is in your opinion?
Are you aware that functorch is now part of torch core? As such you don't need an extra dependency.
@vmoens Hi vincent, have you checked our low-level API? You can find information in our README and also our doc https://torchopt.readthedocs.io/en/latest/api/api.html#functional-optimizers. The extract_state api is mainly for the pytorch-like API (which is the high-level API). For instance, when I am conducting iterative multi-task training for maml, i need to reset my neural network parameter to the initial one. Thus we need such api.
Thanks for this @waterhorse1 I'm familiar with how the low and high level APIs work. I'm just trying to point out that from a user perspective, the following two code snippets require the same mental effort, the same low level understanding of what a meta-learning algorithm is and does and the same amount of coding (the number of lines of code is identical).
Example 1 (current)
policy_state_dict = torchopt.extract_state_dict(policy)
for idx in range(TASK_NUM):
for _ in range(inner_iters):
pre_trajs = sample_traj(env, tasks[idx], policy)
inner_loss = a2c_loss(pre_trajs, policy, value_coef=0.5)
inner_opt.step(inner_loss)
post_trajs = sample_traj(env, tasks[idx], policy)
outer_loss = a2c_loss(post_trajs, policy, value_coef=0.5)
outer_loss.backward()
torchopt.recover_state_dict(policy, policy_state_dict)
Example 2 (functorch)
fpolicy, policy_params = functorch.make_functional(policy)
for idx in range(TASK_NUM):
policy_params_new = policy_params
for _ in range(inner_iters):
pre_trajs = sample_traj(env, tasks[idx], fpolicy, policy_params_new)
inner_loss = a2c_loss(pre_trajs, fpolicy, policy_params_new, value_coef=0.5)
policy_params_new = inner_opt.step(inner_loss, policy_params_new)
post_trajs = sample_traj(env, tasks[idx], fpolicy, policy_params_new)
outer_loss = a2c_loss(post_trajs, fpolicy, policy_params_new, value_coef=0.5)
outer_loss.backward()
# not necessary since policy parameters need not to be reset
# torchopt.recover_state_dict(policy, policy_state_dict)
It is my personal opinion that functorch offers more clarity on what is happening: with the current API, trying to make "as if" everything was like another pytorch optimisers may push users to overlook the extract_state_dict
and recover_state_dict
, or worse, use them where they should not. I do not think using functorch makes it less clear, on the contrary.
From an OOP perspective, the current API shows to the user twice the same object (e.g. the policy), once with a set of regular parameters, once with tensors that are not parameters anymore but the result of some optimization. I personally find it confusing, as I am looking at the same object, but with different content, and it is not apparent that the objects that once were attributes of it aren't anymore:
policy_state_dict = torchopt.extract_state_dict(policy) << HERE POLICY HAS nn.Parameters ATTRIBUTES
for idx in range(TASK_NUM):
for _ in range(inner_iters):
pre_trajs = sample_traj(env, tasks[idx], policy)
inner_loss = a2c_loss(pre_trajs, policy, value_coef=0.5)
inner_opt.step(inner_loss) << AFTER THIS POLICY HASN'T nn.Parameters ANYMORE
post_trajs = sample_traj(env, tasks[idx], policy)
outer_loss = a2c_loss(post_trajs, policy, value_coef=0.5)
outer_loss.backward()
torchopt.recover_state_dict(policy, policy_state_dict)<< AFTER POLICY HAS nn.Parameters ATTRIBUTES AGAIN
Basically, we're showing to the user one thing that is not one thing but 2, and this may lead users to expect behaviours that are not going to work in practice (e.g. what should parameters()
return? In "normal" pytorch this iterator will always return the very same list of items).
@vmoens Hi vincent, thanks for your advice. We have a dicussion about what you mention and here is what we get:
For the functional high-level api, we can easily do the thing in your snippest by building a warpper to warp some low-level api (including torch.autograd.grad, optimizer.update and apply_updates). Bo is working on that in FuncOptimizer.
For the OOP api, it's impossible to make the parameters within the inner-loop process still nn.Parameters because they are non-leaf nodes. We can offer an alternative solution to that by warpping the tensor to become nn.Parameters. It's still a tensor but you can treat it as nn.Parameters.
Codecov Report
Merging #6 (fa2a38c) into main (d8f90cc) will increase coverage by
0.46%
. The diff coverage is86.20%
.
@@ Coverage Diff @@
## main #6 +/- ##
==========================================
+ Coverage 70.16% 70.63% +0.46%
==========================================
Files 31 33 +2
Lines 1391 1420 +29
==========================================
+ Hits 976 1003 +27
- Misses 415 417 +2
Flag | Coverage Δ | |
---|---|---|
unittests | 70.63% <86.20%> (+0.46%) |
:arrow_up: |
Flags with carried forward coverage won't be shown. Click here to find out more.
Impacted Files | Coverage Δ | |
---|---|---|
torchopt/_src/alias.py | 77.67% <ø> (+0.89%) |
:arrow_up: |
torchopt/_src/optimizer/meta/base.py | 31.42% <0.00%> (-1.91%) |
:arrow_down: |
torchopt/_src/optimizer/base.py | 84.61% <50.00%> (-1.39%) |
:arrow_down: |
torchopt/_src/optimizer/func/base.py | 95.45% <95.45%> (ø) |
|
torchopt/__init__.py | 100.00% <100.00%> (ø) |
|
torchopt/_src/optimizer/__init__.py | 100.00% <100.00%> (ø) |
|
torchopt/_src/optimizer/func/__init__.py | 100.00% <100.00%> (ø) |
|
torchopt/_src/transform.py | 81.01% <0.00%> (+0.31%) |
:arrow_up: |
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.
@vmoens Hi vincent, thanks for your advice. We have a dicussion about what you mention and here is what we get:
For the functional high-level api, we can easily do the thing in your snippest by building a warpper to warp some low-level api (including torch.autograd.grad, optimizer.update and apply_updates). Bo is working on that in FuncOptimizer.
For the OOP api, it's impossible to make the parameters within the inner-loop process still nn.Parameters because they are non-leaf nodes. We can offer an alternative solution to that by warpping the tensor to become nn.Parameters. It's still a tensor but you can treat it as nn.Parameters.
Sure this is what I had in mind! Obviously we can't work with non-leaf nn.Parameters :-) Great work guys, I love it