Pytree-based Optimizers
This topic is in my mind every once in a while, it has already been discussed extensively (e.g. https://github.com/deepmind/optax/issues/197#issuecomment-982548377), but I feel it needs new life because it could resolve the last remaining quirks in optax.
Optax optimizers have well defined API and contrary to neural networks they have clear ways on how to update their state, making them perfectly suitable for pytree/dataclass interfaces. Similar to what @NeilGirdhar has done here, one could express Pytree version of all optimizers by wrapping functional optax with the added benefits:
- Optimizer can now pass through jax's function transformation boundaries, e.g.
jit. - Hyper-parameters could be updated using immutable API's like
.replace(). - You could get rid of the optimizer vs
opt_stateseparation. - You can now inspect hyper-parameter updates e.g. log the learning rate under a schedule.
Example
For this example I'l be using Flax's PyTreeNode but any pytree implementation is just as good.
class SGD(PyTreeNode):
learning_rate: ScalarOrSchedule
momentum: Optional[float] = None
nesterov: bool = False
accumulator_dtype: Optional[Any] = field(pytree_node=False, default=None)
opt_state: Optional[OptState] = None
@property
def tx(self):
return optax.sdg(**{k: v for k, v in vars(self).items() if k != 'opt_state'})
def init(self: A, params: Params) -> A:
return self.replace(opt_state=self.tx.init(params))
def update(
self: A, updates: Updates, params: Optional[Params] = None
) -> Tuple[Updates, A]:
updates, opt_state = self.tx.update(updates, self.opt_state, params=params)
return update, self.replace(opt_state=opt_state)
# sample usage
tx = SDG(3e-4)
tx = tx.init(params)
updates, tx = tx.update(grads)
params = optax.apply_updates(params, updates)
Proposal
Given that any community shim will probably not succeed, how about a optax.pytree namespace (naming suggestions are welcomed) where a shim could officially live and be discussed with the core team?
@mkunesch @rosshemsley @hbq1 let us know what you think 👍
@cgarciae Sorry for taking a while to respond! And thanks for sharing your design proposal!
There have been a few projects recently working on attaching functions to custom pytrees in JAX (e.g. equinox), and this proposal has seems to have some similar ideas.
For the reasons you mentioned, this factoring can be attractive! Although it's worth highlighting that there are some downsides to this approach, too:
- JAX function transformations on classes can be a source of confusion for users. Functions are generally easier to reason about under JAX transformations than class methods.
- Many JAX users checkpoint their states using pickle, and custom pytrees cause problems with this. Keeping state tree as close to 'a dictionary of arrays' as possible has generally been a useful goal for many of the teams currently using optax. Furthermore, any changes to the optax API would have to be able to support reloading existing checkpoints in a backwards-compatible way.
We are continuing to work on polishing the optax API - although we are also deliberately being conservative about the changes we make - part of what makes optax successful is its ruthless simplicity, and forking the API with two sets of alternative factorings would increase the API surface area and could make it harder for us to support.
We're currently working on improving the package factoring, which will hopefully leave optax in a better place for trying out some more experimental ideas (such as this kind of API factoring), but it may be a little while before we would want to introduce big changes like this to the core library.
We'd encourage you to keep thinking about this idea though! Especially with regards to 2) above. Optax has thousands of users at the moment, and so charting a path forwards whilst retaining checkpoint compatibility is probably the biggest barrier we have to making these kinds of changes.
It would also be a good idea to try and "break" this design - e.g. what happens when using more esoteric JAX transforms (such as vmap, pmap, pjit, or grad) can you break this design through unexpected jit placement? (as a rule, someone has done one of these things somewhere to all optax optimizers already)
Just some thoughts:
JAX function transformations on classes can be a source of confusion for users. Functions are generally easier to reason about under JAX transformations than class methods.
This proposal doesn't change anything for users since the interface is identical except for the four benefits mentioned. The reasoning that users have to do is exactly the same. If anything, the user reasoning is simpler since the sequence interface is not exposed.
It's unfortunate that we didn't reconcile this issue back when it was suggested in the very first Optax issue.
It would also be a good idea to try and "break" this design - e.g. what happens when using more esoteric JAX transforms (such as vmap, pmap, pjit, or grad)
Why don't you try breaking it? I think it might make the benefits more apparent.
I also think it would be good to at least block the sequence interface, which are misuses of the current optax design. This will make it easier to improve your design in the future.