xla
xla copied to clipboard
Very slow constant folding of very-large integer arrays, for instance when working with sparse matrices
This relates to the JAX issue #14655: copying in various details from that thread below.
I've got a use case where I'd like to store the nonzero entries of a very large sparse matrix, and then access them later during a machine learning training loop. Unfortunately, using JIT compilation results in constant-folding of this array, making it extremely slow on large problems. Here's an MWE that runs on my laptop and captures the typical behavior:
import jax
import jax.numpy as jnp
import jax.experimental.sparse as sparse
from jax.experimental.sparse import BCOO
n = 10000000
def build_sparse_linear_operator():
nonzeroes = sparse.eye(n).indices # shape (n,2)
def product(other):
matrix = BCOO((jnp.ones(n),nonzeroes), shape=(n,n), indices_sorted=True, unique_indices=True)
return matrix @ other
return product
operator = build_sparse_linear_operator()
def fn(x):
return operator(jnp.ones(n) / x).sum()
fn(1.0) # executes in 0.1s
jax.jit(fn)(1.0) # executes in almost one minute
Calling the function without JIT executes in about a tenth of a second, but calling it with JIT takes almost a minute. On larger problems in the codebase which prompted this MWE, I have had it crash due to running out of memory after about an hour. This produces warnings similar to the following:
Constant folding an instruction is taking > 8s:
slice.22 (displaying the full instruction incurs a runtime overhead. Raise your logging level to 4 or above).
This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.
Switching to the following JAX code bypasses the issue:
def build_sparse_linear_operator():
nonzeroes = sparse.eye(n).indices # shape (n,2)
def product(other):
nz = _optimization_barrier(nonzeroes)
matrix = BCOO((jnp.ones(n), nz), shape=(n,n), indices_sorted=True, unique_indices=True)
return matrix @ other
return product
From this, the problem seems to be that XLA for some reason tries to constant-fold nonzeroes
, in spite of its large size, and then runs out of resources while trying to do so. I haven't yet been able to replicate this for float arrays, so I'm not sure whether or not the issue is intger-specific.