jax
jax copied to clipboard
Compilation time on GPU is proportional to batch size for grad of vmapped Cholesky solve
Description
The problem is with the grad of the mean of a vmapped Cholesky solution . If I define
def func(pars):
ftf = fmat @ jax.numpy.diag(pars**2) @ fmat.T + one
cf = jax.scipy.linalg.cho_factor(ftf)
b = jax.scipy.linalg.cho_solve(cf, ones)
return b.mean()
and then transform/compile
jvg = jax.value_and_grad(lambda pars: jax.vmap(func)(pars).mean())
pars = jax.random.normal(jax.random.PRNGKey(0), (nbatch,2*ngp,))
jjvg = jax.jit(jvg).lower(pars).compile()
I find that the compilation time grows with nbatch. For instance nbatch, time(s) = [16,0.532], [32,0.507], [64,0.516], [128,0.580], [256,0.652], [512,0.822], [1024,1.7], [2048,2.75] for the example matrices listed below.
What's happening here?
To run this example you need matrices such as
nobs, ngp = 256, 64
t = np.linspace(0, 1, nobs)
f = np.arange(1, ngp + 1, dtype=np.float64)
fmat = np.zeros((nobs, 2*ngp), dtype=np.float64)
fmat[:, ::2] = np.sin(2.0 * jnp.pi * f * t[:,np.newaxis])
fmat[:, 1::2] = np.cos(2.0 * jnp.pi * f * t[:,np.newaxis])
one, ones = jax.numpy.identity(nobs, dtype=np.float64), jax.numpy.ones(nobs, dtype=np.float64)
System info (python version, jaxlib version, accelerator, etc.)
JAX 0.4.26, CUDA 12.2 and driver 535.104.05, Nvidia V100. Python 3.10.12 on Linux (Colab)
Thanks for the report. I'm not sure what's going on, but it seems others are also noticing this: https://stackoverflow.com/questions/78486071/why-does-jax-compilation-time-grow-with-vmap-batch-size
Thanks @jakevdp; both queries are from me :) In this case, note also that the Cholesky without the grad compiles in constant time, so it must be something about the high-level gradient algorithm for Cholesky.
I can repro on a Colab A100; thought it somehow might have to do with constant folding but even passing fmat as an argument and defining one and ones in function I still see the batch-dependent compile time
Interestingly, using your make_hlo (which I just found in https://github.com/google/jax/issues/7949) shows that the XLA code for, say, nbatch = 64 and nbatch = 512 is essentially the same, except for 64 -> 512.
Would this mean that the problem is at the LLVM level? (Or another Nvidia representation?)
Another clue is that the linear compilation time happens also for a function that's already written with the extra batch dimension instead of being vmapped. So the problem must be the batched grad Cholesky.