optax icon indicating copy to clipboard operation
optax copied to clipboard

Pytree-based Optimizers

Open cgarciae opened this issue 3 years ago • 3 comments

This topic is in my mind every once in a while, it has already been discussed extensively (e.g. https://github.com/deepmind/optax/issues/197#issuecomment-982548377), but I feel it needs new life because it could resolve the last remaining quirks in optax.

Optax optimizers have well defined API and contrary to neural networks they have clear ways on how to update their state, making them perfectly suitable for pytree/dataclass interfaces. Similar to what @NeilGirdhar has done here, one could express Pytree version of all optimizers by wrapping functional optax with the added benefits:

  1. Optimizer can now pass through jax's function transformation boundaries, e.g. jit.
  2. Hyper-parameters could be updated using immutable API's like .replace().
  3. You could get rid of the optimizer vs opt_state separation.
  4. You can now inspect hyper-parameter updates e.g. log the learning rate under a schedule.

Example

For this example I'l be using Flax's PyTreeNode but any pytree implementation is just as good.

class SGD(PyTreeNode):
  learning_rate: ScalarOrSchedule
  momentum: Optional[float] = None
  nesterov: bool = False
  accumulator_dtype: Optional[Any] = field(pytree_node=False, default=None)
  opt_state: Optional[OptState] = None

  @property
  def tx(self):
    return optax.sdg(**{k: v for k, v in vars(self).items() if k != 'opt_state'})

  def init(self: A, params: Params) -> A:
    return self.replace(opt_state=self.tx.init(params))

  def update(
    self: A, updates: Updates, params: Optional[Params] = None
  ) -> Tuple[Updates, A]:
    updates, opt_state = self.tx.update(updates, self.opt_state, params=params)
    return update, self.replace(opt_state=opt_state)
    
# sample usage
tx = SDG(3e-4)
tx = tx.init(params)
updates, tx = tx.update(grads)
params = optax.apply_updates(params, updates)

Proposal

Given that any community shim will probably not succeed, how about a optax.pytree namespace (naming suggestions are welcomed) where a shim could officially live and be discussed with the core team?

cgarciae avatar Oct 04 '22 03:10 cgarciae

@mkunesch @rosshemsley @hbq1 let us know what you think 👍

8bitmp3 avatar Oct 04 '22 15:10 8bitmp3

@cgarciae Sorry for taking a while to respond! And thanks for sharing your design proposal!

There have been a few projects recently working on attaching functions to custom pytrees in JAX (e.g. equinox), and this proposal has seems to have some similar ideas.

For the reasons you mentioned, this factoring can be attractive! Although it's worth highlighting that there are some downsides to this approach, too:

  1. JAX function transformations on classes can be a source of confusion for users. Functions are generally easier to reason about under JAX transformations than class methods.
  2. Many JAX users checkpoint their states using pickle, and custom pytrees cause problems with this. Keeping state tree as close to 'a dictionary of arrays' as possible has generally been a useful goal for many of the teams currently using optax. Furthermore, any changes to the optax API would have to be able to support reloading existing checkpoints in a backwards-compatible way.

We are continuing to work on polishing the optax API - although we are also deliberately being conservative about the changes we make - part of what makes optax successful is its ruthless simplicity, and forking the API with two sets of alternative factorings would increase the API surface area and could make it harder for us to support.

We're currently working on improving the package factoring, which will hopefully leave optax in a better place for trying out some more experimental ideas (such as this kind of API factoring), but it may be a little while before we would want to introduce big changes like this to the core library.

We'd encourage you to keep thinking about this idea though! Especially with regards to 2) above. Optax has thousands of users at the moment, and so charting a path forwards whilst retaining checkpoint compatibility is probably the biggest barrier we have to making these kinds of changes.

It would also be a good idea to try and "break" this design - e.g. what happens when using more esoteric JAX transforms (such as vmap, pmap, pjit, or grad) can you break this design through unexpected jit placement? (as a rule, someone has done one of these things somewhere to all optax optimizers already)

rosshemsley avatar Nov 21 '22 08:11 rosshemsley

Just some thoughts:

JAX function transformations on classes can be a source of confusion for users. Functions are generally easier to reason about under JAX transformations than class methods.

This proposal doesn't change anything for users since the interface is identical except for the four benefits mentioned. The reasoning that users have to do is exactly the same. If anything, the user reasoning is simpler since the sequence interface is not exposed.

It's unfortunate that we didn't reconcile this issue back when it was suggested in the very first Optax issue.

It would also be a good idea to try and "break" this design - e.g. what happens when using more esoteric JAX transforms (such as vmap, pmap, pjit, or grad)

Why don't you try breaking it? I think it might make the benefits more apparent.

I also think it would be good to at least block the sequence interface, which are misuses of the current optax design. This will make it easier to improve your design in the future.

NeilGirdhar avatar Feb 10 '23 08:02 NeilGirdhar