jax icon indicating copy to clipboard operation
jax copied to clipboard

vmap of cond's predicate results in select, leading to unexpected compute/memory use

Open aespielberg opened this issue 3 years ago • 32 comments

I have been playing around with converting diffmpm from the difftaichi package into a jax version, and while the forward pass has been working wonderfully, the backward pass has been using way too much GPU memory.

Today, I was able to track down that memory usage to the grid op. The grid op step is a series of nested if statements. At first, I was using jnp.where, which evaluates all branches. That is extremely inefficient and can lead to OOM errors. I simplified my code, and switch to jnp.cond, but my only conclusion is that cond is also evaluating both branches, otherwise I cannot see why this would run into OOM issues.

Below is a modified version of the grid op, that is composed into itself 4,000 times, like a simulation. Even run with the XLA_PYTHON_CLIENT_PREALLOCATE=false flag, this quickly leads to the the whole GPU being used, and more if the loop length is increased. This is not true if every line from lin = .... until right before the return of grid_op is commented out. In that case, memory usage is practically negligible. Note that because bound = 0, literally every line written v_out = jax.lax.cond ... evaluates to False by definition, and so most of the expressions, including the v_out_gate's and their dependencies, shouldn't even need to be evaluated in the jitted function.

Maybe I am misunderstanding cond; if so, what is the proper way to get this sparse branching behavior? I don't want to evlauate and hang onto a bunch of expensive tensors that are never actually needed and crash my GPU with OOM, especially in an backward pass. This is a core bottleneck to practical deployment of my code and a feature that I think should be supported. FWIW, I am using Version: 0.1.69+cuda101

Code to reproduce is below.

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import os
import math
import numpy as np
import matplotlib.pyplot as plt
import jax.nn as jnn
import jax.lax as jlax
import timeit
import jax

dim = 2
n_grid = 128
dt = 1e-3
gravity = 3.8


def allocate_arrays():
  global grid_m_in, grid_v_in, grid_v_out, loss, index_array
  grid_m_in = jnp.ones((n_grid, n_grid))
  grid_v_in = jnp.zeros((n_grid, n_grid, dim))
  grid_v_out = jnp.zeros((n_grid, n_grid, dim))


  index_array = np.zeros((n_grid, n_grid, dim))
  
  for i in range(n_grid):
    for j in range(n_grid):
      index_array[i, j] = np.array([i, j])
 
  index_array = jnp.array(index_array)

  

def grid_op(grid_v_in, grid_m_in, index_tuple):
  bound = 0
  coeff = 0.5
  
  i = index_tuple[0]
  j = index_tuple[1]
  
  normal = jnp.array([0., 1.])
  
  inv_m = 1 / (grid_m_in + 1e-10)
  v_out = jnp.expand_dims(inv_m, -1) * grid_v_in
  v_out -= dt * gravity * jnp.array([0., 1.])
  
  v_out = jax.lax.cond(jnp.logical_and(i < bound, v_out[0] < 0), lambda _: jnp.zeros_like(v_out), lambda _: v_out, operand=None)
  
  
  v_out = jax.lax.cond(jnp.logical_and(i > n_grid - bound, v_out[0] > 0), lambda _: jnp.zeros_like(v_out), lambda _: v_out, operand=None)
  
  lin = (v_out.transpose() @ normal)
  
  vit = v_out - lin * normal
  lit = jnp.linalg.norm(vit + 1e-10)  + 1e-10
  
  
  v_out_gate_2 = jax.lax.cond(lit + coeff * lin <= 0, lambda _: jnp.zeros_like(v_out), lambda _: (1 + coeff * lin / lit) * vit, operand=None)
  v_out_gate_1 = jax.lax.cond(lin < 0, lambda _: v_out_gate_2, lambda _: jnp.zeros_like(v_out), operand=None)
  v_out = jax.lax.cond(jnp.logical_and(j < bound, v_out[1] < 0), lambda _: v_out_gate_1, lambda _: v_out, operand=None)          
  v_out = jax.lax.cond(jnp.logical_and(j > n_grid - bound, v_out[1] > 0), lambda _: jnp.zeros_like(v_out), lambda _: v_out, operand=None)
  
  return v_out

go_j = jit(vmap(vmap(grid_op)))


def advance2(t, args):
  grid_v_in = args[0]
  grid_m_in = args[1]
  index_array = args[2]
  grid_v_in = go_j(grid_v_in, grid_m_in, index_array)
  
  return grid_v_in, grid_m_in, index_array
  
  
def advance(t, args):
  x = args[0]
  v = args[1]
  C = args[2]
  F = args[3]
  x, v, C, F = p1_j(x, v, C, F, actuator_id)
  
  return x, v, C, F
  
a = jit(advance)

def forward2(grid_v_in, grid_m_in, index_array):
  grid_v_in, grid_m_in, index_array = jlax.fori_loop(0, 4000, advance2, (grid_v_in, grid_m_in, index_array))
  return jnp.mean(grid_v_in)


def main():
# initialization
  allocate_arrays()
  
  f2 = jit(forward2)
  forward_grad2 = jit(grad(forward2))

  number = 10
  

  print(timeit.timeit(lambda : f2(grid_v_in, grid_m_in, index_array).block_until_ready(), number=number) / number)
  print(timeit.timeit(lambda : f2(grid_v_in, grid_m_in, index_array).block_until_ready(), number=number) / number)
  print(timeit.timeit(lambda : forward_grad2(grid_v_in, grid_m_in, index_array).block_until_ready(), number=number) / number)
  print(timeit.timeit(lambda : forward_grad2(grid_v_in, grid_m_in, index_array).block_until_ready(), number=number) / number)
  
if __name__ == "__main__":
  main()

aespielberg avatar Oct 30 '21 20:10 aespielberg