Using FSDP
This branch is my attempt to try to squeeze the largest size model I can with BlockSparse vs standard dot product Attention + FSDP with optimal training from scratch throughput. In this paper they mention a 147 TFLOP/s per GPU performance using an A100, so hoping to see something similar (albeit no Tensor Parallelism here).
Currently, I'm assuming the wrapping of each transformer encoder block, however, I'm hoping to get some advice on parameters/setup for optimum throughput!
cc @zhaojuanmao who may be able to help!
Will need to use this fix: https://github.com/PyTorchLightning/pytorch-lightning/pull/12965
For native FSDP version, feel free to use "transformer_auto_wrap_policy" to wrap your model, also try the new mixedprecision config for bfloat16:)
While trying FSDP I'm seeing poor performance out of the box, not entirely sure what's causing it:
python train.py --devices 8 --n_embd 8192 --n_layer 14 --n_head 32
# Number of parameters: 11.29 Billion
Recorded values:
Estimates: 24.49TFLOPs Avg Iteration Time: 11.33s
Average Peak CUDA memory 31225.20 MiB
Average Peak Virtual memory 28.66 GiB
Average Peak Swap memory 0.00 Gib
Environment:
- 8 A100 GPUs Hyperplane
- PyTorch Version: 1.12.0.dev20220517+cu113
- PyTorch Lightning Version: https://github.com/PyTorchLightning/pytorch-lightning/pull/12985
- Triton: 2.0.0.dev20220505
- xFormers: 0.0.11.dev0
Any ideas of what to tweak to improve the estimated flop performance @zhaojuanmao?
@SeanNaren thanks for trying PTD FSDP!
- would you please print(model) after constructing the whole model? we found some bugs in lightning, seems the outermost model is not wrapped, it will disable the communication and computation overlapping in forward and backward, the wrapping should be like this FSDP(FSDP(transformer_layer_1), FSDP(transformer_layer_2)....); also, seems the device is not set properly, results in rank 0 OOM easily. basically call cuda.set_device() properly or set CUDA_VISIABLE_DEVICE properly
- what is the memory utilization right now? FSDP can help scaling the batch size without degrading the latency much due to comm&compute overlapping. Would you please maximize the batch size to maximize the throughput here?
- in real large model training workload, usually FSDP is used with activation checkpointing to save more memory and scale larger batch size, this can push throughput much higher. Basically do things like this FSDP(activation_checkpoint_wrapper(transformer_block)), we released a prototyped checkpoint wrapper here 'torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py'
- After the above fix, tuning and trying, if the throughput is still not ideal, could you please enable autograd profiler to get a trace of the training run? the autograd profiler should not be enabled in real training, instead it should be enabled only during debugging