jax icon indicating copy to clipboard operation
jax copied to clipboard

Memory complexity issue with pmap

Open mohamad-amin opened this issue 4 years ago • 8 comments

Hey!

I'm trying to compute the result of multiple kernel ridge regressions in a parallel mode. I've wrote the code and created jax expressions of my functions using jax.make_jaxpr. According to the jax expressions, the data and computation should fit into my GPU (I'm using 4 V100 GPU with 16GB of RAM on each, which amounts to 64GB of GPU RAM), and they should be very far from the actual limits of what I have, but surprisingly, it throws and OOM. (I'm using 64bit precision)

Basically, what I expect from the jax expressions is that the most expensive item here (memory-wise) should be the 4000 x 2000 x 10 x 10 along with the 20000x20000 matrix that are broadcasted on each GPU, which amounts to ~9GB of GPU RAM, but other than that, I can't see why this code can't fit in the GPU. (P.S: before entering the pmap, the gpu is in the state that is shown in the picture below)

Screen Shot 2021-11-18 at 12 21 16 AM

Error:

021-11-18 00:23:23.244442: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:802] failed to alloc 34156314624 bytes on host: CUDA_ERROR_INVALID_VALUE: invalid argument
2021-11-18 00:23:23.244517: W external/org_tensorflow/tensorflow/core/common_runtime/device/device_host_allocator.h:46] could not allocate pinned host memory of size: 34156314624

Killed

My compiled functions:

{ lambda a:bool[10,10] b:f64[10,10] c:f64[10,10] d:bool[1,10,10] e:f64[1,10,10] f:bool[1,10,10]
    g:f64[1,10,10]; h:f64[4,1,10] i:f64[4,1,1,2000,10,10] j:f64[4,1,1,1,10,10] k:f64[4,1,1,4000,10,10]
    l:f64[4,1,1,10] m:f64[10,10] n:f64[2000,10] o:f64[20000,20000] p:bool[] q:f64[2000,10]
    r:f64[4000,10] s:f64[4000,2000,10,10]. let
    t:f64[4,1] = xla_pmap[
      axis_name=<axis 0x2aad77a59550>
            slice_sizes=(1, 1, 1, 1)
            unique_indices=False
          ] ez fz
          gb:f64[1,10,4000] = broadcast_in_dim[
            broadcast_dimensions=(0, 1, 2)
            shape=(1, 10, 4000)
          ] ga
          gc:bool[1,10,4000] = ge gb 1.0
          gd:f64[4000] = broadcast_in_dim[broadcast_dimensions=() shape=(4000,)] 1.0
          ge:f64[1,10,4000] = xla_call[
            backend=None
            call_jaxpr={ lambda ; gf:bool[1,10,4000] gg:f64[4000] gh:f64[1,10,4000]. let
                gi:f64[10,4000] = broadcast_in_dim[
                  broadcast_dimensions=(1,)
                  shape=(10, 4000)
                ] gg
                gj:f64[1,10,4000] = broadcast_in_dim[
                  broadcast_dimensions=(1, 2)
                  shape=(1, 10, 4000)
                ] gi
                gk:f64[1,10,4000] = select gf gj gh
              in (gk,) }
            device=None
            donated_invars=(False, False, False)
            inline=False
            name=vmap(vmap(_where))
          ] gc gd gb
          gl:f64[1,10,4000] = sub 1.0 ge
          gm:f64[1,10] = reduce_sum[axes=(2,)] gl
          gn:f64[1] = dot_general[
            dimension_numbers=(((1,), (1,)), ((0,), (0,)))
            precision=None
            preferred_element_type=None
          ] bb gm
          go:f64[1] = mul gn -1.0
        in (go,) }
      devices=None
      donated_invars=(False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False)
      global_arg_shapes=(None, None, None, None, None, None, None, None, None, None, None, None)
      global_axis_size=None
      in_axes=(None, None, None, None, None, None, None, 0, 0, 0, 0, 0, None, None, None, None, None, None, None)
      name=compute_batch_uncertainty
      out_axes=(0,)
    ] a b c d e f g h i j k l m n o p q r s
  in (t,) }

in the compiled function above, there is a xla_call that calls this compiled function:

{ lambda a:bool[10,10] b:f64[10,10] c:f64[10,10] d:bool[10,10] e:f64[10,10]; f:f64[10,1,10]
    g:f64[4000,1,10,10] h:f64[2000,1,10,10] i:f64[1,10,10] j:f64[10] k:f64[2000,10]
    l:f64[20000,20000] m:bool[] n:f64[2000,10] o:f64[4000,10] p:f64[4000,2000,10,10]. let
    q:f64[2000,10,1,10] = transpose[permutation=(0, 2, 1, 3)] h
    r:f64[20000,10] = reshape[dimensions=None new_sizes=(20000, 10)] q
    s:f64[20000,10] = xla_call[
      backend=None
    call_jaxpr={ lambda ; t:f64[20000,20000] u:f64[20000,10]. let
        v:f64[20000,10] = triangular_solve[
        conjugate_a=False
        left_side=True
        lower=False
        transpose_a=True
        unit_diagonal=False
        ] t u
        w:f64[20000,10] = triangular_solve[
        conjugate_a=False
        left_side=True
        lower=False
        transpose_a=False
        unit_diagonal=False
        ] t v
    in (w,) }
      device=None
      donated_invars=(False, False)
      inline=False
      name=_cho_solve
    ] l r
    x:f64[2000,10,1,10] = reshape[dimensions=None new_sizes=(2000, 10, 1, 10)] s
    y:f64[2000,1,10,10] = transpose[permutation=(0, 2, 1, 3)] x
    z:f64[1,10,1,10] = dot_general[
      dimension_numbers=(((0, 2), (0, 2)), ((), ()))
      precision=None
      preferred_element_type=None
    ] h y
    ba:f64[1,1,10,10] = transpose[permutation=(0, 2, 1, 3)] z
    bb:f64[1,1,10,10] = broadcast_in_dim[
      broadcast_dimensions=(1, 2, 3)
      shape=(1, 1, 10, 10)
    ] i
    bc:f64[1,1,10,10] = sub bb ba
    bd:f64[1,10,1,10] = transpose[permutation=(0, 2, 1, 3)] bc
    be:f64[10,10] = reshape[dimensions=None new_sizes=(10, 10)] bd
    dt:i32[4000] = convert_element_type[new_dtype=int32 weak_type=False] dp
    du:i32[10,4000] = convert_element_type[new_dtype=int32 weak_type=False] ds
    dv:i32[4000,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(4000, 1)] dt
    dw:i32[10,4000,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(10, 4000, 1)
    ] du
    dx:i32[10,4000,1] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(10, 4000, 1)
    ] dv
    dy:i32[10,4000,2] = concatenate[dimension=2] dx dw
    dz:i32[10,4000,1] = iota[dimension=0 dtype=int32 shape=(10, 4000, 1)]
    ea:i32[10,4000,3] = concatenate[dimension=2] dz dy
    eb:f64[10,4000] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1, 2), start_index_map=(0, 1, 2))
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1, 1)
      unique_indices=False
    ] df ea
    ec:f64[10,4000] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(10, 4000)
    ] eb
    ed:bool[10,4000] = ge ec 1.0
    ee:f64[4000] = broadcast_in_dim[broadcast_dimensions=() shape=(4000,)] 1.0
    ef:f64[10,4000] = xla_call[
      backend=None
      call_jaxpr={ lambda ; eg:bool[10,4000] eh:f64[4000] ei:f64[10,4000]. let
          ej:f64[10,4000] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(10, 4000)
          ] eh
          ek:f64[10,4000] = select eg ej ei
        in (ek,) }
      device=None
      donated_invars=(False, False, False)
      inline=False
      name=vmap(_where)
    ] ed ee ec
    el:f64[10,4000] = sub 1.0 ef
    em:f64[10] = reduce_sum[axes=(1,)] el
  in (em,) }

mohamad-amin avatar Nov 18 '21 08:11 mohamad-amin

I guess the question is, is XLA's triangular_solve operator really working in-place? If not, this could be expected, but shouldn't it work inplace?

mohamad-amin avatar Nov 19 '21 07:11 mohamad-amin

I guess it's not triangular_solve. I checked the source code. The problem is that dot_general is not operating as it's expected to, and is creating a huge matrix for computing a multiplication of shape (a, b) x (b, c) that is as big as (a, b). In my case, (a, b) is way bigger than (a, c). I guess this is not an issue anymore (unless dot_general is not expected to have this behaviour?).

Is there any suggestion on how I can avoid memory problems while computing (a, b) x (b, c) dot? Let's say (a, b) takes more than half of my memory, then computing (a, b) x (b, c) will be impossible using jax.jit. But in fact, it is very possible by implementing the for loop. This gives rise to three questions:

  • Is this behaviour expected in jax?
  • If so, is it okay to be this way? I think it shouldn't be okay as jax is designed for scientific computations and in such computations we might encounter a lot of situations where we are computing such a huge matrix multiplication where one side is huge and takes more than half of memory but the final result is actually very small and the sequential code will be fast
  • Is there any workaround for me to avoid this memory problem? (I can't really use the sequential version here, jax arrays are immutable, it's a bit of a hassle...)

Well again I feel like this is an issue, but not the issue that I mentioned in the first post. This is a problem with huge matrix multiplication in Jax now.

mohamad-amin avatar Nov 19 '21 08:11 mohamad-amin

Well the work-around is this: (At least this is what comes to my mind:)

Z1 = lax.map(lambda X_i: np.einsum('j,jk->k', X_i, Y), X)
Z2 = X @ Y
np.alltrue(Z1 == Z2)
# True

This works fine but shouldn't this be automated if X @ Y doesn't fit in GPU or whatever memory?

Edit: Just realized that even this wouldn't work! Jax will autocompile this to dot_general again!

mohamad-amin avatar Nov 19 '21 20:11 mohamad-amin

Moreover, I noticed that calling lax_linalg.cholesky(A, False) on a 10 by 10 matrix A causes jax to soak up 100MB (!!) of memory. I'm really curious about why lax needs 100 MB of memory to compute cholesky factorization of a 10 by 10 matrix!

mohamad-amin avatar Nov 20 '21 03:11 mohamad-amin

Suggestion:

  • maybe jax should also have something like https://pytorch.org/docs/stable/generated/torch.bmm.html and automatically use it in vmaps (or pmaps and xmaps) whenever the vectorization causes OOM errors. This could maybe configured through a parameter that is passed in to vmap (optimize=False or True maybe?)

mohamad-amin avatar Nov 20 '21 04:11 mohamad-amin

It's impossible for us to debug your problem without a complete, self-contained Python code that reproduces your problem. I don't know what is happening in place and what is happening out-of-place without debugging it, and I can't do that without a way to run the code.

I note that JAX does have batched matrix multiplication operator, and vmap and einsum will use it if applicable. So that's probably not the problem.

hawkinsp avatar Nov 22 '21 14:11 hawkinsp

@mohamad-amin , did you ever find a solution to this? We're hitting a similar OOM issue from pmap when trying to use a neural_tangents linearized vision transformer, even though we shouldn't have any issues.

@hawkinsp , I have a colab I can share: https://colab.research.google.com/drive/184moQLq3tjo-wEpc8gD7fXCFguAVDBOm#scrollTo=k4CjYqp5qLvj

RylanSchaeffer avatar Mar 07 '22 19:03 RylanSchaeffer

@RylanSchaeffer Not yet, I solved my problem in another way though. I was also using pmap while computing the result of many kernel ridge regressions (in parallel) and the kernel ridge regression's code was mainly taken from the neural_tangents library's predict functions. How do you make sure that you shouldn't have any issues? Did you check the generated jax expressions?

mohamad-amin avatar Mar 07 '22 23:03 mohamad-amin