anndata
anndata copied to clipboard
Support JAX as a backend
I would support JAX as another backend.
Would probably be worth opening another issue for tracking JAX as a backend. For IO, I was thinking that could be done with #659
Originally posted by @ivirshup in https://github.com/theislab/anndata/issues/695#issuecomment-1026039360
I wonder if apart from the modifiers (#659) we should consider saving original object class as an attribute. Would one expect to restore the same object class that was saved? If yes, that should make sense.
import h5py
f = h5py.File("adata.h5ad")
f['X'].attrs['origin-class']
# => jaxlib.xla_extension.DeviceArray
f.close()
import anndata
adata = anndata.read("adata.h5ad")
type(adata.X)
# => jaxlib.xla_extension.DeviceArray
If the respective library is not available, we use the default backend (NumPy in this case).
I don't think IO operations should return different types depending on the environment. I think that's too surprising of behavior, and will definitely lead to bugs in user code.
Plus very little space for interlanguage interoperability.
I guess it would be nice to include a «global switch» for the backend when reading files then, i.e.
adata = anndata.read("adata.h5ad", matrix_backend="jax")
That would include sparse matrices.
The tricky thing here is that if a GPU is available Jax will try to put all the data on device which can easily become problematic. While I really do think Jax is promising, a Torch backend is probably way more practical as a result (combined with the fact that most sc ML applications use pytorch). PyTorch also has sparse CSR support.
I am not sure how much of a problem this is in practice though. Of course we'll have to test it on different setups but I expect it to work smoothly on most of them. Especially considering GPU by default can be turned off.
Also, implementing one alternative backend should make adding more backends easier.
Especially considering GPU by default can be turned off.
Yes but can you then easily turn it back on for only some tensors? Agree that it's not mutually exclusive, jax and torch would be great.
Also there is already a loader with built-in converter to pytorch tensors https://github.com/theislab/anndata/tree/master/anndata/experimental/pytorch https://anndata.readthedocs.io/en/latest/api.html#experimental-api
Also there is already a loader with built-in converter to pytorch tensors
This would be different than a complete backend right? We have our own PyTorch dataloader as well in scvi-tools. Backends could support things like PyTorch PCA in Scanpy etc
I agree, a different thing if you have this in mind. But then it gets unclear if things like pytorch (or jax) tensors should be treated in loaders or directly in AnnData objects.
But then it gets unclear if things like pytorch (or jax) tensors should be treated in loaders or directly in AnnData objects.
At least on our side, in the torch case, we would only need to add one line of code to not convert the format in our dataloader. In the Jax case, same thing, but it would be bad for a dataloader if you had jax arrays on GPU and then had to convert to torch gpu I think
Hm, i think the zero copy conversion can be done via dlpack
:
https://jax.readthedocs.io/en/latest/jax.dlpack.html
https://pytorch.org/docs/stable/dlpack.html
JAX is our groups primary mode of development, and it would be fantastic if anndata
also supported this! JAX now has experimental support for sparse matrices via BCOO and BCSR as well.
This issue has been automatically marked as stale because it has not had recent activity. Please add a comment if you want to keep the issue open. Thank you for your contributions!
Not marking as stale, but also not sure where this sits on a roadmap until there is an implementation plan.
At the moment, I would suggest we can split this into:
- Add in memory support for jax/ PyTorch arrays
- Add initial write support via conversion to numpy/ scipy arrays
- Eventual direct GPU IO
- Read support via
read_dispatched