unitwise_norm fails for 3D convolutions
unitwise_norm, used by adaptive_grad_clip, only supports a few values of ndim, and raises ValueError when applied to a conv3d kernel since ndim=5 (HWDIO). Would it be acceptable to add an optional axis kwarg to adaptive_grad_clip and unitwise_norm? This would allow specifying the reduction axes at the callsite instead of baking every possible combination into the implementation of unitwise_norm.
I'm happy to submit a PR
Hello @froody,
Good catch. The behavior of adaptive_grad_clip hides indeed some logic that could mislead users indeed.
If you are willing to do a pr to let this function handle ndim=5 that would be great. I don't know exactly how you can add an axis and keep the current default behavior, I let you try and see :)
Thank you !