Memory complexity issue with pmap
Hey!
I'm trying to compute the result of multiple kernel ridge regressions in a parallel mode. I've wrote the code and created jax expressions of my functions using jax.make_jaxpr. According to the jax expressions, the data and computation should fit into my GPU (I'm using 4 V100 GPU with 16GB of RAM on each, which amounts to 64GB of GPU RAM), and they should be very far from the actual limits of what I have, but surprisingly, it throws and OOM. (I'm using 64bit precision)
Basically, what I expect from the jax expressions is that the most expensive item here (memory-wise) should be the 4000 x 2000 x 10 x 10 along with the 20000x20000 matrix that are broadcasted on each GPU, which amounts to ~9GB of GPU RAM, but other than that, I can't see why this code can't fit in the GPU. (P.S: before entering the pmap, the gpu is in the state that is shown in the picture below)
Error:
021-11-18 00:23:23.244442: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:802] failed to alloc 34156314624 bytes on host: CUDA_ERROR_INVALID_VALUE: invalid argument
2021-11-18 00:23:23.244517: W external/org_tensorflow/tensorflow/core/common_runtime/device/device_host_allocator.h:46] could not allocate pinned host memory of size: 34156314624
Killed
My compiled functions:
{ lambda a:bool[10,10] b:f64[10,10] c:f64[10,10] d:bool[1,10,10] e:f64[1,10,10] f:bool[1,10,10]
g:f64[1,10,10]; h:f64[4,1,10] i:f64[4,1,1,2000,10,10] j:f64[4,1,1,1,10,10] k:f64[4,1,1,4000,10,10]
l:f64[4,1,1,10] m:f64[10,10] n:f64[2000,10] o:f64[20000,20000] p:bool[] q:f64[2000,10]
r:f64[4000,10] s:f64[4000,2000,10,10]. let
t:f64[4,1] = xla_pmap[
axis_name=<axis 0x2aad77a59550>
slice_sizes=(1, 1, 1, 1)
unique_indices=False
] ez fz
gb:f64[1,10,4000] = broadcast_in_dim[
broadcast_dimensions=(0, 1, 2)
shape=(1, 10, 4000)
] ga
gc:bool[1,10,4000] = ge gb 1.0
gd:f64[4000] = broadcast_in_dim[broadcast_dimensions=() shape=(4000,)] 1.0
ge:f64[1,10,4000] = xla_call[
backend=None
call_jaxpr={ lambda ; gf:bool[1,10,4000] gg:f64[4000] gh:f64[1,10,4000]. let
gi:f64[10,4000] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(10, 4000)
] gg
gj:f64[1,10,4000] = broadcast_in_dim[
broadcast_dimensions=(1, 2)
shape=(1, 10, 4000)
] gi
gk:f64[1,10,4000] = select gf gj gh
in (gk,) }
device=None
donated_invars=(False, False, False)
inline=False
name=vmap(vmap(_where))
] gc gd gb
gl:f64[1,10,4000] = sub 1.0 ge
gm:f64[1,10] = reduce_sum[axes=(2,)] gl
gn:f64[1] = dot_general[
dimension_numbers=(((1,), (1,)), ((0,), (0,)))
precision=None
preferred_element_type=None
] bb gm
go:f64[1] = mul gn -1.0
in (go,) }
devices=None
donated_invars=(False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False)
global_arg_shapes=(None, None, None, None, None, None, None, None, None, None, None, None)
global_axis_size=None
in_axes=(None, None, None, None, None, None, None, 0, 0, 0, 0, 0, None, None, None, None, None, None, None)
name=compute_batch_uncertainty
out_axes=(0,)
] a b c d e f g h i j k l m n o p q r s
in (t,) }
in the compiled function above, there is a xla_call that calls this compiled function:
{ lambda a:bool[10,10] b:f64[10,10] c:f64[10,10] d:bool[10,10] e:f64[10,10]; f:f64[10,1,10]
g:f64[4000,1,10,10] h:f64[2000,1,10,10] i:f64[1,10,10] j:f64[10] k:f64[2000,10]
l:f64[20000,20000] m:bool[] n:f64[2000,10] o:f64[4000,10] p:f64[4000,2000,10,10]. let
q:f64[2000,10,1,10] = transpose[permutation=(0, 2, 1, 3)] h
r:f64[20000,10] = reshape[dimensions=None new_sizes=(20000, 10)] q
s:f64[20000,10] = xla_call[
backend=None
call_jaxpr={ lambda ; t:f64[20000,20000] u:f64[20000,10]. let
v:f64[20000,10] = triangular_solve[
conjugate_a=False
left_side=True
lower=False
transpose_a=True
unit_diagonal=False
] t u
w:f64[20000,10] = triangular_solve[
conjugate_a=False
left_side=True
lower=False
transpose_a=False
unit_diagonal=False
] t v
in (w,) }
device=None
donated_invars=(False, False)
inline=False
name=_cho_solve
] l r
x:f64[2000,10,1,10] = reshape[dimensions=None new_sizes=(2000, 10, 1, 10)] s
y:f64[2000,1,10,10] = transpose[permutation=(0, 2, 1, 3)] x
z:f64[1,10,1,10] = dot_general[
dimension_numbers=(((0, 2), (0, 2)), ((), ()))
precision=None
preferred_element_type=None
] h y
ba:f64[1,1,10,10] = transpose[permutation=(0, 2, 1, 3)] z
bb:f64[1,1,10,10] = broadcast_in_dim[
broadcast_dimensions=(1, 2, 3)
shape=(1, 1, 10, 10)
] i
bc:f64[1,1,10,10] = sub bb ba
bd:f64[1,10,1,10] = transpose[permutation=(0, 2, 1, 3)] bc
be:f64[10,10] = reshape[dimensions=None new_sizes=(10, 10)] bd
dt:i32[4000] = convert_element_type[new_dtype=int32 weak_type=False] dp
du:i32[10,4000] = convert_element_type[new_dtype=int32 weak_type=False] ds
dv:i32[4000,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(4000, 1)] dt
dw:i32[10,4000,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(10, 4000, 1)
] du
dx:i32[10,4000,1] = broadcast_in_dim[
broadcast_dimensions=(1, 2)
shape=(10, 4000, 1)
] dv
dy:i32[10,4000,2] = concatenate[dimension=2] dx dw
dz:i32[10,4000,1] = iota[dimension=0 dtype=int32 shape=(10, 4000, 1)]
ea:i32[10,4000,3] = concatenate[dimension=2] dz dy
eb:f64[10,4000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1, 2), start_index_map=(0, 1, 2))
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 1, 1)
unique_indices=False
] df ea
ec:f64[10,4000] = broadcast_in_dim[
broadcast_dimensions=(0, 1)
shape=(10, 4000)
] eb
ed:bool[10,4000] = ge ec 1.0
ee:f64[4000] = broadcast_in_dim[broadcast_dimensions=() shape=(4000,)] 1.0
ef:f64[10,4000] = xla_call[
backend=None
call_jaxpr={ lambda ; eg:bool[10,4000] eh:f64[4000] ei:f64[10,4000]. let
ej:f64[10,4000] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(10, 4000)
] eh
ek:f64[10,4000] = select eg ej ei
in (ek,) }
device=None
donated_invars=(False, False, False)
inline=False
name=vmap(_where)
] ed ee ec
el:f64[10,4000] = sub 1.0 ef
em:f64[10] = reduce_sum[axes=(1,)] el
in (em,) }
I guess the question is, is XLA's triangular_solve operator really working in-place? If not, this could be expected, but shouldn't it work inplace?
I guess it's not triangular_solve. I checked the source code. The problem is that dot_general is not operating as it's expected to, and is creating a huge matrix for computing a multiplication of shape (a, b) x (b, c) that is as big as (a, b). In my case, (a, b) is way bigger than (a, c). I guess this is not an issue anymore (unless dot_general is not expected to have this behaviour?).
Is there any suggestion on how I can avoid memory problems while computing (a, b) x (b, c) dot? Let's say (a, b) takes more than half of my memory, then computing (a, b) x (b, c) will be impossible using jax.jit. But in fact, it is very possible by implementing the for loop. This gives rise to three questions:
- Is this behaviour expected in jax?
- If so, is it okay to be this way? I think it shouldn't be okay as jax is designed for scientific computations and in such computations we might encounter a lot of situations where we are computing such a huge matrix multiplication where one side is huge and takes more than half of memory but the final result is actually very small and the sequential code will be fast
- Is there any workaround for me to avoid this memory problem? (I can't really use the sequential version here, jax arrays are immutable, it's a bit of a hassle...)
Well again I feel like this is an issue, but not the issue that I mentioned in the first post. This is a problem with huge matrix multiplication in Jax now.
Well the work-around is this: (At least this is what comes to my mind:)
Z1 = lax.map(lambda X_i: np.einsum('j,jk->k', X_i, Y), X)
Z2 = X @ Y
np.alltrue(Z1 == Z2)
# True
This works fine but shouldn't this be automated if X @ Y doesn't fit in GPU or whatever memory?
Edit: Just realized that even this wouldn't work! Jax will autocompile this to dot_general again!
Moreover, I noticed that calling lax_linalg.cholesky(A, False) on a 10 by 10 matrix A causes jax to soak up 100MB (!!) of memory. I'm really curious about why lax needs 100 MB of memory to compute cholesky factorization of a 10 by 10 matrix!
Suggestion:
- maybe jax should also have something like https://pytorch.org/docs/stable/generated/torch.bmm.html and automatically use it in vmaps (or pmaps and xmaps) whenever the vectorization causes OOM errors. This could maybe configured through a parameter that is passed in to vmap (optimize=False or True maybe?)
It's impossible for us to debug your problem without a complete, self-contained Python code that reproduces your problem. I don't know what is happening in place and what is happening out-of-place without debugging it, and I can't do that without a way to run the code.
I note that JAX does have batched matrix multiplication operator, and vmap and einsum will use it if applicable. So that's probably not the problem.
@mohamad-amin , did you ever find a solution to this? We're hitting a similar OOM issue from pmap when trying to use a neural_tangents linearized vision transformer, even though we shouldn't have any issues.
@hawkinsp , I have a colab I can share: https://colab.research.google.com/drive/184moQLq3tjo-wEpc8gD7fXCFguAVDBOm#scrollTo=k4CjYqp5qLvj
@RylanSchaeffer Not yet, I solved my problem in another way though. I was also using pmap while computing the result of many kernel ridge regressions (in parallel) and the kernel ridge regression's code was mainly taken from the neural_tangents library's predict functions. How do you make sure that you shouldn't have any issues? Did you check the generated jax expressions?