anndata icon indicating copy to clipboard operation
anndata copied to clipboard

Support JAX as a backend

Open gtca opened this issue 2 years ago • 11 comments

I would support JAX as another backend.

Would probably be worth opening another issue for tracking JAX as a backend. For IO, I was thinking that could be done with #659

Originally posted by @ivirshup in https://github.com/theislab/anndata/issues/695#issuecomment-1026039360

gtca avatar Jan 31 '22 17:01 gtca

I wonder if apart from the modifiers (#659) we should consider saving original object class as an attribute. Would one expect to restore the same object class that was saved? If yes, that should make sense.

import h5py
f = h5py.File("adata.h5ad")
f['X'].attrs['origin-class']
# => jaxlib.xla_extension.DeviceArray
f.close()

import anndata
adata = anndata.read("adata.h5ad")
type(adata.X)
# => jaxlib.xla_extension.DeviceArray

If the respective library is not available, we use the default backend (NumPy in this case).

gtca avatar Jan 31 '22 23:01 gtca

I don't think IO operations should return different types depending on the environment. I think that's too surprising of behavior, and will definitely lead to bugs in user code.

Plus very little space for interlanguage interoperability.

ivirshup avatar Feb 01 '22 00:02 ivirshup

I guess it would be nice to include a «global switch» for the backend when reading files then, i.e.

adata = anndata.read("adata.h5ad", matrix_backend="jax")

That would include sparse matrices.

gtca avatar Feb 01 '22 00:02 gtca

The tricky thing here is that if a GPU is available Jax will try to put all the data on device which can easily become problematic. While I really do think Jax is promising, a Torch backend is probably way more practical as a result (combined with the fact that most sc ML applications use pytorch). PyTorch also has sparse CSR support.

adamgayoso avatar Mar 07 '22 01:03 adamgayoso

I am not sure how much of a problem this is in practice though. Of course we'll have to test it on different setups but I expect it to work smoothly on most of them. Especially considering GPU by default can be turned off.

Also, implementing one alternative backend should make adding more backends easier.

gtca avatar Mar 07 '22 11:03 gtca

Especially considering GPU by default can be turned off.

Yes but can you then easily turn it back on for only some tensors? Agree that it's not mutually exclusive, jax and torch would be great.

adamgayoso avatar Mar 07 '22 16:03 adamgayoso

Also there is already a loader with built-in converter to pytorch tensors https://github.com/theislab/anndata/tree/master/anndata/experimental/pytorch https://anndata.readthedocs.io/en/latest/api.html#experimental-api

Koncopd avatar Mar 07 '22 16:03 Koncopd

Also there is already a loader with built-in converter to pytorch tensors

This would be different than a complete backend right? We have our own PyTorch dataloader as well in scvi-tools. Backends could support things like PyTorch PCA in Scanpy etc

adamgayoso avatar Mar 07 '22 16:03 adamgayoso

I agree, a different thing if you have this in mind. But then it gets unclear if things like pytorch (or jax) tensors should be treated in loaders or directly in AnnData objects.

Koncopd avatar Mar 07 '22 16:03 Koncopd

But then it gets unclear if things like pytorch (or jax) tensors should be treated in loaders or directly in AnnData objects.

At least on our side, in the torch case, we would only need to add one line of code to not convert the format in our dataloader. In the Jax case, same thing, but it would be bad for a dataloader if you had jax arrays on GPU and then had to convert to torch gpu I think

adamgayoso avatar Mar 07 '22 16:03 adamgayoso

Hm, i think the zero copy conversion can be done via dlpack: https://jax.readthedocs.io/en/latest/jax.dlpack.html https://pytorch.org/docs/stable/dlpack.html

Koncopd avatar Mar 07 '22 16:03 Koncopd

JAX is our groups primary mode of development, and it would be fantastic if anndata also supported this! JAX now has experimental support for sparse matrices via BCOO and BCSR as well.

quattro avatar Apr 18 '23 17:04 quattro

This issue has been automatically marked as stale because it has not had recent activity. Please add a comment if you want to keep the issue open. Thank you for your contributions!

github-actions[bot] avatar Jun 21 '23 02:06 github-actions[bot]

Not marking as stale, but also not sure where this sits on a roadmap until there is an implementation plan.

At the moment, I would suggest we can split this into:

  • Add in memory support for jax/ PyTorch arrays
  • Add initial write support via conversion to numpy/ scipy arrays
    • Eventual direct GPU IO
  • Read support via read_dispatched

ivirshup avatar Jun 21 '23 09:06 ivirshup