xla
xla copied to clipboard
Parallel compile on GPU backend would slow down execution performance.
🐛 Bug
Parallel compile ~~on PJRT:GPU~~ would slow down execution performance.
~~(Probably not about PJRT, I haven't tested the performance of parallel compilation on XRT.)~~
(Just tested on XRT, found the same performance drop when --xla_gpu_force_compilation_parallelism=96
)
To Reproduce
Steps to reproduce the behavior:
Test script was posted below, it's a simple bert model from huggingface/transformers
- compile with thread pool:
PJRT_DEVICE=GPU CUDA_VISIBLE_DEVICES=0 GPU_NUM_DEVICES=1 python train_sst_bert.py --batch_size=30 --xla_enabled --amp_enabled --syncfree_optimizer
- compile without thread pool:
XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1" PJRT_DEVICE=GPU CUDA_VISIBLE_DEVICES=0 GPU_NUM_DEVICES=1 python train_sst_bert.py --batch_size=30 --xla_enabled --amp_enabled --syncfree_optimizer
with thread pool:
[ 2022-08-08-02:51:20.798009 ] Iteration 0 complete in 0.0m 0.4970126152038574 s
[ 2022-08-08-02:52:51.567088 ] Iteration 16 complete in 1.0m 30.769078254699707 s
[ 2022-08-08-02:52:56.524109 ] Iteration 32 complete in 0.0m 4.957021713256836 s
[ 2022-08-08-02:53:01.489270 ] Iteration 48 complete in 0.0m 4.96515965461731 s
[ 2022-08-08-02:53:06.447775 ] Iteration 64 complete in 0.0m 4.958504915237427 s
[ 2022-08-08-02:53:11.402642 ] Iteration 80 complete in 0.0m 4.9548656940460205 s
[ 2022-08-08-02:53:16.353444 ] Iteration 96 complete in 0.0m 4.950800895690918 s
[ 2022-08-08-02:53:21.302015 ] Iteration 112 complete in 0.0m 4.948570013046265 s
[ 2022-08-08-02:53:26.029324 ] Iteration 128 complete in 0.0m 4.72730827331543 s
[ 2022-08-08-02:53:30.919764 ] Iteration 144 complete in 0.0m 4.890439510345459 s
[ 2022-08-08-02:53:35.853633 ] Iteration 160 complete in 0.0m 4.933867931365967 s
[ 2022-08-08-02:53:40.864197 ] Iteration 176 complete in 0.0m 5.0105626583099365 s
[ 2022-08-08-02:53:45.626250 ] Iteration 192 complete in 0.0m 4.762052297592163 s
without thread pool:
[ 2022-08-08-02:54:39.386595 ] Iteration 0 complete in 0.0m 0.5132155418395996 s
[ 2022-08-08-02:58:05.775467 ] Iteration 16 complete in 3.0m 26.38887643814087 s
[ 2022-08-08-02:58:10.557267 ] Iteration 32 complete in 0.0m 4.78179931640625 s
[ 2022-08-08-02:58:15.226885 ] Iteration 48 complete in 0.0m 4.669617652893066 s
[ 2022-08-08-02:58:19.707870 ] Iteration 64 complete in 0.0m 4.480984210968018 s
[ 2022-08-08-02:58:24.489684 ] Iteration 80 complete in 0.0m 4.781812906265259 s
[ 2022-08-08-02:58:29.169496 ] Iteration 96 complete in 0.0m 4.679811000823975 s
[ 2022-08-08-02:58:33.733199 ] Iteration 112 complete in 0.0m 4.56370210647583 s
[ 2022-08-08-02:58:38.362863 ] Iteration 128 complete in 0.0m 4.6296632289886475 s
[ 2022-08-08-02:58:42.963132 ] Iteration 144 complete in 0.0m 4.600268840789795 s
[ 2022-08-08-02:58:47.473824 ] Iteration 160 complete in 0.0m 4.510691165924072 s
[ 2022-08-08-02:58:52.093992 ] Iteration 176 complete in 0.0m 4.6201653480529785 s
[ 2022-08-08-02:58:56.660320 ] Iteration 192 complete in 0.0m 4.566328048706055 s
Expected behavior
Parallel compiling shouldn't cause performance drop just like what cpu world happans?
Environment
- Reproducible on XLA backend [CPU/TPU]: GPU
- torch_xla version: v1.12.0
- tensorflow version: fc1c08fabd
Additional context
HLO dump:
Thanks, we have someone from internal xla:gpu to look into this issue.
@cicirori We have a PR to fix this issue https://github.com/tensorflow/tensorflow/pull/57108. If it works for you, please let me know.
amazing! Thanks @ymwangg . After this fix, should we enable the parallel thread pool compilation for pt/xla:gpu by default?
@JackCaoG I think we may consider making it default by setting XLA_FLAGS="--xla_gpu_force_compilation_parallelism=?"
here. Enabling multi-threading for compilation did significant improve the GPU user experience by reducing >50% of the compilation time for the models we have tested.
@ymwangg remind me if we ever make LA_FLAGS="--xla_gpu_force_compilation_parallelism
default?
Yes, it's set to be 8 by default now https://github.com/pytorch/xla/blob/32da64f69a9c246186603a168291d7e42e0d3884/torch_xla/init.py#L51-L52.
Thanks @ymwangg ! I will close this issue for now.