metaseq
metaseq copied to clipboard
Any recommend way to improve training speed on hardware with low VRAM?
❓ Questions and Help
Before asking:
- search the issues.
- search the docs.
What is your question?
Hi, I am working on training a 10B model in a limited resources (16 * A100 40GB). The problem I am facing now is that I cannot achieve the flop/s target (130T/s ~ 150T/s). I have tuned parameters around, but the most prominent speed-up comes from reducing the model size and increasing batch size. So I am thinking the reason may be the small batch size I have to use to accomodate the lower VRAM.
n params (B) | hidden | ffw | # heads | # layers | # tensor parallel | batch size | wps | Tflop/s/A100 |
---|---|---|---|---|---|---|---|---|
8.172 | 4096 | 16384 | 32 | 40 | 2 | 8 | 17k | 69 |
8.172 | 4096 | 16384 | 32 | 40 | 4 | 16 | OOM | OOM |
4.144 | 4096 | 16384 | 32 | 20 | 2 | 16 | 43k | 89 |
4.144 | 4096 | 16384 | 32 | 20 | 4 | 32 | 27k | 56 |
The most straightforward way to validate this is to increasing parallel size. However from my observations, increasing tensor parallel size from 2 to 4 only slows down training. Is it as expected? If it is, is there any other way to improve the training speed here?
Seq_len 2048, Flops calculation: wps * n_params * 8 / n_gpus,
Code
What have you tried?
What's your environment?
- metaseq Version (e.g., 1.0 or master): master
- PyTorch Version (e.g., 1.0) nightly 1.13.0a0+d321be6
- OS (e.g., Linux): Ubuntu 20.04
- How you installed metaseq (
pip
, source): source - Build command you used (if compiling from source): pip install . -e
- Python version: 3.9
- CUDA/cuDNN version: 450.80.02/11.7/8600
- GPU models and configuration: A100 40GB * 16
- Any other relevant information:
Checkpoint activations and FSDP will significantly lower memory pressure.