optax
optax copied to clipboard
unitwise_norm fails for 3D convolutions
unitwise_norm
, used by adaptive_grad_clip
, only supports a few values of ndim
, and raises ValueError when applied to a conv3d kernel since ndim=5
(HWDIO). Would it be acceptable to add an optional axis
kwarg to adaptive_grad_clip
and unitwise_norm
? This would allow specifying the reduction axes at the callsite instead of baking every possible combination into the implementation of unitwise_norm
.
I'm happy to submit a PR
Hello @froody,
Good catch. The behavior of adaptive_grad_clip
hides indeed some logic that could mislead users indeed.
If you are willing to do a pr to let this function handle ndim=5
that would be great. I don't know exactly how you can add an axis
and keep the current default behavior, I let you try and see :)
Thank you !