scvi-tools
scvi-tools copied to clipboard
Jax version of VAEMixin
Should be easy to reimplement VAEMixin for JAX models. Will require a whole new class since the forward pass call is completely different.
Example implementation of get_reconstruction_error
:
def get_reconstruction_error(
self,
adata: AnnData | None = None,
indices: list[int] | None = None,
batch_size: int | None = None,
**kwargs,
) -> dict[str, float]:
adata = self._validate_anndata(adata)
dataloader = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size, iter_ndarray=True
)
reconstruction_loss_sum = 0.0
for batch in dataloader:
vars_in = {"params": self.module.params, **self.module.state}
outputs = self.module.apply(vars_in, batch, rngs=self.module.rngs, **kwargs)
rec_loss_output = outputs[2].reconstruction_loss_sum.item()
reconstruction_loss_sum += rec_loss_output
return -(reconstruction_loss_sum / len(dataloader.dataset))