Jax 0.4.34 raises an error for None as tree-prefix
Hello:)
The most recent release of jax 0.4.34 does no longer support anything of the kind jax.tree.map(f, None, non-None).
This also affects optax, for example, in add_decayed_weights.
The error can be simply prevented by changing the line to
updates = jtu.tree_map(
lambda g, p: None if g is None else g + weight_decay * p, updates, params, is_leaf=lambda x: x is None)
I'm happy to take a look and see if there are any other parts affected and provide a pull request if desired.
Best, Alex
Hello @a1302z, Thank you very much for pointing this. We used an automatic tool to catch this error and make a pass on our codebase to fix this (see e.g. here) but it seems that it did not catch all occurrences. If you are willing to do a PR for that, that would be great! And if not, I'd be curious to know how you flagged them.
I found it by reproducing this tutorial of equinox.
I found eight more occurrences and have modified them appropriately. I'll create a PR :)