scvi-tools icon indicating copy to clipboard operation
scvi-tools copied to clipboard

scvi-tools Jax support

Open martinkim0 opened this issue 3 years ago • 3 comments

This issue will serve as a parent thread for all issues related to supporting Jax-backend for scvi-tools models.

  • [ ] FCLayers analog #1620
  • [ ] PeakVI implementation #1578
  • [ ] Granular device management #1553
  • [x] Default Jax seed #1584

martinkim0 avatar Oct 06 '22 19:10 martinkim0

I think in the long term, it would be useful if our high-level API is framework-agnostic. Something like the following:

model = scvi.model.SCVI(adata, backend="jax")

or

model = scvi.model.SCVI(adata, backend="torch")

martinkim0 avatar Oct 06 '22 19:10 martinkim0

I think in the long term, it would be useful if our high-level API is framework-agnostic

We thought about this but it requires matching features exactly, which might be hard. But we should continue to evaluate this option.

adamgayoso avatar Oct 06 '22 19:10 adamgayoso

We thought about this but it requires matching features exactly, which might be hard. But we should continue to evaluate this option.

Oh I see, yeah in that case we can revisit this when more code is implemented.

martinkim0 avatar Oct 06 '22 20:10 martinkim0