Matrix factorization on multiple GPUs causes memory overflow
Description
My original aim is to compute a huge amount of determinants that can't fit into the memory of a single GPU. I always get memory overflows when I run it on multiple GPUs, and the problem seems to be the matrix factorization.
Here is a simple example to perform matrix factorization in parallel. I parallelize more computations when I have more machines.
import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding, Mesh, PartitionSpec
mesh = Mesh(jax.devices(), "x")
pspecs = PartitionSpec("x")
sharding = NamedSharding(mesh, pspecs)
a = jnp.zeros((jax.device_count(), 1500000, 40, 40), device=sharding)
out = jax.lax.linalg.lu(a)
It works well on a single A100-80GB GPU, but causes the following memory overflow on 3 GPUs. It seems that I can never parallelize more computations with multiple machines. Other matrix factorizations like qr or cholesky cause the same problem.
2024-10-15 16:30:37.290084: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 21.15GiB (22705764841 bytes) by rematerialization; only reduced to 90.75GiB (97440000040 bytes), down from 90.75GiB (97440000040 bytes) originally
2024-10-15 16:30:47.781059: W external/xla/xla/tsl/framework/bfc_allocator.cc:497] Allocator (GPU_0_bfc) ran out of memory trying to allocate 26.83GiB (rounded to 28804500480)requested by op
2024-10-15 16:30:47.781185: W external/xla/xla/tsl/framework/bfc_allocator.cc:497] Allocator (GPU_2_bfc) ran out of memory trying to allocate 26.83GiB (rounded to 28804500480)requested by op
2024-10-15 16:30:47.781317: W external/xla/xla/tsl/framework/bfc_allocator.cc:508] ***************************************************************_____________________________________
E1015 16:30:47.781354 473320 pjrt_stream_executor_client.cc:3084] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 28804500256 bytes.
2024-10-15 16:30:47.781345: W external/xla/xla/tsl/framework/bfc_allocator.cc:497] Allocator (GPU_1_bfc) ran out of memory trying to allocate 26.83GiB (rounded to 28804500480)requested by op
2024-10-15 16:30:47.781454: W external/xla/xla/tsl/framework/bfc_allocator.cc:508] ***************************************************************_____________________________________
E1015 16:30:47.781486 473326 pjrt_stream_executor_client.cc:3084] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 28804500256 bytes.
2024-10-15 16:30:47.781549: W external/xla/xla/tsl/framework/bfc_allocator.cc:508] ***************************************************************_____________________________________
E1015 16:30:47.781597 473323 pjrt_stream_executor_client.cc:3084] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 28804500256 bytes.
---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
Cell In[1], line 13
10 sharding = NamedSharding(mesh, pspecs)
12 a = jnp.zeros((jax.device_count(), 1500000, 40, 40), device=sharding)
---> 13 out = jax.lax.linalg.lu(a)
File /hpc/gpfs2/scratch/u/chenao/.conda/envs/quantax_env/lib/python3.12/site-packages/jax/_src/lax/linalg.py:277, in lu(x)
247 def lu(x: ArrayLike) -> tuple[Array, Array, Array]:
248 """LU decomposition with partial pivoting.
249
250 Computes the matrix decomposition:
(...)
275 ``[..., m]``.
276 """
--> 277 lu, pivots, permutation = lu_p.bind(x)
278 return lu, pivots, permutation
File /hpc/gpfs2/scratch/u/chenao/.conda/envs/quantax_env/lib/python3.12/site-packages/jax/_src/core.py:438, in Primitive.bind(self, *args, **params)
435 def bind(self, *args, **params):
436 assert (not config.enable_checks.value or
437 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 438 return self.bind_with_trace(find_top_trace(args), args, params)
File /hpc/gpfs2/scratch/u/chenao/.conda/envs/quantax_env/lib/python3.12/site-packages/jax/_src/core.py:442, in Primitive.bind_with_trace(self, trace, args, params)
440 def bind_with_trace(self, trace, args, params):
441 with pop_level(trace.level):
--> 442 out = trace.process_primitive(self, map(trace.full_raise, args), params)
443 return map(full_lower, out) if self.multiple_results else full_lower(out)
File /hpc/gpfs2/scratch/u/chenao/.conda/envs/quantax_env/lib/python3.12/site-packages/jax/_src/core.py:948, in EvalTrace.process_primitive(self, primitive, tracers, params)
946 return call_impl_with_key_reuse_checks(primitive, primitive.impl, *tracers, **params)
947 else:
--> 948 return primitive.impl(*tracers, **params)
File /hpc/gpfs2/scratch/u/chenao/.conda/envs/quantax_env/lib/python3.12/site-packages/jax/_src/lax/linalg.py:1408, in _lu_impl(operand)
1407 def _lu_impl(operand):
-> 1408 lu, pivot, perm = dispatch.apply_primitive(lu_p, operand)
1409 return lu, pivot, perm
File /hpc/gpfs2/scratch/u/chenao/.conda/envs/quantax_env/lib/python3.12/site-packages/jax/_src/dispatch.py:90, in apply_primitive(prim, *args, **params)
88 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
89 try:
---> 90 outs = fun(*args)
91 finally:
92 lib.jax_jit.swap_thread_local_state_disable_jit(prev)
[... skipping hidden 10 frame]
File /hpc/gpfs2/scratch/u/chenao/.conda/envs/quantax_env/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py:1288, in ExecuteReplicated.__call__(self, *args)
1286 self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
1287 else:
-> 1288 results = self.xla_executable.execute_sharded(input_bufs)
1290 if dispatch.needs_check_special():
1291 out_arrays = results.disassemble_into_single_device_arrays()
XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 28804500256 bytes.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
I have tried jax.transfer_guard to ensure that no data is transferred among machines during matrix factorization. Maybe I made some stupid mistakes. I really appreciate any help with this problem, or any suggestion to obtain determinants in parallel.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.34
jaxlib: 0.4.34
numpy: 2.0.2
python: 3.12.7 | packaged by Anaconda, Inc. | (main, Oct 4 2024, 13:27:36) [GCC 11.2.0]
jax.devices (3 total, 3 local): [CudaDevice(id=0) CudaDevice(id=1) CudaDevice(id=2)]
process_count: 1
platform: uname_result(system='Linux', node='alcc145', release='5.15.0-94-generic', version='#104-Ubuntu SMP Tue Jan 9 15:25:40 UTC 2024', machine='x86_64')
$ nvidia-smi
Tue Oct 15 16:58:25 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.07 Driver Version: 535.161.07 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA A100 80GB PCIe Off | 00000000:21:00.0 Off | 0 |
| N/A 35C P0 66W / 300W | 425MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA A100 80GB PCIe Off | 00000000:81:00.0 Off | 0 |
| N/A 31C P0 63W / 300W | 425MiB / 81920MiB | 2% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 2 NVIDIA A100 80GB PCIe Off | 00000000:E2:00.0 Off | 0 |
| N/A 33C P0 63W / 300W | 425MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
Unfortunately, the JAX/XLA compiler currently does not know how to shard jax.lax.linalg.lu operator.
As a workaround, you can still do the sharding manually, see the JAX docs.
For your example, the manually parallelized version could be something like this:
import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding, Mesh, PartitionSpec
from jax.experimental.shard_map import shard_map
mesh = Mesh(jax.devices(), "x")
pspecs = PartitionSpec("x", None)
sharding = NamedSharding(mesh, pspecs)
a = jnp.zeros((jax.device_count(), 1500000, 40, 40), device=sharding)
sharded_lu = shard_map(jax.lax.linalg.lu,
mesh=mesh,
in_specs=(pspecs,),
out_specs=(pspecs, pspecs, pspecs))
sharded_jitted_lu = jax.jit(sharded_lu)
out = sharded_jitted_lu(a)
@jaro-sevcik Thanks very much for your reply! This solves my problem.
By the way, is there any plan to solve the sharding problem in the future? It looks straightforward.
@ChenAo-Phys — Thanks for the report! And thanks to @jaro-sevcik for suggesting this workaround.
By the way, is there any plan to solve the sharding problem in the future? It looks straightforward.
JAX doesn't currently have a great API for customizing the sharding behavior of custom calls (which is how most of the linear algebra operations are implemented). custom_partitioning is one option, but the fact that it is implemented using Python callbacks means that it can introduce some surprising issues that are probably beyond the scope of the discussion here. All that to say, a better solution would be great (and should be possible with upstream changes in XLA), but I'd say it's probably not straightforward. For now, @jaro-sevcik's suggestion to use shard_map is the best approach!
@ChenAo-Phys — The LU factorization on GPU should now (with JAX v0.6.0) automatically shard properly in cases like this. Want to try again in your environment if this is still a blocker?
@ChenAo-Phys — The LU factorization on GPU should now (with JAX v0.6.0) automatically shard properly in cases like this. Want to try again in your environment if this is still a blocker?
Thanks very much for the fix! It now works perfectly on multiple devices.
Excellent! I'm going to close this now, but I will note that the Cholesky factorization still doesn't shard properly on GPU. All the other factorizations should work! It's on my to do list to fix Cholesky as well.