torchscale icon indicating copy to clipboard operation
torchscale copied to clipboard

Training RetNet on A100 GPUs

Open Antoine-Bergerault opened this issue 7 months ago • 1 comments

Hello,

I followed the blog post https://zenn.dev/selllous/articles/retnet_tutorial shared in #52 in order to train RetNet, and it seems to work well for small models (< 3B).

But I am unable to train retnet_3b without running into memory issues. For now I just want to make it run, but even with very small batch-size and max-tokens I run into issues.

cd torchscale/examples/fairseq/
python train.py ../../../fairseq/examples/language_model/data-bin/wikitext-103 \
  --task language_modeling \
  --save-dir checkpoints/retnet_3b/transformer_wikitext-103 \
  --arch retnet_3b --share-decoder-input-output-embed \
  --save-interval 1 \
  --dropout 0.1 \
  --optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0 \
  --lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
  --max-tokens 512 --update-freq 16 \
  --fp16 \
  --batch-size 2 \
  --max-update 1 \
  --tokens-per-sample 512

It seems like the backward pass always introduces OOM issues since the call to optimizer.step() in fairseq_task.py, line 498 exits with:

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 11.96 GiB. GPU 0 has a total capacty of 79.15 GiB of which 7.20 GiB is free. Process 75705 has 71.94 GiB memory in use. Of the allocated memory 65.77 GiB is allocated by PyTorch, and 3.41 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

What would you recommend for training this size of model? Is there a way to train it on one or more A100 GPUs with 80GiB of memory?

I understand that I might want to partition the model into multiple GPUs, but I am very unfamiliar with this and any help would be appreciated.

Antoine-Bergerault avatar Nov 26 '23 13:11 Antoine-Bergerault

You can try --memory-efficient-fp16 --checkpoint-activations which can signficantly reduce the memory consumption.

shumingma avatar Dec 20 '23 11:12 shumingma