jax icon indicating copy to clipboard operation
jax copied to clipboard

Pure callback with jax types hangs with jax > 0.4.31

Open mfschubert opened this issue 1 year ago • 6 comments

Description

I am experiencing hanging for jax versions newer than 0.4.31, as referenced in an earlier issue that I created and subsequently closed (https://github.com/jax-ml/jax/issues/24219). I managed to simplify the reproduction.

The issue seems to be related to jax calculations within a function called by pure_callback. The code below reproduces the issue.

import jax
print(f"jax_version={jax.__version__}")
import jax.numpy as jnp
import numpy as onp

def _eig_jax(matrix):
    """Eigendecomposition using `jax.numpy.linalg.eig`."""
    eigval, eigvec = jax.pure_callback(
        _eig_cpu,
        (
            jnp.ones(matrix.shape[:-1], dtype=complex),  # Eigenvalues
            jnp.ones(matrix.shape, dtype=complex),  # Eigenvectors
        ),
        matrix.astype(complex),
        vectorized=True,
    )
    return jnp.asarray(eigval), jnp.asarray(eigvec)

with jax.default_device(jax.devices("cpu")[0]):
    _eig_jax_cpu = jax.jit(jnp.linalg.eig)

def _eig_cpu(matrix):
  eigvals, eigvecs = _eig_jax_cpu(matrix)
  return onp.asarray(eigvals), onp.asarray(eigvecs)

# This loop hangs, typically at < 10 steps on a colab CPU runtime. Larger matrices
# cause the loop to hang at earlier steps.
for i in range(100):
    print(i)
    _eig_jax(jnp.ones((500, 500)))

The method of wrapping jnp.linalg.eig is one that has been successful for jax 0.4.31 and earlier, and has been brought up in discussions several times (https://github.com/jax-ml/jax/discussions/23079, https://github.com/jax-ml/jax/issues/1259).

System info (python version, jaxlib version, accelerator, etc.)

Python 3.10 Jax 0.4.33 Colab CPU runtime

mfschubert avatar Oct 11 '24 18:10 mfschubert

OK, I am currently working around this via,

def _eig_jax(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Eigendecomposition using `jax.numpy.linalg.eig`."""
    if jax.devices()[0] == jax.devices("cpu")[0]:
        return jnp.linalg.eig(matrix)
    else:
        eigvals, eigvecs = jax.pure_callback(
            _eig_jax_cpu,
            (
                jnp.ones(matrix.shape[:-1], dtype=complex),  # Eigenvalues
                jnp.ones(matrix.shape, dtype=complex),  # Eigenvectors
            ),
            matrix.astype(complex),
            vectorized=True,
        )
        return jnp.asarray(eigvals), jnp.asarray(eigvecs)

This is fine for me, but it may be nice to have a graceful failure mode other than simply hanging.

mfschubert avatar Oct 11 '24 20:10 mfschubert

Thanks for reporting this! I'm not totally surprised that things don't go so well if we nest JAX functions inside of a pure_callback, but I agree that hanging isn't a good outcome!

I'll take a look into this because I have some ideas about what might be causing it. But, regardless, I would recommend avoiding the use of JAX functions within pure_callback. A better workaround than the one you came up with for this specific example would be to use numpy.linalg.eig instead of the JAX version:

def _eig_jax(matrix):
  matrix = matrix.astype(complex)
  return jax.pure_callback(
      np.linalg.eig,  # <-- np instead of jnp
      ...  # The rest of the arguments are the same as before
  )

This works on both GPU and CPU without relying on querying the first device. I expect the performance will be equivalent, without causing any hangs.

In the long run, I'm not sure we'll want to support including JAX code in the callback executed by pure_callback for reasons that are probably outside the scope of this discussion. jax.experimental.compute_on should eventually provide the needed API, but unfortunately it doesn't yet work on GPU.

dfm avatar Oct 13 '24 19:10 dfm

Thanks for your suggestion and I'll look forward to the new API.

Unfortunately, I have found performance to be quite different (also when using scipy.linalg.eig). This is so both on my machine and on colab, and so I suspect something other than my specific setup is responsible. But, I suppose is a different issue entirely.

mfschubert avatar Oct 14 '24 16:10 mfschubert

Unfortunately, I have found performance to be quite different (also when using scipy.linalg.eig).

Interesting! Can you say more about what performance differences you're seeing? I believe that the JAX CPU implementation of eig calls exactly the same LAPACK function that the scipy version does (JAX actually uses scipy to find LAPACK!), so I'm surprised that you would find significant performance differences. When I compare the performance of JAX on CPU with the version that uses a pure callback to scipy I actually get exactly the same performance.

Sample code run on CPU
@jax.jit
def eig_jax(x):
  x = x.astype(np.complex64)
  return jnp.linalg.eig(x)

@jax.jit
def eig_scipy(x):
  x = x.astype(np.complex64)
  eigvals = jax.ShapeDtypeStruct(x.shape[:-1], x.dtype)
  return jax.pure_callback(scipy.linalg.eig, (eigvals, x), x)

Regardless, it will be great when jax.experimental.compute_on works on GPU, but I'm just surprised that you're finding significant performance differences!

dfm avatar Oct 15 '24 20:10 dfm

Sure, here is some benchmarking code that I ran on colab CPU. I avoided using %%timeit (which reports mean time) since free tier colab seems to be quite noisy.

import jax
import jax.numpy as jnp
import numpy as onp
import scipy
import time

onp.random.seed(0)
matrix = onp.random.randn(500, 500).astype(onp.float32)

times = {"scipy": onp.inf, "numpy": onp.inf, "jax": onp.inf}
for _ in range(20):
  start = time.time()
  onp.linalg.eig(matrix)
  times["numpy"] = min(time.time() - start, times["numpy"])

  start = time.time()
  scipy.linalg.eig(matrix)
  times["scipy"] = min(time.time() - start, times["scipy"])

  start = time.time()
  jax.block_until_ready(jnp.linalg.eig(matrix))
  times["jax"] = min(time.time() - start, times["jax"])

print(times)
# {'scipy': 0.2719097137451172, 'numpy': 0.41986513137817383, 'jax': 0.227125883102417}

I also observe speed differences on my local machine, but I figure that the odds of operator error on the setup are lower with colab. :-)

mfschubert avatar Oct 15 '24 20:10 mfschubert

I am thinking this is possibly related to some issues we are having with the new CPU client async dispatch. One issue is that, when the client chooses async, the initial call returns immediately (in the main thread) and the callback will happen in a separate/new thread created by the client. If the function does not return anything, or the return of the function is not used in some way that would block the main thread (print, array, block_until_ready, etc) nothing stops the main thread from running ahead to the end and exiting.

According to https://docs.python.org/3/library/threading.html

A thread can be flagged as a “daemon thread”. The significance of this flag is that the entire Python program exits when only daemon threads are left.

It also mentions (but not clear if applicable)

Daemon threads are abruptly stopped at shutdown.

Using the threading module in the callback, E.G.

thread = threading.current_thread()
print(f"{type(thread)}: {thread!r}")

will give something like

<class 'threading._DummyThread'>: <_DummyThread(Dummy-3, started daemon 128068297426496)>

The expected behavior, which seems to match what happens, is that when the main thread reaches the end it will start tearing down and calls everything registered with atexit, including jax._src.api.clean_up, jax._src.dispatch.wait_for_tokens. In my case at least it calls clean_up first, and when the callback happens JAX starts trying to re-initialize itself. I'm not exactly sure what all that means, but I can imagine an inconsistent state that is a mixture between teardown and startup.

If this is not what is happening in this issue, I can start a new one to document this information.

kcdodd avatar Oct 17 '24 16:10 kcdodd

This was fixed in https://github.com/openxla/xla/pull/21990 with a test added in https://github.com/jax-ml/jax/pull/26173. The approach here is to disable CPU asynchronous dispatch when running within the body of a callback. See the PRs, and the PRs linked therein for more discussion. If you run into other callback deadlocks after the next JAX release, please report them!

dfm avatar Jan 29 '25 14:01 dfm