optax icon indicating copy to clipboard operation
optax copied to clipboard

unitwise_norm fails for 3D convolutions

Open froody opened this issue 10 months ago • 1 comments

unitwise_norm, used by adaptive_grad_clip, only supports a few values of ndim, and raises ValueError when applied to a conv3d kernel since ndim=5 (HWDIO). Would it be acceptable to add an optional axis kwarg to adaptive_grad_clip and unitwise_norm? This would allow specifying the reduction axes at the callsite instead of baking every possible combination into the implementation of unitwise_norm.

I'm happy to submit a PR

froody avatar Apr 05 '24 00:04 froody

Hello @froody,

Good catch. The behavior of adaptive_grad_clip hides indeed some logic that could mislead users indeed. If you are willing to do a pr to let this function handle ndim=5 that would be great. I don't know exactly how you can add an axis and keep the current default behavior, I let you try and see :)

Thank you !

vroulet avatar Apr 05 '24 00:04 vroulet