jax icon indicating copy to clipboard operation
jax copied to clipboard

Support (nonsymmetric) np.linalg.eig on GPU

Open clemisch opened this issue 5 years ago • 31 comments

Dear jax team,

this is just a friendly bump on the implementation of eigendecomposition and batched SVD on GPU. Are you planning on implementing these?

Should I want to implement it myself, would I be able to do it with the primitives in jax.lax, or would I have to hook up a new part of cuSolver? I am willing to spend the time as I would benefit a lot from these features, but I have no experience with expanding jax and would not know where to look.

clemisch avatar Aug 28 '19 20:08 clemisch

Thanks for the ping! Are there open issues for these already?

@hawkinsp is the expert on these things and can provide the best advice, but for GPU implementations of linalg my understanding is we set up some wrappers in jaxlib, then set up backend-specific translation rules for the appropriate primitives in lax_linalg.py. As for adding batching specifically, I think we just need to make sure batch dimensions are plumbed through properly, which if the cusolver kernels themselves don't support batch dimensions might mean adding some kind of a loop over cusolver calls. It looks like Peter added batched triangular solve and LU decomposition for GPU in #1144, so that might provide hints for the plumbing needed.

What do you think? Questions welcome! I can only provide high-level pointers to the right places, but if we sniff around there I bet we'll find things.

mattjj avatar Aug 28 '19 20:08 mattjj

Thanks for your quick response! I think there are already some issues concerning linear ops, but not specifically eigendecomp or batched SVD.

Also thanks a lot for the explanation! I'll try to get oriented and come here if I have questions.

clemisch avatar Aug 29 '19 15:08 clemisch

@clemisch I can take a look at these if you aren't already working on them.

hawkinsp avatar Sep 03 '19 14:09 hawkinsp

Hey @hawkinsp, thank you for getting back on this! Tbh I have not looked into this so far. It would be great if you could have a look too!

clemisch avatar Sep 04 '19 12:09 clemisch

PR #1314 adds batched SVD on CPU and GPU. On CPU or for large matrices on GPU it merely calls the current code in a loop. On GPU for small matrices it calls the batched Jacobi kernel from Cusolver.

Unfortunately np.linalg.eig is a little harder. I can add a "batched" implementation on CPU (simply looping over the batch elements.) However there is no support for non-symmetric eigendecomposition in Cusolver (batched or unbatched). If you really need this, then we'd need to add another dependency (probably MAGMA), which is a bunch more work. Does SVD and symmetric eigendecomposition satisfy you for now?

hawkinsp avatar Sep 05 '19 22:09 hawkinsp

I merged the PR that adds batched SVD support. You'll need to rebuild Jaxlib (or wait for us to make a release.)

I retitled the issue to reflect the open action item (nonsymmetric eigendecomposition on GPU).

hawkinsp avatar Sep 06 '19 13:09 hawkinsp

GPU Eigendecomposition via MAGMA might fall into the "contributions welcome" category, unless it proves to be a popular request.

hawkinsp avatar Sep 06 '19 13:09 hawkinsp

Thank you @hawkinsp, this is great! Non-symmetric eigendecomposition is not very urgent for me, especially if it's so cumbersome to add to jax.

Concerning batched SVD I have a question about speed: In this little test I only see x4 speedup vs. single-core numpy. Is this expected?

import jax
import jax.numpy as np
import numpy as onp

x_host = onp.random.rand(100000, 3, 3).astype(onp.float32)
x_gpu = np.array(x_host)

svd_batch = jax.jit(jax.vmap(np.linalg.svd, 0, 0))

u1, s1, v1 = onp.linalg.svd(x_host)
u2, s2, v2 = np.linalg.svd(x_gpu)
u3, s3, v3 = svd_batch(x_gpu)

%timeit onp.linalg.svd(x_host)                       # 495 ms
%timeit np.linalg.svd(x_gpu)[0].block_until_ready()  # 122 ms
%timeit svd_batch(x_gpu)[1].block_until_ready()      # 123 ms

(sorry about the repost, I deleted the original comment by mistake)

clemisch avatar Sep 12 '19 19:09 clemisch

Bump :smiley_cat:

In this little test I only see x4 speedup vs. single-core numpy. Is this expected?

clemisch avatar Sep 27 '19 11:09 clemisch

I believe that's just how fast the NVidia's Cusolver batched jacobi implementation is. On my GPU, it seems we spend 99.9% of the time in the batched Jacobi kernel:

 GPU activities:
99.90%  3.20600s        12  267.17ms  233.68ms  305.82ms  void batched_svd_parallel_jacobi_32x16<float, float>(int, int, int, int, float*, unsigned long, int, float*, float*, unsigned long, int, float*, unsigned long, int, float, int, int*, float, int, int*, int, float)

The algorithm does have some tunable parameters that one might explore setting: https://docs.nvidia.com/cuda/cusolver/index.html#cuds-lt-t-gt-gesvdjbatch

If you wanted to try that, I think you just need to call the functions that modify the Jacobi parameters at this line and then rebuild Jaxlib. https://github.com/google/jax/blob/master/jaxlib/cusolver.cc#L731

hawkinsp avatar Sep 27 '19 17:09 hawkinsp

Thank you very much for clarifying!

clemisch avatar Sep 28 '19 07:09 clemisch

Hi! Just popping up to ask if there is any progress regarding eig. I'm currently preparing a JAX implementation of implicitly restarted arnoldi (non-symmetric operators). The working CPU implementation relies on jax.numpy.linalg.eig to compute eigenvalues of the Hessenberg matrix returned by Arnoldi. Would be great to have this run on GPU eventually.

mganahl avatar Sep 18 '20 09:09 mganahl

Hey, thought I'd also express my desire for this, my use case being finding the poles of many auto-regressive models in parallel with np.roots. Thanks to all the contributors to JAX for where it already is, it's amazing.

joncarter1 avatar May 17 '21 11:05 joncarter1

I'm curious how folks would feel about the following: suppose MAGMA were an optional dependency of JAX. i.e., we don't bundle it in jaxlib builds, but if you install it yourself (or perhaps via conda?) and JAX can find the shared library in your library path, then jnp.linalg.eig works on GPU.

(I'm a bit reluctant to bundle it with jaxlib unconditionally for just one function!)

hawkinsp avatar May 21 '21 13:05 hawkinsp

I'd be totally fine with this. Could always be bundled in later down the line but as you say I feel the critical threshold for functional usage is perhaps a bit higher than one! :)

joncarter1 avatar May 21 '21 17:05 joncarter1

+1 that support for GPU-backed eig would be great.

ianwilliamson avatar Aug 09 '21 20:08 ianwilliamson

+1 for GPU-support for nonsymmetric eig to allow GPU-enabled numpy.roots

drscook avatar Sep 29 '22 16:09 drscook

I also support strongly the implementation of this feature, in order to be able to use jnp.roots with GPU. I am training a network whose loss function requires computing roots of a polynomial, and training on CPU is really too slow.

melsophos avatar Oct 04 '22 20:10 melsophos

I developed a workaround for my use case, which involves using the jax.experimental.host_callback module. Just sharing it in case it's useful.

def _eig_host(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Wraps jnp.linalg.eig so that it can be jit-ed on a machine with GPUs."""
    eigenvalues_shape = jax.ShapeDtypeStruct(matrix.shape[:-1], complex)
    eigenvectors_shape = jax.ShapeDtypeStruct(matrix.shape, complex)
    return host_callback.call(
        # We force this computation to be performed on the cpu by jit-ing and
        # explicitly specifying the device.
        jax.jit(jnp.linalg.eig, device=jax.devices("cpu")[0]),
        matrix.astype(complex),
        result_shape=[eigenvalues_shape, eigenvectors_shape],
    )

jax.jit(_eig_host, device=jax.devices("gpu")[0])(m)  # This works, we can jit on GPU.

mfschubert avatar Oct 31 '22 16:10 mfschubert

A brief update to this: we have a slightly modified version of this which avoids the device specification in the call to jax.jit, which is the new recommended practice:

def _eig_host(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Wraps jnp.linalg.eig so that it can be jit-ed on a machine with GPUs."""
    eigenvalues_shape = jax.ShapeDtypeStruct(matrix.shape[:-1], complex)
    eigenvectors_shape = jax.ShapeDtypeStruct(matrix.shape, complex)

    def _eig_cpu(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
        # We force this computation to be performed on the cpu by jit-ing and
        # explicitly specifying the device.
        with jax.default_device(jax.devices("cpu")[0]):
            return jax.jit(jnp.linalg.eig)(matrix)

    return host_callback.call(
        _eig_cpu,
        matrix.astype(complex),
        result_shape=(eigenvalues_shape, eigenvectors_shape),
    )

mfschubert avatar Aug 29 '23 02:08 mfschubert

Hi there! I was just wondering if there has been any progress made on this particular issue. Since it is quite a common and essential function for scientific studies.

tsunhopang avatar Mar 27 '24 13:03 tsunhopang

Note for anyone using the above workaround with the host callback nowadays: host_callback has been deprecated, use external callbacks instead (most probably pure_callback()). These also work nicely under function transformations.

ju-kreber avatar May 03 '24 16:05 ju-kreber

I have implemented (matrix-free) eigs in JAX for scientific purposes in jaxeigs. I have borrowed some code from TensorNetwork and performed Arnoldi decomposition on the GPU. However, I have kept the last step, which involves solving the eigenproblem in the projected Krylov space, implemented on the CPU (via callback) since the algorithms is divide and conquer thus not efficient on GPU.

I must admit that this code is currently extremely unstable, and the documentation is incomplete. Despite these limitations, it is functional for my own use.

qiyang-ustc avatar May 05 '24 08:05 qiyang-ustc

Note for anyone using the above workaround with the host callback nowadays: host_callback has been deprecated, use external callbacks instead (most probably pure_callback()). These also work nicely under function transformations.

As pure_callback does not seem to support fp64 at the moment, you need additional tricks (in case you are using fp32).

Looking forward to the implementation of eig on GPU.

def _eig_host(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    eigenvalues_shape = jax.ShapeDtypeStruct(matrix.shape[:-1], matrix.dtype)
    eigenvectors_shape = jax.ShapeDtypeStruct(matrix.shape, matrix.dtype)

    def _eig_cpu(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
        # We force this computation to be performed on the cpu by jit-ing and
        # explicitly specifying the device.
        with jax.default_device(jax.devices("cpu")[0]):
            val, vec =  jax.jit(jnp.linalg.eig)(matrix)
            return (val.real, val.imag), (vec.real, vec.imag)

    val, vec = jax.pure_callback(_eig_cpu,
                                 ((eigenvalues_shape, eigenvalues_shape),
                                  (eigenvectors_shape, eigenvectors_shape)),
                                 matrix)
    return val[0] + 1j * val[1], vec[0] + 1j * vec[1]

moskomule avatar May 15 '24 09:05 moskomule

As pure_callback does not seem to support fp64 at the moment, you need additional tricks (in case you are using fp32).

We don't seem to have issues supporting fp32 and fp64 with the following implementation in fmmax:

def _eig_host(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Wraps jnp.linalg.eig so that it can be jit-ed on a machine with GPUs."""

    def _eig_cpu(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
        # We force this computation to be performed on the cpu by jit-ing and
        # explicitly specifying the device.
        with jax.default_device(jax.devices("cpu")[0]):
            return jax.jit(jnp.linalg.eig)(matrix)

    return jax.pure_callback(
        _eig_cpu,
        (
            jnp.ones(matrix.shape[:-1], dtype=complex),  # Eigenvalues
            jnp.ones(matrix.shape, dtype=complex),  # Eigenvectors
        ),
        matrix.astype(complex),
        vectorized=True,
    )

mfschubert avatar May 15 '24 15:05 mfschubert

I did some tests comparing eig performance for scipy, numpy, jax, and torch and found that they can differ quite a bit, with torch generally being the fastest. In lieu of a GPU-accelerated eig, simply using the torch version may be of benefit.

I also created a pip-installable jeig package which wraps all of these for use with jax. All implementations can be jit-compiled, including on machines with GPUs.

Here is an example of the performance difference I am seeing. This was generated on CPU colab, but torch comes out ahead also on my Apple and Intel machines. I didn't investigate the origin of the difference, but presumably there's a different linear algebra library being used in each of these packages.

image

mfschubert avatar Aug 28 '24 15:08 mfschubert

JAX just calls scipy's copy of LAPACK. You can probably accelerate it by installing e.g., Intel's MKL scipy.

Torch, as far as I know, also just calls LAPACK. It may be linking it with a different BLAS library; JAX will just be using openblas from scipy.

hawkinsp avatar Aug 28 '24 15:08 hawkinsp

Hi, bumping this to ask if there is any plan from the jax team to implement this feature? @jakevdp

We'd also need this feature for dynamiqs, for the simulation of quantum systems in the so-called Floquet basis (time-periodic quantum systems).

Thanks!

gautierronan avatar Sep 13 '24 07:09 gautierronan

Hello, I was wondering if there has been any progress on this issue?

cwoolfo1 avatar Oct 15 '24 04:10 cwoolfo1