metaseq icon indicating copy to clipboard operation
metaseq copied to clipboard

Any recommend way to improve training speed on hardware with low VRAM?

Open QIU-Shuo opened this issue 2 years ago • 1 comments

❓ Questions and Help

Before asking:

  1. search the issues.
  2. search the docs.

What is your question?

Hi, I am working on training a 10B model in a limited resources (16 * A100 40GB). The problem I am facing now is that I cannot achieve the flop/s target (130T/s ~ 150T/s). I have tuned parameters around, but the most prominent speed-up comes from reducing the model size and increasing batch size. So I am thinking the reason may be the small batch size I have to use to accomodate the lower VRAM.

n params (B) hidden ffw # heads # layers # tensor parallel batch size wps Tflop/s/A100
8.172 4096 16384 32 40 2 8 17k 69
8.172 4096 16384 32 40 4 16 OOM OOM
4.144 4096 16384 32 20 2 16 43k 89
4.144 4096 16384 32 20 4 32 27k 56

The most straightforward way to validate this is to increasing parallel size. However from my observations, increasing tensor parallel size from 2 to 4 only slows down training. Is it as expected? If it is, is there any other way to improve the training speed here?

Seq_len 2048, Flops calculation: wps * n_params * 8 / n_gpus,

Code

What have you tried?

What's your environment?

  • metaseq Version (e.g., 1.0 or master): master
  • PyTorch Version (e.g., 1.0) nightly 1.13.0a0+d321be6
  • OS (e.g., Linux): Ubuntu 20.04
  • How you installed metaseq (pip, source): source
  • Build command you used (if compiling from source): pip install . -e
  • Python version: 3.9
  • CUDA/cuDNN version: 450.80.02/11.7/8600
  • GPU models and configuration: A100 40GB * 16
  • Any other relevant information:

QIU-Shuo avatar Sep 20 '22 07:09 QIU-Shuo

Checkpoint activations and FSDP will significantly lower memory pressure.

stephenroller avatar Sep 24 '22 17:09 stephenroller