jax
jax copied to clipboard
XLA decides to transpose huge parameter matrices
Description
I've encountered an interesting situation that I've described in more details here: https://github.com/google/jax/discussions/20284#discussioncomment-8815174
Basically, the problem is as follows:
- I've written some code for running inference on a large model that does a lot of large matmuls.
- The model is so large that only one copy of model parameters will fit into GPU memory.
- Unfortunately, for some reason XLA decides to transpose all model parameters at the beginning of inference, thus effectively doubling memory consumption.
I've tried to produce a small repro for my problem, but, unfortunately, was unable to achieve the same behavior within a simplified example. I can however provide HLO for my inference code if that would be of any help.
Basically my main questions are these:
- Is there any mechanism I can use to prevent XLA from transposing my parameter matrices? I'm happy with less efficient matmuls if I can avoid OOM.
- XLA could have done transposes on-the-fly, allocating memory just before matmuls and deallocating after, which would also have been relatively fine. However it decided to do it once in advance for all parameters, presumably because parameters are being re-used in a scan loop multiple times. Perhaps there is a way to avoid this behavior, doing transposes only when they are needed and deallocating the memory right after?
Overall, it seems that in my case there might exist a tradeoff between memory consumption and efficiency, and XLA optimizes for efficiency, ignoring memory constraints of the device, which is not ideal. Perhaps it can somehow be made to respect device constraints when deciding on optimizations?
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.25
jaxlib: 0.4.25
numpy: 1.24.3
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (8 total, 8 local): [cuda(id=0) cuda(id=1) ... cuda(id=6) cuda(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='hr0nix-dev-8-gpu', release='5.4.0-155-generic', version='#172-Ubuntu SMP Fri Jul 7 16:10:02 UTC 2023', machine='x86_64')
$ nvidia-smi
Mon Mar 18 19:22:10 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12 Driver Version: 535.104.12 CUDA Version: 12.3 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 |
| N/A 35C P0 116W / 700W | 539MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 |
| N/A 31C P0 114W / 700W | 539MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 2 NVIDIA H100 80GB HBM3 On | 00000000:95:00.0 Off | 0 |
| N/A 34C P0 114W / 700W | 539MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 3 NVIDIA H100 80GB HBM3 On | 00000000:99:00.0 Off | 0 |
| N/A 30C P0 112W / 700W | 539MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 4 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 |
| N/A 35C P0 114W / 700W | 539MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 5 NVIDIA H100 80GB HBM3 On | 00000000:AF:00.0 Off | 0 |
| N/A 31C P0 115W / 700W | 539MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 6 NVIDIA H100 80GB HBM3 On | 00000000:B3:00.0 Off | 0 |
| N/A 34C P0 117W / 700W | 539MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 7 NVIDIA H100 80GB HBM3 On | 00000000:B7:00.0 Off | 0 |
| N/A 30C P0 111W / 700W | 539MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
CC: @nouiz
Could you transpose the original weight directly? That way, XLA won't do the transpose at all. All other options XLA could do would be slower.
Otherwise, you could dump the HLO after each optimization pass. With that, find which pass does this, then use an XLA flag to disable it (and hope this doesn't cause other issues).
You can dump the HLO like this:
XLA_FLAGS="--xla_dump_to=A_DIRECTORY_NAME --xla_dump_hlo_as_text
How can I distinguish between different optimization passes? Looks like these flags only give me HLO with and without all optimizations:
$ ls -l ./hlo_dump/ | grep decode_fn
-rw-r--r-- 1 root root 124718 Mar 19 16:49 module_0016.jit__decode_fn.autotune_results.pbtxt
-rw-r--r-- 1 root root 9361826 Mar 19 16:49 module_0016.jit__decode_fn.before_optimizations.txt
-rw-r--r-- 1 root root 623 Mar 19 16:49 module_0016.jit__decode_fn.gpu_target_config.pbtxt
-rw-r--r-- 1 root root 6398565 Mar 19 16:49 module_0016.jit__decode_fn.ir-no-opt.ll
-rw-r--r-- 1 root root 3782920 Mar 19 16:49 module_0016.jit__decode_fn.ir-with-opt.ll
-rw-r--r-- 1 root root 1850107 Mar 19 16:49 module_0016.jit__decode_fn.ptx
-rw-r--r-- 1 root root 14278686 Mar 19 16:49 module_0016.jit__decode_fn.sm_9.0_gpu_after_optimizations-buffer-assignment.txt
-rw-r--r-- 1 root root 8605183 Mar 19 16:49 module_0016.jit__decode_fn.sm_9.0_gpu_after_optimizations.txt
-rw-r--r-- 1 root root 6039 Mar 19 16:49 module_0016.jit__decode_fn.thunk_sequence.txt
Nevermind, found it: --xla_dump_hlo_pass_re=.*
Turned out in my case the undesired behavior can be turned off by using xla_gpu_enable_dot_strength_reduction=false
. It does come with a significant performance reduction though.
The problem as I understand it was that I was running inference with batch size 1, which meant that some of my matmuls were matrix-vector products. One of the XLA transforms rewrites such matmuls to make them more efficient, but it might require transposing the matrix. xla_gpu_enable_dot_strength_reduction=false
seems to disable this rewrite.
How much is the slowdown when you use that flag?
Could it be because that optimization is disabled at many places in the graph?
If so, you could modify XLA to allow to disable it only for that node and not globally for the graph.
If you are ready to try that, you could modify the file: xla/service/algebraic_simplifier.cc, search for enable_dot_strength_reduction
. There is a condition there, maybe you can add an extra && with something like
dot->name() != "the_name".
This is an XLA issue, so could be worthwhile to report on the openxla/xla repo.