dfdx icon indicating copy to clipboard operation
dfdx copied to clipboard

Gradient Clipping

Open jafioti opened this issue 2 years ago • 7 comments

When training large or deep models, exploding gradients are frequent and cause instability. Clipping them to a certian small amount is an effective way of stabilizing training.

To implement this, I believe a method on the Gradients struct would be needed (correct me if I'm wrong)

jafioti avatar Mar 21 '23 17:03 jafioti

I know there's multiple ways to clip gradients (e.g. pytorch has clip_grad_norm_ and clip_grad_value_.

Do we know if one of these is more widely used than the other?

chelsea0x3b avatar Mar 21 '23 17:03 chelsea0x3b

I think clip_grad_norm_ is more widely used, however it is also more complex, as it takes the norm of all the gradients first. clip_grad_value_ is used less, but is far more straightforward to implement so I think it makes sense to add that first.

jafioti avatar Mar 21 '23 17:03 jafioti

It should be possible to implement a general Gradients::map function that takes a FnMut(&mut Tensor<(usize,), E, D>) -> Result<(), D::Err> and applies it to each D::Vec after wrapping it in a tensor.

nkoppel avatar Mar 21 '23 18:03 nkoppel

That seems like all that would be needed for the clip_grad_value_

jafioti avatar Mar 21 '23 18:03 jafioti

pytorch implementations of the above are pretty straightforward https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/clip_grad.py

I would say clip_grad_norm would be required to go through TensorCollection api so:

  1. only the norm of the model's gradients are considered
  2. We can get access to the gradient's tensor
model.clip_grad_norm(&mut grads, 0.5);
model.clip_grad_value(&mut grads, 0.5);

Then we could implement clip_grad_norm with two passes with RecursiveWalker:

  1. Accumulate each gradient's norm. For each tensor & gradient:
    1. Create a tensor out of the gradient using Gradients::get
    2. Compute norm of tensor with g.square().sum().sqrt()
    3. Append this 0d norm tensor to a Vec along the walker
  2. Call stack on the Vec of 0d norms
  3. Call stacked.square().sum().sqrt() to compute total norm
  4. Multiply each gradient by max_norm / total_norm as done in pytorch code

If we wanted this all to be in-place:

  • For clip_grad_norm, we'd need a way to in-place multiply a D::Vec<E>.
  • For clip_grad_value, we'd need a way to in-place clamp a D::Vec<E>.

Also separately, the .square().sum().sqrt() way of taking norm may be expensive since .square() will allocate another tensor with the same size as the gradient. I think this can be addressed separately though.

chelsea0x3b avatar Mar 22 '23 01:03 chelsea0x3b

Has any work been done on this?

opfromthestart avatar Sep 24 '23 21:09 opfromthestart

I've submitted a draft PR, and once the examples are added I'll mark as ready for review. But so far I think it's working correctly, I've been able to avoid exploding gradients.

swfsql avatar Dec 14 '23 01:12 swfsql