jax icon indicating copy to clipboard operation
jax copied to clipboard

Matrix factorization on multiple GPUs causes memory overflow

Open ChenAo-Phys opened this issue 1 year ago • 3 comments

Description

My original aim is to compute a huge amount of determinants that can't fit into the memory of a single GPU. I always get memory overflows when I run it on multiple GPUs, and the problem seems to be the matrix factorization.

Here is a simple example to perform matrix factorization in parallel. I parallelize more computations when I have more machines.

import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding, Mesh, PartitionSpec

mesh = Mesh(jax.devices(), "x")
pspecs = PartitionSpec("x")
sharding = NamedSharding(mesh, pspecs)

a = jnp.zeros((jax.device_count(), 1500000, 40, 40), device=sharding)
out = jax.lax.linalg.lu(a)

It works well on a single A100-80GB GPU, but causes the following memory overflow on 3 GPUs. It seems that I can never parallelize more computations with multiple machines. Other matrix factorizations like qr or cholesky cause the same problem.

2024-10-15 16:30:37.290084: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 21.15GiB (22705764841 bytes) by rematerialization; only reduced to 90.75GiB (97440000040 bytes), down from 90.75GiB (97440000040 bytes) originally
2024-10-15 16:30:47.781059: W external/xla/xla/tsl/framework/bfc_allocator.cc:497] Allocator (GPU_0_bfc) ran out of memory trying to allocate 26.83GiB (rounded to 28804500480)requested by op 
2024-10-15 16:30:47.781185: W external/xla/xla/tsl/framework/bfc_allocator.cc:497] Allocator (GPU_2_bfc) ran out of memory trying to allocate 26.83GiB (rounded to 28804500480)requested by op 
2024-10-15 16:30:47.781317: W external/xla/xla/tsl/framework/bfc_allocator.cc:508] ***************************************************************_____________________________________
E1015 16:30:47.781354  473320 pjrt_stream_executor_client.cc:3084] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 28804500256 bytes.
2024-10-15 16:30:47.781345: W external/xla/xla/tsl/framework/bfc_allocator.cc:497] Allocator (GPU_1_bfc) ran out of memory trying to allocate 26.83GiB (rounded to 28804500480)requested by op 
2024-10-15 16:30:47.781454: W external/xla/xla/tsl/framework/bfc_allocator.cc:508] ***************************************************************_____________________________________
E1015 16:30:47.781486  473326 pjrt_stream_executor_client.cc:3084] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 28804500256 bytes.
2024-10-15 16:30:47.781549: W external/xla/xla/tsl/framework/bfc_allocator.cc:508] ***************************************************************_____________________________________
E1015 16:30:47.781597  473323 pjrt_stream_executor_client.cc:3084] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 28804500256 bytes.
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[1], line 13
     10 sharding = NamedSharding(mesh, pspecs)
     12 a = jnp.zeros((jax.device_count(), 1500000, 40, 40), device=sharding)
---> 13 out = jax.lax.linalg.lu(a)

File /hpc/gpfs2/scratch/u/chenao/.conda/envs/quantax_env/lib/python3.12/site-packages/jax/_src/lax/linalg.py:277, in lu(x)
    247 def lu(x: ArrayLike) -> tuple[Array, Array, Array]:
    248   """LU decomposition with partial pivoting.
    249 
    250   Computes the matrix decomposition:
   (...)
    275     ``[..., m]``.
    276   """
--> 277   lu, pivots, permutation = lu_p.bind(x)
    278   return lu, pivots, permutation

File /hpc/gpfs2/scratch/u/chenao/.conda/envs/quantax_env/lib/python3.12/site-packages/jax/_src/core.py:438, in Primitive.bind(self, *args, **params)
    435 def bind(self, *args, **params):
    436   assert (not config.enable_checks.value or
    437           all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 438   return self.bind_with_trace(find_top_trace(args), args, params)

File /hpc/gpfs2/scratch/u/chenao/.conda/envs/quantax_env/lib/python3.12/site-packages/jax/_src/core.py:442, in Primitive.bind_with_trace(self, trace, args, params)
    440 def bind_with_trace(self, trace, args, params):
    441   with pop_level(trace.level):
--> 442     out = trace.process_primitive(self, map(trace.full_raise, args), params)
    443   return map(full_lower, out) if self.multiple_results else full_lower(out)

File /hpc/gpfs2/scratch/u/chenao/.conda/envs/quantax_env/lib/python3.12/site-packages/jax/_src/core.py:948, in EvalTrace.process_primitive(self, primitive, tracers, params)
    946   return call_impl_with_key_reuse_checks(primitive, primitive.impl, *tracers, **params)
    947 else:
--> 948   return primitive.impl(*tracers, **params)

File /hpc/gpfs2/scratch/u/chenao/.conda/envs/quantax_env/lib/python3.12/site-packages/jax/_src/lax/linalg.py:1408, in _lu_impl(operand)
   1407 def _lu_impl(operand):
-> 1408   lu, pivot, perm = dispatch.apply_primitive(lu_p, operand)
   1409   return lu, pivot, perm

File /hpc/gpfs2/scratch/u/chenao/.conda/envs/quantax_env/lib/python3.12/site-packages/jax/_src/dispatch.py:90, in apply_primitive(prim, *args, **params)
     88 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
     89 try:
---> 90   outs = fun(*args)
     91 finally:
     92   lib.jax_jit.swap_thread_local_state_disable_jit(prev)

    [... skipping hidden 10 frame]

File /hpc/gpfs2/scratch/u/chenao/.conda/envs/quantax_env/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py:1288, in ExecuteReplicated.__call__(self, *args)
   1286   self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
   1287 else:
-> 1288   results = self.xla_executable.execute_sharded(input_bufs)
   1290 if dispatch.needs_check_special():
   1291   out_arrays = results.disassemble_into_single_device_arrays()

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 28804500256 bytes.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).

I have tried jax.transfer_guard to ensure that no data is transferred among machines during matrix factorization. Maybe I made some stupid mistakes. I really appreciate any help with this problem, or any suggestion to obtain determinants in parallel.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.34
jaxlib: 0.4.34
numpy:  2.0.2
python: 3.12.7 | packaged by Anaconda, Inc. | (main, Oct  4 2024, 13:27:36) [GCC 11.2.0]
jax.devices (3 total, 3 local): [CudaDevice(id=0) CudaDevice(id=1) CudaDevice(id=2)]
process_count: 1
platform: uname_result(system='Linux', node='alcc145', release='5.15.0-94-generic', version='#104-Ubuntu SMP Tue Jan 9 15:25:40 UTC 2024', machine='x86_64')


$ nvidia-smi
Tue Oct 15 16:58:25 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.07             Driver Version: 535.161.07   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100 80GB PCIe          Off | 00000000:21:00.0 Off |                    0 |
| N/A   35C    P0              66W / 300W |    425MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80GB PCIe          Off | 00000000:81:00.0 Off |                    0 |
| N/A   31C    P0              63W / 300W |    425MiB / 81920MiB |      2%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA A100 80GB PCIe          Off | 00000000:E2:00.0 Off |                    0 |
| N/A   33C    P0              63W / 300W |    425MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

ChenAo-Phys avatar Oct 15 '24 15:10 ChenAo-Phys

Unfortunately, the JAX/XLA compiler currently does not know how to shard jax.lax.linalg.lu operator.

As a workaround, you can still do the sharding manually, see the JAX docs.

For your example, the manually parallelized version could be something like this:

import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding, Mesh, PartitionSpec
from jax.experimental.shard_map import shard_map

mesh = Mesh(jax.devices(), "x")
pspecs = PartitionSpec("x", None)
sharding = NamedSharding(mesh, pspecs)

a = jnp.zeros((jax.device_count(), 1500000, 40, 40), device=sharding)

sharded_lu = shard_map(jax.lax.linalg.lu,
                       mesh=mesh,
                       in_specs=(pspecs,),
                       out_specs=(pspecs, pspecs, pspecs))
sharded_jitted_lu = jax.jit(sharded_lu)
out = sharded_jitted_lu(a)

jaro-sevcik avatar Oct 16 '24 12:10 jaro-sevcik

@jaro-sevcik Thanks very much for your reply! This solves my problem.

By the way, is there any plan to solve the sharding problem in the future? It looks straightforward.

ChenAo-Phys avatar Oct 16 '24 14:10 ChenAo-Phys

@ChenAo-Phys — Thanks for the report! And thanks to @jaro-sevcik for suggesting this workaround.

By the way, is there any plan to solve the sharding problem in the future? It looks straightforward.

JAX doesn't currently have a great API for customizing the sharding behavior of custom calls (which is how most of the linear algebra operations are implemented). custom_partitioning is one option, but the fact that it is implemented using Python callbacks means that it can introduce some surprising issues that are probably beyond the scope of the discussion here. All that to say, a better solution would be great (and should be possible with upstream changes in XLA), but I'd say it's probably not straightforward. For now, @jaro-sevcik's suggestion to use shard_map is the best approach!

dfm avatar Oct 16 '24 14:10 dfm

@ChenAo-Phys — The LU factorization on GPU should now (with JAX v0.6.0) automatically shard properly in cases like this. Want to try again in your environment if this is still a blocker?

dfm avatar Apr 17 '25 15:04 dfm

@ChenAo-Phys — The LU factorization on GPU should now (with JAX v0.6.0) automatically shard properly in cases like this. Want to try again in your environment if this is still a blocker?

Thanks very much for the fix! It now works perfectly on multiple devices.

ChenAo-Phys avatar Apr 17 '25 15:04 ChenAo-Phys

Excellent! I'm going to close this now, but I will note that the Cholesky factorization still doesn't shard properly on GPU. All the other factorizations should work! It's on my to do list to fix Cholesky as well.

dfm avatar Apr 17 '25 16:04 dfm