optax icon indicating copy to clipboard operation
optax copied to clipboard

Allow for specifying axis for `optax.projections`

Open JesseFarebro opened this issue 10 months ago • 3 comments

Currently, you can't specify the axis to take the projection along. Would the Optax authors be open to contributions allowing for specifying the axis of the projections? This would be much easier than having to vmap the projections.

JesseFarebro avatar Jan 26 '25 18:01 JesseFarebro

Thanks Jesse! That makes sense to me. Looping in @mblondel who is the original author of the projection module

fabianp avatar Jan 27 '25 08:01 fabianp

Hi, can i work on this issue?

shreyans413 avatar Mar 06 '25 11:03 shreyans413

The functions in optax.projections operate on entire general pytrees, not arrays. How would axes be defined in that case?

Perhaps we need to create versions of these functions that operate on arrays specifically.

carlosgmartin avatar Apr 24 '25 20:04 carlosgmartin