jax icon indicating copy to clipboard operation
jax copied to clipboard

XLA decides to transpose huge parameter matrices

Open hr0nix opened this issue 3 months ago • 4 comments

Description

I've encountered an interesting situation that I've described in more details here: https://github.com/google/jax/discussions/20284#discussioncomment-8815174

Basically, the problem is as follows:

  • I've written some code for running inference on a large model that does a lot of large matmuls.
  • The model is so large that only one copy of model parameters will fit into GPU memory.
  • Unfortunately, for some reason XLA decides to transpose all model parameters at the beginning of inference, thus effectively doubling memory consumption.

I've tried to produce a small repro for my problem, but, unfortunately, was unable to achieve the same behavior within a simplified example. I can however provide HLO for my inference code if that would be of any help.

Basically my main questions are these:

  • Is there any mechanism I can use to prevent XLA from transposing my parameter matrices? I'm happy with less efficient matmuls if I can avoid OOM.
  • XLA could have done transposes on-the-fly, allocating memory just before matmuls and deallocating after, which would also have been relatively fine. However it decided to do it once in advance for all parameters, presumably because parameters are being re-used in a scan loop multiple times. Perhaps there is a way to avoid this behavior, doing transposes only when they are needed and deallocating the memory right after?

Overall, it seems that in my case there might exist a tradeoff between memory consumption and efficiency, and XLA optimizes for efficiency, ignoring memory constraints of the device, which is not ideal. Perhaps it can somehow be made to respect device constraints when deciding on optimizations?

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.25
jaxlib: 0.4.25
numpy:  1.24.3
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (8 total, 8 local): [cuda(id=0) cuda(id=1) ... cuda(id=6) cuda(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='hr0nix-dev-8-gpu', release='5.4.0-155-generic', version='#172-Ubuntu SMP Fri Jul 7 16:10:02 UTC 2023', machine='x86_64')


$ nvidia-smi
Mon Mar 18 19:22:10 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 80GB HBM3          On  | 00000000:8D:00.0 Off |                    0 |
| N/A   35C    P0             116W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  | 00000000:91:00.0 Off |                    0 |
| N/A   31C    P0             114W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  | 00000000:95:00.0 Off |                    0 |
| N/A   34C    P0             114W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  | 00000000:99:00.0 Off |                    0 |
| N/A   30C    P0             112W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  | 00000000:AB:00.0 Off |                    0 |
| N/A   35C    P0             114W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  | 00000000:AF:00.0 Off |                    0 |
| N/A   31C    P0             115W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  | 00000000:B3:00.0 Off |                    0 |
| N/A   34C    P0             117W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  | 00000000:B7:00.0 Off |                    0 |
| N/A   30C    P0             111W / 700W |    539MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+

hr0nix avatar Mar 18 '24 19:03 hr0nix

CC: @nouiz

mjsML avatar Mar 18 '24 20:03 mjsML

Could you transpose the original weight directly? That way, XLA won't do the transpose at all. All other options XLA could do would be slower.

Otherwise, you could dump the HLO after each optimization pass. With that, find which pass does this, then use an XLA flag to disable it (and hope this doesn't cause other issues).

You can dump the HLO like this: XLA_FLAGS="--xla_dump_to=A_DIRECTORY_NAME --xla_dump_hlo_as_text

nouiz avatar Mar 19 '24 04:03 nouiz

How can I distinguish between different optimization passes? Looks like these flags only give me HLO with and without all optimizations:

$ ls -l ./hlo_dump/ | grep decode_fn
-rw-r--r-- 1 root root   124718 Mar 19 16:49 module_0016.jit__decode_fn.autotune_results.pbtxt
-rw-r--r-- 1 root root  9361826 Mar 19 16:49 module_0016.jit__decode_fn.before_optimizations.txt
-rw-r--r-- 1 root root      623 Mar 19 16:49 module_0016.jit__decode_fn.gpu_target_config.pbtxt
-rw-r--r-- 1 root root  6398565 Mar 19 16:49 module_0016.jit__decode_fn.ir-no-opt.ll
-rw-r--r-- 1 root root  3782920 Mar 19 16:49 module_0016.jit__decode_fn.ir-with-opt.ll
-rw-r--r-- 1 root root  1850107 Mar 19 16:49 module_0016.jit__decode_fn.ptx
-rw-r--r-- 1 root root 14278686 Mar 19 16:49 module_0016.jit__decode_fn.sm_9.0_gpu_after_optimizations-buffer-assignment.txt
-rw-r--r-- 1 root root  8605183 Mar 19 16:49 module_0016.jit__decode_fn.sm_9.0_gpu_after_optimizations.txt
-rw-r--r-- 1 root root     6039 Mar 19 16:49 module_0016.jit__decode_fn.thunk_sequence.txt

hr0nix avatar Mar 19 '24 16:03 hr0nix

Nevermind, found it: --xla_dump_hlo_pass_re=.*

hr0nix avatar Mar 19 '24 17:03 hr0nix

Turned out in my case the undesired behavior can be turned off by using xla_gpu_enable_dot_strength_reduction=false. It does come with a significant performance reduction though.

hr0nix avatar Mar 20 '24 16:03 hr0nix

The problem as I understand it was that I was running inference with batch size 1, which meant that some of my matmuls were matrix-vector products. One of the XLA transforms rewrites such matmuls to make them more efficient, but it might require transposing the matrix. xla_gpu_enable_dot_strength_reduction=false seems to disable this rewrite.

hr0nix avatar Mar 22 '24 00:03 hr0nix

How much is the slowdown when you use that flag? Could it be because that optimization is disabled at many places in the graph? If so, you could modify XLA to allow to disable it only for that node and not globally for the graph. If you are ready to try that, you could modify the file: xla/service/algebraic_simplifier.cc, search for enable_dot_strength_reduction . There is a condition there, maybe you can add an extra && with something like dot->name() != "the_name".

This is an XLA issue, so could be worthwhile to report on the openxla/xla repo.

nouiz avatar Mar 28 '24 16:03 nouiz