optax icon indicating copy to clipboard operation
optax copied to clipboard

Jax 0.4.34 raises an error for None as tree-prefix

Open a1302z opened this issue 1 year ago • 2 comments

Hello:)

The most recent release of jax 0.4.34 does no longer support anything of the kind jax.tree.map(f, None, non-None).

This also affects optax, for example, in add_decayed_weights.

The error can be simply prevented by changing the line to

    updates = jtu.tree_map(
        lambda g, p: None if g is None else g + weight_decay * p, updates, params, is_leaf=lambda x: x is None)

I'm happy to take a look and see if there are any other parts affected and provide a pull request if desired.

Best, Alex

a1302z avatar Oct 07 '24 14:10 a1302z

Hello @a1302z, Thank you very much for pointing this. We used an automatic tool to catch this error and make a pass on our codebase to fix this (see e.g. here) but it seems that it did not catch all occurrences. If you are willing to do a PR for that, that would be great! And if not, I'd be curious to know how you flagged them.

vroulet avatar Oct 07 '24 15:10 vroulet

I found it by reproducing this tutorial of equinox.

I found eight more occurrences and have modified them appropriately. I'll create a PR :)

a1302z avatar Oct 07 '24 15:10 a1302z