jax icon indicating copy to clipboard operation
jax copied to clipboard

Compilation time on GPU is proportional to batch size for grad of vmapped Cholesky solve

Open vallis opened this issue 1 year ago • 5 comments

Description

The problem is with the grad of the mean of a vmapped Cholesky solution . If I define

def func(pars):
  ftf = fmat @ jax.numpy.diag(pars**2) @ fmat.T + one
  cf = jax.scipy.linalg.cho_factor(ftf)
  b = jax.scipy.linalg.cho_solve(cf, ones)
  return b.mean()

and then transform/compile

jvg = jax.value_and_grad(lambda pars: jax.vmap(func)(pars).mean())
pars = jax.random.normal(jax.random.PRNGKey(0), (nbatch,2*ngp,))
jjvg = jax.jit(jvg).lower(pars).compile()

I find that the compilation time grows with nbatch. For instance nbatch, time(s) = [16,0.532], [32,0.507], [64,0.516], [128,0.580], [256,0.652], [512,0.822], [1024,1.7], [2048,2.75] for the example matrices listed below.

What's happening here?

To run this example you need matrices such as

nobs, ngp = 256, 64
t = np.linspace(0, 1, nobs)
f = np.arange(1, ngp + 1, dtype=np.float64)

fmat = np.zeros((nobs, 2*ngp), dtype=np.float64)
fmat[:,  ::2] = np.sin(2.0 * jnp.pi * f * t[:,np.newaxis])
fmat[:, 1::2] = np.cos(2.0 * jnp.pi * f * t[:,np.newaxis])

one, ones = jax.numpy.identity(nobs, dtype=np.float64), jax.numpy.ones(nobs, dtype=np.float64)

System info (python version, jaxlib version, accelerator, etc.)

JAX 0.4.26, CUDA 12.2 and driver 535.104.05, Nvidia V100. Python 3.10.12 on Linux (Colab)

vallis avatar May 20 '24 19:05 vallis

Thanks for the report. I'm not sure what's going on, but it seems others are also noticing this: https://stackoverflow.com/questions/78486071/why-does-jax-compilation-time-grow-with-vmap-batch-size

jakevdp avatar May 20 '24 20:05 jakevdp

Thanks @jakevdp; both queries are from me :) In this case, note also that the Cholesky without the grad compiles in constant time, so it must be something about the high-level gradient algorithm for Cholesky.

vallis avatar May 20 '24 21:05 vallis

I can repro on a Colab A100; thought it somehow might have to do with constant folding but even passing fmat as an argument and defining one and ones in function I still see the batch-dependent compile time

jakevdp avatar May 20 '24 21:05 jakevdp

Interestingly, using your make_hlo (which I just found in https://github.com/google/jax/issues/7949) shows that the XLA code for, say, nbatch = 64 and nbatch = 512 is essentially the same, except for 64 -> 512.

Would this mean that the problem is at the LLVM level? (Or another Nvidia representation?)

vallis avatar May 20 '24 23:05 vallis

Another clue is that the linear compilation time happens also for a function that's already written with the extra batch dimension instead of being vmapped. So the problem must be the batched grad Cholesky.

vallis avatar May 21 '24 03:05 vallis