xla icon indicating copy to clipboard operation
xla copied to clipboard

Parallel compile on GPU backend would slow down execution performance.

Open cicirori opened this issue 2 years ago • 4 comments

🐛 Bug

Parallel compile ~~on PJRT:GPU~~ would slow down execution performance. ~~(Probably not about PJRT, I haven't tested the performance of parallel compilation on XRT.)~~ (Just tested on XRT, found the same performance drop when --xla_gpu_force_compilation_parallelism=96)

To Reproduce

Steps to reproduce the behavior:

Test script was posted below, it's a simple bert model from huggingface/transformers

  1. compile with thread pool: PJRT_DEVICE=GPU CUDA_VISIBLE_DEVICES=0 GPU_NUM_DEVICES=1 python train_sst_bert.py --batch_size=30 --xla_enabled --amp_enabled --syncfree_optimizer
  2. compile without thread pool: XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1" PJRT_DEVICE=GPU CUDA_VISIBLE_DEVICES=0 GPU_NUM_DEVICES=1 python train_sst_bert.py --batch_size=30 --xla_enabled --amp_enabled --syncfree_optimizer

with thread pool:

[ 2022-08-08-02:51:20.798009 ] Iteration 0       complete in 0.0m 0.4970126152038574 s
[ 2022-08-08-02:52:51.567088 ] Iteration 16      complete in 1.0m 30.769078254699707 s
[ 2022-08-08-02:52:56.524109 ] Iteration 32      complete in 0.0m 4.957021713256836 s
[ 2022-08-08-02:53:01.489270 ] Iteration 48      complete in 0.0m 4.96515965461731 s
[ 2022-08-08-02:53:06.447775 ] Iteration 64      complete in 0.0m 4.958504915237427 s
[ 2022-08-08-02:53:11.402642 ] Iteration 80      complete in 0.0m 4.9548656940460205 s
[ 2022-08-08-02:53:16.353444 ] Iteration 96      complete in 0.0m 4.950800895690918 s
[ 2022-08-08-02:53:21.302015 ] Iteration 112     complete in 0.0m 4.948570013046265 s
[ 2022-08-08-02:53:26.029324 ] Iteration 128     complete in 0.0m 4.72730827331543 s
[ 2022-08-08-02:53:30.919764 ] Iteration 144     complete in 0.0m 4.890439510345459 s
[ 2022-08-08-02:53:35.853633 ] Iteration 160     complete in 0.0m 4.933867931365967 s
[ 2022-08-08-02:53:40.864197 ] Iteration 176     complete in 0.0m 5.0105626583099365 s
[ 2022-08-08-02:53:45.626250 ] Iteration 192     complete in 0.0m 4.762052297592163 s

without thread pool:

[ 2022-08-08-02:54:39.386595 ] Iteration 0       complete in 0.0m 0.5132155418395996 s
[ 2022-08-08-02:58:05.775467 ] Iteration 16      complete in 3.0m 26.38887643814087 s
[ 2022-08-08-02:58:10.557267 ] Iteration 32      complete in 0.0m 4.78179931640625 s
[ 2022-08-08-02:58:15.226885 ] Iteration 48      complete in 0.0m 4.669617652893066 s
[ 2022-08-08-02:58:19.707870 ] Iteration 64      complete in 0.0m 4.480984210968018 s
[ 2022-08-08-02:58:24.489684 ] Iteration 80      complete in 0.0m 4.781812906265259 s
[ 2022-08-08-02:58:29.169496 ] Iteration 96      complete in 0.0m 4.679811000823975 s
[ 2022-08-08-02:58:33.733199 ] Iteration 112     complete in 0.0m 4.56370210647583 s
[ 2022-08-08-02:58:38.362863 ] Iteration 128     complete in 0.0m 4.6296632289886475 s
[ 2022-08-08-02:58:42.963132 ] Iteration 144     complete in 0.0m 4.600268840789795 s
[ 2022-08-08-02:58:47.473824 ] Iteration 160     complete in 0.0m 4.510691165924072 s
[ 2022-08-08-02:58:52.093992 ] Iteration 176     complete in 0.0m 4.6201653480529785 s
[ 2022-08-08-02:58:56.660320 ] Iteration 192     complete in 0.0m 4.566328048706055 s

Expected behavior

Parallel compiling shouldn't cause performance drop just like what cpu world happans?

Environment

  • Reproducible on XLA backend [CPU/TPU]: GPU
  • torch_xla version: v1.12.0
  • tensorflow version: fc1c08fabd

Additional context

HLO dump:

parallel_compile.tgz

serial_compile.tgz

cicirori avatar Aug 08 '22 03:08 cicirori

Thanks, we have someone from internal xla:gpu to look into this issue.

JackCaoG avatar Aug 08 '22 17:08 JackCaoG

@cicirori We have a PR to fix this issue https://github.com/tensorflow/tensorflow/pull/57108. If it works for you, please let me know.

ymwangg avatar Aug 11 '22 20:08 ymwangg

amazing! Thanks @ymwangg . After this fix, should we enable the parallel thread pool compilation for pt/xla:gpu by default?

JackCaoG avatar Aug 11 '22 23:08 JackCaoG

@JackCaoG I think we may consider making it default by setting XLA_FLAGS="--xla_gpu_force_compilation_parallelism=?" here. Enabling multi-threading for compilation did significant improve the GPU user experience by reducing >50% of the compilation time for the models we have tested.

ymwangg avatar Aug 13 '22 00:08 ymwangg

@ymwangg remind me if we ever make LA_FLAGS="--xla_gpu_force_compilation_parallelism default?

JackCaoG avatar Mar 03 '23 18:03 JackCaoG

Yes, it's set to be 8 by default now https://github.com/pytorch/xla/blob/32da64f69a9c246186603a168291d7e42e0d3884/torch_xla/init.py#L51-L52.

ymwangg avatar Mar 03 '23 18:03 ymwangg

Thanks @ymwangg ! I will close this issue for now.

JackCaoG avatar Mar 03 '23 18:03 JackCaoG