optax icon indicating copy to clipboard operation
optax copied to clipboard

Create an example for documentation for freezing parameters

Open rosshemsley opened this issue 3 years ago • 14 comments
trafficstars

Users often ask how to freeze parameters (https://github.com/deepmind/optax/issues/290).

It would be nice to add a runnable example to the documentation under Examples (next to the meta-learning example) that shows how to do this.

rosshemsley avatar Feb 03 '22 10:02 rosshemsley

@rosshemsley Do you have a specific "setting" of interest? Or do you want just the most basic implementation possible?

pharringtonp19 avatar Feb 03 '22 14:02 pharringtonp19

I suspect there's more than one valid way to freeze params, and we could potentially show each option in a notebook.

@mkunesch also gave some ideas over in https://github.com/deepmind/optax/issues/290.

rosshemsley avatar Feb 03 '22 15:02 rosshemsley

@rosshemsley As a starting point, would it be helpful to extend the meta-learning example to show case how parameters can be frozen during the inner loop? I would be happy to write something up this weekend.

pharringtonp19 avatar Feb 04 '22 00:02 pharringtonp19

Hey @pharringtonp19, that would be welcome!

Sounds good! We could use the "meta-learning trick" that means the example can be a notebook - but for the actual code it might be good to use the most simple possible model, so that users who are just learning about deep learning can follow the example, without having to worry about the meta learning part.

(maybe one of the examples from https://github.com/deepmind/optax/tree/master/examples could be made to work?)

I don't know if anyone has tried building the docs outside of the main contributors team, so do let us know if you encounter any difficulties!

rosshemsley avatar Feb 04 '22 08:02 rosshemsley

@rosshemsley I admit, I am not entirely sure what you mean by

building the docs outside of the main contributors team

I have never contributed to a repository before, so I am eager to see how this works. I'll start (on Thursday) with a small colab example, and then go from there (if that makes sense) -- apologies for the delay

pharringtonp19 avatar Feb 08 '22 21:02 pharringtonp19

@rosshemsley - If you think the following is suitable, I would be happy to write up a more detailed tutorial.

Here is a colab notebook where we fit a linear model via maximum likelihood. In this notebook, we show how to freeze parameters by applying the set_to_zero gradient transformation to the nuisance parameter sigma. As is expected, we recover the least squares estimate for beta, showcasing that MLE coincides with the Least Squares Estimate.

-- The inspiration for this notebook was this very nice example of Jax + MLE

pharringtonp19 avatar Feb 10 '22 19:02 pharringtonp19

Hey @pharringtonp19 , sorry for dropping your previous message!

@rosshemsley I admit, I am not entirely sure what you mean

Sorry it wasn't that clear indeed! - In the Optax docs, we use the "trick" that you can convert a notebook into a documentation page - the code and content is nicely formatted into a webpage, see e.g. https://optax.readthedocs.io/en/latest/optax-101.html, which is generated automatically from a colab.

There's a command for running this process locally to check it works and visualize the result, but I don't think anyone outside of Alphabet has run this before, so it may not be that obvious how to get it working.

Does that help? (One implication is that it can make sense to add a lot more text into the example, so that readers can "read along" with the example a bit more).

Thanks for the example! I will take a look now, and forward it along to the other optax maintainers.

rosshemsley avatar Feb 11 '22 10:02 rosshemsley

Hi! I think the example looks great in terms of demonstrating how to use multi_transform with set_to_zero.

In this particular example, since sigma is never optimized it might be worth pointing out that we could also use jax.lax.stop_gradient (or just not include sigma in params). Or we could add an example where the freezing only happens for certain steps (e.g. the first N) but that could be an additional example we add in the future. What do you think?

mkunesch avatar Feb 14 '22 11:02 mkunesch

@mkunesch I think my example is a rather poor one for the reasons that you point out -- the simplest thing to do would be not to include sigma in params.

multi_transform + set_to_zero "shines" when we are given a parameterized function + training procedure for the entire array of params, but we are only interested in updating/training a subset of these params.

If we have more direct control over the parameterized function + training procedure, then as you suggest, there are more "direct" (conceptually simpler) ways to train over the parameters of interest.

pharringtonp19 avatar Feb 14 '22 14:02 pharringtonp19

Or we could add an example where the freezing only happens for certain steps (e.g. the first N)

@mkunesch I'm trying to implement an optimizer where some parameters are frozen for the first N steps. Could you give me a hint/example how to achieve such a behavior? So far I was mainly thinking of using zero/small learning rates for the frozen parameter.

trologat avatar Apr 09 '22 15:04 trologat

Hi, thanks for the discussions above.

I myself wanted to freeze params on one my projects, and referred several issues (https://arc.net/e/A4919E55-946B-483F-9274-15AD5F00D6B8), to understand multi_transform and set_to_zero.

I actually missed @pharringtonp19's colab example above (until I wanted to add this comment), and I'm not sure whether its added to the doc already.

I also made a notebook on my own 😅 made an example where,

  • At the start, param_2 is fixed, update param_1 during first 500 steps, then freeze it.
  • Unfreeze param_2 at 500th step, and optimize it till 1000th step

Notebook : https://colab.research.google.com/drive/1TJh-OSk5cqLWoHaVGAhPDWJERbtRS0b_?usp=sharing

If the example is of help, would love to add more text to explain and make it documentation ready for a PR.

ramithuh avatar Feb 23 '23 19:02 ramithuh

+1 for this bug. I also struggled on this and I feel documentation should be improved with examples. After searching, here are 2 ways to freeze parameters:

Assuming:

params = {
    'param0': jnp.array([10.0]),
    'param1' : jnp.array([10.0])
}

Using optax.masked:

mask = {
    'param0': False,
    'param1': True,  # Param 1 will be frozen
}

optimizer = optax.chain(
    optax.adam(),
    optax.masked(optax.set_to_zero(), mask)
)

Using optax.multi_transform:

optimizer = optax.multi_transform(
    {
        'train': optax.adam(),
        'freeze': optax.set_to_zero(),
    },
    {
        'param0': 'train',
        'param1': 'freeze',  # Param 1 will be frozen
    },
)

I personally feel both of those options are too boilerplate for such a common operation. I think optax would benefit from a higher level abstraction. Something like:

optimizer = optax.apply_and_freeze(  # Not sure what's the best name though.
    optax.adam(),
    mask,
)

Conchylicultor avatar Dec 16 '24 11:12 Conchylicultor