jax
jax copied to clipboard
Support (nonsymmetric) np.linalg.eig on GPU
Dear jax team,
this is just a friendly bump on the implementation of eigendecomposition and batched SVD on GPU. Are you planning on implementing these?
Should I want to implement it myself, would I be able to do it with the primitives in jax.lax
, or would I have to hook up a new part of cuSolver? I am willing to spend the time as I would benefit a lot from these features, but I have no experience with expanding jax and would not know where to look.
Thanks for the ping! Are there open issues for these already?
@hawkinsp is the expert on these things and can provide the best advice, but for GPU implementations of linalg my understanding is we set up some wrappers in jaxlib, then set up backend-specific translation rules for the appropriate primitives in lax_linalg.py. As for adding batching specifically, I think we just need to make sure batch dimensions are plumbed through properly, which if the cusolver kernels themselves don't support batch dimensions might mean adding some kind of a loop over cusolver calls. It looks like Peter added batched triangular solve and LU decomposition for GPU in #1144, so that might provide hints for the plumbing needed.
What do you think? Questions welcome! I can only provide high-level pointers to the right places, but if we sniff around there I bet we'll find things.
Thanks for your quick response! I think there are already some issues concerning linear ops, but not specifically eigendecomp or batched SVD.
Also thanks a lot for the explanation! I'll try to get oriented and come here if I have questions.
@clemisch I can take a look at these if you aren't already working on them.
Hey @hawkinsp, thank you for getting back on this! Tbh I have not looked into this so far. It would be great if you could have a look too!
PR #1314 adds batched SVD on CPU and GPU. On CPU or for large matrices on GPU it merely calls the current code in a loop. On GPU for small matrices it calls the batched Jacobi kernel from Cusolver.
Unfortunately np.linalg.eig
is a little harder. I can add a "batched" implementation on CPU (simply looping over the batch elements.) However there is no support for non-symmetric eigendecomposition in Cusolver (batched or unbatched). If you really need this, then we'd need to add another dependency (probably MAGMA), which is a bunch more work. Does SVD and symmetric eigendecomposition satisfy you for now?
I merged the PR that adds batched SVD support. You'll need to rebuild Jaxlib (or wait for us to make a release.)
I retitled the issue to reflect the open action item (nonsymmetric eigendecomposition on GPU).
GPU Eigendecomposition via MAGMA might fall into the "contributions welcome" category, unless it proves to be a popular request.
Thank you @hawkinsp, this is great! Non-symmetric eigendecomposition is not very urgent for me, especially if it's so cumbersome to add to jax.
Concerning batched SVD I have a question about speed: In this little test I only see x4 speedup vs. single-core numpy. Is this expected?
import jax
import jax.numpy as np
import numpy as onp
x_host = onp.random.rand(100000, 3, 3).astype(onp.float32)
x_gpu = np.array(x_host)
svd_batch = jax.jit(jax.vmap(np.linalg.svd, 0, 0))
u1, s1, v1 = onp.linalg.svd(x_host)
u2, s2, v2 = np.linalg.svd(x_gpu)
u3, s3, v3 = svd_batch(x_gpu)
%timeit onp.linalg.svd(x_host) # 495 ms
%timeit np.linalg.svd(x_gpu)[0].block_until_ready() # 122 ms
%timeit svd_batch(x_gpu)[1].block_until_ready() # 123 ms
(sorry about the repost, I deleted the original comment by mistake)
Bump :smiley_cat:
In this little test I only see x4 speedup vs. single-core numpy. Is this expected?
I believe that's just how fast the NVidia's Cusolver batched jacobi implementation is. On my GPU, it seems we spend 99.9% of the time in the batched Jacobi kernel:
GPU activities:
99.90% 3.20600s 12 267.17ms 233.68ms 305.82ms void batched_svd_parallel_jacobi_32x16<float, float>(int, int, int, int, float*, unsigned long, int, float*, float*, unsigned long, int, float*, unsigned long, int, float, int, int*, float, int, int*, int, float)
The algorithm does have some tunable parameters that one might explore setting: https://docs.nvidia.com/cuda/cusolver/index.html#cuds-lt-t-gt-gesvdjbatch
If you wanted to try that, I think you just need to call the functions that modify the Jacobi parameters at this line and then rebuild Jaxlib. https://github.com/google/jax/blob/master/jaxlib/cusolver.cc#L731
Thank you very much for clarifying!
Hi! Just popping up to ask if there is any progress regarding eig
.
I'm currently preparing a JAX implementation of implicitly restarted arnoldi (non-symmetric operators). The working CPU implementation relies on jax.numpy.linalg.eig
to compute eigenvalues of the Hessenberg matrix returned by Arnoldi. Would be great to have this run on GPU eventually.
Hey, thought I'd also express my desire for this, my use case being finding the poles of many auto-regressive models in parallel with np.roots. Thanks to all the contributors to JAX for where it already is, it's amazing.
I'm curious how folks would feel about the following: suppose MAGMA were an optional dependency of JAX. i.e., we don't bundle it in jaxlib builds, but if you install it yourself (or perhaps via conda?) and JAX can find the shared library in your library path, then jnp.linalg.eig
works on GPU.
(I'm a bit reluctant to bundle it with jaxlib unconditionally for just one function!)
I'd be totally fine with this. Could always be bundled in later down the line but as you say I feel the critical threshold for functional usage is perhaps a bit higher than one! :)
+1 that support for GPU-backed eig
would be great.
+1 for GPU-support for nonsymmetric eig to allow GPU-enabled numpy.roots
I also support strongly the implementation of this feature, in order to be able to use jnp.roots with GPU. I am training a network whose loss function requires computing roots of a polynomial, and training on CPU is really too slow.
I developed a workaround for my use case, which involves using the jax.experimental.host_callback
module. Just sharing it in case it's useful.
def _eig_host(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Wraps jnp.linalg.eig so that it can be jit-ed on a machine with GPUs."""
eigenvalues_shape = jax.ShapeDtypeStruct(matrix.shape[:-1], complex)
eigenvectors_shape = jax.ShapeDtypeStruct(matrix.shape, complex)
return host_callback.call(
# We force this computation to be performed on the cpu by jit-ing and
# explicitly specifying the device.
jax.jit(jnp.linalg.eig, device=jax.devices("cpu")[0]),
matrix.astype(complex),
result_shape=[eigenvalues_shape, eigenvectors_shape],
)
jax.jit(_eig_host, device=jax.devices("gpu")[0])(m) # This works, we can jit on GPU.
A brief update to this: we have a slightly modified version of this which avoids the device specification in the call to jax.jit
, which is the new recommended practice:
def _eig_host(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Wraps jnp.linalg.eig so that it can be jit-ed on a machine with GPUs."""
eigenvalues_shape = jax.ShapeDtypeStruct(matrix.shape[:-1], complex)
eigenvectors_shape = jax.ShapeDtypeStruct(matrix.shape, complex)
def _eig_cpu(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
# We force this computation to be performed on the cpu by jit-ing and
# explicitly specifying the device.
with jax.default_device(jax.devices("cpu")[0]):
return jax.jit(jnp.linalg.eig)(matrix)
return host_callback.call(
_eig_cpu,
matrix.astype(complex),
result_shape=(eigenvalues_shape, eigenvectors_shape),
)
Hi there! I was just wondering if there has been any progress made on this particular issue. Since it is quite a common and essential function for scientific studies.
Note for anyone using the above workaround with the host callback nowadays: host_callback
has been deprecated, use external callbacks instead (most probably pure_callback()
). These also work nicely under function transformations.
I have implemented (matrix-free) eigs in JAX for scientific purposes in jaxeigs. I have borrowed some code from TensorNetwork and performed Arnoldi decomposition on the GPU. However, I have kept the last step, which involves solving the eigenproblem in the projected Krylov space, implemented on the CPU (via callback) since the algorithms is divide and conquer thus not efficient on GPU.
I must admit that this code is currently extremely unstable, and the documentation is incomplete. Despite these limitations, it is functional for my own use.
Note for anyone using the above workaround with the host callback nowadays:
host_callback
has been deprecated, use external callbacks instead (most probablypure_callback()
). These also work nicely under function transformations.
As pure_callback
does not seem to support fp64 at the moment, you need additional tricks (in case you are using fp32).
Looking forward to the implementation of eig
on GPU.
def _eig_host(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
eigenvalues_shape = jax.ShapeDtypeStruct(matrix.shape[:-1], matrix.dtype)
eigenvectors_shape = jax.ShapeDtypeStruct(matrix.shape, matrix.dtype)
def _eig_cpu(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
# We force this computation to be performed on the cpu by jit-ing and
# explicitly specifying the device.
with jax.default_device(jax.devices("cpu")[0]):
val, vec = jax.jit(jnp.linalg.eig)(matrix)
return (val.real, val.imag), (vec.real, vec.imag)
val, vec = jax.pure_callback(_eig_cpu,
((eigenvalues_shape, eigenvalues_shape),
(eigenvectors_shape, eigenvectors_shape)),
matrix)
return val[0] + 1j * val[1], vec[0] + 1j * vec[1]
As
pure_callback
does not seem to support fp64 at the moment, you need additional tricks (in case you are using fp32).
We don't seem to have issues supporting fp32 and fp64 with the following implementation in fmmax:
def _eig_host(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Wraps jnp.linalg.eig so that it can be jit-ed on a machine with GPUs."""
def _eig_cpu(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
# We force this computation to be performed on the cpu by jit-ing and
# explicitly specifying the device.
with jax.default_device(jax.devices("cpu")[0]):
return jax.jit(jnp.linalg.eig)(matrix)
return jax.pure_callback(
_eig_cpu,
(
jnp.ones(matrix.shape[:-1], dtype=complex), # Eigenvalues
jnp.ones(matrix.shape, dtype=complex), # Eigenvectors
),
matrix.astype(complex),
vectorized=True,
)
I did some tests comparing eig
performance for scipy, numpy, jax, and torch and found that they can differ quite a bit, with torch generally being the fastest. In lieu of a GPU-accelerated eig
, simply using the torch version may be of benefit.
I also created a pip-installable jeig
package which wraps all of these for use with jax. All implementations can be jit-compiled, including on machines with GPUs.
Here is an example of the performance difference I am seeing. This was generated on CPU colab, but torch comes out ahead also on my Apple and Intel machines. I didn't investigate the origin of the difference, but presumably there's a different linear algebra library being used in each of these packages.
JAX just calls scipy's copy of LAPACK. You can probably accelerate it by installing e.g., Intel's MKL scipy.
Torch, as far as I know, also just calls LAPACK. It may be linking it with a different BLAS library; JAX will just be using openblas from scipy.
Hi, bumping this to ask if there is any plan from the jax team to implement this feature? @jakevdp
We'd also need this feature for dynamiqs, for the simulation of quantum systems in the so-called Floquet basis (time-periodic quantum systems).
Thanks!
Hello, I was wondering if there has been any progress on this issue?