optax
optax copied to clipboard
Allow for specifying axis for `optax.projections`
Currently, you can't specify the axis to take the projection along. Would the Optax authors be open to contributions allowing for specifying the axis of the projections? This would be much easier than having to vmap the projections.
Thanks Jesse! That makes sense to me. Looping in @mblondel who is the original author of the projection module
Hi, can i work on this issue?
The functions in optax.projections operate on entire general pytrees, not arrays. How would axes be defined in that case?
Perhaps we need to create versions of these functions that operate on arrays specifically.