xla
xla copied to clipboard
Pin update March 2024
Update xla pin to HEAD
Summary:
- Update bazel to 6.5.0
- Rename
PJRT_Structure_Base
toPJRT_Extension_Base
to accommodate change in XLA.
Hit the following error
File "/home/lsiyuan/.cache/bazel/_bazel_lsiyuan/9d8c0c9d904275861907f86bf4a21dbc/external/llvm-project/mlir/BUILD.bazel", line 40, column 7, in <toplevel>
} | if_cuda_available(
Error: unsupported binary operation: dict | select
Need to upgrade bazel version to above 6.0.0
Testing performance with the following cmd, on v4-8 TPU
python test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1 --metrics_debug
After pin update
| Training Device=xla:0/1 Epoch=1 Step=2280 Loss=0.00135 Rate=425.92 GlobalRate=370.64 Time=23:07:35
| Training Device=xla:0/2 Epoch=1 Step=2280 Loss=0.00135 Rate=425.92 GlobalRate=371.41 Time=23:07:35
| Training Device=xla:0/3 Epoch=1 Step=2280 Loss=0.00135 Rate=425.92 GlobalRate=371.07 Time=23:07:35
| Training Device=xla:0/1 Epoch=1 Step=2300 Loss=0.00135 Rate=425.11 GlobalRate=371.04 Time=23:07:41
| Training Device=xla:0/2 Epoch=1 Step=2300 Loss=0.00135 Rate=425.12 GlobalRate=371.81 Time=23:07:41
| Training Device=xla:0/0 Epoch=1 Step=2300 Loss=0.00135 Rate=425.12 GlobalRate=371.80 Time=23:07:41
| Training Device=xla:0/3 Epoch=1 Step=2300 Loss=0.00135 Rate=425.12 GlobalRate=371.48 Time=23:07:41
Before pin update
| Training Device=xla:0/1 Epoch=1 Step=2260 Loss=0.00135 Rate=453.15 GlobalRate=401.45 Time=00:43:45
| Training Device=xla:0/0 Epoch=1 Step=2260 Loss=0.00135 Rate=453.15 GlobalRate=400.79 Time=00:43:45
| Training Device=xla:0/2 Epoch=1 Step=2260 Loss=0.00135 Rate=453.14 GlobalRate=400.77 Time=00:43:45
| Training Device=xla:0/1 Epoch=1 Step=2280 Loss=0.00135 Rate=456.66 GlobalRate=401.89 Time=00:43:50
| Training Device=xla:0/0 Epoch=1 Step=2280 Loss=0.00135 Rate=456.66 GlobalRate=401.23 Time=00:43:50
| Training Device=xla:0/2 Epoch=1 Step=2280 Loss=0.00135 Rate=456.67 GlobalRate=401.21 Time=00:43:50
| Training Device=xla:0/3 Epoch=1 Step=2280 Loss=0.00135 Rate=456.65 GlobalRate=401.81 Time=00:43:50
| Training Device=xla:0/3 Epoch=1 Step=2300 Loss=0.00135 Rate=458.70 GlobalRate=402.25 Time=00:43:56
| Training Device=xla:0/2 Epoch=1 Step=2300 Loss=0.00135 Rate=458.70 GlobalRate=401.66 Time=00:43:56
| Training Device=xla:0/1 Epoch=1 Step=2300 Loss=0.00135 Rate=458.69 GlobalRate=402.34 Time=00:43:56
| Training Device=xla:0/0 Epoch=1 Step=2300 Loss=0.00135 Rate=458.68 GlobalRate=401.68 Time=00:43:56
There is a perf regression after the pin update.
Update: The perf result above is using debugging build, redo with release build. after pin update
| Training Device=xla:0/1 Epoch=1 Step=2300 Loss=0.00135 Rate=1792.04 GlobalRate=1229.30 Time=02:04:04
| Training Device=xla:0/1 Epoch=1 Step=2320 Loss=0.00135 Rate=1794.55 GlobalRate=1232.65 Time=02:04:06
| Training Device=xla:0/2 Epoch=1 Step=2320 Loss=0.00135 Rate=1794.40 GlobalRate=1229.71 Time=02:04:06
| Training Device=xla:0/3 Epoch=1 Step=2320 Loss=0.00135 Rate=1794.37 GlobalRate=1239.40 Time=02:04:06
| Training Device=xla:0/0 Epoch=1 Step=2320 Loss=0.00135 Rate=1794.41 GlobalRate=1237.04 Time=02:04:06
| Training Device=xla:0/1 Epoch=1 Step=2340 Loss=0.00135 Rate=1795.36 GlobalRate=1235.96 Time=02:04:07
| Training Device=xla:0/2 Epoch=1 Step=2340 Loss=0.00135 Rate=1795.27 GlobalRate=1233.04 Time=02:04:07
| Training Device=xla:0/3 Epoch=1 Step=2340 Loss=0.00135 Rate=1795.29 GlobalRate=1242.69 Time=02:04:07
| Training Device=xla:0/0 Epoch=1 Step=2340 Loss=0.00135 Rate=1795.25 GlobalRate=1240.33 Time=02:04:07
Before pin upate
| Training Device=xla:0/0 Epoch=1 Step=2300 Loss=0.00135 Rate=1792.56 GlobalRate=1229.77 Time=04:31:01
| Training Device=xla:0/1 Epoch=1 Step=2300 Loss=0.00135 Rate=1792.61 GlobalRate=1225.71 Time=04:31:01
| Training Device=xla:0/2 Epoch=1 Step=2300 Loss=0.00135 Rate=1792.51 GlobalRate=1235.31 Time=04:31:01
| Training Device=xla:0/3 Epoch=1 Step=2300 Loss=0.00135 Rate=1792.52 GlobalRate=1234.37 Time=04:31:01
| Training Device=xla:0/0 Epoch=1 Step=2320 Loss=0.00135 Rate=1794.64 GlobalRate=1233.12 Time=04:31:02
| Training Device=xla:0/2 Epoch=1 Step=2320 Loss=0.00135 Rate=1794.66 GlobalRate=1238.64 Time=04:31:02
| Training Device=xla:0/1 Epoch=1 Step=2320 Loss=0.00135 Rate=1794.68 GlobalRate=1229.07 Time=04:31:02
| Training Device=xla:0/3 Epoch=1 Step=2320 Loss=0.00135 Rate=1794.65 GlobalRate=1237.70 Time=04:31:02
| Training Device=xla:0/3 Epoch=1 Step=2340 Loss=0.00135 Rate=1794.95 GlobalRate=1240.99 Time=04:31:04
| Training Device=xla:0/2 Epoch=1 Step=2340 Loss=0.00135 Rate=1794.83 GlobalRate=1241.93 Time=04:31:04
| Training Device=xla:0/0 Epoch=1 Step=2340 Loss=0.00135 Rate=1794.70 GlobalRate=1236.43 Time=04:31:04
| Training Device=xla:0/1 Epoch=1 Step=2340 Loss=0.00135 Rate=1794.77 GlobalRate=1232.39 Time=04:31:04
Test failed with PT2E test, because the converter patch is commented out now. Move xla pin again after https://github.com/pytorch/xla/blob/master/openxla_patches/quant_dequant_converter.diff is upstreamed
The following GPU tests hit OOM in CI after pin update
PJRT_DEVICE=CUDA torchrun --nnodes=1 --node_rank=0 --nproc_per_node=2 test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=16 --num_epochs=1 --num_steps=25 --model=resnet18
PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1
PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
Example error message:
E0000 00:00:1709794596.611335 87900 pjrt_stream_executor_client.cc:2804] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 5571021088 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 134.46MiB
constant allocation: 4B
maybe_live_out allocation: 183.47MiB
preallocated temp allocation: 5.19GiB
preallocated temp fragmentation: 31.49MiB (0.59%)
total allocation: 5.42GiB
total fragmentation: 119.12MiB (2.15%)
Peak buffers:
Buffer 1:
Size: 196.00MiB
XLA Label: custom-call
Shape: f32[64,256,56,56]
==========================
...
Buffer 15:
Size: 98.00MiB
XLA Label: fusion
Shape: f32[64,128,56,56]
==========================
cc @will-cromar for some PJRT changes to accommodate the change of PJRT interface in upstream XLA .
Thanks @sdasgup3 for pointing out we need to generate custom call to mhlo.uniform_de/quantize
, to accommodate the incoming stablehlo.uniform_quantize/dequantize
in HLO->MHLO converter. Put the note here for reference.