scvi-tools icon indicating copy to clipboard operation
scvi-tools copied to clipboard

Jax version of VAEMixin

Open justjhong opened this issue 2 months ago • 0 comments

Should be easy to reimplement VAEMixin for JAX models. Will require a whole new class since the forward pass call is completely different.

Example implementation of get_reconstruction_error:

 def get_reconstruction_error(
        self,
        adata: AnnData | None = None,
        indices: list[int] | None = None,
        batch_size: int | None = None,
        **kwargs,
    ) -> dict[str, float]:
        adata = self._validate_anndata(adata)
        dataloader = self._make_data_loader(
            adata=adata, indices=indices, batch_size=batch_size, iter_ndarray=True
        )

        reconstruction_loss_sum = 0.0
        for batch in dataloader:
            vars_in = {"params": self.module.params, **self.module.state}
            outputs = self.module.apply(vars_in, batch, rngs=self.module.rngs, **kwargs)
            rec_loss_output = outputs[2].reconstruction_loss_sum.item()
            reconstruction_loss_sum += rec_loss_output

        return -(reconstruction_loss_sum / len(dataloader.dataset))

justjhong avatar May 06 '24 21:05 justjhong