open_clip icon indicating copy to clipboard operation
open_clip copied to clipboard

Add support for gradient accumulation.

Open mitchellnw opened this issue 2 years ago • 6 comments

Added a new flag --accum-freq (accumulation frequency) which defaults to 1.

If this is greater than 1, then the optimizer is only stepped every --accum-freq batches.

Can be combined with gradient checkpointing.

Feature was requested in case people only have a few gpus but want to train with large batch.

We don't have to merge if people think it's makes things too complicated, and can instead close but point to upon request, but at least curious to hear thoughts.

For per-gpu batch size of m and --acum-freq k the effective per-gpu batch size is mk.

The basic psuedocode, when --accum-freq > 1 is:

accum_data, accum_features = [], []
for i, data in enumerate(dataloader):
  
  opt.zero_grad()
  
  # first, get the features for a bunch of batches without gradient tracking
  with no_grad:
    features = model(data)
  accum_data.append(data)
  accum_features.append(features)
  
  if (i + 1) % accum_freq > 0:
    continue
    
    
  # now re-compute the forward pass for the previous batches, with gradient tracking
  for j, data in enumerate(accum_data):
    features = model(data)
    all_features = cat(accum_features[:j], [features], accum_features[j+1:])
    loss = get_loss(all_features)
    loss.backward()
    
  optimizer.step()
  accum_data, accum_features = [], []

mitchellnw avatar Nov 29 '22 00:11 mitchellnw

Interesting! Does it work with --local-loss --gather-with-grad too?

usuyama avatar Nov 29 '22 02:11 usuyama

@usuyama yep! if you check out the pseudocode above, it doesn't really depend on how loss is implemented

mitchellnw avatar Nov 29 '22 02:11 mitchellnw

Very nice!

For users, it could be good to have some guidance on how much the training time overhead is.

ludwigschmidt avatar Nov 29 '22 03:11 ludwigschmidt

Sounds good, using --accum-freq k is just over k times slower than --accum-freq 1

mitchellnw avatar Nov 29 '22 03:11 mitchellnw

Hi, looks cool!

It's not obvious to me what this PR introduces in term of cpu and memory overhead. What recomputation gets done, what temporary storage is used, is there any network consequences? Could be answered either by analysing the code in details or running experiments at multiple scales

rom1504 avatar Dec 01 '22 15:12 rom1504

Cool! Is this an implementation of GradAccum in BASIC?

Quan-Sun avatar Dec 02 '22 09:12 Quan-Sun

Here is a screenshot verifying that training on 8 gpus with per-gpu batch size 512 behaves the same as training on 4 gpus with per-gpu batch size 512 and accum freq 2. However, it's 2x as slow. I've also updated the readme to calrify samples/s and other information about this feature.

Screen Shot 2022-12-03 at 10 05 38 AM

mitchellnw avatar Dec 03 '22 18:12 mitchellnw

Cool! Is this an implementation of GradAccum in BASIC?

Not exactly but it looks like an overall similar approach.

mitchellnw avatar Dec 03 '22 18:12 mitchellnw

Any thoughts on if this can be merged?

mitchellnw avatar Dec 08 '22 17:12 mitchellnw

Yeah lgtm, let's go

rom1504 avatar Dec 08 '22 21:12 rom1504

Hello! Thanks a lot for adding this functionality. I think there is an error in the computation of the number of samples during the logging process. It's missing the multiplication by the accum-freq argument. I made a PR to correct it. #327

rfbr avatar Dec 29 '22 09:12 rfbr