Gradient Clipping
When training large or deep models, exploding gradients are frequent and cause instability. Clipping them to a certian small amount is an effective way of stabilizing training.
To implement this, I believe a method on the Gradients struct would be needed (correct me if I'm wrong)
I know there's multiple ways to clip gradients (e.g. pytorch has clip_grad_norm_ and clip_grad_value_.
Do we know if one of these is more widely used than the other?
I think clip_grad_norm_ is more widely used, however it is also more complex, as it takes the norm of all the gradients first. clip_grad_value_ is used less, but is far more straightforward to implement so I think it makes sense to add that first.
It should be possible to implement a general Gradients::map function that takes a FnMut(&mut Tensor<(usize,), E, D>) -> Result<(), D::Err> and applies it to each D::Vec after wrapping it in a tensor.
That seems like all that would be needed for the clip_grad_value_
pytorch implementations of the above are pretty straightforward https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/clip_grad.py
I would say clip_grad_norm would be required to go through TensorCollection api so:
- only the norm of the model's gradients are considered
- We can get access to the gradient's tensor
model.clip_grad_norm(&mut grads, 0.5);
model.clip_grad_value(&mut grads, 0.5);
Then we could implement clip_grad_norm with two passes with RecursiveWalker:
- Accumulate each gradient's norm. For each tensor & gradient:
- Create a tensor out of the gradient using Gradients::get
- Compute norm of tensor with
g.square().sum().sqrt() - Append this 0d norm tensor to a Vec along the walker
- Call stack on the Vec of 0d norms
- Call
stacked.square().sum().sqrt()to compute total norm - Multiply each gradient by
max_norm / total_normas done in pytorch code
If we wanted this all to be in-place:
- For clip_grad_norm, we'd need a way to in-place multiply a
D::Vec<E>. - For clip_grad_value, we'd need a way to in-place clamp a
D::Vec<E>.
Also separately, the .square().sum().sqrt() way of taking norm may be expensive since .square() will allocate another tensor with the same size as the gradient. I think this can be addressed separately though.
Has any work been done on this?
I've submitted a draft PR, and once the examples are added I'll mark as ready for review. But so far I think it's working correctly, I've been able to avoid exploding gradients.